#[inline]
pub fn alibi_slope(h: usize, num_heads: usize) -> f32 {
let exponent = -8.0 * ((h + 1) as f32) / (num_heads as f32);
2.0f32.powf(exponent)
}
pub fn alibi_bias_scalar(scores: &mut [f32], num_heads: usize, seq_len: usize) {
assert_eq!(
scores.len(),
num_heads * seq_len * seq_len,
"scores dimension mismatch"
);
let head_stride = seq_len * seq_len;
for h in 0..num_heads {
let slope = alibi_slope(h, num_heads);
let base = h * head_stride;
for i in 0..seq_len {
for j in 0..seq_len {
let dist = i.abs_diff(j);
scores[base + i * seq_len + j] -= slope * (dist as f32);
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn alibi_bias_avx2(scores: &mut [f32], num_heads: usize, seq_len: usize) {
alibi_bias_scalar(scores, num_heads, seq_len);
}
pub fn alibi_ptx() -> &'static str {
r#".version 8.5
.target sm_90
.address_size 64
.visible .entry alibi_kernel(
.param .u64 SCORES,
.param .u32 NUM_HEADS,
.param .u32 SEQ_LEN
) {
.reg .u32 %tid, %bid, %num_heads, %seq_len;
.reg .u32 %head, %i, %j, %dist, %head_stride, %offset;
.reg .u64 %scores_ptr, %addr, %off64;
.reg .f32 %slope, %dist_f, %bias, %score, %exp, %neg8, %h_f, %nh_f;
.reg .pred %p_bound, %p_ge;
mov.u32 %tid, %tid.x;
mov.u32 %bid, %ctaid.x;
ld.param.u32 %num_heads, [NUM_HEADS];
ld.param.u32 %seq_len, [SEQ_LEN];
ld.param.u64 %scores_ptr, [SCORES];
// bid = head, tid = flattened (i * seq_len + j)
mov.u32 %head, %bid;
mul.lo.u32 %head_stride, %seq_len, %seq_len;
// i = tid / seq_len, j = tid % seq_len
div.u32 %i, %tid, %seq_len;
rem.u32 %j, %tid, %seq_len;
setp.ge.u32 %p_bound, %tid, %head_stride;
@%p_bound bra EXIT;
// slope = 2^(-8 * (head+1) / num_heads)
add.u32 %offset, %head, 1;
cvt.rn.f32.u32 %h_f, %offset;
cvt.rn.f32.u32 %nh_f, %num_heads;
mov.f32 %neg8, 0fC1000000;
mul.f32 %exp, %neg8, %h_f;
div.rn.f32 %exp, %exp, %nh_f;
ex2.approx.f32 %slope, %exp;
// dist = |i - j|
setp.ge.u32 %p_ge, %i, %j;
@%p_ge bra CALC_DIST_FORWARD;
sub.u32 %dist, %j, %i;
bra APPLY_BIAS;
CALC_DIST_FORWARD:
sub.u32 %dist, %i, %j;
APPLY_BIAS:
cvt.rn.f32.u32 %dist_f, %dist;
mul.f32 %bias, %slope, %dist_f;
// scores[head * head_stride + tid] -= bias
mad.lo.u32 %offset, %head, %head_stride, %tid;
mul.wide.u32 %off64, %offset, 4;
add.u64 %addr, %scores_ptr, %off64;
ld.global.f32 %score, [%addr];
sub.f32 %score, %score, %bias;
st.global.f32 [%addr], %score;
EXIT:
ret;
}
"#
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn test_alibi_slopes() {
let slopes: Vec<f32> = (0..8).map(|h| alibi_slope(h, 8)).collect();
assert!((slopes[0] - 0.5).abs() < 1e-6);
assert!((slopes[1] - 0.25).abs() < 1e-6);
assert!((slopes[7] - 1.0 / 256.0).abs() < 1e-6);
for i in 1..8 {
assert!(slopes[i] < slopes[i - 1], "slopes not decreasing at {i}");
}
}
#[test]
fn test_alibi_diagonal_zero() {
let seq_len = 4;
let num_heads = 2;
let mut scores = vec![1.0f32; num_heads * seq_len * seq_len];
alibi_bias_scalar(&mut scores, num_heads, seq_len);
for h in 0..num_heads {
for i in 0..seq_len {
let idx = h * seq_len * seq_len + i * seq_len + i;
assert_eq!(
scores[idx], 1.0,
"diagonal should be unchanged at h={h} i={i}"
);
}
}
}
#[test]
fn test_alibi_negative_bias() {
let seq_len = 3;
let num_heads = 1;
let mut scores = vec![0.0f32; seq_len * seq_len];
alibi_bias_scalar(&mut scores, num_heads, seq_len);
for i in 0..seq_len {
for j in 0..seq_len {
if i != j {
assert!(
scores[i * seq_len + j] < 0.0,
"off-diagonal [{i},{j}] should be negative, got {}",
scores[i * seq_len + j]
);
}
}
}
}
#[test]
fn test_alibi_symmetry() {
let seq_len = 5;
let num_heads = 2;
let mut scores = vec![0.0f32; num_heads * seq_len * seq_len];
alibi_bias_scalar(&mut scores, num_heads, seq_len);
for h in 0..num_heads {
let base = h * seq_len * seq_len;
for i in 0..seq_len {
for j in 0..seq_len {
let a = scores[base + i * seq_len + j];
let b = scores[base + j * seq_len + i];
assert!(
(a - b).abs() < 1e-6,
"asymmetry at h={h} [{i},{j}]: {a} vs {b}"
);
}
}
}
}
proptest! {
#[test]
fn prop_alibi_slopes_positive(num_heads in 1usize..17) {
for h in 0..num_heads {
let s = alibi_slope(h, num_heads);
prop_assert!(s > 0.0, "slope must be positive, got {s} at h={h}");
prop_assert!(s <= 1.0, "slope must be <= 1, got {s} at h={h}");
}
}
#[test]
fn prop_alibi_output_finite(
num_heads in 1usize..5,
seq_len in 1usize..8,
) {
let mut scores = vec![0.0f32; num_heads * seq_len * seq_len];
alibi_bias_scalar(&mut scores, num_heads, seq_len);
for (idx, &val) in scores.iter().enumerate() {
prop_assert!(val.is_finite(), "scores[{idx}] = {val} not finite");
}
}
}
#[test]
fn test_alibi_ptx_structure() {
let ptx = alibi_ptx();
assert!(ptx.contains(".entry alibi_kernel"));
assert!(ptx.contains("ex2.approx.f32"));
assert!(ptx.contains("ret;"));
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_alibi_avx2_parity() {
if !is_x86_feature_detected!("avx2") {
return;
}
let num_heads = 2;
let seq_len = 4;
let mut scalar_scores = vec![1.0f32; num_heads * seq_len * seq_len];
let mut avx2_scores = scalar_scores.clone();
alibi_bias_scalar(&mut scalar_scores, num_heads, seq_len);
unsafe { alibi_bias_avx2(&mut avx2_scores, num_heads, seq_len) };
assert_eq!(scalar_scores, avx2_scores);
}
}