oxibonsai-model 0.1.4

Qwen3-8B Transformer implementation for OxiBonsai 1-bit inference
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
//! ALiBi (Attention with Linear Biases) positional encoding.
//!
//! ALiBi adds a fixed, non-learned linear bias to attention scores based on the
//! query-key distance. Unlike RoPE, it generalises to longer sequences than seen
//! during training without any learned parameters.
//!
//! Reference: Press et al. "Train Short, Test Long: Attention with Linear Biases
//! Enables Input Length Extrapolation" (ICLR 2022).
//!
//! ## Slope schedule
//!
//! For `n` heads, the slopes follow a geometric sequence:
//! ```text
//! start = 2^(-8/n)
//! slopes[i] = start^(i+1)   for i in 0..n
//! ```
//! For non-power-of-2 head counts an "extrapolated" variant fills in missing
//! slopes by interpolating between adjacent powers of 2.
//!
//! ## Bias formula
//!
//! For causal attention, the bias added to score `(q, k)` for head `h` is:
//! ```text
//! bias[h][q][k] = -slope_h * (q_pos - k_pos)   for k_pos <= q_pos
//! ```
//! This is always ≤ 0 and equals 0 only when `k_pos == q_pos`.

use crate::layers::attention::{dot, softmax};

// ─── AliBiSlopes ─────────────────────────────────────────────────────────────

/// Pre-computed ALiBi slope for each attention head.
///
/// Slopes follow a geometric sequence derived from the number of heads so
/// that each head focuses on a different recency scale.
#[derive(Debug, Clone)]
pub struct AliBiSlopes {
    slopes: Vec<f32>,
    num_heads: usize,
}

impl AliBiSlopes {
    /// Compute slopes for exactly `num_heads` heads (power-of-2 preferred).
    ///
    /// Formula: `slopes[i] = (2^(-8/n))^(i+1)` for `i` in `0..n`, where
    /// `n = num_heads`.
    pub fn new(num_heads: usize) -> Self {
        assert!(num_heads > 0, "num_heads must be > 0");
        let start = 2.0_f32.powf(-8.0 / num_heads as f32);
        let slopes: Vec<f32> = (1..=num_heads).map(|i| start.powi(i as i32)).collect();
        Self { slopes, num_heads }
    }

    /// Extrapolated variant for non-power-of-2 head counts.
    ///
    /// The paper recommends computing slopes for the nearest smaller and
    /// larger powers of 2, then interleaving them to fill `num_heads` slots.
    /// This gives better extrapolation behaviour than the basic formula when
    /// `num_heads` is not a power of 2.
    pub fn new_extrapolated(num_heads: usize) -> Self {
        assert!(num_heads > 0, "num_heads must be > 0");

        // Find nearest power of 2 >= num_heads
        let mut p = 1usize;
        while p < num_heads {
            p <<= 1;
        }

        // Slopes for the full power-of-2 count
        let start_p = 2.0_f32.powf(-8.0 / p as f32);
        let full_slopes: Vec<f32> = (1..=p).map(|i| start_p.powi(i as i32)).collect();

        if p == num_heads {
            // Exact power of 2 — no extrapolation needed
            return Self {
                slopes: full_slopes,
                num_heads,
            };
        }

        // Half count (nearest smaller power of 2)
        let half = p / 2;
        let start_half = 2.0_f32.powf(-8.0 / half as f32);
        let half_slopes: Vec<f32> = (1..=half).map(|i| start_half.powi(i as i32)).collect();

        // Interleave half_slopes and full_slopes (even positions from half,
        // odd positions from full) and take the first num_heads entries.
        // This mirrors the implementation in the original paper repo.
        let mut slopes: Vec<f32> = Vec::with_capacity(num_heads);
        let mut hi = half_slopes.iter();
        let mut fi = full_slopes.iter();
        for idx in 0..num_heads {
            if idx % 2 == 0 {
                // Take from half-count slopes when available, otherwise full
                if let Some(&s) = hi.next() {
                    slopes.push(s);
                } else if let Some(&s) = fi.next() {
                    slopes.push(s);
                }
            } else {
                // Take from full slopes when available, otherwise half
                if let Some(&s) = fi.next() {
                    slopes.push(s);
                } else if let Some(&s) = hi.next() {
                    slopes.push(s);
                }
            }
        }

        // Pad if needed (unlikely but safe)
        while slopes.len() < num_heads {
            let last = *slopes.last().expect("at least one slope computed");
            slopes.push(last * 0.5);
        }

        Self { slopes, num_heads }
    }

