llama-gguf 0.14.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
//! Gated DeltaNet (linear attention with delta rule) for Qwen3Next recurrent layers.
//!
//! Implements the autoregressive path of the delta rule following llama.cpp's
//! `build_delta_net_autoregressive` and `build_qwen3next_linear_attn`.
//!
//! The state update per value head:
//!   S_t = S_{t-1} * exp(gate) + beta * (v - S_{t-1}^T @ k) ⊗ k^T
//!   output = S_t @ q
//!
//! Where gate = softplus(alpha + dt_bias) * ssm_a (negative → decay).

use crate::backend::Backend;
use crate::tensor::{DType, Tensor};

use super::error::ModelResult;
use super::layers::{Linear, RMSNorm};
use super::mamba::{MambaConfig, MambaState};

/// Configuration for a DeltaNet layer, derived from GGUF SSM metadata.
#[derive(Debug, Clone)]
pub struct DeltaNetConfig {
    pub d_inner: usize,
    pub d_state: usize,
    pub num_v_heads: usize,
    pub num_k_heads: usize,
    pub head_v_dim: usize,
    pub head_k_dim: usize,
    pub conv_kernel: usize,
    pub qkv_dim: usize,
}

/// Beta/alpha projection: combined or separate (Qwen3.5 uses separate tensors).
pub enum BetaAlphaProjection {
    /// Combined beta+alpha projection (Qwen3Next DeltaNet): [hidden_size, 2 * num_v_heads]
    Combined(Linear),
    /// Separate beta and alpha projections (Qwen3.5): each [hidden_size, num_v_heads]
    Separate { beta: Linear, alpha: Linear },
}

/// Gated DeltaNet layer for recurrent (non-attention) layers.
pub struct DeltaNetLayer {
    pub config: DeltaNetConfig,
    /// Combined QKV projection [hidden_size, qkv_dim]
    pub attn_qkv: Linear,
    /// Output gate projection [hidden_size, d_inner]
    pub attn_gate: Linear,
    /// Beta + Alpha projection (combined or separate)
    pub ssm_ba: BetaAlphaProjection,
    /// 1D convolution kernel [conv_kernel, qkv_dim]
    pub ssm_conv1d_weight: Tensor,
    /// Decay multiplier per value head [num_v_heads] (negative values → state decays)
    pub ssm_a: Tensor,
    /// Decay bias per value head [num_v_heads]
    pub ssm_dt_bias: Tensor,
    /// Per-head output RMS normalization [head_v_dim]
    pub ssm_norm: RMSNorm,
    /// Output projection [d_inner, hidden_size]
    pub ssm_out: Linear,
}

impl std::fmt::Debug for DeltaNetLayer {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("DeltaNetLayer")
            .field("config", &self.config)
            .finish()
    }
}

/// Per-layer recurrent state for DeltaNet.
#[derive(Debug, Clone)]
pub struct DeltaNetState {
    /// Convolution ring buffer: last (kernel_size - 1) QKV vectors.
    /// Layout: [(kernel_size - 1), qkv_dim] stored as [(kernel_size-1) * qkv_dim]
    pub conv_state: Vec<f32>,
    /// SSM state matrices: one [head_v_dim, head_k_dim] per value head.
    /// Flat: [num_v_heads * head_v_dim * head_k_dim]
    pub ssm_state: Vec<f32>,
}

/// Per-layer recurrent state: either DeltaNet (Qwen3Next) or Mamba.
#[derive(Debug, Clone)]
pub enum RecurrentLayerState {
    DeltaNet(DeltaNetState),
    Mamba(MambaState),
}

impl RecurrentLayerState {
    pub fn reset(&mut self) {
        match self {
            Self::DeltaNet(ds) => {
                ds.conv_state.fill(0.0);
                ds.ssm_state.fill(0.0);
            }
            Self::Mamba(ms) => ms.reset(),
        }
    }
}

/// Configuration for recurrent layers (DeltaNet or Mamba).
#[derive(Debug, Clone)]
pub enum RecurrentConfig {
    DeltaNet(DeltaNetConfig),
    Mamba(MambaConfig),
}

/// Recurrent state for all layers (Qwen3Next DeltaNet or Mamba).
#[derive(Debug, Clone)]
pub struct RecurrentState {
    pub states: Vec<Option<RecurrentLayerState>>,
}

