use crate::{
array::Array,
error::{Result, check},
lm::cache::RopeOffset,
stream::default_stream,
};
pub const DEFAULT_BASE: f32 = 10000.0;
#[derive(Debug, Clone, Copy, derive_more::IsVariant)]
pub enum RopeOffsetRef<'a> {
Scalar(i32),
Array(&'a Array),
}
impl<'a> From<&'a RopeOffset> for RopeOffsetRef<'a> {
fn from(offset: &'a RopeOffset) -> Self {
match offset {
RopeOffset::Scalar(p) => RopeOffsetRef::Scalar(i32::try_from(*p).unwrap_or(i32::MAX)),
RopeOffset::Batch(arr) => RopeOffsetRef::Array(arr),
}
}
}
pub fn rope(
x: &Array,
dims: i32,
traditional: bool,
base: f32,
scale: f32,
offset: i32,
) -> Result<Array> {
let base_opt = mlxrs_sys::mlx_optional_float {
value: base,
has_value: true,
};
let null_freqs = Array(unsafe { mlxrs_sys::mlx_array_new() });
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fast_rope(
&mut out.0,
x.0,
dims,
traditional,
base_opt,
scale,
offset,
null_freqs.0,
default_stream(),
)
})?;
Ok(out)
}
pub fn rope_dynamic(
x: &Array,
dims: i32,
traditional: bool,
base: f32,
scale: f32,
offset: &Array,
) -> Result<Array> {
let base_opt = mlxrs_sys::mlx_optional_float {
value: base,
has_value: true,
};
let null_freqs = Array(unsafe { mlxrs_sys::mlx_array_new() });
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fast_rope_dynamic(
&mut out.0,
x.0,
dims,
traditional,
base_opt,
scale,
offset.0,
null_freqs.0,
default_stream(),
)
})?;
Ok(out)
}
pub fn rope_with_offset(
x: &Array,
dims: i32,
traditional: bool,
base: f32,
scale: f32,
offset: RopeOffsetRef<'_>,
) -> Result<Array> {
match offset {
RopeOffsetRef::Scalar(p) => rope(x, dims, traditional, base, scale, p),
RopeOffsetRef::Array(arr) => rope_dynamic(x, dims, traditional, base, scale, arr),
}
}
pub fn rope_with_freqs(
x: &Array,
dims: i32,
traditional: bool,
scale: f32,
offset: i32,
freqs: &Array,
) -> Result<Array> {
let base_absent = mlxrs_sys::mlx_optional_float {
value: 0.0,
has_value: false,
};
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fast_rope(
&mut out.0,
x.0,
dims,
traditional,
base_absent,
scale,
offset,
freqs.0,
default_stream(),
)
})?;
Ok(out)
}
pub fn rope_dynamic_with_freqs(
x: &Array,
dims: i32,
traditional: bool,
scale: f32,
offset: &Array,
freqs: &Array,
) -> Result<Array> {
let base_absent = mlxrs_sys::mlx_optional_float {
value: 0.0,
has_value: false,
};
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fast_rope_dynamic(
&mut out.0,
x.0,
dims,
traditional,
base_absent,
scale,
offset.0,
freqs.0,
default_stream(),
)
})?;
Ok(out)
}
pub fn rope_with_freqs_offset(
x: &Array,
dims: i32,
traditional: bool,
scale: f32,
offset: RopeOffsetRef<'_>,
freqs: &Array,
) -> Result<Array> {
match offset {
RopeOffsetRef::Scalar(p) => rope_with_freqs(x, dims, traditional, scale, p, freqs),
RopeOffsetRef::Array(arr) => rope_dynamic_with_freqs(x, dims, traditional, scale, arr, freqs),
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Rope {
pub dims: i32,
pub traditional: bool,
pub base: f32,
pub scale: f32,
}
impl Rope {
pub fn new(dims: i32, traditional: bool, base: f32, scale: f32) -> Self {
Self {
dims,
traditional,
base,
scale,
}
}
pub fn standard(dims: i32) -> Self {
Self::new(dims, false, DEFAULT_BASE, 1.0)
}
pub fn apply(&self, x: &Array, offset: i32) -> Result<Array> {
rope(
x,
self.dims,
self.traditional,
self.base,
self.scale,
offset,
)
}
pub fn apply_with_offset(&self, x: &Array, offset: RopeOffsetRef<'_>) -> Result<Array> {
rope_with_offset(
x,
self.dims,
self.traditional,
self.base,
self.scale,
offset,
)
}
}
#[cfg(test)]
#[allow(clippy::excessive_precision)]
mod tests {
use super::*;
const TOL: f32 = 1e-5;
fn input() -> Array {
Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], &(1, 1, 2, 4)).unwrap()
}
fn assert_close(got: &[f32], want: &[f32]) {
assert_eq!(got.len(), want.len(), "length mismatch");
for (i, (g, w)) in got.iter().zip(want).enumerate() {
assert!(
(g - w).abs() <= TOL,
"index {i}: got {g}, want {w} (|Δ|={})",
(g - w).abs()
);
}
}
#[test]
fn non_traditional_offset0() {
let x = input();
let mut y = rope(&x, 4, false, DEFAULT_BASE, 1.0, 0).unwrap();
assert_close(
&y.to_vec::<f32>().unwrap(),
&[
0.0, 1.0, 2.0, 3.0, -2.8876167, 4.9297512, 6.6076978, 7.0496492, ],
);
}
#[test]
fn non_traditional_offset2() {
let x = input();
let mut y = rope(&x, 4, false, DEFAULT_BASE, 1.0, 2).unwrap();
assert_close(
&y.to_vec::<f32>().unwrap(),
&[
-1.8185949, 0.9398040, -0.8322937, 3.0193987, -4.8066900, 4.7877817, -5.3754749, 7.1468277, ],
);
}
#[test]
fn traditional_offset0() {
let x = input();
let mut y = rope(&x, 4, true, DEFAULT_BASE, 1.0, 0).unwrap();
assert_close(
&y.to_vec::<f32>().unwrap(),
&[
0.0, 1.0, 2.0, 3.0, -2.0461457, 6.0673955, 5.9297012, 7.0596490, ],
);
}
#[test]
fn traditional_offset2() {
let x = input();
let mut y = rope(&x, 4, true, DEFAULT_BASE, 1.0, 2).unwrap();
assert_close(
&y.to_vec::<f32>().unwrap(),
&[
-0.9092974, -0.4161468, 1.9396040, 3.0393974, -4.6655700, -4.3854825, 5.7873317, 7.1768232, ],
);
}
#[test]
fn scale_half_is_position_interpolation() {
let x = input();
let mut y = rope(&x, 4, false, DEFAULT_BASE, 0.5, 0).unwrap();
assert_close(
&y.to_vec::<f32>().unwrap(),
&[
0.0, 1.0, 2.0, 3.0, 0.6337770, 4.9649376, 7.1831975, 7.0249124, ],
);
}
#[test]
fn partial_dims_pass_through_tail() {
let x = input();
let mut y = rope(&x, 2, false, DEFAULT_BASE, 1.0, 0).unwrap();
assert_close(
&y.to_vec::<f32>().unwrap(),
&[
0.0, 1.0, 2.0, 3.0, -2.0461457, 6.0673955, 6.0, 7.0, ],
);
}
#[test]
fn config_apply_matches_free_fn() {
let x = input();
let r = Rope::new(4, false, DEFAULT_BASE, 1.0);
let mut via_config = r.apply(&x, 2).unwrap();
let mut via_fn = rope(&x, 4, false, DEFAULT_BASE, 1.0, 2).unwrap();
assert_close(
&via_config.to_vec::<f32>().unwrap(),
&via_fn.to_vec::<f32>().unwrap(),
);
}
#[test]
fn standard_uses_mlx_defaults() {
let r = Rope::standard(8);
assert_eq!(r.dims, 8);
assert!(!r.traditional);
assert_eq!(r.base, DEFAULT_BASE);
assert_eq!(r.scale, 1.0);
}
fn batch_input() -> Array {
Array::from_slice::<f32>(
&[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, ],
&(2, 1, 2, 4),
)
.unwrap()
}
#[test]
fn dynamic_per_row_offsets() {
let x = batch_input();
let offset = Array::from_slice::<i32>(&[0, 2], &(2,)).unwrap();
let mut y = rope_dynamic(&x, 4, false, DEFAULT_BASE, 1.0, &offset).unwrap();
assert_close(
&y.to_vec::<f32>().unwrap(),
&[
0.0, 1.0, 2.0, 3.0, -2.8876167, 4.9297512, 6.6076978, 7.0496492, -1.8185949, 0.9398040, -0.8322937, 3.0193987, -4.8066900, 4.7877817, -5.3754749, 7.1468277, ],
);
}
#[test]
fn dynamic_offsets_swapped_rotate_independently() {
let x = batch_input();
let offset = Array::from_slice::<i32>(&[2, 0], &(2,)).unwrap();
let mut y = rope_dynamic(&x, 4, false, DEFAULT_BASE, 1.0, &offset).unwrap();
assert_close(
&y.to_vec::<f32>().unwrap(),
&[
-1.8185949, 0.9398040, -0.8322937, 3.0193987, -4.8066900, 4.7877817, -5.3754749, 7.1468277, 0.0, 1.0, 2.0, 3.0, -2.8876167, 4.9297512, 6.6076978, 7.0496492, ],
);
}
#[test]
fn dynamic_scalar_array_matches_scalar_rope() {
let x = input();
let offset = Array::from_slice::<i32>(&[2], &(1,)).unwrap();
let mut via_dynamic = rope_dynamic(&x, 4, false, DEFAULT_BASE, 1.0, &offset).unwrap();
let mut via_scalar = rope(&x, 4, false, DEFAULT_BASE, 1.0, 2).unwrap();
assert_close(
&via_dynamic.to_vec::<f32>().unwrap(),
&via_scalar.to_vec::<f32>().unwrap(),
);
}
#[test]
fn rope_with_offset_dispatches_both_arms() {
let x = batch_input();
let mut via_scalar =
rope_with_offset(&x, 4, false, DEFAULT_BASE, 1.0, RopeOffsetRef::Scalar(2)).unwrap();
let mut via_plain = rope(&x, 4, false, DEFAULT_BASE, 1.0, 2).unwrap();
assert_close(
&via_scalar.to_vec::<f32>().unwrap(),
&via_plain.to_vec::<f32>().unwrap(),
);
let offset = Array::from_slice::<i32>(&[0, 2], &(2,)).unwrap();
let mut via_dispatch = rope_with_offset(
&x,
4,
false,
DEFAULT_BASE,
1.0,
RopeOffsetRef::Array(&offset),
)
.unwrap();
let mut via_direct = rope_dynamic(&x, 4, false, DEFAULT_BASE, 1.0, &offset).unwrap();
assert_close(
&via_dispatch.to_vec::<f32>().unwrap(),
&via_direct.to_vec::<f32>().unwrap(),
);
}
#[test]
fn apply_with_offset_matches_free_fn() {
let x = batch_input();
let r = Rope::standard(4);
let offset = Array::from_slice::<i32>(&[0, 2], &(2,)).unwrap();
let mut via_config = r
.apply_with_offset(&x, RopeOffsetRef::Array(&offset))
.unwrap();
let mut via_fn = rope_dynamic(&x, 4, false, DEFAULT_BASE, 1.0, &offset).unwrap();
assert_close(
&via_config.to_vec::<f32>().unwrap(),
&via_fn.to_vec::<f32>().unwrap(),
);
}
#[test]
fn rope_offset_ref_borrows_cache_offset_without_clone() {
let x = batch_input();
let arr = Array::from_slice::<i32>(&[0, 2], &(2,)).unwrap();
let owned = RopeOffset::Batch(arr.try_clone().unwrap());
let r = Rope::standard(4);
let mut via_bridge = r.apply_with_offset(&x, (&owned).into()).unwrap();
let mut via_direct = rope_dynamic(&x, 4, false, DEFAULT_BASE, 1.0, &arr).unwrap();
assert_close(
&via_bridge.to_vec::<f32>().unwrap(),
&via_direct.to_vec::<f32>().unwrap(),
);
let owned_scalar = RopeOffset::Scalar(2);
let mut via_scalar_bridge = r.apply_with_offset(&x, (&owned_scalar).into()).unwrap();
let mut via_scalar = rope(&x, 4, false, DEFAULT_BASE, 1.0, 2).unwrap();
assert_close(
&via_scalar_bridge.to_vec::<f32>().unwrap(),
&via_scalar.to_vec::<f32>().unwrap(),
);
}
#[test]
fn rope_offset_ref_scalar_saturates_instead_of_wrapping() {
let huge = RopeOffset::Scalar(usize::MAX);
match (&huge).into() {
RopeOffsetRef::Scalar(p) => assert_eq!(p, i32::MAX),
RopeOffsetRef::Array(_) => panic!("scalar offset must map to a scalar ref"),
}
let at_max = RopeOffset::Scalar(i32::MAX as usize);
match (&at_max).into() {
RopeOffsetRef::Scalar(p) => assert_eq!(p, i32::MAX),
RopeOffsetRef::Array(_) => unreachable!(),
}
}
fn base_freqs(base: f64, dims: usize) -> Array {
let half = dims / 2;
let mut f = Vec::with_capacity(half);
for i in 0..half {
f.push((base.powf((2 * i) as f64 / dims as f64)) as f32);
}
Array::from_slice::<f32>(&f, &(half,)).unwrap()
}
#[test]
fn freqs_path_matches_base_path() {
let x = input();
let freqs = base_freqs(DEFAULT_BASE as f64, 4);
let mut via_freqs = rope_with_freqs(&x, 4, false, 1.0, 2, &freqs).unwrap();
let mut via_base = rope(&x, 4, false, DEFAULT_BASE, 1.0, 2).unwrap();
assert_close(
&via_freqs.to_vec::<f32>().unwrap(),
&via_base.to_vec::<f32>().unwrap(),
);
}
#[test]
fn freqs_path_traditional_matches_base_path() {
let x = input();
let freqs = base_freqs(DEFAULT_BASE as f64, 4);
let mut via_freqs = rope_with_freqs(&x, 4, true, 1.0, 0, &freqs).unwrap();
let mut via_base = rope(&x, 4, true, DEFAULT_BASE, 1.0, 0).unwrap();
assert_close(
&via_freqs.to_vec::<f32>().unwrap(),
&via_base.to_vec::<f32>().unwrap(),
);
}
#[test]
fn freqs_dynamic_path_matches_base_dynamic() {
let x = batch_input();
let freqs = base_freqs(DEFAULT_BASE as f64, 4);
let offset = Array::from_slice::<i32>(&[0, 2], &(2,)).unwrap();
let mut via_freqs = rope_dynamic_with_freqs(&x, 4, false, 1.0, &offset, &freqs).unwrap();
let mut via_base = rope_dynamic(&x, 4, false, DEFAULT_BASE, 1.0, &offset).unwrap();
assert_close(
&via_freqs.to_vec::<f32>().unwrap(),
&via_base.to_vec::<f32>().unwrap(),
);
}
#[test]
fn freqs_offset_dispatch_matches_both_arms() {
let x = batch_input();
let freqs = base_freqs(DEFAULT_BASE as f64, 4);
let mut via_scalar =
rope_with_freqs_offset(&x, 4, false, 1.0, RopeOffsetRef::Scalar(2), &freqs).unwrap();
let mut via_direct_scalar = rope_with_freqs(&x, 4, false, 1.0, 2, &freqs).unwrap();
assert_close(
&via_scalar.to_vec::<f32>().unwrap(),
&via_direct_scalar.to_vec::<f32>().unwrap(),
);
let offset = Array::from_slice::<i32>(&[0, 2], &(2,)).unwrap();
let mut via_array =
rope_with_freqs_offset(&x, 4, false, 1.0, RopeOffsetRef::Array(&offset), &freqs).unwrap();
let mut via_direct_array = rope_dynamic_with_freqs(&x, 4, false, 1.0, &offset, &freqs).unwrap();
assert_close(
&via_array.to_vec::<f32>().unwrap(),
&via_direct_array.to_vec::<f32>().unwrap(),
);
}
}