Skip to main content

llama_models/
lib.rs

1//! # llama-models
2//!
3//! Foundational model blocks for llama.rs:
4//! - RMSNorm
5//! - RoPE
6//! - Attention (scaled dot-product, single-step decode form)
7//! - MLP (SwiGLU)
8//! - Safetensors weight loading
9
10use std::collections::HashMap;
11use std::fs;
12use std::path::Path;
13
14use bytemuck::cast_slice;
15use safetensors::{Dtype, SafeTensors};
16
17/// Errors for model operations and weight loading.
18#[derive(Debug, thiserror::Error)]
19pub enum ModelError {
20    #[error("shape mismatch: {0}")]
21    Shape(String),
22    #[error("missing weight: {0}")]
23    MissingWeight(String),
24    #[error("invalid dtype for {name}: expected F32, got {dtype:?}")]
25    InvalidDtype { name: String, dtype: Dtype },
26    #[error("io error: {0}")]
27    Io(#[from] std::io::Error),
28    #[error("safetensors error: {0}")]
29    SafeTensors(#[from] safetensors::SafeTensorError),
30}
31
32/// Lightweight tensor holder for loaded model weights.
33#[derive(Debug, Clone)]
34pub struct Tensor {
35    pub shape: Vec<usize>,
36    pub data: Vec<f32>,
37}
38
39/// Named weight storage loaded from safetensors.
40#[derive(Debug, Clone, Default)]
41pub struct ModelWeights {
42    tensors: HashMap<String, Tensor>,
43}
44
45impl ModelWeights {
46    pub fn load_safetensors_bytes(bytes: &[u8]) -> Result<Self, ModelError> {
47        let st = SafeTensors::deserialize(bytes)?;
48        let mut tensors = HashMap::new();
49
50        for name in st.names() {
51            let view = st.tensor(name)?;
52            if view.dtype() != Dtype::F32 {
53                return Err(ModelError::InvalidDtype {
54                    name: name.to_string(),
55                    dtype: view.dtype(),
56                });
57            }
58
59            let shape = view.shape().to_vec();
60            let raw: &[f32] = cast_slice(view.data());
61            tensors.insert(
62                name.to_string(),
63                Tensor {
64                    shape,
65                    data: raw.to_vec(),
66                },
67            );
68        }
69
70        Ok(Self { tensors })
71    }
72
73    pub fn load_safetensors_file(path: impl AsRef<Path>) -> Result<Self, ModelError> {
74        let bytes = fs::read(path)?;
75        Self::load_safetensors_bytes(&bytes)
76    }
77
78    pub fn get(&self, name: &str) -> Result<&Tensor, ModelError> {
79        self.tensors
80            .get(name)
81            .ok_or_else(|| ModelError::MissingWeight(name.to_string()))
82    }
83}
84
85/// Root mean square normalization.
86pub fn rms_norm(input: &[f32], weight: &[f32], eps: f32) -> Result<Vec<f32>, ModelError> {
87    if input.len() != weight.len() {
88        return Err(ModelError::Shape(format!(
89            "rms_norm input/weight mismatch: {} != {}",
90            input.len(),
91            weight.len()
92        )));
93    }
94    if input.is_empty() {
95        return Ok(Vec::new());
96    }
97    let mean_sq = input.iter().map(|x| x * x).sum::<f32>() / input.len() as f32;
98    let inv_rms = 1.0 / (mean_sq + eps).sqrt();
99    Ok(input
100        .iter()
101        .zip(weight.iter())
102        .map(|(&x, &w)| x * inv_rms * w)
103        .collect())
104}
105
106/// Apply rotary positional embeddings in-place to query and key vectors.
107///
108/// `q` and `k` are flattened `[n_heads * head_dim]`.
109pub fn apply_rope(
110    q: &mut [f32],
111    k: &mut [f32],
112    position: usize,
113    n_heads: usize,
114    head_dim: usize,
115    base: f32,
116) -> Result<(), ModelError> {
117    let expected = n_heads * head_dim;
118    if q.len() != expected || k.len() != expected {
119        return Err(ModelError::Shape(format!(
120            "rope q/k mismatch: expected {}, got q={}, k={}",
121            expected,
122            q.len(),
123            k.len()
124        )));
125    }
126    if !head_dim.is_multiple_of(2) {
127        return Err(ModelError::Shape(format!(
128            "head_dim must be even for RoPE, got {}",
129            head_dim
130        )));
131    }
132
133    for h in 0..n_heads {
134        let offset = h * head_dim;
135        for i in (0..head_dim).step_by(2) {
136            let theta = (position as f32) / base.powf(i as f32 / head_dim as f32);
137            let (sin_t, cos_t) = theta.sin_cos();
138
139            let q0 = q[offset + i];
140            let q1 = q[offset + i + 1];
141            q[offset + i] = q0 * cos_t - q1 * sin_t;
142            q[offset + i + 1] = q0 * sin_t + q1 * cos_t;
143
144            let k0 = k[offset + i];
145            let k1 = k[offset + i + 1];
146            k[offset + i] = k0 * cos_t - k1 * sin_t;
147            k[offset + i + 1] = k0 * sin_t + k1 * cos_t;
148        }
149    }
150    Ok(())
151}
152
153/// Single-step decode attention:
154/// - `query`: `[n_heads * head_dim]`
155/// - `keys`/`values`: `[seq_len * n_heads * head_dim]`
156/// - output: `[n_heads * head_dim]`
157pub fn attention_decode(
158    query: &[f32],
159    keys: &[f32],
160    values: &[f32],
161    seq_len: usize,
162    n_heads: usize,
163    head_dim: usize,
164) -> Result<Vec<f32>, ModelError> {
165    let q_expected = n_heads * head_dim;
166    let kv_expected = seq_len * n_heads * head_dim;
167    if query.len() != q_expected {
168        return Err(ModelError::Shape(format!(
169            "query shape mismatch: expected {}, got {}",
170            q_expected,
171            query.len()
172        )));
173    }
174    if keys.len() != kv_expected || values.len() != kv_expected {
175        return Err(ModelError::Shape(format!(
176            "kv shape mismatch: expected {}, got k={}, v={}",
177            kv_expected,
178            keys.len(),
179            values.len()
180        )));
181    }
182
183    let mut out = vec![0.0; q_expected];
184    let scale = 1.0 / (head_dim as f32).sqrt();
185
186    for h in 0..n_heads {
187        let qh = &query[h * head_dim..(h + 1) * head_dim];
188
189        let mut scores = vec![0.0f32; seq_len];
190        for (t, score) in scores.iter_mut().enumerate().take(seq_len) {
191            let kh_off = (t * n_heads + h) * head_dim;
192            let kh = &keys[kh_off..kh_off + head_dim];
193            let dot = qh.iter().zip(kh.iter()).map(|(&a, &b)| a * b).sum::<f32>();
194            *score = dot * scale;
195        }
196
197        let probs = softmax(&scores);
198        for (t, &p) in probs.iter().enumerate() {
199            let vh_off = (t * n_heads + h) * head_dim;
200            let vh = &values[vh_off..vh_off + head_dim];
201            let out_h = &mut out[h * head_dim..(h + 1) * head_dim];
202            for i in 0..head_dim {
203                out_h[i] += p * vh[i];
204            }
205        }
206    }
207
208    Ok(out)
209}
210
211/// SwiGLU MLP:
212/// - gate = silu(x * W_gate)
213/// - up = x * W_up
214/// - hidden = gate .* up
215/// - out = hidden * W_down
216pub fn mlp_swiglu(
217    x: &[f32],
218    w_gate: &[f32], // [d_model, d_ff]
219    w_up: &[f32],   // [d_model, d_ff]
220    w_down: &[f32], // [d_ff, d_model]
221    d_model: usize,
222    d_ff: usize,
223) -> Result<Vec<f32>, ModelError> {
224    if x.len() != d_model {
225        return Err(ModelError::Shape(format!(
226            "x shape mismatch: expected {}, got {}",
227            d_model,
228            x.len()
229        )));
230    }
231    if w_gate.len() != d_model * d_ff || w_up.len() != d_model * d_ff {
232        return Err(ModelError::Shape("w_gate/w_up shape mismatch".to_string()));
233    }
234    if w_down.len() != d_ff * d_model {
235        return Err(ModelError::Shape("w_down shape mismatch".to_string()));
236    }
237
238    let gate_pre = matvec_row_major(x, w_gate, d_model, d_ff);
239    let up = matvec_row_major(x, w_up, d_model, d_ff);
240    let hidden: Vec<f32> = gate_pre
241        .iter()
242        .zip(up.iter())
243        .map(|(&g, &u)| silu(g) * u)
244        .collect();
245    Ok(matvec_row_major(&hidden, w_down, d_ff, d_model))
246}
247
248/// Minimal Llama block composition.
249pub struct LlamaBlock;
250
251impl LlamaBlock {
252    pub fn forward(
253        input: &[f32],
254        norm_weight: &[f32],
255        w_gate: &[f32],
256        w_up: &[f32],
257        w_down: &[f32],
258        d_model: usize,
259        d_ff: usize,
260    ) -> Result<Vec<f32>, ModelError> {
261        let x = rms_norm(input, norm_weight, 1e-5)?;
262        mlp_swiglu(&x, w_gate, w_up, w_down, d_model, d_ff)
263    }
264}
265
266/// Minimal Qwen block composition (same block primitives at this stage).
267pub struct QwenBlock;
268
269impl QwenBlock {
270    pub fn forward(
271        input: &[f32],
272        norm_weight: &[f32],
273        w_gate: &[f32],
274        w_up: &[f32],
275        w_down: &[f32],
276        d_model: usize,
277        d_ff: usize,
278    ) -> Result<Vec<f32>, ModelError> {
279        LlamaBlock::forward(input, norm_weight, w_gate, w_up, w_down, d_model, d_ff)
280    }
281}
282
283fn matvec_row_major(x: &[f32], w: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
284    let mut out = vec![0.0; out_dim];
285    for i in 0..out_dim {
286        let mut acc = 0.0f32;
287        for j in 0..in_dim {
288            acc += x[j] * w[j * out_dim + i];
289        }
290        out[i] = acc;
291    }
292    out
293}
294
295fn silu(x: f32) -> f32 {
296    x / (1.0 + (-x).exp())
297}
298
299fn softmax(x: &[f32]) -> Vec<f32> {
300    let max = x.iter().copied().fold(f32::NEG_INFINITY, f32::max);
301    let mut exps: Vec<f32> = x.iter().map(|v| (v - max).exp()).collect();
302    let sum: f32 = exps.iter().sum();
303    if sum > 0.0 {
304        for e in &mut exps {
305            *e /= sum;
306        }
307    }
308    exps
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use safetensors::tensor::{serialize, TensorView};
315    use std::collections::BTreeMap;
316
317    fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
318        assert_eq!(actual.len(), expected.len());
319        for (idx, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
320            assert!(
321                (a - e).abs() <= tol,
322                "idx={} actual={} expected={} tol={}",
323                idx,
324                a,
325                e,
326                tol
327            );
328        }
329    }
330
331    #[test]
332    fn rms_norm_matches_reference() {
333        // Reference values computed from numpy for eps=1e-5.
334        let x = [1.0, 2.0, 3.0, 4.0];
335        let w = [1.0, 1.0, 1.0, 1.0];
336        let y = rms_norm(&x, &w, 1e-5).unwrap();
337        let expected = [0.36514813, 0.73029625, 1.0954444, 1.4605925];
338        assert_close(&y, &expected, 1e-5);
339    }
340
341    #[test]
342    fn rope_matches_reference() {
343        let mut q = [1.0, 0.0, 0.0, 1.0];
344        let mut k = q;
345        apply_rope(&mut q, &mut k, 1, 1, 4, 10_000.0).unwrap();
346        let expected = [0.5403023, 0.84147096, -0.009999833, 0.99995];
347        assert_close(&q, &expected, 1e-5);
348        assert_close(&k, &expected, 1e-5);
349    }
350
351    #[test]
352    fn attention_matches_reference() {
353        let query = [1.0, 0.0];
354        let keys = [1.0, 0.0, 0.0, 1.0]; // seq=2, heads=1, dim=2
355        let values = [10.0, 0.0, 0.0, 20.0];
356        let out = attention_decode(&query, &keys, &values, 2, 1, 2).unwrap();
357        let expected = [6.697615, 6.60477];
358        assert_close(&out, &expected, 1e-5);
359    }
360
361    #[test]
362    fn swiglu_and_blocks_produce_expected_shape() {
363        let x = [0.5, -1.0];
364        let norm = [1.0, 1.0];
365        let w_gate = [1.0, 0.0, 0.0, 1.0];
366        let w_up = [0.5, 0.0, 0.0, 0.5];
367        let w_down = [1.0, 0.0, 0.0, 1.0];
368
369        let y_llama = LlamaBlock::forward(&x, &norm, &w_gate, &w_up, &w_down, 2, 2).unwrap();
370        let y_qwen = QwenBlock::forward(&x, &norm, &w_gate, &w_up, &w_down, 2, 2).unwrap();
371
372        assert_eq!(y_llama.len(), 2);
373        assert_eq!(y_qwen.len(), 2);
374        assert_close(&y_llama, &y_qwen, 1e-6);
375    }
376
377    #[test]
378    fn loads_weights_from_safetensors() {
379        let tensor = [1.0f32, 2.0, 3.0, 4.0];
380        let view = TensorView::new(Dtype::F32, vec![2, 2], cast_slice(&tensor)).unwrap();
381        let mut map = BTreeMap::new();
382        map.insert("w".to_string(), view);
383        let bytes = serialize(map, &None).unwrap();
384
385        let weights = ModelWeights::load_safetensors_bytes(&bytes).unwrap();
386        let w = weights.get("w").unwrap();
387        assert_eq!(w.shape, vec![2, 2]);
388        assert_eq!(w.data, tensor);
389    }
390}