use super::rope_standard::apply_rope_impl;
use crate::error::{Error, Result};
use numr::autograd::Var;
use numr::ops::{ScalarOps, ShapeOps, TensorOps, TypeConversionOps};
use numr::runtime::{Runtime, RuntimeClient};
pub fn apply_rope_yarn_impl<R, C>(
client: &C,
x: &Var<R>,
cos_cache: &Var<R>,
sin_cache: &Var<R>,
attn_scale: f32,
) -> Result<Var<R>>
where
R: Runtime<DType = numr::dtype::DType>,
C: RuntimeClient<R> + ScalarOps<R> + ShapeOps<R> + TypeConversionOps<R>,
R::Client: TensorOps<R> + ShapeOps<R> + ScalarOps<R> + TypeConversionOps<R>,
{
let rotated = apply_rope_impl(client, x, cos_cache, sin_cache)?;
if (attn_scale - 1.0).abs() < 1e-7 {
return Ok(rotated);
}
numr::autograd::var_mul_scalar(&rotated, attn_scale as f64, client).map_err(Error::Numr)
}
#[cfg(test)]
mod tests {
use super::super::rope_standard::apply_rope_impl;
use super::*;
use crate::test_utils::cpu_setup;
use numr::autograd::Var;
use numr::runtime::cpu::CpuRuntime;
use numr::tensor::Tensor;
#[test]
fn test_rope_yarn_identity_scale_1() {
let (client, device) = cpu_setup();
let x_data: Vec<f32> = (0..16).map(|i| i as f32 * 0.1).collect();
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&x_data, &[1, 1, 2, 8], &device),
false,
);
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32; 8], &[2, 4], &device),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32; 8], &[2, 4], &device),
false,
);
let rope_out = apply_rope_impl(&client, &x, &cos, &sin).unwrap();
let yarn_out = apply_rope_yarn_impl(&client, &x, &cos, &sin, 1.0).unwrap();
let rope_data: Vec<f32> = rope_out.tensor().contiguous().to_vec();
let yarn_data: Vec<f32> = yarn_out.tensor().contiguous().to_vec();
for (i, (&a, &b)) in rope_data.iter().zip(yarn_data.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-5,
"mismatch at {}: rope={}, yarn={}",
i,
a,
b
);
}
}
#[test]
fn test_rope_yarn_scaling() {
let (client, device) = cpu_setup();
let x_data: Vec<f32> = (0..8).map(|i| i as f32 * 0.1).collect();
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&x_data, &[1, 1, 1, 8], &device),
false,
);
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32; 4], &[1, 4], &device),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32; 4], &[1, 4], &device),
false,
);
let rope_out = apply_rope_impl(&client, &x, &cos, &sin).unwrap();
let yarn_out = apply_rope_yarn_impl(&client, &x, &cos, &sin, 2.0).unwrap();
let rope_data: Vec<f32> = rope_out.tensor().contiguous().to_vec();
let yarn_data: Vec<f32> = yarn_out.tensor().contiguous().to_vec();
for (i, (&a, &b)) in rope_data.iter().zip(yarn_data.iter()).enumerate() {
assert!(
(b - a * 2.0).abs() < 1e-5,
"mismatch at {}: rope={}, yarn={} (expected {})",
i,
a,
b,
a * 2.0
);
}
}
}