use ndarray::{s, Array2, Array3, ArrayD, ArrayView1, IxDyn};
use std::fmt;
#[derive(Debug, Clone)]
pub enum PositionError {
HeadDimMustBeEven { head_dim: usize },
SeqOffsetOutOfRange { offset: usize, max: usize },
ShapeMismatch {
expected: Vec<usize>,
got: Vec<usize>,
},
}
impl fmt::Display for PositionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::HeadDimMustBeEven { head_dim } => {
write!(f, "head_dim must be even for RoPE, got {}", head_dim)
}
Self::SeqOffsetOutOfRange { offset, max } => {
write!(
f,
"seq_offset {} is out of range (max pre-computed = {})",
offset, max
)
}
Self::ShapeMismatch { expected, got } => {
write!(f, "Shape mismatch: expected {:?}, got {:?}", expected, got)
}
}
}
}
impl std::error::Error for PositionError {}
#[derive(Debug, Clone)]
pub struct RotaryPositionEmbedding {
pub head_dim: usize,
pub base: f64,
pub max_seq_len: usize,
cos_cache: Array2<f64>,
sin_cache: Array2<f64>,
}
impl RotaryPositionEmbedding {
pub fn new(
head_dim: usize,
max_seq_len: usize,
base: f64,
) -> std::result::Result<Self, PositionError> {
if !head_dim.is_multiple_of(2) {
return Err(PositionError::HeadDimMustBeEven { head_dim });
}
let (cos_cache, sin_cache) = Self::build_cos_sin_cache(head_dim, max_seq_len, base);
Ok(Self {
head_dim,
base,
max_seq_len,
cos_cache,
sin_cache,
})
}
fn build_cos_sin_cache(
head_dim: usize,
max_seq_len: usize,
base: f64,
) -> (Array2<f64>, Array2<f64>) {
let half_dim = head_dim / 2;
let thetas: Vec<f64> = (0..half_dim)
.map(|i| base.powf(-(2.0 * i as f64) / head_dim as f64))
.collect();
let mut cos_cache = Array2::<f64>::zeros((max_seq_len, half_dim));
let mut sin_cache = Array2::<f64>::zeros((max_seq_len, half_dim));
for pos in 0..max_seq_len {
for (i, &theta) in thetas.iter().enumerate() {
let angle = pos as f64 * theta;
cos_cache[[pos, i]] = angle.cos();
sin_cache[[pos, i]] = angle.sin();
}
}
(cos_cache, sin_cache)
}
pub fn apply(
&self,
x: &ArrayD<f64>,
seq_offset: usize,
) -> std::result::Result<ArrayD<f64>, PositionError> {
let shape = x.shape();
let ndim = shape.len();
if ndim < 1 {
return Err(PositionError::ShapeMismatch {
expected: vec![1],
got: shape.to_vec(),
});
}
let last_dim = shape[ndim - 1];
if last_dim != self.head_dim {
return Err(PositionError::ShapeMismatch {
expected: vec![self.head_dim],
got: vec![last_dim],
});
}
let seq_len = shape[0];
if seq_offset + seq_len > self.max_seq_len {
return Err(PositionError::SeqOffsetOutOfRange {
offset: seq_offset + seq_len - 1,
max: self.max_seq_len - 1,
});
}
let half_dim = self.head_dim / 2;
let total = x.len() / self.head_dim;
let x2 = x
.view()
.into_shape_with_order((total, self.head_dim))
.map_err(|_| PositionError::ShapeMismatch {
expected: vec![total, self.head_dim],
got: shape.to_vec(),
})?;
let x_first = x2.slice(s![.., ..half_dim]).to_owned();
let x_second = x2.slice(s![.., half_dim..]).to_owned();
let mut rotated = Array2::<f64>::zeros((total, self.head_dim));
rotated.slice_mut(s![.., ..half_dim]).assign(&(-&x_second));
rotated.slice_mut(s![.., half_dim..]).assign(&x_first);
let positions_per_token = total.checked_div(seq_len).unwrap_or(1);
let mut cos_expanded = Array2::<f64>::zeros((total, half_dim));
let mut sin_expanded = Array2::<f64>::zeros((total, half_dim));
for i in 0..total {
let pos = seq_offset + i / positions_per_token.max(1);
let capped_pos = pos.min(self.max_seq_len - 1);
cos_expanded
.slice_mut(s![i, ..])
.assign(&self.cos_cache.slice(s![capped_pos, ..]));
sin_expanded
.slice_mut(s![i, ..])
.assign(&self.sin_cache.slice(s![capped_pos, ..]));
}
let mut cos_full = Array2::<f64>::zeros((total, self.head_dim));
let mut sin_full = Array2::<f64>::zeros((total, self.head_dim));
cos_full.slice_mut(s![.., ..half_dim]).assign(&cos_expanded);
cos_full.slice_mut(s![.., half_dim..]).assign(&cos_expanded);
sin_full.slice_mut(s![.., ..half_dim]).assign(&sin_expanded);
sin_full.slice_mut(s![.., half_dim..]).assign(&sin_expanded);
let result2 = &x2 * &cos_full + &rotated * &sin_full;
let result = result2
.into_dyn()
.into_shape_with_order(IxDyn(shape))
.map_err(|_| PositionError::ShapeMismatch {
expected: shape.to_vec(),
got: vec![total, self.head_dim],
})?;
Ok(result)
}
pub fn rotate_half(x: &ArrayD<f64>) -> ArrayD<f64> {
let shape = x.shape();
let ndim = shape.len();
if ndim < 1 {
return x.to_owned();
}
let head_dim = shape[ndim - 1];
let half = head_dim / 2;
let total = x.len() / head_dim;
let x2 = x
.view()
.into_shape_with_order((total, head_dim))
.expect("rotate_half reshape");
let x_first = x2.slice(s![.., ..half]).to_owned();
let x_second = x2.slice(s![.., half..]).to_owned();
let mut out = Array2::<f64>::zeros((total, head_dim));
out.slice_mut(s![.., ..half]).assign(&(-&x_second));
out.slice_mut(s![.., half..]).assign(&x_first);
out.into_dyn()
.into_shape_with_order(IxDyn(shape))
.expect("rotate_half final reshape")
}
pub fn frequencies_at(&self, pos: usize) -> ArrayView1<'_, f64> {
let capped = pos.min(self.max_seq_len - 1);
self.cos_cache.slice(s![capped, ..])
}
}
#[derive(Debug, Clone)]
pub struct RelativePositionBias {
pub num_heads: usize,
pub num_buckets: usize,
pub max_distance: usize,
pub bidirectional: bool,
biases: Array2<f64>,
}
impl RelativePositionBias {
pub fn new(
num_heads: usize,
num_buckets: usize,
max_distance: usize,
bidirectional: bool,
) -> Self {
Self {
num_heads,
num_buckets,
max_distance,
bidirectional,
biases: Array2::<f64>::zeros((num_buckets, num_heads)),
}
}
pub fn compute_bias(&self, query_len: usize, key_len: usize) -> Array3<f64> {
let mut bias = Array3::<f64>::zeros((self.num_heads, query_len, key_len));
for q in 0..query_len {
for k in 0..key_len {
let relative_position = q as i32 - k as i32;
let bucket = Self::relative_position_bucket(
relative_position,
self.bidirectional,
self.num_buckets,
self.max_distance,
);
for h in 0..self.num_heads {
bias[[h, q, k]] = self.biases[[bucket, h]];
}
}
}
bias
}
fn relative_position_bucket(
relative_position: i32,
bidirectional: bool,
num_buckets: usize,
max_distance: usize,
) -> usize {
let mut n = num_buckets;
let mut relative = relative_position;
if bidirectional {
n /= 2;
if relative_position > 0 {
let pos_bucket =
Self::distance_to_bucket(relative_position as usize, n, max_distance);
return (n + pos_bucket).min(num_buckets - 1);
}
relative = -relative;
} else {
relative = (-relative).max(0);
}
let distance = relative as usize;
Self::distance_to_bucket(distance, n, max_distance).min(num_buckets - 1)
}
fn distance_to_bucket(distance: usize, n: usize, max_distance: usize) -> usize {
if n == 0 {
return 0;
}
let max_exact = n / 2;
if distance < max_exact {
distance
} else {
let clamped = distance.min(max_distance);
let scale = (clamped as f64 / max_exact as f64).ln()
/ (max_distance as f64 / max_exact as f64).ln().max(1e-10);
let bucket_offset = (scale * (n - max_exact) as f64) as usize;
(max_exact + bucket_offset).min(n - 1)
}
}
pub fn update_biases(
&mut self,
new_biases: Array2<f64>,
) -> std::result::Result<(), PositionError> {
let expected = vec![self.num_buckets, self.num_heads];
let got = new_biases.shape().to_vec();
if got != expected {
return Err(PositionError::ShapeMismatch { expected, got });
}
self.biases = new_biases;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tensor(shape: &[usize], fill: f64) -> ArrayD<f64> {
ArrayD::from_elem(IxDyn(shape), fill)
}
#[test]
fn test_rope_new_builds_cache() {
let rope = RotaryPositionEmbedding::new(8, 16, 10000.0).expect("valid head_dim");
assert_eq!(
rope.cos_cache.shape(),
&[16, 4],
"cos_cache shape [max_seq, half_dim]"
);
assert_eq!(
rope.sin_cache.shape(),
&[16, 4],
"sin_cache shape [max_seq, half_dim]"
);
}
#[test]
fn test_rope_apply_preserves_shape() {
let rope = RotaryPositionEmbedding::new(8, 32, 10000.0).expect("valid");
let x = make_tensor(&[4, 8], 1.0);
let result = rope.apply(&x, 0).expect("apply should succeed");
assert_eq!(
result.shape(),
x.shape(),
"output shape must match input shape"
);
}
#[test]
fn test_rope_rotate_half_correct() {
let data = vec![1.0_f64, 2.0, 3.0, 4.0];
let x = ArrayD::from_shape_vec(IxDyn(&[1, 4]), data).expect("build");
let rotated = RotaryPositionEmbedding::rotate_half(&x);
let flat: Vec<f64> = rotated.iter().copied().collect();
assert!(
(flat[0] - (-3.0)).abs() < 1e-9,
"first element should be -3"
);
assert!(
(flat[1] - (-4.0)).abs() < 1e-9,
"second element should be -4"
);
assert!((flat[2] - 1.0).abs() < 1e-9, "third element should be 1");
assert!((flat[3] - 2.0).abs() < 1e-9, "fourth element should be 2");
}
#[test]
fn test_rope_head_dim_odd_errors() {
let result = RotaryPositionEmbedding::new(7, 16, 10000.0);
assert!(
matches!(result, Err(PositionError::HeadDimMustBeEven { .. })),
"odd head_dim should produce HeadDimMustBeEven error"
);
}
#[test]
fn test_relative_position_bias_compute() {
let rpb = RelativePositionBias::new(4, 32, 128, true);
let bias = rpb.compute_bias(6, 10);
assert_eq!(
bias.shape(),
&[4, 6, 10],
"bias shape must be [num_heads, q_len, k_len]"
);
}
#[test]
fn test_relative_position_bias_symmetric_for_bidirectional() {
let _rpb = RelativePositionBias::new(1, 32, 64, true);
let forward_bucket = RelativePositionBias::relative_position_bucket(5, true, 32, 64);
let backward_bucket = RelativePositionBias::relative_position_bucket(-5, true, 32, 64);
assert_ne!(
forward_bucket, backward_bucket,
"forward and backward positions should map to different buckets"
);
}
#[test]
fn test_relative_position_bucket_clamping() {
let bucket = RelativePositionBias::relative_position_bucket(100000, false, 16, 128);
assert!(bucket < 16, "bucket must be within [0, num_buckets)");
}
}