Skip to main content

rnn/runtime/
profile.rs

1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub struct RuntimeProfile {
3    pub batch_size: usize,
4    pub sequence_len: usize,
5    pub bytes_per_param: usize,
6    pub bytes_per_activation: usize,
7    pub bytes_per_kv: usize,
8}
9
10impl RuntimeProfile {
11    pub const fn fp32(batch_size: usize, sequence_len: usize) -> Self {
12        Self {
13            batch_size,
14            sequence_len,
15            bytes_per_param: 4,
16            bytes_per_activation: 4,
17            bytes_per_kv: 4,
18        }
19    }
20
21    pub const fn fp16(batch_size: usize, sequence_len: usize) -> Self {
22        Self {
23            batch_size,
24            sequence_len,
25            bytes_per_param: 2,
26            bytes_per_activation: 2,
27            bytes_per_kv: 2,
28        }
29    }
30
31    pub fn validate(&self) -> bool {
32        self.batch_size > 0
33            && self.sequence_len > 0
34            && self.bytes_per_param > 0
35            && self.bytes_per_activation > 0
36            && self.bytes_per_kv > 0
37    }
38
39    pub fn token_count(&self) -> Option<usize> {
40        self.batch_size.checked_mul(self.sequence_len)
41    }
42}