Skip to main content

rnn/model_config/
model_config.rs

1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub struct TransformerConfig {
3    pub vocab_size: usize,
4    pub context_len: usize,
5    pub hidden_size: usize,
6    pub ffw_size: usize,
7    pub num_layers: usize,
8    pub num_heads: usize,
9}
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq)]
12pub enum ConfigError {
13    Invalid,
14    Overflow,
15}
16
17impl TransformerConfig {
18    pub fn validate(&self) -> Result<(), ConfigError> {
19        if self.vocab_size == 0
20            || self.context_len == 0
21            || self.hidden_size == 0
22            || self.ffw_size == 0
23            || self.num_layers == 0
24            || self.num_heads == 0
25        {
26            return Err(ConfigError::Invalid);
27        }
28        if self.hidden_size % self.num_heads != 0 {
29            return Err(ConfigError::Invalid);
30        }
31        Ok(())
32    }
33
34    pub fn attention_head_dim(&self) -> Result<usize, ConfigError> {
35        self.validate()?;
36        Ok(self.hidden_size / self.num_heads)
37    }
38
39    pub fn approximate_parameter_count(&self) -> Result<usize, ConfigError> {
40        self.validate()?;
41
42        let h = self.hidden_size;
43        let f = self.ffw_size;
44
45        let qkv = h.checked_mul(h).and_then(|x| x.checked_mul(3)).ok_or(ConfigError::Overflow)?;
46        let o_proj = h.checked_mul(h).ok_or(ConfigError::Overflow)?;
47        let ff_up = h.checked_mul(f).ok_or(ConfigError::Overflow)?;
48        let ff_down = f.checked_mul(h).ok_or(ConfigError::Overflow)?;
49        let block = qkv
50            .checked_add(o_proj)
51            .and_then(|x| x.checked_add(ff_up))
52            .and_then(|x| x.checked_add(ff_down))
53            .ok_or(ConfigError::Overflow)?;
54
55        let transformer = block.checked_mul(self.num_layers).ok_or(ConfigError::Overflow)?;
56        let embedding = self.vocab_size.checked_mul(h).ok_or(ConfigError::Overflow)?;
57        transformer.checked_add(embedding).ok_or(ConfigError::Overflow)
58    }
59
60    pub fn parameters_per_layer(&self) -> Result<usize, ConfigError> {
61        self.validate()?;
62
63        let h = self.hidden_size;
64        let f = self.ffw_size;
65
66        let qkv = h.checked_mul(h).and_then(|x| x.checked_mul(3)).ok_or(ConfigError::Overflow)?;
67        let o_proj = h.checked_mul(h).ok_or(ConfigError::Overflow)?;
68        let ff_up = h.checked_mul(f).ok_or(ConfigError::Overflow)?;
69        let ff_down = f.checked_mul(h).ok_or(ConfigError::Overflow)?;
70
71        qkv.checked_add(o_proj)
72            .and_then(|x| x.checked_add(ff_up))
73            .and_then(|x| x.checked_add(ff_down))
74            .ok_or(ConfigError::Overflow)
75    }
76
77    pub fn embedding_parameter_count(&self) -> Result<usize, ConfigError> {
78        self.validate()?;
79        self.vocab_size
80            .checked_mul(self.hidden_size)
81            .ok_or(ConfigError::Overflow)
82    }
83
84    pub fn total_token_elements(&self, batch_size: usize, sequence_len: usize) -> Result<usize, ConfigError> {
85        self.validate()?;
86        if batch_size == 0 || sequence_len == 0 {
87            return Err(ConfigError::Invalid);
88        }
89
90        batch_size
91            .checked_mul(sequence_len)
92            .ok_or(ConfigError::Overflow)
93    }
94
95    pub fn activation_elements(&self, batch_size: usize, sequence_len: usize) -> Result<usize, ConfigError> {
96        let tokens = self.total_token_elements(batch_size, sequence_len)?;
97        tokens
98            .checked_mul(self.hidden_size)
99            .and_then(|x| x.checked_mul(self.num_layers))
100            .ok_or(ConfigError::Overflow)
101    }
102
103    pub fn kv_cache_elements(&self, batch_size: usize, sequence_len: usize) -> Result<usize, ConfigError> {
104        let tokens = self.total_token_elements(batch_size, sequence_len)?;
105        tokens
106            .checked_mul(self.hidden_size)
107            .and_then(|x| x.checked_mul(self.num_layers))
108            .and_then(|x| x.checked_mul(2))
109            .ok_or(ConfigError::Overflow)
110    }
111
112    pub fn validate_runtime_shape(&self, batch_size: usize, sequence_len: usize) -> Result<(), ConfigError> {
113        self.validate()?;
114        if batch_size == 0 || sequence_len == 0 || sequence_len > self.context_len {
115            return Err(ConfigError::Invalid);
116        }
117        Ok(())
118    }
119}