use ndarray::{Array2, Array4, ArrayView4, Axis};
#[derive(Debug, Clone, Copy)]
pub struct RopeConfig {
pub head_dim: usize,
pub base: f32,
pub interleaved: bool,
}
impl RopeConfig {
#[inline]
pub fn half(&self) -> usize {
self.head_dim / 2
}
}
#[derive(Debug, Clone)]
pub struct RopeTables {
pub cfg: RopeConfig,
pub seq_len: usize,
pub cos: Array2<f32>,
pub sin: Array2<f32>,
}
impl RopeTables {
pub fn new(cfg: RopeConfig, seq_len: usize) -> Self {
assert!(
cfg.head_dim.is_multiple_of(2),
"head_dim must be even for RoPE"
);
let half = cfg.half();
let freqs: Vec<f32> = (0..half)
.map(|k| 1.0_f32 / cfg.base.powf((2 * k) as f32 / cfg.head_dim as f32))
.collect();
let mut phases = Array2::<f32>::zeros((seq_len, half));
for n in 0..seq_len {
for k in 0..half {
phases[(n, k)] = n as f32 * freqs[k];
}
}
let (cos, sin) = if cfg.interleaved {
let mut c = Array2::<f32>::zeros((seq_len, cfg.head_dim));
let mut s = Array2::<f32>::zeros((seq_len, cfg.head_dim));
for n in 0..seq_len {
for k in 0..half {
let cv = phases[(n, k)].cos();
let sv = phases[(n, k)].sin();
c[(n, 2 * k)] = cv;
c[(n, 2 * k + 1)] = cv;
s[(n, 2 * k)] = sv;
s[(n, 2 * k + 1)] = sv;
}
}
(c, s)
} else {
let c = phases.mapv(f32::cos);
let s = phases.mapv(f32::sin);
(c, s)
};
Self {
cfg,
seq_len,
cos,
sin,
}
}
}
pub fn apply_rotary_emb_ref(t: &ArrayView4<f32>, tables: &RopeTables) -> Array4<f32> {
let (b, h, seq, d) = (t.shape()[0], t.shape()[1], t.shape()[2], t.shape()[3]);
assert!(
seq <= tables.seq_len,
"input seq_len {seq} > table seq_len {}",
tables.seq_len
);
assert_eq!(d, tables.cfg.head_dim, "input head_dim mismatch");
let head_dim = d;
let half = tables.cfg.half();
let mut out = Array4::<f32>::zeros((b, h, seq, d));
if tables.cfg.interleaved {
for bi in 0..b {
for hi in 0..h {
for ti in 0..seq {
for k in 0..half {
let c = tables.cos[(ti, 2 * k)];
let s = tables.sin[(ti, 2 * k)];
let x0 = t[(bi, hi, ti, 2 * k)];
let x1 = t[(bi, hi, ti, 2 * k + 1)];
out[(bi, hi, ti, 2 * k)] = x0 * c + (-x1) * s;
out[(bi, hi, ti, 2 * k + 1)] = x1 * c + x0 * s;
}
}
}
}
} else {
for bi in 0..b {
for hi in 0..h {
for ti in 0..seq {
for k in 0..half {
let c = tables.cos[(ti, k)];
let s = tables.sin[(ti, k)];
let x_low = t[(bi, hi, ti, k)];
let x_high = t[(bi, hi, ti, k + half)];
out[(bi, hi, ti, k)] = x_low * c - x_high * s;
out[(bi, hi, ti, k + half)] = x_high * c + x_low * s;
}
}
}
}
}
if head_dim > 2 * half {
for bi in 0..b {
for hi in 0..h {
for ti in 0..seq {
for k in (2 * half)..head_dim {
out[(bi, hi, ti, k)] = t[(bi, hi, ti, k)];
}
}
}
}
}
out
}
pub fn cos_table_full(tables: &RopeTables) -> Array2<f32> {
if tables.cfg.interleaved {
tables.cos.clone()
} else {
ndarray::concatenate(Axis(1), &[tables.cos.view(), tables.cos.view()]).unwrap()
}
}
pub fn sin_table_full(tables: &RopeTables) -> Array2<f32> {
if tables.cfg.interleaved {
tables.sin.clone()
} else {
ndarray::concatenate(Axis(1), &[tables.sin.view(), tables.sin.view()]).unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn lang_freqs_match_python_formula() {
let cfg = RopeConfig {
head_dim: 16,
base: 100_000.0,
interleaved: false,
};
let t = RopeTables::new(cfg, 1);
for k in 0..cfg.half() {
assert_abs_diff_eq!(t.cos[(0, k)], 1.0, epsilon = 1e-7);
assert_abs_diff_eq!(t.sin[(0, k)], 0.0, epsilon = 1e-7);
}
let t2 = RopeTables::new(cfg, 2);
assert_abs_diff_eq!(t2.cos[(1, 0)], 1.0_f32.cos(), epsilon = 1e-6);
assert_abs_diff_eq!(t2.sin[(1, 0)], 1.0_f32.sin(), epsilon = 1e-6);
let k_last = cfg.half() - 1;
let f_last = 1.0_f32 / cfg.base.powf((2 * k_last) as f32 / cfg.head_dim as f32);
assert_abs_diff_eq!(t2.cos[(1, k_last)], f_last.cos(), epsilon = 1e-6);
}
fn assert_arrays_close(a: &Array4<f32>, b: &Array4<f32>, eps: f32) {
assert_eq!(a.shape(), b.shape());
for (x, y) in a.iter().zip(b.iter()) {
assert!(
(x - y).abs() <= eps,
"values differ: {} vs {} (eps {})",
x,
y,
eps
);
}
}
#[test]
fn identity_at_position_zero_non_interleaved() {
let cfg = RopeConfig {
head_dim: 8,
base: 100_000.0,
interleaved: false,
};
let tables = RopeTables::new(cfg, 4);
let x = Array4::from_shape_fn((2, 3, 1, 8), |(b, h, _, d)| {
(b as f32) * 0.1 + (h as f32) * 0.01 + (d as f32) * 0.001
});
let y = apply_rotary_emb_ref(&x.view(), &tables);
assert_arrays_close(&x, &y, 1e-6);
}
#[test]
fn identity_at_position_zero_interleaved() {
let cfg = RopeConfig {
head_dim: 8,
base: 100_000.0,
interleaved: true,
};
let tables = RopeTables::new(cfg, 4);
let x = Array4::from_shape_fn((2, 3, 1, 8), |(b, h, _, d)| {
(b as f32) * 0.1 + (h as f32) * 0.01 + (d as f32) * 0.001
});
let y = apply_rotary_emb_ref(&x.view(), &tables);
assert_arrays_close(&x, &y, 1e-6);
}
#[test]
fn rotation_preserves_norm_interleaved() {
let cfg = RopeConfig {
head_dim: 8,
base: 10_000.0,
interleaved: true,
};
let tables = RopeTables::new(cfg, 5);
let x = Array4::from_shape_fn((1, 2, 5, 8), |(_, _, t, d)| ((t * 8 + d + 1) as f32).sin());
let y = apply_rotary_emb_ref(&x.view(), &tables);
for h in 0..2 {
for t in 0..5 {
for k in 0..cfg.half() {
let pre = x[(0, h, t, 2 * k)].powi(2) + x[(0, h, t, 2 * k + 1)].powi(2);
let post = y[(0, h, t, 2 * k)].powi(2) + y[(0, h, t, 2 * k + 1)].powi(2);
assert_abs_diff_eq!(pre, post, epsilon = 1e-5);
}
}
}
}
#[test]
fn rotation_preserves_norm_non_interleaved() {
let cfg = RopeConfig {
head_dim: 8,
base: 10_000.0,
interleaved: false,
};
let tables = RopeTables::new(cfg, 5);
let x = Array4::from_shape_fn((1, 2, 5, 8), |(_, _, t, d)| ((t * 8 + d + 1) as f32).cos());
let y = apply_rotary_emb_ref(&x.view(), &tables);
let half = cfg.half();
for h in 0..2 {
for t in 0..5 {
for k in 0..half {
let pre = x[(0, h, t, k)].powi(2) + x[(0, h, t, k + half)].powi(2);
let post = y[(0, h, t, k)].powi(2) + y[(0, h, t, k + half)].powi(2);
assert_abs_diff_eq!(pre, post, epsilon = 1e-5);
}
}
}
}
#[test]
fn full_cos_sin_tables_double_correctly() {
let cfg = RopeConfig {
head_dim: 8,
base: 10_000.0,
interleaved: false,
};
let t = RopeTables::new(cfg, 3);
let cos = cos_table_full(&t);
let sin = sin_table_full(&t);
assert_eq!(cos.shape(), &[3, 8]);
assert_eq!(sin.shape(), &[3, 8]);
for ti in 0..3 {
for k in 0..4 {
assert_eq!(cos[(ti, k)], cos[(ti, k + 4)]);
assert_eq!(sin[(ti, k)], sin[(ti, k + 4)]);
}
}
}
}