impl RecurrentState {
    pub fn new(num_layers: usize, is_recurrent: &[bool], config: &RecurrentConfig) -> Self {
        let states = (0..num_layers)
            .map(|i| {
                if i < is_recurrent.len() && is_recurrent[i] {
                    Some(match config {
                        RecurrentConfig::DeltaNet(c) => {
                            let conv_len = (c.conv_kernel - 1) * c.qkv_dim;
                            let ssm_len = c.num_v_heads * c.head_v_dim * c.head_k_dim;
                            RecurrentLayerState::DeltaNet(DeltaNetState {
                                conv_state: vec![0.0; conv_len],
                                ssm_state: vec![0.0; ssm_len],
                            })
                        }
                        RecurrentConfig::Mamba(c) => {
                            RecurrentLayerState::Mamba(MambaState::new(c))
                        }
                    })
                } else {
                    None
                }
            })
            .collect();
        Self { states }
    }

    pub fn reset(&mut self) {
        for s in self.states.iter_mut().flatten() {
            s.reset();
        }
    }
}

impl DeltaNetLayer {
    /// Forward pass for a single token (autoregressive decode).
    ///
    /// x: [hidden_size]
    /// state: mutable DeltaNetState for this layer
    /// Returns: [hidden_size]
    pub fn forward(
        &self,
        x: &Tensor,
        state: &mut DeltaNetState,
        backend: &dyn Backend,
    ) -> ModelResult<Tensor> {
        let cfg = &self.config;
        let hidden_size = x.shape().last().copied().unwrap_or(0);

        // 1. Project to QKV space and gate (z)
        let mut qkv = Tensor::zeros(vec![cfg.qkv_dim], DType::F32);
        self.attn_qkv.forward(x, &mut qkv, backend)?;

        let mut z_raw = Tensor::zeros(vec![cfg.d_inner], DType::F32);
        self.attn_gate.forward(x, &mut z_raw, backend)?;

        // 2. Project to beta/alpha
        let qkv_data = qkv.as_f32()?.to_vec();

        let mut beta = vec![0.0f32; cfg.num_v_heads];
        let mut alpha = vec![0.0f32; cfg.num_v_heads];

        match &self.ssm_ba {
            BetaAlphaProjection::Combined(ba_proj) => {
                let mut ba_raw = Tensor::zeros(vec![cfg.num_v_heads * 2], DType::F32);
                ba_proj.forward(x, &mut ba_raw, backend)?;
                let ba_data = ba_raw.as_f32()?;

                let kv_ratio = cfg.num_v_heads / cfg.num_k_heads.max(1);
                let ba_per_group = 2 * kv_ratio;

                for kh in 0..cfg.num_k_heads {
                    let group_offset = kh * ba_per_group;
                    for r in 0..kv_ratio {
                        let vh = kh * kv_ratio + r;
                        beta[vh] = sigmoid(ba_data[group_offset + r]);
                        alpha[vh] = ba_data[group_offset + kv_ratio + r];
                    }
                }
            }
            BetaAlphaProjection::Separate {
                beta: beta_proj,
                alpha: alpha_proj,
            } => {
                let mut beta_raw = Tensor::zeros(vec![cfg.num_v_heads], DType::F32);
                let mut alpha_raw = Tensor::zeros(vec![cfg.num_v_heads], DType::F32);
                beta_proj.forward(x, &mut beta_raw, backend)?;
                alpha_proj.forward(x, &mut alpha_raw, backend)?;
                let beta_data = beta_raw.as_f32()?;
                let alpha_data = alpha_raw.as_f32()?;
                for h in 0..cfg.num_v_heads {
                    beta[h] = sigmoid(beta_data[h]);
                    alpha[h] = alpha_data[h];
                }
            }
        }

        // 3. Compute gate (decay): gate = softplus(alpha + dt_bias) * ssm_a
        let ssm_a_data = self.ssm_a.as_f32()?;
        let dt_bias_data = self.ssm_dt_bias.as_f32()?;
        let mut gate = vec![0.0f32; cfg.num_v_heads];
        for h in 0..cfg.num_v_heads {
            gate[h] = softplus(alpha[h] + dt_bias_data[h]) * ssm_a_data[h];
        }

        // 4. Causal 1D convolution on QKV
        let conv_out = self.apply_conv1d(&qkv_data, state)?;

        // 5. Apply SiLU to the convolution output
        let mut conv_silu: Vec<f32> = conv_out.iter().map(|&x| silu(x)).collect();

        // 6. Split into Q, K, V and apply L2 normalization
        let q_dim = cfg.num_k_heads * cfg.head_k_dim;
        let k_dim = cfg.num_k_heads * cfg.head_k_dim;

        let (q_raw, rest) = conv_silu.split_at_mut(q_dim);
        let (k_raw, v_raw) = rest.split_at_mut(k_dim);

        let l2_eps = 1e-6_f32;

        // L2-normalize Q and K per head
        for h in 0..cfg.num_k_heads {
            let offset = h * cfg.head_k_dim;
            l2_normalize_inplace(&mut q_raw[offset..offset + cfg.head_k_dim], l2_eps);
            l2_normalize_inplace(&mut k_raw[offset..offset + cfg.head_k_dim], l2_eps);
        }

        // Scale Q by 1/sqrt(head_k_dim) as in llama.cpp's build_delta_net_autoregressive
        let q_scale = 1.0 / (cfg.head_k_dim as f32).sqrt();
        for q in q_raw.iter_mut() {
            *q *= q_scale;
        }

        // 7. Repeat-interleave Q and K if num_k_heads != num_v_heads
        let q_expanded: Vec<f32>;
        let k_expanded: Vec<f32>;
        let kv_ratio = cfg.num_v_heads / cfg.num_k_heads.max(1);
        if cfg.num_k_heads != cfg.num_v_heads {
            q_expanded = repeat_tile(q_raw, cfg.num_k_heads, cfg.head_k_dim, kv_ratio);
            k_expanded = repeat_tile(k_raw, cfg.num_k_heads, cfg.head_k_dim, kv_ratio);
        } else {
            q_expanded = q_raw.to_vec();
            k_expanded = k_raw.to_vec();
        }

        // 8. Delta rule update per value head
        let mut output = vec![0.0f32; cfg.d_inner];

        for vh in 0..cfg.num_v_heads {
            let s_offset = vh * cfg.head_v_dim * cfg.head_k_dim;
            let v_offset = vh * cfg.head_v_dim;
            let q_offset = vh * cfg.head_k_dim;
            let k_offset = vh * cfg.head_k_dim;
            let o_offset = vh * cfg.head_v_dim;

            let s = &mut state.ssm_state[s_offset..s_offset + cfg.head_v_dim * cfg.head_k_dim];

            // Decay state: s = s * exp(gate[vh])
            let decay = gate[vh].exp().min(1e10);
            for x in s.iter_mut() {
                *x *= decay;
            }

            // sk = s^T @ k  → [head_v_dim]
            let mut sk = vec![0.0f32; cfg.head_v_dim];
            #[allow(clippy::needless_range_loop)]
            for vi in 0..cfg.head_v_dim {
                let row_start = vi * cfg.head_k_dim;
                let mut dot = 0.0f32;
                for ki in 0..cfg.head_k_dim {
                    dot += s[row_start + ki] * k_expanded[k_offset + ki];
                }
                sk[vi] = dot;
            }

            // delta = (v - sk) * beta[vh]  → [head_v_dim]
            let b = beta[vh];
            let mut delta = vec![0.0f32; cfg.head_v_dim];
            #[allow(clippy::needless_range_loop)]
            for vi in 0..cfg.head_v_dim {
                delta[vi] = (v_raw[v_offset + vi] - sk[vi]) * b;
            }

            // State update: s += delta @ k^T  → outer product
            #[allow(clippy::needless_range_loop)]
            for vi in 0..cfg.head_v_dim {
                let row_start = vi * cfg.head_k_dim;
                for ki in 0..cfg.head_k_dim {
                    s[row_start + ki] += delta[vi] * k_expanded[k_offset + ki];
                }
            }

            // Output: o = s @ q  → [head_v_dim]
            for vi in 0..cfg.head_v_dim {
                let row_start = vi * cfg.head_k_dim;
                let mut dot = 0.0f32;
                for ki in 0..cfg.head_k_dim {
                    dot += s[row_start + ki] * q_expanded[q_offset + ki];
                }
                output[o_offset + vi] = dot;
            }
        }

        // 9. Gated normalization: result = rms_norm(output) * silu(z)
        let norm_w = self.ssm_norm.weight.as_f32()?;
        let norm_eps = self.ssm_norm.eps;
        let z_data = z_raw.as_f32()?;

        for vh in 0..cfg.num_v_heads {
            let offset = vh * cfg.head_v_dim;
            let ss: f32 = output[offset..offset + cfg.head_v_dim]
                .iter()
                .map(|x| x * x)
                .sum::<f32>()
                / cfg.head_v_dim as f32;
            let rms = (ss + norm_eps).sqrt();
            for d in 0..cfg.head_v_dim {
                let normed = output[offset + d] / rms * norm_w[d % norm_w.len()];
                output[offset + d] = normed * silu(z_data[offset + d]);
            }
        }

        // 10. Output projection
        let output_tensor = Tensor::from_f32(&output, vec![cfg.d_inner])?;
        let mut result = Tensor::zeros(vec![hidden_size], DType::F32);
        self.ssm_out.forward(&output_tensor, &mut result, backend)?;

        Ok(result)
    }

