Skip to main content

provable_contracts/kernels/
alibi.rs

1//! ALiBi (Attention with Linear Biases) kernel.
2//!
3//! Matches `alibi-kernel-v1.yaml`.
4//! `scores[i,j] += -m_h * |i - j|` where `m_h = 2^(-8h/H)`.
5//!
6//! Each function provides one of three backends:
7//! - `fn alibi_bias_scalar(...)` -- Pure Rust scalar reference (ground truth)
8//! - `unsafe fn alibi_bias_avx2(...)` -- AVX2 SIMD implementation
9//! - `fn alibi_ptx() -> &'static str` -- PTX assembly source string
10
11// ────────────────────────────────────────────────────────────────────────────
12// Scalar implementation
13// ────────────────────────────────────────────────────────────────────────────
14
15/// Compute ALiBi slope for head `h` of `num_heads` total.
16///
17/// `m_h = 2^(-8 * (h+1) / num_heads)`
18#[inline]
19pub fn alibi_slope(h: usize, num_heads: usize) -> f32 {
20    let exponent = -8.0 * ((h + 1) as f32) / (num_heads as f32);
21    2.0f32.powf(exponent)
22}
23
24/// Add ALiBi bias to attention scores (scalar reference).
25///
26/// `scores` is `num_heads x seq_len x seq_len` (row-major).
27/// For head h, position i, position j: `scores[h,i,j] += -m_h * |i - j|`.
28///
29/// # Panics
30/// Panics if dimensions don't match.
31pub fn alibi_bias_scalar(scores: &mut [f32], num_heads: usize, seq_len: usize) {
32    assert_eq!(
33        scores.len(),
34        num_heads * seq_len * seq_len,
35        "scores dimension mismatch"
36    );
37
38    let head_stride = seq_len * seq_len;
39
40    for h in 0..num_heads {
41        let slope = alibi_slope(h, num_heads);
42        let base = h * head_stride;
43
44        for i in 0..seq_len {
45            for j in 0..seq_len {
46                let dist = i.abs_diff(j);
47                scores[base + i * seq_len + j] -= slope * (dist as f32);
48            }
49        }
50    }
51}
52
53// ────────────────────────────────────────────────────────────────────────────
54// AVX2 implementation
55// ────────────────────────────────────────────────────────────────────────────
56
57/// AVX2 ALiBi bias -- delegates to scalar.
58///
59/// # Safety
60/// Requires AVX2 support.
61#[cfg(target_arch = "x86_64")]
62#[target_feature(enable = "avx2")]
63pub unsafe fn alibi_bias_avx2(scores: &mut [f32], num_heads: usize, seq_len: usize) {
64    alibi_bias_scalar(scores, num_heads, seq_len);
65}
66
67// ────────────────────────────────────────────────────────────────────────────
68// PTX implementation
69// ────────────────────────────────────────────────────────────────────────────
70
71/// PTX assembly for ALiBi bias addition.
72///
73/// One thread block per head. Each thread handles one (i, j) pair.
74pub fn alibi_ptx() -> &'static str {
75    r#".version 8.5
76.target sm_90
77.address_size 64
78.visible .entry alibi_kernel(
79    .param .u64 SCORES,
80    .param .u32 NUM_HEADS,
81    .param .u32 SEQ_LEN
82) {
83    .reg .u32 %tid, %bid, %num_heads, %seq_len;
84    .reg .u32 %head, %i, %j, %dist, %head_stride, %offset;
85    .reg .u64 %scores_ptr, %addr, %off64;
86    .reg .f32 %slope, %dist_f, %bias, %score, %exp, %neg8, %h_f, %nh_f;
87    .reg .pred %p_bound, %p_ge;
88
89    mov.u32 %tid, %tid.x;
90    mov.u32 %bid, %ctaid.x;
91
92    ld.param.u32 %num_heads, [NUM_HEADS];
93    ld.param.u32 %seq_len, [SEQ_LEN];
94    ld.param.u64 %scores_ptr, [SCORES];
95
96    // bid = head, tid = flattened (i * seq_len + j)
97    mov.u32 %head, %bid;
98    mul.lo.u32 %head_stride, %seq_len, %seq_len;
99
100    // i = tid / seq_len, j = tid % seq_len
101    div.u32 %i, %tid, %seq_len;
102    rem.u32 %j, %tid, %seq_len;
103
104    setp.ge.u32 %p_bound, %tid, %head_stride;
105    @%p_bound bra EXIT;
106
107    // slope = 2^(-8 * (head+1) / num_heads)
108    add.u32 %offset, %head, 1;
109    cvt.rn.f32.u32 %h_f, %offset;
110    cvt.rn.f32.u32 %nh_f, %num_heads;
111    mov.f32 %neg8, 0fC1000000;
112    mul.f32 %exp, %neg8, %h_f;
113    div.rn.f32 %exp, %exp, %nh_f;
114    ex2.approx.f32 %slope, %exp;
115
116    // dist = |i - j|
117    setp.ge.u32 %p_ge, %i, %j;
118    @%p_ge bra CALC_DIST_FORWARD;
119    sub.u32 %dist, %j, %i;
120    bra APPLY_BIAS;
121CALC_DIST_FORWARD:
122    sub.u32 %dist, %i, %j;
123
124APPLY_BIAS:
125    cvt.rn.f32.u32 %dist_f, %dist;
126    mul.f32 %bias, %slope, %dist_f;
127
128    // scores[head * head_stride + tid] -= bias
129    mad.lo.u32 %offset, %head, %head_stride, %tid;
130    mul.wide.u32 %off64, %offset, 4;
131    add.u64 %addr, %scores_ptr, %off64;
132    ld.global.f32 %score, [%addr];
133    sub.f32 %score, %score, %bias;
134    st.global.f32 [%addr], %score;
135
136EXIT:
137    ret;
138}
139"#
140}
141
142// ────────────────────────────────────────────────────────────────────────────
143// Tests
144// ────────────────────────────────────────────────────────────────────────────
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use proptest::prelude::*;
150
151    #[test]
152    fn test_alibi_slopes() {
153        // 8 heads: slopes should be 2^(-1), 2^(-2), ..., 2^(-8)
154        let slopes: Vec<f32> = (0..8).map(|h| alibi_slope(h, 8)).collect();
155        assert!((slopes[0] - 0.5).abs() < 1e-6);
156        assert!((slopes[1] - 0.25).abs() < 1e-6);
157        assert!((slopes[7] - 1.0 / 256.0).abs() < 1e-6);
158        // Slopes must be monotonically decreasing
159        for i in 1..8 {
160            assert!(slopes[i] < slopes[i - 1], "slopes not decreasing at {i}");
161        }
162    }
163
164    #[test]
165    fn test_alibi_diagonal_zero() {
166        // On the diagonal (i==j), bias should be 0 (distance = 0)
167        let seq_len = 4;
168        let num_heads = 2;
169        let mut scores = vec![1.0f32; num_heads * seq_len * seq_len];
170        alibi_bias_scalar(&mut scores, num_heads, seq_len);
171
172        for h in 0..num_heads {
173            for i in 0..seq_len {
174                let idx = h * seq_len * seq_len + i * seq_len + i;
175                assert_eq!(
176                    scores[idx], 1.0,
177                    "diagonal should be unchanged at h={h} i={i}"
178                );
179            }
180        }
181    }
182
183    #[test]
184    fn test_alibi_negative_bias() {
185        // Off-diagonal elements should have scores decreased
186        let seq_len = 3;
187        let num_heads = 1;
188        let mut scores = vec![0.0f32; seq_len * seq_len];
189        alibi_bias_scalar(&mut scores, num_heads, seq_len);
190
191        // All off-diagonal should be negative (bias subtracted from 0)
192        for i in 0..seq_len {
193            for j in 0..seq_len {
194                if i != j {
195                    assert!(
196                        scores[i * seq_len + j] < 0.0,
197                        "off-diagonal [{i},{j}] should be negative, got {}",
198                        scores[i * seq_len + j]
199                    );
200                }
201            }
202        }
203    }
204
205    #[test]
206    fn test_alibi_symmetry() {
207        // |i-j| = |j-i|, so bias is symmetric for each head
208        let seq_len = 5;
209        let num_heads = 2;
210        let mut scores = vec![0.0f32; num_heads * seq_len * seq_len];
211        alibi_bias_scalar(&mut scores, num_heads, seq_len);
212
213        for h in 0..num_heads {
214            let base = h * seq_len * seq_len;
215            for i in 0..seq_len {
216                for j in 0..seq_len {
217                    let a = scores[base + i * seq_len + j];
218                    let b = scores[base + j * seq_len + i];
219                    assert!(
220                        (a - b).abs() < 1e-6,
221                        "asymmetry at h={h} [{i},{j}]: {a} vs {b}"
222                    );
223                }
224            }
225        }
226    }
227
228    proptest! {
229        #[test]
230        fn prop_alibi_slopes_positive(num_heads in 1usize..17) {
231            for h in 0..num_heads {
232                let s = alibi_slope(h, num_heads);
233                prop_assert!(s > 0.0, "slope must be positive, got {s} at h={h}");
234                prop_assert!(s <= 1.0, "slope must be <= 1, got {s} at h={h}");
235            }
236        }
237
238        #[test]
239        fn prop_alibi_output_finite(
240            num_heads in 1usize..5,
241            seq_len in 1usize..8,
242        ) {
243            let mut scores = vec![0.0f32; num_heads * seq_len * seq_len];
244            alibi_bias_scalar(&mut scores, num_heads, seq_len);
245
246            for (idx, &val) in scores.iter().enumerate() {
247                prop_assert!(val.is_finite(), "scores[{idx}] = {val} not finite");
248            }
249        }
250    }
251
252    #[test]
253    fn test_alibi_ptx_structure() {
254        let ptx = alibi_ptx();
255        assert!(ptx.contains(".entry alibi_kernel"));
256        assert!(ptx.contains("ex2.approx.f32"));
257        assert!(ptx.contains("ret;"));
258    }
259
260    #[cfg(target_arch = "x86_64")]
261    #[test]
262    fn test_alibi_avx2_parity() {
263        if !is_x86_feature_detected!("avx2") {
264            return;
265        }
266        let num_heads = 2;
267        let seq_len = 4;
268        let mut scalar_scores = vec![1.0f32; num_heads * seq_len * seq_len];
269        let mut avx2_scores = scalar_scores.clone();
270        alibi_bias_scalar(&mut scalar_scores, num_heads, seq_len);
271        unsafe { alibi_bias_avx2(&mut avx2_scores, num_heads, seq_len) };
272        assert_eq!(scalar_scores, avx2_scores);
273    }
274}