use crate::layers::attention::{dot, softmax};
#[derive(Debug, Clone)]
pub struct AliBiSlopes {
slopes: Vec<f32>,
num_heads: usize,
}
impl AliBiSlopes {
pub fn new(num_heads: usize) -> Self {
assert!(num_heads > 0, "num_heads must be > 0");
let start = 2.0_f32.powf(-8.0 / num_heads as f32);
let slopes: Vec<f32> = (1..=num_heads).map(|i| start.powi(i as i32)).collect();
Self { slopes, num_heads }
}
pub fn new_extrapolated(num_heads: usize) -> Self {
assert!(num_heads > 0, "num_heads must be > 0");
let mut p = 1usize;
while p < num_heads {
p <<= 1;
}
let start_p = 2.0_f32.powf(-8.0 / p as f32);
let full_slopes: Vec<f32> = (1..=p).map(|i| start_p.powi(i as i32)).collect();
if p == num_heads {
return Self {
slopes: full_slopes,
num_heads,
};
}
let half = p / 2;
let start_half = 2.0_f32.powf(-8.0 / half as f32);
let half_slopes: Vec<f32> = (1..=half).map(|i| start_half.powi(i as i32)).collect();
let mut slopes: Vec<f32> = Vec::with_capacity(num_heads);
let mut hi = half_slopes.iter();
let mut fi = full_slopes.iter();
for idx in 0..num_heads {
if idx % 2 == 0 {
if let Some(&s) = hi.next() {
slopes.push(s);
} else if let Some(&s) = fi.next() {
slopes.push(s);
}
} else {
if let Some(&s) = fi.next() {
slopes.push(s);
} else if let Some(&s) = hi.next() {
slopes.push(s);
}
}
}
while slopes.len() < num_heads {
let last = *slopes.last().expect("at least one slope computed");
slopes.push(last * 0.5);
}
Self { slopes, num_heads }
}
#[inline]
pub fn slopes(&self) -> &[f32] {
&self.slopes
}
#[inline]
pub fn get(&self, head: usize) -> f32 {
self.slopes[head]
}
#[inline]
pub fn num_heads(&self) -> usize {
self.num_heads
}
}
pub struct AliBiBias {
pub slopes: AliBiSlopes,
}
impl AliBiBias {
pub fn new(num_heads: usize) -> Self {
Self {
slopes: AliBiSlopes::new(num_heads),
}
}
pub fn bias_for_head(&self, head: usize, q_pos: usize, kv_len: usize) -> Vec<f32> {
let slope = self.slopes.get(head);
(0..kv_len)
.map(|k| {
let distance = q_pos as f32 - k as f32;
-slope * distance
})
.collect()
}
pub fn biases_all_heads(&self, q_pos: usize, kv_len: usize) -> Vec<Vec<f32>> {
(0..self.slopes.num_heads())
.map(|head| self.bias_for_head(head, q_pos, kv_len))
.collect()
}
pub fn apply(&self, scores: &mut [Vec<f32>], q_pos: usize) {
let kv_len = scores.first().map(|s| s.len()).unwrap_or(0);
let biases = self.biases_all_heads(q_pos, kv_len);
for (head_scores, head_biases) in scores.iter_mut().zip(biases.iter()) {
for (s, b) in head_scores.iter_mut().zip(head_biases.iter()) {
*s += b;
}
}
}
pub fn biases_for_sequence(
&self,
q_len: usize,
kv_len: usize,
q_offset: usize,
) -> Vec<Vec<Vec<f32>>> {
(0..q_len)
.map(|qi| self.biases_all_heads(q_offset + qi, kv_len))
.collect()
}
}
#[derive(Debug, Clone)]
pub struct AliBiConfig {
pub num_heads: usize,
pub use_extrapolated_slopes: bool,
pub causal: bool,
}
impl Default for AliBiConfig {
fn default() -> Self {
Self {
num_heads: 8,
use_extrapolated_slopes: false,
causal: true,
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn attention_with_alibi(
query: &[f32],
keys: &[f32],
values: &[f32],
config: &AliBiConfig,
head_dim: usize,
num_kv_heads: usize,
q_pos: usize,
) -> Vec<f32> {
let num_heads = config.num_heads;
debug_assert!(num_kv_heads > 0, "num_kv_heads must be > 0");
debug_assert_eq!(query.len(), num_heads * head_dim);
let kv_len = if num_kv_heads > 0 && head_dim > 0 {
keys.len() / (num_kv_heads * head_dim)
} else {
0
};
let scale = 1.0_f32 / (head_dim as f32).sqrt();
let heads_per_kv = num_heads / num_kv_heads;
let alibi = if config.use_extrapolated_slopes {
AliBiBias {
slopes: AliBiSlopes::new_extrapolated(num_heads),
}
} else {
AliBiBias::new(num_heads)
};
let mut output = vec![0.0_f32; num_heads * head_dim];
for q_head in 0..num_heads {
let kv_head = q_head / heads_per_kv;
let q_start = q_head * head_dim;
let q_vec = &query[q_start..q_start + head_dim];
let mut scores: Vec<f32> = (0..kv_len)
.map(|t| {
let k_start = t * num_kv_heads * head_dim + kv_head * head_dim;
let k_vec = &keys[k_start..k_start + head_dim];
dot(q_vec, k_vec) * scale
})
.collect();
let biases = alibi.bias_for_head(q_head, q_pos, kv_len);
for (s, b) in scores.iter_mut().zip(biases.iter()) {
*s += b;
}
if config.causal {
for (k, s) in scores.iter_mut().enumerate() {
if k > q_pos {
*s = f32::NEG_INFINITY;
}
}
}
softmax(&mut scores);
let out_start = q_head * head_dim;
for d in 0..head_dim {
let mut acc = 0.0_f32;
for (t, &score_t) in scores.iter().enumerate().take(kv_len) {
let v_start = t * num_kv_heads * head_dim + kv_head * head_dim;
acc += score_t * values[v_start + d];
}
output[out_start + d] = acc;
}
}
output
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_alibi_slopes_power_of_2() {
let s1 = AliBiSlopes::new(8);
let s2 = AliBiSlopes::new_extrapolated(8);
assert_eq!(s1.num_heads(), 8);
assert_eq!(s2.num_heads(), 8);
for i in 0..8 {
assert!(
(s1.get(i) - s2.get(i)).abs() < 1e-6,
"slope mismatch at head {i}: {} vs {}",
s1.get(i),
s2.get(i)
);
}
}
#[test]
fn test_alibi_slopes_8_heads() {
let slopes = AliBiSlopes::new(8);
assert_eq!(slopes.num_heads(), 8);
assert_eq!(slopes.slopes().len(), 8);
let expected_first = 0.5_f32;
assert!(
(slopes.get(0) - expected_first).abs() < 1e-5,
"first slope: got {}, expected {}",
slopes.get(0),
expected_first
);
let expected_last = 0.5_f32.powi(8);
assert!(
(slopes.get(7) - expected_last).abs() < 1e-7,
"last slope: got {}, expected {}",
slopes.get(7),
expected_last
);
}
#[test]
fn test_alibi_slopes_decreasing() {
let slopes = AliBiSlopes::new(16);
for i in 1..16 {
assert!(
slopes.get(i) < slopes.get(i - 1),
"slopes not strictly decreasing at index {i}: {} >= {}",
slopes.get(i),
slopes.get(i - 1)
);
}
}
#[test]
fn test_alibi_bias_zero_distance() {
let bias = AliBiBias::new(4);
for head in 0..4 {
let q_pos = 7;
let biases = bias.bias_for_head(head, q_pos, q_pos + 1);
let at_q = biases[q_pos];
assert!(
at_q.abs() < 1e-8,
"head {head}: expected zero bias at q_pos={q_pos}, got {at_q}"
);
}
}
#[test]
fn test_alibi_bias_increases_with_distance() {
let bias = AliBiBias::new(4);
let q_pos = 10;
let kv_len = 11;
for head in 0..4 {
let biases = bias.bias_for_head(head, q_pos, kv_len);
for k in 1..=q_pos {
assert!(
biases[k] > biases[k - 1],
"head {head}: bias should increase with k, but biases[{k}]={} <= biases[{}]={}",
biases[k],
k - 1,
biases[k - 1]
);
}
}
}
#[test]
fn test_alibi_biases_all_heads_shape() {
let num_heads = 6;
let kv_len = 20;
let bias = AliBiBias::new(num_heads);
let all = bias.biases_all_heads(15, kv_len);
assert_eq!(all.len(), num_heads, "outer dim must equal num_heads");
for (h, row) in all.iter().enumerate() {
assert_eq!(row.len(), kv_len, "head {h}: inner dim must equal kv_len");
}
}
#[test]
fn test_alibi_apply_modifies_scores() {
let num_heads = 4;
let kv_len = 5;
let bias = AliBiBias::new(num_heads);
let mut scores: Vec<Vec<f32>> = vec![vec![0.0_f32; kv_len]; num_heads];
let q_pos = 4; bias.apply(&mut scores, q_pos);
for (head, scores_head) in scores.iter().enumerate() {
assert!(
scores_head[q_pos].abs() < 1e-8,
"head {head}: score at q_pos should be 0 after ALiBi, got {}",
scores_head[q_pos]
);
for (k, &score_k) in scores_head[..q_pos].iter().enumerate() {
assert!(
score_k < 0.0,
"head {head}: score at k={k} should be negative, got {}",
score_k
);
}
}
}
#[test]
fn test_alibi_biases_for_sequence_shape() {
let num_heads = 4;
let q_len = 3;
let kv_len = 8;
let q_offset = 5;
let bias = AliBiBias::new(num_heads);
let seq_biases = bias.biases_for_sequence(q_len, kv_len, q_offset);
assert_eq!(seq_biases.len(), q_len, "outer dim must equal q_len");
for (qi, head_biases) in seq_biases.iter().enumerate() {
assert_eq!(
head_biases.len(),
num_heads,
"q={qi}: second dim must equal num_heads"
);
for (h, kv_biases) in head_biases.iter().enumerate() {
assert_eq!(
kv_biases.len(),
kv_len,
"q={qi} h={h}: inner dim must equal kv_len"
);
}
}
}
#[test]
fn test_attention_with_alibi_output_shape() {
let num_heads = 4;
let num_kv_heads = 2;
let head_dim = 8;
let kv_len = 5;
let q_pos = 4;
let query = vec![0.1_f32; num_heads * head_dim];
let keys = vec![0.05_f32; kv_len * num_kv_heads * head_dim];
let values = vec![0.2_f32; kv_len * num_kv_heads * head_dim];
let config = AliBiConfig {
num_heads,
use_extrapolated_slopes: false,
causal: true,
};
let output = attention_with_alibi(
&query,
&keys,
&values,
&config,
head_dim,
num_kv_heads,
q_pos,
);
assert_eq!(
output.len(),
num_heads * head_dim,
"output length must be num_heads * head_dim"
);
for (i, &v) in output.iter().enumerate() {
assert!(v.is_finite(), "output[{i}] = {v} is not finite");
}
}
#[test]
fn test_alibi_extrapolated_slopes() {
let slopes = AliBiSlopes::new_extrapolated(12);
assert_eq!(slopes.num_heads(), 12);
assert_eq!(slopes.slopes().len(), 12);
for (i, &s) in slopes.slopes().iter().enumerate() {
assert!(
s > 0.0 && s < 1.0,
"extrapolated slope[{i}] = {s} out of (0,1)"
);
}
}
}