use super::rope_common::validate_and_prepare;
use crate::error::{Error, Result};
use numr::autograd::{Var, var_add, var_cat, var_mul, var_narrow, var_sub};
use numr::ops::{ScalarOps, ShapeOps, TensorOps, TypeConversionOps};
use numr::runtime::{Runtime, RuntimeClient};
pub fn apply_rope_interleaved_impl<R, C>(
client: &C,
x: &Var<R>,
cos_cache: &Var<R>,
sin_cache: &Var<R>,
) -> Result<Var<R>>
where
R: Runtime<DType = numr::dtype::DType>,
C: RuntimeClient<R> + ScalarOps<R> + ShapeOps<R> + TypeConversionOps<R>,
R::Client: TensorOps<R> + ShapeOps<R> + TypeConversionOps<R>,
{
let (shape, seq_len, half_d, cos_reshaped, sin_reshaped) =
validate_and_prepare::<R, C>(client, x, cos_cache, sin_cache)?;
let b = shape[0];
let h = shape[1];
let d = shape[3];
let total_bhsd = b * h * seq_len;
let x_flat = numr::autograd::var_reshape(x, &[total_bhsd, half_d, 2]).map_err(Error::Numr)?;
let x_even_3d = var_narrow(&x_flat, -1, 0, 1).map_err(Error::Numr)?; let x_odd_3d = var_narrow(&x_flat, -1, 1, 1).map_err(Error::Numr)?;
let zero_shape = &[total_bhsd, half_d, 1];
let zero = Var::new(
numr::tensor::Tensor::<R>::zeros(zero_shape, x.tensor().dtype(), x.tensor().device()),
false,
);
let x_even_contig = var_add(&x_even_3d, &zero, client).map_err(Error::Numr)?;
let x_odd_contig = var_add(&x_odd_3d, &zero, client).map_err(Error::Numr)?;
let x_even = numr::autograd::var_reshape(&x_even_contig, &[b, h, seq_len, half_d])
.map_err(Error::Numr)?;
let x_odd = numr::autograd::var_reshape(&x_odd_contig, &[b, h, seq_len, half_d])
.map_err(Error::Numr)?;
let even_cos = var_mul(&x_even, &cos_reshaped, client).map_err(Error::Numr)?;
let odd_sin = var_mul(&x_odd, &sin_reshaped, client).map_err(Error::Numr)?;
let out_even = var_sub(&even_cos, &odd_sin, client).map_err(Error::Numr)?;
let even_sin = var_mul(&x_even, &sin_reshaped, client).map_err(Error::Numr)?;
let odd_cos = var_mul(&x_odd, &cos_reshaped, client).map_err(Error::Numr)?;
let out_odd = var_add(&even_sin, &odd_cos, client).map_err(Error::Numr)?;
let out_even_3d =
numr::autograd::var_reshape(&out_even, &[total_bhsd, half_d, 1]).map_err(Error::Numr)?;
let out_odd_3d =
numr::autograd::var_reshape(&out_odd, &[total_bhsd, half_d, 1]).map_err(Error::Numr)?;
let interleaved = var_cat(&[&out_even_3d, &out_odd_3d], -1, client).map_err(Error::Numr)?;
numr::autograd::var_reshape(&interleaved, &[b, h, seq_len, d]).map_err(Error::Numr)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::runtime::cpu::CpuRuntime;
use numr::tensor::Tensor;
#[test]
fn test_rope_interleaved_output_shape() {
let (client, device) = cpu_setup();
let (b, h, s, d) = (1, 2, 4, 8);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&vec![1.0f32; b * h * s * d], &[b, h, s, d], &device),
false,
);
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(&vec![1.0f32; s * d / 2], &[s, d / 2], &device),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(&vec![0.0f32; s * d / 2], &[s, d / 2], &device),
false,
);
let out = apply_rope_interleaved_impl(&client, &x, &cos, &sin).unwrap();
assert_eq!(out.tensor().shape(), &[b, h, s, d]);
}
#[test]
fn test_rope_interleaved_identity_with_zero_angle() {
let (client, device) = cpu_setup();
let x_data: Vec<f32> = (0..8).map(|i| i as f32).collect();
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&x_data, &[1, 1, 1, 8], &device),
false,
);
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32; 4], &[1, 4], &device),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32; 4], &[1, 4], &device),
false,
);
let out = apply_rope_interleaved_impl(&client, &x, &cos, &sin).unwrap();
let out_data: Vec<f32> = out.tensor().contiguous().to_vec();
for (i, (&a, &b)) in out_data.iter().zip(x_data.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-5,
"mismatch at {}: got {}, expected {}",
i,
a,
b
);
}
}
#[test]
fn test_rope_interleaved_90_degree_rotation() {
let (client, device) = cpu_setup();
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 1, 1, 4], &device),
false,
);
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0], &[1, 2], &device),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[1, 2], &device),
false,
);
let out = apply_rope_interleaved_impl(&client, &x, &cos, &sin).unwrap();
let out_data: Vec<f32> = out.tensor().contiguous().to_vec();
assert!((out_data[0] - (-2.0)).abs() < 1e-5, "got {}", out_data[0]);
assert!((out_data[1] - 1.0).abs() < 1e-5, "got {}", out_data[1]);
assert!((out_data[2] - (-4.0)).abs() < 1e-5, "got {}", out_data[2]);
assert!((out_data[3] - 3.0).abs() < 1e-5, "got {}", out_data[3]);
}
}