use crate::error::{Error, Result};
use numr::autograd::{Var, var_add, var_cat, var_mul, var_narrow, var_reshape, var_sub};
use numr::ops::{IndexingOps, ScalarOps, ShapeOps, TensorOps, TypeConversionOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn apply_rope_packed_impl<R, C>(
client: &C,
x: &Var<R>,
cos_cache: &Var<R>,
sin_cache: &Var<R>,
position_ids: &Tensor<R>,
) -> Result<Var<R>>
where
R: Runtime<DType = numr::dtype::DType>,
C: RuntimeClient<R> + ScalarOps<R> + ShapeOps<R> + TypeConversionOps<R> + IndexingOps<R>,
R::Client: TensorOps<R> + ShapeOps<R> + TypeConversionOps<R>,
{
let x_shape = x.tensor().shape().to_vec();
if x_shape.len() != 3 {
return Err(Error::InvalidArgument {
arg: "x",
reason: format!(
"expected 3D [total_tokens, num_heads, head_dim], got {}D",
x_shape.len()
),
});
}
let total_tokens = x_shape[0];
let num_heads = x_shape[1];
let head_dim = x_shape[2];
if !head_dim.is_multiple_of(2) {
return Err(Error::InvalidArgument {
arg: "x",
reason: format!("head_dim D={} must be even for RoPE", head_dim),
});
}
let half_d = head_dim / 2;
let pid_shape = position_ids.shape();
if pid_shape.len() != 1 || pid_shape[0] != total_tokens {
return Err(Error::InvalidArgument {
arg: "position_ids",
reason: format!(
"expected 1D [total_tokens={}], got {:?}",
total_tokens, pid_shape
),
});
}
let cos_shape = cos_cache.tensor().shape();
let sin_shape = sin_cache.tensor().shape();
if cos_shape.len() != 2 || cos_shape[1] != half_d {
return Err(Error::InvalidArgument {
arg: "cos_cache",
reason: format!("expected [max_seq_len, {}], got {:?}", half_d, cos_shape),
});
}
if sin_shape.len() != 2 || sin_shape[1] != half_d {
return Err(Error::InvalidArgument {
arg: "sin_cache",
reason: format!("expected [max_seq_len, {}], got {:?}", half_d, sin_shape),
});
}
let cos_gathered = client
.embedding_lookup(cos_cache.tensor(), position_ids)
.map_err(Error::Numr)?;
let sin_gathered = client
.embedding_lookup(sin_cache.tensor(), position_ids)
.map_err(Error::Numr)?;
let x_dtype = x.tensor().dtype();
let cos_matched = if cos_gathered.dtype() != x_dtype {
let v = numr::autograd::var_cast(&Var::new(cos_gathered, false), x_dtype, client)
.map_err(Error::Numr)?;
v.tensor().clone()
} else {
cos_gathered
};
let sin_matched = if sin_gathered.dtype() != x_dtype {
let v = numr::autograd::var_cast(&Var::new(sin_gathered, false), x_dtype, client)
.map_err(Error::Numr)?;
v.tensor().clone()
} else {
sin_gathered
};
let cos_reshaped = var_reshape(&Var::new(cos_matched, false), &[total_tokens, 1, half_d])
.map_err(Error::Numr)?;
let sin_reshaped = var_reshape(&Var::new(sin_matched, false), &[total_tokens, 1, half_d])
.map_err(Error::Numr)?;
let x1 = var_narrow(x, -1, 0, half_d).map_err(Error::Numr)?;
let x2 = var_narrow(x, -1, half_d, half_d).map_err(Error::Numr)?;
let x1_cos = var_mul(&x1, &cos_reshaped, client).map_err(Error::Numr)?;
let x2_sin = var_mul(&x2, &sin_reshaped, client).map_err(Error::Numr)?;
let out1 = var_sub(&x1_cos, &x2_sin, client).map_err(Error::Numr)?;
let x1_sin = var_mul(&x1, &sin_reshaped, client).map_err(Error::Numr)?;
let x2_cos = var_mul(&x2, &cos_reshaped, client).map_err(Error::Numr)?;
let out2 = var_add(&x1_sin, &x2_cos, client).map_err(Error::Numr)?;
let _ = num_heads; var_cat(&[&out1, &out2], -1, client).map_err(Error::Numr)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ops::impl_generic::attention::apply_rope_impl;
use crate::test_utils::cpu_setup;
use numr::autograd::Var;
use numr::runtime::cpu::CpuRuntime;
use numr::tensor::Tensor;
fn seq_pids(n: usize, device: &numr::runtime::cpu::CpuDevice) -> Tensor<CpuRuntime> {
let ids: Vec<i32> = (0..n as i32).collect();
Tensor::<CpuRuntime>::from_slice(&ids, &[n], device)
}
#[test]
fn test_rope_packed_identity_cos1_sin0() {
let (client, device) = cpu_setup();
let total_tokens = 3;
let num_heads = 2;
let head_dim = 8;
let half_d = head_dim / 2;
let max_seq = 8;
let x_data: Vec<f32> = (0..total_tokens * num_heads * head_dim)
.map(|i| i as f32)
.collect();
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(
&x_data,
&[total_tokens, num_heads, head_dim],
&device,
),
false,
);
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(
&vec![1.0f32; max_seq * half_d],
&[max_seq, half_d],
&device,
),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(
&vec![0.0f32; max_seq * half_d],
&[max_seq, half_d],
&device,
),
false,
);
let pids = seq_pids(total_tokens, &device);
let out = apply_rope_packed_impl(&client, &x, &cos, &sin, &pids)
.expect("apply_rope_packed_impl failed");
let out_data = out.tensor().contiguous().unwrap().to_vec::<f32>();
assert_eq!(out_data.len(), x_data.len());
for (i, (&a, &b)) in out_data.iter().zip(x_data.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-5,
"identity test mismatch at {i}: got {a}, expected {b}"
);
}
}
#[test]
fn test_rope_packed_position_reset() {
let (client, device) = cpu_setup();
let num_heads = 1;
let head_dim = 4;
let half_d = head_dim / 2;
let max_seq = 8;
let cos_data: Vec<f32> = (0..max_seq * half_d)
.map(|i| (i as f32 * 0.5).cos())
.collect();
let sin_data: Vec<f32> = (0..max_seq * half_d)
.map(|i| (i as f32 * 0.5).sin())
.collect();
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(&cos_data, &[max_seq, half_d], &device),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(&sin_data, &[max_seq, half_d], &device),
false,
);
let x_data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0,
];
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&x_data, &[3, num_heads, head_dim], &device),
false,
);
let pids = Tensor::<CpuRuntime>::from_slice(&[0i32, 1, 0], &[3], &device);
let out = apply_rope_packed_impl(&client, &x, &cos, &sin, &pids)
.expect("apply_rope_packed_impl failed");
let out_data = out.tensor().contiguous().unwrap().to_vec::<f32>();
for d in 0..head_dim {
let t0_val = out_data[d];
let t2_val = out_data[2 * head_dim + d];
assert!(
(t0_val - t2_val).abs() < 1e-5,
"packed reset: token0 and token2 should match at dim {d}: {t0_val} vs {t2_val}"
);
}
let any_diff = (0..head_dim).any(|d| (out_data[d] - out_data[head_dim + d]).abs() > 1e-5);
assert!(
any_diff,
"tokens at different positions should produce different RoPE outputs"
);
}
#[test]
fn test_rope_packed_matches_standard_single_sequence() {
let (client, device) = cpu_setup();
let s = 4;
let h = 2;
let d = 8;
let half_d = d / 2;
let x_data: Vec<f32> = (0..s * h * d).map(|i| (i as f32 * 0.1).sin()).collect();
let cos_data: Vec<f32> = (0..s * half_d).map(|i| (i as f32 * 0.3).cos()).collect();
let sin_data: Vec<f32> = (0..s * half_d).map(|i| (i as f32 * 0.3).sin()).collect();
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(&cos_data, &[s, half_d], &device),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(&sin_data, &[s, half_d], &device),
false,
);
let x_packed = Var::new(
Tensor::<CpuRuntime>::from_slice(&x_data, &[s, h, d], &device),
false,
);
let pids = seq_pids(s, &device);
let out_packed =
apply_rope_packed_impl(&client, &x_packed, &cos, &sin, &pids).expect("packed failed");
let packed_vec = out_packed.tensor().contiguous().unwrap().to_vec::<f32>();
let mut x_4d_data = vec![0.0f32; s * h * d];
for sv in 0..s {
for hv in 0..h {
for dv in 0..d {
let src = sv * h * d + hv * d + dv;
let dst = hv * s * d + sv * d + dv;
x_4d_data[dst] = x_data[src];
}
}
}
let x_standard = Var::new(
Tensor::<CpuRuntime>::from_slice(&x_4d_data, &[1, h, s, d], &device),
false,
);
let out_standard =
apply_rope_impl(&client, &x_standard, &cos, &sin).expect("standard failed");
let standard_4d = out_standard.tensor().contiguous().unwrap().to_vec::<f32>();
let mut standard_vec = vec![0.0f32; s * h * d];
for sv in 0..s {
for hv in 0..h {
for dv in 0..d {
let src = hv * s * d + sv * d + dv;
let dst = sv * h * d + hv * d + dv;
standard_vec[dst] = standard_4d[src];
}
}
}
assert_eq!(packed_vec.len(), standard_vec.len());
for (i, (&a, &b)) in packed_vec.iter().zip(standard_vec.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-5,
"packed vs standard mismatch at {i}: packed={a}, standard={b}"
);
}
}
#[test]
fn test_rope_packed_invalid_odd_dim() {
let (client, device) = cpu_setup();
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32; 3], &[1, 1, 3], &device),
false,
);
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32; 4], &[4, 1], &device),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32; 4], &[4, 1], &device),
false,
);
let pids = Tensor::<CpuRuntime>::from_slice(&[0i32], &[1], &device);
let result = apply_rope_packed_impl(&client, &x, &cos, &sin, &pids);
assert!(result.is_err());
}
#[test]
fn test_rope_packed_invalid_wrong_ndim() {
let (client, device) = cpu_setup();
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32; 8], &[1, 1, 2, 4], &device),
false,
);
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32; 8], &[4, 2], &device),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32; 8], &[4, 2], &device),
false,
);
let pids = Tensor::<CpuRuntime>::from_slice(&[0i32; 2], &[2], &device);
let result = apply_rope_packed_impl(&client, &x, &cos, &sin, &pids);
assert!(result.is_err());
}
#[test]
fn test_rope_packed_dtype_f32() {
let (client, device) = cpu_setup();
let total = 2usize;
let h = 1usize;
let d = 4usize;
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(
&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
&[total, h, d],
&device,
),
false,
);
let cos = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32; 16], &[8, 2], &device),
false,
);
let sin = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32; 16], &[8, 2], &device),
false,
);
let pids = Tensor::<CpuRuntime>::from_slice(&[0i32, 1], &[2], &device);
let out = apply_rope_packed_impl(&client, &x, &cos, &sin, &pids).unwrap();
assert_eq!(out.tensor().shape(), &[total, h, d]);
}
}