1use std::collections::HashMap;
10
11use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
12use ferrotorch_nn::module::{Module, StateDict};
13use ferrotorch_nn::parameter::Parameter;
14use ferrotorch_nn::{
15 Conv2d, GroupNorm, InterpolateMode, Linear, SiLU, Upsample,
16};
17
18#[derive(Debug)]
38pub struct ResnetBlock2D<T: Float> {
39 pub norm1: GroupNorm<T>,
41 pub conv1: Conv2d<T>,
43 pub norm2: GroupNorm<T>,
45 pub conv2: Conv2d<T>,
47 pub conv_shortcut: Option<Conv2d<T>>,
50 activation: SiLU,
51 in_channels: usize,
52 #[allow(dead_code)]
53 out_channels: usize,
54 training: bool,
55}
56
57impl<T: Float> ResnetBlock2D<T> {
58 pub fn new(
65 in_channels: usize,
66 out_channels: usize,
67 norm_num_groups: usize,
68 eps: f64,
69 ) -> FerrotorchResult<Self> {
70 let norm1 = GroupNorm::<T>::new(norm_num_groups, in_channels, eps, true)?;
71 let conv1 = Conv2d::<T>::new(in_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
72 let norm2 = GroupNorm::<T>::new(norm_num_groups, out_channels, eps, true)?;
73 let conv2 = Conv2d::<T>::new(out_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
74 let conv_shortcut = if in_channels == out_channels {
75 None
76 } else {
77 Some(Conv2d::<T>::new(
78 in_channels,
79 out_channels,
80 (1, 1),
81 (1, 1),
82 (0, 0),
83 true,
84 )?)
85 };
86 Ok(Self {
87 norm1,
88 conv1,
89 norm2,
90 conv2,
91 conv_shortcut,
92 activation: SiLU::new(),
93 in_channels,
94 out_channels,
95 training: false,
96 })
97 }
98}
99
100impl<T: Float> Module<T> for ResnetBlock2D<T> {
101 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
102 if input.ndim() != 4 || input.shape()[1] != self.in_channels {
103 return Err(FerrotorchError::ShapeMismatch {
104 message: format!(
105 "ResnetBlock2D::forward: expected [B, {}, H, W], got {:?}",
106 self.in_channels,
107 input.shape()
108 ),
109 });
110 }
111 let mut h = self.norm1.forward(input)?;
113 h = self.activation.forward(&h)?;
114 h = self.conv1.forward(&h)?;
115 h = self.norm2.forward(&h)?;
117 h = self.activation.forward(&h)?;
118 h = self.conv2.forward(&h)?;
119 let res = if let Some(sc) = &self.conv_shortcut {
121 sc.forward(input)?
122 } else {
123 input.clone()
124 };
125 ferrotorch_core::grad_fns::arithmetic::add(&h, &res)
128 }
129
130 fn parameters(&self) -> Vec<&Parameter<T>> {
131 let mut out = Vec::new();
132 out.extend(self.norm1.parameters());
133 out.extend(self.conv1.parameters());
134 out.extend(self.norm2.parameters());
135 out.extend(self.conv2.parameters());
136 if let Some(sc) = &self.conv_shortcut {
137 out.extend(sc.parameters());
138 }
139 out
140 }
141
142 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
143 let mut out = Vec::new();
144 out.extend(self.norm1.parameters_mut());
145 out.extend(self.conv1.parameters_mut());
146 out.extend(self.norm2.parameters_mut());
147 out.extend(self.conv2.parameters_mut());
148 if let Some(sc) = &mut self.conv_shortcut {
149 out.extend(sc.parameters_mut());
150 }
151 out
152 }
153
154 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
155 let mut out = Vec::new();
156 for (n, p) in self.norm1.named_parameters() {
157 out.push((format!("norm1.{n}"), p));
158 }
159 for (n, p) in self.conv1.named_parameters() {
160 out.push((format!("conv1.{n}"), p));
161 }
162 for (n, p) in self.norm2.named_parameters() {
163 out.push((format!("norm2.{n}"), p));
164 }
165 for (n, p) in self.conv2.named_parameters() {
166 out.push((format!("conv2.{n}"), p));
167 }
168 if let Some(sc) = &self.conv_shortcut {
169 for (n, p) in sc.named_parameters() {
170 out.push((format!("conv_shortcut.{n}"), p));
171 }
172 }
173 out
174 }
175
176 fn train(&mut self) {
177 self.training = true;
178 }
179 fn eval(&mut self) {
180 self.training = false;
181 }
182 fn is_training(&self) -> bool {
183 self.training
184 }
185
186 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
187 let extract = |prefix: &str| -> StateDict<T> {
188 let p = format!("{prefix}.");
189 state
190 .iter()
191 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
192 .collect()
193 };
194 if strict {
195 for k in state.keys() {
196 let ok = k.starts_with("norm1.")
197 || k.starts_with("conv1.")
198 || k.starts_with("norm2.")
199 || k.starts_with("conv2.")
200 || k.starts_with("conv_shortcut.");
201 if !ok {
202 return Err(FerrotorchError::InvalidArgument {
203 message: format!("unexpected key in ResnetBlock2D state_dict: \"{k}\""),
204 });
205 }
206 }
207 }
208 self.norm1.load_state_dict(&extract("norm1"), strict)?;
209 self.conv1.load_state_dict(&extract("conv1"), strict)?;
210 self.norm2.load_state_dict(&extract("norm2"), strict)?;
211 self.conv2.load_state_dict(&extract("conv2"), strict)?;
212 if let Some(sc) = self.conv_shortcut.as_mut() {
213 sc.load_state_dict(&extract("conv_shortcut"), strict)?;
214 }
215 Ok(())
216 }
217}
218
219#[derive(Debug)]
245pub struct AttnBlock2D<T: Float> {
246 pub group_norm: GroupNorm<T>,
248 pub to_q: Linear<T>,
250 pub to_k: Linear<T>,
252 pub to_v: Linear<T>,
254 pub to_out_0: Linear<T>,
256 channels: usize,
257 training: bool,
258}
259
260impl<T: Float> AttnBlock2D<T> {
261 pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
269 let group_norm = GroupNorm::<T>::new(norm_num_groups, channels, eps, true)?;
270 let to_q = Linear::<T>::new(channels, channels, true)?;
271 let to_k = Linear::<T>::new(channels, channels, true)?;
272 let to_v = Linear::<T>::new(channels, channels, true)?;
273 let to_out_0 = Linear::<T>::new(channels, channels, true)?;
274 Ok(Self {
275 group_norm,
276 to_q,
277 to_k,
278 to_v,
279 to_out_0,
280 channels,
281 training: false,
282 })
283 }
284}
285
286impl<T: Float> Module<T> for AttnBlock2D<T> {
287 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
288 if input.ndim() != 4 || input.shape()[1] != self.channels {
289 return Err(FerrotorchError::ShapeMismatch {
290 message: format!(
291 "AttnBlock2D::forward: expected [B, {}, H, W], got {:?}",
292 self.channels,
293 input.shape()
294 ),
295 });
296 }
297 let b = input.shape()[0];
298 let c = input.shape()[1];
299 let h = input.shape()[2];
300 let w = input.shape()[3];
301 let hw = h * w;
302
303 let residual = input.clone();
305
306 let hidden = input
311 .reshape_t(&[b as isize, c as isize, hw as isize])?
312 .transpose(1, 2)?
313 .contiguous()?;
314 let normed_hwc = self
316 .group_norm
317 .forward(&hidden.transpose(1, 2)?.contiguous()?)?
318 .transpose(1, 2)?
319 .contiguous()?;
320
321 let q = self.to_q.forward(&normed_hwc)?; let k = self.to_k.forward(&normed_hwc)?; let v = self.to_v.forward(&normed_hwc)?; let scale = (c as f64).sqrt().recip();
334 let scale_t = T::from(scale).ok_or_else(|| FerrotorchError::InvalidArgument {
335 message: "AttnBlock2D::forward: failed to cast attention scale into Float".into(),
336 })?;
337 let scale_tensor = ferrotorch_core::scalar::<T>(scale_t)?;
338 let k_t = k.transpose(1, 2)?.contiguous()?; let scores = q.bmm(&k_t)?; let scores_scaled =
341 ferrotorch_core::grad_fns::arithmetic::mul(&scores, &scale_tensor)?;
342 let probs = scores_scaled.softmax()?; let attended = probs.bmm(&v)?; let projected = self.to_out_0.forward(&attended)?; let back = projected
350 .transpose(1, 2)? .reshape_t(&[b as isize, c as isize, h as isize, w as isize])?
352 .contiguous()?;
353 ferrotorch_core::grad_fns::arithmetic::add(&back, &residual)
354 }
355
356 fn parameters(&self) -> Vec<&Parameter<T>> {
357 let mut out = Vec::new();
358 out.extend(self.group_norm.parameters());
359 out.extend(self.to_q.parameters());
360 out.extend(self.to_k.parameters());
361 out.extend(self.to_v.parameters());
362 out.extend(self.to_out_0.parameters());
363 out
364 }
365
366 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
367 let mut out = Vec::new();
368 out.extend(self.group_norm.parameters_mut());
369 out.extend(self.to_q.parameters_mut());
370 out.extend(self.to_k.parameters_mut());
371 out.extend(self.to_v.parameters_mut());
372 out.extend(self.to_out_0.parameters_mut());
373 out
374 }
375
376 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
377 let mut out = Vec::new();
378 for (n, p) in self.group_norm.named_parameters() {
379 out.push((format!("group_norm.{n}"), p));
380 }
381 for (n, p) in self.to_q.named_parameters() {
382 out.push((format!("to_q.{n}"), p));
383 }
384 for (n, p) in self.to_k.named_parameters() {
385 out.push((format!("to_k.{n}"), p));
386 }
387 for (n, p) in self.to_v.named_parameters() {
388 out.push((format!("to_v.{n}"), p));
389 }
390 for (n, p) in self.to_out_0.named_parameters() {
391 out.push((format!("to_out.0.{n}"), p));
392 }
393 out
394 }
395
396 fn train(&mut self) {
397 self.training = true;
398 }
399 fn eval(&mut self) {
400 self.training = false;
401 }
402 fn is_training(&self) -> bool {
403 self.training
404 }
405
406 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
407 let extract = |prefix: &str| -> StateDict<T> {
408 let p = format!("{prefix}.");
409 state
410 .iter()
411 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
412 .collect()
413 };
414 if strict {
415 for k in state.keys() {
416 let ok = k.starts_with("group_norm.")
417 || k.starts_with("to_q.")
418 || k.starts_with("to_k.")
419 || k.starts_with("to_v.")
420 || k.starts_with("to_out.0.");
421 if !ok {
422 return Err(FerrotorchError::InvalidArgument {
423 message: format!("unexpected key in AttnBlock2D state_dict: \"{k}\""),
424 });
425 }
426 }
427 }
428 self.group_norm
429 .load_state_dict(&extract("group_norm"), strict)?;
430 self.to_q.load_state_dict(&extract("to_q"), strict)?;
431 self.to_k.load_state_dict(&extract("to_k"), strict)?;
432 self.to_v.load_state_dict(&extract("to_v"), strict)?;
433 self.to_out_0
434 .load_state_dict(&extract("to_out.0"), strict)?;
435 Ok(())
436 }
437}
438
439#[derive(Debug)]
449pub struct Upsample2D<T: Float> {
450 pub conv: Conv2d<T>,
452 channels: usize,
453 training: bool,
454}
455
456impl<T: Float> Upsample2D<T> {
457 pub fn new(channels: usize) -> FerrotorchResult<Self> {
463 let conv = Conv2d::<T>::new(channels, channels, (3, 3), (1, 1), (1, 1), true)?;
464 Ok(Self {
465 conv,
466 channels,
467 training: false,
468 })
469 }
470}
471
472impl<T: Float> Module<T> for Upsample2D<T> {
473 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
474 if input.ndim() != 4 || input.shape()[1] != self.channels {
475 return Err(FerrotorchError::ShapeMismatch {
476 message: format!(
477 "Upsample2D::forward: expected [B, {}, H, W], got {:?}",
478 self.channels,
479 input.shape()
480 ),
481 });
482 }
483 let h = input.shape()[2];
484 let w = input.shape()[3];
485 let up = Upsample::new([h * 2, w * 2], InterpolateMode::Nearest);
486 let upsampled = up.forward(input)?;
487 self.conv.forward(&upsampled)
488 }
489
490 fn parameters(&self) -> Vec<&Parameter<T>> {
491 self.conv.parameters()
492 }
493 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
494 self.conv.parameters_mut()
495 }
496 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
497 self.conv
498 .named_parameters()
499 .into_iter()
500 .map(|(n, p)| (format!("conv.{n}"), p))
501 .collect()
502 }
503 fn train(&mut self) {
504 self.training = true;
505 }
506 fn eval(&mut self) {
507 self.training = false;
508 }
509 fn is_training(&self) -> bool {
510 self.training
511 }
512 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
513 let conv_sd: StateDict<T> = state
514 .iter()
515 .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
516 .collect();
517 if strict {
518 for k in state.keys() {
519 if !k.starts_with("conv.") {
520 return Err(FerrotorchError::InvalidArgument {
521 message: format!("unexpected key in Upsample2D state_dict: \"{k}\""),
522 });
523 }
524 }
525 }
526 self.conv.load_state_dict(&conv_sd, strict)
527 }
528}
529
530#[derive(Debug)]
540pub struct Downsample2D<T: Float> {
541 pub conv: Conv2d<T>,
543 channels: usize,
544 training: bool,
545}
546
547impl<T: Float> Downsample2D<T> {
548 pub fn new(channels: usize) -> FerrotorchResult<Self> {
554 let conv = Conv2d::<T>::new(channels, channels, (3, 3), (2, 2), (1, 1), true)?;
555 Ok(Self {
556 conv,
557 channels,
558 training: false,
559 })
560 }
561}
562
563impl<T: Float> Module<T> for Downsample2D<T> {
564 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
565 if input.ndim() != 4 || input.shape()[1] != self.channels {
566 return Err(FerrotorchError::ShapeMismatch {
567 message: format!(
568 "Downsample2D::forward: expected [B, {}, H, W], got {:?}",
569 self.channels,
570 input.shape()
571 ),
572 });
573 }
574 self.conv.forward(input)
575 }
576 fn parameters(&self) -> Vec<&Parameter<T>> {
577 self.conv.parameters()
578 }
579 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
580 self.conv.parameters_mut()
581 }
582 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
583 self.conv
584 .named_parameters()
585 .into_iter()
586 .map(|(n, p)| (format!("conv.{n}"), p))
587 .collect()
588 }
589 fn train(&mut self) {
590 self.training = true;
591 }
592 fn eval(&mut self) {
593 self.training = false;
594 }
595 fn is_training(&self) -> bool {
596 self.training
597 }
598 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
599 let conv_sd: StateDict<T> = state
600 .iter()
601 .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
602 .collect();
603 if strict {
604 for k in state.keys() {
605 if !k.starts_with("conv.") {
606 return Err(FerrotorchError::InvalidArgument {
607 message: format!("unexpected key in Downsample2D state_dict: \"{k}\""),
608 });
609 }
610 }
611 }
612 self.conv.load_state_dict(&conv_sd, strict)
613 }
614}
615
616#[derive(Debug)]
630pub struct UpDecoderBlock2D<T: Float> {
631 pub resnets: Vec<ResnetBlock2D<T>>,
633 pub upsamplers_0: Option<Upsample2D<T>>,
635 training: bool,
636}
637
638impl<T: Float> UpDecoderBlock2D<T> {
639 pub fn new(
646 in_channels: usize,
647 out_channels: usize,
648 num_resnets: usize,
649 norm_num_groups: usize,
650 eps: f64,
651 add_upsample: bool,
652 ) -> FerrotorchResult<Self> {
653 if num_resnets == 0 {
654 return Err(FerrotorchError::InvalidArgument {
655 message: "UpDecoderBlock2D: num_resnets must be > 0".into(),
656 });
657 }
658 let mut resnets = Vec::with_capacity(num_resnets);
659 for i in 0..num_resnets {
660 let in_c = if i == 0 { in_channels } else { out_channels };
661 resnets.push(ResnetBlock2D::<T>::new(
662 in_c,
663 out_channels,
664 norm_num_groups,
665 eps,
666 )?);
667 }
668 let upsamplers_0 = if add_upsample {
669 Some(Upsample2D::<T>::new(out_channels)?)
670 } else {
671 None
672 };
673 Ok(Self {
674 resnets,
675 upsamplers_0,
676 training: false,
677 })
678 }
679}
680
681impl<T: Float> Module<T> for UpDecoderBlock2D<T> {
682 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
683 let mut h = input.clone();
684 for r in &self.resnets {
685 h = r.forward(&h)?;
686 }
687 if let Some(u) = &self.upsamplers_0 {
688 h = u.forward(&h)?;
689 }
690 Ok(h)
691 }
692
693 fn parameters(&self) -> Vec<&Parameter<T>> {
694 let mut out = Vec::new();
695 for r in &self.resnets {
696 out.extend(r.parameters());
697 }
698 if let Some(u) = &self.upsamplers_0 {
699 out.extend(u.parameters());
700 }
701 out
702 }
703
704 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
705 let mut out = Vec::new();
706 for r in &mut self.resnets {
707 out.extend(r.parameters_mut());
708 }
709 if let Some(u) = self.upsamplers_0.as_mut() {
710 out.extend(u.parameters_mut());
711 }
712 out
713 }
714
715 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
716 let mut out = Vec::new();
717 for (i, r) in self.resnets.iter().enumerate() {
718 for (n, p) in r.named_parameters() {
719 out.push((format!("resnets.{i}.{n}"), p));
720 }
721 }
722 if let Some(u) = &self.upsamplers_0 {
723 for (n, p) in u.named_parameters() {
724 out.push((format!("upsamplers.0.{n}"), p));
725 }
726 }
727 out
728 }
729
730 fn train(&mut self) {
731 self.training = true;
732 for r in &mut self.resnets {
733 r.train();
734 }
735 if let Some(u) = self.upsamplers_0.as_mut() {
736 u.train();
737 }
738 }
739 fn eval(&mut self) {
740 self.training = false;
741 for r in &mut self.resnets {
742 r.eval();
743 }
744 if let Some(u) = self.upsamplers_0.as_mut() {
745 u.eval();
746 }
747 }
748 fn is_training(&self) -> bool {
749 self.training
750 }
751
752 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
753 let extract = |prefix: &str| -> StateDict<T> {
754 let p = format!("{prefix}.");
755 state
756 .iter()
757 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
758 .collect()
759 };
760 if strict {
761 for k in state.keys() {
762 let ok =
763 k.starts_with("resnets.") || k.starts_with("upsamplers.0.");
764 if !ok {
765 return Err(FerrotorchError::InvalidArgument {
766 message: format!(
767 "unexpected key in UpDecoderBlock2D state_dict: \"{k}\""
768 ),
769 });
770 }
771 }
772 }
773 for (i, r) in self.resnets.iter_mut().enumerate() {
774 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
775 }
776 if let Some(u) = self.upsamplers_0.as_mut() {
777 u.load_state_dict(&extract("upsamplers.0"), strict)?;
778 } else if strict {
779 for k in state.keys() {
781 if k.starts_with("upsamplers.") {
782 return Err(FerrotorchError::InvalidArgument {
783 message: format!(
784 "UpDecoderBlock2D has no upsampler but state_dict contains \"{k}\""
785 ),
786 });
787 }
788 }
789 }
790 Ok(())
791 }
792}
793
794#[derive(Debug)]
809pub struct UNetMidBlock2D<T: Float> {
810 pub resnets: Vec<ResnetBlock2D<T>>,
812 pub attentions: Vec<AttnBlock2D<T>>,
814 channels: usize,
815 training: bool,
816}
817
818impl<T: Float> UNetMidBlock2D<T> {
819 pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
825 let resnets = vec![
827 ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
828 ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
829 ];
830 let attentions = vec![AttnBlock2D::<T>::new(channels, norm_num_groups, eps)?];
831 Ok(Self {
832 resnets,
833 attentions,
834 channels,
835 training: false,
836 })
837 }
838}
839
840impl<T: Float> Module<T> for UNetMidBlock2D<T> {
841 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
842 let mut h = self.resnets[0].forward(input)?;
847 for (attn, resnet) in self.attentions.iter().zip(self.resnets.iter().skip(1)) {
848 h = attn.forward(&h)?;
849 h = resnet.forward(&h)?;
850 }
851 Ok(h)
852 }
853
854 fn parameters(&self) -> Vec<&Parameter<T>> {
855 let mut out = Vec::new();
856 for a in &self.attentions {
861 out.extend(a.parameters());
862 }
863 for r in &self.resnets {
864 out.extend(r.parameters());
865 }
866 out
867 }
868
869 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
870 let mut out = Vec::new();
871 for a in &mut self.attentions {
872 out.extend(a.parameters_mut());
873 }
874 for r in &mut self.resnets {
875 out.extend(r.parameters_mut());
876 }
877 out
878 }
879
880 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
881 let mut out = Vec::new();
882 for (i, a) in self.attentions.iter().enumerate() {
885 for (n, p) in a.named_parameters() {
886 out.push((format!("attentions.{i}.{n}"), p));
887 }
888 }
889 for (i, r) in self.resnets.iter().enumerate() {
890 for (n, p) in r.named_parameters() {
891 out.push((format!("resnets.{i}.{n}"), p));
892 }
893 }
894 out
895 }
896
897 fn train(&mut self) {
898 self.training = true;
899 for r in &mut self.resnets {
900 r.train();
901 }
902 for a in &mut self.attentions {
903 a.train();
904 }
905 }
906 fn eval(&mut self) {
907 self.training = false;
908 for r in &mut self.resnets {
909 r.eval();
910 }
911 for a in &mut self.attentions {
912 a.eval();
913 }
914 }
915 fn is_training(&self) -> bool {
916 self.training
917 }
918
919 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
920 let extract = |prefix: &str| -> StateDict<T> {
921 let p = format!("{prefix}.");
922 state
923 .iter()
924 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
925 .collect()
926 };
927 if strict {
928 for k in state.keys() {
929 let ok = k.starts_with("attentions.") || k.starts_with("resnets.");
930 if !ok {
931 return Err(FerrotorchError::InvalidArgument {
932 message: format!(
933 "unexpected key in UNetMidBlock2D state_dict: \"{k}\""
934 ),
935 });
936 }
937 }
938 }
939 for (i, a) in self.attentions.iter_mut().enumerate() {
940 a.load_state_dict(&extract(&format!("attentions.{i}")), strict)?;
941 }
942 for (i, r) in self.resnets.iter_mut().enumerate() {
943 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
944 }
945 let _ = self.channels; let _: HashMap<String, Tensor<T>> = HashMap::new(); Ok(())
948 }
949}
950
951#[cfg(test)]
952mod tests {
953 use super::*;
954 use ferrotorch_core::TensorStorage;
955
956 #[test]
957 fn resnet_same_channels_no_shortcut() {
958 let r = ResnetBlock2D::<f32>::new(32, 32, 32, 1e-6).unwrap();
959 assert!(r.conv_shortcut.is_none());
960 let x = Tensor::from_storage(
961 TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
962 vec![1, 32, 4, 4],
963 false,
964 )
965 .unwrap();
966 let y = r.forward(&x).unwrap();
967 assert_eq!(y.shape(), &[1, 32, 4, 4]);
968 for &v in y.data().unwrap() {
969 assert!(v.is_finite());
970 }
971 }
972
973 #[test]
974 fn resnet_different_channels_has_shortcut() {
975 let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
976 assert!(r.conv_shortcut.is_some());
977 let x = Tensor::from_storage(
978 TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
979 vec![1, 32, 4, 4],
980 false,
981 )
982 .unwrap();
983 let y = r.forward(&x).unwrap();
984 assert_eq!(y.shape(), &[1, 64, 4, 4]);
985 }
986
987 #[test]
988 fn resnet_named_parameters_layout() {
989 let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
990 let names: Vec<String> = r.named_parameters().into_iter().map(|(n, _)| n).collect();
991 for k in [
992 "norm1.weight",
993 "norm1.bias",
994 "conv1.weight",
995 "conv1.bias",
996 "norm2.weight",
997 "norm2.bias",
998 "conv2.weight",
999 "conv2.bias",
1000 "conv_shortcut.weight",
1001 "conv_shortcut.bias",
1002 ] {
1003 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1004 }
1005 }
1006
1007 #[test]
1008 fn attn_shape_and_residual() {
1009 let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1011 let x = Tensor::from_storage(
1012 TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1013 vec![1, 32, 4, 4],
1014 false,
1015 )
1016 .unwrap();
1017 let y = a.forward(&x).unwrap();
1018 assert_eq!(y.shape(), &[1, 32, 4, 4]);
1019 for &v in y.data().unwrap() {
1020 assert!(v.is_finite(), "attn output non-finite: {v}");
1021 }
1022 }
1023
1024 #[test]
1025 fn attn_named_parameters_layout() {
1026 let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1027 let names: Vec<String> = a.named_parameters().into_iter().map(|(n, _)| n).collect();
1028 for k in [
1029 "group_norm.weight",
1030 "group_norm.bias",
1031 "to_q.weight",
1032 "to_q.bias",
1033 "to_k.weight",
1034 "to_k.bias",
1035 "to_v.weight",
1036 "to_v.bias",
1037 "to_out.0.weight",
1038 "to_out.0.bias",
1039 ] {
1040 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1041 }
1042 }
1043
1044 #[test]
1045 fn upsample2d_doubles_spatial() {
1046 let u = Upsample2D::<f32>::new(8).unwrap();
1047 let x = Tensor::from_storage(
1048 TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1049 vec![1, 8, 3, 3],
1050 false,
1051 )
1052 .unwrap();
1053 let y = u.forward(&x).unwrap();
1054 assert_eq!(y.shape(), &[1, 8, 6, 6]);
1055 }
1056
1057 #[test]
1058 fn up_decoder_block_shape_with_upsample() {
1059 let b = UpDecoderBlock2D::<f32>::new(16, 8, 2, 4, 1e-6, true).unwrap();
1060 let x = Tensor::from_storage(
1061 TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1062 vec![1, 16, 3, 3],
1063 false,
1064 )
1065 .unwrap();
1066 let y = b.forward(&x).unwrap();
1067 assert_eq!(y.shape(), &[1, 8, 6, 6]);
1068 }
1069
1070 #[test]
1071 fn up_decoder_block_shape_no_upsample() {
1072 let b = UpDecoderBlock2D::<f32>::new(8, 8, 2, 4, 1e-6, false).unwrap();
1073 let x = Tensor::from_storage(
1074 TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1075 vec![1, 8, 3, 3],
1076 false,
1077 )
1078 .unwrap();
1079 let y = b.forward(&x).unwrap();
1080 assert_eq!(y.shape(), &[1, 8, 3, 3]);
1081 }
1082
1083 #[test]
1084 fn mid_block_shape() {
1085 let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1086 let x = Tensor::from_storage(
1087 TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1088 vec![1, 16, 3, 3],
1089 false,
1090 )
1091 .unwrap();
1092 let y = m.forward(&x).unwrap();
1093 assert_eq!(y.shape(), &[1, 16, 3, 3]);
1094 }
1095
1096 #[test]
1097 fn mid_block_named_parameters_layout() {
1098 let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1099 let names: Vec<String> = m.named_parameters().into_iter().map(|(n, _)| n).collect();
1100 for k in [
1101 "attentions.0.group_norm.weight",
1102 "attentions.0.to_q.weight",
1103 "resnets.0.norm1.weight",
1104 "resnets.0.conv1.weight",
1105 "resnets.1.conv2.weight",
1106 ] {
1107 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1108 }
1109 }
1110}