    /// Return all slopes as a slice.
    #[inline]
    pub fn slopes(&self) -> &[f32] {
        &self.slopes
    }

    /// Return the slope for head `head`.
    ///
    /// # Panics
    /// Panics if `head >= num_heads`.
    #[inline]
    pub fn get(&self, head: usize) -> f32 {
        self.slopes[head]
    }

    /// Number of heads these slopes were computed for.
    #[inline]
    pub fn num_heads(&self) -> usize {
        self.num_heads
    }
}

// ─── AliBiBias ───────────────────────────────────────────────────────────────

/// Computes and applies ALiBi bias matrices for a set of attention heads.
///
/// The bias for head `h`, query position `q_pos`, and key position `k` is:
/// ```text
/// bias = -slope_h * (q_pos - k)   for k in 0..kv_len
/// ```
/// All values are ≤ 0 (zero at `k == q_pos`), penalising earlier tokens
/// proportionally to their distance from the current query.
pub struct AliBiBias {
    /// Pre-computed slopes, one per head.
    pub slopes: AliBiSlopes,
}

impl AliBiBias {
    /// Create a new `AliBiBias` for `num_heads` heads using the standard
    /// slope schedule.
    pub fn new(num_heads: usize) -> Self {
        Self {
            slopes: AliBiSlopes::new(num_heads),
        }
    }

    /// Compute the ALiBi bias vector for a single head at query position
    /// `q_pos` over `kv_len` key positions.
    ///
    /// Returns a `Vec<f32>` of length `kv_len` where:
    /// ```text
    /// result[k] = -slope * (q_pos - k)
    /// ```
    pub fn bias_for_head(&self, head: usize, q_pos: usize, kv_len: usize) -> Vec<f32> {
        let slope = self.slopes.get(head);
        (0..kv_len)
            .map(|k| {
                let distance = q_pos as f32 - k as f32;
                -slope * distance
            })
            .collect()
    }

    /// Compute bias vectors for all heads at query position `q_pos`.
    ///
    /// Returns shape `[num_heads][kv_len]`.
    pub fn biases_all_heads(&self, q_pos: usize, kv_len: usize) -> Vec<Vec<f32>> {
        (0..self.slopes.num_heads())
            .map(|head| self.bias_for_head(head, q_pos, kv_len))
            .collect()
    }

    /// Add ALiBi biases to attention scores in-place.
    ///
    /// `scores` must have shape `[num_heads][kv_len]` where `kv_len =
    /// scores[0].len()`.
    pub fn apply(&self, scores: &mut [Vec<f32>], q_pos: usize) {
        let kv_len = scores.first().map(|s| s.len()).unwrap_or(0);
        let biases = self.biases_all_heads(q_pos, kv_len);
        for (head_scores, head_biases) in scores.iter_mut().zip(biases.iter()) {
            for (s, b) in head_scores.iter_mut().zip(head_biases.iter()) {
                *s += b;
            }
        }
    }

    /// Compute biases for an entire sequence of query positions.
    ///
    /// - `q_len`: number of query positions (starting at `q_offset`).
    /// - `kv_len`: number of key/value positions.
    /// - `q_offset`: absolute position of the first query token.
    ///
    /// Returns shape `[q_len][num_heads][kv_len]`.
    pub fn biases_for_sequence(
        &self,
        q_len: usize,
        kv_len: usize,
        q_offset: usize,
    ) -> Vec<Vec<Vec<f32>>> {
        (0..q_len)
            .map(|qi| self.biases_all_heads(q_offset + qi, kv_len))
            .collect()
    }
}

