1use std::sync::Arc;
40
41use ferrotorch_core::autograd::no_grad::is_grad_enabled;
42use ferrotorch_core::grad_fns::activation::silu;
43use ferrotorch_core::grad_fns::arithmetic::{add, mul};
44use ferrotorch_core::grad_fns::shape::reshape;
45use ferrotorch_core::tensor::GradFn;
46use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
47
48use crate::attention::MultiheadAttention;
49use crate::dropout::Dropout;
50use crate::linear::Linear;
51use crate::module::Module;
52use crate::norm::LayerNorm;
53use crate::parameter::Parameter;
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
66pub enum RoPEConvention {
67 #[default]
69 Interleaved,
70 HalfRotation,
72}
73
74impl std::fmt::Display for RoPEConvention {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 match self {
77 RoPEConvention::Interleaved => write!(f, "interleaved"),
78 RoPEConvention::HalfRotation => write!(f, "half_rotation"),
79 }
80 }
81}
82
83#[derive(Debug)]
93struct RoPEBackward<T: Float> {
94 input: Tensor<T>,
95 cos_flat: Vec<T>,
96 sin_flat: Vec<T>,
97 half_dim: usize,
98 seq_len: usize,
99 batch_dims: usize,
100 dim: usize,
101 seq_offset: usize,
102 convention: RoPEConvention,
103}
104
105impl<T: Float> GradFn<T> for RoPEBackward<T> {
106 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
107 let da = if self.input.requires_grad() {
108 let go_data = grad_output.data_vec()?;
109 let total = go_data.len();
110 let mut grad_input = Vec::with_capacity(total);
111
112 match self.convention {
113 RoPEConvention::Interleaved => {
114 for b in 0..self.batch_dims {
115 for s in 0..self.seq_len {
116 let cache_start = (self.seq_offset + s) * self.half_dim;
117 let go_start = b * self.seq_len * self.dim + s * self.dim;
118
119 for i in 0..self.half_dim {
120 let go_even = go_data[go_start + 2 * i];
121 let go_odd = go_data[go_start + 2 * i + 1];
122 let cos_val = self.cos_flat[cache_start + i];
123 let sin_val = self.sin_flat[cache_start + i];
124
125 grad_input.push(go_even * cos_val + go_odd * sin_val);
127 grad_input.push(-go_even * sin_val + go_odd * cos_val);
128 }
129 }
130 }
131 }
132 RoPEConvention::HalfRotation => {
133 for b in 0..self.batch_dims {
134 for s in 0..self.seq_len {
135 let cache_start = (self.seq_offset + s) * self.half_dim;
136 let go_start = b * self.seq_len * self.dim + s * self.dim;
137
138 for i in 0..self.half_dim {
140 let go_first = go_data[go_start + i];
141 let go_second = go_data[go_start + self.half_dim + i];
142 let cos_val = self.cos_flat[cache_start + i];
143 let sin_val = self.sin_flat[cache_start + i];
144
145 grad_input.push(go_first * cos_val + go_second * sin_val);
146 }
147 for i in 0..self.half_dim {
149 let go_first = go_data[go_start + i];
150 let go_second = go_data[go_start + self.half_dim + i];
151 let cos_val = self.cos_flat[cache_start + i];
152 let sin_val = self.sin_flat[cache_start + i];
153
154 grad_input.push(-go_first * sin_val + go_second * cos_val);
155 }
156 }
157 }
158 }
159 }
160
161 let g = Tensor::from_storage(
162 TensorStorage::cpu(grad_input),
163 self.input.shape().to_vec(),
164 false,
165 )?;
166 Some(if self.input.is_cuda() {
167 g.to(self.input.device())?
168 } else {
169 g
170 })
171 } else {
172 None
173 };
174 Ok(vec![da])
175 }
176
177 fn inputs(&self) -> Vec<&Tensor<T>> {
178 vec![&self.input]
179 }
180
181 fn name(&self) -> &'static str {
182 "RoPEBackward"
183 }
184}
185
186#[derive(Debug)]
209pub struct RotaryPositionEmbedding<T: Float> {
210 dim: usize,
211 max_seq_len: usize,
212 base: f64,
213 convention: RoPEConvention,
214 scaling: RoPEScaling,
215 cos_cache: Tensor<T>,
217 sin_cache: Tensor<T>,
219}
220
221#[derive(Debug, Clone, Copy, PartialEq, Default)]
227pub enum RoPEScaling {
228 #[default]
232 None,
233
234 Linear {
241 factor: f64,
243 },
244
245 NtkAware {
252 factor: f64,
254 original_max_pos_embeddings: usize,
259 },
260
261 Yarn {
268 factor: f64,
270 original_max_pos_embeddings: usize,
272 beta_fast: f64,
275 beta_slow: f64,
278 },
279}
280
281impl RoPEScaling {
282 pub const fn yarn_default(factor: f64, original_max_pos_embeddings: usize) -> Self {
284 RoPEScaling::Yarn {
285 factor,
286 original_max_pos_embeddings,
287 beta_fast: 32.0,
288 beta_slow: 1.0,
289 }
290 }
291}
292
293fn yarn_find_correction_dim(
297 num_rotations: f64,
298 dim: usize,
299 base: f64,
300 original_max_pos_embeddings: usize,
301) -> f64 {
302 (dim as f64
304 * (original_max_pos_embeddings as f64 / (num_rotations * 2.0 * std::f64::consts::PI)).ln())
305 / (2.0 * base.ln())
306}
307
308fn yarn_find_correction_range(
310 low_rot: f64,
311 high_rot: f64,
312 dim: usize,
313 base: f64,
314 original_max_pos_embeddings: usize,
315) -> (f64, f64) {
316 let low = yarn_find_correction_dim(low_rot, dim, base, original_max_pos_embeddings).floor();
317 let high = yarn_find_correction_dim(high_rot, dim, base, original_max_pos_embeddings).ceil();
318 (low.max(0.0), high.min((dim - 1) as f64))
319}
320
321fn compute_base_inv_freq(dim: usize, base: f64) -> Vec<f64> {
323 let half = dim / 2;
324 (0..half)
325 .map(|i| 1.0 / base.powf(2.0 * i as f64 / dim as f64))
326 .collect()
327}
328
329pub(crate) fn compute_scaled_inv_freq(dim: usize, base: f64, scaling: RoPEScaling) -> Vec<f64> {
334 match scaling {
335 RoPEScaling::None => compute_base_inv_freq(dim, base),
336
337 RoPEScaling::Linear { factor } => {
338 let mut iv = compute_base_inv_freq(dim, base);
339 for v in iv.iter_mut() {
340 *v /= factor;
341 }
342 iv
343 }
344
345 RoPEScaling::NtkAware { factor, .. } => {
346 let exp = dim as f64 / (dim as f64 - 2.0);
351 let base_scaled = base * factor.powf(exp);
352 compute_base_inv_freq(dim, base_scaled)
353 }
354
355 RoPEScaling::Yarn {
356 factor,
357 original_max_pos_embeddings,
358 beta_fast,
359 beta_slow,
360 } => {
361 let half = dim / 2;
362 let pos_freqs: Vec<f64> = (0..half)
363 .map(|i| base.powf(2.0 * i as f64 / dim as f64))
364 .collect();
365 let extrapolation: Vec<f64> = pos_freqs.iter().map(|p| 1.0 / p).collect();
366 let interpolation: Vec<f64> = pos_freqs.iter().map(|p| 1.0 / (factor * p)).collect();
367
368 let (low, high) = yarn_find_correction_range(
369 beta_fast,
370 beta_slow,
371 dim,
372 base,
373 original_max_pos_embeddings,
374 );
375 let (low, high) = (low / 2.0, high / 2.0);
378
379 let denom = if high == low { 0.001 } else { high - low };
384 (0..half)
385 .map(|i| {
386 let t = ((i as f64 - low) / denom).clamp(0.0, 1.0);
387 let mask = 1.0 - t;
389 interpolation[i] * (1.0 - mask) + extrapolation[i] * mask
390 })
391 .collect()
392 }
393 }
394}
395
396impl<T: Float> RotaryPositionEmbedding<T> {
397 pub fn new(dim: usize, max_seq_len: usize, base: f64) -> FerrotorchResult<Self> {
410 Self::with_scaling(
411 dim,
412 max_seq_len,
413 base,
414 RoPEConvention::default(),
415 RoPEScaling::None,
416 )
417 }
418
419 pub fn with_convention(
425 dim: usize,
426 max_seq_len: usize,
427 base: f64,
428 convention: RoPEConvention,
429 ) -> FerrotorchResult<Self> {
430 Self::with_scaling(dim, max_seq_len, base, convention, RoPEScaling::None)
431 }
432
433 pub fn with_scaling(
439 dim: usize,
440 max_seq_len: usize,
441 base: f64,
442 convention: RoPEConvention,
443 scaling: RoPEScaling,
444 ) -> FerrotorchResult<Self> {
445 if dim == 0 || dim % 2 != 0 {
446 return Err(FerrotorchError::InvalidArgument {
447 message: format!("RoPE dim must be even and positive, got {dim}"),
448 });
449 }
450 if max_seq_len == 0 {
451 return Err(FerrotorchError::InvalidArgument {
452 message: "RoPE max_seq_len must be positive".into(),
453 });
454 }
455 if let RoPEScaling::Linear { factor }
456 | RoPEScaling::NtkAware { factor, .. }
457 | RoPEScaling::Yarn { factor, .. } = scaling
458 {
459 if !(factor.is_finite() && factor > 0.0) {
460 return Err(FerrotorchError::InvalidArgument {
461 message: format!("RoPE scaling factor must be finite and > 0, got {factor}"),
462 });
463 }
464 }
465
466 let half_dim = dim / 2;
467 let thetas = compute_scaled_inv_freq(dim, base, scaling);
468
469 let total = max_seq_len * half_dim;
472 let mut cos_data = Vec::with_capacity(total);
473 let mut sin_data = Vec::with_capacity(total);
474
475 for pos in 0..max_seq_len {
476 for &theta in &thetas {
477 let angle = pos as f64 * theta;
478 cos_data.push(T::from(angle.cos()).unwrap());
479 sin_data.push(T::from(angle.sin()).unwrap());
480 }
481 }
482
483 let cos_cache = Tensor::from_storage(
484 TensorStorage::cpu(cos_data),
485 vec![max_seq_len, half_dim],
486 false,
487 )?;
488 let sin_cache = Tensor::from_storage(
489 TensorStorage::cpu(sin_data),
490 vec![max_seq_len, half_dim],
491 false,
492 )?;
493
494 Ok(Self {
495 dim,
496 max_seq_len,
497 base,
498 convention,
499 scaling,
500 cos_cache,
501 sin_cache,
502 })
503 }
504
505 pub fn apply(&self, x: &Tensor<T>, seq_offset: usize) -> FerrotorchResult<Tensor<T>> {
521 let shape = x.shape();
522 let ndim = shape.len();
523 if ndim < 2 {
524 return Err(FerrotorchError::InvalidArgument {
525 message: format!(
526 "RoPE input must be at least 2-D, got {ndim}-D with shape {shape:?}"
527 ),
528 });
529 }
530
531 let last_dim = shape[ndim - 1];
532 if last_dim != self.dim {
533 return Err(FerrotorchError::ShapeMismatch {
534 message: format!("RoPE: last dim of input ({last_dim}) != dim ({})", self.dim),
535 });
536 }
537
538 let seq_len = shape[ndim - 2];
539 if seq_offset + seq_len > self.max_seq_len {
540 return Err(FerrotorchError::InvalidArgument {
541 message: format!(
542 "RoPE: seq_offset ({seq_offset}) + seq_len ({seq_len}) > max_seq_len ({})",
543 self.max_seq_len
544 ),
545 });
546 }
547
548 let device = x.device();
549 let half_dim = self.dim / 2;
550 let cos_data = self.cos_cache.data_vec()?;
551 let sin_data = self.sin_cache.data_vec()?;
552 let x_data = x.data_vec()?;
553
554 let batch_dims: usize = shape[..ndim - 2].iter().product();
556
557 let total = x.numel();
558 let mut output = Vec::with_capacity(total);
559
560 match self.convention {
561 RoPEConvention::Interleaved => {
562 for b in 0..batch_dims {
564 for s in 0..seq_len {
565 let pos = seq_offset + s;
566 let cache_start = pos * half_dim;
567 let x_start = b * seq_len * self.dim + s * self.dim;
568
569 for i in 0..half_dim {
570 let x_even = x_data[x_start + 2 * i];
571 let x_odd = x_data[x_start + 2 * i + 1];
572 let cos_val = cos_data[cache_start + i];
573 let sin_val = sin_data[cache_start + i];
574
575 output.push(x_even * cos_val - x_odd * sin_val);
576 output.push(x_even * sin_val + x_odd * cos_val);
577 }
578 }
579 }
580 }
581 RoPEConvention::HalfRotation => {
582 for b in 0..batch_dims {
585 for s in 0..seq_len {
586 let pos = seq_offset + s;
587 let cache_start = pos * half_dim;
588 let x_start = b * seq_len * self.dim + s * self.dim;
589
590 for i in 0..half_dim {
592 let x_first = x_data[x_start + i];
593 let x_second = x_data[x_start + half_dim + i];
594 let cos_val = cos_data[cache_start + i];
595 let sin_val = sin_data[cache_start + i];
596
597 output.push(x_first * cos_val - x_second * sin_val);
598 }
599 for i in 0..half_dim {
601 let x_first = x_data[x_start + i];
602 let x_second = x_data[x_start + half_dim + i];
603 let cos_val = cos_data[cache_start + i];
604 let sin_val = sin_data[cache_start + i];
605
606 output.push(x_first * sin_val + x_second * cos_val);
607 }
608 }
609 }
610 }
611 }
612
613 let result = if is_grad_enabled() && x.requires_grad() {
614 Tensor::from_operation(
615 TensorStorage::cpu(output),
616 shape.to_vec(),
617 Arc::new(RoPEBackward {
618 input: x.clone(),
619 cos_flat: cos_data,
620 sin_flat: sin_data,
621 half_dim,
622 seq_len,
623 batch_dims,
624 dim: self.dim,
625 seq_offset,
626 convention: self.convention,
627 }),
628 )?
629 } else {
630 Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?
631 };
632 if device.is_cuda() {
633 result.to(device)
634 } else {
635 Ok(result)
636 }
637 }
638
639 #[inline]
641 pub fn dim(&self) -> usize {
642 self.dim
643 }
644
645 #[inline]
647 pub fn max_seq_len(&self) -> usize {
648 self.max_seq_len
649 }
650
651 #[inline]
653 pub fn base(&self) -> f64 {
654 self.base
655 }
656
657 #[inline]
659 pub fn convention(&self) -> RoPEConvention {
660 self.convention
661 }
662
663 #[inline]
665 pub fn scaling(&self) -> RoPEScaling {
666 self.scaling
667 }
668}
669
670#[derive(Debug)]
694pub struct SwiGLU<T: Float> {
695 w1: Linear<T>,
697 w2: Linear<T>,
699 w3: Linear<T>,
701 training: bool,
702}
703
704impl<T: Float> SwiGLU<T> {
705 pub fn new(in_features: usize, hidden_features: usize, bias: bool) -> FerrotorchResult<Self> {
714 let w1 = Linear::new(in_features, hidden_features, bias)?;
715 let w2 = Linear::new(in_features, hidden_features, bias)?;
716 let w3 = Linear::new(hidden_features, in_features, bias)?;
717
718 Ok(Self {
719 w1,
720 w2,
721 w3,
722 training: true,
723 })
724 }
725
726 fn forward_3d(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
731 let shape = input.shape();
732 let batch = shape[0];
733 let seq_len = shape[1];
734
735 let flat = reshape(input, &[(batch * seq_len) as isize, -1])?;
737
738 let output_flat = self.forward_2d(&flat)?;
739
740 let out_features = output_flat.shape()[1];
742 reshape(
743 &output_flat,
744 &[batch as isize, seq_len as isize, out_features as isize],
745 )
746 }
747
748 fn forward_2d(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
750 let w1_out = self.w1.forward(input)?;
752 let gate = silu(&w1_out)?;
753
754 let up = self.w2.forward(input)?;
756
757 let gated = mul(&gate, &up)?;
759
760 self.w3.forward(&gated)
762 }
763}
764
765impl<T: Float> Module<T> for SwiGLU<T> {
766 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
767 match input.ndim() {
768 2 => self.forward_2d(input),
769 3 => self.forward_3d(input),
770 _ => Err(FerrotorchError::InvalidArgument {
771 message: format!(
772 "SwiGLU expects 2-D or 3-D input, got {}-D with shape {:?}",
773 input.ndim(),
774 input.shape()
775 ),
776 }),
777 }
778 }
779
780 fn parameters(&self) -> Vec<&Parameter<T>> {
781 let mut params = self.w1.parameters();
782 params.extend(self.w2.parameters());
783 params.extend(self.w3.parameters());
784 params
785 }
786
787 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
788 let mut params = self.w1.parameters_mut();
789 params.extend(self.w2.parameters_mut());
790 params.extend(self.w3.parameters_mut());
791 params
792 }
793
794 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
795 let mut params = Vec::new();
796 for (name, param) in self.w1.named_parameters() {
797 params.push((format!("w1.{name}"), param));
798 }
799 for (name, param) in self.w2.named_parameters() {
800 params.push((format!("w2.{name}"), param));
801 }
802 for (name, param) in self.w3.named_parameters() {
803 params.push((format!("w3.{name}"), param));
804 }
805 params
806 }
807
808 fn train(&mut self) {
809 self.training = true;
810 self.w1.train();
811 self.w2.train();
812 self.w3.train();
813 }
814
815 fn eval(&mut self) {
816 self.training = false;
817 self.w1.eval();
818 self.w2.eval();
819 self.w3.eval();
820 }
821
822 fn is_training(&self) -> bool {
823 self.training
824 }
825}
826
827#[derive(Debug, Clone, Copy, PartialEq, Eq)]
835struct CacheDims {
836 batch: usize,
837 num_kv_heads: usize,
838 head_dim: usize,
839}
840
841#[derive(Debug)]
865pub struct KVCache<T: Float> {
866 key_cache: Option<Tensor<T>>,
868 value_cache: Option<Tensor<T>>,
870 max_seq_len: usize,
872 dims: Option<CacheDims>,
875}
876
877impl<T: Float> KVCache<T> {
878 pub fn new(max_seq_len: usize) -> Self {
882 Self {
883 key_cache: None,
884 value_cache: None,
885 max_seq_len,
886 dims: None,
887 }
888 }
889
890 pub fn with_dims(
898 max_seq_len: usize,
899 batch: usize,
900 num_kv_heads: usize,
901 head_dim: usize,
902 ) -> Self {
903 Self {
904 key_cache: None,
905 value_cache: None,
906 max_seq_len,
907 dims: Some(CacheDims {
908 batch,
909 num_kv_heads,
910 head_dim,
911 }),
912 }
913 }
914
915 pub fn update(
930 &mut self,
931 key: Tensor<T>,
932 value: Tensor<T>,
933 ) -> FerrotorchResult<(Tensor<T>, Tensor<T>)> {
934 if key.ndim() != 4 || value.ndim() != 4 {
935 return Err(FerrotorchError::InvalidArgument {
936 message: format!(
937 "KVCache expects 4-D [B, kv_heads, seq, dim] tensors, \
938 got key {:?}, value {:?}",
939 key.shape(),
940 value.shape()
941 ),
942 });
943 }
944
945 if key.shape() != value.shape() {
946 return Err(FerrotorchError::ShapeMismatch {
947 message: format!(
948 "KVCache: key shape {:?} != value shape {:?}",
949 key.shape(),
950 value.shape()
951 ),
952 });
953 }
954
955 let ks = key.shape();
956 let incoming = CacheDims {
957 batch: ks[0],
958 num_kv_heads: ks[1],
959 head_dim: ks[3],
960 };
961
962 match &self.dims {
963 Some(expected) if expected != &incoming => {
964 return Err(FerrotorchError::ShapeMismatch {
965 message: format!(
966 "KVCache: update shape [B={}, kv_heads={}, _, dim={}] does not \
967 match pinned dims [B={}, kv_heads={}, _, dim={}]",
968 incoming.batch,
969 incoming.num_kv_heads,
970 incoming.head_dim,
971 expected.batch,
972 expected.num_kv_heads,
973 expected.head_dim,
974 ),
975 });
976 }
977 None => self.dims = Some(incoming),
978 _ => {}
979 }
980
981 let (full_key, full_value) = match (&self.key_cache, &self.value_cache) {
982 (Some(ck), Some(cv)) => {
983 let fk = concat_along_dim2(ck, &key)?;
984 let fv = concat_along_dim2(cv, &value)?;
985 (fk, fv)
986 }
987 _ => (key.clone(), value.clone()),
988 };
989
990 let total_seq = full_key.shape()[2];
992 if total_seq > self.max_seq_len {
993 return Err(FerrotorchError::InvalidArgument {
994 message: format!(
995 "KVCache: total sequence length ({total_seq}) exceeds max_seq_len ({})",
996 self.max_seq_len
997 ),
998 });
999 }
1000
1001 self.key_cache = Some(full_key.clone());
1002 self.value_cache = Some(full_value.clone());
1003
1004 Ok((full_key, full_value))
1005 }
1006
1007 pub fn reset(&mut self) {
1012 self.key_cache = None;
1013 self.value_cache = None;
1014 }
1015
1016 pub fn seq_len(&self) -> usize {
1018 self.key_cache.as_ref().map(|k| k.shape()[2]).unwrap_or(0)
1019 }
1020
1021 pub fn is_empty(&self) -> bool {
1023 self.key_cache.is_none()
1024 }
1025
1026 #[inline]
1028 pub fn max_seq_len(&self) -> usize {
1029 self.max_seq_len
1030 }
1031
1032 pub fn num_kv_heads(&self) -> Option<usize> {
1037 self.dims.map(|d| d.num_kv_heads)
1038 }
1039
1040 pub fn head_dim(&self) -> Option<usize> {
1042 self.dims.map(|d| d.head_dim)
1043 }
1044
1045 pub fn batch_size(&self) -> Option<usize> {
1047 self.dims.map(|d| d.batch)
1048 }
1049}
1050
1051fn concat_along_dim2<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1055 let sa = a.shape();
1056 let sb = b.shape();
1057
1058 if sa[0] != sb[0] || sa[1] != sb[1] || sa[3] != sb[3] {
1059 return Err(FerrotorchError::ShapeMismatch {
1060 message: format!(
1061 "concat_along_dim2: shapes {:?} and {:?} must match on dims 0, 1, 3",
1062 sa, sb
1063 ),
1064 });
1065 }
1066
1067 let device = a.device();
1068 let (batch, heads, seq_a, dim) = (sa[0], sa[1], sa[2], sa[3]);
1069 let seq_b = sb[2];
1070 let seq_out = seq_a + seq_b;
1071
1072 let a_data = a.data_vec()?;
1073 let b_data = b.data_vec()?;
1074
1075 let mut output = Vec::with_capacity(batch * heads * seq_out * dim);
1076
1077 for ba in 0..batch {
1078 for h in 0..heads {
1079 let a_start = (ba * heads + h) * seq_a * dim;
1081 output.extend_from_slice(&a_data[a_start..a_start + seq_a * dim]);
1082 let b_start = (ba * heads + h) * seq_b * dim;
1084 output.extend_from_slice(&b_data[b_start..b_start + seq_b * dim]);
1085 }
1086 }
1087
1088 let result = Tensor::from_storage(
1089 TensorStorage::cpu(output),
1090 vec![batch, heads, seq_out, dim],
1091 false,
1092 )?;
1093 if device.is_cuda() {
1094 result.to(device)
1095 } else {
1096 Ok(result)
1097 }
1098}
1099
1100#[derive(Debug)]
1121pub struct TransformerEncoderLayer<T: Float> {
1122 self_attn: MultiheadAttention<T>,
1123 ffn: SwiGLU<T>,
1124 norm1: LayerNorm<T>,
1125 norm2: LayerNorm<T>,
1126 dropout: Dropout<T>,
1127 training: bool,
1128}
1129
1130impl<T: Float> TransformerEncoderLayer<T> {
1131 pub fn new(
1142 d_model: usize,
1143 num_heads: usize,
1144 d_ff: usize,
1145 dropout_p: f64,
1146 layer_norm_eps: f64,
1147 bias: bool,
1148 ) -> FerrotorchResult<Self> {
1149 let self_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
1150 let ffn = SwiGLU::new(d_model, d_ff, bias)?;
1151 let norm1 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
1152 let norm2 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
1153 let dropout = Dropout::new(dropout_p)?;
1154
1155 Ok(Self {
1156 self_attn,
1157 ffn,
1158 norm1,
1159 norm2,
1160 dropout,
1161 training: true,
1162 })
1163 }
1164
1165 pub fn set_dropout_p(&mut self, p: f64) -> FerrotorchResult<()> {
1170 self.dropout.set_p(p)
1171 }
1172
1173 pub fn dropout_p(&self) -> f64 {
1175 self.dropout.p()
1176 }
1177}
1178
1179impl<T: Float> Module<T> for TransformerEncoderLayer<T> {
1180 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1184 if input.ndim() != 3 {
1185 return Err(FerrotorchError::InvalidArgument {
1186 message: format!(
1187 "TransformerEncoderLayer expects 3-D [batch, seq, d_model], got {:?}",
1188 input.shape()
1189 ),
1190 });
1191 }
1192
1193 let normed1 = self.norm1.forward(input)?;
1195 let attn_out = self.self_attn.forward(&normed1)?;
1196 let attn_out = self.dropout.forward(&attn_out)?;
1197 let residual1 = add(input, &attn_out)?;
1198
1199 let normed2 = self.norm2.forward(&residual1)?;
1201 let ffn_out = self.ffn.forward(&normed2)?;
1202 let ffn_out = self.dropout.forward(&ffn_out)?;
1203 let residual2 = add(&residual1, &ffn_out)?;
1204
1205 Ok(residual2)
1206 }
1207
1208 fn parameters(&self) -> Vec<&Parameter<T>> {
1209 let mut params = Vec::new();
1210 params.extend(self.self_attn.parameters());
1211 params.extend(self.ffn.parameters());
1212 params.extend(self.norm1.parameters());
1213 params.extend(self.norm2.parameters());
1214 params
1216 }
1217
1218 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1219 let mut params = Vec::new();
1220 params.extend(self.self_attn.parameters_mut());
1221 params.extend(self.ffn.parameters_mut());
1222 params.extend(self.norm1.parameters_mut());
1223 params.extend(self.norm2.parameters_mut());
1224 params
1225 }
1226
1227 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1228 let mut params = Vec::new();
1229 for (name, param) in self.self_attn.named_parameters() {
1230 params.push((format!("self_attn.{name}"), param));
1231 }
1232 for (name, param) in self.ffn.named_parameters() {
1233 params.push((format!("ffn.{name}"), param));
1234 }
1235 for (name, param) in self.norm1.named_parameters() {
1236 params.push((format!("norm1.{name}"), param));
1237 }
1238 for (name, param) in self.norm2.named_parameters() {
1239 params.push((format!("norm2.{name}"), param));
1240 }
1241 params
1242 }
1243
1244 fn train(&mut self) {
1245 self.training = true;
1246 self.self_attn.train();
1247 self.ffn.train();
1248 self.norm1.train();
1249 self.norm2.train();
1250 self.dropout.train();
1251 }
1252
1253 fn eval(&mut self) {
1254 self.training = false;
1255 self.self_attn.eval();
1256 self.ffn.eval();
1257 self.norm1.eval();
1258 self.norm2.eval();
1259 self.dropout.eval();
1260 }
1261
1262 fn is_training(&self) -> bool {
1263 self.training
1264 }
1265}
1266
1267#[derive(Debug)]
1290pub struct TransformerDecoderLayer<T: Float> {
1291 self_attn: MultiheadAttention<T>,
1292 cross_attn: MultiheadAttention<T>,
1293 ffn: SwiGLU<T>,
1294 norm1: LayerNorm<T>,
1295 norm2: LayerNorm<T>,
1296 norm3: LayerNorm<T>,
1297 dropout: Dropout<T>,
1298 training: bool,
1299}
1300
1301impl<T: Float> TransformerDecoderLayer<T> {
1302 pub fn new(
1313 d_model: usize,
1314 num_heads: usize,
1315 d_ff: usize,
1316 dropout_p: f64,
1317 layer_norm_eps: f64,
1318 bias: bool,
1319 ) -> FerrotorchResult<Self> {
1320 let self_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
1321 let cross_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
1322 let ffn = SwiGLU::new(d_model, d_ff, bias)?;
1323 let norm1 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
1324 let norm2 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
1325 let norm3 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
1326 let dropout = Dropout::new(dropout_p)?;
1327
1328 Ok(Self {
1329 self_attn,
1330 cross_attn,
1331 ffn,
1332 norm1,
1333 norm2,
1334 norm3,
1335 dropout,
1336 training: true,
1337 })
1338 }
1339
1340 pub fn forward_with_memory(
1351 &self,
1352 input: &Tensor<T>,
1353 memory: &Tensor<T>,
1354 ) -> FerrotorchResult<Tensor<T>> {
1355 if input.ndim() != 3 || memory.ndim() != 3 {
1356 return Err(FerrotorchError::InvalidArgument {
1357 message: format!(
1358 "TransformerDecoderLayer expects 3-D inputs, \
1359 got input {:?}, memory {:?}",
1360 input.shape(),
1361 memory.shape()
1362 ),
1363 });
1364 }
1365
1366 let normed1 = self.norm1.forward(input)?;
1368 let self_attn_out = self
1369 .self_attn
1370 .forward_qkv(&normed1, &normed1, &normed1, true)?;
1371 let self_attn_out = self.dropout.forward(&self_attn_out)?;
1372 let residual1 = add(input, &self_attn_out)?;
1373
1374 let normed2 = self.norm2.forward(&residual1)?;
1376 let cross_attn_out = self
1377 .cross_attn
1378 .forward_qkv(&normed2, memory, memory, false)?;
1379 let cross_attn_out = self.dropout.forward(&cross_attn_out)?;
1380 let residual2 = add(&residual1, &cross_attn_out)?;
1381
1382 let normed3 = self.norm3.forward(&residual2)?;
1384 let ffn_out = self.ffn.forward(&normed3)?;
1385 let ffn_out = self.dropout.forward(&ffn_out)?;
1386 let residual3 = add(&residual2, &ffn_out)?;
1387
1388 Ok(residual3)
1389 }
1390}
1391
1392impl<T: Float> Module<T> for TransformerDecoderLayer<T> {
1393 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1398 self.forward_with_memory(input, input)
1399 }
1400
1401 fn parameters(&self) -> Vec<&Parameter<T>> {
1402 let mut params = Vec::new();
1403 params.extend(self.self_attn.parameters());
1404 params.extend(self.cross_attn.parameters());
1405 params.extend(self.ffn.parameters());
1406 params.extend(self.norm1.parameters());
1407 params.extend(self.norm2.parameters());
1408 params.extend(self.norm3.parameters());
1409 params
1410 }
1411
1412 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1413 let mut params = Vec::new();
1414 params.extend(self.self_attn.parameters_mut());
1415 params.extend(self.cross_attn.parameters_mut());
1416 params.extend(self.ffn.parameters_mut());
1417 params.extend(self.norm1.parameters_mut());
1418 params.extend(self.norm2.parameters_mut());
1419 params.extend(self.norm3.parameters_mut());
1420 params
1421 }
1422
1423 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1424 let mut params = Vec::new();
1425 for (name, param) in self.self_attn.named_parameters() {
1426 params.push((format!("self_attn.{name}"), param));
1427 }
1428 for (name, param) in self.cross_attn.named_parameters() {
1429 params.push((format!("cross_attn.{name}"), param));
1430 }
1431 for (name, param) in self.ffn.named_parameters() {
1432 params.push((format!("ffn.{name}"), param));
1433 }
1434 for (name, param) in self.norm1.named_parameters() {
1435 params.push((format!("norm1.{name}"), param));
1436 }
1437 for (name, param) in self.norm2.named_parameters() {
1438 params.push((format!("norm2.{name}"), param));
1439 }
1440 for (name, param) in self.norm3.named_parameters() {
1441 params.push((format!("norm3.{name}"), param));
1442 }
1443 params
1444 }
1445
1446 fn train(&mut self) {
1447 self.training = true;
1448 self.self_attn.train();
1449 self.cross_attn.train();
1450 self.ffn.train();
1451 self.norm1.train();
1452 self.norm2.train();
1453 self.norm3.train();
1454 self.dropout.train();
1455 }
1456
1457 fn eval(&mut self) {
1458 self.training = false;
1459 self.self_attn.eval();
1460 self.cross_attn.eval();
1461 self.ffn.eval();
1462 self.norm1.eval();
1463 self.norm2.eval();
1464 self.norm3.eval();
1465 self.dropout.eval();
1466 }
1467
1468 fn is_training(&self) -> bool {
1469 self.training
1470 }
1471}
1472
1473#[derive(Debug)]
1489pub struct TransformerEncoder<T: Float> {
1490 layers: Vec<TransformerEncoderLayer<T>>,
1491 norm: Option<LayerNorm<T>>,
1492 training: bool,
1493}
1494
1495impl<T: Float> TransformerEncoder<T> {
1496 #[allow(clippy::too_many_arguments)]
1512 pub fn new(
1513 d_model: usize,
1514 num_heads: usize,
1515 num_layers: usize,
1516 d_ff: usize,
1517 dropout_p: f64,
1518 layer_norm_eps: f64,
1519 bias: bool,
1520 final_norm: bool,
1521 ) -> FerrotorchResult<Self> {
1522 if num_layers == 0 {
1523 return Err(FerrotorchError::InvalidArgument {
1524 message: "TransformerEncoder: num_layers must be > 0".into(),
1525 });
1526 }
1527
1528 let mut layers = Vec::with_capacity(num_layers);
1529 for _ in 0..num_layers {
1530 layers.push(TransformerEncoderLayer::new(
1531 d_model,
1532 num_heads,
1533 d_ff,
1534 dropout_p,
1535 layer_norm_eps,
1536 bias,
1537 )?);
1538 }
1539
1540 let norm = if final_norm {
1541 Some(LayerNorm::new(vec![d_model], layer_norm_eps, true)?)
1542 } else {
1543 None
1544 };
1545
1546 Ok(Self {
1547 layers,
1548 norm,
1549 training: true,
1550 })
1551 }
1552
1553 #[inline]
1555 pub fn num_layers(&self) -> usize {
1556 self.layers.len()
1557 }
1558}
1559
1560impl<T: Float> Module<T> for TransformerEncoder<T> {
1561 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1565 let mut output = input.clone();
1566 for layer in &self.layers {
1567 output = layer.forward(&output)?;
1568 }
1569 if let Some(ref norm) = self.norm {
1570 output = norm.forward(&output)?;
1571 }
1572 Ok(output)
1573 }
1574
1575 fn parameters(&self) -> Vec<&Parameter<T>> {
1576 let mut params = Vec::new();
1577 for (i, layer) in self.layers.iter().enumerate() {
1578 let _ = i; params.extend(layer.parameters());
1580 }
1581 if let Some(ref norm) = self.norm {
1582 params.extend(norm.parameters());
1583 }
1584 params
1585 }
1586
1587 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1588 let mut params = Vec::new();
1589 for layer in &mut self.layers {
1590 params.extend(layer.parameters_mut());
1591 }
1592 if let Some(ref mut norm) = self.norm {
1593 params.extend(norm.parameters_mut());
1594 }
1595 params
1596 }
1597
1598 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1599 let mut params = Vec::new();
1600 for (i, layer) in self.layers.iter().enumerate() {
1601 for (name, param) in layer.named_parameters() {
1602 params.push((format!("layers.{i}.{name}"), param));
1603 }
1604 }
1605 if let Some(ref norm) = self.norm {
1606 for (name, param) in norm.named_parameters() {
1607 params.push((format!("norm.{name}"), param));
1608 }
1609 }
1610 params
1611 }
1612
1613 fn train(&mut self) {
1614 self.training = true;
1615 for layer in &mut self.layers {
1616 layer.train();
1617 }
1618 if let Some(ref mut norm) = self.norm {
1619 norm.train();
1620 }
1621 }
1622
1623 fn eval(&mut self) {
1624 self.training = false;
1625 for layer in &mut self.layers {
1626 layer.eval();
1627 }
1628 if let Some(ref mut norm) = self.norm {
1629 norm.eval();
1630 }
1631 }
1632
1633 fn is_training(&self) -> bool {
1634 self.training
1635 }
1636}
1637
1638#[derive(Debug)]
1656pub struct TransformerDecoder<T: Float> {
1657 layers: Vec<TransformerDecoderLayer<T>>,
1658 norm: Option<LayerNorm<T>>,
1659 training: bool,
1660}
1661
1662impl<T: Float> TransformerDecoder<T> {
1663 #[allow(clippy::too_many_arguments)]
1679 pub fn new(
1680 d_model: usize,
1681 num_heads: usize,
1682 num_layers: usize,
1683 d_ff: usize,
1684 dropout_p: f64,
1685 layer_norm_eps: f64,
1686 bias: bool,
1687 final_norm: bool,
1688 ) -> FerrotorchResult<Self> {
1689 if num_layers == 0 {
1690 return Err(FerrotorchError::InvalidArgument {
1691 message: "TransformerDecoder: num_layers must be > 0".into(),
1692 });
1693 }
1694
1695 let mut layers = Vec::with_capacity(num_layers);
1696 for _ in 0..num_layers {
1697 layers.push(TransformerDecoderLayer::new(
1698 d_model,
1699 num_heads,
1700 d_ff,
1701 dropout_p,
1702 layer_norm_eps,
1703 bias,
1704 )?);
1705 }
1706
1707 let norm = if final_norm {
1708 Some(LayerNorm::new(vec![d_model], layer_norm_eps, true)?)
1709 } else {
1710 None
1711 };
1712
1713 Ok(Self {
1714 layers,
1715 norm,
1716 training: true,
1717 })
1718 }
1719
1720 pub fn forward_with_memory(
1731 &self,
1732 input: &Tensor<T>,
1733 memory: &Tensor<T>,
1734 ) -> FerrotorchResult<Tensor<T>> {
1735 let mut output = input.clone();
1736 for layer in &self.layers {
1737 output = layer.forward_with_memory(&output, memory)?;
1738 }
1739 if let Some(ref norm) = self.norm {
1740 output = norm.forward(&output)?;
1741 }
1742 Ok(output)
1743 }
1744
1745 #[inline]
1747 pub fn num_layers(&self) -> usize {
1748 self.layers.len()
1749 }
1750}
1751
1752impl<T: Float> Module<T> for TransformerDecoder<T> {
1753 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1758 self.forward_with_memory(input, input)
1759 }
1760
1761 fn parameters(&self) -> Vec<&Parameter<T>> {
1762 let mut params = Vec::new();
1763 for layer in &self.layers {
1764 params.extend(layer.parameters());
1765 }
1766 if let Some(ref norm) = self.norm {
1767 params.extend(norm.parameters());
1768 }
1769 params
1770 }
1771
1772 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1773 let mut params = Vec::new();
1774 for layer in &mut self.layers {
1775 params.extend(layer.parameters_mut());
1776 }
1777 if let Some(ref mut norm) = self.norm {
1778 params.extend(norm.parameters_mut());
1779 }
1780 params
1781 }
1782
1783 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1784 let mut params = Vec::new();
1785 for (i, layer) in self.layers.iter().enumerate() {
1786 for (name, param) in layer.named_parameters() {
1787 params.push((format!("layers.{i}.{name}"), param));
1788 }
1789 }
1790 if let Some(ref norm) = self.norm {
1791 for (name, param) in norm.named_parameters() {
1792 params.push((format!("norm.{name}"), param));
1793 }
1794 }
1795 params
1796 }
1797
1798 fn train(&mut self) {
1799 self.training = true;
1800 for layer in &mut self.layers {
1801 layer.train();
1802 }
1803 if let Some(ref mut norm) = self.norm {
1804 norm.train();
1805 }
1806 }
1807
1808 fn eval(&mut self) {
1809 self.training = false;
1810 for layer in &mut self.layers {
1811 layer.eval();
1812 }
1813 if let Some(ref mut norm) = self.norm {
1814 norm.eval();
1815 }
1816 }
1817
1818 fn is_training(&self) -> bool {
1819 self.training
1820 }
1821}
1822
1823#[derive(Debug)]
1848pub struct Transformer<T: Float> {
1849 encoder: TransformerEncoder<T>,
1850 decoder: TransformerDecoder<T>,
1851 training: bool,
1852}
1853
1854impl<T: Float> Transformer<T> {
1855 #[allow(clippy::too_many_arguments)]
1868 pub fn new(
1869 d_model: usize,
1870 num_heads: usize,
1871 num_encoder_layers: usize,
1872 num_decoder_layers: usize,
1873 d_ff: usize,
1874 dropout_p: f64,
1875 layer_norm_eps: f64,
1876 bias: bool,
1877 ) -> FerrotorchResult<Self> {
1878 let encoder = TransformerEncoder::new(
1879 d_model,
1880 num_heads,
1881 num_encoder_layers,
1882 d_ff,
1883 dropout_p,
1884 layer_norm_eps,
1885 bias,
1886 true, )?;
1888 let decoder = TransformerDecoder::new(
1889 d_model,
1890 num_heads,
1891 num_decoder_layers,
1892 d_ff,
1893 dropout_p,
1894 layer_norm_eps,
1895 bias,
1896 true, )?;
1898
1899 Ok(Self {
1900 encoder,
1901 decoder,
1902 training: true,
1903 })
1904 }
1905
1906 pub fn forward_transformer(
1917 &self,
1918 src: &Tensor<T>,
1919 tgt: &Tensor<T>,
1920 ) -> FerrotorchResult<Tensor<T>> {
1921 let memory = self.encoder.forward(src)?;
1922 self.decoder.forward_with_memory(tgt, &memory)
1923 }
1924
1925 #[inline]
1927 pub fn num_encoder_layers(&self) -> usize {
1928 self.encoder.num_layers()
1929 }
1930
1931 #[inline]
1933 pub fn num_decoder_layers(&self) -> usize {
1934 self.decoder.num_layers()
1935 }
1936}
1937
1938impl<T: Float> Module<T> for Transformer<T> {
1939 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1944 self.forward_transformer(input, input)
1945 }
1946
1947 fn parameters(&self) -> Vec<&Parameter<T>> {
1948 let mut params = self.encoder.parameters();
1949 params.extend(self.decoder.parameters());
1950 params
1951 }
1952
1953 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1954 let mut params = self.encoder.parameters_mut();
1955 params.extend(self.decoder.parameters_mut());
1956 params
1957 }
1958
1959 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1960 let mut params = Vec::new();
1961 for (name, param) in self.encoder.named_parameters() {
1962 params.push((format!("encoder.{name}"), param));
1963 }
1964 for (name, param) in self.decoder.named_parameters() {
1965 params.push((format!("decoder.{name}"), param));
1966 }
1967 params
1968 }
1969
1970 fn train(&mut self) {
1971 self.training = true;
1972 self.encoder.train();
1973 self.decoder.train();
1974 }
1975
1976 fn eval(&mut self) {
1977 self.training = false;
1978 self.encoder.eval();
1979 self.decoder.eval();
1980 }
1981
1982 fn is_training(&self) -> bool {
1983 self.training
1984 }
1985}
1986
1987#[cfg(test)]
1992mod tests {
1993 use super::*;
1994
1995 #[test]
2000 fn test_rope_construction() {
2001 let rope = RotaryPositionEmbedding::<f32>::new(64, 512, 10000.0);
2002 assert!(rope.is_ok());
2003 let rope = rope.unwrap();
2004 assert_eq!(rope.dim(), 64);
2005 assert_eq!(rope.max_seq_len(), 512);
2006 assert_eq!(rope.base(), 10000.0);
2007 }
2008
2009 #[test]
2010 fn test_rope_odd_dim_rejected() {
2011 assert!(RotaryPositionEmbedding::<f32>::new(63, 512, 10000.0).is_err());
2012 }
2013
2014 #[test]
2015 fn test_rope_zero_dim_rejected() {
2016 assert!(RotaryPositionEmbedding::<f32>::new(0, 512, 10000.0).is_err());
2017 }
2018
2019 #[test]
2020 fn test_rope_zero_seq_rejected() {
2021 assert!(RotaryPositionEmbedding::<f32>::new(64, 0, 10000.0).is_err());
2022 }
2023
2024 #[test]
2025 fn test_rope_output_shape_2d() {
2026 let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
2027 let x = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
2029 let y = rope.apply(&x, 0).unwrap();
2030 assert_eq!(y.shape(), &[4, 8]);
2031 }
2032
2033 #[test]
2034 fn test_rope_output_shape_3d() {
2035 let rope = RotaryPositionEmbedding::<f32>::new(16, 256, 10000.0).unwrap();
2036 let x = ferrotorch_core::zeros::<f32>(&[2, 10, 16]).unwrap();
2038 let y = rope.apply(&x, 0).unwrap();
2039 assert_eq!(y.shape(), &[2, 10, 16]);
2040 }
2041
2042 #[test]
2043 fn test_rope_output_shape_4d() {
2044 let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
2045 let x = ferrotorch_core::zeros::<f32>(&[2, 4, 6, 8]).unwrap();
2047 let y = rope.apply(&x, 0).unwrap();
2048 assert_eq!(y.shape(), &[2, 4, 6, 8]);
2049 }
2050
2051 #[test]
2052 fn test_rope_with_offset() {
2053 let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
2054 let x = ferrotorch_core::ones::<f32>(&[4, 8]).unwrap();
2055 let y = rope.apply(&x, 10).unwrap();
2057 assert_eq!(y.shape(), &[4, 8]);
2058 }
2059
2060 #[test]
2061 fn test_rope_offset_overflow_rejected() {
2062 let rope = RotaryPositionEmbedding::<f32>::new(8, 16, 10000.0).unwrap();
2063 let x = ferrotorch_core::zeros::<f32>(&[10, 8]).unwrap();
2065 assert!(rope.apply(&x, 10).is_err());
2066 }
2067
2068 #[test]
2069 fn test_rope_position_zero_is_identity() {
2070 let rope = RotaryPositionEmbedding::<f64>::new(4, 64, 10000.0).unwrap();
2072 let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
2073 let y = rope.apply(&x, 0).unwrap();
2074 let y_data = y.data().unwrap();
2075 let x_data = x.data().unwrap();
2076 for (i, (&xv, &yv)) in x_data.iter().zip(y_data.iter()).enumerate() {
2077 assert!(
2078 (xv - yv).abs() < 1e-10,
2079 "position 0 should be identity, index {i}: x={xv}, y={yv}"
2080 );
2081 }
2082 }
2083
2084 #[test]
2085 fn test_rope_values_are_finite() {
2086 let rope = RotaryPositionEmbedding::<f32>::new(16, 512, 10000.0).unwrap();
2087 let x = ferrotorch_core::ones::<f32>(&[2, 4, 10, 16]).unwrap();
2088 let y = rope.apply(&x, 0).unwrap();
2089 for &v in y.data().unwrap() {
2090 assert!(v.is_finite(), "RoPE produced non-finite value: {v}");
2091 }
2092 }
2093
2094 #[test]
2095 fn test_rope_wrong_dim_rejected() {
2096 let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
2097 let x = ferrotorch_core::zeros::<f32>(&[4, 10]).unwrap(); assert!(rope.apply(&x, 0).is_err());
2099 }
2100
2101 #[test]
2104 fn test_rope_scaling_default_is_none() {
2105 let rope = RotaryPositionEmbedding::<f32>::new(16, 128, 10000.0).unwrap();
2106 assert_eq!(rope.scaling(), RoPEScaling::None);
2107 }
2108
2109 #[test]
2110 fn test_rope_scaling_none_matches_classical() {
2111 let a = RotaryPositionEmbedding::<f64>::new(16, 32, 10000.0).unwrap();
2113 let b = RotaryPositionEmbedding::<f64>::with_scaling(
2114 16,
2115 32,
2116 10000.0,
2117 RoPEConvention::default(),
2118 RoPEScaling::None,
2119 )
2120 .unwrap();
2121 let x = ferrotorch_core::from_slice(
2122 &(0..16).map(|i| i as f64 * 0.1).collect::<Vec<_>>(),
2123 &[1, 16],
2124 )
2125 .unwrap();
2126 let ya = a.apply(&x, 7).unwrap();
2127 let yb = b.apply(&x, 7).unwrap();
2128 for (va, vb) in ya.data().unwrap().iter().zip(yb.data().unwrap().iter()) {
2129 assert!((va - vb).abs() < 1e-12);
2130 }
2131 }
2132
2133 #[test]
2134 fn test_rope_scaling_linear_halves_angles() {
2135 let scaled = RotaryPositionEmbedding::<f64>::with_scaling(
2138 8,
2139 64,
2140 10000.0,
2141 RoPEConvention::default(),
2142 RoPEScaling::Linear { factor: 2.0 },
2143 )
2144 .unwrap();
2145 let plain = RotaryPositionEmbedding::<f64>::new(8, 64, 10000.0).unwrap();
2146
2147 let x = ferrotorch_core::ones::<f64>(&[1, 8]).unwrap();
2150 let y_scaled = scaled.apply(&x, 8).unwrap();
2151 let y_plain = plain.apply(&x, 4).unwrap();
2152 for (a, b) in y_scaled
2153 .data()
2154 .unwrap()
2155 .iter()
2156 .zip(y_plain.data().unwrap().iter())
2157 {
2158 assert!(
2159 (a - b).abs() < 1e-6,
2160 "scaled(pos=8) should match plain(pos=4): {a} vs {b}"
2161 );
2162 }
2163 }
2164
2165 #[test]
2166 fn test_rope_scaling_ntk_inv_freq() {
2167 use super::compute_scaled_inv_freq;
2172
2173 let dim = 64;
2174 let base = 10000.0;
2175 let factor = 4.0;
2176 let ntk = compute_scaled_inv_freq(
2177 dim,
2178 base,
2179 RoPEScaling::NtkAware {
2180 factor,
2181 original_max_pos_embeddings: 2048,
2182 },
2183 );
2184 let plain = compute_scaled_inv_freq(dim, base, RoPEScaling::None);
2185 assert_eq!(ntk.len(), 32);
2186 assert_eq!(plain.len(), 32);
2187
2188 assert!(
2190 (ntk[0] - plain[0]).abs() < 1e-15,
2191 "NTK inv_freq[0] should equal plain inv_freq[0]: ntk={}, plain={}",
2192 ntk[0],
2193 plain[0]
2194 );
2195
2196 let ratio = ntk[31] / plain[31];
2199 let expected = 1.0 / factor;
2200 assert!(
2201 (ratio - expected).abs() < 0.05,
2202 "NTK inv_freq[31]/plain ratio should be ~{expected}: got {ratio}"
2203 );
2204 }
2205
2206 #[test]
2207 fn test_rope_scaling_linear_inv_freq_halved() {
2208 use super::compute_scaled_inv_freq;
2209 let lin = compute_scaled_inv_freq(8, 10000.0, RoPEScaling::Linear { factor: 2.0 });
2210 let plain = compute_scaled_inv_freq(8, 10000.0, RoPEScaling::None);
2211 for (a, b) in lin.iter().zip(plain.iter()) {
2212 assert!(
2213 (a - b / 2.0).abs() < 1e-15,
2214 "linear should halve: {a} vs {b}/2"
2215 );
2216 }
2217 }
2218
2219 #[test]
2220 fn test_rope_scaling_yarn_inv_freq_piecewise() {
2221 use super::compute_scaled_inv_freq;
2224 let dim = 64;
2225 let base = 10000.0;
2226 let factor = 4.0;
2227 let yarn = compute_scaled_inv_freq(dim, base, RoPEScaling::yarn_default(factor, 2048));
2228 let plain = compute_scaled_inv_freq(dim, base, RoPEScaling::None);
2229
2230 assert!(
2232 (yarn[0] - plain[0]).abs() < 1e-12,
2233 "YARN[0] (extrapolation) should equal plain[0]: {} vs {}",
2234 yarn[0],
2235 plain[0]
2236 );
2237 let expected_low = plain[dim / 2 - 1] / factor;
2239 let ratio = yarn[dim / 2 - 1] / expected_low;
2240 assert!(
2241 (ratio - 1.0).abs() < 0.1,
2242 "YARN[dim/2-1] (interpolation) should approx equal plain/factor: {} vs {}",
2243 yarn[dim / 2 - 1],
2244 expected_low
2245 );
2246 }
2247
2248 #[test]
2249 fn test_rope_scaling_yarn_constructs() {
2250 let rope = RotaryPositionEmbedding::<f32>::with_scaling(
2251 64,
2252 256,
2253 10000.0,
2254 RoPEConvention::default(),
2255 RoPEScaling::yarn_default(2.0, 2048),
2256 )
2257 .unwrap();
2258 assert!(matches!(rope.scaling(), RoPEScaling::Yarn { .. }));
2259 let x = ferrotorch_core::ones::<f32>(&[1, 64]).unwrap();
2260 for &v in rope.apply(&x, 0).unwrap().data().unwrap() {
2261 assert!(v.is_finite());
2262 }
2263 }
2264
2265 #[test]
2266 fn test_rope_scaling_rejects_zero_factor() {
2267 let r = RotaryPositionEmbedding::<f32>::with_scaling(
2268 8,
2269 16,
2270 10000.0,
2271 RoPEConvention::default(),
2272 RoPEScaling::Linear { factor: 0.0 },
2273 );
2274 assert!(r.is_err());
2275 }
2276
2277 #[test]
2278 fn test_rope_scaling_rejects_negative_factor() {
2279 let r = RotaryPositionEmbedding::<f32>::with_scaling(
2280 8,
2281 16,
2282 10000.0,
2283 RoPEConvention::default(),
2284 RoPEScaling::NtkAware {
2285 factor: -2.0,
2286 original_max_pos_embeddings: 2048,
2287 },
2288 );
2289 assert!(r.is_err());
2290 }
2291
2292 #[test]
2293 fn test_rope_scaling_accessor() {
2294 let rope = RotaryPositionEmbedding::<f32>::with_scaling(
2295 16,
2296 64,
2297 10000.0,
2298 RoPEConvention::default(),
2299 RoPEScaling::Linear { factor: 4.0 },
2300 )
2301 .unwrap();
2302 assert_eq!(rope.scaling(), RoPEScaling::Linear { factor: 4.0 });
2303 }
2304
2305 #[test]
2310 fn test_rope_half_rotation_construction() {
2311 let rope = RotaryPositionEmbedding::<f32>::with_convention(
2312 8,
2313 128,
2314 10000.0,
2315 RoPEConvention::HalfRotation,
2316 )
2317 .unwrap();
2318 assert_eq!(rope.convention(), RoPEConvention::HalfRotation);
2319 }
2320
2321 #[test]
2322 fn test_rope_half_rotation_output_shape() {
2323 let rope = RotaryPositionEmbedding::<f32>::with_convention(
2324 8,
2325 128,
2326 10000.0,
2327 RoPEConvention::HalfRotation,
2328 )
2329 .unwrap();
2330 let x = ferrotorch_core::zeros::<f32>(&[2, 4, 8]).unwrap();
2331 let y = rope.apply(&x, 0).unwrap();
2332 assert_eq!(y.shape(), &[2, 4, 8]);
2333 }
2334
2335 #[test]
2336 fn test_rope_half_rotation_position_zero_is_identity() {
2337 let rope = RotaryPositionEmbedding::<f64>::with_convention(
2339 4,
2340 64,
2341 10000.0,
2342 RoPEConvention::HalfRotation,
2343 )
2344 .unwrap();
2345 let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
2346 let y = rope.apply(&x, 0).unwrap();
2347 let x_data = x.data().unwrap();
2348 let y_data = y.data().unwrap();
2349 for (i, (&xv, &yv)) in x_data.iter().zip(y_data.iter()).enumerate() {
2350 assert!(
2351 (xv - yv).abs() < 1e-10,
2352 "half-rot pos 0 should be identity, index {i}: x={xv}, y={yv}"
2353 );
2354 }
2355 }
2356
2357 #[test]
2358 fn test_rope_half_rotation_correctness() {
2359 let rope = RotaryPositionEmbedding::<f64>::with_convention(
2365 4,
2366 64,
2367 10000.0,
2368 RoPEConvention::HalfRotation,
2369 )
2370 .unwrap();
2371
2372 let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
2374 let y = rope.apply(&x, 1).unwrap();
2375
2376 let cos_data = rope.cos_cache.data().unwrap();
2378 let sin_data = rope.sin_cache.data().unwrap();
2379 let c0 = cos_data[2];
2381 let c1 = cos_data[3];
2382 let s0 = sin_data[2];
2383 let s1 = sin_data[3];
2384
2385 let expected = [
2386 1.0 * c0 - 3.0 * s0,
2387 2.0 * c1 - 4.0 * s1,
2388 1.0 * s0 + 3.0 * c0,
2389 2.0 * s1 + 4.0 * c1,
2390 ];
2391
2392 let y_data = y.data().unwrap();
2393 for (i, (&actual, &exp)) in y_data.iter().zip(expected.iter()).enumerate() {
2394 assert!(
2395 (actual - exp).abs() < 1e-10,
2396 "half-rot index {i}: actual={actual}, expected={exp}"
2397 );
2398 }
2399 }
2400
2401 #[test]
2402 fn test_rope_interleaved_vs_half_rotation_differ() {
2403 let rope_il = RotaryPositionEmbedding::<f64>::with_convention(
2405 4,
2406 64,
2407 10000.0,
2408 RoPEConvention::Interleaved,
2409 )
2410 .unwrap();
2411 let rope_hr = RotaryPositionEmbedding::<f64>::with_convention(
2412 4,
2413 64,
2414 10000.0,
2415 RoPEConvention::HalfRotation,
2416 )
2417 .unwrap();
2418
2419 let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
2420 let y_il = rope_il.apply(&x, 1).unwrap();
2421 let y_hr = rope_hr.apply(&x, 1).unwrap();
2422
2423 let il_data = y_il.data().unwrap();
2425 let hr_data = y_hr.data().unwrap();
2426 let any_differ = il_data
2427 .iter()
2428 .zip(hr_data.iter())
2429 .any(|(&a, &b)| (a - b).abs() > 1e-10);
2430 assert!(
2431 any_differ,
2432 "interleaved and half-rotation should produce different outputs at pos > 0"
2433 );
2434 }
2435
2436 #[test]
2437 fn test_rope_default_convention_is_interleaved() {
2438 let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
2439 assert_eq!(rope.convention(), RoPEConvention::Interleaved);
2440 }
2441
2442 #[test]
2447 fn test_swiglu_construction() {
2448 let swiglu = SwiGLU::<f32>::new(64, 128, true);
2449 assert!(swiglu.is_ok());
2450 }
2451
2452 #[test]
2453 fn test_swiglu_forward_shape_2d() {
2454 let swiglu = SwiGLU::<f32>::new(16, 32, true).unwrap();
2455 let input = ferrotorch_core::zeros::<f32>(&[4, 16]).unwrap();
2456 let output = swiglu.forward(&input).unwrap();
2457 assert_eq!(output.shape(), &[4, 16]);
2458 }
2459
2460 #[test]
2461 fn test_swiglu_forward_shape_3d() {
2462 let swiglu = SwiGLU::<f32>::new(16, 32, false).unwrap();
2463 let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
2464 let output = swiglu.forward(&input).unwrap();
2465 assert_eq!(output.shape(), &[2, 5, 16]);
2466 }
2467
2468 #[test]
2469 fn test_swiglu_forward_values_finite() {
2470 let swiglu = SwiGLU::<f32>::new(8, 16, true).unwrap();
2471 let input = ferrotorch_core::ones::<f32>(&[2, 3, 8]).unwrap();
2472 let output = swiglu.forward(&input).unwrap();
2473 for &v in output.data().unwrap() {
2474 assert!(v.is_finite(), "SwiGLU produced non-finite value: {v}");
2475 }
2476 }
2477
2478 #[test]
2479 fn test_swiglu_1d_rejected() {
2480 let swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
2481 let input = ferrotorch_core::zeros::<f32>(&[8]).unwrap();
2482 assert!(swiglu.forward(&input).is_err());
2483 }
2484
2485 #[test]
2486 fn test_swiglu_parameters() {
2487 let swiglu = SwiGLU::<f32>::new(8, 16, true).unwrap();
2488 let params = swiglu.parameters();
2489 assert_eq!(params.len(), 6);
2491
2492 let named = swiglu.named_parameters();
2493 let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
2494 assert!(names.contains(&"w1.weight"));
2495 assert!(names.contains(&"w1.bias"));
2496 assert!(names.contains(&"w2.weight"));
2497 assert!(names.contains(&"w2.bias"));
2498 assert!(names.contains(&"w3.weight"));
2499 assert!(names.contains(&"w3.bias"));
2500 }
2501
2502 #[test]
2503 fn test_swiglu_parameters_no_bias() {
2504 let swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
2505 let params = swiglu.parameters();
2506 assert_eq!(params.len(), 3);
2508 }
2509
2510 #[test]
2511 fn test_swiglu_train_eval() {
2512 let mut swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
2513 assert!(swiglu.is_training());
2514 swiglu.eval();
2515 assert!(!swiglu.is_training());
2516 swiglu.train();
2517 assert!(swiglu.is_training());
2518 }
2519
2520 #[test]
2525 fn test_kv_cache_new_empty() {
2526 let cache = KVCache::<f32>::new(1024);
2527 assert!(cache.is_empty());
2528 assert_eq!(cache.seq_len(), 0);
2529 assert_eq!(cache.max_seq_len(), 1024);
2530 }
2531
2532 #[test]
2533 fn test_kv_cache_single_update() {
2534 let mut cache = KVCache::<f32>::new(128);
2535 let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
2537 let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
2538 let (fk, fv) = cache.update(k, v).unwrap();
2539 assert_eq!(fk.shape(), &[1, 2, 3, 4]);
2540 assert_eq!(fv.shape(), &[1, 2, 3, 4]);
2541 assert_eq!(cache.seq_len(), 3);
2542 }
2543
2544 #[test]
2545 fn test_kv_cache_append() {
2546 let mut cache = KVCache::<f32>::new(128);
2547 let k1 = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
2549 let v1 = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
2550 cache.update(k1, v1).unwrap();
2551 assert_eq!(cache.seq_len(), 3);
2552
2553 let k2 = ferrotorch_core::ones::<f32>(&[1, 2, 2, 4]).unwrap();
2555 let v2 = ferrotorch_core::ones::<f32>(&[1, 2, 2, 4]).unwrap();
2556 let (fk, fv) = cache.update(k2, v2).unwrap();
2557 assert_eq!(fk.shape(), &[1, 2, 5, 4]); assert_eq!(fv.shape(), &[1, 2, 5, 4]);
2559 assert_eq!(cache.seq_len(), 5);
2560 }
2561
2562 #[test]
2563 fn test_kv_cache_reset() {
2564 let mut cache = KVCache::<f32>::new(128);
2565 let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
2566 let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
2567 cache.update(k, v).unwrap();
2568 assert_eq!(cache.seq_len(), 3);
2569
2570 cache.reset();
2571 assert!(cache.is_empty());
2572 assert_eq!(cache.seq_len(), 0);
2573 }
2574
2575 #[test]
2576 fn test_kv_cache_overflow_rejected() {
2577 let mut cache = KVCache::<f32>::new(4);
2578 let k = ferrotorch_core::ones::<f32>(&[1, 1, 5, 2]).unwrap();
2579 let v = ferrotorch_core::ones::<f32>(&[1, 1, 5, 2]).unwrap();
2580 assert!(cache.update(k, v).is_err());
2582 }
2583
2584 #[test]
2585 fn test_kv_cache_shape_mismatch_rejected() {
2586 let mut cache = KVCache::<f32>::new(128);
2587 let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
2588 let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 8]).unwrap(); assert!(cache.update(k, v).is_err());
2590 }
2591
2592 #[test]
2593 fn test_kv_cache_values_preserved() {
2594 let mut cache = KVCache::<f64>::new(128);
2595 let k1 = ferrotorch_core::ones::<f64>(&[1, 1, 2, 3]).unwrap();
2597 let v1 = ferrotorch_core::ones::<f64>(&[1, 1, 2, 3]).unwrap();
2598 cache.update(k1, v1).unwrap();
2599
2600 let k2_data = vec![2.0f64; 3];
2602 let k2 = ferrotorch_core::from_slice(&k2_data, &[1, 1, 1, 3]).unwrap();
2603 let v2 = ferrotorch_core::from_slice(&k2_data, &[1, 1, 1, 3]).unwrap();
2604 let (fk, _fv) = cache.update(k2, v2).unwrap();
2605
2606 assert_eq!(fk.shape(), &[1, 1, 3, 3]); let fk_data = fk.data().unwrap();
2608 for &v in &fk_data[..6] {
2610 assert!((v - 1.0).abs() < 1e-10, "expected 1.0, got {v}");
2611 }
2612 for &v in &fk_data[6..9] {
2613 assert!((v - 2.0).abs() < 1e-10, "expected 2.0, got {v}");
2614 }
2615 }
2616
2617 #[test]
2620 fn test_kv_cache_gqa_stores_at_kv_head_granularity() {
2621 let mut cache = KVCache::<f32>::new(8192);
2623 let k = ferrotorch_core::zeros::<f32>(&[1, 8, 3, 128]).unwrap();
2624 let v = ferrotorch_core::zeros::<f32>(&[1, 8, 3, 128]).unwrap();
2625 let (fk, _) = cache.update(k, v).unwrap();
2626 assert_eq!(fk.shape(), &[1, 8, 3, 128]);
2627 assert_eq!(cache.num_kv_heads(), Some(8));
2628 assert_eq!(cache.head_dim(), Some(128));
2629 assert_eq!(cache.batch_size(), Some(1));
2630 }
2631
2632 #[test]
2633 fn test_kv_cache_with_dims_pre_declares_shape() {
2634 let cache = KVCache::<f32>::with_dims(8192, 1, 8, 128);
2635 assert_eq!(cache.num_kv_heads(), Some(8));
2636 assert_eq!(cache.head_dim(), Some(128));
2637 assert_eq!(cache.batch_size(), Some(1));
2638 assert!(cache.is_empty());
2639 }
2640
2641 #[test]
2642 fn test_kv_cache_with_dims_rejects_first_update_mismatch() {
2643 let mut cache = KVCache::<f32>::with_dims(128, 1, 8, 16);
2645 let k = ferrotorch_core::zeros::<f32>(&[1, 4, 2, 16]).unwrap();
2646 let v = ferrotorch_core::zeros::<f32>(&[1, 4, 2, 16]).unwrap();
2647 assert!(cache.update(k, v).is_err());
2648 }
2649
2650 #[test]
2651 fn test_kv_cache_with_dims_rejects_head_dim_mismatch() {
2652 let mut cache = KVCache::<f32>::with_dims(128, 1, 8, 16);
2653 let k = ferrotorch_core::zeros::<f32>(&[1, 8, 2, 32]).unwrap(); let v = ferrotorch_core::zeros::<f32>(&[1, 8, 2, 32]).unwrap();
2655 assert!(cache.update(k, v).is_err());
2656 }
2657
2658 #[test]
2659 fn test_kv_cache_with_dims_rejects_batch_mismatch() {
2660 let mut cache = KVCache::<f32>::with_dims(128, 2, 4, 8);
2661 let k = ferrotorch_core::zeros::<f32>(&[1, 4, 2, 8]).unwrap(); let v = ferrotorch_core::zeros::<f32>(&[1, 4, 2, 8]).unwrap();
2663 assert!(cache.update(k, v).is_err());
2664 }
2665
2666 #[test]
2667 fn test_kv_cache_with_dims_accepts_matching_update() {
2668 let mut cache = KVCache::<f32>::with_dims(128, 1, 8, 16);
2669 let k = ferrotorch_core::ones::<f32>(&[1, 8, 3, 16]).unwrap();
2670 let v = ferrotorch_core::ones::<f32>(&[1, 8, 3, 16]).unwrap();
2671 assert!(cache.update(k, v).is_ok());
2672 assert_eq!(cache.seq_len(), 3);
2673 }
2674
2675 #[test]
2676 fn test_kv_cache_inferred_dims_reject_subsequent_mismatch() {
2677 let mut cache = KVCache::<f32>::new(128);
2679 let k1 = ferrotorch_core::zeros::<f32>(&[1, 8, 2, 16]).unwrap();
2680 let v1 = ferrotorch_core::zeros::<f32>(&[1, 8, 2, 16]).unwrap();
2681 cache.update(k1, v1).unwrap();
2682 assert_eq!(cache.num_kv_heads(), Some(8));
2683
2684 let k2 = ferrotorch_core::zeros::<f32>(&[1, 4, 1, 16]).unwrap(); let v2 = ferrotorch_core::zeros::<f32>(&[1, 4, 1, 16]).unwrap();
2686 assert!(cache.update(k2, v2).is_err());
2687 }
2688
2689 #[test]
2690 fn test_kv_cache_dims_not_yet_pinned_on_fresh_new() {
2691 let cache = KVCache::<f32>::new(128);
2692 assert_eq!(cache.num_kv_heads(), None);
2693 assert_eq!(cache.head_dim(), None);
2694 assert_eq!(cache.batch_size(), None);
2695 }
2696
2697 #[test]
2698 fn test_kv_cache_reset_preserves_pinned_dims() {
2699 let mut cache = KVCache::<f32>::with_dims(128, 1, 8, 16);
2700 let k = ferrotorch_core::ones::<f32>(&[1, 8, 2, 16]).unwrap();
2701 let v = ferrotorch_core::ones::<f32>(&[1, 8, 2, 16]).unwrap();
2702 cache.update(k, v).unwrap();
2703 cache.reset();
2704 assert!(cache.is_empty());
2705 assert_eq!(cache.num_kv_heads(), Some(8));
2707 let bad = ferrotorch_core::zeros::<f32>(&[1, 4, 1, 16]).unwrap();
2708 assert!(cache.update(bad.clone(), bad).is_err());
2709 }
2710
2711 #[test]
2712 fn test_kv_cache_gqa_prefill_then_decode_preserves_all_positions() {
2713 let build = |seed: u64, shape: &[usize]| {
2721 let numel: usize = shape.iter().product();
2722 let data: Vec<f32> = (0..numel)
2723 .map(|i| ((i as u64).wrapping_mul(seed) % 997) as f32 * 0.001)
2724 .collect();
2725 ferrotorch_core::from_slice(&data, shape).unwrap()
2726 };
2727
2728 let (b, h, s_prefill, s_decode, d) = (1usize, 8usize, 4usize, 1usize, 16usize);
2730 let s_full = s_prefill + s_decode;
2731
2732 let k_prefill = build(7, &[b, h, s_prefill, d]);
2733 let v_prefill = build(11, &[b, h, s_prefill, d]);
2734 let k_decode = build(13, &[b, h, s_decode, d]);
2735 let v_decode = build(17, &[b, h, s_decode, d]);
2736
2737 let mut cache = KVCache::<f32>::with_dims(16, b, h, d);
2738 cache.update(k_prefill.clone(), v_prefill.clone()).unwrap();
2739 let (fk, fv) = cache.update(k_decode.clone(), v_decode.clone()).unwrap();
2740 assert_eq!(fk.shape(), &[b, h, s_full, d]);
2741 assert_eq!(fv.shape(), &[b, h, s_full, d]);
2742
2743 let fk_data = fk.data_vec().unwrap();
2744 let fv_data = fv.data_vec().unwrap();
2745 let kp = k_prefill.data_vec().unwrap();
2746 let vp = v_prefill.data_vec().unwrap();
2747 let kd = k_decode.data_vec().unwrap();
2748 let vd = v_decode.data_vec().unwrap();
2749
2750 let full_idx = |bi, hi, si, di| ((bi * h + hi) * s_full + si) * d + di;
2752 let src_idx = |bi, hi, si, di, s_len| ((bi * h + hi) * s_len + si) * d + di;
2753
2754 for bi in 0..b {
2755 for hi in 0..h {
2756 for si in 0..s_full {
2757 for di in 0..d {
2758 let out = full_idx(bi, hi, si, di);
2759 let (exp_k, exp_v) = if si < s_prefill {
2760 let src = src_idx(bi, hi, si, di, s_prefill);
2761 (kp[src], vp[src])
2762 } else {
2763 let src = src_idx(bi, hi, si - s_prefill, di, s_decode);
2764 (kd[src], vd[src])
2765 };
2766 assert!(
2767 (fk_data[out] - exp_k).abs() < 1e-6,
2768 "k mismatch at [b={bi}, h={hi}, s={si}, d={di}]: got {}, want {exp_k}",
2769 fk_data[out]
2770 );
2771 assert!(
2772 (fv_data[out] - exp_v).abs() < 1e-6,
2773 "v mismatch at [b={bi}, h={hi}, s={si}, d={di}]: got {}, want {exp_v}",
2774 fv_data[out]
2775 );
2776 }
2777 }
2778 }
2779 }
2780 }
2781
2782 #[test]
2787 fn test_encoder_layer_construction() {
2788 let layer = TransformerEncoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, true);
2789 assert!(layer.is_ok());
2790 }
2791
2792 #[test]
2793 fn test_encoder_layer_forward_shape() {
2794 let layer = TransformerEncoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, false).unwrap();
2795 let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
2796 let output = layer.forward(&input).unwrap();
2797 assert_eq!(output.shape(), &[2, 5, 16]);
2798 }
2799
2800 #[test]
2801 fn test_encoder_layer_forward_values_finite() {
2802 let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
2803 let input = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
2804 let output = layer.forward(&input).unwrap();
2805 for &v in output.data().unwrap() {
2806 assert!(
2807 v.is_finite(),
2808 "TransformerEncoderLayer produced non-finite value: {v}"
2809 );
2810 }
2811 }
2812
2813 #[test]
2814 fn test_encoder_layer_2d_rejected() {
2815 let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, false).unwrap();
2816 let input = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
2817 assert!(layer.forward(&input).is_err());
2818 }
2819
2820 #[test]
2821 fn test_encoder_layer_parameters_count() {
2822 let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
2823 let params = layer.parameters();
2824 assert_eq!(params.len(), 18);
2830 }
2831
2832 #[test]
2833 fn test_encoder_layer_train_eval() {
2834 let mut layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.1, 1e-5, false).unwrap();
2835 assert!(layer.is_training());
2836 layer.eval();
2837 assert!(!layer.is_training());
2838 layer.train();
2839 assert!(layer.is_training());
2840 }
2841
2842 #[test]
2843 fn test_encoder_layer_is_send_sync() {
2844 fn assert_send_sync<T: Send + Sync>() {}
2845 assert_send_sync::<TransformerEncoderLayer<f32>>();
2846 assert_send_sync::<TransformerEncoderLayer<f64>>();
2847 }
2848
2849 #[test]
2854 fn test_decoder_layer_construction() {
2855 let layer = TransformerDecoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, true);
2856 assert!(layer.is_ok());
2857 }
2858
2859 #[test]
2860 fn test_decoder_layer_forward_shape() {
2861 let layer = TransformerDecoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, false).unwrap();
2862 let tgt = ferrotorch_core::zeros::<f32>(&[2, 4, 16]).unwrap();
2864 let memory = ferrotorch_core::zeros::<f32>(&[2, 6, 16]).unwrap();
2865 let output = layer.forward_with_memory(&tgt, &memory).unwrap();
2866 assert_eq!(output.shape(), &[2, 4, 16]);
2867 }
2868
2869 #[test]
2870 fn test_decoder_layer_self_forward_shape() {
2871 let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
2873 let input = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
2874 let output = layer.forward(&input).unwrap();
2875 assert_eq!(output.shape(), &[1, 3, 8]);
2876 }
2877
2878 #[test]
2879 fn test_decoder_layer_forward_values_finite() {
2880 let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
2881 let tgt = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
2882 let mem = ferrotorch_core::ones::<f32>(&[1, 5, 8]).unwrap();
2883 let output = layer.forward_with_memory(&tgt, &mem).unwrap();
2884 for &v in output.data().unwrap() {
2885 assert!(
2886 v.is_finite(),
2887 "TransformerDecoderLayer produced non-finite value: {v}"
2888 );
2889 }
2890 }
2891
2892 #[test]
2893 fn test_decoder_layer_2d_rejected() {
2894 let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, false).unwrap();
2895 let input = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
2896 let memory = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
2897 assert!(layer.forward_with_memory(&input, &memory).is_err());
2898 }
2899
2900 #[test]
2901 fn test_decoder_layer_parameters_count() {
2902 let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
2903 let params = layer.parameters();
2904 assert_eq!(params.len(), 28);
2912 }
2913
2914 #[test]
2915 fn test_decoder_layer_train_eval() {
2916 let mut layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.1, 1e-5, false).unwrap();
2917 assert!(layer.is_training());
2918 layer.eval();
2919 assert!(!layer.is_training());
2920 layer.train();
2921 assert!(layer.is_training());
2922 }
2923
2924 #[test]
2925 fn test_decoder_layer_is_send_sync() {
2926 fn assert_send_sync<T: Send + Sync>() {}
2927 assert_send_sync::<TransformerDecoderLayer<f32>>();
2928 assert_send_sync::<TransformerDecoderLayer<f64>>();
2929 }
2930
2931 #[test]
2936 fn test_encoder_construction() {
2937 let enc = TransformerEncoder::<f32>::new(16, 4, 3, 32, 0.0, 1e-5, true, true);
2938 assert!(enc.is_ok());
2939 assert_eq!(enc.unwrap().num_layers(), 3);
2940 }
2941
2942 #[test]
2943 fn test_encoder_zero_layers_rejected() {
2944 assert!(TransformerEncoder::<f32>::new(16, 4, 0, 32, 0.0, 1e-5, true, true).is_err());
2945 }
2946
2947 #[test]
2948 fn test_encoder_forward_shape() {
2949 let enc = TransformerEncoder::<f32>::new(16, 4, 2, 32, 0.0, 1e-5, false, true).unwrap();
2950 let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
2951 let output = enc.forward(&input).unwrap();
2952 assert_eq!(output.shape(), &[2, 5, 16]);
2953 }
2954
2955 #[test]
2956 fn test_encoder_forward_no_final_norm() {
2957 let enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, false, false).unwrap();
2958 let input = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
2959 let output = enc.forward(&input).unwrap();
2960 assert_eq!(output.shape(), &[1, 3, 8]);
2961 }
2962
2963 #[test]
2964 fn test_encoder_forward_values_finite() {
2965 let enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
2966 let input = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
2967 let output = enc.forward(&input).unwrap();
2968 for &v in output.data().unwrap() {
2969 assert!(
2970 v.is_finite(),
2971 "TransformerEncoder produced non-finite value: {v}"
2972 );
2973 }
2974 }
2975
2976 #[test]
2977 fn test_encoder_parameters_with_final_norm() {
2978 let enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
2979 assert_eq!(enc.parameters().len(), 38);
2983 }
2984
2985 #[test]
2986 fn test_encoder_named_parameters_have_layer_prefix() {
2987 let enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
2988 let named = enc.named_parameters();
2989 let has_layer_0 = named.iter().any(|(n, _)| n.starts_with("layers.0."));
2991 let has_layer_1 = named.iter().any(|(n, _)| n.starts_with("layers.1."));
2992 let has_norm = named.iter().any(|(n, _)| n.starts_with("norm."));
2993 assert!(has_layer_0, "missing layers.0.* in named_parameters");
2994 assert!(has_layer_1, "missing layers.1.* in named_parameters");
2995 assert!(has_norm, "missing norm.* in named_parameters");
2996 }
2997
2998 #[test]
2999 fn test_encoder_train_eval() {
3000 let mut enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.1, 1e-5, false, false).unwrap();
3001 assert!(enc.is_training());
3002 enc.eval();
3003 assert!(!enc.is_training());
3004 enc.train();
3005 assert!(enc.is_training());
3006 }
3007
3008 #[test]
3009 fn test_encoder_is_send_sync() {
3010 fn assert_send_sync<T: Send + Sync>() {}
3011 assert_send_sync::<TransformerEncoder<f32>>();
3012 assert_send_sync::<TransformerEncoder<f64>>();
3013 }
3014
3015 #[test]
3020 fn test_decoder_construction() {
3021 let dec = TransformerDecoder::<f32>::new(16, 4, 3, 32, 0.0, 1e-5, true, true);
3022 assert!(dec.is_ok());
3023 assert_eq!(dec.unwrap().num_layers(), 3);
3024 }
3025
3026 #[test]
3027 fn test_decoder_zero_layers_rejected() {
3028 assert!(TransformerDecoder::<f32>::new(16, 4, 0, 32, 0.0, 1e-5, true, true).is_err());
3029 }
3030
3031 #[test]
3032 fn test_decoder_forward_with_memory_shape() {
3033 let dec = TransformerDecoder::<f32>::new(16, 4, 2, 32, 0.0, 1e-5, false, true).unwrap();
3034 let tgt = ferrotorch_core::zeros::<f32>(&[2, 4, 16]).unwrap();
3035 let memory = ferrotorch_core::zeros::<f32>(&[2, 6, 16]).unwrap();
3036 let output = dec.forward_with_memory(&tgt, &memory).unwrap();
3037 assert_eq!(output.shape(), &[2, 4, 16]);
3038 }
3039
3040 #[test]
3041 fn test_decoder_forward_values_finite() {
3042 let dec = TransformerDecoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
3043 let tgt = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
3044 let mem = ferrotorch_core::ones::<f32>(&[1, 5, 8]).unwrap();
3045 let output = dec.forward_with_memory(&tgt, &mem).unwrap();
3046 for &v in output.data().unwrap() {
3047 assert!(
3048 v.is_finite(),
3049 "TransformerDecoder produced non-finite value: {v}"
3050 );
3051 }
3052 }
3053
3054 #[test]
3055 fn test_decoder_parameters_with_final_norm() {
3056 let dec = TransformerDecoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
3057 assert_eq!(dec.parameters().len(), 58);
3061 }
3062
3063 #[test]
3064 fn test_decoder_named_parameters_have_layer_prefix() {
3065 let dec = TransformerDecoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
3066 let named = dec.named_parameters();
3067 let has_layer_0 = named.iter().any(|(n, _)| n.starts_with("layers.0."));
3068 let has_layer_1 = named.iter().any(|(n, _)| n.starts_with("layers.1."));
3069 let has_norm = named.iter().any(|(n, _)| n.starts_with("norm."));
3070 assert!(has_layer_0, "missing layers.0.* in named_parameters");
3071 assert!(has_layer_1, "missing layers.1.* in named_parameters");
3072 assert!(has_norm, "missing norm.* in named_parameters");
3073 }
3074
3075 #[test]
3076 fn test_decoder_train_eval() {
3077 let mut dec = TransformerDecoder::<f32>::new(8, 2, 2, 16, 0.1, 1e-5, false, false).unwrap();
3078 assert!(dec.is_training());
3079 dec.eval();
3080 assert!(!dec.is_training());
3081 dec.train();
3082 assert!(dec.is_training());
3083 }
3084
3085 #[test]
3086 fn test_decoder_is_send_sync() {
3087 fn assert_send_sync<T: Send + Sync>() {}
3088 assert_send_sync::<TransformerDecoder<f32>>();
3089 assert_send_sync::<TransformerDecoder<f64>>();
3090 }
3091
3092 #[test]
3097 fn test_transformer_construction() {
3098 let t = Transformer::<f32>::new(16, 4, 2, 2, 32, 0.0, 1e-5, true);
3099 assert!(t.is_ok());
3100 let t = t.unwrap();
3101 assert_eq!(t.num_encoder_layers(), 2);
3102 assert_eq!(t.num_decoder_layers(), 2);
3103 }
3104
3105 #[test]
3106 fn test_transformer_forward_shape() {
3107 let t = Transformer::<f32>::new(16, 4, 2, 2, 32, 0.0, 1e-5, false).unwrap();
3108 let src = ferrotorch_core::zeros::<f32>(&[2, 10, 16]).unwrap();
3109 let tgt = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
3110 let output = t.forward_transformer(&src, &tgt).unwrap();
3111 assert_eq!(output.shape(), &[2, 5, 16]);
3112 }
3113
3114 #[test]
3115 fn test_transformer_self_forward_shape() {
3116 let t = Transformer::<f32>::new(8, 2, 1, 1, 16, 0.0, 1e-5, false).unwrap();
3118 let input = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
3119 let output = t.forward(&input).unwrap();
3120 assert_eq!(output.shape(), &[1, 3, 8]);
3121 }
3122
3123 #[test]
3124 fn test_transformer_forward_values_finite() {
3125 let t = Transformer::<f32>::new(8, 2, 2, 2, 16, 0.0, 1e-5, true).unwrap();
3126 let src = ferrotorch_core::ones::<f32>(&[1, 4, 8]).unwrap();
3127 let tgt = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
3128 let output = t.forward_transformer(&src, &tgt).unwrap();
3129 for &v in output.data().unwrap() {
3130 assert!(v.is_finite(), "Transformer produced non-finite value: {v}");
3131 }
3132 }
3133
3134 #[test]
3135 fn test_transformer_parameters_count() {
3136 let t = Transformer::<f32>::new(8, 2, 2, 2, 16, 0.0, 1e-5, true).unwrap();
3137 assert_eq!(t.parameters().len(), 96);
3141 }
3142
3143 #[test]
3144 fn test_transformer_named_parameters_prefixed() {
3145 let t = Transformer::<f32>::new(8, 2, 1, 1, 16, 0.0, 1e-5, true).unwrap();
3146 let named = t.named_parameters();
3147 let has_encoder = named.iter().any(|(n, _)| n.starts_with("encoder."));
3148 let has_decoder = named.iter().any(|(n, _)| n.starts_with("decoder."));
3149 assert!(has_encoder, "missing encoder.* in named_parameters");
3150 assert!(has_decoder, "missing decoder.* in named_parameters");
3151 }
3152
3153 #[test]
3154 fn test_transformer_train_eval() {
3155 let mut t = Transformer::<f32>::new(8, 2, 1, 1, 16, 0.1, 1e-5, false).unwrap();
3156 assert!(t.is_training());
3157 t.eval();
3158 assert!(!t.is_training());
3159 t.train();
3160 assert!(t.is_training());
3161 }
3162
3163 #[test]
3164 fn test_transformer_is_send_sync() {
3165 fn assert_send_sync<T: Send + Sync>() {}
3166 assert_send_sync::<Transformer<f32>>();
3167 assert_send_sync::<Transformer<f64>>();
3168 }
3169}