1#[derive(Clone, Copy, Debug, PartialEq)]
2pub struct OpCounter {
3 pub matmul_ops: u64,
4 pub attention_ops: u64,
5 pub activation_ops: u64,
6 pub normalization_ops: u64,
7 pub memory_read_bytes: u64,
8 pub memory_write_bytes: u64,
9}
10
11impl OpCounter {
12 pub const fn new() -> Self {
13 Self {
14 matmul_ops: 0,
15 attention_ops: 0,
16 activation_ops: 0,
17 normalization_ops: 0,
18 memory_read_bytes: 0,
19 memory_write_bytes: 0,
20 }
21 }
22
23 pub fn add_matmul(&mut self, m: usize, n: usize, k: usize) {
24 let ops = (m as u64)
25 .saturating_mul(n as u64)
26 .saturating_mul(k as u64)
27 .saturating_mul(2);
28 self.matmul_ops = self.matmul_ops.saturating_add(ops);
29 }
30
31 pub fn add_attention(&mut self, q: usize, k: usize, d: usize) {
32 let ops = (q as u64)
33 .saturating_mul(k as u64)
34 .saturating_mul(d as u64)
35 .saturating_mul(2);
36 self.attention_ops = self.attention_ops.saturating_add(ops);
37 }
38
39 pub fn add_activation(&mut self, elems: usize) {
40 self.activation_ops = self.activation_ops.saturating_add(elems as u64);
41 }
42
43 pub fn add_normalization(&mut self, elems: usize) {
44 let ops = (elems as u64).saturating_mul(6);
45 self.normalization_ops = self.normalization_ops.saturating_add(ops);
46 }
47
48 pub fn add_memory_read(&mut self, bytes: usize) {
49 self.memory_read_bytes = self.memory_read_bytes.saturating_add(bytes as u64);
50 }
51
52 pub fn add_memory_write(&mut self, bytes: usize) {
53 self.memory_write_bytes = self.memory_write_bytes.saturating_add(bytes as u64);
54 }
55
56 pub fn merge(&mut self, other: &OpCounter) {
57 self.matmul_ops = self.matmul_ops.saturating_add(other.matmul_ops);
58 self.attention_ops = self.attention_ops.saturating_add(other.attention_ops);
59 self.activation_ops = self.activation_ops.saturating_add(other.activation_ops);
60 self.normalization_ops = self.normalization_ops.saturating_add(other.normalization_ops);
61 self.memory_read_bytes = self.memory_read_bytes.saturating_add(other.memory_read_bytes);
62 self.memory_write_bytes = self.memory_write_bytes.saturating_add(other.memory_write_bytes);
63 }
64
65 pub fn reset(&mut self) {
66 *self = Self::new();
67 }
68
69 pub fn total_ops(&self) -> u64 {
70 self.matmul_ops
71 .saturating_add(self.attention_ops)
72 .saturating_add(self.activation_ops)
73 .saturating_add(self.normalization_ops)
74 }
75
76 pub fn total_memory_bytes(&self) -> u64 {
77 self.memory_read_bytes.saturating_add(self.memory_write_bytes)
78 }
79}