boostr 0.1.0

ML framework built on numr - attention, quantization, model architectures
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
//! Multi-Head Latent Attention (MLA) module
//!
//! DeepSeek-V2 style attention with low-rank KV compression.
//! Compresses KV cache from O(L * n_heads * head_dim * 2) to O(L * (kv_lora_rank + rope_head_dim)).
//!
//! Architecture:
//! - Q path: optional low-rank compression (q_down → norm → q_up) or direct projection
//! - KV path: compress → split (c_kv, k_pe) → norm c_kv → decompress → split (k_nope, v)
//! - Decoupled RoPE: applied only to q_pe and k_pe portions
//! - Attention: Q=[q_nope, q_pe], K=[k_nope, k_pe], V=v

use crate::error::{Error, Result};
use crate::nn::{Linear, MaybeQuantLinear, RmsNorm, RoPE, VarBuilder};
use crate::ops::RoPEOps;
use crate::ops::impl_generic::attention::mla::scaled_dot_product_attention_impl;
use crate::ops::impl_generic::attention::rope::apply_rope_impl;
use crate::quant::traits::QuantMatmulOps;
use numr::autograd::{Var, var_broadcast_to, var_cat, var_narrow, var_permute, var_reshape};
use numr::dtype::DType;
use numr::ops::{
    BinaryOps, NormalizationOps, ReduceOps, ScalarOps, ShapeOps, TensorOps, TypeConversionOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;

/// MLA configuration
#[derive(Debug, Clone)]
pub struct MlaConfig {
    /// Hidden dimension
    pub hidden_size: usize,
    /// Number of attention heads
    pub num_heads: usize,
    /// Dimension per head for Q/K nope portion
    pub head_dim: usize,
    /// Dimension per head for values (can differ from head_dim)
    pub head_dim_v: usize,
    /// KV compression latent dimension
    pub kv_lora_rank: usize,
    /// Q compression latent dimension (0 = no compression)
    pub q_lora_rank: usize,
    /// Decoupled RoPE dimension
    pub rope_head_dim: usize,
    /// Maximum sequence length
    pub max_seq_len: usize,
    /// RoPE base theta
    pub rope_theta: f32,
    /// Whether to use RMSNorm on compressed representations
    pub use_norm: bool,
    /// RMSNorm epsilon
    pub norm_eps: f32,
}

impl MlaConfig {
    /// Create config with DeepSeek-V2 defaults
    pub fn deepseek_v2(
        hidden_size: usize,
        num_heads: usize,
        kv_lora_rank: usize,
        q_lora_rank: usize,
        rope_head_dim: usize,
        max_seq_len: usize,
    ) -> Self {
        let head_dim = hidden_size / num_heads;
        Self {
            hidden_size,
            num_heads,
            head_dim,
            head_dim_v: head_dim,
            kv_lora_rank,
            q_lora_rank,
            rope_head_dim,
            max_seq_len,
            rope_theta: 10000.0,
            use_norm: true,
            norm_eps: 1e-6,
        }
    }

    pub fn validate(&self) -> Result<()> {
        if self.hidden_size == 0 || self.num_heads == 0 {
            return Err(Error::ModelError {
                reason: "hidden_size and num_heads must be > 0".into(),
            });
        }
        if self.kv_lora_rank == 0 {
            return Err(Error::ModelError {
                reason: "kv_lora_rank must be > 0 for MLA".into(),
            });
        }
        if self.rope_head_dim > self.head_dim {
            return Err(Error::ModelError {
                reason: format!(
                    "rope_head_dim ({}) > head_dim ({})",
                    self.rope_head_dim, self.head_dim
                ),
            });
        }
        Ok(())
    }

    /// Total Q/K dimension per head (nope + pe)
    pub fn qk_head_dim(&self) -> usize {
        self.head_dim + self.rope_head_dim
    }

    /// Whether Q uses low-rank compression
    pub fn q_uses_lora(&self) -> bool {
        self.q_lora_rank > 0
    }
}

/// Multi-Head Latent Attention (MLA) layer from DeepSeek-V2.
///
/// Implements low-rank KV compression with decoupled RoPE:
/// - **Q path**: optional low-rank compression (`q_down` → norm → `q_up`) or direct projection
/// - **KV path**: `kv_compress` → norm → `kv_decompress` → split into `(k_nope, v)`
/// - **RoPE**: applied separately to `q_pe` and `k_pe` (decoupled from compressed latent)
/// - **Output**: `softmax(Q·K^T / √d) · V` projected through `o_proj`
pub struct Mla<R: Runtime> {
    // Q path
    q_down: Option<MaybeQuantLinear<R>>,
    q_up: MaybeQuantLinear<R>,
    q_norm: Option<RmsNorm<R>>,

    // KV path
    kv_compress: MaybeQuantLinear<R>,
    kv_norm: Option<RmsNorm<R>>,
    kv_decompress: MaybeQuantLinear<R>,

    // Output
    o_proj: MaybeQuantLinear<R>,

    // RoPE
    rope: RoPE<R>,

    // Config
    num_heads: usize,
    head_dim: usize,
    head_dim_v: usize,
    rope_head_dim: usize,
    kv_lora_rank: usize,
    scale: f64,
}

impl<R: Runtime<DType = DType>> Mla<R> {
    /// Create MLA from config with random/zero weights (for testing/training)
    pub fn from_config(config: &MlaConfig, device: &R::Device) -> Result<Self> {
        config.validate()?;

        let h = config.hidden_size;
        let nh = config.num_heads;
        let qk_dim = config.qk_head_dim();
        let dt = DType::F32;

        let (q_down, q_up, q_norm) = if config.q_uses_lora() {
            let q_down = MaybeQuantLinear::Standard(Linear::new(
                Tensor::<R>::zeros(&[config.q_lora_rank, h], dt, device),
                None,
                true,
            ));
            let q_up = MaybeQuantLinear::Standard(Linear::new(
                Tensor::<R>::zeros(&[nh * qk_dim, config.q_lora_rank], dt, device),
                None,
                true,
            ));
            let q_norm = if config.use_norm {
                Some(RmsNorm::new(
                    Tensor::<R>::ones(&[config.q_lora_rank], dt, device),
                    config.norm_eps,
                    true,
                ))
            } else {
                None
            };
            (Some(q_down), q_up, q_norm)
        } else {
            let q_up = MaybeQuantLinear::Standard(Linear::new(
                Tensor::<R>::zeros(&[nh * qk_dim, h], dt, device),
                None,
                true,
            ));
            (None, q_up, None)
        };

        let kv_compress = MaybeQuantLinear::Standard(Linear::new(
            Tensor::<R>::zeros(&[config.kv_lora_rank + config.rope_head_dim, h], dt, device),
            None,
            true,
        ));
        let kv_norm = if config.use_norm {
            Some(RmsNorm::new(
                Tensor::<R>::ones(&[config.kv_lora_rank], dt, device),
                config.norm_eps,
                true,
            ))
        } else {
            None
        };
        let kv_decompress = MaybeQuantLinear::Standard(Linear::new(
            Tensor::<R>::zeros(
                &[
                    nh * (config.head_dim + config.head_dim_v),
                    config.kv_lora_rank,
                ],
                dt,
                device,
            ),
            None,
            true,
        ));

        let o_proj = MaybeQuantLinear::Standard(Linear::new(
            Tensor::<R>::zeros(&[h, nh * config.head_dim_v], dt, device),
            None,
            true,
        ));

        let rope = RoPE::<R>::precompute_freqs(
            config.max_seq_len,
            config.rope_head_dim,
            config.rope_theta,
            None,
            device,
        );

        let scale = 1.0 / (qk_dim as f64).sqrt();

        Ok(Self {
            q_down,
            q_up,
            q_norm,
            kv_compress,
            kv_norm,
            kv_decompress,
            o_proj,
            rope,
            num_heads: nh,
            head_dim: config.head_dim,
            head_dim_v: config.head_dim_v,
            rope_head_dim: config.rope_head_dim,
            kv_lora_rank: config.kv_lora_rank,
            scale,
        })
    }

    /// Load MLA from pretrained weights via VarBuilder
    ///
    /// Weight names follow HuggingFace DeepSeek-V2 conventions:
    /// - `q_a_proj` / `q_b_proj` (Q down/up if q_lora_rank > 0)
    /// - `q_proj` (direct Q if q_lora_rank == 0)
    /// - `q_a_layernorm`
    /// - `kv_a_proj_with_mqa` (KV compression)
    /// - `kv_a_layernorm`
    /// - `kv_b_proj` (KV decompression)
    /// - `o_proj`
    pub fn from_varbuilder(vb: &mut VarBuilder<R>, config: &MlaConfig) -> Result<Self> {
        config.validate()?;

        let nh = config.num_heads;
        let qk_dim = config.qk_head_dim();

        // Q path
        let (q_down, q_up, q_norm) = if config.q_uses_lora() {
            let q_down = vb.pp("q_a_proj").take_maybe_quant_linear("weight", None)?;
            let q_up = vb.pp("q_b_proj").take_maybe_quant_linear("weight", None)?;

            let q_norm = if config.use_norm {
                let mut qn_vb = vb.pp("q_a_layernorm");
                Some(RmsNorm::new(
                    qn_vb.take_tensor("weight")?,
                    config.norm_eps,
                    false,
                ))
            } else {
                None
            };
            (Some(q_down), q_up, q_norm)
        } else {
            let q_up = vb.pp("q_proj").take_maybe_quant_linear("weight", None)?;
            (None, q_up, None)
        };

        // KV path
        let kv_compress = vb
            .pp("kv_a_proj_with_mqa")
            .take_maybe_quant_linear("weight", None)?;

        let kv_norm = if config.use_norm {
            let mut kvn_vb = vb.pp("kv_a_layernorm");
            Some(RmsNorm::new(
                kvn_vb.take_tensor("weight")?,
                config.norm_eps,
                false,
            ))
        } else {
            None
        };

        let kv_decompress = vb.pp("kv_b_proj").take_maybe_quant_linear("weight", None)?;

        // Output
        let o_proj = vb.pp("o_proj").take_maybe_quant_linear("weight", None)?;

        // RoPE
        let rope = RoPE::<R>::precompute_freqs(
            config.max_seq_len,
            config.rope_head_dim,
            config.rope_theta,
            None,
            vb.device(),
        );

        let scale = 1.0 / (qk_dim as f64).sqrt();

        Ok(Self {
            q_down,
            q_up,
            q_norm,
            kv_compress,
            kv_norm,
            kv_decompress,
            o_proj,
            rope,
            num_heads: nh,
            head_dim: config.head_dim,
            head_dim_v: config.head_dim_v,
            rope_head_dim: config.rope_head_dim,
            kv_lora_rank: config.kv_lora_rank,
            scale,
        })
    }

    /// Forward pass: [B, S, hidden] → [B, S, hidden]
    pub fn forward<C>(&self, client: &C, hidden: &Var<R>) -> Result<Var<R>>
    where
        C: RuntimeClient<R>
            + TensorOps<R>
            + ScalarOps<R>
            + ReduceOps<R>
            + NormalizationOps<R>
            + ShapeOps<R>
            + BinaryOps<R>
            + TypeConversionOps<R>
            + QuantMatmulOps<R>
            + RoPEOps<R>,
        R::Client: TensorOps<R> + ScalarOps<R>,
    {
        let shape = hidden.shape().to_vec();
        let batch = shape[0];
        let seq_len = shape[1];
        let qk_dim = self.head_dim + self.rope_head_dim;

        // === Q path ===
        let q = if let Some(q_down) = &self.q_down {
            let q_latent = q_down.forward(client, hidden)?;
            let q_latent = if let Some(norm) = &self.q_norm {
                norm.forward(client, &q_latent)?
            } else {
                q_latent
            };
            self.q_up.forward(client, &q_latent)?
        } else {
            self.q_up.forward(client, hidden)?
        };

        // [B, S, num_heads * qk_dim] → [B, S, H, qk_dim] → [B, H, S, qk_dim]
        let q = var_reshape(&q, &[batch, seq_len, self.num_heads, qk_dim]).map_err(Error::Numr)?;
        let q = var_permute(&q, &[0, 2, 1, 3]).map_err(Error::Numr)?;
        let q = var_contiguous(&q);

        // Split Q into nope and pe
        let q_nope = var_narrow(&q, 3, 0, self.head_dim).map_err(Error::Numr)?;
        let q_nope = var_contiguous(&q_nope);
        let q_pe = var_narrow(&q, 3, self.head_dim, self.rope_head_dim).map_err(Error::Numr)?;
        let q_pe = var_contiguous(&q_pe);

        // === KV path ===
        // Compress: [B, S, hidden] → [B, S, kv_lora_rank + rope_head_dim]
        let kv_compressed = self.kv_compress.forward(client, hidden)?;

        // Split: c_kv [B, S, kv_lora_rank], k_pe_raw [B, S, rope_head_dim]
        let c_kv = var_narrow(&kv_compressed, 2, 0, self.kv_lora_rank).map_err(Error::Numr)?;
        let c_kv = var_contiguous(&c_kv);
        let k_pe_raw = var_narrow(&kv_compressed, 2, self.kv_lora_rank, self.rope_head_dim)
            .map_err(Error::Numr)?;
        let k_pe_raw = var_contiguous(&k_pe_raw);

        // Normalize c_kv
        let c_kv = if let Some(norm) = &self.kv_norm {
            norm.forward(client, &c_kv)?
        } else {
            c_kv
        };

        // Decompress: [B, S, kv_lora_rank] → [B, S, num_heads * (head_dim + head_dim_v)]
        let kv = self.kv_decompress.forward(client, &c_kv)?;
        let kv = var_reshape(
            &kv,
            &[
                batch,
                seq_len,
                self.num_heads,
                self.head_dim + self.head_dim_v,
            ],
        )
        .map_err(Error::Numr)?;
        // → [B, H, S, head_dim + head_dim_v]
        let kv = var_permute(&kv, &[0, 2, 1, 3]).map_err(Error::Numr)?;
        let kv = var_contiguous(&kv);

        // Split K_nope and V
        let k_nope = var_narrow(&kv, 3, 0, self.head_dim).map_err(Error::Numr)?;
        let k_nope = var_contiguous(&k_nope);
        let v = var_narrow(&kv, 3, self.head_dim, self.head_dim_v).map_err(Error::Numr)?;
        let v = var_contiguous(&v);

        // K_pe: [B, S, rope_head_dim] → [B, 1, S, rope_head_dim] → [B, H, S, rope_head_dim]
        let k_pe = var_reshape(&k_pe_raw, &[batch, 1, seq_len, self.rope_head_dim])
            .map_err(Error::Numr)?;
        let k_pe = var_broadcast_to(&k_pe, &[batch, self.num_heads, seq_len, self.rope_head_dim])
            .map_err(Error::Numr)?;
        let k_pe = var_contiguous(&k_pe);

        // Apply RoPE to q_pe and k_pe (decoupled)
        let q_pe = apply_rope_impl(client, &q_pe, self.rope.cos_cache(), self.rope.sin_cache())?;
        let k_pe = apply_rope_impl(client, &k_pe, self.rope.cos_cache(), self.rope.sin_cache())?;

        // Concatenate: Q = [q_nope, q_pe], K = [k_nope, k_pe]
        let q = var_cat(&[&q_nope, &q_pe], 3, client).map_err(Error::Numr)?;
        let k = var_cat(&[&k_nope, &k_pe], 3, client).map_err(Error::Numr)?;

        // Attention: Q,K [B, H, S, qk_dim], V [B, H, S, head_dim_v]
        let attn_out = scaled_dot_product_attention_impl(client, &q, &k, &v, self.scale, true)?;

        // [B, H, S, head_dim_v] → [B, S, H, head_dim_v] → [B, S, H*head_dim_v]
        let attn_out = var_permute(&attn_out, &[0, 2, 1, 3]).map_err(Error::Numr)?;
        let attn_out = var_contiguous(&attn_out);
        let attn_out = var_reshape(
            &attn_out,
            &[batch, seq_len, self.num_heads * self.head_dim_v],
        )
        .map_err(Error::Numr)?;

        // Output projection
        self.o_proj.forward(client, &attn_out)
    }
}

/// Make a Var contiguous (copies data if non-contiguous layout).
fn var_contiguous<R: Runtime>(v: &Var<R>) -> Var<R> {
    Var::new(v.tensor().contiguous(), v.requires_grad())
}