impl ScaledRoPE {
pub fn new(dim: usize, base: f32, scaling: RopeScalingType) -> Result<Self> {
if dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "dim must be > 0".to_string(),
});
}
if !dim.is_multiple_of(2) {
return Err(RealizarError::InvalidShape {
reason: "dim must be even for RoPE".to_string(),
});
}
let (scaled_base, mscale, inv_freq) = Self::compute_frequencies(dim, base, &scaling);
Ok(Self {
dim,
original_base: base,
scaled_base,
scaling,
inv_freq,
mscale,
})
}
pub fn with_default_base(dim: usize, scaling: RopeScalingType) -> Result<Self> {
Self::new(dim, 10000.0, scaling)
}
fn compute_frequencies(
dim: usize,
base: f32,
scaling: &RopeScalingType,
) -> (f32, f32, Vec<f32>) {
let half_dim = dim / 2;
#[allow(clippy::cast_precision_loss)]
let (scaled_base, mscale) = match scaling {
RopeScalingType::None | RopeScalingType::Linear { .. } => (base, 1.0),
RopeScalingType::Ntk { scale } => {
let dim_f = dim as f32;
let exponent = dim_f / (dim_f - 2.0);
let ntk_base = base * scale.powf(exponent);
(ntk_base, 1.0)
},
RopeScalingType::DynamicNtk {
original_max_len,
target_max_len,
} => {
let scale = (*target_max_len as f32) / (*original_max_len as f32);
let dim_f = dim as f32;
let exponent = dim_f / (dim_f - 2.0);
let ntk_base = base * scale.powf(exponent);
(ntk_base, 1.0)
},
RopeScalingType::Yarn {
original_max_len,
target_max_len,
attn_factor,
beta_fast,
beta_slow,
} => {
let scale = (*target_max_len as f32) / (*original_max_len as f32);
let dim_f = dim as f32;
let exponent = dim_f / (dim_f - 2.0);
let ntk_base = base * scale.powf(exponent);
let mscale = if *attn_factor > 0.0 {
*attn_factor
} else {
let log_scale = scale.ln();
let log_orig = (*original_max_len as f32).ln();
(1.0 + log_scale / log_orig).sqrt()
};
let _ = (beta_fast, beta_slow);
(ntk_base, mscale)
},
};
let mut inv_freq = Vec::with_capacity(half_dim);
#[allow(clippy::cast_precision_loss)]
for i in 0..half_dim {
let exponent = -2.0 * (i as f32) / (dim as f32);
inv_freq.push(scaled_base.powf(exponent));
}
(scaled_base, mscale, inv_freq)
}
pub fn forward(&self, input: &Tensor<f32>, position: usize) -> Result<Tensor<f32>> {
let shape = input.shape();
if shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Input tensor must have at least 1 dimension".to_string(),
});
}
let last_dim = shape[shape.len() - 1];
if last_dim != self.dim {
return Err(RealizarError::InvalidShape {
reason: format!("Expected last dimension {}, got {}", self.dim, last_dim),
});
}
let data = input.data();
let num_vectors = data.len() / self.dim;
let mut output = Vec::with_capacity(data.len());
#[allow(clippy::cast_precision_loss)]
let effective_pos = match &self.scaling {
RopeScalingType::None
| RopeScalingType::Ntk { .. }
| RopeScalingType::DynamicNtk { .. }
| RopeScalingType::Yarn { .. } => position as f32,
RopeScalingType::Linear { scale } => (position as f32) / scale,
};
let half_dim = self.dim / 2;
let mut cos_vals = Vec::with_capacity(half_dim);
let mut sin_vals = Vec::with_capacity(half_dim);
#[allow(clippy::cast_precision_loss)]
for (i, inv_f) in self.inv_freq.iter().enumerate() {
let angle = inv_f * effective_pos;
let (cos_val, sin_val) = if let RopeScalingType::Yarn {
original_max_len,
target_max_len,
beta_fast,
beta_slow,
..
} = &self.scaling
{
let freq = 1.0 / inv_f;
let wavelength = 2.0 * std::f32::consts::PI * freq;
let low_freq_wavelen = (*original_max_len as f32) / *beta_slow;
let high_freq_wavelen = (*original_max_len as f32) / *beta_fast;
let ramp = if wavelength < high_freq_wavelen {
0.0 } else if wavelength > low_freq_wavelen {
1.0 } else {
(wavelength - high_freq_wavelen) / (low_freq_wavelen - high_freq_wavelen)
};
let scale = (*target_max_len as f32) / (*original_max_len as f32);
let linear_pos = effective_pos / scale;
let orig_inv_f = self
.original_base
.powf(-2.0 * (i as f32) / (self.dim as f32));
let linear_angle = orig_inv_f * linear_pos;
let final_angle = angle * (1.0 - ramp) + linear_angle * ramp;
(final_angle.cos(), final_angle.sin())
} else {
(angle.cos(), angle.sin())
};
cos_vals.push(cos_val);
sin_vals.push(sin_val);
}
for vec_idx in 0..num_vectors {
let offset = vec_idx * self.dim;
for i in 0..half_dim {
let x0 = data[offset + 2 * i];
let x1 = data[offset + 2 * i + 1];
let cos_val = cos_vals[i];
let sin_val = sin_vals[i];
let y0 = (x0 * cos_val - x1 * sin_val) * self.mscale;
let y1 = (x0 * sin_val + x1 * cos_val) * self.mscale;
output.push(y0);
output.push(y1);
}
}
Tensor::from_vec(shape.to_vec(), output)
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
pub fn original_base(&self) -> f32 {
self.original_base
}
#[must_use]
pub fn scaled_base(&self) -> f32 {
self.scaled_base
}
#[must_use]
pub fn scaling(&self) -> &RopeScalingType {
&self.scaling
}
#[must_use]
pub fn inv_freq(&self) -> &[f32] {
&self.inv_freq
}
#[must_use]
pub fn mscale(&self) -> f32 {
self.mscale
}
#[must_use]
pub fn context_length_multiplier(&self) -> f32 {
match &self.scaling {
RopeScalingType::None => 1.0,
RopeScalingType::Linear { scale } | RopeScalingType::Ntk { scale } => *scale,
RopeScalingType::DynamicNtk {
original_max_len,
target_max_len,
}
| RopeScalingType::Yarn {
original_max_len,
target_max_len,
..
} => (*target_max_len as f32) / (*original_max_len as f32),
}
}
}
#[derive(Debug, Clone)]
pub struct ALiBi {
num_heads: usize,
slopes: Vec<f32>,
}