Skip to main content

rnn/profiler/
profiler.rs

1#[derive(Clone, Copy, Debug, PartialEq)]
2pub struct OpCounter {
3    pub matmul_ops: u64,
4    pub attention_ops: u64,
5    pub activation_ops: u64,
6    pub normalization_ops: u64,
7    pub memory_read_bytes: u64,
8    pub memory_write_bytes: u64,
9}
10
11impl OpCounter {
12    pub const fn new() -> Self {
13        Self {
14            matmul_ops: 0,
15            attention_ops: 0,
16            activation_ops: 0,
17            normalization_ops: 0,
18            memory_read_bytes: 0,
19            memory_write_bytes: 0,
20        }
21    }
22
23    pub fn add_matmul(&mut self, m: usize, n: usize, k: usize) {
24        let ops = (m as u64)
25            .saturating_mul(n as u64)
26            .saturating_mul(k as u64)
27            .saturating_mul(2);
28        self.matmul_ops = self.matmul_ops.saturating_add(ops);
29    }
30
31    pub fn add_attention(&mut self, q: usize, k: usize, d: usize) {
32        let ops = (q as u64)
33            .saturating_mul(k as u64)
34            .saturating_mul(d as u64)
35            .saturating_mul(2);
36        self.attention_ops = self.attention_ops.saturating_add(ops);
37    }
38
39    pub fn add_activation(&mut self, elems: usize) {
40        self.activation_ops = self.activation_ops.saturating_add(elems as u64);
41    }
42
43    pub fn add_normalization(&mut self, elems: usize) {
44        let ops = (elems as u64).saturating_mul(6);
45        self.normalization_ops = self.normalization_ops.saturating_add(ops);
46    }
47
48    pub fn add_memory_read(&mut self, bytes: usize) {
49        self.memory_read_bytes = self.memory_read_bytes.saturating_add(bytes as u64);
50    }
51
52    pub fn add_memory_write(&mut self, bytes: usize) {
53        self.memory_write_bytes = self.memory_write_bytes.saturating_add(bytes as u64);
54    }
55
56    pub fn merge(&mut self, other: &OpCounter) {
57        self.matmul_ops = self.matmul_ops.saturating_add(other.matmul_ops);
58        self.attention_ops = self.attention_ops.saturating_add(other.attention_ops);
59        self.activation_ops = self.activation_ops.saturating_add(other.activation_ops);
60        self.normalization_ops = self.normalization_ops.saturating_add(other.normalization_ops);
61        self.memory_read_bytes = self.memory_read_bytes.saturating_add(other.memory_read_bytes);
62        self.memory_write_bytes = self.memory_write_bytes.saturating_add(other.memory_write_bytes);
63    }
64
65    pub fn reset(&mut self) {
66        *self = Self::new();
67    }
68
69    pub fn total_ops(&self) -> u64 {
70        self.matmul_ops
71            .saturating_add(self.attention_ops)
72            .saturating_add(self.activation_ops)
73            .saturating_add(self.normalization_ops)
74    }
75
76    pub fn total_memory_bytes(&self) -> u64 {
77        self.memory_read_bytes.saturating_add(self.memory_write_bytes)
78    }
79}