// ─── AliBiConfig ─────────────────────────────────────────────────────────────

/// Configuration for ALiBi-enhanced attention.
#[derive(Debug, Clone)]
pub struct AliBiConfig {
    /// Number of query attention heads.
    pub num_heads: usize,
    /// Whether to use the extrapolated slope schedule for non-power-of-2
    /// head counts.
    pub use_extrapolated_slopes: bool,
    /// Whether to apply a causal mask (key positions after the query
    /// position receive `-∞` attention score). Almost always `true`.
    pub causal: bool,
}

impl Default for AliBiConfig {
    fn default() -> Self {
        Self {
            num_heads: 8,
            use_extrapolated_slopes: false,
            causal: true,
        }
    }
}

// ─── attention_with_alibi ────────────────────────────────────────────────────

/// Compute grouped-query attention augmented with ALiBi positional biases.
///
/// This is a single-token (decode-step) attention kernel. It reads one
/// query vector per head, attends over all `kv_len` key/value positions,
/// adds the ALiBi bias, applies the causal mask when `config.causal`, and
/// returns the concatenated head outputs.
///
/// # Arguments
///
/// - `query`:       flattened query tensor `[num_heads * head_dim]`
/// - `keys`:        flattened key cache `[kv_len * num_kv_heads * head_dim]`
///   (tokens in outermost dimension)
/// - `values`:      flattened value cache `[kv_len * num_kv_heads * head_dim]`
/// - `config`:      ALiBi attention configuration
/// - `head_dim`:    dimension of each attention head
/// - `num_kv_heads`: number of key-value heads (GQA)
/// - `q_pos`:       absolute position of the query token
///
/// # Returns
///
/// Flattened output `[num_heads * head_dim]`.
#[allow(clippy::too_many_arguments)]
pub fn attention_with_alibi(
    query: &[f32],
    keys: &[f32],
    values: &[f32],
    config: &AliBiConfig,
    head_dim: usize,
    num_kv_heads: usize,
    q_pos: usize,
) -> Vec<f32> {
    let num_heads = config.num_heads;
    debug_assert!(num_kv_heads > 0, "num_kv_heads must be > 0");
    debug_assert_eq!(query.len(), num_heads * head_dim);
    let kv_len = if num_kv_heads > 0 && head_dim > 0 {
        keys.len() / (num_kv_heads * head_dim)
    } else {
        0
    };

    let scale = 1.0_f32 / (head_dim as f32).sqrt();
    let heads_per_kv = num_heads / num_kv_heads;

    let alibi = if config.use_extrapolated_slopes {
        AliBiBias {
            slopes: AliBiSlopes::new_extrapolated(num_heads),
        }
    } else {
        AliBiBias::new(num_heads)
    };

    let mut output = vec![0.0_f32; num_heads * head_dim];

    for q_head in 0..num_heads {
        let kv_head = q_head / heads_per_kv;
        let q_start = q_head * head_dim;
        let q_vec = &query[q_start..q_start + head_dim];

        // Compute raw dot-product scores
        let mut scores: Vec<f32> = (0..kv_len)
            .map(|t| {
                // keys layout: [kv_len * num_kv_heads * head_dim], token-major
                let k_start = t * num_kv_heads * head_dim + kv_head * head_dim;
                let k_vec = &keys[k_start..k_start + head_dim];
                dot(q_vec, k_vec) * scale
            })
            .collect();

        // Add ALiBi bias
        let biases = alibi.bias_for_head(q_head, q_pos, kv_len);
        for (s, b) in scores.iter_mut().zip(biases.iter()) {
            *s += b;
        }

        // Apply causal mask: future tokens → -∞
        if config.causal {
            for (k, s) in scores.iter_mut().enumerate() {
                if k > q_pos {
                    *s = f32::NEG_INFINITY;
                }
            }
        }

        softmax(&mut scores);

        // Weighted sum of values
        let out_start = q_head * head_dim;
        for d in 0..head_dim {
            let mut acc = 0.0_f32;
            for (t, &score_t) in scores.iter().enumerate().take(kv_len) {
                let v_start = t * num_kv_heads * head_dim + kv_head * head_dim;
                acc += score_t * values[v_start + d];
            }
            output[out_start + d] = acc;
        }
    }

    output
}

