1use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
52use ferrotorch_nn::module::{Module, StateDict};
53use ferrotorch_nn::parameter::Parameter;
54use ferrotorch_nn::{Conv2d, GELU, GroupNorm, LayerNorm, Linear};
55
56#[derive(Debug)]
70pub struct Attention<T: Float> {
71 pub dim_head: usize,
73 pub heads: usize,
75 pub inner_dim: usize,
77 pub to_q: Linear<T>,
79 pub to_k: Linear<T>,
81 pub to_v: Linear<T>,
83 pub to_out_0: Linear<T>,
85 query_dim: usize,
86 kv_dim: usize,
87 scale: f64,
88 training: bool,
89}
90
91impl<T: Float> Attention<T> {
92 pub fn new(
106 query_dim: usize,
107 cross_attention_dim: Option<usize>,
108 heads: usize,
109 dim_head: usize,
110 bias: bool,
111 ) -> FerrotorchResult<Self> {
112 let inner_dim = heads * dim_head;
113 let kv_dim = cross_attention_dim.unwrap_or(query_dim);
114 let to_q = Linear::<T>::new(query_dim, inner_dim, bias)?;
115 let to_k = Linear::<T>::new(kv_dim, inner_dim, bias)?;
116 let to_v = Linear::<T>::new(kv_dim, inner_dim, bias)?;
117 let to_out_0 = Linear::<T>::new(inner_dim, query_dim, true)?;
118 let scale = (dim_head as f64).sqrt().recip();
119 Ok(Self {
120 dim_head,
121 heads,
122 inner_dim,
123 to_q,
124 to_k,
125 to_v,
126 to_out_0,
127 query_dim,
128 kv_dim,
129 scale,
130 training: false,
131 })
132 }
133
134 pub fn forward_xattn(
145 &self,
146 hidden_states: &Tensor<T>,
147 encoder_hidden_states: Option<&Tensor<T>>,
148 ) -> FerrotorchResult<Tensor<T>> {
149 if hidden_states.ndim() != 3 || hidden_states.shape()[2] != self.query_dim {
150 return Err(FerrotorchError::ShapeMismatch {
151 message: format!(
152 "Attention::forward_xattn: expected hidden_states [B, N, {}], got {:?}",
153 self.query_dim,
154 hidden_states.shape()
155 ),
156 });
157 }
158 let b = hidden_states.shape()[0];
159 let n = hidden_states.shape()[1];
160 let kv = encoder_hidden_states.unwrap_or(hidden_states);
162 if kv.ndim() != 3 || kv.shape()[0] != b || kv.shape()[2] != self.kv_dim {
163 return Err(FerrotorchError::ShapeMismatch {
164 message: format!(
165 "Attention::forward_xattn: expected kv [B={b}, S, {}], got {:?}",
166 self.kv_dim,
167 kv.shape()
168 ),
169 });
170 }
171 let s = kv.shape()[1];
172
173 let q = self.to_q.forward(hidden_states)?;
175 let k = self.to_k.forward(kv)?;
176 let v = self.to_v.forward(kv)?;
177
178 let h = self.heads;
182 let d = self.dim_head;
183 let q = q
184 .reshape_t(&[b as isize, n as isize, h as isize, d as isize])?
185 .transpose(1, 2)? .contiguous()?
187 .reshape_t(&[(b * h) as isize, n as isize, d as isize])?;
188 let k = k
189 .reshape_t(&[b as isize, s as isize, h as isize, d as isize])?
190 .transpose(1, 2)? .contiguous()?
192 .reshape_t(&[(b * h) as isize, s as isize, d as isize])?;
193 let v = v
194 .reshape_t(&[b as isize, s as isize, h as isize, d as isize])?
195 .transpose(1, 2)? .contiguous()?
197 .reshape_t(&[(b * h) as isize, s as isize, d as isize])?;
198
199 let k_t = k.transpose(1, 2)?.contiguous()?; let scores = q.bmm(&k_t)?; let scale_t = T::from(self.scale).ok_or_else(|| FerrotorchError::InvalidArgument {
203 message: "Attention::forward_xattn: failed to cast attention scale into Float".into(),
204 })?;
205 let scale_tensor = ferrotorch_core::scalar::<T>(scale_t)?;
206 let scores_scaled = ferrotorch_core::grad_fns::arithmetic::mul(&scores, &scale_tensor)?;
207 let probs = scores_scaled.softmax()?; let attended = probs.bmm(&v)?; let attended = attended
213 .reshape_t(&[b as isize, h as isize, n as isize, d as isize])?
214 .transpose(1, 2)? .contiguous()?
216 .reshape_t(&[b as isize, n as isize, self.inner_dim as isize])?;
217
218 self.to_out_0.forward(&attended)
220 }
221}
222
223impl<T: Float> Module<T> for Attention<T> {
224 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
225 self.forward_xattn(input, None)
227 }
228
229 fn parameters(&self) -> Vec<&Parameter<T>> {
230 let mut o = Vec::new();
231 o.extend(self.to_q.parameters());
232 o.extend(self.to_k.parameters());
233 o.extend(self.to_v.parameters());
234 o.extend(self.to_out_0.parameters());
235 o
236 }
237 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
238 let mut o = Vec::new();
239 o.extend(self.to_q.parameters_mut());
240 o.extend(self.to_k.parameters_mut());
241 o.extend(self.to_v.parameters_mut());
242 o.extend(self.to_out_0.parameters_mut());
243 o
244 }
245 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
246 let mut o = Vec::new();
247 for (n, p) in self.to_q.named_parameters() {
248 o.push((format!("to_q.{n}"), p));
249 }
250 for (n, p) in self.to_k.named_parameters() {
251 o.push((format!("to_k.{n}"), p));
252 }
253 for (n, p) in self.to_v.named_parameters() {
254 o.push((format!("to_v.{n}"), p));
255 }
256 for (n, p) in self.to_out_0.named_parameters() {
257 o.push((format!("to_out.0.{n}"), p));
258 }
259 o
260 }
261 fn train(&mut self) {
262 self.training = true;
263 }
264 fn eval(&mut self) {
265 self.training = false;
266 }
267 fn is_training(&self) -> bool {
268 self.training
269 }
270 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
271 let extract = |prefix: &str| -> StateDict<T> {
272 let p = format!("{prefix}.");
273 state
274 .iter()
275 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
276 .collect()
277 };
278 if strict {
279 for k in state.keys() {
280 let ok = k.starts_with("to_q.")
281 || k.starts_with("to_k.")
282 || k.starts_with("to_v.")
283 || k.starts_with("to_out.0.");
284 if !ok {
285 return Err(FerrotorchError::InvalidArgument {
286 message: format!("unexpected key in Attention state_dict: \"{k}\""),
287 });
288 }
289 }
290 }
291 self.to_q.load_state_dict(&extract("to_q"), strict)?;
292 self.to_k.load_state_dict(&extract("to_k"), strict)?;
293 self.to_v.load_state_dict(&extract("to_v"), strict)?;
294 self.to_out_0
295 .load_state_dict(&extract("to_out.0"), strict)?;
296 Ok(())
297 }
298}
299
300#[derive(Debug)]
326pub struct FeedForward<T: Float> {
327 pub net_0_proj: Linear<T>,
329 pub net_2: Linear<T>,
331 activation: GELU,
332 dim_ff: usize,
333 training: bool,
334}
335
336impl<T: Float> FeedForward<T> {
337 pub fn new(dim: usize, mult: usize) -> FerrotorchResult<Self> {
343 let dim_ff = dim * mult;
344 let net_0_proj = Linear::<T>::new(dim, 2 * dim_ff, true)?;
345 let net_2 = Linear::<T>::new(dim_ff, dim, true)?;
346 Ok(Self {
347 net_0_proj,
348 net_2,
349 activation: GELU::new(),
350 dim_ff,
351 training: false,
352 })
353 }
354}
355
356impl<T: Float> Module<T> for FeedForward<T> {
357 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
358 let proj = self.net_0_proj.forward(input)?;
360 let last = proj.ndim() - 1;
362 let parts = proj.chunk(2, last)?;
363 if parts.len() != 2 {
364 return Err(FerrotorchError::ShapeMismatch {
365 message: format!(
366 "FeedForward: chunk(2) returned {} parts (expected 2)",
367 parts.len()
368 ),
369 });
370 }
371 let x = parts[0].contiguous()?;
372 let gate = parts[1].contiguous()?;
373 let gated = self.activation.forward(&gate)?;
374 let activated = ferrotorch_core::grad_fns::arithmetic::mul(&x, &gated)?;
375 self.net_2.forward(&activated)
376 }
377 fn parameters(&self) -> Vec<&Parameter<T>> {
378 let mut o = Vec::new();
379 o.extend(self.net_0_proj.parameters());
380 o.extend(self.net_2.parameters());
381 o
382 }
383 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
384 let mut o = Vec::new();
385 o.extend(self.net_0_proj.parameters_mut());
386 o.extend(self.net_2.parameters_mut());
387 o
388 }
389 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
390 let mut o = Vec::new();
391 for (n, p) in self.net_0_proj.named_parameters() {
392 o.push((format!("net.0.proj.{n}"), p));
393 }
394 for (n, p) in self.net_2.named_parameters() {
395 o.push((format!("net.2.{n}"), p));
396 }
397 o
398 }
399 fn train(&mut self) {
400 self.training = true;
401 }
402 fn eval(&mut self) {
403 self.training = false;
404 }
405 fn is_training(&self) -> bool {
406 self.training
407 }
408 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
409 let extract = |prefix: &str| -> StateDict<T> {
410 let p = format!("{prefix}.");
411 state
412 .iter()
413 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
414 .collect()
415 };
416 if strict {
417 for k in state.keys() {
418 let ok = k.starts_with("net.0.proj.") || k.starts_with("net.2.");
419 if !ok {
420 return Err(FerrotorchError::InvalidArgument {
421 message: format!("unexpected key in FeedForward state_dict: \"{k}\""),
422 });
423 }
424 }
425 }
426 self.net_0_proj
427 .load_state_dict(&extract("net.0.proj"), strict)?;
428 self.net_2.load_state_dict(&extract("net.2"), strict)?;
429 let _ = self.dim_ff;
430 Ok(())
431 }
432}
433
434#[derive(Debug)]
453pub struct BasicTransformerBlock<T: Float> {
454 pub norm1: LayerNorm<T>,
456 pub attn1: Attention<T>,
458 pub norm2: LayerNorm<T>,
460 pub attn2: Attention<T>,
462 pub norm3: LayerNorm<T>,
464 pub ff: FeedForward<T>,
466 dim: usize,
467 training: bool,
468}
469
470impl<T: Float> BasicTransformerBlock<T> {
471 pub fn new(
477 dim: usize,
478 heads: usize,
479 dim_head: usize,
480 cross_attention_dim: usize,
481 ) -> FerrotorchResult<Self> {
482 let norm1 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
486 let attn1 = Attention::<T>::new(dim, None, heads, dim_head, false)?;
487 let norm2 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
488 let attn2 = Attention::<T>::new(dim, Some(cross_attention_dim), heads, dim_head, false)?;
489 let norm3 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
490 let ff = FeedForward::<T>::new(dim, 4)?;
491 Ok(Self {
492 norm1,
493 attn1,
494 norm2,
495 attn2,
496 norm3,
497 ff,
498 dim,
499 training: false,
500 })
501 }
502
503 pub fn forward_xattn(
511 &self,
512 x: &Tensor<T>,
513 encoder_hidden_states: &Tensor<T>,
514 ) -> FerrotorchResult<Tensor<T>> {
515 if x.ndim() != 3 || x.shape()[2] != self.dim {
516 return Err(FerrotorchError::ShapeMismatch {
517 message: format!(
518 "BasicTransformerBlock::forward: expected x [B, N, {}], got {:?}",
519 self.dim,
520 x.shape()
521 ),
522 });
523 }
524 let h1 = self.norm1.forward(x)?;
526 let h1 = self.attn1.forward_xattn(&h1, None)?;
527 let x = ferrotorch_core::grad_fns::arithmetic::add(&h1, x)?;
528 let h2 = self.norm2.forward(&x)?;
530 let h2 = self.attn2.forward_xattn(&h2, Some(encoder_hidden_states))?;
531 let x = ferrotorch_core::grad_fns::arithmetic::add(&h2, &x)?;
532 let h3 = self.norm3.forward(&x)?;
534 let h3 = self.ff.forward(&h3)?;
535 ferrotorch_core::grad_fns::arithmetic::add(&h3, &x)
536 }
537}
538
539impl<T: Float> Module<T> for BasicTransformerBlock<T> {
540 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
541 Err(FerrotorchError::InvalidArgument {
542 message: "BasicTransformerBlock::forward: cross-attn requires \
543 encoder_hidden_states — call forward_xattn instead"
544 .into(),
545 })
546 }
547
548 fn parameters(&self) -> Vec<&Parameter<T>> {
549 let mut o = Vec::new();
550 o.extend(self.norm1.parameters());
551 o.extend(self.attn1.parameters());
552 o.extend(self.norm2.parameters());
553 o.extend(self.attn2.parameters());
554 o.extend(self.norm3.parameters());
555 o.extend(self.ff.parameters());
556 o
557 }
558 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
559 let mut o = Vec::new();
560 o.extend(self.norm1.parameters_mut());
561 o.extend(self.attn1.parameters_mut());
562 o.extend(self.norm2.parameters_mut());
563 o.extend(self.attn2.parameters_mut());
564 o.extend(self.norm3.parameters_mut());
565 o.extend(self.ff.parameters_mut());
566 o
567 }
568 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
569 let mut o = Vec::new();
570 for (n, p) in self.norm1.named_parameters() {
571 o.push((format!("norm1.{n}"), p));
572 }
573 for (n, p) in self.attn1.named_parameters() {
574 o.push((format!("attn1.{n}"), p));
575 }
576 for (n, p) in self.norm2.named_parameters() {
577 o.push((format!("norm2.{n}"), p));
578 }
579 for (n, p) in self.attn2.named_parameters() {
580 o.push((format!("attn2.{n}"), p));
581 }
582 for (n, p) in self.norm3.named_parameters() {
583 o.push((format!("norm3.{n}"), p));
584 }
585 for (n, p) in self.ff.named_parameters() {
586 o.push((format!("ff.{n}"), p));
587 }
588 o
589 }
590 fn train(&mut self) {
591 self.training = true;
592 }
593 fn eval(&mut self) {
594 self.training = false;
595 }
596 fn is_training(&self) -> bool {
597 self.training
598 }
599 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
600 let extract = |prefix: &str| -> StateDict<T> {
601 let p = format!("{prefix}.");
602 state
603 .iter()
604 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
605 .collect()
606 };
607 if strict {
608 for k in state.keys() {
609 let ok = k.starts_with("norm1.")
610 || k.starts_with("attn1.")
611 || k.starts_with("norm2.")
612 || k.starts_with("attn2.")
613 || k.starts_with("norm3.")
614 || k.starts_with("ff.");
615 if !ok {
616 return Err(FerrotorchError::InvalidArgument {
617 message: format!(
618 "unexpected key in BasicTransformerBlock state_dict: \"{k}\""
619 ),
620 });
621 }
622 }
623 }
624 self.norm1.load_state_dict(&extract("norm1"), strict)?;
625 self.attn1.load_state_dict(&extract("attn1"), strict)?;
626 self.norm2.load_state_dict(&extract("norm2"), strict)?;
627 self.attn2.load_state_dict(&extract("attn2"), strict)?;
628 self.norm3.load_state_dict(&extract("norm3"), strict)?;
629 self.ff.load_state_dict(&extract("ff"), strict)?;
630 Ok(())
631 }
632}
633
634#[derive(Debug)]
656pub struct Transformer2DModel<T: Float> {
657 pub norm: GroupNorm<T>,
659 pub proj_in: Conv2d<T>,
661 pub transformer_blocks: Vec<BasicTransformerBlock<T>>,
663 pub proj_out: Conv2d<T>,
665 channels: usize,
666 inner_dim: usize,
667 training: bool,
668}
669
670impl<T: Float> Transformer2DModel<T> {
671 pub fn new(
681 in_channels: usize,
682 heads: usize,
683 dim_head: usize,
684 num_layers: usize,
685 cross_attention_dim: usize,
686 norm_num_groups: usize,
687 ) -> FerrotorchResult<Self> {
688 let inner_dim = heads * dim_head;
689 let norm = GroupNorm::<T>::new(norm_num_groups, in_channels, 1e-6, true)?;
690 let proj_in = Conv2d::<T>::new(in_channels, inner_dim, (1, 1), (1, 1), (0, 0), true)?;
691 let proj_out = Conv2d::<T>::new(inner_dim, in_channels, (1, 1), (1, 1), (0, 0), true)?;
692 let mut transformer_blocks = Vec::with_capacity(num_layers);
693 for _ in 0..num_layers {
694 transformer_blocks.push(BasicTransformerBlock::<T>::new(
695 inner_dim,
696 heads,
697 dim_head,
698 cross_attention_dim,
699 )?);
700 }
701 Ok(Self {
702 norm,
703 proj_in,
704 transformer_blocks,
705 proj_out,
706 channels: in_channels,
707 inner_dim,
708 training: false,
709 })
710 }
711
712 pub fn forward_xattn(
721 &self,
722 x: &Tensor<T>,
723 encoder_hidden_states: &Tensor<T>,
724 ) -> FerrotorchResult<Tensor<T>> {
725 if x.ndim() != 4 || x.shape()[1] != self.channels {
726 return Err(FerrotorchError::ShapeMismatch {
727 message: format!(
728 "Transformer2DModel::forward: expected [B, {}, H, W], got {:?}",
729 self.channels,
730 x.shape()
731 ),
732 });
733 }
734 let b = x.shape()[0];
735 let c = x.shape()[1];
736 let h = x.shape()[2];
737 let w = x.shape()[3];
738 let hw = h * w;
739
740 let residual = x.clone();
741 let mut hidden = self.norm.forward(x)?;
743 hidden = self.proj_in.forward(&hidden)?;
744 let mut hidden_seq = hidden
746 .reshape_t(&[b as isize, self.inner_dim as isize, hw as isize])?
747 .transpose(1, 2)?
748 .contiguous()?;
749 for block in &self.transformer_blocks {
751 hidden_seq = block.forward_xattn(&hidden_seq, encoder_hidden_states)?;
752 }
753 let hidden_back = hidden_seq
755 .transpose(1, 2)?
756 .reshape_t(&[b as isize, self.inner_dim as isize, h as isize, w as isize])?
757 .contiguous()?;
758 let out = self.proj_out.forward(&hidden_back)?;
760 let _ = c;
761 ferrotorch_core::grad_fns::arithmetic::add(&out, &residual)
762 }
763}
764
765impl<T: Float> Module<T> for Transformer2DModel<T> {
766 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
767 Err(FerrotorchError::InvalidArgument {
768 message: "Transformer2DModel::forward: cross-attn requires \
769 encoder_hidden_states — call forward_xattn instead"
770 .into(),
771 })
772 }
773
774 fn parameters(&self) -> Vec<&Parameter<T>> {
775 let mut o = Vec::new();
776 o.extend(self.norm.parameters());
777 o.extend(self.proj_in.parameters());
778 for b in &self.transformer_blocks {
779 o.extend(b.parameters());
780 }
781 o.extend(self.proj_out.parameters());
782 o
783 }
784 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
785 let mut o = Vec::new();
786 o.extend(self.norm.parameters_mut());
787 o.extend(self.proj_in.parameters_mut());
788 for b in &mut self.transformer_blocks {
789 o.extend(b.parameters_mut());
790 }
791 o.extend(self.proj_out.parameters_mut());
792 o
793 }
794 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
795 let mut o = Vec::new();
796 for (n, p) in self.norm.named_parameters() {
797 o.push((format!("norm.{n}"), p));
798 }
799 for (n, p) in self.proj_in.named_parameters() {
800 o.push((format!("proj_in.{n}"), p));
801 }
802 for (i, b) in self.transformer_blocks.iter().enumerate() {
803 for (n, p) in b.named_parameters() {
804 o.push((format!("transformer_blocks.{i}.{n}"), p));
805 }
806 }
807 for (n, p) in self.proj_out.named_parameters() {
808 o.push((format!("proj_out.{n}"), p));
809 }
810 o
811 }
812 fn train(&mut self) {
813 self.training = true;
814 }
815 fn eval(&mut self) {
816 self.training = false;
817 }
818 fn is_training(&self) -> bool {
819 self.training
820 }
821 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
822 let extract = |prefix: &str| -> StateDict<T> {
823 let p = format!("{prefix}.");
824 state
825 .iter()
826 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
827 .collect()
828 };
829 if strict {
830 for k in state.keys() {
831 let ok = k.starts_with("norm.")
832 || k.starts_with("proj_in.")
833 || k.starts_with("transformer_blocks.")
834 || k.starts_with("proj_out.");
835 if !ok {
836 return Err(FerrotorchError::InvalidArgument {
837 message: format!(
838 "unexpected key in Transformer2DModel state_dict: \"{k}\""
839 ),
840 });
841 }
842 }
843 }
844 self.norm.load_state_dict(&extract("norm"), strict)?;
845 self.proj_in.load_state_dict(&extract("proj_in"), strict)?;
846 for (i, b) in self.transformer_blocks.iter_mut().enumerate() {
847 b.load_state_dict(&extract(&format!("transformer_blocks.{i}")), strict)?;
848 }
849 self.proj_out
850 .load_state_dict(&extract("proj_out"), strict)?;
851 Ok(())
852 }
853}
854
855#[cfg(test)]
856mod tests {
857 use super::*;
858 use ferrotorch_core::TensorStorage;
859
860 #[test]
861 fn attention_self_shape() {
862 let a = Attention::<f32>::new(16, None, 4, 4, false).unwrap();
863 let x = Tensor::from_storage(
864 TensorStorage::cpu(vec![0.01f32; 5 * 16]),
865 vec![1, 5, 16],
866 false,
867 )
868 .unwrap();
869 let y = a.forward_xattn(&x, None).unwrap();
870 assert_eq!(y.shape(), &[1, 5, 16]);
871 }
872
873 #[test]
874 fn attention_cross_shape() {
875 let a = Attention::<f32>::new(16, Some(24), 4, 4, false).unwrap();
876 let x = Tensor::from_storage(
877 TensorStorage::cpu(vec![0.01f32; 5 * 16]),
878 vec![1, 5, 16],
879 false,
880 )
881 .unwrap();
882 let ehs = Tensor::from_storage(
883 TensorStorage::cpu(vec![0.01f32; 7 * 24]),
884 vec![1, 7, 24],
885 false,
886 )
887 .unwrap();
888 let y = a.forward_xattn(&x, Some(&ehs)).unwrap();
889 assert_eq!(y.shape(), &[1, 5, 16]);
890 }
891
892 #[test]
893 fn feedforward_shape_and_keys() {
894 let ff = FeedForward::<f32>::new(16, 2).unwrap();
895 let x = Tensor::from_storage(
896 TensorStorage::cpu(vec![0.01f32; 5 * 16]),
897 vec![1, 5, 16],
898 false,
899 )
900 .unwrap();
901 let y = ff.forward(&x).unwrap();
902 assert_eq!(y.shape(), &[1, 5, 16]);
903 let names: Vec<String> = ff.named_parameters().into_iter().map(|(n, _)| n).collect();
904 for k in [
905 "net.0.proj.weight",
906 "net.0.proj.bias",
907 "net.2.weight",
908 "net.2.bias",
909 ] {
910 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
911 }
912 }
913
914 #[test]
915 fn basic_transformer_block_shape() {
916 let blk = BasicTransformerBlock::<f32>::new(16, 4, 4, 24).unwrap();
917 let x = Tensor::from_storage(
918 TensorStorage::cpu(vec![0.01f32; 5 * 16]),
919 vec![1, 5, 16],
920 false,
921 )
922 .unwrap();
923 let ehs = Tensor::from_storage(
924 TensorStorage::cpu(vec![0.01f32; 7 * 24]),
925 vec![1, 7, 24],
926 false,
927 )
928 .unwrap();
929 let y = blk.forward_xattn(&x, &ehs).unwrap();
930 assert_eq!(y.shape(), &[1, 5, 16]);
931 }
932
933 #[test]
934 fn transformer_2d_shape() {
935 let t = Transformer2DModel::<f32>::new(16, 4, 4, 1, 24, 4).unwrap();
936 let x = Tensor::from_storage(
937 TensorStorage::cpu(vec![0.01f32; 16 * 3 * 3]),
938 vec![1, 16, 3, 3],
939 false,
940 )
941 .unwrap();
942 let ehs = Tensor::from_storage(
943 TensorStorage::cpu(vec![0.01f32; 5 * 24]),
944 vec![1, 5, 24],
945 false,
946 )
947 .unwrap();
948 let y = t.forward_xattn(&x, &ehs).unwrap();
949 assert_eq!(y.shape(), &[1, 16, 3, 3]);
950 }
951
952 #[test]
953 fn transformer_2d_named_parameters() {
954 let t = Transformer2DModel::<f32>::new(16, 4, 4, 1, 24, 4).unwrap();
955 let names: Vec<String> = t.named_parameters().into_iter().map(|(n, _)| n).collect();
956 for k in [
957 "norm.weight",
958 "proj_in.weight",
959 "proj_in.bias",
960 "transformer_blocks.0.norm1.weight",
961 "transformer_blocks.0.attn1.to_q.weight",
962 "transformer_blocks.0.attn2.to_k.weight",
963 "transformer_blocks.0.ff.net.0.proj.weight",
964 "transformer_blocks.0.ff.net.2.weight",
965 "proj_out.weight",
966 ] {
967 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
968 }
969 }
970}