1use std::collections::HashMap;
22
23use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
24use ferrotorch_nn::module::{Module, StateDict};
25use ferrotorch_nn::parameter::Parameter;
26use ferrotorch_nn::{Conv2d, GroupNorm, InterpolateMode, Linear, SiLU, Upsample};
27
28#[derive(Debug)]
48pub struct ResnetBlock2D<T: Float> {
49 pub norm1: GroupNorm<T>,
51 pub conv1: Conv2d<T>,
53 pub norm2: GroupNorm<T>,
55 pub conv2: Conv2d<T>,
57 pub conv_shortcut: Option<Conv2d<T>>,
60 activation: SiLU,
61 in_channels: usize,
62 #[allow(dead_code)]
63 out_channels: usize,
64 training: bool,
65}
66
67impl<T: Float> ResnetBlock2D<T> {
68 pub fn new(
75 in_channels: usize,
76 out_channels: usize,
77 norm_num_groups: usize,
78 eps: f64,
79 ) -> FerrotorchResult<Self> {
80 let norm1 = GroupNorm::<T>::new(norm_num_groups, in_channels, eps, true)?;
81 let conv1 = Conv2d::<T>::new(in_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
82 let norm2 = GroupNorm::<T>::new(norm_num_groups, out_channels, eps, true)?;
83 let conv2 = Conv2d::<T>::new(out_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
84 let conv_shortcut = if in_channels == out_channels {
85 None
86 } else {
87 Some(Conv2d::<T>::new(
88 in_channels,
89 out_channels,
90 (1, 1),
91 (1, 1),
92 (0, 0),
93 true,
94 )?)
95 };
96 Ok(Self {
97 norm1,
98 conv1,
99 norm2,
100 conv2,
101 conv_shortcut,
102 activation: SiLU::new(),
103 in_channels,
104 out_channels,
105 training: false,
106 })
107 }
108}
109
110impl<T: Float> Module<T> for ResnetBlock2D<T> {
111 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
112 if input.ndim() != 4 || input.shape()[1] != self.in_channels {
113 return Err(FerrotorchError::ShapeMismatch {
114 message: format!(
115 "ResnetBlock2D::forward: expected [B, {}, H, W], got {:?}",
116 self.in_channels,
117 input.shape()
118 ),
119 });
120 }
121 let mut h = self.norm1.forward(input)?;
123 h = self.activation.forward(&h)?;
124 h = self.conv1.forward(&h)?;
125 h = self.norm2.forward(&h)?;
127 h = self.activation.forward(&h)?;
128 h = self.conv2.forward(&h)?;
129 let res = if let Some(sc) = &self.conv_shortcut {
131 sc.forward(input)?
132 } else {
133 input.clone()
134 };
135 ferrotorch_core::grad_fns::arithmetic::add(&h, &res)
138 }
139
140 fn parameters(&self) -> Vec<&Parameter<T>> {
141 let mut out = Vec::new();
142 out.extend(self.norm1.parameters());
143 out.extend(self.conv1.parameters());
144 out.extend(self.norm2.parameters());
145 out.extend(self.conv2.parameters());
146 if let Some(sc) = &self.conv_shortcut {
147 out.extend(sc.parameters());
148 }
149 out
150 }
151
152 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
153 let mut out = Vec::new();
154 out.extend(self.norm1.parameters_mut());
155 out.extend(self.conv1.parameters_mut());
156 out.extend(self.norm2.parameters_mut());
157 out.extend(self.conv2.parameters_mut());
158 if let Some(sc) = &mut self.conv_shortcut {
159 out.extend(sc.parameters_mut());
160 }
161 out
162 }
163
164 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
165 let mut out = Vec::new();
166 for (n, p) in self.norm1.named_parameters() {
167 out.push((format!("norm1.{n}"), p));
168 }
169 for (n, p) in self.conv1.named_parameters() {
170 out.push((format!("conv1.{n}"), p));
171 }
172 for (n, p) in self.norm2.named_parameters() {
173 out.push((format!("norm2.{n}"), p));
174 }
175 for (n, p) in self.conv2.named_parameters() {
176 out.push((format!("conv2.{n}"), p));
177 }
178 if let Some(sc) = &self.conv_shortcut {
179 for (n, p) in sc.named_parameters() {
180 out.push((format!("conv_shortcut.{n}"), p));
181 }
182 }
183 out
184 }
185
186 fn train(&mut self) {
187 self.training = true;
188 }
189 fn eval(&mut self) {
190 self.training = false;
191 }
192 fn is_training(&self) -> bool {
193 self.training
194 }
195
196 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
197 let extract = |prefix: &str| -> StateDict<T> {
198 let p = format!("{prefix}.");
199 state
200 .iter()
201 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
202 .collect()
203 };
204 if strict {
205 for k in state.keys() {
206 let ok = k.starts_with("norm1.")
207 || k.starts_with("conv1.")
208 || k.starts_with("norm2.")
209 || k.starts_with("conv2.")
210 || k.starts_with("conv_shortcut.");
211 if !ok {
212 return Err(FerrotorchError::InvalidArgument {
213 message: format!("unexpected key in ResnetBlock2D state_dict: \"{k}\""),
214 });
215 }
216 }
217 }
218 self.norm1.load_state_dict(&extract("norm1"), strict)?;
219 self.conv1.load_state_dict(&extract("conv1"), strict)?;
220 self.norm2.load_state_dict(&extract("norm2"), strict)?;
221 self.conv2.load_state_dict(&extract("conv2"), strict)?;
222 if let Some(sc) = self.conv_shortcut.as_mut() {
223 sc.load_state_dict(&extract("conv_shortcut"), strict)?;
224 }
225 Ok(())
226 }
227}
228
229#[derive(Debug)]
255pub struct AttnBlock2D<T: Float> {
256 pub group_norm: GroupNorm<T>,
258 pub to_q: Linear<T>,
260 pub to_k: Linear<T>,
262 pub to_v: Linear<T>,
264 pub to_out_0: Linear<T>,
266 channels: usize,
267 training: bool,
268}
269
270impl<T: Float> AttnBlock2D<T> {
271 pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
279 let group_norm = GroupNorm::<T>::new(norm_num_groups, channels, eps, true)?;
280 let to_q = Linear::<T>::new(channels, channels, true)?;
281 let to_k = Linear::<T>::new(channels, channels, true)?;
282 let to_v = Linear::<T>::new(channels, channels, true)?;
283 let to_out_0 = Linear::<T>::new(channels, channels, true)?;
284 Ok(Self {
285 group_norm,
286 to_q,
287 to_k,
288 to_v,
289 to_out_0,
290 channels,
291 training: false,
292 })
293 }
294}
295
296impl<T: Float> Module<T> for AttnBlock2D<T> {
297 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
298 if input.ndim() != 4 || input.shape()[1] != self.channels {
299 return Err(FerrotorchError::ShapeMismatch {
300 message: format!(
301 "AttnBlock2D::forward: expected [B, {}, H, W], got {:?}",
302 self.channels,
303 input.shape()
304 ),
305 });
306 }
307 let b = input.shape()[0];
308 let c = input.shape()[1];
309 let h = input.shape()[2];
310 let w = input.shape()[3];
311 let hw = h * w;
312
313 let residual = input.clone();
315
316 let hidden = input
321 .reshape_t(&[b as isize, c as isize, hw as isize])?
322 .transpose(1, 2)?
323 .contiguous()?;
324 let normed_hwc = self
326 .group_norm
327 .forward(&hidden.transpose(1, 2)?.contiguous()?)?
328 .transpose(1, 2)?
329 .contiguous()?;
330
331 let q = self.to_q.forward(&normed_hwc)?; let k = self.to_k.forward(&normed_hwc)?; let v = self.to_v.forward(&normed_hwc)?; let scale = (c as f64).sqrt().recip();
344 let scale_t = T::from(scale).ok_or_else(|| FerrotorchError::InvalidArgument {
345 message: "AttnBlock2D::forward: failed to cast attention scale into Float".into(),
346 })?;
347 let scale_tensor = ferrotorch_core::scalar::<T>(scale_t)?;
348 let k_t = k.transpose(1, 2)?.contiguous()?; let scores = q.bmm(&k_t)?; let scores_scaled = ferrotorch_core::grad_fns::arithmetic::mul(&scores, &scale_tensor)?;
351 let probs = scores_scaled.softmax()?; let attended = probs.bmm(&v)?; let projected = self.to_out_0.forward(&attended)?; let back = projected
359 .transpose(1, 2)? .reshape_t(&[b as isize, c as isize, h as isize, w as isize])?
361 .contiguous()?;
362 ferrotorch_core::grad_fns::arithmetic::add(&back, &residual)
363 }
364
365 fn parameters(&self) -> Vec<&Parameter<T>> {
366 let mut out = Vec::new();
367 out.extend(self.group_norm.parameters());
368 out.extend(self.to_q.parameters());
369 out.extend(self.to_k.parameters());
370 out.extend(self.to_v.parameters());
371 out.extend(self.to_out_0.parameters());
372 out
373 }
374
375 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
376 let mut out = Vec::new();
377 out.extend(self.group_norm.parameters_mut());
378 out.extend(self.to_q.parameters_mut());
379 out.extend(self.to_k.parameters_mut());
380 out.extend(self.to_v.parameters_mut());
381 out.extend(self.to_out_0.parameters_mut());
382 out
383 }
384
385 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
386 let mut out = Vec::new();
387 for (n, p) in self.group_norm.named_parameters() {
388 out.push((format!("group_norm.{n}"), p));
389 }
390 for (n, p) in self.to_q.named_parameters() {
391 out.push((format!("to_q.{n}"), p));
392 }
393 for (n, p) in self.to_k.named_parameters() {
394 out.push((format!("to_k.{n}"), p));
395 }
396 for (n, p) in self.to_v.named_parameters() {
397 out.push((format!("to_v.{n}"), p));
398 }
399 for (n, p) in self.to_out_0.named_parameters() {
400 out.push((format!("to_out.0.{n}"), p));
401 }
402 out
403 }
404
405 fn train(&mut self) {
406 self.training = true;
407 }
408 fn eval(&mut self) {
409 self.training = false;
410 }
411 fn is_training(&self) -> bool {
412 self.training
413 }
414
415 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
416 let extract = |prefix: &str| -> StateDict<T> {
417 let p = format!("{prefix}.");
418 state
419 .iter()
420 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
421 .collect()
422 };
423 if strict {
424 for k in state.keys() {
425 let ok = k.starts_with("group_norm.")
426 || k.starts_with("to_q.")
427 || k.starts_with("to_k.")
428 || k.starts_with("to_v.")
429 || k.starts_with("to_out.0.");
430 if !ok {
431 return Err(FerrotorchError::InvalidArgument {
432 message: format!("unexpected key in AttnBlock2D state_dict: \"{k}\""),
433 });
434 }
435 }
436 }
437 self.group_norm
438 .load_state_dict(&extract("group_norm"), strict)?;
439 self.to_q.load_state_dict(&extract("to_q"), strict)?;
440 self.to_k.load_state_dict(&extract("to_k"), strict)?;
441 self.to_v.load_state_dict(&extract("to_v"), strict)?;
442 self.to_out_0
443 .load_state_dict(&extract("to_out.0"), strict)?;
444 Ok(())
445 }
446}
447
448#[derive(Debug)]
458pub struct Upsample2D<T: Float> {
459 pub conv: Conv2d<T>,
461 channels: usize,
462 training: bool,
463}
464
465impl<T: Float> Upsample2D<T> {
466 pub fn new(channels: usize) -> FerrotorchResult<Self> {
472 let conv = Conv2d::<T>::new(channels, channels, (3, 3), (1, 1), (1, 1), true)?;
473 Ok(Self {
474 conv,
475 channels,
476 training: false,
477 })
478 }
479}
480
481impl<T: Float> Module<T> for Upsample2D<T> {
482 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
483 if input.ndim() != 4 || input.shape()[1] != self.channels {
484 return Err(FerrotorchError::ShapeMismatch {
485 message: format!(
486 "Upsample2D::forward: expected [B, {}, H, W], got {:?}",
487 self.channels,
488 input.shape()
489 ),
490 });
491 }
492 let h = input.shape()[2];
493 let w = input.shape()[3];
494 let up = Upsample::new([h * 2, w * 2], InterpolateMode::Nearest);
495 let upsampled = up.forward(input)?;
496 self.conv.forward(&upsampled)
497 }
498
499 fn parameters(&self) -> Vec<&Parameter<T>> {
500 self.conv.parameters()
501 }
502 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
503 self.conv.parameters_mut()
504 }
505 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
506 self.conv
507 .named_parameters()
508 .into_iter()
509 .map(|(n, p)| (format!("conv.{n}"), p))
510 .collect()
511 }
512 fn train(&mut self) {
513 self.training = true;
514 }
515 fn eval(&mut self) {
516 self.training = false;
517 }
518 fn is_training(&self) -> bool {
519 self.training
520 }
521 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
522 let conv_sd: StateDict<T> = state
523 .iter()
524 .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
525 .collect();
526 if strict {
527 for k in state.keys() {
528 if !k.starts_with("conv.") {
529 return Err(FerrotorchError::InvalidArgument {
530 message: format!("unexpected key in Upsample2D state_dict: \"{k}\""),
531 });
532 }
533 }
534 }
535 self.conv.load_state_dict(&conv_sd, strict)
536 }
537}
538
539#[derive(Debug)]
549pub struct Downsample2D<T: Float> {
550 pub conv: Conv2d<T>,
552 channels: usize,
553 training: bool,
554}
555
556impl<T: Float> Downsample2D<T> {
557 pub fn new(channels: usize) -> FerrotorchResult<Self> {
563 let conv = Conv2d::<T>::new(channels, channels, (3, 3), (2, 2), (1, 1), true)?;
564 Ok(Self {
565 conv,
566 channels,
567 training: false,
568 })
569 }
570}
571
572impl<T: Float> Module<T> for Downsample2D<T> {
573 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
574 if input.ndim() != 4 || input.shape()[1] != self.channels {
575 return Err(FerrotorchError::ShapeMismatch {
576 message: format!(
577 "Downsample2D::forward: expected [B, {}, H, W], got {:?}",
578 self.channels,
579 input.shape()
580 ),
581 });
582 }
583 self.conv.forward(input)
584 }
585 fn parameters(&self) -> Vec<&Parameter<T>> {
586 self.conv.parameters()
587 }
588 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
589 self.conv.parameters_mut()
590 }
591 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
592 self.conv
593 .named_parameters()
594 .into_iter()
595 .map(|(n, p)| (format!("conv.{n}"), p))
596 .collect()
597 }
598 fn train(&mut self) {
599 self.training = true;
600 }
601 fn eval(&mut self) {
602 self.training = false;
603 }
604 fn is_training(&self) -> bool {
605 self.training
606 }
607 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
608 let conv_sd: StateDict<T> = state
609 .iter()
610 .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
611 .collect();
612 if strict {
613 for k in state.keys() {
614 if !k.starts_with("conv.") {
615 return Err(FerrotorchError::InvalidArgument {
616 message: format!("unexpected key in Downsample2D state_dict: \"{k}\""),
617 });
618 }
619 }
620 }
621 self.conv.load_state_dict(&conv_sd, strict)
622 }
623}
624
625#[derive(Debug)]
639pub struct UpDecoderBlock2D<T: Float> {
640 pub resnets: Vec<ResnetBlock2D<T>>,
642 pub upsamplers_0: Option<Upsample2D<T>>,
644 training: bool,
645}
646
647impl<T: Float> UpDecoderBlock2D<T> {
648 pub fn new(
655 in_channels: usize,
656 out_channels: usize,
657 num_resnets: usize,
658 norm_num_groups: usize,
659 eps: f64,
660 add_upsample: bool,
661 ) -> FerrotorchResult<Self> {
662 if num_resnets == 0 {
663 return Err(FerrotorchError::InvalidArgument {
664 message: "UpDecoderBlock2D: num_resnets must be > 0".into(),
665 });
666 }
667 let mut resnets = Vec::with_capacity(num_resnets);
668 for i in 0..num_resnets {
669 let in_c = if i == 0 { in_channels } else { out_channels };
670 resnets.push(ResnetBlock2D::<T>::new(
671 in_c,
672 out_channels,
673 norm_num_groups,
674 eps,
675 )?);
676 }
677 let upsamplers_0 = if add_upsample {
678 Some(Upsample2D::<T>::new(out_channels)?)
679 } else {
680 None
681 };
682 Ok(Self {
683 resnets,
684 upsamplers_0,
685 training: false,
686 })
687 }
688}
689
690impl<T: Float> Module<T> for UpDecoderBlock2D<T> {
691 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
692 let mut h = input.clone();
693 for r in &self.resnets {
694 h = r.forward(&h)?;
695 }
696 if let Some(u) = &self.upsamplers_0 {
697 h = u.forward(&h)?;
698 }
699 Ok(h)
700 }
701
702 fn parameters(&self) -> Vec<&Parameter<T>> {
703 let mut out = Vec::new();
704 for r in &self.resnets {
705 out.extend(r.parameters());
706 }
707 if let Some(u) = &self.upsamplers_0 {
708 out.extend(u.parameters());
709 }
710 out
711 }
712
713 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
714 let mut out = Vec::new();
715 for r in &mut self.resnets {
716 out.extend(r.parameters_mut());
717 }
718 if let Some(u) = self.upsamplers_0.as_mut() {
719 out.extend(u.parameters_mut());
720 }
721 out
722 }
723
724 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
725 let mut out = Vec::new();
726 for (i, r) in self.resnets.iter().enumerate() {
727 for (n, p) in r.named_parameters() {
728 out.push((format!("resnets.{i}.{n}"), p));
729 }
730 }
731 if let Some(u) = &self.upsamplers_0 {
732 for (n, p) in u.named_parameters() {
733 out.push((format!("upsamplers.0.{n}"), p));
734 }
735 }
736 out
737 }
738
739 fn train(&mut self) {
740 self.training = true;
741 for r in &mut self.resnets {
742 r.train();
743 }
744 if let Some(u) = self.upsamplers_0.as_mut() {
745 u.train();
746 }
747 }
748 fn eval(&mut self) {
749 self.training = false;
750 for r in &mut self.resnets {
751 r.eval();
752 }
753 if let Some(u) = self.upsamplers_0.as_mut() {
754 u.eval();
755 }
756 }
757 fn is_training(&self) -> bool {
758 self.training
759 }
760
761 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
762 let extract = |prefix: &str| -> StateDict<T> {
763 let p = format!("{prefix}.");
764 state
765 .iter()
766 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
767 .collect()
768 };
769 if strict {
770 for k in state.keys() {
771 let ok = k.starts_with("resnets.") || k.starts_with("upsamplers.0.");
772 if !ok {
773 return Err(FerrotorchError::InvalidArgument {
774 message: format!("unexpected key in UpDecoderBlock2D state_dict: \"{k}\""),
775 });
776 }
777 }
778 }
779 for (i, r) in self.resnets.iter_mut().enumerate() {
780 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
781 }
782 if let Some(u) = self.upsamplers_0.as_mut() {
783 u.load_state_dict(&extract("upsamplers.0"), strict)?;
784 } else if strict {
785 for k in state.keys() {
787 if k.starts_with("upsamplers.") {
788 return Err(FerrotorchError::InvalidArgument {
789 message: format!(
790 "UpDecoderBlock2D has no upsampler but state_dict contains \"{k}\""
791 ),
792 });
793 }
794 }
795 }
796 Ok(())
797 }
798}
799
800#[derive(Debug)]
819pub struct DownEncoderBlock2D<T: Float> {
820 pub resnets: Vec<ResnetBlock2D<T>>,
825 pub downsamplers_0: Option<Downsample2D<T>>,
828 training: bool,
829}
830
831impl<T: Float> DownEncoderBlock2D<T> {
832 pub fn new(
840 in_channels: usize,
841 out_channels: usize,
842 num_resnets: usize,
843 norm_num_groups: usize,
844 eps: f64,
845 add_downsample: bool,
846 ) -> FerrotorchResult<Self> {
847 if num_resnets == 0 {
848 return Err(FerrotorchError::InvalidArgument {
849 message: "DownEncoderBlock2D: num_resnets must be > 0".into(),
850 });
851 }
852 let mut resnets = Vec::with_capacity(num_resnets);
853 for i in 0..num_resnets {
854 let in_c = if i == 0 { in_channels } else { out_channels };
855 resnets.push(ResnetBlock2D::<T>::new(
856 in_c,
857 out_channels,
858 norm_num_groups,
859 eps,
860 )?);
861 }
862 let downsamplers_0 = if add_downsample {
863 Some(Downsample2D::<T>::new(out_channels)?)
864 } else {
865 None
866 };
867 Ok(Self {
868 resnets,
869 downsamplers_0,
870 training: false,
871 })
872 }
873}
874
875impl<T: Float> Module<T> for DownEncoderBlock2D<T> {
876 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
877 let mut h = input.clone();
878 for r in &self.resnets {
879 h = r.forward(&h)?;
880 }
881 if let Some(d) = &self.downsamplers_0 {
882 h = d.forward(&h)?;
883 }
884 Ok(h)
885 }
886
887 fn parameters(&self) -> Vec<&Parameter<T>> {
888 let mut out = Vec::new();
889 for r in &self.resnets {
890 out.extend(r.parameters());
891 }
892 if let Some(d) = &self.downsamplers_0 {
893 out.extend(d.parameters());
894 }
895 out
896 }
897
898 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
899 let mut out = Vec::new();
900 for r in &mut self.resnets {
901 out.extend(r.parameters_mut());
902 }
903 if let Some(d) = self.downsamplers_0.as_mut() {
904 out.extend(d.parameters_mut());
905 }
906 out
907 }
908
909 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
910 let mut out = Vec::new();
911 for (i, r) in self.resnets.iter().enumerate() {
912 for (n, p) in r.named_parameters() {
913 out.push((format!("resnets.{i}.{n}"), p));
914 }
915 }
916 if let Some(d) = &self.downsamplers_0 {
917 for (n, p) in d.named_parameters() {
918 out.push((format!("downsamplers.0.{n}"), p));
919 }
920 }
921 out
922 }
923
924 fn train(&mut self) {
925 self.training = true;
926 for r in &mut self.resnets {
927 r.train();
928 }
929 if let Some(d) = self.downsamplers_0.as_mut() {
930 d.train();
931 }
932 }
933 fn eval(&mut self) {
934 self.training = false;
935 for r in &mut self.resnets {
936 r.eval();
937 }
938 if let Some(d) = self.downsamplers_0.as_mut() {
939 d.eval();
940 }
941 }
942 fn is_training(&self) -> bool {
943 self.training
944 }
945
946 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
947 let extract = |prefix: &str| -> StateDict<T> {
948 let p = format!("{prefix}.");
949 state
950 .iter()
951 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
952 .collect()
953 };
954 if strict {
955 for k in state.keys() {
956 let ok = k.starts_with("resnets.") || k.starts_with("downsamplers.0.");
957 if !ok {
958 return Err(FerrotorchError::InvalidArgument {
959 message: format!(
960 "unexpected key in DownEncoderBlock2D state_dict: \"{k}\""
961 ),
962 });
963 }
964 }
965 }
966 for (i, r) in self.resnets.iter_mut().enumerate() {
967 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
968 }
969 if let Some(d) = self.downsamplers_0.as_mut() {
970 d.load_state_dict(&extract("downsamplers.0"), strict)?;
971 } else if strict {
972 for k in state.keys() {
973 if k.starts_with("downsamplers.") {
974 return Err(FerrotorchError::InvalidArgument {
975 message: format!(
976 "DownEncoderBlock2D has no downsampler but state_dict contains \"{k}\""
977 ),
978 });
979 }
980 }
981 }
982 Ok(())
983 }
984}
985
986#[derive(Debug)]
1001pub struct UNetMidBlock2D<T: Float> {
1002 pub resnets: Vec<ResnetBlock2D<T>>,
1004 pub attentions: Vec<AttnBlock2D<T>>,
1006 channels: usize,
1007 training: bool,
1008}
1009
1010impl<T: Float> UNetMidBlock2D<T> {
1011 pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
1017 let resnets = vec![
1019 ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
1020 ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
1021 ];
1022 let attentions = vec![AttnBlock2D::<T>::new(channels, norm_num_groups, eps)?];
1023 Ok(Self {
1024 resnets,
1025 attentions,
1026 channels,
1027 training: false,
1028 })
1029 }
1030}
1031
1032impl<T: Float> Module<T> for UNetMidBlock2D<T> {
1033 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1034 let mut h = self.resnets[0].forward(input)?;
1039 for (attn, resnet) in self.attentions.iter().zip(self.resnets.iter().skip(1)) {
1040 h = attn.forward(&h)?;
1041 h = resnet.forward(&h)?;
1042 }
1043 Ok(h)
1044 }
1045
1046 fn parameters(&self) -> Vec<&Parameter<T>> {
1047 let mut out = Vec::new();
1048 for a in &self.attentions {
1053 out.extend(a.parameters());
1054 }
1055 for r in &self.resnets {
1056 out.extend(r.parameters());
1057 }
1058 out
1059 }
1060
1061 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1062 let mut out = Vec::new();
1063 for a in &mut self.attentions {
1064 out.extend(a.parameters_mut());
1065 }
1066 for r in &mut self.resnets {
1067 out.extend(r.parameters_mut());
1068 }
1069 out
1070 }
1071
1072 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1073 let mut out = Vec::new();
1074 for (i, a) in self.attentions.iter().enumerate() {
1077 for (n, p) in a.named_parameters() {
1078 out.push((format!("attentions.{i}.{n}"), p));
1079 }
1080 }
1081 for (i, r) in self.resnets.iter().enumerate() {
1082 for (n, p) in r.named_parameters() {
1083 out.push((format!("resnets.{i}.{n}"), p));
1084 }
1085 }
1086 out
1087 }
1088
1089 fn train(&mut self) {
1090 self.training = true;
1091 for r in &mut self.resnets {
1092 r.train();
1093 }
1094 for a in &mut self.attentions {
1095 a.train();
1096 }
1097 }
1098 fn eval(&mut self) {
1099 self.training = false;
1100 for r in &mut self.resnets {
1101 r.eval();
1102 }
1103 for a in &mut self.attentions {
1104 a.eval();
1105 }
1106 }
1107 fn is_training(&self) -> bool {
1108 self.training
1109 }
1110
1111 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1112 let extract = |prefix: &str| -> StateDict<T> {
1113 let p = format!("{prefix}.");
1114 state
1115 .iter()
1116 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
1117 .collect()
1118 };
1119 if strict {
1120 for k in state.keys() {
1121 let ok = k.starts_with("attentions.") || k.starts_with("resnets.");
1122 if !ok {
1123 return Err(FerrotorchError::InvalidArgument {
1124 message: format!("unexpected key in UNetMidBlock2D state_dict: \"{k}\""),
1125 });
1126 }
1127 }
1128 }
1129 for (i, a) in self.attentions.iter_mut().enumerate() {
1130 a.load_state_dict(&extract(&format!("attentions.{i}")), strict)?;
1131 }
1132 for (i, r) in self.resnets.iter_mut().enumerate() {
1133 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
1134 }
1135 let _ = self.channels; let _: HashMap<String, Tensor<T>> = HashMap::new(); Ok(())
1138 }
1139}
1140
1141#[cfg(test)]
1142mod tests {
1143 use super::*;
1144 use ferrotorch_core::TensorStorage;
1145
1146 #[test]
1147 fn resnet_same_channels_no_shortcut() {
1148 let r = ResnetBlock2D::<f32>::new(32, 32, 32, 1e-6).unwrap();
1149 assert!(r.conv_shortcut.is_none());
1150 let x = Tensor::from_storage(
1151 TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1152 vec![1, 32, 4, 4],
1153 false,
1154 )
1155 .unwrap();
1156 let y = r.forward(&x).unwrap();
1157 assert_eq!(y.shape(), &[1, 32, 4, 4]);
1158 for &v in y.data().unwrap() {
1159 assert!(v.is_finite());
1160 }
1161 }
1162
1163 #[test]
1164 fn resnet_different_channels_has_shortcut() {
1165 let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
1166 assert!(r.conv_shortcut.is_some());
1167 let x = Tensor::from_storage(
1168 TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1169 vec![1, 32, 4, 4],
1170 false,
1171 )
1172 .unwrap();
1173 let y = r.forward(&x).unwrap();
1174 assert_eq!(y.shape(), &[1, 64, 4, 4]);
1175 }
1176
1177 #[test]
1178 fn resnet_named_parameters_layout() {
1179 let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
1180 let names: Vec<String> = r.named_parameters().into_iter().map(|(n, _)| n).collect();
1181 for k in [
1182 "norm1.weight",
1183 "norm1.bias",
1184 "conv1.weight",
1185 "conv1.bias",
1186 "norm2.weight",
1187 "norm2.bias",
1188 "conv2.weight",
1189 "conv2.bias",
1190 "conv_shortcut.weight",
1191 "conv_shortcut.bias",
1192 ] {
1193 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1194 }
1195 }
1196
1197 #[test]
1198 fn attn_shape_and_residual() {
1199 let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1201 let x = Tensor::from_storage(
1202 TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1203 vec![1, 32, 4, 4],
1204 false,
1205 )
1206 .unwrap();
1207 let y = a.forward(&x).unwrap();
1208 assert_eq!(y.shape(), &[1, 32, 4, 4]);
1209 for &v in y.data().unwrap() {
1210 assert!(v.is_finite(), "attn output non-finite: {v}");
1211 }
1212 }
1213
1214 #[test]
1215 fn attn_named_parameters_layout() {
1216 let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1217 let names: Vec<String> = a.named_parameters().into_iter().map(|(n, _)| n).collect();
1218 for k in [
1219 "group_norm.weight",
1220 "group_norm.bias",
1221 "to_q.weight",
1222 "to_q.bias",
1223 "to_k.weight",
1224 "to_k.bias",
1225 "to_v.weight",
1226 "to_v.bias",
1227 "to_out.0.weight",
1228 "to_out.0.bias",
1229 ] {
1230 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1231 }
1232 }
1233
1234 #[test]
1235 fn upsample2d_doubles_spatial() {
1236 let u = Upsample2D::<f32>::new(8).unwrap();
1237 let x = Tensor::from_storage(
1238 TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1239 vec![1, 8, 3, 3],
1240 false,
1241 )
1242 .unwrap();
1243 let y = u.forward(&x).unwrap();
1244 assert_eq!(y.shape(), &[1, 8, 6, 6]);
1245 }
1246
1247 #[test]
1248 fn up_decoder_block_shape_with_upsample() {
1249 let b = UpDecoderBlock2D::<f32>::new(16, 8, 2, 4, 1e-6, true).unwrap();
1250 let x = Tensor::from_storage(
1251 TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1252 vec![1, 16, 3, 3],
1253 false,
1254 )
1255 .unwrap();
1256 let y = b.forward(&x).unwrap();
1257 assert_eq!(y.shape(), &[1, 8, 6, 6]);
1258 }
1259
1260 #[test]
1261 fn up_decoder_block_shape_no_upsample() {
1262 let b = UpDecoderBlock2D::<f32>::new(8, 8, 2, 4, 1e-6, false).unwrap();
1263 let x = Tensor::from_storage(
1264 TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1265 vec![1, 8, 3, 3],
1266 false,
1267 )
1268 .unwrap();
1269 let y = b.forward(&x).unwrap();
1270 assert_eq!(y.shape(), &[1, 8, 3, 3]);
1271 }
1272
1273 #[test]
1274 fn down_encoder_block_shape_with_downsample() {
1275 let b = DownEncoderBlock2D::<f32>::new(8, 16, 2, 4, 1e-6, true).unwrap();
1277 let x = Tensor::from_storage(
1278 TensorStorage::cpu(vec![0.05f32; 8 * 4 * 4]),
1279 vec![1, 8, 4, 4],
1280 false,
1281 )
1282 .unwrap();
1283 let y = b.forward(&x).unwrap();
1284 assert_eq!(y.shape(), &[1, 16, 2, 2]);
1285 }
1286
1287 #[test]
1288 fn down_encoder_block_shape_no_downsample() {
1289 let b = DownEncoderBlock2D::<f32>::new(16, 16, 2, 4, 1e-6, false).unwrap();
1291 let x = Tensor::from_storage(
1292 TensorStorage::cpu(vec![0.05f32; 16 * 2 * 2]),
1293 vec![1, 16, 2, 2],
1294 false,
1295 )
1296 .unwrap();
1297 let y = b.forward(&x).unwrap();
1298 assert_eq!(y.shape(), &[1, 16, 2, 2]);
1299 }
1300
1301 #[test]
1302 fn down_encoder_block_named_parameters_layout() {
1303 let b = DownEncoderBlock2D::<f32>::new(8, 16, 2, 4, 1e-6, true).unwrap();
1304 let names: Vec<String> = b.named_parameters().into_iter().map(|(n, _)| n).collect();
1305 for k in [
1306 "resnets.0.norm1.weight",
1307 "resnets.0.conv1.weight",
1308 "resnets.0.conv_shortcut.weight",
1309 "resnets.1.norm1.weight",
1310 "downsamplers.0.conv.weight",
1311 "downsamplers.0.conv.bias",
1312 ] {
1313 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1314 }
1315 }
1316
1317 #[test]
1318 fn mid_block_shape() {
1319 let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1320 let x = Tensor::from_storage(
1321 TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1322 vec![1, 16, 3, 3],
1323 false,
1324 )
1325 .unwrap();
1326 let y = m.forward(&x).unwrap();
1327 assert_eq!(y.shape(), &[1, 16, 3, 3]);
1328 }
1329
1330 #[test]
1331 fn mid_block_named_parameters_layout() {
1332 let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1333 let names: Vec<String> = m.named_parameters().into_iter().map(|(n, _)| n).collect();
1334 for k in [
1335 "attentions.0.group_norm.weight",
1336 "attentions.0.to_q.weight",
1337 "resnets.0.norm1.weight",
1338 "resnets.0.conv1.weight",
1339 "resnets.1.conv2.weight",
1340 ] {
1341 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1342 }
1343 }
1344}