// ─── Tests ───────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::*;

    // ── slope tests ──────────────────────────────────────────────────────────

    #[test]
    fn test_alibi_slopes_power_of_2() {
        // 8 heads is a power of 2 — both constructors should give the same result
        let s1 = AliBiSlopes::new(8);
        let s2 = AliBiSlopes::new_extrapolated(8);
        assert_eq!(s1.num_heads(), 8);
        assert_eq!(s2.num_heads(), 8);
        for i in 0..8 {
            assert!(
                (s1.get(i) - s2.get(i)).abs() < 1e-6,
                "slope mismatch at head {i}: {} vs {}",
                s1.get(i),
                s2.get(i)
            );
        }
    }

    #[test]
    fn test_alibi_slopes_8_heads() {
        let slopes = AliBiSlopes::new(8);
        assert_eq!(slopes.num_heads(), 8);
        assert_eq!(slopes.slopes().len(), 8);
        // start = 2^(-8/8) = 2^(-1) = 0.5
        // slopes[0] = 0.5^1 = 0.5
        let expected_first = 0.5_f32;
        assert!(
            (slopes.get(0) - expected_first).abs() < 1e-5,
            "first slope: got {}, expected {}",
            slopes.get(0),
            expected_first
        );
        // slopes[7] = 0.5^8 ≈ 0.00390625
        let expected_last = 0.5_f32.powi(8);
        assert!(
            (slopes.get(7) - expected_last).abs() < 1e-7,
            "last slope: got {}, expected {}",
            slopes.get(7),
            expected_last
        );
    }

    #[test]
    fn test_alibi_slopes_decreasing() {
        // Each successive slope must be strictly smaller
        let slopes = AliBiSlopes::new(16);
        for i in 1..16 {
            assert!(
                slopes.get(i) < slopes.get(i - 1),
                "slopes not strictly decreasing at index {i}: {} >= {}",
                slopes.get(i),
                slopes.get(i - 1)
            );
        }
    }

    // ── bias tests ───────────────────────────────────────────────────────────

    #[test]
    fn test_alibi_bias_zero_distance() {
        // At k == q_pos the distance is 0, so bias must be exactly 0
        let bias = AliBiBias::new(4);
        for head in 0..4 {
            let q_pos = 7;
            let biases = bias.bias_for_head(head, q_pos, q_pos + 1);
            let at_q = biases[q_pos];
            assert!(
                at_q.abs() < 1e-8,
                "head {head}: expected zero bias at q_pos={q_pos}, got {at_q}"
            );
        }
    }

    #[test]
    fn test_alibi_bias_increases_with_distance() {
        // Bias is -slope * (q - k), so bias[k] < bias[k+1] (further away = more negative)
        let bias = AliBiBias::new(4);
        let q_pos = 10;
        let kv_len = 11;
        for head in 0..4 {
            let biases = bias.bias_for_head(head, q_pos, kv_len);
            // biases[0] < biases[1] < ... < biases[q_pos] == 0
            for k in 1..=q_pos {
                assert!(
                    biases[k] > biases[k - 1],
                    "head {head}: bias should increase with k, but biases[{k}]={} <= biases[{}]={}",
                    biases[k],
                    k - 1,
                    biases[k - 1]
                );
            }
        }
    }

    #[test]
    fn test_alibi_biases_all_heads_shape() {
        let num_heads = 6;
        let kv_len = 20;
        let bias = AliBiBias::new(num_heads);
        let all = bias.biases_all_heads(15, kv_len);
        assert_eq!(all.len(), num_heads, "outer dim must equal num_heads");
        for (h, row) in all.iter().enumerate() {
            assert_eq!(row.len(), kv_len, "head {h}: inner dim must equal kv_len");
        }
    }

    #[test]
    fn test_alibi_apply_modifies_scores() {
        let num_heads = 4;
        let kv_len = 5;
        let bias = AliBiBias::new(num_heads);

        // Start with all-zero scores
        let mut scores: Vec<Vec<f32>> = vec![vec![0.0_f32; kv_len]; num_heads];
        let q_pos = 4; // last position
        bias.apply(&mut scores, q_pos);

        // After apply, position q_pos should have bias == 0
        for (head, scores_head) in scores.iter().enumerate() {
            assert!(
                scores_head[q_pos].abs() < 1e-8,
                "head {head}: score at q_pos should be 0 after ALiBi, got {}",
                scores_head[q_pos]
            );
            // Positions before q_pos should be negative
            for (k, &score_k) in scores_head[..q_pos].iter().enumerate() {
                assert!(
                    score_k < 0.0,
                    "head {head}: score at k={k} should be negative, got {}",
                    score_k
                );
            }
        }
    }

    #[test]
    fn test_alibi_biases_for_sequence_shape() {
        let num_heads = 4;
        let q_len = 3;
        let kv_len = 8;
        let q_offset = 5;
        let bias = AliBiBias::new(num_heads);
        let seq_biases = bias.biases_for_sequence(q_len, kv_len, q_offset);

        assert_eq!(seq_biases.len(), q_len, "outer dim must equal q_len");
        for (qi, head_biases) in seq_biases.iter().enumerate() {
            assert_eq!(
                head_biases.len(),
                num_heads,
                "q={qi}: second dim must equal num_heads"
            );
            for (h, kv_biases) in head_biases.iter().enumerate() {
                assert_eq!(
                    kv_biases.len(),
                    kv_len,
                    "q={qi} h={h}: inner dim must equal kv_len"
                );
            }
        }
    }

    // ── attention_with_alibi tests ────────────────────────────────────────────

    #[test]
    fn test_attention_with_alibi_output_shape() {
        let num_heads = 4;
        let num_kv_heads = 2;
        let head_dim = 8;
        let kv_len = 5;
        let q_pos = 4;

        let query = vec![0.1_f32; num_heads * head_dim];
        // layout: [kv_len][num_kv_heads][head_dim]
        let keys = vec![0.05_f32; kv_len * num_kv_heads * head_dim];
        let values = vec![0.2_f32; kv_len * num_kv_heads * head_dim];

        let config = AliBiConfig {
            num_heads,
            use_extrapolated_slopes: false,
            causal: true,
        };

        let output = attention_with_alibi(
            &query,
            &keys,
            &values,
            &config,
            head_dim,
            num_kv_heads,
            q_pos,
        );

        assert_eq!(
            output.len(),
            num_heads * head_dim,
            "output length must be num_heads * head_dim"
        );
        // All outputs should be finite
        for (i, &v) in output.iter().enumerate() {
            assert!(v.is_finite(), "output[{i}] = {v} is not finite");
        }
    }

    #[test]
    fn test_alibi_extrapolated_slopes() {
        // 12 is not a power of 2 — extrapolated variant must produce 12 slopes
        let slopes = AliBiSlopes::new_extrapolated(12);
        assert_eq!(slopes.num_heads(), 12);
        assert_eq!(slopes.slopes().len(), 12);
        // All slopes must be in (0, 1)
        for (i, &s) in slopes.slopes().iter().enumerate() {
            assert!(
                s > 0.0 && s < 1.0,
                "extrapolated slope[{i}] = {s} out of (0,1)"
            );
        }
    }
}