    /// Apply 1D causal depthwise convolution using the ring buffer state.
    ///
    /// State holds the last (kernel_size - 1) inputs. The convolution uses
    /// the state positions for kernel taps 0..ks-2, and the current input qkv
    /// for the final tap (ks-1). The state is updated AFTER the convolution.
    ///
    /// Conv weight layout (GGML): [kernel_size, channels] → data[ch * kernel_size + ki]
    fn apply_conv1d(
        &self,
        qkv: &[f32],
        state: &mut DeltaNetState,
    ) -> ModelResult<Vec<f32>> {
        let cfg = &self.config;
        let channels = cfg.qkv_dim;
        let ks = cfg.conv_kernel;
        let buf_len = ks - 1;

        let conv_w = self.ssm_conv1d_weight.as_f32()?;

        // Depthwise 1D convolution: out[ch] = sum_k(input[k][ch] * weight[ch][k])
        // State holds positions [t-(ks-1), ..., t-1], current qkv is position t
        let mut out = vec![0.0f32; channels];

        for ch in 0..channels {
            let mut sum = 0.0f32;
            for ki in 0..buf_len {
                sum += state.conv_state[ki * channels + ch] * conv_w[ch * ks + ki];
            }
            sum += qkv[ch] * conv_w[ch * ks + (ks - 1)];
            out[ch] = sum;
        }

        // Update state: shift left and append current qkv
        if buf_len > 1 {
            state
                .conv_state
                .copy_within(channels..buf_len * channels, 0);
        }
        let last_start = (buf_len - 1) * channels;
        state.conv_state[last_start..last_start + channels].copy_from_slice(qkv);

        Ok(out)
    }
}

