Skip to main content

ferrum_interfaces/
transformer.rs

1//! Transformer model weight abstraction.
2//!
3//! Different model architectures (Qwen3, Llama, Qwen2) implement
4//! `TransformerWeights` to provide a uniform weight access interface.
5//! This decouples the execution backend from the model architecture.
6
7use crate::tensor::TensorRef;
8
9/// Configuration for a standard transformer decoder.
10#[derive(Debug, Clone)]
11pub struct TransformerConfig {
12    pub num_layers: usize,
13    pub hidden_size: usize,
14    pub num_attention_heads: usize,
15    pub num_kv_heads: usize,
16    pub head_dim: usize,
17    pub intermediate_size: usize,
18    pub vocab_size: usize,
19    pub max_seq_len: usize,
20    pub rms_norm_eps: f32,
21    /// Whether Q/K heads have per-head normalization (Qwen3 has this, Llama doesn't).
22    pub has_qk_norm: bool,
23}
24
25/// Uniform weight access for transformer decoder models.
26///
27/// This trait abstracts over different model architectures. All standard
28/// transformer decoders share the same layer structure:
29///
30/// ```text
31/// embed → N × (norm → QKV → [Q/K norm] → RoPE → Attention → O → norm → MLP) → norm → lm_head
32/// ```
33///
34/// Variance between architectures (Qwen3 vs Llama vs Qwen2):
35/// - head_dim: explicit vs derived from hidden_size / num_heads
36/// - Q/K normalization: present in Qwen3, absent in Llama
37/// - Bias: present in some Qwen2 layers, absent elsewhere
38/// - RoPE parameters: differ between architectures
39///
40/// Weights are returned as `TensorRef` — the same zero-copy handle used
41/// throughout the framework. The backend extracts device-specific pointers
42/// (CudaSlice, Metal buffer, etc.) from the TensorRef.
43pub trait TransformerWeights: Send + Sync {
44    /// Model configuration.
45    fn config(&self) -> &TransformerConfig;
46
47    /// Embedding table: [vocab_size, hidden_size]
48    fn embed_weight(&self) -> TensorRef;
49
50    /// Input layer norm weight for layer `i`: [hidden_size]
51    fn layer_input_norm_weight(&self, layer: usize) -> TensorRef;
52
53    /// Fused QKV projection weight for layer `i`: [q_dim + 2*kv_dim, hidden_size]
54    /// If the model stores Q/K/V separately, the implementation fuses them.
55    fn layer_qkv_weight(&self, layer: usize) -> TensorRef;
56
57    /// Q-head normalization weight: [head_dim] (None if architecture doesn't have it)
58    fn layer_q_norm_weight(&self, layer: usize) -> Option<TensorRef>;
59
60    /// K-head normalization weight: [head_dim] (None if architecture doesn't have it)
61    fn layer_k_norm_weight(&self, layer: usize) -> Option<TensorRef>;
62
63    /// Output projection weight for layer `i`: [hidden_size, q_dim]
64    fn layer_o_weight(&self, layer: usize) -> TensorRef;
65
66    /// Post-attention layer norm weight for layer `i`: [hidden_size]
67    fn layer_post_norm_weight(&self, layer: usize) -> TensorRef;
68
69    /// Fused gate+up projection weight: [2*intermediate_size, hidden_size]
70    fn layer_gate_up_weight(&self, layer: usize) -> TensorRef;
71
72    /// Down projection weight: [hidden_size, intermediate_size]
73    fn layer_down_weight(&self, layer: usize) -> TensorRef;
74
75    /// Final RMS norm weight: [hidden_size]
76    fn final_norm_weight(&self) -> TensorRef;
77
78    /// LM head projection weight: [vocab_size, hidden_size]
79    fn lm_head_weight(&self) -> TensorRef;
80
81    /// RoPE cosine table: [max_seq_len, head_dim/2]
82    fn rope_cos(&self) -> TensorRef;
83
84    /// RoPE sine table: [max_seq_len, head_dim/2]
85    fn rope_sin(&self) -> TensorRef;
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use crate::tensor::{TensorLike, TensorRef};
92    use std::any::Any;
93    use std::sync::Arc;
94
95    /// Minimal mock tensor for testing (avoids circular dep with ferrum-testkit).
96    #[derive(Debug)]
97    struct TestTensor {
98        shape: Vec<usize>,
99    }
100    impl TensorLike for TestTensor {
101        fn as_any(&self) -> &dyn Any {
102            self
103        }
104        fn shape(&self) -> &[usize] {
105            &self.shape
106        }
107        fn is_contiguous(&self) -> bool {
108            true
109        }
110        fn dtype(&self) -> ferrum_types::DataType {
111            ferrum_types::DataType::FP16
112        }
113        fn device(&self) -> ferrum_types::Device {
114            ferrum_types::Device::CPU
115        }
116        fn view(&self, _: &[usize], _end: &[usize]) -> ferrum_types::Result<TensorRef> {
117            Ok(Arc::new(TestTensor {
118                shape: self.shape.clone(),
119            }))
120        }
121        fn reshape(&self, shape: &[usize]) -> ferrum_types::Result<TensorRef> {
122            Ok(Arc::new(TestTensor {
123                shape: shape.to_vec(),
124            }))
125        }
126        fn to_cpu(&self) -> ferrum_types::Result<TensorRef> {
127            Ok(Arc::new(TestTensor {
128                shape: self.shape.clone(),
129            }))
130        }
131        fn to_device(&self, _: &ferrum_types::Device) -> ferrum_types::Result<TensorRef> {
132            Ok(Arc::new(TestTensor {
133                shape: self.shape.clone(),
134            }))
135        }
136        fn to_dtype(&self, _: ferrum_types::DataType) -> ferrum_types::Result<TensorRef> {
137            Ok(Arc::new(TestTensor {
138                shape: self.shape.clone(),
139            }))
140        }
141    }
142
143    struct MockWeights {
144        config: TransformerConfig,
145    }
146
147    impl MockWeights {
148        fn new(num_layers: usize) -> Self {
149            Self {
150                config: TransformerConfig {
151                    num_layers,
152                    hidden_size: 64,
153                    num_attention_heads: 4,
154                    num_kv_heads: 2,
155                    head_dim: 16,
156                    intermediate_size: 128,
157                    vocab_size: 100,
158                    max_seq_len: 512,
159                    rms_norm_eps: 1e-6,
160                    has_qk_norm: true,
161                },
162            }
163        }
164
165        fn mock_tensor(shape: &[usize]) -> TensorRef {
166            Arc::new(TestTensor {
167                shape: shape.to_vec(),
168            })
169        }
170    }
171
172    impl TransformerWeights for MockWeights {
173        fn config(&self) -> &TransformerConfig {
174            &self.config
175        }
176        fn embed_weight(&self) -> TensorRef {
177            Self::mock_tensor(&[self.config.vocab_size, self.config.hidden_size])
178        }
179        fn layer_input_norm_weight(&self, _layer: usize) -> TensorRef {
180            Self::mock_tensor(&[self.config.hidden_size])
181        }
182        fn layer_qkv_weight(&self, _layer: usize) -> TensorRef {
183            let q = self.config.num_attention_heads * self.config.head_dim;
184            let kv = self.config.num_kv_heads * self.config.head_dim;
185            Self::mock_tensor(&[q + 2 * kv, self.config.hidden_size])
186        }
187        fn layer_q_norm_weight(&self, _layer: usize) -> Option<TensorRef> {
188            if self.config.has_qk_norm {
189                Some(Self::mock_tensor(&[self.config.head_dim]))
190            } else {
191                None
192            }
193        }
194        fn layer_k_norm_weight(&self, _layer: usize) -> Option<TensorRef> {
195            self.layer_q_norm_weight(_layer)
196        }
197        fn layer_o_weight(&self, _layer: usize) -> TensorRef {
198            let q = self.config.num_attention_heads * self.config.head_dim;
199            Self::mock_tensor(&[self.config.hidden_size, q])
200        }
201        fn layer_post_norm_weight(&self, _layer: usize) -> TensorRef {
202            Self::mock_tensor(&[self.config.hidden_size])
203        }
204        fn layer_gate_up_weight(&self, _layer: usize) -> TensorRef {
205            Self::mock_tensor(&[2 * self.config.intermediate_size, self.config.hidden_size])
206        }
207        fn layer_down_weight(&self, _layer: usize) -> TensorRef {
208            Self::mock_tensor(&[self.config.hidden_size, self.config.intermediate_size])
209        }
210        fn final_norm_weight(&self) -> TensorRef {
211            Self::mock_tensor(&[self.config.hidden_size])
212        }
213        fn lm_head_weight(&self) -> TensorRef {
214            Self::mock_tensor(&[self.config.vocab_size, self.config.hidden_size])
215        }
216        fn rope_cos(&self) -> TensorRef {
217            Self::mock_tensor(&[self.config.max_seq_len, self.config.head_dim / 2])
218        }
219        fn rope_sin(&self) -> TensorRef {
220            Self::mock_tensor(&[self.config.max_seq_len, self.config.head_dim / 2])
221        }
222    }
223
224    #[test]
225    fn transformer_weights_config() {
226        let w = MockWeights::new(4);
227        assert_eq!(w.config().num_layers, 4);
228        assert_eq!(w.config().hidden_size, 64);
229        assert!(w.config().has_qk_norm);
230    }
231
232    #[test]
233    fn transformer_weights_shapes() {
234        let w = MockWeights::new(2);
235        let cfg = w.config();
236
237        // Embed: [vocab, hidden]
238        assert_eq!(w.embed_weight().shape(), &[100, 64]);
239
240        // QKV: [q_dim + 2*kv_dim, hidden]
241        let q_dim = cfg.num_attention_heads * cfg.head_dim; // 4*16=64
242        let kv_dim = cfg.num_kv_heads * cfg.head_dim; // 2*16=32
243        assert_eq!(w.layer_qkv_weight(0).shape(), &[q_dim + 2 * kv_dim, 64]);
244
245        // Q/K norm: [head_dim]
246        assert_eq!(w.layer_q_norm_weight(0).unwrap().shape(), &[16]);
247        assert_eq!(w.layer_k_norm_weight(1).unwrap().shape(), &[16]);
248
249        // Gate+Up: [2*inter, hidden]
250        assert_eq!(w.layer_gate_up_weight(0).shape(), &[256, 64]);
251
252        // LM head: [vocab, hidden]
253        assert_eq!(w.lm_head_weight().shape(), &[100, 64]);
254
255        // RoPE: [max_seq, head_dim/2]
256        assert_eq!(w.rope_cos().shape(), &[512, 8]);
257    }
258
259    #[test]
260    fn transformer_weights_no_qk_norm() {
261        let mut w = MockWeights::new(2);
262        w.config.has_qk_norm = false;
263        assert!(w.layer_q_norm_weight(0).is_none());
264        assert!(w.layer_k_norm_weight(0).is_none());
265    }
266
267    #[test]
268    fn transformer_weights_all_layers() {
269        let w = MockWeights::new(36);
270        for i in 0..36 {
271            // Every layer should return valid tensors
272            assert!(!w.layer_input_norm_weight(i).shape().is_empty());
273            assert!(!w.layer_qkv_weight(i).shape().is_empty());
274            assert!(!w.layer_o_weight(i).shape().is_empty());
275        }
276    }
277}