1use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
41use ferrotorch_nn::module::{Module, StateDict};
42use ferrotorch_nn::parameter::Parameter;
43use ferrotorch_nn::{Conv2d, GELU, GroupNorm, LayerNorm, Linear};
44
45#[derive(Debug)]
59pub struct Attention<T: Float> {
60 pub dim_head: usize,
62 pub heads: usize,
64 pub inner_dim: usize,
66 pub to_q: Linear<T>,
68 pub to_k: Linear<T>,
70 pub to_v: Linear<T>,
72 pub to_out_0: Linear<T>,
74 query_dim: usize,
75 kv_dim: usize,
76 scale: f64,
77 training: bool,
78}
79
80impl<T: Float> Attention<T> {
81 pub fn new(
95 query_dim: usize,
96 cross_attention_dim: Option<usize>,
97 heads: usize,
98 dim_head: usize,
99 bias: bool,
100 ) -> FerrotorchResult<Self> {
101 let inner_dim = heads * dim_head;
102 let kv_dim = cross_attention_dim.unwrap_or(query_dim);
103 let to_q = Linear::<T>::new(query_dim, inner_dim, bias)?;
104 let to_k = Linear::<T>::new(kv_dim, inner_dim, bias)?;
105 let to_v = Linear::<T>::new(kv_dim, inner_dim, bias)?;
106 let to_out_0 = Linear::<T>::new(inner_dim, query_dim, true)?;
107 let scale = (dim_head as f64).sqrt().recip();
108 Ok(Self {
109 dim_head,
110 heads,
111 inner_dim,
112 to_q,
113 to_k,
114 to_v,
115 to_out_0,
116 query_dim,
117 kv_dim,
118 scale,
119 training: false,
120 })
121 }
122
123 pub fn forward_xattn(
134 &self,
135 hidden_states: &Tensor<T>,
136 encoder_hidden_states: Option<&Tensor<T>>,
137 ) -> FerrotorchResult<Tensor<T>> {
138 if hidden_states.ndim() != 3 || hidden_states.shape()[2] != self.query_dim {
139 return Err(FerrotorchError::ShapeMismatch {
140 message: format!(
141 "Attention::forward_xattn: expected hidden_states [B, N, {}], got {:?}",
142 self.query_dim,
143 hidden_states.shape()
144 ),
145 });
146 }
147 let b = hidden_states.shape()[0];
148 let n = hidden_states.shape()[1];
149 let kv = encoder_hidden_states.unwrap_or(hidden_states);
151 if kv.ndim() != 3 || kv.shape()[0] != b || kv.shape()[2] != self.kv_dim {
152 return Err(FerrotorchError::ShapeMismatch {
153 message: format!(
154 "Attention::forward_xattn: expected kv [B={b}, S, {}], got {:?}",
155 self.kv_dim,
156 kv.shape()
157 ),
158 });
159 }
160 let s = kv.shape()[1];
161
162 let q = self.to_q.forward(hidden_states)?;
164 let k = self.to_k.forward(kv)?;
165 let v = self.to_v.forward(kv)?;
166
167 let h = self.heads;
171 let d = self.dim_head;
172 let q = q
173 .reshape_t(&[b as isize, n as isize, h as isize, d as isize])?
174 .transpose(1, 2)? .contiguous()?
176 .reshape_t(&[(b * h) as isize, n as isize, d as isize])?;
177 let k = k
178 .reshape_t(&[b as isize, s as isize, h as isize, d as isize])?
179 .transpose(1, 2)? .contiguous()?
181 .reshape_t(&[(b * h) as isize, s as isize, d as isize])?;
182 let v = v
183 .reshape_t(&[b as isize, s as isize, h as isize, d as isize])?
184 .transpose(1, 2)? .contiguous()?
186 .reshape_t(&[(b * h) as isize, s as isize, d as isize])?;
187
188 let k_t = k.transpose(1, 2)?.contiguous()?; let scores = q.bmm(&k_t)?; let scale_t = T::from(self.scale).ok_or_else(|| FerrotorchError::InvalidArgument {
192 message: "Attention::forward_xattn: failed to cast attention scale into Float".into(),
193 })?;
194 let scale_tensor = ferrotorch_core::scalar::<T>(scale_t)?;
195 let scores_scaled = ferrotorch_core::grad_fns::arithmetic::mul(&scores, &scale_tensor)?;
196 let probs = scores_scaled.softmax()?; let attended = probs.bmm(&v)?; let attended = attended
202 .reshape_t(&[b as isize, h as isize, n as isize, d as isize])?
203 .transpose(1, 2)? .contiguous()?
205 .reshape_t(&[b as isize, n as isize, self.inner_dim as isize])?;
206
207 self.to_out_0.forward(&attended)
209 }
210}
211
212impl<T: Float> Module<T> for Attention<T> {
213 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
214 self.forward_xattn(input, None)
216 }
217
218 fn parameters(&self) -> Vec<&Parameter<T>> {
219 let mut o = Vec::new();
220 o.extend(self.to_q.parameters());
221 o.extend(self.to_k.parameters());
222 o.extend(self.to_v.parameters());
223 o.extend(self.to_out_0.parameters());
224 o
225 }
226 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
227 let mut o = Vec::new();
228 o.extend(self.to_q.parameters_mut());
229 o.extend(self.to_k.parameters_mut());
230 o.extend(self.to_v.parameters_mut());
231 o.extend(self.to_out_0.parameters_mut());
232 o
233 }
234 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
235 let mut o = Vec::new();
236 for (n, p) in self.to_q.named_parameters() {
237 o.push((format!("to_q.{n}"), p));
238 }
239 for (n, p) in self.to_k.named_parameters() {
240 o.push((format!("to_k.{n}"), p));
241 }
242 for (n, p) in self.to_v.named_parameters() {
243 o.push((format!("to_v.{n}"), p));
244 }
245 for (n, p) in self.to_out_0.named_parameters() {
246 o.push((format!("to_out.0.{n}"), p));
247 }
248 o
249 }
250 fn train(&mut self) {
251 self.training = true;
252 }
253 fn eval(&mut self) {
254 self.training = false;
255 }
256 fn is_training(&self) -> bool {
257 self.training
258 }
259 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
260 let extract = |prefix: &str| -> StateDict<T> {
261 let p = format!("{prefix}.");
262 state
263 .iter()
264 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
265 .collect()
266 };
267 if strict {
268 for k in state.keys() {
269 let ok = k.starts_with("to_q.")
270 || k.starts_with("to_k.")
271 || k.starts_with("to_v.")
272 || k.starts_with("to_out.0.");
273 if !ok {
274 return Err(FerrotorchError::InvalidArgument {
275 message: format!("unexpected key in Attention state_dict: \"{k}\""),
276 });
277 }
278 }
279 }
280 self.to_q.load_state_dict(&extract("to_q"), strict)?;
281 self.to_k.load_state_dict(&extract("to_k"), strict)?;
282 self.to_v.load_state_dict(&extract("to_v"), strict)?;
283 self.to_out_0
284 .load_state_dict(&extract("to_out.0"), strict)?;
285 Ok(())
286 }
287}
288
289#[derive(Debug)]
315pub struct FeedForward<T: Float> {
316 pub net_0_proj: Linear<T>,
318 pub net_2: Linear<T>,
320 activation: GELU,
321 dim_ff: usize,
322 training: bool,
323}
324
325impl<T: Float> FeedForward<T> {
326 pub fn new(dim: usize, mult: usize) -> FerrotorchResult<Self> {
332 let dim_ff = dim * mult;
333 let net_0_proj = Linear::<T>::new(dim, 2 * dim_ff, true)?;
334 let net_2 = Linear::<T>::new(dim_ff, dim, true)?;
335 Ok(Self {
336 net_0_proj,
337 net_2,
338 activation: GELU::new(),
339 dim_ff,
340 training: false,
341 })
342 }
343}
344
345impl<T: Float> Module<T> for FeedForward<T> {
346 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
347 let proj = self.net_0_proj.forward(input)?;
349 let last = proj.ndim() - 1;
351 let parts = proj.chunk(2, last)?;
352 if parts.len() != 2 {
353 return Err(FerrotorchError::ShapeMismatch {
354 message: format!(
355 "FeedForward: chunk(2) returned {} parts (expected 2)",
356 parts.len()
357 ),
358 });
359 }
360 let x = parts[0].contiguous()?;
361 let gate = parts[1].contiguous()?;
362 let gated = self.activation.forward(&gate)?;
363 let activated = ferrotorch_core::grad_fns::arithmetic::mul(&x, &gated)?;
364 self.net_2.forward(&activated)
365 }
366 fn parameters(&self) -> Vec<&Parameter<T>> {
367 let mut o = Vec::new();
368 o.extend(self.net_0_proj.parameters());
369 o.extend(self.net_2.parameters());
370 o
371 }
372 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
373 let mut o = Vec::new();
374 o.extend(self.net_0_proj.parameters_mut());
375 o.extend(self.net_2.parameters_mut());
376 o
377 }
378 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
379 let mut o = Vec::new();
380 for (n, p) in self.net_0_proj.named_parameters() {
381 o.push((format!("net.0.proj.{n}"), p));
382 }
383 for (n, p) in self.net_2.named_parameters() {
384 o.push((format!("net.2.{n}"), p));
385 }
386 o
387 }
388 fn train(&mut self) {
389 self.training = true;
390 }
391 fn eval(&mut self) {
392 self.training = false;
393 }
394 fn is_training(&self) -> bool {
395 self.training
396 }
397 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
398 let extract = |prefix: &str| -> StateDict<T> {
399 let p = format!("{prefix}.");
400 state
401 .iter()
402 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
403 .collect()
404 };
405 if strict {
406 for k in state.keys() {
407 let ok = k.starts_with("net.0.proj.") || k.starts_with("net.2.");
408 if !ok {
409 return Err(FerrotorchError::InvalidArgument {
410 message: format!("unexpected key in FeedForward state_dict: \"{k}\""),
411 });
412 }
413 }
414 }
415 self.net_0_proj
416 .load_state_dict(&extract("net.0.proj"), strict)?;
417 self.net_2.load_state_dict(&extract("net.2"), strict)?;
418 let _ = self.dim_ff;
419 Ok(())
420 }
421}
422
423#[derive(Debug)]
442pub struct BasicTransformerBlock<T: Float> {
443 pub norm1: LayerNorm<T>,
445 pub attn1: Attention<T>,
447 pub norm2: LayerNorm<T>,
449 pub attn2: Attention<T>,
451 pub norm3: LayerNorm<T>,
453 pub ff: FeedForward<T>,
455 dim: usize,
456 training: bool,
457}
458
459impl<T: Float> BasicTransformerBlock<T> {
460 pub fn new(
466 dim: usize,
467 heads: usize,
468 dim_head: usize,
469 cross_attention_dim: usize,
470 ) -> FerrotorchResult<Self> {
471 let norm1 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
475 let attn1 = Attention::<T>::new(dim, None, heads, dim_head, false)?;
476 let norm2 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
477 let attn2 = Attention::<T>::new(dim, Some(cross_attention_dim), heads, dim_head, false)?;
478 let norm3 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
479 let ff = FeedForward::<T>::new(dim, 4)?;
480 Ok(Self {
481 norm1,
482 attn1,
483 norm2,
484 attn2,
485 norm3,
486 ff,
487 dim,
488 training: false,
489 })
490 }
491
492 pub fn forward_xattn(
500 &self,
501 x: &Tensor<T>,
502 encoder_hidden_states: &Tensor<T>,
503 ) -> FerrotorchResult<Tensor<T>> {
504 if x.ndim() != 3 || x.shape()[2] != self.dim {
505 return Err(FerrotorchError::ShapeMismatch {
506 message: format!(
507 "BasicTransformerBlock::forward: expected x [B, N, {}], got {:?}",
508 self.dim,
509 x.shape()
510 ),
511 });
512 }
513 let h1 = self.norm1.forward(x)?;
515 let h1 = self.attn1.forward_xattn(&h1, None)?;
516 let x = ferrotorch_core::grad_fns::arithmetic::add(&h1, x)?;
517 let h2 = self.norm2.forward(&x)?;
519 let h2 = self.attn2.forward_xattn(&h2, Some(encoder_hidden_states))?;
520 let x = ferrotorch_core::grad_fns::arithmetic::add(&h2, &x)?;
521 let h3 = self.norm3.forward(&x)?;
523 let h3 = self.ff.forward(&h3)?;
524 ferrotorch_core::grad_fns::arithmetic::add(&h3, &x)
525 }
526}
527
528impl<T: Float> Module<T> for BasicTransformerBlock<T> {
529 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
530 Err(FerrotorchError::InvalidArgument {
531 message: "BasicTransformerBlock::forward: cross-attn requires \
532 encoder_hidden_states — call forward_xattn instead"
533 .into(),
534 })
535 }
536
537 fn parameters(&self) -> Vec<&Parameter<T>> {
538 let mut o = Vec::new();
539 o.extend(self.norm1.parameters());
540 o.extend(self.attn1.parameters());
541 o.extend(self.norm2.parameters());
542 o.extend(self.attn2.parameters());
543 o.extend(self.norm3.parameters());
544 o.extend(self.ff.parameters());
545 o
546 }
547 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
548 let mut o = Vec::new();
549 o.extend(self.norm1.parameters_mut());
550 o.extend(self.attn1.parameters_mut());
551 o.extend(self.norm2.parameters_mut());
552 o.extend(self.attn2.parameters_mut());
553 o.extend(self.norm3.parameters_mut());
554 o.extend(self.ff.parameters_mut());
555 o
556 }
557 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
558 let mut o = Vec::new();
559 for (n, p) in self.norm1.named_parameters() {
560 o.push((format!("norm1.{n}"), p));
561 }
562 for (n, p) in self.attn1.named_parameters() {
563 o.push((format!("attn1.{n}"), p));
564 }
565 for (n, p) in self.norm2.named_parameters() {
566 o.push((format!("norm2.{n}"), p));
567 }
568 for (n, p) in self.attn2.named_parameters() {
569 o.push((format!("attn2.{n}"), p));
570 }
571 for (n, p) in self.norm3.named_parameters() {
572 o.push((format!("norm3.{n}"), p));
573 }
574 for (n, p) in self.ff.named_parameters() {
575 o.push((format!("ff.{n}"), p));
576 }
577 o
578 }
579 fn train(&mut self) {
580 self.training = true;
581 }
582 fn eval(&mut self) {
583 self.training = false;
584 }
585 fn is_training(&self) -> bool {
586 self.training
587 }
588 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
589 let extract = |prefix: &str| -> StateDict<T> {
590 let p = format!("{prefix}.");
591 state
592 .iter()
593 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
594 .collect()
595 };
596 if strict {
597 for k in state.keys() {
598 let ok = k.starts_with("norm1.")
599 || k.starts_with("attn1.")
600 || k.starts_with("norm2.")
601 || k.starts_with("attn2.")
602 || k.starts_with("norm3.")
603 || k.starts_with("ff.");
604 if !ok {
605 return Err(FerrotorchError::InvalidArgument {
606 message: format!(
607 "unexpected key in BasicTransformerBlock state_dict: \"{k}\""
608 ),
609 });
610 }
611 }
612 }
613 self.norm1.load_state_dict(&extract("norm1"), strict)?;
614 self.attn1.load_state_dict(&extract("attn1"), strict)?;
615 self.norm2.load_state_dict(&extract("norm2"), strict)?;
616 self.attn2.load_state_dict(&extract("attn2"), strict)?;
617 self.norm3.load_state_dict(&extract("norm3"), strict)?;
618 self.ff.load_state_dict(&extract("ff"), strict)?;
619 Ok(())
620 }
621}
622
623#[derive(Debug)]
645pub struct Transformer2DModel<T: Float> {
646 pub norm: GroupNorm<T>,
648 pub proj_in: Conv2d<T>,
650 pub transformer_blocks: Vec<BasicTransformerBlock<T>>,
652 pub proj_out: Conv2d<T>,
654 channels: usize,
655 inner_dim: usize,
656 training: bool,
657}
658
659impl<T: Float> Transformer2DModel<T> {
660 pub fn new(
670 in_channels: usize,
671 heads: usize,
672 dim_head: usize,
673 num_layers: usize,
674 cross_attention_dim: usize,
675 norm_num_groups: usize,
676 ) -> FerrotorchResult<Self> {
677 let inner_dim = heads * dim_head;
678 let norm = GroupNorm::<T>::new(norm_num_groups, in_channels, 1e-6, true)?;
679 let proj_in = Conv2d::<T>::new(in_channels, inner_dim, (1, 1), (1, 1), (0, 0), true)?;
680 let proj_out = Conv2d::<T>::new(inner_dim, in_channels, (1, 1), (1, 1), (0, 0), true)?;
681 let mut transformer_blocks = Vec::with_capacity(num_layers);
682 for _ in 0..num_layers {
683 transformer_blocks.push(BasicTransformerBlock::<T>::new(
684 inner_dim,
685 heads,
686 dim_head,
687 cross_attention_dim,
688 )?);
689 }
690 Ok(Self {
691 norm,
692 proj_in,
693 transformer_blocks,
694 proj_out,
695 channels: in_channels,
696 inner_dim,
697 training: false,
698 })
699 }
700
701 pub fn forward_xattn(
710 &self,
711 x: &Tensor<T>,
712 encoder_hidden_states: &Tensor<T>,
713 ) -> FerrotorchResult<Tensor<T>> {
714 if x.ndim() != 4 || x.shape()[1] != self.channels {
715 return Err(FerrotorchError::ShapeMismatch {
716 message: format!(
717 "Transformer2DModel::forward: expected [B, {}, H, W], got {:?}",
718 self.channels,
719 x.shape()
720 ),
721 });
722 }
723 let b = x.shape()[0];
724 let c = x.shape()[1];
725 let h = x.shape()[2];
726 let w = x.shape()[3];
727 let hw = h * w;
728
729 let residual = x.clone();
730 let mut hidden = self.norm.forward(x)?;
732 hidden = self.proj_in.forward(&hidden)?;
733 let mut hidden_seq = hidden
735 .reshape_t(&[b as isize, self.inner_dim as isize, hw as isize])?
736 .transpose(1, 2)?
737 .contiguous()?;
738 for block in &self.transformer_blocks {
740 hidden_seq = block.forward_xattn(&hidden_seq, encoder_hidden_states)?;
741 }
742 let hidden_back = hidden_seq
744 .transpose(1, 2)?
745 .reshape_t(&[b as isize, self.inner_dim as isize, h as isize, w as isize])?
746 .contiguous()?;
747 let out = self.proj_out.forward(&hidden_back)?;
749 let _ = c;
750 ferrotorch_core::grad_fns::arithmetic::add(&out, &residual)
751 }
752}
753
754impl<T: Float> Module<T> for Transformer2DModel<T> {
755 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
756 Err(FerrotorchError::InvalidArgument {
757 message: "Transformer2DModel::forward: cross-attn requires \
758 encoder_hidden_states — call forward_xattn instead"
759 .into(),
760 })
761 }
762
763 fn parameters(&self) -> Vec<&Parameter<T>> {
764 let mut o = Vec::new();
765 o.extend(self.norm.parameters());
766 o.extend(self.proj_in.parameters());
767 for b in &self.transformer_blocks {
768 o.extend(b.parameters());
769 }
770 o.extend(self.proj_out.parameters());
771 o
772 }
773 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
774 let mut o = Vec::new();
775 o.extend(self.norm.parameters_mut());
776 o.extend(self.proj_in.parameters_mut());
777 for b in &mut self.transformer_blocks {
778 o.extend(b.parameters_mut());
779 }
780 o.extend(self.proj_out.parameters_mut());
781 o
782 }
783 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
784 let mut o = Vec::new();
785 for (n, p) in self.norm.named_parameters() {
786 o.push((format!("norm.{n}"), p));
787 }
788 for (n, p) in self.proj_in.named_parameters() {
789 o.push((format!("proj_in.{n}"), p));
790 }
791 for (i, b) in self.transformer_blocks.iter().enumerate() {
792 for (n, p) in b.named_parameters() {
793 o.push((format!("transformer_blocks.{i}.{n}"), p));
794 }
795 }
796 for (n, p) in self.proj_out.named_parameters() {
797 o.push((format!("proj_out.{n}"), p));
798 }
799 o
800 }
801 fn train(&mut self) {
802 self.training = true;
803 }
804 fn eval(&mut self) {
805 self.training = false;
806 }
807 fn is_training(&self) -> bool {
808 self.training
809 }
810 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
811 let extract = |prefix: &str| -> StateDict<T> {
812 let p = format!("{prefix}.");
813 state
814 .iter()
815 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
816 .collect()
817 };
818 if strict {
819 for k in state.keys() {
820 let ok = k.starts_with("norm.")
821 || k.starts_with("proj_in.")
822 || k.starts_with("transformer_blocks.")
823 || k.starts_with("proj_out.");
824 if !ok {
825 return Err(FerrotorchError::InvalidArgument {
826 message: format!(
827 "unexpected key in Transformer2DModel state_dict: \"{k}\""
828 ),
829 });
830 }
831 }
832 }
833 self.norm.load_state_dict(&extract("norm"), strict)?;
834 self.proj_in.load_state_dict(&extract("proj_in"), strict)?;
835 for (i, b) in self.transformer_blocks.iter_mut().enumerate() {
836 b.load_state_dict(&extract(&format!("transformer_blocks.{i}")), strict)?;
837 }
838 self.proj_out
839 .load_state_dict(&extract("proj_out"), strict)?;
840 Ok(())
841 }
842}
843
844#[cfg(test)]
845mod tests {
846 use super::*;
847 use ferrotorch_core::TensorStorage;
848
849 #[test]
850 fn attention_self_shape() {
851 let a = Attention::<f32>::new(16, None, 4, 4, false).unwrap();
852 let x = Tensor::from_storage(
853 TensorStorage::cpu(vec![0.01f32; 5 * 16]),
854 vec![1, 5, 16],
855 false,
856 )
857 .unwrap();
858 let y = a.forward_xattn(&x, None).unwrap();
859 assert_eq!(y.shape(), &[1, 5, 16]);
860 }
861
862 #[test]
863 fn attention_cross_shape() {
864 let a = Attention::<f32>::new(16, Some(24), 4, 4, false).unwrap();
865 let x = Tensor::from_storage(
866 TensorStorage::cpu(vec![0.01f32; 5 * 16]),
867 vec![1, 5, 16],
868 false,
869 )
870 .unwrap();
871 let ehs = Tensor::from_storage(
872 TensorStorage::cpu(vec![0.01f32; 7 * 24]),
873 vec![1, 7, 24],
874 false,
875 )
876 .unwrap();
877 let y = a.forward_xattn(&x, Some(&ehs)).unwrap();
878 assert_eq!(y.shape(), &[1, 5, 16]);
879 }
880
881 #[test]
882 fn feedforward_shape_and_keys() {
883 let ff = FeedForward::<f32>::new(16, 2).unwrap();
884 let x = Tensor::from_storage(
885 TensorStorage::cpu(vec![0.01f32; 5 * 16]),
886 vec![1, 5, 16],
887 false,
888 )
889 .unwrap();
890 let y = ff.forward(&x).unwrap();
891 assert_eq!(y.shape(), &[1, 5, 16]);
892 let names: Vec<String> = ff.named_parameters().into_iter().map(|(n, _)| n).collect();
893 for k in [
894 "net.0.proj.weight",
895 "net.0.proj.bias",
896 "net.2.weight",
897 "net.2.bias",
898 ] {
899 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
900 }
901 }
902
903 #[test]
904 fn basic_transformer_block_shape() {
905 let blk = BasicTransformerBlock::<f32>::new(16, 4, 4, 24).unwrap();
906 let x = Tensor::from_storage(
907 TensorStorage::cpu(vec![0.01f32; 5 * 16]),
908 vec![1, 5, 16],
909 false,
910 )
911 .unwrap();
912 let ehs = Tensor::from_storage(
913 TensorStorage::cpu(vec![0.01f32; 7 * 24]),
914 vec![1, 7, 24],
915 false,
916 )
917 .unwrap();
918 let y = blk.forward_xattn(&x, &ehs).unwrap();
919 assert_eq!(y.shape(), &[1, 5, 16]);
920 }
921
922 #[test]
923 fn transformer_2d_shape() {
924 let t = Transformer2DModel::<f32>::new(16, 4, 4, 1, 24, 4).unwrap();
925 let x = Tensor::from_storage(
926 TensorStorage::cpu(vec![0.01f32; 16 * 3 * 3]),
927 vec![1, 16, 3, 3],
928 false,
929 )
930 .unwrap();
931 let ehs = Tensor::from_storage(
932 TensorStorage::cpu(vec![0.01f32; 5 * 24]),
933 vec![1, 5, 24],
934 false,
935 )
936 .unwrap();
937 let y = t.forward_xattn(&x, &ehs).unwrap();
938 assert_eq!(y.shape(), &[1, 16, 3, 3]);
939 }
940
941 #[test]
942 fn transformer_2d_named_parameters() {
943 let t = Transformer2DModel::<f32>::new(16, 4, 4, 1, 24, 4).unwrap();
944 let names: Vec<String> = t.named_parameters().into_iter().map(|(n, _)| n).collect();
945 for k in [
946 "norm.weight",
947 "proj_in.weight",
948 "proj_in.bias",
949 "transformer_blocks.0.norm1.weight",
950 "transformer_blocks.0.attn1.to_q.weight",
951 "transformer_blocks.0.attn2.to_k.weight",
952 "transformer_blocks.0.ff.net.0.proj.weight",
953 "transformer_blocks.0.ff.net.2.weight",
954 "proj_out.weight",
955 ] {
956 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
957 }
958 }
959}