1use std::collections::HashMap;
11use std::fs;
12use std::path::Path;
13
14use bytemuck::cast_slice;
15use safetensors::{Dtype, SafeTensors};
16
17#[derive(Debug, thiserror::Error)]
19pub enum ModelError {
20 #[error("shape mismatch: {0}")]
21 Shape(String),
22 #[error("missing weight: {0}")]
23 MissingWeight(String),
24 #[error("invalid dtype for {name}: expected F32, got {dtype:?}")]
25 InvalidDtype { name: String, dtype: Dtype },
26 #[error("io error: {0}")]
27 Io(#[from] std::io::Error),
28 #[error("safetensors error: {0}")]
29 SafeTensors(#[from] safetensors::SafeTensorError),
30}
31
32#[derive(Debug, Clone)]
34pub struct Tensor {
35 pub shape: Vec<usize>,
36 pub data: Vec<f32>,
37}
38
39#[derive(Debug, Clone, Default)]
41pub struct ModelWeights {
42 tensors: HashMap<String, Tensor>,
43}
44
45impl ModelWeights {
46 pub fn load_safetensors_bytes(bytes: &[u8]) -> Result<Self, ModelError> {
47 let st = SafeTensors::deserialize(bytes)?;
48 let mut tensors = HashMap::new();
49
50 for name in st.names() {
51 let view = st.tensor(name)?;
52 if view.dtype() != Dtype::F32 {
53 return Err(ModelError::InvalidDtype {
54 name: name.to_string(),
55 dtype: view.dtype(),
56 });
57 }
58
59 let shape = view.shape().to_vec();
60 let raw: &[f32] = cast_slice(view.data());
61 tensors.insert(
62 name.to_string(),
63 Tensor {
64 shape,
65 data: raw.to_vec(),
66 },
67 );
68 }
69
70 Ok(Self { tensors })
71 }
72
73 pub fn load_safetensors_file(path: impl AsRef<Path>) -> Result<Self, ModelError> {
74 let bytes = fs::read(path)?;
75 Self::load_safetensors_bytes(&bytes)
76 }
77
78 pub fn get(&self, name: &str) -> Result<&Tensor, ModelError> {
79 self.tensors
80 .get(name)
81 .ok_or_else(|| ModelError::MissingWeight(name.to_string()))
82 }
83}
84
85pub fn rms_norm(input: &[f32], weight: &[f32], eps: f32) -> Result<Vec<f32>, ModelError> {
87 if input.len() != weight.len() {
88 return Err(ModelError::Shape(format!(
89 "rms_norm input/weight mismatch: {} != {}",
90 input.len(),
91 weight.len()
92 )));
93 }
94 if input.is_empty() {
95 return Ok(Vec::new());
96 }
97 let mean_sq = input.iter().map(|x| x * x).sum::<f32>() / input.len() as f32;
98 let inv_rms = 1.0 / (mean_sq + eps).sqrt();
99 Ok(input
100 .iter()
101 .zip(weight.iter())
102 .map(|(&x, &w)| x * inv_rms * w)
103 .collect())
104}
105
106pub fn apply_rope(
110 q: &mut [f32],
111 k: &mut [f32],
112 position: usize,
113 n_heads: usize,
114 head_dim: usize,
115 base: f32,
116) -> Result<(), ModelError> {
117 let expected = n_heads * head_dim;
118 if q.len() != expected || k.len() != expected {
119 return Err(ModelError::Shape(format!(
120 "rope q/k mismatch: expected {}, got q={}, k={}",
121 expected,
122 q.len(),
123 k.len()
124 )));
125 }
126 if !head_dim.is_multiple_of(2) {
127 return Err(ModelError::Shape(format!(
128 "head_dim must be even for RoPE, got {}",
129 head_dim
130 )));
131 }
132
133 for h in 0..n_heads {
134 let offset = h * head_dim;
135 for i in (0..head_dim).step_by(2) {
136 let theta = (position as f32) / base.powf(i as f32 / head_dim as f32);
137 let (sin_t, cos_t) = theta.sin_cos();
138
139 let q0 = q[offset + i];
140 let q1 = q[offset + i + 1];
141 q[offset + i] = q0 * cos_t - q1 * sin_t;
142 q[offset + i + 1] = q0 * sin_t + q1 * cos_t;
143
144 let k0 = k[offset + i];
145 let k1 = k[offset + i + 1];
146 k[offset + i] = k0 * cos_t - k1 * sin_t;
147 k[offset + i + 1] = k0 * sin_t + k1 * cos_t;
148 }
149 }
150 Ok(())
151}
152
153pub fn attention_decode(
158 query: &[f32],
159 keys: &[f32],
160 values: &[f32],
161 seq_len: usize,
162 n_heads: usize,
163 head_dim: usize,
164) -> Result<Vec<f32>, ModelError> {
165 let q_expected = n_heads * head_dim;
166 let kv_expected = seq_len * n_heads * head_dim;
167 if query.len() != q_expected {
168 return Err(ModelError::Shape(format!(
169 "query shape mismatch: expected {}, got {}",
170 q_expected,
171 query.len()
172 )));
173 }
174 if keys.len() != kv_expected || values.len() != kv_expected {
175 return Err(ModelError::Shape(format!(
176 "kv shape mismatch: expected {}, got k={}, v={}",
177 kv_expected,
178 keys.len(),
179 values.len()
180 )));
181 }
182
183 let mut out = vec![0.0; q_expected];
184 let scale = 1.0 / (head_dim as f32).sqrt();
185
186 for h in 0..n_heads {
187 let qh = &query[h * head_dim..(h + 1) * head_dim];
188
189 let mut scores = vec![0.0f32; seq_len];
190 for (t, score) in scores.iter_mut().enumerate().take(seq_len) {
191 let kh_off = (t * n_heads + h) * head_dim;
192 let kh = &keys[kh_off..kh_off + head_dim];
193 let dot = qh.iter().zip(kh.iter()).map(|(&a, &b)| a * b).sum::<f32>();
194 *score = dot * scale;
195 }
196
197 let probs = softmax(&scores);
198 for (t, &p) in probs.iter().enumerate() {
199 let vh_off = (t * n_heads + h) * head_dim;
200 let vh = &values[vh_off..vh_off + head_dim];
201 let out_h = &mut out[h * head_dim..(h + 1) * head_dim];
202 for i in 0..head_dim {
203 out_h[i] += p * vh[i];
204 }
205 }
206 }
207
208 Ok(out)
209}
210
211pub fn mlp_swiglu(
217 x: &[f32],
218 w_gate: &[f32], w_up: &[f32], w_down: &[f32], d_model: usize,
222 d_ff: usize,
223) -> Result<Vec<f32>, ModelError> {
224 if x.len() != d_model {
225 return Err(ModelError::Shape(format!(
226 "x shape mismatch: expected {}, got {}",
227 d_model,
228 x.len()
229 )));
230 }
231 if w_gate.len() != d_model * d_ff || w_up.len() != d_model * d_ff {
232 return Err(ModelError::Shape("w_gate/w_up shape mismatch".to_string()));
233 }
234 if w_down.len() != d_ff * d_model {
235 return Err(ModelError::Shape("w_down shape mismatch".to_string()));
236 }
237
238 let gate_pre = matvec_row_major(x, w_gate, d_model, d_ff);
239 let up = matvec_row_major(x, w_up, d_model, d_ff);
240 let hidden: Vec<f32> = gate_pre
241 .iter()
242 .zip(up.iter())
243 .map(|(&g, &u)| silu(g) * u)
244 .collect();
245 Ok(matvec_row_major(&hidden, w_down, d_ff, d_model))
246}
247
248pub struct LlamaBlock;
250
251impl LlamaBlock {
252 pub fn forward(
253 input: &[f32],
254 norm_weight: &[f32],
255 w_gate: &[f32],
256 w_up: &[f32],
257 w_down: &[f32],
258 d_model: usize,
259 d_ff: usize,
260 ) -> Result<Vec<f32>, ModelError> {
261 let x = rms_norm(input, norm_weight, 1e-5)?;
262 mlp_swiglu(&x, w_gate, w_up, w_down, d_model, d_ff)
263 }
264}
265
266pub struct QwenBlock;
268
269impl QwenBlock {
270 pub fn forward(
271 input: &[f32],
272 norm_weight: &[f32],
273 w_gate: &[f32],
274 w_up: &[f32],
275 w_down: &[f32],
276 d_model: usize,
277 d_ff: usize,
278 ) -> Result<Vec<f32>, ModelError> {
279 LlamaBlock::forward(input, norm_weight, w_gate, w_up, w_down, d_model, d_ff)
280 }
281}
282
283fn matvec_row_major(x: &[f32], w: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
284 let mut out = vec![0.0; out_dim];
285 for i in 0..out_dim {
286 let mut acc = 0.0f32;
287 for j in 0..in_dim {
288 acc += x[j] * w[j * out_dim + i];
289 }
290 out[i] = acc;
291 }
292 out
293}
294
295fn silu(x: f32) -> f32 {
296 x / (1.0 + (-x).exp())
297}
298
299fn softmax(x: &[f32]) -> Vec<f32> {
300 let max = x.iter().copied().fold(f32::NEG_INFINITY, f32::max);
301 let mut exps: Vec<f32> = x.iter().map(|v| (v - max).exp()).collect();
302 let sum: f32 = exps.iter().sum();
303 if sum > 0.0 {
304 for e in &mut exps {
305 *e /= sum;
306 }
307 }
308 exps
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use safetensors::tensor::{serialize, TensorView};
315 use std::collections::BTreeMap;
316
317 fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
318 assert_eq!(actual.len(), expected.len());
319 for (idx, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
320 assert!(
321 (a - e).abs() <= tol,
322 "idx={} actual={} expected={} tol={}",
323 idx,
324 a,
325 e,
326 tol
327 );
328 }
329 }
330
331 #[test]
332 fn rms_norm_matches_reference() {
333 let x = [1.0, 2.0, 3.0, 4.0];
335 let w = [1.0, 1.0, 1.0, 1.0];
336 let y = rms_norm(&x, &w, 1e-5).unwrap();
337 let expected = [0.36514813, 0.73029625, 1.0954444, 1.4605925];
338 assert_close(&y, &expected, 1e-5);
339 }
340
341 #[test]
342 fn rope_matches_reference() {
343 let mut q = [1.0, 0.0, 0.0, 1.0];
344 let mut k = q;
345 apply_rope(&mut q, &mut k, 1, 1, 4, 10_000.0).unwrap();
346 let expected = [0.5403023, 0.84147096, -0.009999833, 0.99995];
347 assert_close(&q, &expected, 1e-5);
348 assert_close(&k, &expected, 1e-5);
349 }
350
351 #[test]
352 fn attention_matches_reference() {
353 let query = [1.0, 0.0];
354 let keys = [1.0, 0.0, 0.0, 1.0]; let values = [10.0, 0.0, 0.0, 20.0];
356 let out = attention_decode(&query, &keys, &values, 2, 1, 2).unwrap();
357 let expected = [6.697615, 6.60477];
358 assert_close(&out, &expected, 1e-5);
359 }
360
361 #[test]
362 fn swiglu_and_blocks_produce_expected_shape() {
363 let x = [0.5, -1.0];
364 let norm = [1.0, 1.0];
365 let w_gate = [1.0, 0.0, 0.0, 1.0];
366 let w_up = [0.5, 0.0, 0.0, 0.5];
367 let w_down = [1.0, 0.0, 0.0, 1.0];
368
369 let y_llama = LlamaBlock::forward(&x, &norm, &w_gate, &w_up, &w_down, 2, 2).unwrap();
370 let y_qwen = QwenBlock::forward(&x, &norm, &w_gate, &w_up, &w_down, 2, 2).unwrap();
371
372 assert_eq!(y_llama.len(), 2);
373 assert_eq!(y_qwen.len(), 2);
374 assert_close(&y_llama, &y_qwen, 1e-6);
375 }
376
377 #[test]
378 fn loads_weights_from_safetensors() {
379 let tensor = [1.0f32, 2.0, 3.0, 4.0];
380 let view = TensorView::new(Dtype::F32, vec![2, 2], cast_slice(&tensor)).unwrap();
381 let mut map = BTreeMap::new();
382 map.insert("w".to_string(), view);
383 let bytes = serialize(map, &None).unwrap();
384
385 let weights = ModelWeights::load_safetensors_bytes(&bytes).unwrap();
386 let w = weights.get("w").unwrap();
387 assert_eq!(w.shape, vec![2, 2]);
388 assert_eq!(w.data, tensor);
389 }
390}