ferrum_interfaces/
transformer.rs1use crate::tensor::TensorRef;
8
9#[derive(Debug, Clone)]
11pub struct TransformerConfig {
12 pub num_layers: usize,
13 pub hidden_size: usize,
14 pub num_attention_heads: usize,
15 pub num_kv_heads: usize,
16 pub head_dim: usize,
17 pub intermediate_size: usize,
18 pub vocab_size: usize,
19 pub max_seq_len: usize,
20 pub rms_norm_eps: f32,
21 pub has_qk_norm: bool,
23}
24
25pub trait TransformerWeights: Send + Sync {
44 fn config(&self) -> &TransformerConfig;
46
47 fn embed_weight(&self) -> TensorRef;
49
50 fn layer_input_norm_weight(&self, layer: usize) -> TensorRef;
52
53 fn layer_qkv_weight(&self, layer: usize) -> TensorRef;
56
57 fn layer_q_norm_weight(&self, layer: usize) -> Option<TensorRef>;
59
60 fn layer_k_norm_weight(&self, layer: usize) -> Option<TensorRef>;
62
63 fn layer_o_weight(&self, layer: usize) -> TensorRef;
65
66 fn layer_post_norm_weight(&self, layer: usize) -> TensorRef;
68
69 fn layer_gate_up_weight(&self, layer: usize) -> TensorRef;
71
72 fn layer_down_weight(&self, layer: usize) -> TensorRef;
74
75 fn final_norm_weight(&self) -> TensorRef;
77
78 fn lm_head_weight(&self) -> TensorRef;
80
81 fn rope_cos(&self) -> TensorRef;
83
84 fn rope_sin(&self) -> TensorRef;
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use crate::tensor::{TensorLike, TensorRef};
92 use std::any::Any;
93 use std::sync::Arc;
94
95 #[derive(Debug)]
97 struct TestTensor {
98 shape: Vec<usize>,
99 }
100 impl TensorLike for TestTensor {
101 fn as_any(&self) -> &dyn Any {
102 self
103 }
104 fn shape(&self) -> &[usize] {
105 &self.shape
106 }
107 fn is_contiguous(&self) -> bool {
108 true
109 }
110 fn dtype(&self) -> ferrum_types::DataType {
111 ferrum_types::DataType::FP16
112 }
113 fn device(&self) -> ferrum_types::Device {
114 ferrum_types::Device::CPU
115 }
116 fn view(&self, _: &[usize], _end: &[usize]) -> ferrum_types::Result<TensorRef> {
117 Ok(Arc::new(TestTensor {
118 shape: self.shape.clone(),
119 }))
120 }
121 fn reshape(&self, shape: &[usize]) -> ferrum_types::Result<TensorRef> {
122 Ok(Arc::new(TestTensor {
123 shape: shape.to_vec(),
124 }))
125 }
126 fn to_cpu(&self) -> ferrum_types::Result<TensorRef> {
127 Ok(Arc::new(TestTensor {
128 shape: self.shape.clone(),
129 }))
130 }
131 fn to_device(&self, _: &ferrum_types::Device) -> ferrum_types::Result<TensorRef> {
132 Ok(Arc::new(TestTensor {
133 shape: self.shape.clone(),
134 }))
135 }
136 fn to_dtype(&self, _: ferrum_types::DataType) -> ferrum_types::Result<TensorRef> {
137 Ok(Arc::new(TestTensor {
138 shape: self.shape.clone(),
139 }))
140 }
141 }
142
143 struct MockWeights {
144 config: TransformerConfig,
145 }
146
147 impl MockWeights {
148 fn new(num_layers: usize) -> Self {
149 Self {
150 config: TransformerConfig {
151 num_layers,
152 hidden_size: 64,
153 num_attention_heads: 4,
154 num_kv_heads: 2,
155 head_dim: 16,
156 intermediate_size: 128,
157 vocab_size: 100,
158 max_seq_len: 512,
159 rms_norm_eps: 1e-6,
160 has_qk_norm: true,
161 },
162 }
163 }
164
165 fn mock_tensor(shape: &[usize]) -> TensorRef {
166 Arc::new(TestTensor {
167 shape: shape.to_vec(),
168 })
169 }
170 }
171
172 impl TransformerWeights for MockWeights {
173 fn config(&self) -> &TransformerConfig {
174 &self.config
175 }
176 fn embed_weight(&self) -> TensorRef {
177 Self::mock_tensor(&[self.config.vocab_size, self.config.hidden_size])
178 }
179 fn layer_input_norm_weight(&self, _layer: usize) -> TensorRef {
180 Self::mock_tensor(&[self.config.hidden_size])
181 }
182 fn layer_qkv_weight(&self, _layer: usize) -> TensorRef {
183 let q = self.config.num_attention_heads * self.config.head_dim;
184 let kv = self.config.num_kv_heads * self.config.head_dim;
185 Self::mock_tensor(&[q + 2 * kv, self.config.hidden_size])
186 }
187 fn layer_q_norm_weight(&self, _layer: usize) -> Option<TensorRef> {
188 if self.config.has_qk_norm {
189 Some(Self::mock_tensor(&[self.config.head_dim]))
190 } else {
191 None
192 }
193 }
194 fn layer_k_norm_weight(&self, _layer: usize) -> Option<TensorRef> {
195 self.layer_q_norm_weight(_layer)
196 }
197 fn layer_o_weight(&self, _layer: usize) -> TensorRef {
198 let q = self.config.num_attention_heads * self.config.head_dim;
199 Self::mock_tensor(&[self.config.hidden_size, q])
200 }
201 fn layer_post_norm_weight(&self, _layer: usize) -> TensorRef {
202 Self::mock_tensor(&[self.config.hidden_size])
203 }
204 fn layer_gate_up_weight(&self, _layer: usize) -> TensorRef {
205 Self::mock_tensor(&[2 * self.config.intermediate_size, self.config.hidden_size])
206 }
207 fn layer_down_weight(&self, _layer: usize) -> TensorRef {
208 Self::mock_tensor(&[self.config.hidden_size, self.config.intermediate_size])
209 }
210 fn final_norm_weight(&self) -> TensorRef {
211 Self::mock_tensor(&[self.config.hidden_size])
212 }
213 fn lm_head_weight(&self) -> TensorRef {
214 Self::mock_tensor(&[self.config.vocab_size, self.config.hidden_size])
215 }
216 fn rope_cos(&self) -> TensorRef {
217 Self::mock_tensor(&[self.config.max_seq_len, self.config.head_dim / 2])
218 }
219 fn rope_sin(&self) -> TensorRef {
220 Self::mock_tensor(&[self.config.max_seq_len, self.config.head_dim / 2])
221 }
222 }
223
224 #[test]
225 fn transformer_weights_config() {
226 let w = MockWeights::new(4);
227 assert_eq!(w.config().num_layers, 4);
228 assert_eq!(w.config().hidden_size, 64);
229 assert!(w.config().has_qk_norm);
230 }
231
232 #[test]
233 fn transformer_weights_shapes() {
234 let w = MockWeights::new(2);
235 let cfg = w.config();
236
237 assert_eq!(w.embed_weight().shape(), &[100, 64]);
239
240 let q_dim = cfg.num_attention_heads * cfg.head_dim; let kv_dim = cfg.num_kv_heads * cfg.head_dim; assert_eq!(w.layer_qkv_weight(0).shape(), &[q_dim + 2 * kv_dim, 64]);
244
245 assert_eq!(w.layer_q_norm_weight(0).unwrap().shape(), &[16]);
247 assert_eq!(w.layer_k_norm_weight(1).unwrap().shape(), &[16]);
248
249 assert_eq!(w.layer_gate_up_weight(0).shape(), &[256, 64]);
251
252 assert_eq!(w.lm_head_weight().shape(), &[100, 64]);
254
255 assert_eq!(w.rope_cos().shape(), &[512, 8]);
257 }
258
259 #[test]
260 fn transformer_weights_no_qk_norm() {
261 let mut w = MockWeights::new(2);
262 w.config.has_qk_norm = false;
263 assert!(w.layer_q_norm_weight(0).is_none());
264 assert!(w.layer_k_norm_weight(0).is_none());
265 }
266
267 #[test]
268 fn transformer_weights_all_layers() {
269 let w = MockWeights::new(36);
270 for i in 0..36 {
271 assert!(!w.layer_input_norm_weight(i).shape().is_empty());
273 assert!(!w.layer_qkv_weight(i).shape().is_empty());
274 assert!(!w.layer_o_weight(i).shape().is_empty());
275 }
276 }
277}