use std::sync::Arc;
use ferrotorch_core::autograd::no_grad::is_grad_enabled;
use ferrotorch_core::grad_fns::activation::silu;
use ferrotorch_core::grad_fns::arithmetic::{add, mul};
use ferrotorch_core::grad_fns::shape::reshape;
use ferrotorch_core::tensor::GradFn;
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
use crate::attention::MultiheadAttention;
use crate::dropout::Dropout;
use crate::linear::Linear;
use crate::module::Module;
use crate::norm::LayerNorm;
use crate::parameter::Parameter;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum RoPEConvention {
#[default]
Interleaved,
HalfRotation,
}
impl std::fmt::Display for RoPEConvention {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RoPEConvention::Interleaved => write!(f, "interleaved"),
RoPEConvention::HalfRotation => write!(f, "half_rotation"),
}
}
}
#[derive(Debug)]
struct RoPEBackward<T: Float> {
input: Tensor<T>,
cos_flat: Vec<T>,
sin_flat: Vec<T>,
half_dim: usize,
seq_len: usize,
batch_dims: usize,
dim: usize,
seq_offset: usize,
convention: RoPEConvention,
}
impl<T: Float> GradFn<T> for RoPEBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go_data = grad_output.data_vec()?;
let total = go_data.len();
let mut grad_input = Vec::with_capacity(total);
match self.convention {
RoPEConvention::Interleaved => {
for b in 0..self.batch_dims {
for s in 0..self.seq_len {
let cache_start = (self.seq_offset + s) * self.half_dim;
let go_start = b * self.seq_len * self.dim + s * self.dim;
for i in 0..self.half_dim {
let go_even = go_data[go_start + 2 * i];
let go_odd = go_data[go_start + 2 * i + 1];
let cos_val = self.cos_flat[cache_start + i];
let sin_val = self.sin_flat[cache_start + i];
grad_input.push(go_even * cos_val + go_odd * sin_val);
grad_input.push(-go_even * sin_val + go_odd * cos_val);
}
}
}
}
RoPEConvention::HalfRotation => {
for b in 0..self.batch_dims {
for s in 0..self.seq_len {
let cache_start = (self.seq_offset + s) * self.half_dim;
let go_start = b * self.seq_len * self.dim + s * self.dim;
for i in 0..self.half_dim {
let go_first = go_data[go_start + i];
let go_second = go_data[go_start + self.half_dim + i];
let cos_val = self.cos_flat[cache_start + i];
let sin_val = self.sin_flat[cache_start + i];
grad_input.push(go_first * cos_val + go_second * sin_val);
}
for i in 0..self.half_dim {
let go_first = go_data[go_start + i];
let go_second = go_data[go_start + self.half_dim + i];
let cos_val = self.cos_flat[cache_start + i];
let sin_val = self.sin_flat[cache_start + i];
grad_input.push(-go_first * sin_val + go_second * cos_val);
}
}
}
}
}
let g = Tensor::from_storage(
TensorStorage::cpu(grad_input),
self.input.shape().to_vec(),
false,
)?;
Some(if self.input.is_cuda() {
g.to(self.input.device())?
} else {
g
})
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"RoPEBackward"
}
}
#[derive(Debug)]
pub struct RotaryPositionEmbedding<T: Float> {
dim: usize,
max_seq_len: usize,
base: f64,
convention: RoPEConvention,
scaling: RoPEScaling,
cos_cache: Tensor<T>,
sin_cache: Tensor<T>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum RoPEScaling {
None,
Linear {
factor: f64,
},
NtkAware {
factor: f64,
original_max_pos_embeddings: usize,
},
Yarn {
factor: f64,
original_max_pos_embeddings: usize,
beta_fast: f64,
beta_slow: f64,
},
}
impl Default for RoPEScaling {
fn default() -> Self {
RoPEScaling::None
}
}
impl RoPEScaling {
pub const fn yarn_default(factor: f64, original_max_pos_embeddings: usize) -> Self {
RoPEScaling::Yarn {
factor,
original_max_pos_embeddings,
beta_fast: 32.0,
beta_slow: 1.0,
}
}
}
fn yarn_find_correction_dim(
num_rotations: f64,
dim: usize,
base: f64,
original_max_pos_embeddings: usize,
) -> f64 {
(dim as f64
* (original_max_pos_embeddings as f64
/ (num_rotations * 2.0 * std::f64::consts::PI))
.ln())
/ (2.0 * base.ln())
}
fn yarn_find_correction_range(
low_rot: f64,
high_rot: f64,
dim: usize,
base: f64,
original_max_pos_embeddings: usize,
) -> (f64, f64) {
let low = yarn_find_correction_dim(low_rot, dim, base, original_max_pos_embeddings).floor();
let high = yarn_find_correction_dim(high_rot, dim, base, original_max_pos_embeddings).ceil();
(low.max(0.0), high.min((dim - 1) as f64))
}
fn compute_base_inv_freq(dim: usize, base: f64) -> Vec<f64> {
let half = dim / 2;
(0..half)
.map(|i| 1.0 / base.powf(2.0 * i as f64 / dim as f64))
.collect()
}
pub(crate) fn compute_scaled_inv_freq(dim: usize, base: f64, scaling: RoPEScaling) -> Vec<f64> {
match scaling {
RoPEScaling::None => compute_base_inv_freq(dim, base),
RoPEScaling::Linear { factor } => {
let mut iv = compute_base_inv_freq(dim, base);
for v in iv.iter_mut() {
*v /= factor;
}
iv
}
RoPEScaling::NtkAware { factor, .. } => {
let exp = dim as f64 / (dim as f64 - 2.0);
let base_scaled = base * factor.powf(exp);
compute_base_inv_freq(dim, base_scaled)
}
RoPEScaling::Yarn {
factor,
original_max_pos_embeddings,
beta_fast,
beta_slow,
} => {
let half = dim / 2;
let pos_freqs: Vec<f64> = (0..half)
.map(|i| base.powf(2.0 * i as f64 / dim as f64))
.collect();
let extrapolation: Vec<f64> = pos_freqs.iter().map(|p| 1.0 / p).collect();
let interpolation: Vec<f64> =
pos_freqs.iter().map(|p| 1.0 / (factor * p)).collect();
let (low, high) =
yarn_find_correction_range(beta_fast, beta_slow, dim, base, original_max_pos_embeddings);
let (low, high) = (low / 2.0, high / 2.0);
let denom = if high == low { 0.001 } else { high - low };
(0..half)
.map(|i| {
let t = ((i as f64 - low) / denom).clamp(0.0, 1.0);
let mask = 1.0 - t;
interpolation[i] * (1.0 - mask) + extrapolation[i] * mask
})
.collect()
}
}
}
impl<T: Float> RotaryPositionEmbedding<T> {
pub fn new(dim: usize, max_seq_len: usize, base: f64) -> FerrotorchResult<Self> {
Self::with_scaling(
dim,
max_seq_len,
base,
RoPEConvention::default(),
RoPEScaling::None,
)
}
pub fn with_convention(
dim: usize,
max_seq_len: usize,
base: f64,
convention: RoPEConvention,
) -> FerrotorchResult<Self> {
Self::with_scaling(dim, max_seq_len, base, convention, RoPEScaling::None)
}
pub fn with_scaling(
dim: usize,
max_seq_len: usize,
base: f64,
convention: RoPEConvention,
scaling: RoPEScaling,
) -> FerrotorchResult<Self> {
if dim == 0 || dim % 2 != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("RoPE dim must be even and positive, got {dim}"),
});
}
if max_seq_len == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "RoPE max_seq_len must be positive".into(),
});
}
if let RoPEScaling::Linear { factor }
| RoPEScaling::NtkAware { factor, .. }
| RoPEScaling::Yarn { factor, .. } = scaling
{
if !(factor.is_finite() && factor > 0.0) {
return Err(FerrotorchError::InvalidArgument {
message: format!("RoPE scaling factor must be finite and > 0, got {factor}"),
});
}
}
let half_dim = dim / 2;
let thetas = compute_scaled_inv_freq(dim, base, scaling);
let total = max_seq_len * half_dim;
let mut cos_data = Vec::with_capacity(total);
let mut sin_data = Vec::with_capacity(total);
for pos in 0..max_seq_len {
for &theta in &thetas {
let angle = pos as f64 * theta;
cos_data.push(T::from(angle.cos()).unwrap());
sin_data.push(T::from(angle.sin()).unwrap());
}
}
let cos_cache = Tensor::from_storage(
TensorStorage::cpu(cos_data),
vec![max_seq_len, half_dim],
false,
)?;
let sin_cache = Tensor::from_storage(
TensorStorage::cpu(sin_data),
vec![max_seq_len, half_dim],
false,
)?;
Ok(Self {
dim,
max_seq_len,
base,
convention,
scaling,
cos_cache,
sin_cache,
})
}
pub fn apply(&self, x: &Tensor<T>, seq_offset: usize) -> FerrotorchResult<Tensor<T>> {
let shape = x.shape();
let ndim = shape.len();
if ndim < 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"RoPE input must be at least 2-D, got {ndim}-D with shape {shape:?}"
),
});
}
let last_dim = shape[ndim - 1];
if last_dim != self.dim {
return Err(FerrotorchError::ShapeMismatch {
message: format!("RoPE: last dim of input ({last_dim}) != dim ({})", self.dim),
});
}
let seq_len = shape[ndim - 2];
if seq_offset + seq_len > self.max_seq_len {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"RoPE: seq_offset ({seq_offset}) + seq_len ({seq_len}) > max_seq_len ({})",
self.max_seq_len
),
});
}
let device = x.device();
let half_dim = self.dim / 2;
let cos_data = self.cos_cache.data_vec()?;
let sin_data = self.sin_cache.data_vec()?;
let x_data = x.data_vec()?;
let batch_dims: usize = shape[..ndim - 2].iter().product();
let total = x.numel();
let mut output = Vec::with_capacity(total);
match self.convention {
RoPEConvention::Interleaved => {
for b in 0..batch_dims {
for s in 0..seq_len {
let pos = seq_offset + s;
let cache_start = pos * half_dim;
let x_start = b * seq_len * self.dim + s * self.dim;
for i in 0..half_dim {
let x_even = x_data[x_start + 2 * i];
let x_odd = x_data[x_start + 2 * i + 1];
let cos_val = cos_data[cache_start + i];
let sin_val = sin_data[cache_start + i];
output.push(x_even * cos_val - x_odd * sin_val);
output.push(x_even * sin_val + x_odd * cos_val);
}
}
}
}
RoPEConvention::HalfRotation => {
for b in 0..batch_dims {
for s in 0..seq_len {
let pos = seq_offset + s;
let cache_start = pos * half_dim;
let x_start = b * seq_len * self.dim + s * self.dim;
for i in 0..half_dim {
let x_first = x_data[x_start + i];
let x_second = x_data[x_start + half_dim + i];
let cos_val = cos_data[cache_start + i];
let sin_val = sin_data[cache_start + i];
output.push(x_first * cos_val - x_second * sin_val);
}
for i in 0..half_dim {
let x_first = x_data[x_start + i];
let x_second = x_data[x_start + half_dim + i];
let cos_val = cos_data[cache_start + i];
let sin_val = sin_data[cache_start + i];
output.push(x_first * sin_val + x_second * cos_val);
}
}
}
}
}
let result = if is_grad_enabled() && x.requires_grad() {
Tensor::from_operation(
TensorStorage::cpu(output),
shape.to_vec(),
Arc::new(RoPEBackward {
input: x.clone(),
cos_flat: cos_data,
sin_flat: sin_data,
half_dim,
seq_len,
batch_dims,
dim: self.dim,
seq_offset,
convention: self.convention,
}),
)?
} else {
Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?
};
if device.is_cuda() {
result.to(device)
} else {
Ok(result)
}
}
#[inline]
pub fn dim(&self) -> usize {
self.dim
}
#[inline]
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
#[inline]
pub fn base(&self) -> f64 {
self.base
}
#[inline]
pub fn convention(&self) -> RoPEConvention {
self.convention
}
#[inline]
pub fn scaling(&self) -> RoPEScaling {
self.scaling
}
}
#[derive(Debug)]
pub struct SwiGLU<T: Float> {
w1: Linear<T>,
w2: Linear<T>,
w3: Linear<T>,
training: bool,
}
impl<T: Float> SwiGLU<T> {
pub fn new(in_features: usize, hidden_features: usize, bias: bool) -> FerrotorchResult<Self> {
let w1 = Linear::new(in_features, hidden_features, bias)?;
let w2 = Linear::new(in_features, hidden_features, bias)?;
let w3 = Linear::new(hidden_features, in_features, bias)?;
Ok(Self {
w1,
w2,
w3,
training: true,
})
}
fn forward_3d(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
let batch = shape[0];
let seq_len = shape[1];
let flat = reshape(input, &[(batch * seq_len) as isize, -1])?;
let output_flat = self.forward_2d(&flat)?;
let out_features = output_flat.shape()[1];
reshape(
&output_flat,
&[batch as isize, seq_len as isize, out_features as isize],
)
}
fn forward_2d(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let w1_out = self.w1.forward(input)?;
let gate = silu(&w1_out)?;
let up = self.w2.forward(input)?;
let gated = mul(&gate, &up)?;
self.w3.forward(&gated)
}
}
impl<T: Float> Module<T> for SwiGLU<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
match input.ndim() {
2 => self.forward_2d(input),
3 => self.forward_3d(input),
_ => Err(FerrotorchError::InvalidArgument {
message: format!(
"SwiGLU expects 2-D or 3-D input, got {}-D with shape {:?}",
input.ndim(),
input.shape()
),
}),
}
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = self.w1.parameters();
params.extend(self.w2.parameters());
params.extend(self.w3.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = self.w1.parameters_mut();
params.extend(self.w2.parameters_mut());
params.extend(self.w3.parameters_mut());
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = Vec::new();
for (name, param) in self.w1.named_parameters() {
params.push((format!("w1.{name}"), param));
}
for (name, param) in self.w2.named_parameters() {
params.push((format!("w2.{name}"), param));
}
for (name, param) in self.w3.named_parameters() {
params.push((format!("w3.{name}"), param));
}
params
}
fn train(&mut self) {
self.training = true;
self.w1.train();
self.w2.train();
self.w3.train();
}
fn eval(&mut self) {
self.training = false;
self.w1.eval();
self.w2.eval();
self.w3.eval();
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct CacheDims {
batch: usize,
num_kv_heads: usize,
head_dim: usize,
}
#[derive(Debug)]
pub struct KVCache<T: Float> {
key_cache: Option<Tensor<T>>,
value_cache: Option<Tensor<T>>,
max_seq_len: usize,
dims: Option<CacheDims>,
}
impl<T: Float> KVCache<T> {
pub fn new(max_seq_len: usize) -> Self {
Self {
key_cache: None,
value_cache: None,
max_seq_len,
dims: None,
}
}
pub fn with_dims(
max_seq_len: usize,
batch: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Self {
Self {
key_cache: None,
value_cache: None,
max_seq_len,
dims: Some(CacheDims {
batch,
num_kv_heads,
head_dim,
}),
}
}
pub fn update(
&mut self,
key: Tensor<T>,
value: Tensor<T>,
) -> FerrotorchResult<(Tensor<T>, Tensor<T>)> {
if key.ndim() != 4 || value.ndim() != 4 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"KVCache expects 4-D [B, kv_heads, seq, dim] tensors, \
got key {:?}, value {:?}",
key.shape(),
value.shape()
),
});
}
if key.shape() != value.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"KVCache: key shape {:?} != value shape {:?}",
key.shape(),
value.shape()
),
});
}
let ks = key.shape();
let incoming = CacheDims {
batch: ks[0],
num_kv_heads: ks[1],
head_dim: ks[3],
};
match &self.dims {
Some(expected) if expected != &incoming => {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"KVCache: update shape [B={}, kv_heads={}, _, dim={}] does not \
match pinned dims [B={}, kv_heads={}, _, dim={}]",
incoming.batch,
incoming.num_kv_heads,
incoming.head_dim,
expected.batch,
expected.num_kv_heads,
expected.head_dim,
),
});
}
None => self.dims = Some(incoming),
_ => {}
}
let (full_key, full_value) = match (&self.key_cache, &self.value_cache) {
(Some(ck), Some(cv)) => {
let fk = concat_along_dim2(ck, &key)?;
let fv = concat_along_dim2(cv, &value)?;
(fk, fv)
}
_ => (key.clone(), value.clone()),
};
let total_seq = full_key.shape()[2];
if total_seq > self.max_seq_len {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"KVCache: total sequence length ({total_seq}) exceeds max_seq_len ({})",
self.max_seq_len
),
});
}
self.key_cache = Some(full_key.clone());
self.value_cache = Some(full_value.clone());
Ok((full_key, full_value))
}
pub fn reset(&mut self) {
self.key_cache = None;
self.value_cache = None;
}
pub fn seq_len(&self) -> usize {
self.key_cache.as_ref().map(|k| k.shape()[2]).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.key_cache.is_none()
}
#[inline]
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
pub fn num_kv_heads(&self) -> Option<usize> {
self.dims.map(|d| d.num_kv_heads)
}
pub fn head_dim(&self) -> Option<usize> {
self.dims.map(|d| d.head_dim)
}
pub fn batch_size(&self) -> Option<usize> {
self.dims.map(|d| d.batch)
}
}
fn concat_along_dim2<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let sa = a.shape();
let sb = b.shape();
if sa[0] != sb[0] || sa[1] != sb[1] || sa[3] != sb[3] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"concat_along_dim2: shapes {:?} and {:?} must match on dims 0, 1, 3",
sa, sb
),
});
}
let device = a.device();
let (batch, heads, seq_a, dim) = (sa[0], sa[1], sa[2], sa[3]);
let seq_b = sb[2];
let seq_out = seq_a + seq_b;
let a_data = a.data_vec()?;
let b_data = b.data_vec()?;
let mut output = Vec::with_capacity(batch * heads * seq_out * dim);
for ba in 0..batch {
for h in 0..heads {
let a_start = (ba * heads + h) * seq_a * dim;
output.extend_from_slice(&a_data[a_start..a_start + seq_a * dim]);
let b_start = (ba * heads + h) * seq_b * dim;
output.extend_from_slice(&b_data[b_start..b_start + seq_b * dim]);
}
}
let result = Tensor::from_storage(
TensorStorage::cpu(output),
vec![batch, heads, seq_out, dim],
false,
)?;
if device.is_cuda() {
result.to(device)
} else {
Ok(result)
}
}
#[derive(Debug)]
pub struct TransformerEncoderLayer<T: Float> {
self_attn: MultiheadAttention<T>,
ffn: SwiGLU<T>,
norm1: LayerNorm<T>,
norm2: LayerNorm<T>,
dropout: Dropout<T>,
training: bool,
}
impl<T: Float> TransformerEncoderLayer<T> {
pub fn new(
d_model: usize,
num_heads: usize,
d_ff: usize,
dropout_p: f64,
layer_norm_eps: f64,
bias: bool,
) -> FerrotorchResult<Self> {
let self_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
let ffn = SwiGLU::new(d_model, d_ff, bias)?;
let norm1 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let norm2 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let dropout = Dropout::new(dropout_p)?;
Ok(Self {
self_attn,
ffn,
norm1,
norm2,
dropout,
training: true,
})
}
}
impl<T: Float> Module<T> for TransformerEncoderLayer<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 3 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"TransformerEncoderLayer expects 3-D [batch, seq, d_model], got {:?}",
input.shape()
),
});
}
let normed1 = self.norm1.forward(input)?;
let attn_out = self.self_attn.forward(&normed1)?;
let attn_out = self.dropout.forward(&attn_out)?;
let residual1 = add(input, &attn_out)?;
let normed2 = self.norm2.forward(&residual1)?;
let ffn_out = self.ffn.forward(&normed2)?;
let ffn_out = self.dropout.forward(&ffn_out)?;
let residual2 = add(&residual1, &ffn_out)?;
Ok(residual2)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters());
params.extend(self.ffn.parameters());
params.extend(self.norm1.parameters());
params.extend(self.norm2.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters_mut());
params.extend(self.ffn.parameters_mut());
params.extend(self.norm1.parameters_mut());
params.extend(self.norm2.parameters_mut());
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = Vec::new();
for (name, param) in self.self_attn.named_parameters() {
params.push((format!("self_attn.{name}"), param));
}
for (name, param) in self.ffn.named_parameters() {
params.push((format!("ffn.{name}"), param));
}
for (name, param) in self.norm1.named_parameters() {
params.push((format!("norm1.{name}"), param));
}
for (name, param) in self.norm2.named_parameters() {
params.push((format!("norm2.{name}"), param));
}
params
}
fn train(&mut self) {
self.training = true;
self.self_attn.train();
self.ffn.train();
self.norm1.train();
self.norm2.train();
self.dropout.train();
}
fn eval(&mut self) {
self.training = false;
self.self_attn.eval();
self.ffn.eval();
self.norm1.eval();
self.norm2.eval();
self.dropout.eval();
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
pub struct TransformerDecoderLayer<T: Float> {
self_attn: MultiheadAttention<T>,
cross_attn: MultiheadAttention<T>,
ffn: SwiGLU<T>,
norm1: LayerNorm<T>,
norm2: LayerNorm<T>,
norm3: LayerNorm<T>,
dropout: Dropout<T>,
training: bool,
}
impl<T: Float> TransformerDecoderLayer<T> {
pub fn new(
d_model: usize,
num_heads: usize,
d_ff: usize,
dropout_p: f64,
layer_norm_eps: f64,
bias: bool,
) -> FerrotorchResult<Self> {
let self_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
let cross_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
let ffn = SwiGLU::new(d_model, d_ff, bias)?;
let norm1 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let norm2 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let norm3 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
let dropout = Dropout::new(dropout_p)?;
Ok(Self {
self_attn,
cross_attn,
ffn,
norm1,
norm2,
norm3,
dropout,
training: true,
})
}
pub fn forward_with_memory(
&self,
input: &Tensor<T>,
memory: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 3 || memory.ndim() != 3 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"TransformerDecoderLayer expects 3-D inputs, \
got input {:?}, memory {:?}",
input.shape(),
memory.shape()
),
});
}
let normed1 = self.norm1.forward(input)?;
let self_attn_out = self
.self_attn
.forward_qkv(&normed1, &normed1, &normed1, true)?;
let self_attn_out = self.dropout.forward(&self_attn_out)?;
let residual1 = add(input, &self_attn_out)?;
let normed2 = self.norm2.forward(&residual1)?;
let cross_attn_out = self
.cross_attn
.forward_qkv(&normed2, memory, memory, false)?;
let cross_attn_out = self.dropout.forward(&cross_attn_out)?;
let residual2 = add(&residual1, &cross_attn_out)?;
let normed3 = self.norm3.forward(&residual2)?;
let ffn_out = self.ffn.forward(&normed3)?;
let ffn_out = self.dropout.forward(&ffn_out)?;
let residual3 = add(&residual2, &ffn_out)?;
Ok(residual3)
}
}
impl<T: Float> Module<T> for TransformerDecoderLayer<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
self.forward_with_memory(input, input)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters());
params.extend(self.cross_attn.parameters());
params.extend(self.ffn.parameters());
params.extend(self.norm1.parameters());
params.extend(self.norm2.parameters());
params.extend(self.norm3.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters_mut());
params.extend(self.cross_attn.parameters_mut());
params.extend(self.ffn.parameters_mut());
params.extend(self.norm1.parameters_mut());
params.extend(self.norm2.parameters_mut());
params.extend(self.norm3.parameters_mut());
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = Vec::new();
for (name, param) in self.self_attn.named_parameters() {
params.push((format!("self_attn.{name}"), param));
}
for (name, param) in self.cross_attn.named_parameters() {
params.push((format!("cross_attn.{name}"), param));
}
for (name, param) in self.ffn.named_parameters() {
params.push((format!("ffn.{name}"), param));
}
for (name, param) in self.norm1.named_parameters() {
params.push((format!("norm1.{name}"), param));
}
for (name, param) in self.norm2.named_parameters() {
params.push((format!("norm2.{name}"), param));
}
for (name, param) in self.norm3.named_parameters() {
params.push((format!("norm3.{name}"), param));
}
params
}
fn train(&mut self) {
self.training = true;
self.self_attn.train();
self.cross_attn.train();
self.ffn.train();
self.norm1.train();
self.norm2.train();
self.norm3.train();
self.dropout.train();
}
fn eval(&mut self) {
self.training = false;
self.self_attn.eval();
self.cross_attn.eval();
self.ffn.eval();
self.norm1.eval();
self.norm2.eval();
self.norm3.eval();
self.dropout.eval();
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
pub struct TransformerEncoder<T: Float> {
layers: Vec<TransformerEncoderLayer<T>>,
norm: Option<LayerNorm<T>>,
training: bool,
}
impl<T: Float> TransformerEncoder<T> {
#[allow(clippy::too_many_arguments)]
pub fn new(
d_model: usize,
num_heads: usize,
num_layers: usize,
d_ff: usize,
dropout_p: f64,
layer_norm_eps: f64,
bias: bool,
final_norm: bool,
) -> FerrotorchResult<Self> {
if num_layers == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "TransformerEncoder: num_layers must be > 0".into(),
});
}
let mut layers = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
layers.push(TransformerEncoderLayer::new(
d_model,
num_heads,
d_ff,
dropout_p,
layer_norm_eps,
bias,
)?);
}
let norm = if final_norm {
Some(LayerNorm::new(vec![d_model], layer_norm_eps, true)?)
} else {
None
};
Ok(Self {
layers,
norm,
training: true,
})
}
#[inline]
pub fn num_layers(&self) -> usize {
self.layers.len()
}
}
impl<T: Float> Module<T> for TransformerEncoder<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let mut output = input.clone();
for layer in &self.layers {
output = layer.forward(&output)?;
}
if let Some(ref norm) = self.norm {
output = norm.forward(&output)?;
}
Ok(output)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = Vec::new();
for (i, layer) in self.layers.iter().enumerate() {
let _ = i; params.extend(layer.parameters());
}
if let Some(ref norm) = self.norm {
params.extend(norm.parameters());
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = Vec::new();
for layer in &mut self.layers {
params.extend(layer.parameters_mut());
}
if let Some(ref mut norm) = self.norm {
params.extend(norm.parameters_mut());
}
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = Vec::new();
for (i, layer) in self.layers.iter().enumerate() {
for (name, param) in layer.named_parameters() {
params.push((format!("layers.{i}.{name}"), param));
}
}
if let Some(ref norm) = self.norm {
for (name, param) in norm.named_parameters() {
params.push((format!("norm.{name}"), param));
}
}
params
}
fn train(&mut self) {
self.training = true;
for layer in &mut self.layers {
layer.train();
}
if let Some(ref mut norm) = self.norm {
norm.train();
}
}
fn eval(&mut self) {
self.training = false;
for layer in &mut self.layers {
layer.eval();
}
if let Some(ref mut norm) = self.norm {
norm.eval();
}
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
pub struct TransformerDecoder<T: Float> {
layers: Vec<TransformerDecoderLayer<T>>,
norm: Option<LayerNorm<T>>,
training: bool,
}
impl<T: Float> TransformerDecoder<T> {
#[allow(clippy::too_many_arguments)]
pub fn new(
d_model: usize,
num_heads: usize,
num_layers: usize,
d_ff: usize,
dropout_p: f64,
layer_norm_eps: f64,
bias: bool,
final_norm: bool,
) -> FerrotorchResult<Self> {
if num_layers == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "TransformerDecoder: num_layers must be > 0".into(),
});
}
let mut layers = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
layers.push(TransformerDecoderLayer::new(
d_model,
num_heads,
d_ff,
dropout_p,
layer_norm_eps,
bias,
)?);
}
let norm = if final_norm {
Some(LayerNorm::new(vec![d_model], layer_norm_eps, true)?)
} else {
None
};
Ok(Self {
layers,
norm,
training: true,
})
}
pub fn forward_with_memory(
&self,
input: &Tensor<T>,
memory: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
let mut output = input.clone();
for layer in &self.layers {
output = layer.forward_with_memory(&output, memory)?;
}
if let Some(ref norm) = self.norm {
output = norm.forward(&output)?;
}
Ok(output)
}
#[inline]
pub fn num_layers(&self) -> usize {
self.layers.len()
}
}
impl<T: Float> Module<T> for TransformerDecoder<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
self.forward_with_memory(input, input)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = Vec::new();
for layer in &self.layers {
params.extend(layer.parameters());
}
if let Some(ref norm) = self.norm {
params.extend(norm.parameters());
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = Vec::new();
for layer in &mut self.layers {
params.extend(layer.parameters_mut());
}
if let Some(ref mut norm) = self.norm {
params.extend(norm.parameters_mut());
}
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = Vec::new();
for (i, layer) in self.layers.iter().enumerate() {
for (name, param) in layer.named_parameters() {
params.push((format!("layers.{i}.{name}"), param));
}
}
if let Some(ref norm) = self.norm {
for (name, param) in norm.named_parameters() {
params.push((format!("norm.{name}"), param));
}
}
params
}
fn train(&mut self) {
self.training = true;
for layer in &mut self.layers {
layer.train();
}
if let Some(ref mut norm) = self.norm {
norm.train();
}
}
fn eval(&mut self) {
self.training = false;
for layer in &mut self.layers {
layer.eval();
}
if let Some(ref mut norm) = self.norm {
norm.eval();
}
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
pub struct Transformer<T: Float> {
encoder: TransformerEncoder<T>,
decoder: TransformerDecoder<T>,
training: bool,
}
impl<T: Float> Transformer<T> {
#[allow(clippy::too_many_arguments)]
pub fn new(
d_model: usize,
num_heads: usize,
num_encoder_layers: usize,
num_decoder_layers: usize,
d_ff: usize,
dropout_p: f64,
layer_norm_eps: f64,
bias: bool,
) -> FerrotorchResult<Self> {
let encoder = TransformerEncoder::new(
d_model,
num_heads,
num_encoder_layers,
d_ff,
dropout_p,
layer_norm_eps,
bias,
true, )?;
let decoder = TransformerDecoder::new(
d_model,
num_heads,
num_decoder_layers,
d_ff,
dropout_p,
layer_norm_eps,
bias,
true, )?;
Ok(Self {
encoder,
decoder,
training: true,
})
}
pub fn forward_transformer(
&self,
src: &Tensor<T>,
tgt: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
let memory = self.encoder.forward(src)?;
self.decoder.forward_with_memory(tgt, &memory)
}
#[inline]
pub fn num_encoder_layers(&self) -> usize {
self.encoder.num_layers()
}
#[inline]
pub fn num_decoder_layers(&self) -> usize {
self.decoder.num_layers()
}
}
impl<T: Float> Module<T> for Transformer<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
self.forward_transformer(input, input)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = self.encoder.parameters();
params.extend(self.decoder.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = self.encoder.parameters_mut();
params.extend(self.decoder.parameters_mut());
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = Vec::new();
for (name, param) in self.encoder.named_parameters() {
params.push((format!("encoder.{name}"), param));
}
for (name, param) in self.decoder.named_parameters() {
params.push((format!("decoder.{name}"), param));
}
params
}
fn train(&mut self) {
self.training = true;
self.encoder.train();
self.decoder.train();
}
fn eval(&mut self) {
self.training = false;
self.encoder.eval();
self.decoder.eval();
}
fn is_training(&self) -> bool {
self.training
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rope_construction() {
let rope = RotaryPositionEmbedding::<f32>::new(64, 512, 10000.0);
assert!(rope.is_ok());
let rope = rope.unwrap();
assert_eq!(rope.dim(), 64);
assert_eq!(rope.max_seq_len(), 512);
assert_eq!(rope.base(), 10000.0);
}
#[test]
fn test_rope_odd_dim_rejected() {
assert!(RotaryPositionEmbedding::<f32>::new(63, 512, 10000.0).is_err());
}
#[test]
fn test_rope_zero_dim_rejected() {
assert!(RotaryPositionEmbedding::<f32>::new(0, 512, 10000.0).is_err());
}
#[test]
fn test_rope_zero_seq_rejected() {
assert!(RotaryPositionEmbedding::<f32>::new(64, 0, 10000.0).is_err());
}
#[test]
fn test_rope_output_shape_2d() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
let y = rope.apply(&x, 0).unwrap();
assert_eq!(y.shape(), &[4, 8]);
}
#[test]
fn test_rope_output_shape_3d() {
let rope = RotaryPositionEmbedding::<f32>::new(16, 256, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[2, 10, 16]).unwrap();
let y = rope.apply(&x, 0).unwrap();
assert_eq!(y.shape(), &[2, 10, 16]);
}
#[test]
fn test_rope_output_shape_4d() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[2, 4, 6, 8]).unwrap();
let y = rope.apply(&x, 0).unwrap();
assert_eq!(y.shape(), &[2, 4, 6, 8]);
}
#[test]
fn test_rope_with_offset() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
let x = ferrotorch_core::ones::<f32>(&[4, 8]).unwrap();
let y = rope.apply(&x, 10).unwrap();
assert_eq!(y.shape(), &[4, 8]);
}
#[test]
fn test_rope_offset_overflow_rejected() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 16, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[10, 8]).unwrap();
assert!(rope.apply(&x, 10).is_err());
}
#[test]
fn test_rope_position_zero_is_identity() {
let rope = RotaryPositionEmbedding::<f64>::new(4, 64, 10000.0).unwrap();
let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
let y = rope.apply(&x, 0).unwrap();
let y_data = y.data().unwrap();
let x_data = x.data().unwrap();
for (i, (&xv, &yv)) in x_data.iter().zip(y_data.iter()).enumerate() {
assert!(
(xv - yv).abs() < 1e-10,
"position 0 should be identity, index {i}: x={xv}, y={yv}"
);
}
}
#[test]
fn test_rope_values_are_finite() {
let rope = RotaryPositionEmbedding::<f32>::new(16, 512, 10000.0).unwrap();
let x = ferrotorch_core::ones::<f32>(&[2, 4, 10, 16]).unwrap();
let y = rope.apply(&x, 0).unwrap();
for &v in y.data().unwrap() {
assert!(v.is_finite(), "RoPE produced non-finite value: {v}");
}
}
#[test]
fn test_rope_wrong_dim_rejected() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
let x = ferrotorch_core::zeros::<f32>(&[4, 10]).unwrap(); assert!(rope.apply(&x, 0).is_err());
}
#[test]
fn test_rope_scaling_default_is_none() {
let rope = RotaryPositionEmbedding::<f32>::new(16, 128, 10000.0).unwrap();
assert_eq!(rope.scaling(), RoPEScaling::None);
}
#[test]
fn test_rope_scaling_none_matches_classical() {
let a = RotaryPositionEmbedding::<f64>::new(16, 32, 10000.0).unwrap();
let b = RotaryPositionEmbedding::<f64>::with_scaling(
16,
32,
10000.0,
RoPEConvention::default(),
RoPEScaling::None,
)
.unwrap();
let x = ferrotorch_core::from_slice(
&(0..16).map(|i| i as f64 * 0.1).collect::<Vec<_>>(),
&[1, 16],
)
.unwrap();
let ya = a.apply(&x, 7).unwrap();
let yb = b.apply(&x, 7).unwrap();
for (va, vb) in ya.data().unwrap().iter().zip(yb.data().unwrap().iter()) {
assert!((va - vb).abs() < 1e-12);
}
}
#[test]
fn test_rope_scaling_linear_halves_angles() {
let scaled = RotaryPositionEmbedding::<f64>::with_scaling(
8,
64,
10000.0,
RoPEConvention::default(),
RoPEScaling::Linear { factor: 2.0 },
)
.unwrap();
let plain = RotaryPositionEmbedding::<f64>::new(8, 64, 10000.0).unwrap();
let x = ferrotorch_core::ones::<f64>(&[1, 8]).unwrap();
let y_scaled = scaled.apply(&x, 8).unwrap();
let y_plain = plain.apply(&x, 4).unwrap();
for (a, b) in y_scaled
.data()
.unwrap()
.iter()
.zip(y_plain.data().unwrap().iter())
{
assert!(
(a - b).abs() < 1e-6,
"scaled(pos=8) should match plain(pos=4): {a} vs {b}"
);
}
}
#[test]
fn test_rope_scaling_ntk_inv_freq() {
use super::compute_scaled_inv_freq;
let dim = 64;
let base = 10000.0;
let factor = 4.0;
let ntk = compute_scaled_inv_freq(
dim,
base,
RoPEScaling::NtkAware {
factor,
original_max_pos_embeddings: 2048,
},
);
let plain = compute_scaled_inv_freq(dim, base, RoPEScaling::None);
assert_eq!(ntk.len(), 32);
assert_eq!(plain.len(), 32);
assert!(
(ntk[0] - plain[0]).abs() < 1e-15,
"NTK inv_freq[0] should equal plain inv_freq[0]: ntk={}, plain={}",
ntk[0],
plain[0]
);
let ratio = ntk[31] / plain[31];
let expected = 1.0 / factor;
assert!(
(ratio - expected).abs() < 0.05,
"NTK inv_freq[31]/plain ratio should be ~{expected}: got {ratio}"
);
}
#[test]
fn test_rope_scaling_linear_inv_freq_halved() {
use super::compute_scaled_inv_freq;
let lin = compute_scaled_inv_freq(8, 10000.0, RoPEScaling::Linear { factor: 2.0 });
let plain = compute_scaled_inv_freq(8, 10000.0, RoPEScaling::None);
for (a, b) in lin.iter().zip(plain.iter()) {
assert!((a - b / 2.0).abs() < 1e-15, "linear should halve: {a} vs {b}/2");
}
}
#[test]
fn test_rope_scaling_yarn_inv_freq_piecewise() {
use super::compute_scaled_inv_freq;
let dim = 64;
let base = 10000.0;
let factor = 4.0;
let yarn = compute_scaled_inv_freq(
dim,
base,
RoPEScaling::yarn_default(factor, 2048),
);
let plain = compute_scaled_inv_freq(dim, base, RoPEScaling::None);
assert!(
(yarn[0] - plain[0]).abs() < 1e-12,
"YARN[0] (extrapolation) should equal plain[0]: {} vs {}",
yarn[0],
plain[0]
);
let expected_low = plain[dim / 2 - 1] / factor;
let ratio = yarn[dim / 2 - 1] / expected_low;
assert!(
(ratio - 1.0).abs() < 0.1,
"YARN[dim/2-1] (interpolation) should approx equal plain/factor: {} vs {}",
yarn[dim / 2 - 1],
expected_low
);
}
#[test]
fn test_rope_scaling_yarn_constructs() {
let rope = RotaryPositionEmbedding::<f32>::with_scaling(
64,
256,
10000.0,
RoPEConvention::default(),
RoPEScaling::yarn_default(2.0, 2048),
)
.unwrap();
assert!(matches!(rope.scaling(), RoPEScaling::Yarn { .. }));
let x = ferrotorch_core::ones::<f32>(&[1, 64]).unwrap();
for &v in rope.apply(&x, 0).unwrap().data().unwrap() {
assert!(v.is_finite());
}
}
#[test]
fn test_rope_scaling_rejects_zero_factor() {
let r = RotaryPositionEmbedding::<f32>::with_scaling(
8,
16,
10000.0,
RoPEConvention::default(),
RoPEScaling::Linear { factor: 0.0 },
);
assert!(r.is_err());
}
#[test]
fn test_rope_scaling_rejects_negative_factor() {
let r = RotaryPositionEmbedding::<f32>::with_scaling(
8,
16,
10000.0,
RoPEConvention::default(),
RoPEScaling::NtkAware {
factor: -2.0,
original_max_pos_embeddings: 2048,
},
);
assert!(r.is_err());
}
#[test]
fn test_rope_scaling_accessor() {
let rope = RotaryPositionEmbedding::<f32>::with_scaling(
16,
64,
10000.0,
RoPEConvention::default(),
RoPEScaling::Linear { factor: 4.0 },
)
.unwrap();
assert_eq!(rope.scaling(), RoPEScaling::Linear { factor: 4.0 });
}
#[test]
fn test_rope_half_rotation_construction() {
let rope = RotaryPositionEmbedding::<f32>::with_convention(
8,
128,
10000.0,
RoPEConvention::HalfRotation,
)
.unwrap();
assert_eq!(rope.convention(), RoPEConvention::HalfRotation);
}
#[test]
fn test_rope_half_rotation_output_shape() {
let rope = RotaryPositionEmbedding::<f32>::with_convention(
8,
128,
10000.0,
RoPEConvention::HalfRotation,
)
.unwrap();
let x = ferrotorch_core::zeros::<f32>(&[2, 4, 8]).unwrap();
let y = rope.apply(&x, 0).unwrap();
assert_eq!(y.shape(), &[2, 4, 8]);
}
#[test]
fn test_rope_half_rotation_position_zero_is_identity() {
let rope = RotaryPositionEmbedding::<f64>::with_convention(
4,
64,
10000.0,
RoPEConvention::HalfRotation,
)
.unwrap();
let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
let y = rope.apply(&x, 0).unwrap();
let x_data = x.data().unwrap();
let y_data = y.data().unwrap();
for (i, (&xv, &yv)) in x_data.iter().zip(y_data.iter()).enumerate() {
assert!(
(xv - yv).abs() < 1e-10,
"half-rot pos 0 should be identity, index {i}: x={xv}, y={yv}"
);
}
}
#[test]
fn test_rope_half_rotation_correctness() {
let rope = RotaryPositionEmbedding::<f64>::with_convention(
4,
64,
10000.0,
RoPEConvention::HalfRotation,
)
.unwrap();
let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
let y = rope.apply(&x, 1).unwrap();
let cos_data = rope.cos_cache.data().unwrap();
let sin_data = rope.sin_cache.data().unwrap();
let c0 = cos_data[2];
let c1 = cos_data[3];
let s0 = sin_data[2];
let s1 = sin_data[3];
let expected = [
1.0 * c0 - 3.0 * s0,
2.0 * c1 - 4.0 * s1,
1.0 * s0 + 3.0 * c0,
2.0 * s1 + 4.0 * c1,
];
let y_data = y.data().unwrap();
for (i, (&actual, &exp)) in y_data.iter().zip(expected.iter()).enumerate() {
assert!(
(actual - exp).abs() < 1e-10,
"half-rot index {i}: actual={actual}, expected={exp}"
);
}
}
#[test]
fn test_rope_interleaved_vs_half_rotation_differ() {
let rope_il = RotaryPositionEmbedding::<f64>::with_convention(
4,
64,
10000.0,
RoPEConvention::Interleaved,
)
.unwrap();
let rope_hr = RotaryPositionEmbedding::<f64>::with_convention(
4,
64,
10000.0,
RoPEConvention::HalfRotation,
)
.unwrap();
let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
let y_il = rope_il.apply(&x, 1).unwrap();
let y_hr = rope_hr.apply(&x, 1).unwrap();
let il_data = y_il.data().unwrap();
let hr_data = y_hr.data().unwrap();
let any_differ = il_data
.iter()
.zip(hr_data.iter())
.any(|(&a, &b)| (a - b).abs() > 1e-10);
assert!(
any_differ,
"interleaved and half-rotation should produce different outputs at pos > 0"
);
}
#[test]
fn test_rope_default_convention_is_interleaved() {
let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
assert_eq!(rope.convention(), RoPEConvention::Interleaved);
}
#[test]
fn test_swiglu_construction() {
let swiglu = SwiGLU::<f32>::new(64, 128, true);
assert!(swiglu.is_ok());
}
#[test]
fn test_swiglu_forward_shape_2d() {
let swiglu = SwiGLU::<f32>::new(16, 32, true).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[4, 16]).unwrap();
let output = swiglu.forward(&input).unwrap();
assert_eq!(output.shape(), &[4, 16]);
}
#[test]
fn test_swiglu_forward_shape_3d() {
let swiglu = SwiGLU::<f32>::new(16, 32, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
let output = swiglu.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 5, 16]);
}
#[test]
fn test_swiglu_forward_values_finite() {
let swiglu = SwiGLU::<f32>::new(8, 16, true).unwrap();
let input = ferrotorch_core::ones::<f32>(&[2, 3, 8]).unwrap();
let output = swiglu.forward(&input).unwrap();
for &v in output.data().unwrap() {
assert!(v.is_finite(), "SwiGLU produced non-finite value: {v}");
}
}
#[test]
fn test_swiglu_1d_rejected() {
let swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[8]).unwrap();
assert!(swiglu.forward(&input).is_err());
}
#[test]
fn test_swiglu_parameters() {
let swiglu = SwiGLU::<f32>::new(8, 16, true).unwrap();
let params = swiglu.parameters();
assert_eq!(params.len(), 6);
let named = swiglu.named_parameters();
let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
assert!(names.contains(&"w1.weight"));
assert!(names.contains(&"w1.bias"));
assert!(names.contains(&"w2.weight"));
assert!(names.contains(&"w2.bias"));
assert!(names.contains(&"w3.weight"));
assert!(names.contains(&"w3.bias"));
}
#[test]
fn test_swiglu_parameters_no_bias() {
let swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
let params = swiglu.parameters();
assert_eq!(params.len(), 3);
}
#[test]
fn test_swiglu_train_eval() {
let mut swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
assert!(swiglu.is_training());
swiglu.eval();
assert!(!swiglu.is_training());
swiglu.train();
assert!(swiglu.is_training());
}
#[test]
fn test_kv_cache_new_empty() {
let cache = KVCache::<f32>::new(1024);
assert!(cache.is_empty());
assert_eq!(cache.seq_len(), 0);
assert_eq!(cache.max_seq_len(), 1024);
}
#[test]
fn test_kv_cache_single_update() {
let mut cache = KVCache::<f32>::new(128);
let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let (fk, fv) = cache.update(k, v).unwrap();
assert_eq!(fk.shape(), &[1, 2, 3, 4]);
assert_eq!(fv.shape(), &[1, 2, 3, 4]);
assert_eq!(cache.seq_len(), 3);
}
#[test]
fn test_kv_cache_append() {
let mut cache = KVCache::<f32>::new(128);
let k1 = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let v1 = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
cache.update(k1, v1).unwrap();
assert_eq!(cache.seq_len(), 3);
let k2 = ferrotorch_core::ones::<f32>(&[1, 2, 2, 4]).unwrap();
let v2 = ferrotorch_core::ones::<f32>(&[1, 2, 2, 4]).unwrap();
let (fk, fv) = cache.update(k2, v2).unwrap();
assert_eq!(fk.shape(), &[1, 2, 5, 4]); assert_eq!(fv.shape(), &[1, 2, 5, 4]);
assert_eq!(cache.seq_len(), 5);
}
#[test]
fn test_kv_cache_reset() {
let mut cache = KVCache::<f32>::new(128);
let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
cache.update(k, v).unwrap();
assert_eq!(cache.seq_len(), 3);
cache.reset();
assert!(cache.is_empty());
assert_eq!(cache.seq_len(), 0);
}
#[test]
fn test_kv_cache_overflow_rejected() {
let mut cache = KVCache::<f32>::new(4);
let k = ferrotorch_core::ones::<f32>(&[1, 1, 5, 2]).unwrap();
let v = ferrotorch_core::ones::<f32>(&[1, 1, 5, 2]).unwrap();
assert!(cache.update(k, v).is_err());
}
#[test]
fn test_kv_cache_shape_mismatch_rejected() {
let mut cache = KVCache::<f32>::new(128);
let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 8]).unwrap(); assert!(cache.update(k, v).is_err());
}
#[test]
fn test_kv_cache_values_preserved() {
let mut cache = KVCache::<f64>::new(128);
let k1 = ferrotorch_core::ones::<f64>(&[1, 1, 2, 3]).unwrap();
let v1 = ferrotorch_core::ones::<f64>(&[1, 1, 2, 3]).unwrap();
cache.update(k1, v1).unwrap();
let k2_data = vec![2.0f64; 1 * 1 * 1 * 3];
let k2 = ferrotorch_core::from_slice(&k2_data, &[1, 1, 1, 3]).unwrap();
let v2 = ferrotorch_core::from_slice(&k2_data, &[1, 1, 1, 3]).unwrap();
let (fk, _fv) = cache.update(k2, v2).unwrap();
assert_eq!(fk.shape(), &[1, 1, 3, 3]); let fk_data = fk.data().unwrap();
for &v in &fk_data[..6] {
assert!((v - 1.0).abs() < 1e-10, "expected 1.0, got {v}");
}
for &v in &fk_data[6..9] {
assert!((v - 2.0).abs() < 1e-10, "expected 2.0, got {v}");
}
}
#[test]
fn test_kv_cache_gqa_stores_at_kv_head_granularity() {
let mut cache = KVCache::<f32>::new(8192);
let k = ferrotorch_core::zeros::<f32>(&[1, 8, 3, 128]).unwrap();
let v = ferrotorch_core::zeros::<f32>(&[1, 8, 3, 128]).unwrap();
let (fk, _) = cache.update(k, v).unwrap();
assert_eq!(fk.shape(), &[1, 8, 3, 128]);
assert_eq!(cache.num_kv_heads(), Some(8));
assert_eq!(cache.head_dim(), Some(128));
assert_eq!(cache.batch_size(), Some(1));
}
#[test]
fn test_kv_cache_with_dims_pre_declares_shape() {
let cache = KVCache::<f32>::with_dims(8192, 1, 8, 128);
assert_eq!(cache.num_kv_heads(), Some(8));
assert_eq!(cache.head_dim(), Some(128));
assert_eq!(cache.batch_size(), Some(1));
assert!(cache.is_empty());
}
#[test]
fn test_kv_cache_with_dims_rejects_first_update_mismatch() {
let mut cache = KVCache::<f32>::with_dims(128, 1, 8, 16);
let k = ferrotorch_core::zeros::<f32>(&[1, 4, 2, 16]).unwrap();
let v = ferrotorch_core::zeros::<f32>(&[1, 4, 2, 16]).unwrap();
assert!(cache.update(k, v).is_err());
}
#[test]
fn test_kv_cache_with_dims_rejects_head_dim_mismatch() {
let mut cache = KVCache::<f32>::with_dims(128, 1, 8, 16);
let k = ferrotorch_core::zeros::<f32>(&[1, 8, 2, 32]).unwrap(); let v = ferrotorch_core::zeros::<f32>(&[1, 8, 2, 32]).unwrap();
assert!(cache.update(k, v).is_err());
}
#[test]
fn test_kv_cache_with_dims_rejects_batch_mismatch() {
let mut cache = KVCache::<f32>::with_dims(128, 2, 4, 8);
let k = ferrotorch_core::zeros::<f32>(&[1, 4, 2, 8]).unwrap(); let v = ferrotorch_core::zeros::<f32>(&[1, 4, 2, 8]).unwrap();
assert!(cache.update(k, v).is_err());
}
#[test]
fn test_kv_cache_with_dims_accepts_matching_update() {
let mut cache = KVCache::<f32>::with_dims(128, 1, 8, 16);
let k = ferrotorch_core::ones::<f32>(&[1, 8, 3, 16]).unwrap();
let v = ferrotorch_core::ones::<f32>(&[1, 8, 3, 16]).unwrap();
assert!(cache.update(k, v).is_ok());
assert_eq!(cache.seq_len(), 3);
}
#[test]
fn test_kv_cache_inferred_dims_reject_subsequent_mismatch() {
let mut cache = KVCache::<f32>::new(128);
let k1 = ferrotorch_core::zeros::<f32>(&[1, 8, 2, 16]).unwrap();
let v1 = ferrotorch_core::zeros::<f32>(&[1, 8, 2, 16]).unwrap();
cache.update(k1, v1).unwrap();
assert_eq!(cache.num_kv_heads(), Some(8));
let k2 = ferrotorch_core::zeros::<f32>(&[1, 4, 1, 16]).unwrap(); let v2 = ferrotorch_core::zeros::<f32>(&[1, 4, 1, 16]).unwrap();
assert!(cache.update(k2, v2).is_err());
}
#[test]
fn test_kv_cache_dims_not_yet_pinned_on_fresh_new() {
let cache = KVCache::<f32>::new(128);
assert_eq!(cache.num_kv_heads(), None);
assert_eq!(cache.head_dim(), None);
assert_eq!(cache.batch_size(), None);
}
#[test]
fn test_kv_cache_reset_preserves_pinned_dims() {
let mut cache = KVCache::<f32>::with_dims(128, 1, 8, 16);
let k = ferrotorch_core::ones::<f32>(&[1, 8, 2, 16]).unwrap();
let v = ferrotorch_core::ones::<f32>(&[1, 8, 2, 16]).unwrap();
cache.update(k, v).unwrap();
cache.reset();
assert!(cache.is_empty());
assert_eq!(cache.num_kv_heads(), Some(8));
let bad = ferrotorch_core::zeros::<f32>(&[1, 4, 1, 16]).unwrap();
assert!(cache.update(bad.clone(), bad).is_err());
}
#[test]
fn test_kv_cache_gqa_prefill_then_decode_preserves_all_positions() {
let build = |seed: u64, shape: &[usize]| {
let numel: usize = shape.iter().product();
let data: Vec<f32> = (0..numel)
.map(|i| ((i as u64).wrapping_mul(seed) % 997) as f32 * 0.001)
.collect();
ferrotorch_core::from_slice(&data, shape).unwrap()
};
let (b, h, s_prefill, s_decode, d) = (1usize, 8usize, 4usize, 1usize, 16usize);
let s_full = s_prefill + s_decode;
let k_prefill = build(7, &[b, h, s_prefill, d]);
let v_prefill = build(11, &[b, h, s_prefill, d]);
let k_decode = build(13, &[b, h, s_decode, d]);
let v_decode = build(17, &[b, h, s_decode, d]);
let mut cache = KVCache::<f32>::with_dims(16, b, h, d);
cache.update(k_prefill.clone(), v_prefill.clone()).unwrap();
let (fk, fv) = cache.update(k_decode.clone(), v_decode.clone()).unwrap();
assert_eq!(fk.shape(), &[b, h, s_full, d]);
assert_eq!(fv.shape(), &[b, h, s_full, d]);
let fk_data = fk.data_vec().unwrap();
let fv_data = fv.data_vec().unwrap();
let kp = k_prefill.data_vec().unwrap();
let vp = v_prefill.data_vec().unwrap();
let kd = k_decode.data_vec().unwrap();
let vd = v_decode.data_vec().unwrap();
let full_idx = |bi, hi, si, di| ((bi * h + hi) * s_full + si) * d + di;
let src_idx = |bi, hi, si, di, s_len| ((bi * h + hi) * s_len + si) * d + di;
for bi in 0..b {
for hi in 0..h {
for si in 0..s_full {
for di in 0..d {
let out = full_idx(bi, hi, si, di);
let (exp_k, exp_v) = if si < s_prefill {
let src = src_idx(bi, hi, si, di, s_prefill);
(kp[src], vp[src])
} else {
let src = src_idx(bi, hi, si - s_prefill, di, s_decode);
(kd[src], vd[src])
};
assert!(
(fk_data[out] - exp_k).abs() < 1e-6,
"k mismatch at [b={bi}, h={hi}, s={si}, d={di}]: got {}, want {exp_k}",
fk_data[out]
);
assert!(
(fv_data[out] - exp_v).abs() < 1e-6,
"v mismatch at [b={bi}, h={hi}, s={si}, d={di}]: got {}, want {exp_v}",
fv_data[out]
);
}
}
}
}
}
#[test]
fn test_encoder_layer_construction() {
let layer = TransformerEncoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, true);
assert!(layer.is_ok());
}
#[test]
fn test_encoder_layer_forward_shape() {
let layer = TransformerEncoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
let output = layer.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 5, 16]);
}
#[test]
fn test_encoder_layer_forward_values_finite() {
let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let input = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
let output = layer.forward(&input).unwrap();
for &v in output.data().unwrap() {
assert!(
v.is_finite(),
"TransformerEncoderLayer produced non-finite value: {v}"
);
}
}
#[test]
fn test_encoder_layer_2d_rejected() {
let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
assert!(layer.forward(&input).is_err());
}
#[test]
fn test_encoder_layer_parameters_count() {
let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let params = layer.parameters();
assert_eq!(params.len(), 18);
}
#[test]
fn test_encoder_layer_train_eval() {
let mut layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.1, 1e-5, false).unwrap();
assert!(layer.is_training());
layer.eval();
assert!(!layer.is_training());
layer.train();
assert!(layer.is_training());
}
#[test]
fn test_encoder_layer_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<TransformerEncoderLayer<f32>>();
assert_send_sync::<TransformerEncoderLayer<f64>>();
}
#[test]
fn test_decoder_layer_construction() {
let layer = TransformerDecoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, true);
assert!(layer.is_ok());
}
#[test]
fn test_decoder_layer_forward_shape() {
let layer = TransformerDecoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, false).unwrap();
let tgt = ferrotorch_core::zeros::<f32>(&[2, 4, 16]).unwrap();
let memory = ferrotorch_core::zeros::<f32>(&[2, 6, 16]).unwrap();
let output = layer.forward_with_memory(&tgt, &memory).unwrap();
assert_eq!(output.shape(), &[2, 4, 16]);
}
#[test]
fn test_decoder_layer_self_forward_shape() {
let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
let output = layer.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 3, 8]);
}
#[test]
fn test_decoder_layer_forward_values_finite() {
let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let tgt = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
let mem = ferrotorch_core::ones::<f32>(&[1, 5, 8]).unwrap();
let output = layer.forward_with_memory(&tgt, &mem).unwrap();
for &v in output.data().unwrap() {
assert!(
v.is_finite(),
"TransformerDecoderLayer produced non-finite value: {v}"
);
}
}
#[test]
fn test_decoder_layer_2d_rejected() {
let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
let memory = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
assert!(layer.forward_with_memory(&input, &memory).is_err());
}
#[test]
fn test_decoder_layer_parameters_count() {
let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
let params = layer.parameters();
assert_eq!(params.len(), 28);
}
#[test]
fn test_decoder_layer_train_eval() {
let mut layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.1, 1e-5, false).unwrap();
assert!(layer.is_training());
layer.eval();
assert!(!layer.is_training());
layer.train();
assert!(layer.is_training());
}
#[test]
fn test_decoder_layer_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<TransformerDecoderLayer<f32>>();
assert_send_sync::<TransformerDecoderLayer<f64>>();
}
#[test]
fn test_encoder_construction() {
let enc = TransformerEncoder::<f32>::new(16, 4, 3, 32, 0.0, 1e-5, true, true);
assert!(enc.is_ok());
assert_eq!(enc.unwrap().num_layers(), 3);
}
#[test]
fn test_encoder_zero_layers_rejected() {
assert!(TransformerEncoder::<f32>::new(16, 4, 0, 32, 0.0, 1e-5, true, true).is_err());
}
#[test]
fn test_encoder_forward_shape() {
let enc = TransformerEncoder::<f32>::new(16, 4, 2, 32, 0.0, 1e-5, false, true).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
let output = enc.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 5, 16]);
}
#[test]
fn test_encoder_forward_no_final_norm() {
let enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, false, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
let output = enc.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 3, 8]);
}
#[test]
fn test_encoder_forward_values_finite() {
let enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
let input = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
let output = enc.forward(&input).unwrap();
for &v in output.data().unwrap() {
assert!(
v.is_finite(),
"TransformerEncoder produced non-finite value: {v}"
);
}
}
#[test]
fn test_encoder_parameters_with_final_norm() {
let enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
assert_eq!(enc.parameters().len(), 38);
}
#[test]
fn test_encoder_named_parameters_have_layer_prefix() {
let enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
let named = enc.named_parameters();
let has_layer_0 = named.iter().any(|(n, _)| n.starts_with("layers.0."));
let has_layer_1 = named.iter().any(|(n, _)| n.starts_with("layers.1."));
let has_norm = named.iter().any(|(n, _)| n.starts_with("norm."));
assert!(has_layer_0, "missing layers.0.* in named_parameters");
assert!(has_layer_1, "missing layers.1.* in named_parameters");
assert!(has_norm, "missing norm.* in named_parameters");
}
#[test]
fn test_encoder_train_eval() {
let mut enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.1, 1e-5, false, false).unwrap();
assert!(enc.is_training());
enc.eval();
assert!(!enc.is_training());
enc.train();
assert!(enc.is_training());
}
#[test]
fn test_encoder_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<TransformerEncoder<f32>>();
assert_send_sync::<TransformerEncoder<f64>>();
}
#[test]
fn test_decoder_construction() {
let dec = TransformerDecoder::<f32>::new(16, 4, 3, 32, 0.0, 1e-5, true, true);
assert!(dec.is_ok());
assert_eq!(dec.unwrap().num_layers(), 3);
}
#[test]
fn test_decoder_zero_layers_rejected() {
assert!(TransformerDecoder::<f32>::new(16, 4, 0, 32, 0.0, 1e-5, true, true).is_err());
}
#[test]
fn test_decoder_forward_with_memory_shape() {
let dec = TransformerDecoder::<f32>::new(16, 4, 2, 32, 0.0, 1e-5, false, true).unwrap();
let tgt = ferrotorch_core::zeros::<f32>(&[2, 4, 16]).unwrap();
let memory = ferrotorch_core::zeros::<f32>(&[2, 6, 16]).unwrap();
let output = dec.forward_with_memory(&tgt, &memory).unwrap();
assert_eq!(output.shape(), &[2, 4, 16]);
}
#[test]
fn test_decoder_forward_values_finite() {
let dec = TransformerDecoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
let tgt = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
let mem = ferrotorch_core::ones::<f32>(&[1, 5, 8]).unwrap();
let output = dec.forward_with_memory(&tgt, &mem).unwrap();
for &v in output.data().unwrap() {
assert!(
v.is_finite(),
"TransformerDecoder produced non-finite value: {v}"
);
}
}
#[test]
fn test_decoder_parameters_with_final_norm() {
let dec = TransformerDecoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
assert_eq!(dec.parameters().len(), 58);
}
#[test]
fn test_decoder_named_parameters_have_layer_prefix() {
let dec = TransformerDecoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
let named = dec.named_parameters();
let has_layer_0 = named.iter().any(|(n, _)| n.starts_with("layers.0."));
let has_layer_1 = named.iter().any(|(n, _)| n.starts_with("layers.1."));
let has_norm = named.iter().any(|(n, _)| n.starts_with("norm."));
assert!(has_layer_0, "missing layers.0.* in named_parameters");
assert!(has_layer_1, "missing layers.1.* in named_parameters");
assert!(has_norm, "missing norm.* in named_parameters");
}
#[test]
fn test_decoder_train_eval() {
let mut dec = TransformerDecoder::<f32>::new(8, 2, 2, 16, 0.1, 1e-5, false, false).unwrap();
assert!(dec.is_training());
dec.eval();
assert!(!dec.is_training());
dec.train();
assert!(dec.is_training());
}
#[test]
fn test_decoder_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<TransformerDecoder<f32>>();
assert_send_sync::<TransformerDecoder<f64>>();
}
#[test]
fn test_transformer_construction() {
let t = Transformer::<f32>::new(16, 4, 2, 2, 32, 0.0, 1e-5, true);
assert!(t.is_ok());
let t = t.unwrap();
assert_eq!(t.num_encoder_layers(), 2);
assert_eq!(t.num_decoder_layers(), 2);
}
#[test]
fn test_transformer_forward_shape() {
let t = Transformer::<f32>::new(16, 4, 2, 2, 32, 0.0, 1e-5, false).unwrap();
let src = ferrotorch_core::zeros::<f32>(&[2, 10, 16]).unwrap();
let tgt = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
let output = t.forward_transformer(&src, &tgt).unwrap();
assert_eq!(output.shape(), &[2, 5, 16]);
}
#[test]
fn test_transformer_self_forward_shape() {
let t = Transformer::<f32>::new(8, 2, 1, 1, 16, 0.0, 1e-5, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
let output = t.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 3, 8]);
}
#[test]
fn test_transformer_forward_values_finite() {
let t = Transformer::<f32>::new(8, 2, 2, 2, 16, 0.0, 1e-5, true).unwrap();
let src = ferrotorch_core::ones::<f32>(&[1, 4, 8]).unwrap();
let tgt = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
let output = t.forward_transformer(&src, &tgt).unwrap();
for &v in output.data().unwrap() {
assert!(
v.is_finite(),
"Transformer produced non-finite value: {v}"
);
}
}
#[test]
fn test_transformer_parameters_count() {
let t = Transformer::<f32>::new(8, 2, 2, 2, 16, 0.0, 1e-5, true).unwrap();
assert_eq!(t.parameters().len(), 96);
}
#[test]
fn test_transformer_named_parameters_prefixed() {
let t = Transformer::<f32>::new(8, 2, 1, 1, 16, 0.0, 1e-5, true).unwrap();
let named = t.named_parameters();
let has_encoder = named.iter().any(|(n, _)| n.starts_with("encoder."));
let has_decoder = named.iter().any(|(n, _)| n.starts_with("decoder."));
assert!(has_encoder, "missing encoder.* in named_parameters");
assert!(has_decoder, "missing decoder.* in named_parameters");
}
#[test]
fn test_transformer_train_eval() {
let mut t = Transformer::<f32>::new(8, 2, 1, 1, 16, 0.1, 1e-5, false).unwrap();
assert!(t.is_training());
t.eval();
assert!(!t.is_training());
t.train();
assert!(t.is_training());
}
#[test]
fn test_transformer_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Transformer<f32>>();
assert_send_sync::<Transformer<f64>>();
}
}