use crate::{
error::{RealizarError, Result},
tensor::Tensor,
};
#[derive(Debug, Clone)]
pub struct RoPE {
dim: usize,
base: f32,
inv_freq: Vec<f32>,
}
impl RoPE {
pub fn new(dim: usize, base: f32) -> Result<Self> {
if dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "dim must be > 0".to_string(),
});
}
if !dim.is_multiple_of(2) {
return Err(RealizarError::InvalidShape {
reason: "dim must be even for RoPE".to_string(),
});
}
let half_dim = dim / 2;
let mut inv_freq = Vec::with_capacity(half_dim);
#[allow(clippy::cast_precision_loss)]
for i in 0..half_dim {
let exponent = -2.0 * (i as f32) / (dim as f32);
inv_freq.push(base.powf(exponent));
}
Ok(Self {
dim,
base,
inv_freq,
})
}
pub fn with_default_base(dim: usize) -> Result<Self> {
Self::new(dim, 10000.0)
}
pub fn forward(&self, input: &Tensor<f32>, position: usize) -> Result<Tensor<f32>> {
let shape = input.shape();
if shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Input tensor must have at least 1 dimension".to_string(),
});
}
let last_dim = shape[shape.len() - 1];
if last_dim != self.dim {
return Err(RealizarError::InvalidShape {
reason: format!("Expected last dimension {}, got {}", self.dim, last_dim),
});
}
let data = input.data();
let num_vectors = data.len() / self.dim;
let mut output = Vec::with_capacity(data.len());
let half_dim = self.dim / 2;
let mut cos_vals = Vec::with_capacity(half_dim);
let mut sin_vals = Vec::with_capacity(half_dim);
#[allow(clippy::cast_precision_loss)]
for inv_f in &self.inv_freq {
let angle = inv_f * (position as f32);
cos_vals.push(angle.cos());
sin_vals.push(angle.sin());
}
for vec_idx in 0..num_vectors {
let offset = vec_idx * self.dim;
for i in 0..half_dim {
let x0 = data[offset + 2 * i];
let x1 = data[offset + 2 * i + 1];
let cos_val = cos_vals[i];
let sin_val = sin_vals[i];
let y0 = x0 * cos_val - x1 * sin_val;
let y1 = x0 * sin_val + x1 * cos_val;
output.push(y0);
output.push(y1);
}
}
Tensor::from_vec(shape.to_vec(), output)
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
pub fn base(&self) -> f32 {
self.base
}
#[must_use]
pub fn inv_freq(&self) -> &[f32] {
&self.inv_freq
}
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum RopeScalingType {
#[default]
None,
Linear {
scale: f32,
},
Ntk {
scale: f32,
},
DynamicNtk {
original_max_len: usize,
target_max_len: usize,
},
Yarn {
original_max_len: usize,
target_max_len: usize,
attn_factor: f32,
beta_fast: f32,
beta_slow: f32,
},
}
#[derive(Debug, Clone)]
pub struct ScaledRoPE {
dim: usize,
original_base: f32,
scaled_base: f32,
scaling: RopeScalingType,
inv_freq: Vec<f32>,
mscale: f32,
}
include!("scaled_rope.rs");
include!("alibi.rs");
include!("position_alibi_get.rs");