use ferrotorch_core::grad_fns::activation::silu;
use ferrotorch_core::grad_fns::arithmetic::{add, mul};
use ferrotorch_core::{Float, FerrotorchError, FerrotorchResult, Tensor, TensorStorage};
use crate::attention::MultiheadAttention;
use crate::dropout::Dropout;
use crate::linear::Linear;
use crate::module::Module;
use crate::norm::LayerNorm;
use crate::parameter::Parameter;
#[derive(Debug)]
pub struct RotaryPositionEmbedding<T: Float> {
dim: usize,
max_seq_len: usize,
base: f64,
cos_cache: Tensor<T>,
sin_cache: Tensor<T>,
}
impl<T: Float> RotaryPositionEmbedding<T> {
pub fn new(dim: usize, max_seq_len: usize, base: f64) -> FerrotorchResult<Self> {
if dim == 0 || dim % 2 != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("RoPE dim must be even and positive, got {dim}"),
});
}
if max_seq_len == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "RoPE max_seq_len must be positive".into(),
});
}
let half_dim = dim / 2;
let thetas: Vec<f64> = (0..half_dim)
.map(|i| 1.0 / base.powf(2.0 * i as f64 / dim as f64))
.collect();
let total = max_seq_len * half_dim;
let mut cos_data = Vec::with_capacity(total);
let mut sin_data = Vec::with_capacity(total);
for pos in 0..max_seq_len {
for i in 0..half_dim {
let angle = pos as f64 * thetas[i];
cos_data.push(T::from(angle.cos()).unwrap());
sin_data.push(T::from(angle.sin()).unwrap());
}
}
let cos_cache = Tensor::from_storage(
TensorStorage::cpu(cos_data),
vec![max_seq_len, half_dim],
false,
)?;
let sin_cache = Tensor::from_storage(
TensorStorage::cpu(sin_data),
vec![max_seq_len, half_dim],
false,
)?;
Ok(Self {
dim,
max_seq_len,
base,
cos_cache,
sin_cache,
})
}
pub fn apply(&self, x: &Tensor<T>, seq_offset: usize) -> FerrotorchResult<Tensor<T>> {
let shape = x.shape();
let ndim = shape.len();
if ndim < 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"RoPE input must be at least 2-D, got {ndim}-D with shape {shape:?}"
),
});
}
let last_dim = shape[ndim - 1];
if last_dim != self.dim {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"RoPE: last dim of input ({last_dim}) != dim ({})",
self.dim
),
});
}
let seq_len = shape[ndim - 2];
if seq_offset + seq_len > self.max_seq_len {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"RoPE: seq_offset ({seq_offset}) + seq_len ({seq_len}) > max_seq_len ({})",
self.max_seq_len
),
});
}
let half_dim = self.dim / 2;
let cos_data = self.cos_cache.data()?;
let sin_data = self.sin_cache.data()?;
let x_data = x.data()?;
let batch_dims: usize = shape[..ndim - 2].iter().product();
let total = x.numel();
let mut output = Vec::with_capacity(total);
for b in 0..batch_dims {
for s in 0..seq_len {
let pos = seq_offset + s;
let cos_row_start = pos * half_dim;
let sin_row_start = pos * half_dim;
let x_start = b * seq_len * self.dim + s * self.dim;
for i in 0..half_dim {
let x_even = x_data[x_start + 2 * i];
let x_odd = x_data[x_start + 2 * i + 1];
let cos_val = cos_data[cos_row_start + i];
let sin_val = sin_data[sin_row_start + i];
output.push(x_even * cos_val - x_odd * sin_val);
output.push(x_even * sin_val + x_odd * cos_val);
}
}
}
Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)
}
#[inline]
pub fn dim(&self) -> usize {
self.dim
}
#[inline]
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
#[inline]
pub fn base(&self) -> f64 {
self.base
}
}
#[derive(Debug)]
pub struct SwiGLU<T: Float> {
w1: Linear<T>,
w2: Linear<T>,
w3: Linear<T>,
training: bool,
}
impl<T: Float> SwiGLU<T> {
pub fn new(
in_features: usize,
hidden_features: usize,
bias: bool,
) -> FerrotorchResult<Self> {
let w1 = Linear::new(in_features, hidden_features, bias)?;
let w2 = Linear::new(in_features, hidden_features, bias)?;
let w3 = Linear::new(hidden_features, in_features, bias)?;
Ok(Self {
w1,
w2,
w3,
training: true,
})
}
fn forward_3d(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
let batch = shape[0];
let seq_len = shape[1];
let features = shape[2];
let flat = Tensor::from_storage(
TensorStorage::cpu(input.data()?.to_vec()),
vec![batch * seq_len, features],
input.requires_grad(),
)?;
let output_flat = self.forward_2d(&flat)?;
let out_features = output_flat.shape()[1];
Tensor::from_storage(
TensorStorage::cpu(output_flat.data()?.to_vec()),
vec![batch, seq_len, out_features],
false,
)
}
fn forward_2d(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let w1_out = self.w1.forward(input)?;
let gate = silu(&w1_out)?;
let up = self.w2.forward(input)?;
let gated = mul(&gate, &up)?;
self.w3.forward(&gated)
}
}
impl<T: Float> Module<T> for SwiGLU<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
match input.ndim() {
2 => self.forward_2d(input),
3 => self.forward_3d(input),
_ => Err(FerrotorchError::InvalidArgument {
message: format!(
"SwiGLU expects 2-D or 3-D input, got {}-D with shape {:?}",
input.ndim(),
input.shape()
),
}),
}
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = self.w1.parameters();
params.extend(self.w2.parameters());
params.extend(self.w3.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = self.w1.parameters_mut();
params.extend(self.w2.parameters_mut());
params.extend(self.w3.parameters_mut());
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = Vec::new();
for (name, param) in self.w1.named_parameters() {
params.push((format!("w1.{name}"), param));
}
for (name, param) in self.w2.named_parameters() {
params.push((format!("w2.{name}"), param));
}
for (name, param) in self.w3.named_parameters() {
params.push((format!("w3.{name}"), param));
}
params
}
fn train(&mut self) {
self.training = true;
self.w1.train();
self.w2.train();
self.w3.train();
}
fn eval(&mut self) {
self.training = false;
self.w1.eval();
self.w2.eval();
self.w3.eval();
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
pub struct KVCache<T: Float> {
key_cache: Option<Tensor<T>>,
value_cache: Option<Tensor<T>>,
max_seq_len: usize,
}
impl<T: Float> KVCache<T> {
pub fn new(max_seq_len: usize) -> Self {
Self {
key_cache: None,
value_cache: None,
max_seq_len,
}
}
pub fn update(
&mut self,
key: Tensor<T>,
value: Tensor<T>,
) -> FerrotorchResult<(Tensor<T>, Tensor<T>)> {
if key.ndim() != 4 || value.ndim() != 4 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"KVCache expects 4-D [B, heads, seq, dim] tensors, \
got key {:?}, value {:?}",
key.shape(),
value.shape()
),
});
}
if key.shape() != value.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"KVCache: key shape {:?} != value shape {:?}",
key.shape(),
value.shape()
),
});
}
let (full_key, full_value) = match (&self.key_cache, &self.value_cache) {
(Some(ck), Some(cv)) => {
let fk = concat_along_dim2(ck, &key)?;
let fv = concat_along_dim2(cv, &value)?;
(fk, fv)
}
_ => (key.clone(), value.clone()),
};
let total_seq = full_key.shape()[2];
if total_seq > self.max_seq_len {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"KVCache: total sequence length ({total_seq}) exceeds max_seq_len ({})",
self.max_seq_len
),
});
}
self.key_cache = Some(full_key.clone());
self.value_cache = Some(full_value.clone());
Ok((full_key, full_value))
}
pub fn reset(&mut self) {
self.key_cache = None;
self.value_cache = None;
}
pub fn seq_len(&self) -> usize {
self.key_cache
.as_ref()
.map(|k| k.shape()[2])
.unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.key_cache.is_none()
}
#[inline]
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
}
fn concat_along_dim2<T: Float>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
let sa = a.shape();
let sb = b.shape();
if sa[0] != sb[0] || sa[1] != sb[1] || sa[3] != sb[3] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"concat_along_dim2: shapes {:?} and {:?} must match on dims 0, 1, 3",
sa, sb
),
});
}
let (batch, heads, seq_a, dim) = (sa[0], sa[1], sa[2], sa[3]);
let seq_b = sb[2];
let seq_out = seq_a + seq_b;
let a_data = a.data()?;
let b_data = b.data()?;
let mut output = Vec::with_capacity(batch * heads * seq_out * dim);
for ba in 0..batch {
for h in 0..heads {
let a_start = (ba * heads + h) * seq_a * dim;
output.extend_from_slice(&a_data[a_start..a_start + seq_a * dim]);
let b_start = (ba * heads + h) * seq_b * dim;
output.extend_from_slice(&b_data[b_start..b_start + seq_b * dim]);
}
}
Tensor::from_storage(
TensorStorage::cpu(output),
vec![batch, heads, seq_out, dim],
false,
)
}
#[derive(Debug)]
pub struct TransformerEncoderLayer<T: Float> {
self_attn: MultiheadAttention<T>,
ffn: SwiGLU<T>,
norm1: LayerNorm<T>,
norm2: LayerNorm<T>,
dropout: Dropout<T>,
training: bool,
}
impl<T: Float> TransformerEncoderLayer<T> {
pub fn new(
d_model: usize,
num_heads: usize,
d_ff: usize,
dropout_p: f64,
layer_norm_eps: f64,
bias: bool,
) -> FerrotorchResult<Self> {
let self_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
let ffn = SwiGLU::new(d_model, d_ff, bias)?;
let norm1 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let norm2 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let dropout = Dropout::new(dropout_p)?;
Ok(Self {
self_attn,
ffn,
norm1,
norm2,
dropout,
training: true,
})
}
}
impl<T: Float> Module<T> for TransformerEncoderLayer<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 3 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"TransformerEncoderLayer expects 3-D [batch, seq, d_model], got {:?}",
input.shape()
),
});
}
let normed1 = self.norm1.forward(input)?;
let attn_out = self.self_attn.forward(&normed1)?;
let attn_out = self.dropout.forward(&attn_out)?;
let residual1 = add(input, &attn_out)?;
let normed2 = self.norm2.forward(&residual1)?;
let ffn_out = self.ffn.forward(&normed2)?;
let ffn_out = self.dropout.forward(&ffn_out)?;
let residual2 = add(&residual1, &ffn_out)?;
Ok(residual2)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters());
params.extend(self.ffn.parameters());
params.extend(self.norm1.parameters());
params.extend(self.norm2.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters_mut());
params.extend(self.ffn.parameters_mut());
params.extend(self.norm1.parameters_mut());
params.extend(self.norm2.parameters_mut());
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = Vec::new();
for (name, param) in self.self_attn.named_parameters() {
params.push((format!("self_attn.{name}"), param));
}
for (name, param) in self.ffn.named_parameters() {
params.push((format!("ffn.{name}"), param));
}
for (name, param) in self.norm1.named_parameters() {
params.push((format!("norm1.{name}"), param));
}
for (name, param) in self.norm2.named_parameters() {
params.push((format!("norm2.{name}"), param));
}
params
}
fn train(&mut self) {
self.training = true;
self.self_attn.train();
self.ffn.train();
self.norm1.train();
self.norm2.train();
self.dropout.train();
}
fn eval(&mut self) {
self.training = false;
self.self_attn.eval();
self.ffn.eval();
self.norm1.eval();
self.norm2.eval();
self.dropout.eval();
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
pub struct TransformerDecoderLayer<T: Float> {
self_attn: MultiheadAttention<T>,
cross_attn: MultiheadAttention<T>,
ffn: SwiGLU<T>,
norm1: LayerNorm<T>,
norm2: LayerNorm<T>,
norm3: LayerNorm<T>,
dropout: Dropout<T>,
training: bool,
}
impl<T: Float> TransformerDecoderLayer<T> {
pub fn new(
d_model: usize,
num_heads: usize,
d_ff: usize,
dropout_p: f64,
layer_norm_eps: f64,
bias: bool,
) -> FerrotorchResult<Self> {
let self_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
let cross_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
let ffn = SwiGLU::new(d_model, d_ff, bias)?;
let norm1 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let norm2 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let norm3 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let dropout = Dropout::new(dropout_p)?;
Ok(Self {
self_attn,
cross_attn,
ffn,
norm1,
norm2,
norm3,
dropout,
training: true,
})
}
pub fn forward_with_memory(
&self,
input: &Tensor<T>,
memory: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 3 || memory.ndim() != 3 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"TransformerDecoderLayer expects 3-D inputs, \
got input {:?}, memory {:?}",
input.shape(),
memory.shape()
),
});
}
let normed1 = self.norm1.forward(input)?;
let self_attn_out =
self.self_attn
.forward_qkv(&normed1, &normed1, &normed1, true)?;
let self_attn_out = self.dropout.forward(&self_attn_out)?;
let residual1 = add(input, &self_attn_out)?;
let normed2 = self.norm2.forward(&residual1)?;
let cross_attn_out =
self.cross_attn
.forward_qkv(&normed2, memory, memory, false)?;
let cross_attn_out = self.dropout.forward(&cross_attn_out)?;
let residual2 = add(&residual1, &cross_attn_out)?;
let normed3 = self.norm3.forward(&residual2)?;
let ffn_out = self.ffn.forward(&normed3)?;
let ffn_out = self.dropout.forward(&ffn_out)?;
let residual3 = add(&residual2, &ffn_out)?;
Ok(residual3)
}
}
impl<T: Float> Module<T> for TransformerDecoderLayer<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
self.forward_with_memory(input, input)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters());
params.extend(self.cross_attn.parameters());
params.extend(self.ffn.parameters());
params.extend(self.norm1.parameters());
params.extend(self.norm2.parameters());
params.extend(self.norm3.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters_mut());
params.extend(self.cross_attn.parameters_mut());
params.extend(self.ffn.parameters_mut());
params.extend(self.norm1.parameters_mut());
params.extend(self.norm2.parameters_mut());
params.extend(self.norm3.parameters_mut());
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = Vec::new();
for (name, param) in self.self_attn.named_parameters() {
params.push((format!("self_attn.{name}"), param));
}
for (name, param) in self.cross_attn.named_parameters() {
params.push((format!("cross_attn.{name}"), param));
}
for (name, param) in self.ffn.named_parameters() {
params.push((format!("ffn.{name}"), param));
}
for (name, param) in self.norm1.named_parameters() {
params.push((format!("norm1.{name}"), param));
}
for (name, param) in self.norm2.named_parameters() {
params.push((format!("norm2.{name}"), param));
}
for (name, param) in self.norm3.named_parameters() {
params.push((format!("norm3.{name}"), param));
}
params
}
fn train(&mut self) {
self.training = true;
self.self_attn.train();
self.cross_attn.train();
self.ffn.train();
self.norm1.train();
self.norm2.train();
self.norm3.train();
self.dropout.train();
}
fn eval(&mut self) {
self.training = false;
self.self_attn.eval();
self.cross_attn.eval();
self.ffn.eval();
self.norm1.eval();
self.norm2.eval();
self.norm3.eval();
self.dropout.eval();
}
fn is_training(&self) -> bool {
self.training
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rope_construction() {
let rope = RotaryPositionEmbedding::<f32>::new(64, 512, 10000.0);
assert!(rope.is_ok());
let rope = rope.unwrap();
assert_eq!(rope.dim(), 64);
assert_eq!(rope.max_seq_len(), 512);
assert_eq!(rope.base(), 10000.0);
}
#[test]
fn test_rope_odd_dim_rejected() {
assert!(RotaryPositionEmbedding::<f32>::new(63, 512, 10000.0).is_err());
}
#[test]
fn test_rope_zero_dim_rejected() {
assert!(RotaryPositionEmbedding::<f32>::new(0, 512, 10000.0).is_err());
}
#[test]
fn test_rope_zero_seq_rejected() {
assert!(RotaryPositionEmbedding::<f32>::new(64, 0, 10000.0).is_err());
}
#[test]
fn test_rope_output_shape_2d() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
let y = rope.apply(&x, 0).unwrap();
assert_eq!(y.shape(), &[4, 8]);
}
#[test]
fn test_rope_output_shape_3d() {
let rope = RotaryPositionEmbedding::<f32>::new(16, 256, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[2, 10, 16]).unwrap();
let y = rope.apply(&x, 0).unwrap();
assert_eq!(y.shape(), &[2, 10, 16]);
}
#[test]
fn test_rope_output_shape_4d() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[2, 4, 6, 8]).unwrap();
let y = rope.apply(&x, 0).unwrap();
assert_eq!(y.shape(), &[2, 4, 6, 8]);
}
#[test]
fn test_rope_with_offset() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
let x = ferrotorch_core::ones::<f32>(&[4, 8]).unwrap();
let y = rope.apply(&x, 10).unwrap();
assert_eq!(y.shape(), &[4, 8]);
}
#[test]
fn test_rope_offset_overflow_rejected() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 16, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[10, 8]).unwrap();
assert!(rope.apply(&x, 10).is_err());
}
#[test]
fn test_rope_position_zero_is_identity() {
let rope = RotaryPositionEmbedding::<f64>::new(4, 64, 10000.0).unwrap();
let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
let y = rope.apply(&x, 0).unwrap();
let y_data = y.data().unwrap();
let x_data = x.data().unwrap();
for (i, (&xv, &yv)) in x_data.iter().zip(y_data.iter()).enumerate() {
assert!(
(xv - yv).abs() < 1e-10,
"position 0 should be identity, index {i}: x={xv}, y={yv}"
);
}
}
#[test]
fn test_rope_values_are_finite() {
let rope = RotaryPositionEmbedding::<f32>::new(16, 512, 10000.0).unwrap();
let x = ferrotorch_core::ones::<f32>(&[2, 4, 10, 16]).unwrap();
let y = rope.apply(&x, 0).unwrap();
for &v in y.data().unwrap() {
assert!(v.is_finite(), "RoPE produced non-finite value: {v}");
}
}
#[test]
fn test_rope_wrong_dim_rejected() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[4, 10]).unwrap(); assert!(rope.apply(&x, 0).is_err());
}
#[test]
fn test_swiglu_construction() {
let swiglu = SwiGLU::<f32>::new(64, 128, true);
assert!(swiglu.is_ok());
}
#[test]
fn test_swiglu_forward_shape_2d() {
let swiglu = SwiGLU::<f32>::new(16, 32, true).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[4, 16]).unwrap();
let output = swiglu.forward(&input).unwrap();
assert_eq!(output.shape(), &[4, 16]);
}
#[test]
fn test_swiglu_forward_shape_3d() {
let swiglu = SwiGLU::<f32>::new(16, 32, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
let output = swiglu.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 5, 16]);
}
#[test]
fn test_swiglu_forward_values_finite() {
let swiglu = SwiGLU::<f32>::new(8, 16, true).unwrap();
let input = ferrotorch_core::ones::<f32>(&[2, 3, 8]).unwrap();
let output = swiglu.forward(&input).unwrap();
for &v in output.data().unwrap() {
assert!(v.is_finite(), "SwiGLU produced non-finite value: {v}");
}
}
#[test]
fn test_swiglu_1d_rejected() {
let swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[8]).unwrap();
assert!(swiglu.forward(&input).is_err());
}
#[test]
fn test_swiglu_parameters() {
let swiglu = SwiGLU::<f32>::new(8, 16, true).unwrap();
let params = swiglu.parameters();
assert_eq!(params.len(), 6);
let named = swiglu.named_parameters();
let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
assert!(names.contains(&"w1.weight"));
assert!(names.contains(&"w1.bias"));
assert!(names.contains(&"w2.weight"));
assert!(names.contains(&"w2.bias"));
assert!(names.contains(&"w3.weight"));
assert!(names.contains(&"w3.bias"));
}
#[test]
fn test_swiglu_parameters_no_bias() {
let swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
let params = swiglu.parameters();
assert_eq!(params.len(), 3);
}
#[test]
fn test_swiglu_train_eval() {
let mut swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
assert!(swiglu.is_training());
swiglu.eval();
assert!(!swiglu.is_training());
swiglu.train();
assert!(swiglu.is_training());
}
#[test]
fn test_kv_cache_new_empty() {
let cache = KVCache::<f32>::new(1024);
assert!(cache.is_empty());
assert_eq!(cache.seq_len(), 0);
assert_eq!(cache.max_seq_len(), 1024);
}
#[test]
fn test_kv_cache_single_update() {
let mut cache = KVCache::<f32>::new(128);
let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let (fk, fv) = cache.update(k, v).unwrap();
assert_eq!(fk.shape(), &[1, 2, 3, 4]);
assert_eq!(fv.shape(), &[1, 2, 3, 4]);
assert_eq!(cache.seq_len(), 3);
}
#[test]
fn test_kv_cache_append() {
let mut cache = KVCache::<f32>::new(128);
let k1 = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let v1 = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
cache.update(k1, v1).unwrap();
assert_eq!(cache.seq_len(), 3);
let k2 = ferrotorch_core::ones::<f32>(&[1, 2, 2, 4]).unwrap();
let v2 = ferrotorch_core::ones::<f32>(&[1, 2, 2, 4]).unwrap();
let (fk, fv) = cache.update(k2, v2).unwrap();
assert_eq!(fk.shape(), &[1, 2, 5, 4]); assert_eq!(fv.shape(), &[1, 2, 5, 4]);
assert_eq!(cache.seq_len(), 5);
}
#[test]
fn test_kv_cache_reset() {
let mut cache = KVCache::<f32>::new(128);
let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
cache.update(k, v).unwrap();
assert_eq!(cache.seq_len(), 3);
cache.reset();
assert!(cache.is_empty());
assert_eq!(cache.seq_len(), 0);
}
#[test]
fn test_kv_cache_overflow_rejected() {
let mut cache = KVCache::<f32>::new(4);
let k = ferrotorch_core::ones::<f32>(&[1, 1, 5, 2]).unwrap();
let v = ferrotorch_core::ones::<f32>(&[1, 1, 5, 2]).unwrap();
assert!(cache.update(k, v).is_err());
}
#[test]
fn test_kv_cache_shape_mismatch_rejected() {
let mut cache = KVCache::<f32>::new(128);
let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 8]).unwrap(); assert!(cache.update(k, v).is_err());
}
#[test]
fn test_kv_cache_values_preserved() {
let mut cache = KVCache::<f64>::new(128);
let k1 = ferrotorch_core::ones::<f64>(&[1, 1, 2, 3]).unwrap();
let v1 = ferrotorch_core::ones::<f64>(&[1, 1, 2, 3]).unwrap();
cache.update(k1, v1).unwrap();
let k2_data = vec![2.0f64; 1 * 1 * 1 * 3];
let k2 = ferrotorch_core::from_slice(&k2_data, &[1, 1, 1, 3]).unwrap();
let v2 = ferrotorch_core::from_slice(&k2_data, &[1, 1, 1, 3]).unwrap();
let (fk, _fv) = cache.update(k2, v2).unwrap();
assert_eq!(fk.shape(), &[1, 1, 3, 3]); let fk_data = fk.data().unwrap();
for &v in &fk_data[..6] {
assert!((v - 1.0).abs() < 1e-10, "expected 1.0, got {v}");
}
for &v in &fk_data[6..9] {
assert!((v - 2.0).abs() < 1e-10, "expected 2.0, got {v}");
}
}
#[test]
fn test_encoder_layer_construction() {
let layer = TransformerEncoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, true);
assert!(layer.is_ok());
}
#[test]
fn test_encoder_layer_forward_shape() {
let layer = TransformerEncoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
let output = layer.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 5, 16]);
}
#[test]
fn test_encoder_layer_forward_values_finite() {
let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let input = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
let output = layer.forward(&input).unwrap();
for &v in output.data().unwrap() {
assert!(
v.is_finite(),
"TransformerEncoderLayer produced non-finite value: {v}"
);
}
}
#[test]
fn test_encoder_layer_2d_rejected() {
let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
assert!(layer.forward(&input).is_err());
}
#[test]
fn test_encoder_layer_parameters_count() {
let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let params = layer.parameters();
assert_eq!(params.len(), 18);
}
#[test]
fn test_encoder_layer_train_eval() {
let mut layer =
TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.1, 1e-5, false).unwrap();
assert!(layer.is_training());
layer.eval();
assert!(!layer.is_training());
layer.train();
assert!(layer.is_training());
}
#[test]
fn test_encoder_layer_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<TransformerEncoderLayer<f32>>();
assert_send_sync::<TransformerEncoderLayer<f64>>();
}
#[test]
fn test_decoder_layer_construction() {
let layer = TransformerDecoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, true);
assert!(layer.is_ok());
}
#[test]
fn test_decoder_layer_forward_shape() {
let layer = TransformerDecoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, false).unwrap();
let tgt = ferrotorch_core::zeros::<f32>(&[2, 4, 16]).unwrap();
let memory = ferrotorch_core::zeros::<f32>(&[2, 6, 16]).unwrap();
let output = layer.forward_with_memory(&tgt, &memory).unwrap();
assert_eq!(output.shape(), &[2, 4, 16]);
}
#[test]
fn test_decoder_layer_self_forward_shape() {
let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
let output = layer.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 3, 8]);
}
#[test]
fn test_decoder_layer_forward_values_finite() {
let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let tgt = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
let mem = ferrotorch_core::ones::<f32>(&[1, 5, 8]).unwrap();
let output = layer.forward_with_memory(&tgt, &mem).unwrap();
for &v in output.data().unwrap() {
assert!(
v.is_finite(),
"TransformerDecoderLayer produced non-finite value: {v}"
);
}
}
#[test]
fn test_decoder_layer_2d_rejected() {
let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
let memory = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
assert!(layer.forward_with_memory(&input, &memory).is_err());
}
#[test]
fn test_decoder_layer_parameters_count() {
let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let params = layer.parameters();
assert_eq!(params.len(), 28);
}
#[test]
fn test_decoder_layer_train_eval() {
let mut layer =
TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.1, 1e-5, false).unwrap();
assert!(layer.is_training());
layer.eval();
assert!(!layer.is_training());
layer.train();
assert!(layer.is_training());
}
#[test]
fn test_decoder_layer_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<TransformerDecoderLayer<f32>>();
assert_send_sync::<TransformerDecoderLayer<f64>>();
}
}