use super::rope_common::validate_and_prepare;
use crate::error::{Error, Result};
use numr::autograd::{Var, var_add, var_cat, var_mul, var_narrow, var_sub};
use numr::ops::{ScalarOps, ShapeOps, TensorOps, TypeConversionOps};
use numr::runtime::{Runtime, RuntimeClient};
pub fn apply_rope_impl<R, C>(
client: &C,
x: &Var<R>,
cos_cache: &Var<R>,
sin_cache: &Var<R>,
) -> Result<Var<R>>
where
R: Runtime<DType = numr::dtype::DType>,
C: RuntimeClient<R> + ScalarOps<R> + ShapeOps<R> + TypeConversionOps<R>,
R::Client: TensorOps<R> + ShapeOps<R> + TypeConversionOps<R>,
{
let (_shape, _seq_len, half_d, cos_reshaped, sin_reshaped) =
validate_and_prepare::<R, C>(client, x, cos_cache, sin_cache)?;
let x1 = var_narrow(x, -1, 0, half_d).map_err(Error::Numr)?;
let x2 = var_narrow(x, -1, half_d, half_d).map_err(Error::Numr)?;
let x1_cos = var_mul(&x1, &cos_reshaped, client).map_err(Error::Numr)?;
let x2_sin = var_mul(&x2, &sin_reshaped, client).map_err(Error::Numr)?;
let out1 = var_sub(&x1_cos, &x2_sin, client).map_err(Error::Numr)?;
let x1_sin = var_mul(&x1, &sin_reshaped, client).map_err(Error::Numr)?;
let x2_cos = var_mul(&x2, &cos_reshaped, client).map_err(Error::Numr)?;
let out2 = var_add(&x1_sin, &x2_cos, client).map_err(Error::Numr)?;
var_cat(&[&out1, &out2], -1, client).map_err(Error::Numr)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::runtime::cpu::CpuRuntime;
use numr::tensor::Tensor;
#[test]
fn test_rope_output_shape() {
let (client, device) = cpu_setup();
let b = 1;
let h = 2;
let s = 4;
let d = 8;
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&vec![1.0f32; b * h * s * d], &[b, h, s, d], &device),
false,
);
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(&vec![1.0f32; s * d / 2], &[s, d / 2], &device),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(&vec![0.0f32; s * d / 2], &[s, d / 2], &device),
false,
);
let out = apply_rope_impl(&client, &x, &cos, &sin).unwrap();
assert_eq!(out.tensor().shape(), &[b, h, s, d]);
}
#[test]
fn test_rope_identity_with_zero_angle() {
let (client, device) = cpu_setup();
let x_data: Vec<f32> = (0..16).map(|i| i as f32).collect();
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&x_data, &[1, 1, 2, 8], &device),
false,
);
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32; 8], &[2, 4], &device),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32; 8], &[2, 4], &device),
false,
);
let out = apply_rope_impl(&client, &x, &cos, &sin).unwrap();
let out_data: Vec<f32> = out.tensor().contiguous().to_vec();
for (i, (&a, &b)) in out_data.iter().zip(x_data.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-5,
"mismatch at {}: got {}, expected {}",
i,
a,
b
);
}
}
#[test]
fn test_rope_90_degree_rotation() {
let (client, device) = cpu_setup();
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(
&[1.0f32, 2.0, 3.0, 4.0], &[1, 1, 1, 4],
&device,
),
false,
);
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0], &[1, 2], &device),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[1, 2], &device),
false,
);
let out = apply_rope_impl(&client, &x, &cos, &sin).unwrap();
let out_data: Vec<f32> = out.tensor().contiguous().to_vec();
assert!((out_data[0] - (-3.0)).abs() < 1e-5);
assert!((out_data[1] - (-4.0)).abs() < 1e-5);
assert!((out_data[2] - 1.0).abs() < 1e-5);
assert!((out_data[3] - 2.0).abs() < 1e-5);
}
#[test]
fn test_rope_narrowing_matches_exact_cache() {
let (client, device) = cpu_setup();
let x_data: Vec<f32> = (0..16).map(|i| i as f32 * 0.1).collect();
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&x_data, &[1, 1, 2, 8], &device),
false,
);
let cos_exact = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32; 8], &[2, 4], &device),
false,
);
let sin_exact = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32; 8], &[2, 4], &device),
false,
);
let cos_large = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32; 32], &[8, 4], &device),
false,
);
let sin_large = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32; 32], &[8, 4], &device),
false,
);
let out_exact = apply_rope_impl(&client, &x, &cos_exact, &sin_exact).unwrap();
let out_narrowed = apply_rope_impl(&client, &x, &cos_large, &sin_large).unwrap();
let exact_data: Vec<f32> = out_exact.tensor().contiguous().to_vec();
let narrowed_data: Vec<f32> = out_narrowed.tensor().contiguous().to_vec();
assert_eq!(exact_data.len(), narrowed_data.len());
for (i, (&a, &b)) in exact_data.iter().zip(narrowed_data.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-5,
"mismatch at {i}: exact={a}, narrowed={b}"
);
}
}
#[test]
fn test_rope_invalid_odd_dim() {
let (client, device) = cpu_setup();
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32; 3], &[1, 1, 1, 3], &device),
false,
);
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1, 1], &device),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32], &[1, 1], &device),
false,
);
let result = apply_rope_impl(&client, &x, &cos, &sin);
assert!(result.is_err());
}
}