use crate::{
array::Array,
error::{
ArithmeticOverflowPayload, Error, LengthMismatchPayload, NonFiniteScalarPayload,
OutOfRangePayload, RankMismatchPayload, Result,
},
lm::nn::rope::{RopeOffsetRef, rope_with_freqs_offset},
};
use smol_str::format_smolstr;
use super::rope::DEFAULT_BASE;
fn freqs_half(dims: i32) -> Result<usize> {
if dims <= 0 || dims % 2 != 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"scaled RoPE: dims",
"must be a positive even number",
format!("{dims}"),
)));
}
Ok((dims / 2) as usize)
}
fn base_pair_freqs(base: f64, dims: i32, half: usize) -> Vec<f64> {
let dims_f = f64::from(dims);
(0..half)
.map(|i| base.powf((2 * i) as f64 / dims_f))
.collect()
}
fn freqs_array(freqs: &[f64]) -> Result<Array> {
let mut buf: Vec<f32> = Vec::with_capacity(freqs.len());
for &v in freqs {
let f = v as f32;
if !f.is_finite() || f <= 0.0 || !(1.0f32 / f).is_finite() {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"scaled RoPE freq (check scaling config: base / factor / embeddings / betas)",
"must be positive, finite, AND have a finite f32 reciprocal (freqs are inverted as 1/freqs by mlx_fast_rope; zero / subnormal would become +Inf at apply time)",
format_smolstr!("{v}"),
)));
}
buf.push(f);
}
Array::from_slice::<f32>(&buf, &(freqs.len(),))
}
fn finite_scalar(value: f64, what: &'static str) -> Result<f32> {
let v = value as f32;
if v.is_finite() {
Ok(v)
} else {
Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
what, value,
)))
}
}
fn require_finite_input(value: f32, what: &'static str) -> Result<()> {
if value.is_finite() {
Ok(())
} else {
Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
what,
value as f64,
)))
}
}
fn require_positive_input(value: f32, what: &'static str) -> Result<()> {
if value.is_finite() && value > 0.0 {
Ok(())
} else {
Err(Error::OutOfRange(OutOfRangePayload::new(
what,
"must be a positive finite number",
format_smolstr!("{value}"),
)))
}
}
fn scale_leading_dims(x: &Array, dims: i32, mscale: f32) -> Result<Array> {
let ndim = x.ndim();
if ndim == 0 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"scaled RoPE input must have rank >= 1 (at least one axis)",
0,
x.shape().to_vec(),
)));
}
let scalar = Array::from_slice::<f32>(&[mscale], &(1usize,))?.astype(x.dtype()?)?;
let last = ndim - 1;
let head_dim_usize = x.shape()[last];
let head_dim = i32::try_from(head_dim_usize).map_err(|_| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"scaled RoPE head_dim exceeds i32::MAX",
"i32",
[("head_dim", head_dim_usize as u64)],
))
})?;
if head_dim == dims {
return x.multiply(&scalar);
}
if head_dim < dims {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"scaled RoPE: dims (configured rotation width vs input last-axis head_dim)",
"must be <= head_dim (input last-axis)",
format_smolstr!("dims={dims}, head_dim={head_dim}"),
)));
}
let axis = last as i32;
let parts = x.split_sections(&[dims], axis)?;
let head = &parts[0];
let tail = &parts[1];
let scaled_head = head.multiply(&scalar)?;
scaled_head.concatenate_with(&[tail], axis)
}
#[derive(Debug)]
pub struct Llama3Rope {
dims: i32,
traditional: bool,
freqs: Array,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Llama3ScalingConfig {
pub factor: f32,
pub low_freq_factor: f32,
pub high_freq_factor: f32,
pub original_max_position_embeddings: f32,
}
impl Llama3ScalingConfig {
pub fn new(
factor: f32,
low_freq_factor: f32,
high_freq_factor: f32,
original_max_position_embeddings: f32,
) -> Self {
Self {
factor,
low_freq_factor,
high_freq_factor,
original_max_position_embeddings,
}
}
pub fn with_factor(factor: f32) -> Self {
Self::new(factor, 1.0, 4.0, 8192.0)
}
}
impl Llama3Rope {
pub fn new(
dims: i32,
base: f32,
traditional: bool,
scaling: Llama3ScalingConfig,
) -> Result<Self> {
let half = freqs_half(dims)?;
require_positive_input(base, "base")?;
require_positive_input(scaling.factor, "factor")?;
require_positive_input(scaling.low_freq_factor, "low_freq_factor")?;
require_positive_input(scaling.high_freq_factor, "high_freq_factor")?;
require_positive_input(
scaling.original_max_position_embeddings,
"original_max_position_embeddings",
)?;
let freqs = freqs_array(&Self::compute_freqs(f64::from(base), dims, half, scaling))?;
Ok(Self {
dims,
traditional,
freqs,
})
}
pub fn standard(dims: i32, scaling: Llama3ScalingConfig) -> Result<Self> {
Self::new(dims, DEFAULT_BASE, false, scaling)
}
fn compute_freqs(base: f64, dims: i32, half: usize, c: Llama3ScalingConfig) -> Vec<f64> {
let factor = f64::from(c.factor);
let low_ff = f64::from(c.low_freq_factor);
let high_ff = f64::from(c.high_freq_factor);
let old_ctx = f64::from(c.original_max_position_embeddings);
let low_wl = old_ctx / low_ff;
let high_wl = old_ctx / high_ff;
base_pair_freqs(base, dims, half)
.into_iter()
.map(|f| {
let wavelen = 2.0 * std::f64::consts::PI * f;
if wavelen > low_wl {
f * factor
} else if wavelen > high_wl {
let s = (old_ctx / wavelen - low_ff) / (high_ff - low_ff);
f / ((1.0 - s) / factor + s)
} else {
f
}
})
.collect()
}
pub fn apply(&self, x: &Array, offset: i32) -> Result<Array> {
self.apply_with_offset(x, RopeOffsetRef::Scalar(offset))
}
pub fn apply_with_offset(&self, x: &Array, offset: RopeOffsetRef<'_>) -> Result<Array> {
rope_with_freqs_offset(x, self.dims, self.traditional, 1.0, offset, &self.freqs)
}
}
#[derive(Debug)]
pub struct SuScaledRope {
dims: i32,
scale: f32,
freqs: Array,
}
impl SuScaledRope {
pub fn new(
dims: i32,
base: f32,
max_position_embeddings: i32,
original_max_position_embeddings: i32,
long_factor: &[f32],
long_mscale: Option<f32>,
) -> Result<Self> {
let half = freqs_half(dims)?;
if long_factor.len() != half {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"SuScaledRoPE long_factor length vs dims/2",
half,
long_factor.len(),
)));
}
require_positive_input(base, "base")?;
for &lf in long_factor {
require_positive_input(lf, "long_factor entry")?;
}
let base_freqs = base_pair_freqs(f64::from(base), dims, half);
let freqs: Vec<f64> = base_freqs
.into_iter()
.zip(long_factor)
.map(|(f, &lf)| f64::from(lf) * f)
.collect();
let freqs = freqs_array(&freqs)?;
let scale = match long_mscale {
Some(mscale) => {
require_finite_input(mscale, "long_mscale")?;
mscale
}
None => {
if original_max_position_embeddings <= 0 || max_position_embeddings <= 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"SuScaledRoPE max_position_embeddings / original_max_position_embeddings (required positive to derive the input scale)",
"must both be > 0",
format_smolstr!(
"max_position_embeddings={max_position_embeddings}, original_max_position_embeddings={original_max_position_embeddings}"
),
)));
}
let factor =
f64::from(max_position_embeddings) / f64::from(original_max_position_embeddings);
if factor <= 1.0 {
1.0
} else {
if original_max_position_embeddings <= 1 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"SuScaledRoPE original_max_position_embeddings (derived-scale path, factor > 1)",
"must be > 1 (ln(1) = 0 divides the scale to +Inf)",
format_smolstr!("{original_max_position_embeddings}"),
)));
}
finite_scalar(
(1.0 + factor.ln() / f64::from(original_max_position_embeddings).ln()).sqrt(),
"input scale",
)?
}
}
};
Ok(Self { dims, scale, freqs })
}
pub fn scale(&self) -> f32 {
self.scale
}
pub fn apply(&self, x: &Array, offset: i32) -> Result<Array> {
self.apply_with_offset(x, RopeOffsetRef::Scalar(offset))
}
pub fn apply_with_offset(&self, x: &Array, offset: RopeOffsetRef<'_>) -> Result<Array> {
if self.scale == 1.0 {
return rope_with_freqs_offset(x, self.dims, false, 1.0, offset, &self.freqs);
}
let scaled = scale_leading_dims(x, self.dims, self.scale)?;
rope_with_freqs_offset(&scaled, self.dims, false, 1.0, offset, &self.freqs)
}
}
#[derive(Debug)]
pub struct YarnRope {
dims: i32,
traditional: bool,
mscale: f32,
freqs: Array,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct YarnConfig {
pub scaling_factor: f32,
pub original_max_position_embeddings: i32,
pub beta_fast: f32,
pub beta_slow: f32,
pub mscale: f32,
pub mscale_all_dim: f32,
}
impl YarnConfig {
pub fn new(scaling_factor: f32) -> Self {
Self {
scaling_factor,
original_max_position_embeddings: 4096,
beta_fast: 32.0,
beta_slow: 1.0,
mscale: 1.0,
mscale_all_dim: 0.0,
}
}
}
impl YarnRope {
pub fn new(dims: i32, base: f32, traditional: bool, config: YarnConfig) -> Result<Self> {
let half = freqs_half(dims)?;
require_finite_input(base, "base")?;
require_finite_input(config.scaling_factor, "scaling_factor")?;
require_finite_input(config.beta_fast, "beta_fast")?;
require_finite_input(config.beta_slow, "beta_slow")?;
require_finite_input(config.mscale, "mscale")?;
require_finite_input(config.mscale_all_dim, "mscale_all_dim")?;
let base = f64::from(base);
let dims_f = f64::from(dims);
let scaling_factor = f64::from(config.scaling_factor);
let orig_max = f64::from(config.original_max_position_embeddings);
if base <= 1.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"YarnRoPE: base",
"must be > 1 to derive correction dims",
format!("{base}"),
)));
}
if config.original_max_position_embeddings <= 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"YarnRoPE: original_max_position_embeddings",
"must be positive",
format!("{}", config.original_max_position_embeddings),
)));
}
if config.beta_fast <= 0.0 || config.beta_slow <= 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"YarnRoPE beta_fast / beta_slow",
"must both be > 0",
format_smolstr!(
"beta_fast={}, beta_slow={}",
config.beta_fast,
config.beta_slow
),
)));
}
if config.scaling_factor <= 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"YarnRoPE: scaling_factor",
"must be a positive value",
format!("{}", config.scaling_factor),
)));
}
let find_correction_dim = |num_rotations: f64| {
(dims_f * (orig_max / (num_rotations * 2.0 * std::f64::consts::PI)).ln()) / (2.0 * base.ln())
};
let low = find_correction_dim(f64::from(config.beta_fast)).floor();
let high = find_correction_dim(f64::from(config.beta_slow)).ceil();
let low = low.max(0.0);
let high = high.min(dims_f - 1.0);
let get_mscale = |scale: f64, mscale: f64| {
if scale <= 1.0 {
1.0
} else {
0.1 * mscale * scale.ln() + 1.0
}
};
let mscale = finite_scalar(
get_mscale(scaling_factor, f64::from(config.mscale))
/ get_mscale(scaling_factor, f64::from(config.mscale_all_dim)),
"mscale",
)?;
let extra = base_pair_freqs(base, dims, half);
let ramp_max = if (low - high).abs() < f64::EPSILON {
high + 0.001
} else {
high
};
let freqs: Vec<f64> = extra
.into_iter()
.enumerate()
.map(|(i, freq_extra)| {
let freq_inter = scaling_factor * freq_extra;
let linear = (i as f64 - low) / (ramp_max - low);
let ramp = linear.clamp(0.0, 1.0);
let freq_mask = 1.0 - ramp;
(freq_inter * freq_extra) / (freq_inter * freq_mask + freq_extra * (1.0 - freq_mask))
})
.collect();
let freqs = freqs_array(&freqs)?;
Ok(Self {
dims,
traditional,
mscale,
freqs,
})
}
pub fn standard(dims: i32, config: YarnConfig) -> Result<Self> {
Self::new(dims, DEFAULT_BASE, false, config)
}
pub fn mscale(&self) -> f32 {
self.mscale
}
pub fn apply(&self, x: &Array, offset: i32) -> Result<Array> {
self.apply_with_offset(x, RopeOffsetRef::Scalar(offset))
}
pub fn apply_with_offset(&self, x: &Array, offset: RopeOffsetRef<'_>) -> Result<Array> {
if (self.mscale - 1.0).abs() < f32::EPSILON {
rope_with_freqs_offset(x, self.dims, self.traditional, 1.0, offset, &self.freqs)
} else {
let scaled = scale_leading_dims(x, self.dims, self.mscale)?;
rope_with_freqs_offset(
&scaled,
self.dims,
self.traditional,
1.0,
offset,
&self.freqs,
)
}
}
}
#[cfg(test)]
#[allow(clippy::excessive_precision)]
mod tests {
use super::*;
use crate::{dtype::Dtype, lm::nn::rope::rope_with_freqs};
const TOL: f32 = 1e-5;
fn assert_close(got: &[f32], want: &[f32]) {
assert_eq!(got.len(), want.len(), "length mismatch");
for (i, (g, w)) in got.iter().zip(want).enumerate() {
assert!(
(g - w).abs() <= TOL,
"index {i}: got {g}, want {w} (|Δ|={})",
(g - w).abs()
);
}
}
fn input(head_dim: usize) -> Array {
let n = 2 * head_dim;
let data: Vec<f32> = (0..n).map(|v| v as f32).collect();
Array::from_slice::<f32>(&data, &(1usize, 1usize, 2usize, head_dim)).unwrap()
}
#[test]
fn freqs_half_rejects_odd_and_nonpositive() {
assert!(freqs_half(0).is_err());
assert!(freqs_half(-2).is_err());
assert!(freqs_half(3).is_err());
assert_eq!(freqs_half(8).unwrap(), 4);
}
#[test]
fn base_pair_freqs_matches_formula() {
let f = base_pair_freqs(10000.0, 8, 4);
let got: Vec<f32> = f.iter().map(|&v| v as f32).collect();
assert_close(&got, &[1.0, 10.0, 100.0, 1000.0]);
}
#[test]
fn llama3_freqs_hand_traced() {
let c = Llama3ScalingConfig::new(8.0, 1.0, 4.0, 8192.0);
let f = Llama3Rope::compute_freqs(10000.0, 8, 4, c);
let got: Vec<f32> = f.iter().map(|&v| v as f32).collect();
assert_close(&got, &[1.0, 10.0, 100.0, 4681.4824]);
}
#[test]
fn llama3_high_band_low_factor_is_unscaled() {
let c = Llama3ScalingConfig::new(2.0, 1000.0, 4000.0, 8192.0);
let f = Llama3Rope::compute_freqs(10000.0, 8, 4, c);
assert_close(
&[f[1] as f32, f[2] as f32, f[3] as f32],
&[20.0, 200.0, 2000.0],
);
}
#[test]
fn llama3_apply_matches_freqs_path() {
let x = input(8);
let c = Llama3ScalingConfig::with_factor(8.0);
let r = Llama3Rope::new(8, DEFAULT_BASE, false, c).unwrap();
let freqs = freqs_array(&Llama3Rope::compute_freqs(f64::from(DEFAULT_BASE), 8, 4, c)).unwrap();
let mut via_apply = r.apply(&x, 3).unwrap();
let mut via_freqs = rope_with_freqs(&x, 8, false, 1.0, 3, &freqs).unwrap();
assert_close(
&via_apply.to_vec::<f32>().unwrap(),
&via_freqs.to_vec::<f32>().unwrap(),
);
}
#[test]
fn llama3_rejects_nonpositive_inputs() {
assert!(
Llama3Rope::new(
8,
DEFAULT_BASE,
false,
Llama3ScalingConfig::new(0.0, 1.0, 4.0, 8192.0)
)
.is_err(),
"factor=0 must be rejected"
);
assert!(
Llama3Rope::new(
8,
DEFAULT_BASE,
false,
Llama3ScalingConfig::new(-8.0, 1.0, 4.0, 8192.0)
)
.is_err(),
"negative factor must be rejected"
);
assert!(
Llama3Rope::new(8, 0.0, false, Llama3ScalingConfig::with_factor(8.0)).is_err(),
"base=0 must be rejected"
);
assert!(
Llama3Rope::new(8, -10000.0, false, Llama3ScalingConfig::with_factor(8.0)).is_err(),
"negative base must be rejected"
);
assert!(
Llama3Rope::new(
8,
DEFAULT_BASE,
false,
Llama3ScalingConfig::new(8.0, 0.0, 4.0, 8192.0)
)
.is_err(),
"low_freq_factor=0 must be rejected"
);
assert!(
Llama3Rope::new(
8,
DEFAULT_BASE,
false,
Llama3ScalingConfig::new(8.0, 1.0, 0.0, 8192.0)
)
.is_err(),
"high_freq_factor=0 must be rejected"
);
assert!(
Llama3Rope::new(
8,
DEFAULT_BASE,
false,
Llama3ScalingConfig::new(8.0, 1.0, 4.0, 0.0)
)
.is_err(),
"original_max_position_embeddings=0 must be rejected"
);
assert!(
Llama3Rope::new(
8,
DEFAULT_BASE,
false,
Llama3ScalingConfig::new(8.0, 1.0, 4.0, -8192.0)
)
.is_err(),
"negative original_max_position_embeddings must be rejected"
);
}
#[test]
fn su_scaled_freqs_apply_long_factor() {
let long_factor = [1.0f32, 2.0, 3.0, 4.0];
let r = SuScaledRope::new(8, DEFAULT_BASE, 131072, 4096, &long_factor, None).unwrap();
let freqs = Array::from_slice::<f32>(&[1.0, 20.0, 300.0, 4000.0], &(4usize,)).unwrap();
let x = input(8);
let scalar = Array::from_slice::<f32>(&[r.scale()], &(1usize,)).unwrap();
let scaled = x.multiply(&scalar).unwrap();
let mut manual = rope_with_freqs(&scaled, 8, false, 1.0, 3, &freqs).unwrap();
let mut via_apply = r.apply(&x, 3).unwrap();
assert_close(
&via_apply.to_vec::<f32>().unwrap(),
&manual.to_vec::<f32>().unwrap(),
);
}
#[test]
fn su_scaled_default_scale_hand_traced() {
let r = SuScaledRope::new(8, DEFAULT_BASE, 16384, 4096, &[1.0, 1.0, 1.0, 1.0], None).unwrap();
assert!(
(r.scale() - 1.080123).abs() <= TOL,
"scale {} != 1.080123",
r.scale()
);
}
#[test]
fn su_scaled_factor_le_one_scale_is_one() {
let r = SuScaledRope::new(8, DEFAULT_BASE, 4096, 4096, &[1.0, 1.0, 1.0, 1.0], None).unwrap();
assert_eq!(r.scale(), 1.0);
}
#[test]
fn su_scaled_long_mscale_override() {
let r = SuScaledRope::new(8, DEFAULT_BASE, 131072, 4096, &[1.0; 4], Some(2.5)).unwrap();
assert_eq!(r.scale(), 2.5);
}
#[test]
fn su_scaled_rejects_wrong_long_factor_len() {
assert!(SuScaledRope::new(8, DEFAULT_BASE, 131072, 4096, &[1.0, 2.0], None).is_err());
}
#[test]
fn su_scaled_apply_equals_manual_scale_then_freqs() {
let x = input(8);
let long_factor = [1.0f32, 2.0, 3.0, 4.0];
let r = SuScaledRope::new(8, DEFAULT_BASE, 16384, 4096, &long_factor, None).unwrap();
let scale = r.scale();
let freqs = {
let base = base_pair_freqs(f64::from(DEFAULT_BASE), 8, 4);
let v: Vec<f64> = base
.into_iter()
.zip(long_factor)
.map(|(f, lf)| f64::from(lf) * f)
.collect();
freqs_array(&v).unwrap()
};
let scalar = Array::from_slice::<f32>(&[scale], &(1usize,)).unwrap();
let scaled = x.multiply(&scalar).unwrap();
let mut manual = rope_with_freqs(&scaled, 8, false, 1.0, 5, &freqs).unwrap();
let mut via_apply = r.apply(&x, 5).unwrap();
assert_close(
&via_apply.to_vec::<f32>().unwrap(),
&manual.to_vec::<f32>().unwrap(),
);
}
#[test]
fn su_scaled_rejects_nonpositive_embeddings_in_derived_scale() {
assert!(SuScaledRope::new(8, DEFAULT_BASE, 16384, 0, &[1.0; 4], None).is_err());
assert!(SuScaledRope::new(8, DEFAULT_BASE, 16384, -4096, &[1.0; 4], None).is_err());
assert!(SuScaledRope::new(8, DEFAULT_BASE, 0, 4096, &[1.0; 4], None).is_err());
assert!(SuScaledRope::new(8, DEFAULT_BASE, -1, 4096, &[1.0; 4], None).is_err());
}
#[test]
fn su_scaled_override_skips_embeddings_validation() {
let r = SuScaledRope::new(8, DEFAULT_BASE, 0, 0, &[1.0; 4], Some(1.5)).unwrap();
assert_eq!(r.scale(), 1.5);
}
#[test]
fn su_scaled_rejects_orig_max_one_on_derived_path() {
let r = SuScaledRope::new(8, DEFAULT_BASE, 2, 1, &[1.0; 4], None);
assert!(r.is_err(), "orig_max=1 with factor>1 must be rejected");
if let Ok(r) = r {
assert!(r.scale().is_finite(), "stored scale must be finite");
}
}
#[test]
fn su_scaled_rejects_nonfinite_float_inputs() {
assert!(
SuScaledRope::new(8, f32::NAN, 16384, 4096, &[1.0; 4], None).is_err(),
"NaN base"
);
assert!(
SuScaledRope::new(8, f32::INFINITY, 16384, 4096, &[1.0; 4], None).is_err(),
"Inf base"
);
assert!(
SuScaledRope::new(
8,
DEFAULT_BASE,
16384,
4096,
&[f32::NAN, 1.0, 1.0, 1.0],
None
)
.is_err(),
"NaN long_factor entry"
);
assert!(
SuScaledRope::new(8, DEFAULT_BASE, 16384, 4096, &[1.0; 4], Some(f32::INFINITY)).is_err(),
"Inf long_mscale override"
);
}
#[test]
fn su_scaled_rejects_nonpositive_long_factor_or_base() {
assert!(
SuScaledRope::new(8, DEFAULT_BASE, 16384, 4096, &[0.0, 1.0, 1.0, 1.0], None).is_err(),
"zero long_factor entry must be rejected"
);
assert!(
SuScaledRope::new(8, DEFAULT_BASE, 16384, 4096, &[1.0, -2.0, 1.0, 1.0], None).is_err(),
"negative long_factor entry must be rejected"
);
assert!(
SuScaledRope::new(8, 0.0, 16384, 4096, &[1.0; 4], None).is_err(),
"base=0 must be rejected"
);
assert!(
SuScaledRope::new(8, -10000.0, 16384, 4096, &[1.0; 4], None).is_err(),
"negative base must be rejected"
);
}
#[test]
fn su_scaled_rejects_subnormal_long_factor_with_inf_reciprocal() {
assert!(
SuScaledRope::new(
2,
DEFAULT_BASE,
16384,
4096,
&[f32::from_bits(1)],
Some(1.0)
)
.is_err(),
"subnormal long_factor entry (1/f overflows) must be rejected at construction"
);
}
#[test]
fn su_scaled_valid_inputs_yield_finite_scale_and_freqs() {
let r = SuScaledRope::new(8, DEFAULT_BASE, 16384, 4096, &[1.0, 2.0, 3.0, 4.0], None).unwrap();
assert!(r.scale().is_finite(), "scale must be finite");
let mut freqs = r.freqs.try_clone().unwrap();
for v in freqs.to_vec::<f32>().unwrap() {
assert!(v.is_finite(), "non-finite freq {v}");
}
}
#[test]
fn su_scaled_scale_one_skip_path_matches_plain_freqs() {
let long_factor = [1.0f32, 2.0, 3.0, 4.0];
let r = SuScaledRope::new(8, DEFAULT_BASE, 4096, 4096, &long_factor, None).unwrap();
assert_eq!(r.scale(), 1.0);
let x = input(8);
let freqs = {
let base = base_pair_freqs(f64::from(DEFAULT_BASE), 8, 4);
let v: Vec<f64> = base
.into_iter()
.zip(long_factor)
.map(|(f, lf)| f64::from(lf) * f)
.collect();
freqs_array(&v).unwrap()
};
let mut via_apply = r.apply(&x, 5).unwrap();
let mut via_freqs = rope_with_freqs(&x, 8, false, 1.0, 5, &freqs).unwrap();
assert_close(
&via_apply.to_vec::<f32>().unwrap(),
&via_freqs.to_vec::<f32>().unwrap(),
);
}
#[test]
fn su_scaled_scale_one_override_skip_path() {
let long_factor = [1.0f32, 2.0, 3.0, 4.0];
let r = SuScaledRope::new(8, DEFAULT_BASE, 131072, 4096, &long_factor, Some(1.0)).unwrap();
assert_eq!(r.scale(), 1.0);
let x = input(8);
let freqs = {
let base = base_pair_freqs(f64::from(DEFAULT_BASE), 8, 4);
let v: Vec<f64> = base
.into_iter()
.zip(long_factor)
.map(|(f, lf)| f64::from(lf) * f)
.collect();
freqs_array(&v).unwrap()
};
let mut via_apply = r.apply(&x, 2).unwrap();
let mut via_freqs = rope_with_freqs(&x, 8, false, 1.0, 2, &freqs).unwrap();
assert_close(
&via_apply.to_vec::<f32>().unwrap(),
&via_freqs.to_vec::<f32>().unwrap(),
);
}
#[test]
fn yarn_freqs_and_mscale_hand_traced() {
let cfg = YarnConfig {
scaling_factor: 4.0,
original_max_position_embeddings: 4096,
beta_fast: 32.0,
beta_slow: 1.0,
mscale: 1.0,
mscale_all_dim: 0.0,
};
let r = YarnRope::new(8, DEFAULT_BASE, false, cfg).unwrap();
assert!(
(r.mscale() - 1.138629).abs() <= TOL,
"mscale {} != 1.138629",
r.mscale()
);
let freqs = yarn_reference_freqs(8, f64::from(DEFAULT_BASE), 4.0, 4096, 32.0, 1.0);
let mut freqs_arr = freqs.try_clone().unwrap();
assert_close(
&freqs_arr.to_vec::<f32>().unwrap(),
&[1.0, 10.0, 160.0, 4000.0],
);
}
fn yarn_reference_freqs(
dims: i32,
base: f64,
scaling_factor: f64,
orig_max: i32,
beta_fast: f64,
beta_slow: f64,
) -> Array {
let half = (dims / 2) as usize;
let dims_f = f64::from(dims);
let orig = f64::from(orig_max);
let cdim =
|n: f64| (dims_f * (orig / (n * 2.0 * std::f64::consts::PI)).ln()) / (2.0 * base.ln());
let low = cdim(beta_fast).floor().max(0.0);
let high = cdim(beta_slow).ceil().min(dims_f - 1.0);
let ramp_max = if (low - high).abs() < f64::EPSILON {
high + 0.001
} else {
high
};
let extra = base_pair_freqs(base, dims, half);
let v: Vec<f64> = extra
.into_iter()
.enumerate()
.map(|(i, fe)| {
let fi = scaling_factor * fe;
let ramp = ((i as f64 - low) / (ramp_max - low)).clamp(0.0, 1.0);
let mask = 1.0 - ramp;
(fi * fe) / (fi * mask + fe * (1.0 - mask))
})
.collect();
freqs_array(&v).unwrap()
}
#[test]
fn yarn_scaling_factor_le_one_mscale_is_one() {
let cfg = YarnConfig::new(1.0);
let r = YarnRope::new(8, DEFAULT_BASE, false, cfg).unwrap();
assert_eq!(r.mscale(), 1.0);
let x = input(8);
let freqs = yarn_reference_freqs(8, f64::from(DEFAULT_BASE), 1.0, 4096, 32.0, 1.0);
let mut via_apply = r.apply(&x, 4).unwrap();
let mut via_freqs = rope_with_freqs(&x, 8, false, 1.0, 4, &freqs).unwrap();
assert_close(
&via_apply.to_vec::<f32>().unwrap(),
&via_freqs.to_vec::<f32>().unwrap(),
);
}
#[test]
fn yarn_apply_includes_mscale() {
let cfg = YarnConfig::new(4.0);
let r = YarnRope::new(8, DEFAULT_BASE, false, cfg).unwrap();
let mscale = r.mscale();
assert!((mscale - 1.0).abs() > TOL, "expected non-unit mscale");
let x = input(8);
let freqs = yarn_reference_freqs(8, f64::from(DEFAULT_BASE), 4.0, 4096, 32.0, 1.0);
let scalar = Array::from_slice::<f32>(&[mscale], &(1usize,)).unwrap();
let scaled = x.multiply(&scalar).unwrap();
let mut manual = rope_with_freqs(&scaled, 8, false, 1.0, 6, &freqs).unwrap();
let mut via_apply = r.apply(&x, 6).unwrap();
assert_close(
&via_apply.to_vec::<f32>().unwrap(),
&manual.to_vec::<f32>().unwrap(),
);
}
#[test]
fn yarn_rejects_base_le_one() {
let cfg = YarnConfig::new(4.0);
assert!(YarnRope::new(8, 1.0, false, cfg).is_err());
assert!(YarnRope::new(8, 0.0, false, cfg).is_err());
assert!(YarnRope::new(8, -10.0, false, cfg).is_err());
}
#[test]
fn yarn_rejects_nonpositive_orig_max() {
let mut cfg = YarnConfig::new(4.0);
cfg.original_max_position_embeddings = 0;
assert!(YarnRope::new(8, DEFAULT_BASE, false, cfg).is_err());
cfg.original_max_position_embeddings = -4096;
assert!(YarnRope::new(8, DEFAULT_BASE, false, cfg).is_err());
}
#[test]
fn yarn_rejects_nonpositive_betas() {
let mut cfg = YarnConfig::new(4.0);
cfg.beta_fast = 0.0;
assert!(YarnRope::new(8, DEFAULT_BASE, false, cfg).is_err());
let mut cfg = YarnConfig::new(4.0);
cfg.beta_slow = -1.0;
assert!(YarnRope::new(8, DEFAULT_BASE, false, cfg).is_err());
}
#[test]
fn yarn_rejects_nonpositive_scaling_factor() {
let mut cfg = YarnConfig::new(0.0);
assert!(
YarnRope::new(8, DEFAULT_BASE, false, cfg).is_err(),
"scaling_factor=0 must be rejected"
);
cfg.scaling_factor = -4.0;
assert!(
YarnRope::new(8, DEFAULT_BASE, false, cfg).is_err(),
"negative scaling_factor must be rejected"
);
}
#[test]
fn yarn_rejects_nonfinite_float_inputs() {
assert!(
YarnRope::new(8, f32::NAN, false, YarnConfig::new(4.0)).is_err(),
"NaN base"
);
let mut cfg = YarnConfig::new(f32::INFINITY);
assert!(
YarnRope::new(8, DEFAULT_BASE, false, cfg).is_err(),
"Inf scaling_factor"
);
cfg = YarnConfig::new(4.0);
cfg.beta_fast = f32::NAN;
assert!(
YarnRope::new(8, DEFAULT_BASE, false, cfg).is_err(),
"NaN beta_fast"
);
cfg = YarnConfig::new(4.0);
cfg.mscale = f32::INFINITY;
assert!(
YarnRope::new(8, DEFAULT_BASE, false, cfg).is_err(),
"Inf mscale"
);
cfg = YarnConfig::new(4.0);
cfg.mscale_all_dim = f32::NAN;
assert!(
YarnRope::new(8, DEFAULT_BASE, false, cfg).is_err(),
"NaN mscale_all_dim"
);
}
#[test]
fn yarn_valid_inputs_yield_finite_freqs_and_mscale() {
let cfg = YarnConfig::new(4.0);
let r = YarnRope::new(8, DEFAULT_BASE, false, cfg).unwrap();
assert!(r.mscale().is_finite());
let mut freqs = r.freqs.try_clone().unwrap();
for v in freqs.to_vec::<f32>().unwrap() {
assert!(v.is_finite(), "non-finite freq {v}");
}
}
#[test]
fn scale_leading_dims_partial_keeps_tail() {
let x = Array::from_slice::<f32>(
&[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
&(1usize, 1usize, 2usize, 6usize),
)
.unwrap();
let mut scaled = scale_leading_dims(&x, 4, 2.0).unwrap();
assert_close(
&scaled.to_vec::<f32>().unwrap(),
&[
2.0, 4.0, 6.0, 8.0, 5.0, 6.0, 14.0, 16.0, 18.0, 20.0, 11.0, 12.0, ],
);
}
#[test]
fn scale_leading_dims_whole_axis() {
let x =
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(1usize, 1usize, 1usize, 4usize)).unwrap();
let mut scaled = scale_leading_dims(&x, 4, 3.0).unwrap();
assert_close(&scaled.to_vec::<f32>().unwrap(), &[3.0, 6.0, 9.0, 12.0]);
}
#[test]
fn scale_leading_dims_rejects_dims_gt_head_dim() {
let x = input(4);
assert!(scale_leading_dims(&x, 8, 2.0).is_err());
}
#[test]
fn su_scaled_partial_dims_apply() {
let x = Array::from_slice::<f32>(
&[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
&(1usize, 1usize, 2usize, 6usize),
)
.unwrap();
let long_factor = [1.0f32, 2.0]; let r = SuScaledRope::new(4, DEFAULT_BASE, 16384, 4096, &long_factor, None).unwrap();
let scale = r.scale();
let freqs = {
let base = base_pair_freqs(f64::from(DEFAULT_BASE), 4, 2);
let v: Vec<f64> = base
.into_iter()
.zip(long_factor)
.map(|(f, lf)| f64::from(lf) * f)
.collect();
freqs_array(&v).unwrap()
};
let manual_scaled = scale_leading_dims(&x, 4, scale).unwrap();
let mut manual = rope_with_freqs(&manual_scaled, 4, false, 1.0, 2, &freqs).unwrap();
let mut via_apply = r.apply(&x, 2).unwrap();
assert_close(
&via_apply.to_vec::<f32>().unwrap(),
&manual.to_vec::<f32>().unwrap(),
);
}
fn input_dtype(head_dim: usize, dtype: Dtype) -> Array {
input(head_dim).astype(dtype).unwrap()
}
#[test]
fn scale_leading_dims_preserves_half_dtype() {
for dtype in [Dtype::F16, Dtype::BF16] {
let whole = scale_leading_dims(&input_dtype(8, dtype), 8, 2.0).unwrap();
assert_eq!(whole.dtype().unwrap(), dtype, "whole-axis dtype, {dtype:?}");
let partial = scale_leading_dims(&input_dtype(8, dtype), 4, 2.0).unwrap();
assert_eq!(
partial.dtype().unwrap(),
dtype,
"partial-dims dtype, {dtype:?}"
);
}
}
#[test]
fn su_scaled_apply_preserves_half_dtype() {
for dtype in [Dtype::F16, Dtype::BF16] {
let r = SuScaledRope::new(8, DEFAULT_BASE, 16384, 4096, &[1.0; 4], None).unwrap();
assert!((r.scale() - 1.0).abs() > TOL, "expected non-unit scale");
let out = r.apply(&input_dtype(8, dtype), 3).unwrap();
assert_eq!(
out.dtype().unwrap(),
dtype,
"Su head_dim==dims dtype, {dtype:?}"
);
let r_partial = SuScaledRope::new(4, DEFAULT_BASE, 16384, 4096, &[1.0, 2.0], None).unwrap();
let out_partial = r_partial.apply(&input_dtype(8, dtype), 3).unwrap();
assert_eq!(
out_partial.dtype().unwrap(),
dtype,
"Su head_dim>dims dtype, {dtype:?}"
);
}
}
#[test]
fn yarn_apply_preserves_half_dtype() {
for dtype in [Dtype::F16, Dtype::BF16] {
let cfg = YarnConfig::new(4.0);
let r = YarnRope::new(8, DEFAULT_BASE, false, cfg).unwrap();
assert!((r.mscale() - 1.0).abs() > TOL, "expected non-unit mscale");
let out = r.apply(&input_dtype(8, dtype), 6).unwrap();
assert_eq!(
out.dtype().unwrap(),
dtype,
"YaRN head_dim==dims dtype, {dtype:?}"
);
let r_partial = YarnRope::new(4, DEFAULT_BASE, false, cfg).unwrap();
assert!(
(r_partial.mscale() - 1.0).abs() > TOL,
"expected non-unit mscale"
);
let out_partial = r_partial.apply(&input_dtype(8, dtype), 6).unwrap();
assert_eq!(
out_partial.dtype().unwrap(),
dtype,
"YaRN head_dim>dims dtype, {dtype:?}"
);
}
}
}