rnn/model_config/
model_config.rs1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub struct TransformerConfig {
3 pub vocab_size: usize,
4 pub context_len: usize,
5 pub hidden_size: usize,
6 pub ffw_size: usize,
7 pub num_layers: usize,
8 pub num_heads: usize,
9}
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq)]
12pub enum ConfigError {
13 Invalid,
14 Overflow,
15}
16
17impl TransformerConfig {
18 pub fn validate(&self) -> Result<(), ConfigError> {
19 if self.vocab_size == 0
20 || self.context_len == 0
21 || self.hidden_size == 0
22 || self.ffw_size == 0
23 || self.num_layers == 0
24 || self.num_heads == 0
25 {
26 return Err(ConfigError::Invalid);
27 }
28 if self.hidden_size % self.num_heads != 0 {
29 return Err(ConfigError::Invalid);
30 }
31 Ok(())
32 }
33
34 pub fn attention_head_dim(&self) -> Result<usize, ConfigError> {
35 self.validate()?;
36 Ok(self.hidden_size / self.num_heads)
37 }
38
39 pub fn approximate_parameter_count(&self) -> Result<usize, ConfigError> {
40 self.validate()?;
41
42 let h = self.hidden_size;
43 let f = self.ffw_size;
44
45 let qkv = h.checked_mul(h).and_then(|x| x.checked_mul(3)).ok_or(ConfigError::Overflow)?;
46 let o_proj = h.checked_mul(h).ok_or(ConfigError::Overflow)?;
47 let ff_up = h.checked_mul(f).ok_or(ConfigError::Overflow)?;
48 let ff_down = f.checked_mul(h).ok_or(ConfigError::Overflow)?;
49 let block = qkv
50 .checked_add(o_proj)
51 .and_then(|x| x.checked_add(ff_up))
52 .and_then(|x| x.checked_add(ff_down))
53 .ok_or(ConfigError::Overflow)?;
54
55 let transformer = block.checked_mul(self.num_layers).ok_or(ConfigError::Overflow)?;
56 let embedding = self.vocab_size.checked_mul(h).ok_or(ConfigError::Overflow)?;
57 transformer.checked_add(embedding).ok_or(ConfigError::Overflow)
58 }
59
60 pub fn parameters_per_layer(&self) -> Result<usize, ConfigError> {
61 self.validate()?;
62
63 let h = self.hidden_size;
64 let f = self.ffw_size;
65
66 let qkv = h.checked_mul(h).and_then(|x| x.checked_mul(3)).ok_or(ConfigError::Overflow)?;
67 let o_proj = h.checked_mul(h).ok_or(ConfigError::Overflow)?;
68 let ff_up = h.checked_mul(f).ok_or(ConfigError::Overflow)?;
69 let ff_down = f.checked_mul(h).ok_or(ConfigError::Overflow)?;
70
71 qkv.checked_add(o_proj)
72 .and_then(|x| x.checked_add(ff_up))
73 .and_then(|x| x.checked_add(ff_down))
74 .ok_or(ConfigError::Overflow)
75 }
76
77 pub fn embedding_parameter_count(&self) -> Result<usize, ConfigError> {
78 self.validate()?;
79 self.vocab_size
80 .checked_mul(self.hidden_size)
81 .ok_or(ConfigError::Overflow)
82 }
83
84 pub fn total_token_elements(&self, batch_size: usize, sequence_len: usize) -> Result<usize, ConfigError> {
85 self.validate()?;
86 if batch_size == 0 || sequence_len == 0 {
87 return Err(ConfigError::Invalid);
88 }
89
90 batch_size
91 .checked_mul(sequence_len)
92 .ok_or(ConfigError::Overflow)
93 }
94
95 pub fn activation_elements(&self, batch_size: usize, sequence_len: usize) -> Result<usize, ConfigError> {
96 let tokens = self.total_token_elements(batch_size, sequence_len)?;
97 tokens
98 .checked_mul(self.hidden_size)
99 .and_then(|x| x.checked_mul(self.num_layers))
100 .ok_or(ConfigError::Overflow)
101 }
102
103 pub fn kv_cache_elements(&self, batch_size: usize, sequence_len: usize) -> Result<usize, ConfigError> {
104 let tokens = self.total_token_elements(batch_size, sequence_len)?;
105 tokens
106 .checked_mul(self.hidden_size)
107 .and_then(|x| x.checked_mul(self.num_layers))
108 .and_then(|x| x.checked_mul(2))
109 .ok_or(ConfigError::Overflow)
110 }
111
112 pub fn validate_runtime_shape(&self, batch_size: usize, sequence_len: usize) -> Result<(), ConfigError> {
113 self.validate()?;
114 if batch_size == 0 || sequence_len == 0 || sequence_len > self.context_len {
115 return Err(ConfigError::Invalid);
116 }
117 Ok(())
118 }
119}