1use std::collections::HashMap;
10
11use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
12use ferrotorch_nn::module::{Module, StateDict};
13use ferrotorch_nn::parameter::Parameter;
14use ferrotorch_nn::{Conv2d, GroupNorm, InterpolateMode, Linear, SiLU, Upsample};
15
16#[derive(Debug)]
36pub struct ResnetBlock2D<T: Float> {
37 pub norm1: GroupNorm<T>,
39 pub conv1: Conv2d<T>,
41 pub norm2: GroupNorm<T>,
43 pub conv2: Conv2d<T>,
45 pub conv_shortcut: Option<Conv2d<T>>,
48 activation: SiLU,
49 in_channels: usize,
50 #[allow(dead_code)]
51 out_channels: usize,
52 training: bool,
53}
54
55impl<T: Float> ResnetBlock2D<T> {
56 pub fn new(
63 in_channels: usize,
64 out_channels: usize,
65 norm_num_groups: usize,
66 eps: f64,
67 ) -> FerrotorchResult<Self> {
68 let norm1 = GroupNorm::<T>::new(norm_num_groups, in_channels, eps, true)?;
69 let conv1 = Conv2d::<T>::new(in_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
70 let norm2 = GroupNorm::<T>::new(norm_num_groups, out_channels, eps, true)?;
71 let conv2 = Conv2d::<T>::new(out_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
72 let conv_shortcut = if in_channels == out_channels {
73 None
74 } else {
75 Some(Conv2d::<T>::new(
76 in_channels,
77 out_channels,
78 (1, 1),
79 (1, 1),
80 (0, 0),
81 true,
82 )?)
83 };
84 Ok(Self {
85 norm1,
86 conv1,
87 norm2,
88 conv2,
89 conv_shortcut,
90 activation: SiLU::new(),
91 in_channels,
92 out_channels,
93 training: false,
94 })
95 }
96}
97
98impl<T: Float> Module<T> for ResnetBlock2D<T> {
99 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
100 if input.ndim() != 4 || input.shape()[1] != self.in_channels {
101 return Err(FerrotorchError::ShapeMismatch {
102 message: format!(
103 "ResnetBlock2D::forward: expected [B, {}, H, W], got {:?}",
104 self.in_channels,
105 input.shape()
106 ),
107 });
108 }
109 let mut h = self.norm1.forward(input)?;
111 h = self.activation.forward(&h)?;
112 h = self.conv1.forward(&h)?;
113 h = self.norm2.forward(&h)?;
115 h = self.activation.forward(&h)?;
116 h = self.conv2.forward(&h)?;
117 let res = if let Some(sc) = &self.conv_shortcut {
119 sc.forward(input)?
120 } else {
121 input.clone()
122 };
123 ferrotorch_core::grad_fns::arithmetic::add(&h, &res)
126 }
127
128 fn parameters(&self) -> Vec<&Parameter<T>> {
129 let mut out = Vec::new();
130 out.extend(self.norm1.parameters());
131 out.extend(self.conv1.parameters());
132 out.extend(self.norm2.parameters());
133 out.extend(self.conv2.parameters());
134 if let Some(sc) = &self.conv_shortcut {
135 out.extend(sc.parameters());
136 }
137 out
138 }
139
140 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
141 let mut out = Vec::new();
142 out.extend(self.norm1.parameters_mut());
143 out.extend(self.conv1.parameters_mut());
144 out.extend(self.norm2.parameters_mut());
145 out.extend(self.conv2.parameters_mut());
146 if let Some(sc) = &mut self.conv_shortcut {
147 out.extend(sc.parameters_mut());
148 }
149 out
150 }
151
152 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
153 let mut out = Vec::new();
154 for (n, p) in self.norm1.named_parameters() {
155 out.push((format!("norm1.{n}"), p));
156 }
157 for (n, p) in self.conv1.named_parameters() {
158 out.push((format!("conv1.{n}"), p));
159 }
160 for (n, p) in self.norm2.named_parameters() {
161 out.push((format!("norm2.{n}"), p));
162 }
163 for (n, p) in self.conv2.named_parameters() {
164 out.push((format!("conv2.{n}"), p));
165 }
166 if let Some(sc) = &self.conv_shortcut {
167 for (n, p) in sc.named_parameters() {
168 out.push((format!("conv_shortcut.{n}"), p));
169 }
170 }
171 out
172 }
173
174 fn train(&mut self) {
175 self.training = true;
176 }
177 fn eval(&mut self) {
178 self.training = false;
179 }
180 fn is_training(&self) -> bool {
181 self.training
182 }
183
184 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
185 let extract = |prefix: &str| -> StateDict<T> {
186 let p = format!("{prefix}.");
187 state
188 .iter()
189 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
190 .collect()
191 };
192 if strict {
193 for k in state.keys() {
194 let ok = k.starts_with("norm1.")
195 || k.starts_with("conv1.")
196 || k.starts_with("norm2.")
197 || k.starts_with("conv2.")
198 || k.starts_with("conv_shortcut.");
199 if !ok {
200 return Err(FerrotorchError::InvalidArgument {
201 message: format!("unexpected key in ResnetBlock2D state_dict: \"{k}\""),
202 });
203 }
204 }
205 }
206 self.norm1.load_state_dict(&extract("norm1"), strict)?;
207 self.conv1.load_state_dict(&extract("conv1"), strict)?;
208 self.norm2.load_state_dict(&extract("norm2"), strict)?;
209 self.conv2.load_state_dict(&extract("conv2"), strict)?;
210 if let Some(sc) = self.conv_shortcut.as_mut() {
211 sc.load_state_dict(&extract("conv_shortcut"), strict)?;
212 }
213 Ok(())
214 }
215}
216
217#[derive(Debug)]
243pub struct AttnBlock2D<T: Float> {
244 pub group_norm: GroupNorm<T>,
246 pub to_q: Linear<T>,
248 pub to_k: Linear<T>,
250 pub to_v: Linear<T>,
252 pub to_out_0: Linear<T>,
254 channels: usize,
255 training: bool,
256}
257
258impl<T: Float> AttnBlock2D<T> {
259 pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
267 let group_norm = GroupNorm::<T>::new(norm_num_groups, channels, eps, true)?;
268 let to_q = Linear::<T>::new(channels, channels, true)?;
269 let to_k = Linear::<T>::new(channels, channels, true)?;
270 let to_v = Linear::<T>::new(channels, channels, true)?;
271 let to_out_0 = Linear::<T>::new(channels, channels, true)?;
272 Ok(Self {
273 group_norm,
274 to_q,
275 to_k,
276 to_v,
277 to_out_0,
278 channels,
279 training: false,
280 })
281 }
282}
283
284impl<T: Float> Module<T> for AttnBlock2D<T> {
285 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
286 if input.ndim() != 4 || input.shape()[1] != self.channels {
287 return Err(FerrotorchError::ShapeMismatch {
288 message: format!(
289 "AttnBlock2D::forward: expected [B, {}, H, W], got {:?}",
290 self.channels,
291 input.shape()
292 ),
293 });
294 }
295 let b = input.shape()[0];
296 let c = input.shape()[1];
297 let h = input.shape()[2];
298 let w = input.shape()[3];
299 let hw = h * w;
300
301 let residual = input.clone();
303
304 let hidden = input
309 .reshape_t(&[b as isize, c as isize, hw as isize])?
310 .transpose(1, 2)?
311 .contiguous()?;
312 let normed_hwc = self
314 .group_norm
315 .forward(&hidden.transpose(1, 2)?.contiguous()?)?
316 .transpose(1, 2)?
317 .contiguous()?;
318
319 let q = self.to_q.forward(&normed_hwc)?; let k = self.to_k.forward(&normed_hwc)?; let v = self.to_v.forward(&normed_hwc)?; let scale = (c as f64).sqrt().recip();
332 let scale_t = T::from(scale).ok_or_else(|| FerrotorchError::InvalidArgument {
333 message: "AttnBlock2D::forward: failed to cast attention scale into Float".into(),
334 })?;
335 let scale_tensor = ferrotorch_core::scalar::<T>(scale_t)?;
336 let k_t = k.transpose(1, 2)?.contiguous()?; let scores = q.bmm(&k_t)?; let scores_scaled = ferrotorch_core::grad_fns::arithmetic::mul(&scores, &scale_tensor)?;
339 let probs = scores_scaled.softmax()?; let attended = probs.bmm(&v)?; let projected = self.to_out_0.forward(&attended)?; let back = projected
347 .transpose(1, 2)? .reshape_t(&[b as isize, c as isize, h as isize, w as isize])?
349 .contiguous()?;
350 ferrotorch_core::grad_fns::arithmetic::add(&back, &residual)
351 }
352
353 fn parameters(&self) -> Vec<&Parameter<T>> {
354 let mut out = Vec::new();
355 out.extend(self.group_norm.parameters());
356 out.extend(self.to_q.parameters());
357 out.extend(self.to_k.parameters());
358 out.extend(self.to_v.parameters());
359 out.extend(self.to_out_0.parameters());
360 out
361 }
362
363 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
364 let mut out = Vec::new();
365 out.extend(self.group_norm.parameters_mut());
366 out.extend(self.to_q.parameters_mut());
367 out.extend(self.to_k.parameters_mut());
368 out.extend(self.to_v.parameters_mut());
369 out.extend(self.to_out_0.parameters_mut());
370 out
371 }
372
373 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
374 let mut out = Vec::new();
375 for (n, p) in self.group_norm.named_parameters() {
376 out.push((format!("group_norm.{n}"), p));
377 }
378 for (n, p) in self.to_q.named_parameters() {
379 out.push((format!("to_q.{n}"), p));
380 }
381 for (n, p) in self.to_k.named_parameters() {
382 out.push((format!("to_k.{n}"), p));
383 }
384 for (n, p) in self.to_v.named_parameters() {
385 out.push((format!("to_v.{n}"), p));
386 }
387 for (n, p) in self.to_out_0.named_parameters() {
388 out.push((format!("to_out.0.{n}"), p));
389 }
390 out
391 }
392
393 fn train(&mut self) {
394 self.training = true;
395 }
396 fn eval(&mut self) {
397 self.training = false;
398 }
399 fn is_training(&self) -> bool {
400 self.training
401 }
402
403 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
404 let extract = |prefix: &str| -> StateDict<T> {
405 let p = format!("{prefix}.");
406 state
407 .iter()
408 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
409 .collect()
410 };
411 if strict {
412 for k in state.keys() {
413 let ok = k.starts_with("group_norm.")
414 || k.starts_with("to_q.")
415 || k.starts_with("to_k.")
416 || k.starts_with("to_v.")
417 || k.starts_with("to_out.0.");
418 if !ok {
419 return Err(FerrotorchError::InvalidArgument {
420 message: format!("unexpected key in AttnBlock2D state_dict: \"{k}\""),
421 });
422 }
423 }
424 }
425 self.group_norm
426 .load_state_dict(&extract("group_norm"), strict)?;
427 self.to_q.load_state_dict(&extract("to_q"), strict)?;
428 self.to_k.load_state_dict(&extract("to_k"), strict)?;
429 self.to_v.load_state_dict(&extract("to_v"), strict)?;
430 self.to_out_0
431 .load_state_dict(&extract("to_out.0"), strict)?;
432 Ok(())
433 }
434}
435
436#[derive(Debug)]
446pub struct Upsample2D<T: Float> {
447 pub conv: Conv2d<T>,
449 channels: usize,
450 training: bool,
451}
452
453impl<T: Float> Upsample2D<T> {
454 pub fn new(channels: usize) -> FerrotorchResult<Self> {
460 let conv = Conv2d::<T>::new(channels, channels, (3, 3), (1, 1), (1, 1), true)?;
461 Ok(Self {
462 conv,
463 channels,
464 training: false,
465 })
466 }
467}
468
469impl<T: Float> Module<T> for Upsample2D<T> {
470 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
471 if input.ndim() != 4 || input.shape()[1] != self.channels {
472 return Err(FerrotorchError::ShapeMismatch {
473 message: format!(
474 "Upsample2D::forward: expected [B, {}, H, W], got {:?}",
475 self.channels,
476 input.shape()
477 ),
478 });
479 }
480 let h = input.shape()[2];
481 let w = input.shape()[3];
482 let up = Upsample::new([h * 2, w * 2], InterpolateMode::Nearest);
483 let upsampled = up.forward(input)?;
484 self.conv.forward(&upsampled)
485 }
486
487 fn parameters(&self) -> Vec<&Parameter<T>> {
488 self.conv.parameters()
489 }
490 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
491 self.conv.parameters_mut()
492 }
493 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
494 self.conv
495 .named_parameters()
496 .into_iter()
497 .map(|(n, p)| (format!("conv.{n}"), p))
498 .collect()
499 }
500 fn train(&mut self) {
501 self.training = true;
502 }
503 fn eval(&mut self) {
504 self.training = false;
505 }
506 fn is_training(&self) -> bool {
507 self.training
508 }
509 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
510 let conv_sd: StateDict<T> = state
511 .iter()
512 .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
513 .collect();
514 if strict {
515 for k in state.keys() {
516 if !k.starts_with("conv.") {
517 return Err(FerrotorchError::InvalidArgument {
518 message: format!("unexpected key in Upsample2D state_dict: \"{k}\""),
519 });
520 }
521 }
522 }
523 self.conv.load_state_dict(&conv_sd, strict)
524 }
525}
526
527#[derive(Debug)]
537pub struct Downsample2D<T: Float> {
538 pub conv: Conv2d<T>,
540 channels: usize,
541 training: bool,
542}
543
544impl<T: Float> Downsample2D<T> {
545 pub fn new(channels: usize) -> FerrotorchResult<Self> {
551 let conv = Conv2d::<T>::new(channels, channels, (3, 3), (2, 2), (1, 1), true)?;
552 Ok(Self {
553 conv,
554 channels,
555 training: false,
556 })
557 }
558}
559
560impl<T: Float> Module<T> for Downsample2D<T> {
561 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
562 if input.ndim() != 4 || input.shape()[1] != self.channels {
563 return Err(FerrotorchError::ShapeMismatch {
564 message: format!(
565 "Downsample2D::forward: expected [B, {}, H, W], got {:?}",
566 self.channels,
567 input.shape()
568 ),
569 });
570 }
571 self.conv.forward(input)
572 }
573 fn parameters(&self) -> Vec<&Parameter<T>> {
574 self.conv.parameters()
575 }
576 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
577 self.conv.parameters_mut()
578 }
579 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
580 self.conv
581 .named_parameters()
582 .into_iter()
583 .map(|(n, p)| (format!("conv.{n}"), p))
584 .collect()
585 }
586 fn train(&mut self) {
587 self.training = true;
588 }
589 fn eval(&mut self) {
590 self.training = false;
591 }
592 fn is_training(&self) -> bool {
593 self.training
594 }
595 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
596 let conv_sd: StateDict<T> = state
597 .iter()
598 .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
599 .collect();
600 if strict {
601 for k in state.keys() {
602 if !k.starts_with("conv.") {
603 return Err(FerrotorchError::InvalidArgument {
604 message: format!("unexpected key in Downsample2D state_dict: \"{k}\""),
605 });
606 }
607 }
608 }
609 self.conv.load_state_dict(&conv_sd, strict)
610 }
611}
612
613#[derive(Debug)]
627pub struct UpDecoderBlock2D<T: Float> {
628 pub resnets: Vec<ResnetBlock2D<T>>,
630 pub upsamplers_0: Option<Upsample2D<T>>,
632 training: bool,
633}
634
635impl<T: Float> UpDecoderBlock2D<T> {
636 pub fn new(
643 in_channels: usize,
644 out_channels: usize,
645 num_resnets: usize,
646 norm_num_groups: usize,
647 eps: f64,
648 add_upsample: bool,
649 ) -> FerrotorchResult<Self> {
650 if num_resnets == 0 {
651 return Err(FerrotorchError::InvalidArgument {
652 message: "UpDecoderBlock2D: num_resnets must be > 0".into(),
653 });
654 }
655 let mut resnets = Vec::with_capacity(num_resnets);
656 for i in 0..num_resnets {
657 let in_c = if i == 0 { in_channels } else { out_channels };
658 resnets.push(ResnetBlock2D::<T>::new(
659 in_c,
660 out_channels,
661 norm_num_groups,
662 eps,
663 )?);
664 }
665 let upsamplers_0 = if add_upsample {
666 Some(Upsample2D::<T>::new(out_channels)?)
667 } else {
668 None
669 };
670 Ok(Self {
671 resnets,
672 upsamplers_0,
673 training: false,
674 })
675 }
676}
677
678impl<T: Float> Module<T> for UpDecoderBlock2D<T> {
679 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
680 let mut h = input.clone();
681 for r in &self.resnets {
682 h = r.forward(&h)?;
683 }
684 if let Some(u) = &self.upsamplers_0 {
685 h = u.forward(&h)?;
686 }
687 Ok(h)
688 }
689
690 fn parameters(&self) -> Vec<&Parameter<T>> {
691 let mut out = Vec::new();
692 for r in &self.resnets {
693 out.extend(r.parameters());
694 }
695 if let Some(u) = &self.upsamplers_0 {
696 out.extend(u.parameters());
697 }
698 out
699 }
700
701 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
702 let mut out = Vec::new();
703 for r in &mut self.resnets {
704 out.extend(r.parameters_mut());
705 }
706 if let Some(u) = self.upsamplers_0.as_mut() {
707 out.extend(u.parameters_mut());
708 }
709 out
710 }
711
712 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
713 let mut out = Vec::new();
714 for (i, r) in self.resnets.iter().enumerate() {
715 for (n, p) in r.named_parameters() {
716 out.push((format!("resnets.{i}.{n}"), p));
717 }
718 }
719 if let Some(u) = &self.upsamplers_0 {
720 for (n, p) in u.named_parameters() {
721 out.push((format!("upsamplers.0.{n}"), p));
722 }
723 }
724 out
725 }
726
727 fn train(&mut self) {
728 self.training = true;
729 for r in &mut self.resnets {
730 r.train();
731 }
732 if let Some(u) = self.upsamplers_0.as_mut() {
733 u.train();
734 }
735 }
736 fn eval(&mut self) {
737 self.training = false;
738 for r in &mut self.resnets {
739 r.eval();
740 }
741 if let Some(u) = self.upsamplers_0.as_mut() {
742 u.eval();
743 }
744 }
745 fn is_training(&self) -> bool {
746 self.training
747 }
748
749 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
750 let extract = |prefix: &str| -> StateDict<T> {
751 let p = format!("{prefix}.");
752 state
753 .iter()
754 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
755 .collect()
756 };
757 if strict {
758 for k in state.keys() {
759 let ok = k.starts_with("resnets.") || k.starts_with("upsamplers.0.");
760 if !ok {
761 return Err(FerrotorchError::InvalidArgument {
762 message: format!("unexpected key in UpDecoderBlock2D state_dict: \"{k}\""),
763 });
764 }
765 }
766 }
767 for (i, r) in self.resnets.iter_mut().enumerate() {
768 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
769 }
770 if let Some(u) = self.upsamplers_0.as_mut() {
771 u.load_state_dict(&extract("upsamplers.0"), strict)?;
772 } else if strict {
773 for k in state.keys() {
775 if k.starts_with("upsamplers.") {
776 return Err(FerrotorchError::InvalidArgument {
777 message: format!(
778 "UpDecoderBlock2D has no upsampler but state_dict contains \"{k}\""
779 ),
780 });
781 }
782 }
783 }
784 Ok(())
785 }
786}
787
788#[derive(Debug)]
807pub struct DownEncoderBlock2D<T: Float> {
808 pub resnets: Vec<ResnetBlock2D<T>>,
813 pub downsamplers_0: Option<Downsample2D<T>>,
816 training: bool,
817}
818
819impl<T: Float> DownEncoderBlock2D<T> {
820 pub fn new(
828 in_channels: usize,
829 out_channels: usize,
830 num_resnets: usize,
831 norm_num_groups: usize,
832 eps: f64,
833 add_downsample: bool,
834 ) -> FerrotorchResult<Self> {
835 if num_resnets == 0 {
836 return Err(FerrotorchError::InvalidArgument {
837 message: "DownEncoderBlock2D: num_resnets must be > 0".into(),
838 });
839 }
840 let mut resnets = Vec::with_capacity(num_resnets);
841 for i in 0..num_resnets {
842 let in_c = if i == 0 { in_channels } else { out_channels };
843 resnets.push(ResnetBlock2D::<T>::new(
844 in_c,
845 out_channels,
846 norm_num_groups,
847 eps,
848 )?);
849 }
850 let downsamplers_0 = if add_downsample {
851 Some(Downsample2D::<T>::new(out_channels)?)
852 } else {
853 None
854 };
855 Ok(Self {
856 resnets,
857 downsamplers_0,
858 training: false,
859 })
860 }
861}
862
863impl<T: Float> Module<T> for DownEncoderBlock2D<T> {
864 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
865 let mut h = input.clone();
866 for r in &self.resnets {
867 h = r.forward(&h)?;
868 }
869 if let Some(d) = &self.downsamplers_0 {
870 h = d.forward(&h)?;
871 }
872 Ok(h)
873 }
874
875 fn parameters(&self) -> Vec<&Parameter<T>> {
876 let mut out = Vec::new();
877 for r in &self.resnets {
878 out.extend(r.parameters());
879 }
880 if let Some(d) = &self.downsamplers_0 {
881 out.extend(d.parameters());
882 }
883 out
884 }
885
886 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
887 let mut out = Vec::new();
888 for r in &mut self.resnets {
889 out.extend(r.parameters_mut());
890 }
891 if let Some(d) = self.downsamplers_0.as_mut() {
892 out.extend(d.parameters_mut());
893 }
894 out
895 }
896
897 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
898 let mut out = Vec::new();
899 for (i, r) in self.resnets.iter().enumerate() {
900 for (n, p) in r.named_parameters() {
901 out.push((format!("resnets.{i}.{n}"), p));
902 }
903 }
904 if let Some(d) = &self.downsamplers_0 {
905 for (n, p) in d.named_parameters() {
906 out.push((format!("downsamplers.0.{n}"), p));
907 }
908 }
909 out
910 }
911
912 fn train(&mut self) {
913 self.training = true;
914 for r in &mut self.resnets {
915 r.train();
916 }
917 if let Some(d) = self.downsamplers_0.as_mut() {
918 d.train();
919 }
920 }
921 fn eval(&mut self) {
922 self.training = false;
923 for r in &mut self.resnets {
924 r.eval();
925 }
926 if let Some(d) = self.downsamplers_0.as_mut() {
927 d.eval();
928 }
929 }
930 fn is_training(&self) -> bool {
931 self.training
932 }
933
934 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
935 let extract = |prefix: &str| -> StateDict<T> {
936 let p = format!("{prefix}.");
937 state
938 .iter()
939 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
940 .collect()
941 };
942 if strict {
943 for k in state.keys() {
944 let ok = k.starts_with("resnets.") || k.starts_with("downsamplers.0.");
945 if !ok {
946 return Err(FerrotorchError::InvalidArgument {
947 message: format!(
948 "unexpected key in DownEncoderBlock2D state_dict: \"{k}\""
949 ),
950 });
951 }
952 }
953 }
954 for (i, r) in self.resnets.iter_mut().enumerate() {
955 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
956 }
957 if let Some(d) = self.downsamplers_0.as_mut() {
958 d.load_state_dict(&extract("downsamplers.0"), strict)?;
959 } else if strict {
960 for k in state.keys() {
961 if k.starts_with("downsamplers.") {
962 return Err(FerrotorchError::InvalidArgument {
963 message: format!(
964 "DownEncoderBlock2D has no downsampler but state_dict contains \"{k}\""
965 ),
966 });
967 }
968 }
969 }
970 Ok(())
971 }
972}
973
974#[derive(Debug)]
989pub struct UNetMidBlock2D<T: Float> {
990 pub resnets: Vec<ResnetBlock2D<T>>,
992 pub attentions: Vec<AttnBlock2D<T>>,
994 channels: usize,
995 training: bool,
996}
997
998impl<T: Float> UNetMidBlock2D<T> {
999 pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
1005 let resnets = vec![
1007 ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
1008 ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
1009 ];
1010 let attentions = vec![AttnBlock2D::<T>::new(channels, norm_num_groups, eps)?];
1011 Ok(Self {
1012 resnets,
1013 attentions,
1014 channels,
1015 training: false,
1016 })
1017 }
1018}
1019
1020impl<T: Float> Module<T> for UNetMidBlock2D<T> {
1021 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1022 let mut h = self.resnets[0].forward(input)?;
1027 for (attn, resnet) in self.attentions.iter().zip(self.resnets.iter().skip(1)) {
1028 h = attn.forward(&h)?;
1029 h = resnet.forward(&h)?;
1030 }
1031 Ok(h)
1032 }
1033
1034 fn parameters(&self) -> Vec<&Parameter<T>> {
1035 let mut out = Vec::new();
1036 for a in &self.attentions {
1041 out.extend(a.parameters());
1042 }
1043 for r in &self.resnets {
1044 out.extend(r.parameters());
1045 }
1046 out
1047 }
1048
1049 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1050 let mut out = Vec::new();
1051 for a in &mut self.attentions {
1052 out.extend(a.parameters_mut());
1053 }
1054 for r in &mut self.resnets {
1055 out.extend(r.parameters_mut());
1056 }
1057 out
1058 }
1059
1060 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1061 let mut out = Vec::new();
1062 for (i, a) in self.attentions.iter().enumerate() {
1065 for (n, p) in a.named_parameters() {
1066 out.push((format!("attentions.{i}.{n}"), p));
1067 }
1068 }
1069 for (i, r) in self.resnets.iter().enumerate() {
1070 for (n, p) in r.named_parameters() {
1071 out.push((format!("resnets.{i}.{n}"), p));
1072 }
1073 }
1074 out
1075 }
1076
1077 fn train(&mut self) {
1078 self.training = true;
1079 for r in &mut self.resnets {
1080 r.train();
1081 }
1082 for a in &mut self.attentions {
1083 a.train();
1084 }
1085 }
1086 fn eval(&mut self) {
1087 self.training = false;
1088 for r in &mut self.resnets {
1089 r.eval();
1090 }
1091 for a in &mut self.attentions {
1092 a.eval();
1093 }
1094 }
1095 fn is_training(&self) -> bool {
1096 self.training
1097 }
1098
1099 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1100 let extract = |prefix: &str| -> StateDict<T> {
1101 let p = format!("{prefix}.");
1102 state
1103 .iter()
1104 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
1105 .collect()
1106 };
1107 if strict {
1108 for k in state.keys() {
1109 let ok = k.starts_with("attentions.") || k.starts_with("resnets.");
1110 if !ok {
1111 return Err(FerrotorchError::InvalidArgument {
1112 message: format!("unexpected key in UNetMidBlock2D state_dict: \"{k}\""),
1113 });
1114 }
1115 }
1116 }
1117 for (i, a) in self.attentions.iter_mut().enumerate() {
1118 a.load_state_dict(&extract(&format!("attentions.{i}")), strict)?;
1119 }
1120 for (i, r) in self.resnets.iter_mut().enumerate() {
1121 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
1122 }
1123 let _ = self.channels; let _: HashMap<String, Tensor<T>> = HashMap::new(); Ok(())
1126 }
1127}
1128
1129#[cfg(test)]
1130mod tests {
1131 use super::*;
1132 use ferrotorch_core::TensorStorage;
1133
1134 #[test]
1135 fn resnet_same_channels_no_shortcut() {
1136 let r = ResnetBlock2D::<f32>::new(32, 32, 32, 1e-6).unwrap();
1137 assert!(r.conv_shortcut.is_none());
1138 let x = Tensor::from_storage(
1139 TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1140 vec![1, 32, 4, 4],
1141 false,
1142 )
1143 .unwrap();
1144 let y = r.forward(&x).unwrap();
1145 assert_eq!(y.shape(), &[1, 32, 4, 4]);
1146 for &v in y.data().unwrap() {
1147 assert!(v.is_finite());
1148 }
1149 }
1150
1151 #[test]
1152 fn resnet_different_channels_has_shortcut() {
1153 let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
1154 assert!(r.conv_shortcut.is_some());
1155 let x = Tensor::from_storage(
1156 TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1157 vec![1, 32, 4, 4],
1158 false,
1159 )
1160 .unwrap();
1161 let y = r.forward(&x).unwrap();
1162 assert_eq!(y.shape(), &[1, 64, 4, 4]);
1163 }
1164
1165 #[test]
1166 fn resnet_named_parameters_layout() {
1167 let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
1168 let names: Vec<String> = r.named_parameters().into_iter().map(|(n, _)| n).collect();
1169 for k in [
1170 "norm1.weight",
1171 "norm1.bias",
1172 "conv1.weight",
1173 "conv1.bias",
1174 "norm2.weight",
1175 "norm2.bias",
1176 "conv2.weight",
1177 "conv2.bias",
1178 "conv_shortcut.weight",
1179 "conv_shortcut.bias",
1180 ] {
1181 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1182 }
1183 }
1184
1185 #[test]
1186 fn attn_shape_and_residual() {
1187 let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1189 let x = Tensor::from_storage(
1190 TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1191 vec![1, 32, 4, 4],
1192 false,
1193 )
1194 .unwrap();
1195 let y = a.forward(&x).unwrap();
1196 assert_eq!(y.shape(), &[1, 32, 4, 4]);
1197 for &v in y.data().unwrap() {
1198 assert!(v.is_finite(), "attn output non-finite: {v}");
1199 }
1200 }
1201
1202 #[test]
1203 fn attn_named_parameters_layout() {
1204 let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1205 let names: Vec<String> = a.named_parameters().into_iter().map(|(n, _)| n).collect();
1206 for k in [
1207 "group_norm.weight",
1208 "group_norm.bias",
1209 "to_q.weight",
1210 "to_q.bias",
1211 "to_k.weight",
1212 "to_k.bias",
1213 "to_v.weight",
1214 "to_v.bias",
1215 "to_out.0.weight",
1216 "to_out.0.bias",
1217 ] {
1218 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1219 }
1220 }
1221
1222 #[test]
1223 fn upsample2d_doubles_spatial() {
1224 let u = Upsample2D::<f32>::new(8).unwrap();
1225 let x = Tensor::from_storage(
1226 TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1227 vec![1, 8, 3, 3],
1228 false,
1229 )
1230 .unwrap();
1231 let y = u.forward(&x).unwrap();
1232 assert_eq!(y.shape(), &[1, 8, 6, 6]);
1233 }
1234
1235 #[test]
1236 fn up_decoder_block_shape_with_upsample() {
1237 let b = UpDecoderBlock2D::<f32>::new(16, 8, 2, 4, 1e-6, true).unwrap();
1238 let x = Tensor::from_storage(
1239 TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1240 vec![1, 16, 3, 3],
1241 false,
1242 )
1243 .unwrap();
1244 let y = b.forward(&x).unwrap();
1245 assert_eq!(y.shape(), &[1, 8, 6, 6]);
1246 }
1247
1248 #[test]
1249 fn up_decoder_block_shape_no_upsample() {
1250 let b = UpDecoderBlock2D::<f32>::new(8, 8, 2, 4, 1e-6, false).unwrap();
1251 let x = Tensor::from_storage(
1252 TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1253 vec![1, 8, 3, 3],
1254 false,
1255 )
1256 .unwrap();
1257 let y = b.forward(&x).unwrap();
1258 assert_eq!(y.shape(), &[1, 8, 3, 3]);
1259 }
1260
1261 #[test]
1262 fn down_encoder_block_shape_with_downsample() {
1263 let b = DownEncoderBlock2D::<f32>::new(8, 16, 2, 4, 1e-6, true).unwrap();
1265 let x = Tensor::from_storage(
1266 TensorStorage::cpu(vec![0.05f32; 8 * 4 * 4]),
1267 vec![1, 8, 4, 4],
1268 false,
1269 )
1270 .unwrap();
1271 let y = b.forward(&x).unwrap();
1272 assert_eq!(y.shape(), &[1, 16, 2, 2]);
1273 }
1274
1275 #[test]
1276 fn down_encoder_block_shape_no_downsample() {
1277 let b = DownEncoderBlock2D::<f32>::new(16, 16, 2, 4, 1e-6, false).unwrap();
1279 let x = Tensor::from_storage(
1280 TensorStorage::cpu(vec![0.05f32; 16 * 2 * 2]),
1281 vec![1, 16, 2, 2],
1282 false,
1283 )
1284 .unwrap();
1285 let y = b.forward(&x).unwrap();
1286 assert_eq!(y.shape(), &[1, 16, 2, 2]);
1287 }
1288
1289 #[test]
1290 fn down_encoder_block_named_parameters_layout() {
1291 let b = DownEncoderBlock2D::<f32>::new(8, 16, 2, 4, 1e-6, true).unwrap();
1292 let names: Vec<String> = b.named_parameters().into_iter().map(|(n, _)| n).collect();
1293 for k in [
1294 "resnets.0.norm1.weight",
1295 "resnets.0.conv1.weight",
1296 "resnets.0.conv_shortcut.weight",
1297 "resnets.1.norm1.weight",
1298 "downsamplers.0.conv.weight",
1299 "downsamplers.0.conv.bias",
1300 ] {
1301 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1302 }
1303 }
1304
1305 #[test]
1306 fn mid_block_shape() {
1307 let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1308 let x = Tensor::from_storage(
1309 TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1310 vec![1, 16, 3, 3],
1311 false,
1312 )
1313 .unwrap();
1314 let y = m.forward(&x).unwrap();
1315 assert_eq!(y.shape(), &[1, 16, 3, 3]);
1316 }
1317
1318 #[test]
1319 fn mid_block_named_parameters_layout() {
1320 let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1321 let names: Vec<String> = m.named_parameters().into_iter().map(|(n, _)| n).collect();
1322 for k in [
1323 "attentions.0.group_norm.weight",
1324 "attentions.0.to_q.weight",
1325 "resnets.0.norm1.weight",
1326 "resnets.0.conv1.weight",
1327 "resnets.1.conv2.weight",
1328 ] {
1329 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1330 }
1331 }
1332}