use std::sync::Arc;
use ferrotorch_core::autograd::no_grad::is_grad_enabled;
use ferrotorch_core::grad_fns::activation::silu;
use ferrotorch_core::grad_fns::arithmetic::{add, mul};
use ferrotorch_core::grad_fns::shape::reshape;
use ferrotorch_core::tensor::GradFn;
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
use crate::attention::MultiheadAttention;
use crate::dropout::Dropout;
use crate::linear::Linear;
use crate::module::Module;
use crate::norm::LayerNorm;
use crate::parameter::Parameter;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum RoPEConvention {
#[default]
Interleaved,
HalfRotation,
}
impl std::fmt::Display for RoPEConvention {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RoPEConvention::Interleaved => write!(f, "interleaved"),
RoPEConvention::HalfRotation => write!(f, "half_rotation"),
}
}
}
#[derive(Debug)]
struct RoPEBackward<T: Float> {
input: Tensor<T>,
cos_flat: Vec<T>,
sin_flat: Vec<T>,
half_dim: usize,
seq_len: usize,
batch_dims: usize,
dim: usize,
seq_offset: usize,
convention: RoPEConvention,
}
impl<T: Float> GradFn<T> for RoPEBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go_data = grad_output.data_vec()?;
let total = go_data.len();
let mut grad_input = Vec::with_capacity(total);
match self.convention {
RoPEConvention::Interleaved => {
for b in 0..self.batch_dims {
for s in 0..self.seq_len {
let cache_start = (self.seq_offset + s) * self.half_dim;
let go_start = b * self.seq_len * self.dim + s * self.dim;
for i in 0..self.half_dim {
let go_even = go_data[go_start + 2 * i];
let go_odd = go_data[go_start + 2 * i + 1];
let cos_val = self.cos_flat[cache_start + i];
let sin_val = self.sin_flat[cache_start + i];
grad_input.push(go_even * cos_val + go_odd * sin_val);
grad_input.push(-go_even * sin_val + go_odd * cos_val);
}
}
}
}
RoPEConvention::HalfRotation => {
for b in 0..self.batch_dims {
for s in 0..self.seq_len {
let cache_start = (self.seq_offset + s) * self.half_dim;
let go_start = b * self.seq_len * self.dim + s * self.dim;
for i in 0..self.half_dim {
let go_first = go_data[go_start + i];
let go_second = go_data[go_start + self.half_dim + i];
let cos_val = self.cos_flat[cache_start + i];
let sin_val = self.sin_flat[cache_start + i];
grad_input.push(go_first * cos_val + go_second * sin_val);
}
for i in 0..self.half_dim {
let go_first = go_data[go_start + i];
let go_second = go_data[go_start + self.half_dim + i];
let cos_val = self.cos_flat[cache_start + i];
let sin_val = self.sin_flat[cache_start + i];
grad_input.push(-go_first * sin_val + go_second * cos_val);
}
}
}
}
}
let g = Tensor::from_storage(
TensorStorage::cpu(grad_input),
self.input.shape().to_vec(),
false,
)?;
Some(if self.input.is_cuda() {
g.to(self.input.device())?
} else {
g
})
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"RoPEBackward"
}
}
#[derive(Debug)]
pub struct RotaryPositionEmbedding<T: Float> {
dim: usize,
max_seq_len: usize,
base: f64,
convention: RoPEConvention,
cos_cache: Tensor<T>,
sin_cache: Tensor<T>,
}
impl<T: Float> RotaryPositionEmbedding<T> {
pub fn new(dim: usize, max_seq_len: usize, base: f64) -> FerrotorchResult<Self> {
Self::with_convention(dim, max_seq_len, base, RoPEConvention::default())
}
pub fn with_convention(
dim: usize,
max_seq_len: usize,
base: f64,
convention: RoPEConvention,
) -> FerrotorchResult<Self> {
if dim == 0 || dim % 2 != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("RoPE dim must be even and positive, got {dim}"),
});
}
if max_seq_len == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "RoPE max_seq_len must be positive".into(),
});
}
let half_dim = dim / 2;
let thetas: Vec<f64> = (0..half_dim)
.map(|i| 1.0 / base.powf(2.0 * i as f64 / dim as f64))
.collect();
let total = max_seq_len * half_dim;
let mut cos_data = Vec::with_capacity(total);
let mut sin_data = Vec::with_capacity(total);
for pos in 0..max_seq_len {
for &theta in &thetas {
let angle = pos as f64 * theta;
cos_data.push(T::from(angle.cos()).unwrap());
sin_data.push(T::from(angle.sin()).unwrap());
}
}
let cos_cache = Tensor::from_storage(
TensorStorage::cpu(cos_data),
vec![max_seq_len, half_dim],
false,
)?;
let sin_cache = Tensor::from_storage(
TensorStorage::cpu(sin_data),
vec![max_seq_len, half_dim],
false,
)?;
Ok(Self {
dim,
max_seq_len,
base,
convention,
cos_cache,
sin_cache,
})
}
pub fn apply(&self, x: &Tensor<T>, seq_offset: usize) -> FerrotorchResult<Tensor<T>> {
let shape = x.shape();
let ndim = shape.len();
if ndim < 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"RoPE input must be at least 2-D, got {ndim}-D with shape {shape:?}"
),
});
}
let last_dim = shape[ndim - 1];
if last_dim != self.dim {
return Err(FerrotorchError::ShapeMismatch {
message: format!("RoPE: last dim of input ({last_dim}) != dim ({})", self.dim),
});
}
let seq_len = shape[ndim - 2];
if seq_offset + seq_len > self.max_seq_len {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"RoPE: seq_offset ({seq_offset}) + seq_len ({seq_len}) > max_seq_len ({})",
self.max_seq_len
),
});
}
let device = x.device();
let half_dim = self.dim / 2;
let cos_data = self.cos_cache.data_vec()?;
let sin_data = self.sin_cache.data_vec()?;
let x_data = x.data_vec()?;
let batch_dims: usize = shape[..ndim - 2].iter().product();
let total = x.numel();
let mut output = Vec::with_capacity(total);
match self.convention {
RoPEConvention::Interleaved => {
for b in 0..batch_dims {
for s in 0..seq_len {
let pos = seq_offset + s;
let cache_start = pos * half_dim;
let x_start = b * seq_len * self.dim + s * self.dim;
for i in 0..half_dim {
let x_even = x_data[x_start + 2 * i];
let x_odd = x_data[x_start + 2 * i + 1];
let cos_val = cos_data[cache_start + i];
let sin_val = sin_data[cache_start + i];
output.push(x_even * cos_val - x_odd * sin_val);
output.push(x_even * sin_val + x_odd * cos_val);
}
}
}
}
RoPEConvention::HalfRotation => {
for b in 0..batch_dims {
for s in 0..seq_len {
let pos = seq_offset + s;
let cache_start = pos * half_dim;
let x_start = b * seq_len * self.dim + s * self.dim;
for i in 0..half_dim {
let x_first = x_data[x_start + i];
let x_second = x_data[x_start + half_dim + i];
let cos_val = cos_data[cache_start + i];
let sin_val = sin_data[cache_start + i];
output.push(x_first * cos_val - x_second * sin_val);
}
for i in 0..half_dim {
let x_first = x_data[x_start + i];
let x_second = x_data[x_start + half_dim + i];
let cos_val = cos_data[cache_start + i];
let sin_val = sin_data[cache_start + i];
output.push(x_first * sin_val + x_second * cos_val);
}
}
}
}
}
let result = if is_grad_enabled() && x.requires_grad() {
Tensor::from_operation(
TensorStorage::cpu(output),
shape.to_vec(),
Arc::new(RoPEBackward {
input: x.clone(),
cos_flat: cos_data,
sin_flat: sin_data,
half_dim,
seq_len,
batch_dims,
dim: self.dim,
seq_offset,
convention: self.convention,
}),
)?
} else {
Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?
};
if device.is_cuda() {
result.to(device)
} else {
Ok(result)
}
}
#[inline]
pub fn dim(&self) -> usize {
self.dim
}
#[inline]
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
#[inline]
pub fn base(&self) -> f64 {
self.base
}
#[inline]
pub fn convention(&self) -> RoPEConvention {
self.convention
}
}
#[derive(Debug)]
pub struct SwiGLU<T: Float> {
w1: Linear<T>,
w2: Linear<T>,
w3: Linear<T>,
training: bool,
}
impl<T: Float> SwiGLU<T> {
pub fn new(in_features: usize, hidden_features: usize, bias: bool) -> FerrotorchResult<Self> {
let w1 = Linear::new(in_features, hidden_features, bias)?;
let w2 = Linear::new(in_features, hidden_features, bias)?;
let w3 = Linear::new(hidden_features, in_features, bias)?;
Ok(Self {
w1,
w2,
w3,
training: true,
})
}
fn forward_3d(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
let batch = shape[0];
let seq_len = shape[1];
let flat = reshape(input, &[(batch * seq_len) as isize, -1])?;
let output_flat = self.forward_2d(&flat)?;
let out_features = output_flat.shape()[1];
reshape(
&output_flat,
&[batch as isize, seq_len as isize, out_features as isize],
)
}
fn forward_2d(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let w1_out = self.w1.forward(input)?;
let gate = silu(&w1_out)?;
let up = self.w2.forward(input)?;
let gated = mul(&gate, &up)?;
self.w3.forward(&gated)
}
}
impl<T: Float> Module<T> for SwiGLU<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
match input.ndim() {
2 => self.forward_2d(input),
3 => self.forward_3d(input),
_ => Err(FerrotorchError::InvalidArgument {
message: format!(
"SwiGLU expects 2-D or 3-D input, got {}-D with shape {:?}",
input.ndim(),
input.shape()
),
}),
}
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = self.w1.parameters();
params.extend(self.w2.parameters());
params.extend(self.w3.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = self.w1.parameters_mut();
params.extend(self.w2.parameters_mut());
params.extend(self.w3.parameters_mut());
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = Vec::new();
for (name, param) in self.w1.named_parameters() {
params.push((format!("w1.{name}"), param));
}
for (name, param) in self.w2.named_parameters() {
params.push((format!("w2.{name}"), param));
}
for (name, param) in self.w3.named_parameters() {
params.push((format!("w3.{name}"), param));
}
params
}
fn train(&mut self) {
self.training = true;
self.w1.train();
self.w2.train();
self.w3.train();
}
fn eval(&mut self) {
self.training = false;
self.w1.eval();
self.w2.eval();
self.w3.eval();
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
pub struct KVCache<T: Float> {
key_cache: Option<Tensor<T>>,
value_cache: Option<Tensor<T>>,
max_seq_len: usize,
}
impl<T: Float> KVCache<T> {
pub fn new(max_seq_len: usize) -> Self {
Self {
key_cache: None,
value_cache: None,
max_seq_len,
}
}
pub fn update(
&mut self,
key: Tensor<T>,
value: Tensor<T>,
) -> FerrotorchResult<(Tensor<T>, Tensor<T>)> {
if key.ndim() != 4 || value.ndim() != 4 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"KVCache expects 4-D [B, heads, seq, dim] tensors, \
got key {:?}, value {:?}",
key.shape(),
value.shape()
),
});
}
if key.shape() != value.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"KVCache: key shape {:?} != value shape {:?}",
key.shape(),
value.shape()
),
});
}
let (full_key, full_value) = match (&self.key_cache, &self.value_cache) {
(Some(ck), Some(cv)) => {
let fk = concat_along_dim2(ck, &key)?;
let fv = concat_along_dim2(cv, &value)?;
(fk, fv)
}
_ => (key.clone(), value.clone()),
};
let total_seq = full_key.shape()[2];
if total_seq > self.max_seq_len {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"KVCache: total sequence length ({total_seq}) exceeds max_seq_len ({})",
self.max_seq_len
),
});
}
self.key_cache = Some(full_key.clone());
self.value_cache = Some(full_value.clone());
Ok((full_key, full_value))
}
pub fn reset(&mut self) {
self.key_cache = None;
self.value_cache = None;
}
pub fn seq_len(&self) -> usize {
self.key_cache.as_ref().map(|k| k.shape()[2]).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.key_cache.is_none()
}
#[inline]
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
}
fn concat_along_dim2<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let sa = a.shape();
let sb = b.shape();
if sa[0] != sb[0] || sa[1] != sb[1] || sa[3] != sb[3] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"concat_along_dim2: shapes {:?} and {:?} must match on dims 0, 1, 3",
sa, sb
),
});
}
let device = a.device();
let (batch, heads, seq_a, dim) = (sa[0], sa[1], sa[2], sa[3]);
let seq_b = sb[2];
let seq_out = seq_a + seq_b;
let a_data = a.data_vec()?;
let b_data = b.data_vec()?;
let mut output = Vec::with_capacity(batch * heads * seq_out * dim);
for ba in 0..batch {
for h in 0..heads {
let a_start = (ba * heads + h) * seq_a * dim;
output.extend_from_slice(&a_data[a_start..a_start + seq_a * dim]);
let b_start = (ba * heads + h) * seq_b * dim;
output.extend_from_slice(&b_data[b_start..b_start + seq_b * dim]);
}
}
let result = Tensor::from_storage(
TensorStorage::cpu(output),
vec![batch, heads, seq_out, dim],
false,
)?;
if device.is_cuda() {
result.to(device)
} else {
Ok(result)
}
}
#[derive(Debug)]
pub struct TransformerEncoderLayer<T: Float> {
self_attn: MultiheadAttention<T>,
ffn: SwiGLU<T>,
norm1: LayerNorm<T>,
norm2: LayerNorm<T>,
dropout: Dropout<T>,
training: bool,
}
impl<T: Float> TransformerEncoderLayer<T> {
pub fn new(
d_model: usize,
num_heads: usize,
d_ff: usize,
dropout_p: f64,
layer_norm_eps: f64,
bias: bool,
) -> FerrotorchResult<Self> {
let self_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
let ffn = SwiGLU::new(d_model, d_ff, bias)?;
let norm1 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let norm2 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let dropout = Dropout::new(dropout_p)?;
Ok(Self {
self_attn,
ffn,
norm1,
norm2,
dropout,
training: true,
})
}
}
impl<T: Float> Module<T> for TransformerEncoderLayer<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 3 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"TransformerEncoderLayer expects 3-D [batch, seq, d_model], got {:?}",
input.shape()
),
});
}
let normed1 = self.norm1.forward(input)?;
let attn_out = self.self_attn.forward(&normed1)?;
let attn_out = self.dropout.forward(&attn_out)?;
let residual1 = add(input, &attn_out)?;
let normed2 = self.norm2.forward(&residual1)?;
let ffn_out = self.ffn.forward(&normed2)?;
let ffn_out = self.dropout.forward(&ffn_out)?;
let residual2 = add(&residual1, &ffn_out)?;
Ok(residual2)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters());
params.extend(self.ffn.parameters());
params.extend(self.norm1.parameters());
params.extend(self.norm2.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters_mut());
params.extend(self.ffn.parameters_mut());
params.extend(self.norm1.parameters_mut());
params.extend(self.norm2.parameters_mut());
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = Vec::new();
for (name, param) in self.self_attn.named_parameters() {
params.push((format!("self_attn.{name}"), param));
}
for (name, param) in self.ffn.named_parameters() {
params.push((format!("ffn.{name}"), param));
}
for (name, param) in self.norm1.named_parameters() {
params.push((format!("norm1.{name}"), param));
}
for (name, param) in self.norm2.named_parameters() {
params.push((format!("norm2.{name}"), param));
}
params
}
fn train(&mut self) {
self.training = true;
self.self_attn.train();
self.ffn.train();
self.norm1.train();
self.norm2.train();
self.dropout.train();
}
fn eval(&mut self) {
self.training = false;
self.self_attn.eval();
self.ffn.eval();
self.norm1.eval();
self.norm2.eval();
self.dropout.eval();
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
pub struct TransformerDecoderLayer<T: Float> {
self_attn: MultiheadAttention<T>,
cross_attn: MultiheadAttention<T>,
ffn: SwiGLU<T>,
norm1: LayerNorm<T>,
norm2: LayerNorm<T>,
norm3: LayerNorm<T>,
dropout: Dropout<T>,
training: bool,
}
impl<T: Float> TransformerDecoderLayer<T> {
pub fn new(
d_model: usize,
num_heads: usize,
d_ff: usize,
dropout_p: f64,
layer_norm_eps: f64,
bias: bool,
) -> FerrotorchResult<Self> {
let self_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
let cross_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
let ffn = SwiGLU::new(d_model, d_ff, bias)?;
let norm1 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let norm2 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let norm3 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let dropout = Dropout::new(dropout_p)?;
Ok(Self {
self_attn,
cross_attn,
ffn,
norm1,
norm2,
norm3,
dropout,
training: true,
})
}
pub fn forward_with_memory(
&self,
input: &Tensor<T>,
memory: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 3 || memory.ndim() != 3 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"TransformerDecoderLayer expects 3-D inputs, \
got input {:?}, memory {:?}",
input.shape(),
memory.shape()
),
});
}
let normed1 = self.norm1.forward(input)?;
let self_attn_out = self
.self_attn
.forward_qkv(&normed1, &normed1, &normed1, true)?;
let self_attn_out = self.dropout.forward(&self_attn_out)?;
let residual1 = add(input, &self_attn_out)?;
let normed2 = self.norm2.forward(&residual1)?;
let cross_attn_out = self
.cross_attn
.forward_qkv(&normed2, memory, memory, false)?;
let cross_attn_out = self.dropout.forward(&cross_attn_out)?;
let residual2 = add(&residual1, &cross_attn_out)?;
let normed3 = self.norm3.forward(&residual2)?;
let ffn_out = self.ffn.forward(&normed3)?;
let ffn_out = self.dropout.forward(&ffn_out)?;
let residual3 = add(&residual2, &ffn_out)?;
Ok(residual3)
}
}
impl<T: Float> Module<T> for TransformerDecoderLayer<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
self.forward_with_memory(input, input)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters());
params.extend(self.cross_attn.parameters());
params.extend(self.ffn.parameters());
params.extend(self.norm1.parameters());
params.extend(self.norm2.parameters());
params.extend(self.norm3.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters_mut());
params.extend(self.cross_attn.parameters_mut());
params.extend(self.ffn.parameters_mut());
params.extend(self.norm1.parameters_mut());
params.extend(self.norm2.parameters_mut());
params.extend(self.norm3.parameters_mut());
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = Vec::new();
for (name, param) in self.self_attn.named_parameters() {
params.push((format!("self_attn.{name}"), param));
}
for (name, param) in self.cross_attn.named_parameters() {
params.push((format!("cross_attn.{name}"), param));
}
for (name, param) in self.ffn.named_parameters() {
params.push((format!("ffn.{name}"), param));
}
for (name, param) in self.norm1.named_parameters() {
params.push((format!("norm1.{name}"), param));
}
for (name, param) in self.norm2.named_parameters() {
params.push((format!("norm2.{name}"), param));
}
for (name, param) in self.norm3.named_parameters() {
params.push((format!("norm3.{name}"), param));
}
params
}
fn train(&mut self) {
self.training = true;
self.self_attn.train();
self.cross_attn.train();
self.ffn.train();
self.norm1.train();
self.norm2.train();
self.norm3.train();
self.dropout.train();
}
fn eval(&mut self) {
self.training = false;
self.self_attn.eval();
self.cross_attn.eval();
self.ffn.eval();
self.norm1.eval();
self.norm2.eval();
self.norm3.eval();
self.dropout.eval();
}
fn is_training(&self) -> bool {
self.training
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rope_construction() {
let rope = RotaryPositionEmbedding::<f32>::new(64, 512, 10000.0);
assert!(rope.is_ok());
let rope = rope.unwrap();
assert_eq!(rope.dim(), 64);
assert_eq!(rope.max_seq_len(), 512);
assert_eq!(rope.base(), 10000.0);
}
#[test]
fn test_rope_odd_dim_rejected() {
assert!(RotaryPositionEmbedding::<f32>::new(63, 512, 10000.0).is_err());
}
#[test]
fn test_rope_zero_dim_rejected() {
assert!(RotaryPositionEmbedding::<f32>::new(0, 512, 10000.0).is_err());
}
#[test]
fn test_rope_zero_seq_rejected() {
assert!(RotaryPositionEmbedding::<f32>::new(64, 0, 10000.0).is_err());
}
#[test]
fn test_rope_output_shape_2d() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
let y = rope.apply(&x, 0).unwrap();
assert_eq!(y.shape(), &[4, 8]);
}
#[test]
fn test_rope_output_shape_3d() {
let rope = RotaryPositionEmbedding::<f32>::new(16, 256, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[2, 10, 16]).unwrap();
let y = rope.apply(&x, 0).unwrap();
assert_eq!(y.shape(), &[2, 10, 16]);
}
#[test]
fn test_rope_output_shape_4d() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[2, 4, 6, 8]).unwrap();
let y = rope.apply(&x, 0).unwrap();
assert_eq!(y.shape(), &[2, 4, 6, 8]);
}
#[test]
fn test_rope_with_offset() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
let x = ferrotorch_core::ones::<f32>(&[4, 8]).unwrap();
let y = rope.apply(&x, 10).unwrap();
assert_eq!(y.shape(), &[4, 8]);
}
#[test]
fn test_rope_offset_overflow_rejected() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 16, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[10, 8]).unwrap();
assert!(rope.apply(&x, 10).is_err());
}
#[test]
fn test_rope_position_zero_is_identity() {
let rope = RotaryPositionEmbedding::<f64>::new(4, 64, 10000.0).unwrap();
let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
let y = rope.apply(&x, 0).unwrap();
let y_data = y.data().unwrap();
let x_data = x.data().unwrap();
for (i, (&xv, &yv)) in x_data.iter().zip(y_data.iter()).enumerate() {
assert!(
(xv - yv).abs() < 1e-10,
"position 0 should be identity, index {i}: x={xv}, y={yv}"
);
}
}
#[test]
fn test_rope_values_are_finite() {
let rope = RotaryPositionEmbedding::<f32>::new(16, 512, 10000.0).unwrap();
let x = ferrotorch_core::ones::<f32>(&[2, 4, 10, 16]).unwrap();
let y = rope.apply(&x, 0).unwrap();
for &v in y.data().unwrap() {
assert!(v.is_finite(), "RoPE produced non-finite value: {v}");
}
}
#[test]
fn test_rope_wrong_dim_rejected() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[4, 10]).unwrap(); assert!(rope.apply(&x, 0).is_err());
}
#[test]
fn test_rope_half_rotation_construction() {
let rope = RotaryPositionEmbedding::<f32>::with_convention(
8,
128,
10000.0,
RoPEConvention::HalfRotation,
)
.unwrap();
assert_eq!(rope.convention(), RoPEConvention::HalfRotation);
}
#[test]
fn test_rope_half_rotation_output_shape() {
let rope = RotaryPositionEmbedding::<f32>::with_convention(
8,
128,
10000.0,
RoPEConvention::HalfRotation,
)
.unwrap();
let x = ferrotorch_core::zeros::<f32>(&[2, 4, 8]).unwrap();
let y = rope.apply(&x, 0).unwrap();
assert_eq!(y.shape(), &[2, 4, 8]);
}
#[test]
fn test_rope_half_rotation_position_zero_is_identity() {
let rope = RotaryPositionEmbedding::<f64>::with_convention(
4,
64,
10000.0,
RoPEConvention::HalfRotation,
)
.unwrap();
let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
let y = rope.apply(&x, 0).unwrap();
let x_data = x.data().unwrap();
let y_data = y.data().unwrap();
for (i, (&xv, &yv)) in x_data.iter().zip(y_data.iter()).enumerate() {
assert!(
(xv - yv).abs() < 1e-10,
"half-rot pos 0 should be identity, index {i}: x={xv}, y={yv}"
);
}
}
#[test]
fn test_rope_half_rotation_correctness() {
let rope = RotaryPositionEmbedding::<f64>::with_convention(
4,
64,
10000.0,
RoPEConvention::HalfRotation,
)
.unwrap();
let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
let y = rope.apply(&x, 1).unwrap();
let cos_data = rope.cos_cache.data().unwrap();
let sin_data = rope.sin_cache.data().unwrap();
let c0 = cos_data[2];
let c1 = cos_data[3];
let s0 = sin_data[2];
let s1 = sin_data[3];
let expected = [
1.0 * c0 - 3.0 * s0,
2.0 * c1 - 4.0 * s1,
1.0 * s0 + 3.0 * c0,
2.0 * s1 + 4.0 * c1,
];
let y_data = y.data().unwrap();
for (i, (&actual, &exp)) in y_data.iter().zip(expected.iter()).enumerate() {
assert!(
(actual - exp).abs() < 1e-10,
"half-rot index {i}: actual={actual}, expected={exp}"
);
}
}
#[test]
fn test_rope_interleaved_vs_half_rotation_differ() {
let rope_il = RotaryPositionEmbedding::<f64>::with_convention(
4,
64,
10000.0,
RoPEConvention::Interleaved,
)
.unwrap();
let rope_hr = RotaryPositionEmbedding::<f64>::with_convention(
4,
64,
10000.0,
RoPEConvention::HalfRotation,
)
.unwrap();
let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
let y_il = rope_il.apply(&x, 1).unwrap();
let y_hr = rope_hr.apply(&x, 1).unwrap();
let il_data = y_il.data().unwrap();
let hr_data = y_hr.data().unwrap();
let any_differ = il_data
.iter()
.zip(hr_data.iter())
.any(|(&a, &b)| (a - b).abs() > 1e-10);
assert!(
any_differ,
"interleaved and half-rotation should produce different outputs at pos > 0"
);
}
#[test]
fn test_rope_default_convention_is_interleaved() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
assert_eq!(rope.convention(), RoPEConvention::Interleaved);
}
#[test]
fn test_swiglu_construction() {
let swiglu = SwiGLU::<f32>::new(64, 128, true);
assert!(swiglu.is_ok());
}
#[test]
fn test_swiglu_forward_shape_2d() {
let swiglu = SwiGLU::<f32>::new(16, 32, true).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[4, 16]).unwrap();
let output = swiglu.forward(&input).unwrap();
assert_eq!(output.shape(), &[4, 16]);
}
#[test]
fn test_swiglu_forward_shape_3d() {
let swiglu = SwiGLU::<f32>::new(16, 32, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
let output = swiglu.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 5, 16]);
}
#[test]
fn test_swiglu_forward_values_finite() {
let swiglu = SwiGLU::<f32>::new(8, 16, true).unwrap();
let input = ferrotorch_core::ones::<f32>(&[2, 3, 8]).unwrap();
let output = swiglu.forward(&input).unwrap();
for &v in output.data().unwrap() {
assert!(v.is_finite(), "SwiGLU produced non-finite value: {v}");
}
}
#[test]
fn test_swiglu_1d_rejected() {
let swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[8]).unwrap();
assert!(swiglu.forward(&input).is_err());
}
#[test]
fn test_swiglu_parameters() {
let swiglu = SwiGLU::<f32>::new(8, 16, true).unwrap();
let params = swiglu.parameters();
assert_eq!(params.len(), 6);
let named = swiglu.named_parameters();
let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
assert!(names.contains(&"w1.weight"));
assert!(names.contains(&"w1.bias"));
assert!(names.contains(&"w2.weight"));
assert!(names.contains(&"w2.bias"));
assert!(names.contains(&"w3.weight"));
assert!(names.contains(&"w3.bias"));
}
#[test]
fn test_swiglu_parameters_no_bias() {
let swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
let params = swiglu.parameters();
assert_eq!(params.len(), 3);
}
#[test]
fn test_swiglu_train_eval() {
let mut swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
assert!(swiglu.is_training());
swiglu.eval();
assert!(!swiglu.is_training());
swiglu.train();
assert!(swiglu.is_training());
}
#[test]
fn test_kv_cache_new_empty() {
let cache = KVCache::<f32>::new(1024);
assert!(cache.is_empty());
assert_eq!(cache.seq_len(), 0);
assert_eq!(cache.max_seq_len(), 1024);
}
#[test]
fn test_kv_cache_single_update() {
let mut cache = KVCache::<f32>::new(128);
let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let (fk, fv) = cache.update(k, v).unwrap();
assert_eq!(fk.shape(), &[1, 2, 3, 4]);
assert_eq!(fv.shape(), &[1, 2, 3, 4]);
assert_eq!(cache.seq_len(), 3);
}
#[test]
fn test_kv_cache_append() {
let mut cache = KVCache::<f32>::new(128);
let k1 = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let v1 = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
cache.update(k1, v1).unwrap();
assert_eq!(cache.seq_len(), 3);
let k2 = ferrotorch_core::ones::<f32>(&[1, 2, 2, 4]).unwrap();
let v2 = ferrotorch_core::ones::<f32>(&[1, 2, 2, 4]).unwrap();
let (fk, fv) = cache.update(k2, v2).unwrap();
assert_eq!(fk.shape(), &[1, 2, 5, 4]); assert_eq!(fv.shape(), &[1, 2, 5, 4]);
assert_eq!(cache.seq_len(), 5);
}
#[test]
fn test_kv_cache_reset() {
let mut cache = KVCache::<f32>::new(128);
let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
cache.update(k, v).unwrap();
assert_eq!(cache.seq_len(), 3);
cache.reset();
assert!(cache.is_empty());
assert_eq!(cache.seq_len(), 0);
}
#[test]
fn test_kv_cache_overflow_rejected() {
let mut cache = KVCache::<f32>::new(4);
let k = ferrotorch_core::ones::<f32>(&[1, 1, 5, 2]).unwrap();
let v = ferrotorch_core::ones::<f32>(&[1, 1, 5, 2]).unwrap();
assert!(cache.update(k, v).is_err());
}
#[test]
fn test_kv_cache_shape_mismatch_rejected() {
let mut cache = KVCache::<f32>::new(128);
let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 8]).unwrap(); assert!(cache.update(k, v).is_err());
}
#[test]
fn test_kv_cache_values_preserved() {
let mut cache = KVCache::<f64>::new(128);
let k1 = ferrotorch_core::ones::<f64>(&[1, 1, 2, 3]).unwrap();
let v1 = ferrotorch_core::ones::<f64>(&[1, 1, 2, 3]).unwrap();
cache.update(k1, v1).unwrap();
let k2_data = vec![2.0f64; 1 * 1 * 1 * 3];
let k2 = ferrotorch_core::from_slice(&k2_data, &[1, 1, 1, 3]).unwrap();
let v2 = ferrotorch_core::from_slice(&k2_data, &[1, 1, 1, 3]).unwrap();
let (fk, _fv) = cache.update(k2, v2).unwrap();
assert_eq!(fk.shape(), &[1, 1, 3, 3]); let fk_data = fk.data().unwrap();
for &v in &fk_data[..6] {
assert!((v - 1.0).abs() < 1e-10, "expected 1.0, got {v}");
}
for &v in &fk_data[6..9] {
assert!((v - 2.0).abs() < 1e-10, "expected 2.0, got {v}");
}
}
#[test]
fn test_encoder_layer_construction() {
let layer = TransformerEncoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, true);
assert!(layer.is_ok());
}
#[test]
fn test_encoder_layer_forward_shape() {
let layer = TransformerEncoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
let output = layer.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 5, 16]);
}
#[test]
fn test_encoder_layer_forward_values_finite() {
let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let input = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
let output = layer.forward(&input).unwrap();
for &v in output.data().unwrap() {
assert!(
v.is_finite(),
"TransformerEncoderLayer produced non-finite value: {v}"
);
}
}
#[test]
fn test_encoder_layer_2d_rejected() {
let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
assert!(layer.forward(&input).is_err());
}
#[test]
fn test_encoder_layer_parameters_count() {
let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let params = layer.parameters();
assert_eq!(params.len(), 18);
}
#[test]
fn test_encoder_layer_train_eval() {
let mut layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.1, 1e-5, false).unwrap();
assert!(layer.is_training());
layer.eval();
assert!(!layer.is_training());
layer.train();
assert!(layer.is_training());
}
#[test]
fn test_encoder_layer_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<TransformerEncoderLayer<f32>>();
assert_send_sync::<TransformerEncoderLayer<f64>>();
}
#[test]
fn test_decoder_layer_construction() {
let layer = TransformerDecoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, true);
assert!(layer.is_ok());
}
#[test]
fn test_decoder_layer_forward_shape() {
let layer = TransformerDecoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, false).unwrap();
let tgt = ferrotorch_core::zeros::<f32>(&[2, 4, 16]).unwrap();
let memory = ferrotorch_core::zeros::<f32>(&[2, 6, 16]).unwrap();
let output = layer.forward_with_memory(&tgt, &memory).unwrap();
assert_eq!(output.shape(), &[2, 4, 16]);
}
#[test]
fn test_decoder_layer_self_forward_shape() {
let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
let output = layer.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 3, 8]);
}
#[test]
fn test_decoder_layer_forward_values_finite() {
let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let tgt = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
let mem = ferrotorch_core::ones::<f32>(&[1, 5, 8]).unwrap();
let output = layer.forward_with_memory(&tgt, &mem).unwrap();
for &v in output.data().unwrap() {
assert!(
v.is_finite(),
"TransformerDecoderLayer produced non-finite value: {v}"
);
}
}
#[test]
fn test_decoder_layer_2d_rejected() {
let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
let memory = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
assert!(layer.forward_with_memory(&input, &memory).is_err());
}
#[test]
fn test_decoder_layer_parameters_count() {
let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let params = layer.parameters();
assert_eq!(params.len(), 28);
}
#[test]
fn test_decoder_layer_train_eval() {
let mut layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.1, 1e-5, false).unwrap();
assert!(layer.is_training());
layer.eval();
assert!(!layer.is_training());
layer.train();
assert!(layer.is_training());
}
#[test]
fn test_decoder_layer_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<TransformerDecoderLayer<f32>>();
assert_send_sync::<TransformerDecoderLayer<f64>>();
}
}