/// Tile Q/K heads: expand [num_k_heads, head_dim] to [num_v_heads, head_dim]
/// by tiling the heads (matching ggml_repeat_4d behavior).
///
/// For num_k_heads=16 repeated 2x → [h0..h15, h0..h15] (NOT interleaved).
fn repeat_tile(data: &[f32], num_heads: usize, head_dim: usize, repeat: usize) -> Vec<f32> {
    let mut out = vec![0.0f32; num_heads * repeat * head_dim];
    for r in 0..repeat {
        let dst_base = r * num_heads * head_dim;
        out[dst_base..dst_base + num_heads * head_dim]
            .copy_from_slice(&data[..num_heads * head_dim]);
    }
    out
}

#[inline]
fn silu(x: f32) -> f32 {
    x / (1.0 + (-x).exp())
}

#[inline]
fn sigmoid(x: f32) -> f32 {
    1.0 / (1.0 + (-x).exp())
}

#[inline]
fn softplus(x: f32) -> f32 {
    if x > 20.0 {
        x
    } else {
        (1.0 + x.exp()).ln()
    }
}

fn l2_normalize_inplace(v: &mut [f32], eps: f32) {
    let sum_sq: f32 = v.iter().map(|x| x * x).sum();
    let norm = (sum_sq + eps).sqrt();
    let inv = 1.0 / norm;
    for x in v.iter_mut() {
        *x *= inv;
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_silu() {
        assert!((silu(0.0) - 0.0).abs() < 1e-6);
        assert!((silu(1.0) - 0.7310586).abs() < 1e-4);
    }

    #[test]
    fn test_sigmoid() {
        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
    }

    #[test]
    fn test_softplus() {
        assert!((softplus(0.0) - 0.6931).abs() < 1e-3);
        assert!((softplus(25.0) - 25.0).abs() < 1e-6);
    }

    #[test]
    fn test_l2_normalize() {
        let mut v = vec![3.0, 4.0];
        l2_normalize_inplace(&mut v, 1e-6);
        assert!((v[0] - 0.6).abs() < 1e-4);
        assert!((v[1] - 0.8).abs() < 1e-4);
    }

    #[test]
    fn test_repeat_tile() {
        let data = vec![1.0, 2.0, 3.0, 4.0]; // 2 heads, dim=2
        let out = repeat_tile(&data, 2, 2, 3); // tile 3x → 6 heads
        // Tiled: [h0, h1, h0, h1, h0, h1]
        assert_eq!(out, vec![1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]);
    }
}