#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_alibi_get_bias_shape() {
let alibi = ALiBi::new(4).expect("alibi");
let bias = alibi.get_bias(8).expect("bias");
assert_eq!(bias.shape(), &[8, 8, 4]);
}
#[test]
fn test_alibi_get_bias_zero_seq_error() {
let alibi = ALiBi::new(4).expect("alibi");
let result = alibi.get_bias(0);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("seq_len must be > 0"));
}
#[test]
fn test_alibi_get_bias_diagonal_is_zero() {
let alibi = ALiBi::new(2).expect("alibi");
let bias = alibi.get_bias(3).expect("bias");
let data = bias.data();
assert!((data[0] - 0.0).abs() < 1e-6);
assert!((data[8] - 0.0).abs() < 1e-6);
assert!((data[16] - 0.0).abs() < 1e-6);
}
#[test]
fn test_alibi_get_bias_values() {
let alibi = ALiBi::new(1).expect("alibi");
let bias = alibi.get_bias(3).expect("bias");
let data = bias.data();
let slope = alibi.slopes()[0];
assert!((data[0] - 0.0).abs() < 1e-6); assert!((data[1] - (-slope * 1.0)).abs() < 1e-6); assert!((data[2] - (-slope * 2.0)).abs() < 1e-6); assert!((data[3] - (-slope * 1.0)).abs() < 1e-6); assert!((data[4] - 0.0).abs() < 1e-6); assert!((data[5] - (-slope * 1.0)).abs() < 1e-6); }
#[test]
fn test_alibi_bias_symmetry() {
let alibi = ALiBi::new(2).expect("alibi");
let bias = alibi.get_bias(4).expect("bias");
let data = bias.data();
for i in 0..4 {
for j in 0..4 {
for h in 0..2 {
let idx_ij = i * 4 * 2 + j * 2 + h;
let idx_ji = j * 4 * 2 + i * 2 + h;
assert!((data[idx_ij] - data[idx_ji]).abs() < 1e-6);
}
}
}
}
#[test]
fn test_alibi_clone() {
let alibi = ALiBi::new(4).expect("alibi");
let cloned = alibi.clone();
assert_eq!(alibi.num_heads(), cloned.num_heads());
assert_eq!(alibi.slopes(), cloned.slopes());
}
#[test]
fn test_alibi_large_seq_len() {
let alibi = ALiBi::new(8).expect("alibi");
let bias = alibi.get_bias(256).expect("bias");
assert_eq!(bias.shape(), &[256, 256, 8]);
let data = bias.data();
let idx = 255 * 8; let expected = -alibi.slopes()[0] * 255.0;
assert!((data[idx] - expected).abs() < 1e-4);
}
#[test]
fn test_alibi_non_power_of_two_slopes() {
let alibi = ALiBi::new(6).expect("alibi");
let slopes = alibi.slopes();
assert_eq!(slopes.len(), 6);
assert!((slopes[0] - 1.0).abs() < 1e-6);
assert!((slopes[3] - 0.015625).abs() < 1e-6);
assert!(slopes[4] > 0.0);
assert!(slopes[5] > 0.0);
}
include!("position_rope.rs");
}