1use std::collections::HashMap;
10
11use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
12use ferrotorch_nn::module::{Module, StateDict};
13use ferrotorch_nn::parameter::Parameter;
14use ferrotorch_nn::{
15 Conv2d, GroupNorm, InterpolateMode, Linear, SiLU, Upsample,
16};
17
18#[derive(Debug)]
38pub struct ResnetBlock2D<T: Float> {
39 pub norm1: GroupNorm<T>,
41 pub conv1: Conv2d<T>,
43 pub norm2: GroupNorm<T>,
45 pub conv2: Conv2d<T>,
47 pub conv_shortcut: Option<Conv2d<T>>,
50 activation: SiLU,
51 in_channels: usize,
52 #[allow(dead_code)]
53 out_channels: usize,
54 training: bool,
55}
56
57impl<T: Float> ResnetBlock2D<T> {
58 pub fn new(
65 in_channels: usize,
66 out_channels: usize,
67 norm_num_groups: usize,
68 eps: f64,
69 ) -> FerrotorchResult<Self> {
70 let norm1 = GroupNorm::<T>::new(norm_num_groups, in_channels, eps, true)?;
71 let conv1 = Conv2d::<T>::new(in_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
72 let norm2 = GroupNorm::<T>::new(norm_num_groups, out_channels, eps, true)?;
73 let conv2 = Conv2d::<T>::new(out_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
74 let conv_shortcut = if in_channels == out_channels {
75 None
76 } else {
77 Some(Conv2d::<T>::new(
78 in_channels,
79 out_channels,
80 (1, 1),
81 (1, 1),
82 (0, 0),
83 true,
84 )?)
85 };
86 Ok(Self {
87 norm1,
88 conv1,
89 norm2,
90 conv2,
91 conv_shortcut,
92 activation: SiLU::new(),
93 in_channels,
94 out_channels,
95 training: false,
96 })
97 }
98}
99
100impl<T: Float> Module<T> for ResnetBlock2D<T> {
101 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
102 if input.ndim() != 4 || input.shape()[1] != self.in_channels {
103 return Err(FerrotorchError::ShapeMismatch {
104 message: format!(
105 "ResnetBlock2D::forward: expected [B, {}, H, W], got {:?}",
106 self.in_channels,
107 input.shape()
108 ),
109 });
110 }
111 let mut h = self.norm1.forward(input)?;
113 h = self.activation.forward(&h)?;
114 h = self.conv1.forward(&h)?;
115 h = self.norm2.forward(&h)?;
117 h = self.activation.forward(&h)?;
118 h = self.conv2.forward(&h)?;
119 let res = if let Some(sc) = &self.conv_shortcut {
121 sc.forward(input)?
122 } else {
123 input.clone()
124 };
125 ferrotorch_core::grad_fns::arithmetic::add(&h, &res)
128 }
129
130 fn parameters(&self) -> Vec<&Parameter<T>> {
131 let mut out = Vec::new();
132 out.extend(self.norm1.parameters());
133 out.extend(self.conv1.parameters());
134 out.extend(self.norm2.parameters());
135 out.extend(self.conv2.parameters());
136 if let Some(sc) = &self.conv_shortcut {
137 out.extend(sc.parameters());
138 }
139 out
140 }
141
142 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
143 let mut out = Vec::new();
144 out.extend(self.norm1.parameters_mut());
145 out.extend(self.conv1.parameters_mut());
146 out.extend(self.norm2.parameters_mut());
147 out.extend(self.conv2.parameters_mut());
148 if let Some(sc) = &mut self.conv_shortcut {
149 out.extend(sc.parameters_mut());
150 }
151 out
152 }
153
154 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
155 let mut out = Vec::new();
156 for (n, p) in self.norm1.named_parameters() {
157 out.push((format!("norm1.{n}"), p));
158 }
159 for (n, p) in self.conv1.named_parameters() {
160 out.push((format!("conv1.{n}"), p));
161 }
162 for (n, p) in self.norm2.named_parameters() {
163 out.push((format!("norm2.{n}"), p));
164 }
165 for (n, p) in self.conv2.named_parameters() {
166 out.push((format!("conv2.{n}"), p));
167 }
168 if let Some(sc) = &self.conv_shortcut {
169 for (n, p) in sc.named_parameters() {
170 out.push((format!("conv_shortcut.{n}"), p));
171 }
172 }
173 out
174 }
175
176 fn train(&mut self) {
177 self.training = true;
178 }
179 fn eval(&mut self) {
180 self.training = false;
181 }
182 fn is_training(&self) -> bool {
183 self.training
184 }
185
186 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
187 let extract = |prefix: &str| -> StateDict<T> {
188 let p = format!("{prefix}.");
189 state
190 .iter()
191 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
192 .collect()
193 };
194 if strict {
195 for k in state.keys() {
196 let ok = k.starts_with("norm1.")
197 || k.starts_with("conv1.")
198 || k.starts_with("norm2.")
199 || k.starts_with("conv2.")
200 || k.starts_with("conv_shortcut.");
201 if !ok {
202 return Err(FerrotorchError::InvalidArgument {
203 message: format!("unexpected key in ResnetBlock2D state_dict: \"{k}\""),
204 });
205 }
206 }
207 }
208 self.norm1.load_state_dict(&extract("norm1"), strict)?;
209 self.conv1.load_state_dict(&extract("conv1"), strict)?;
210 self.norm2.load_state_dict(&extract("norm2"), strict)?;
211 self.conv2.load_state_dict(&extract("conv2"), strict)?;
212 if let Some(sc) = self.conv_shortcut.as_mut() {
213 sc.load_state_dict(&extract("conv_shortcut"), strict)?;
214 }
215 Ok(())
216 }
217}
218
219#[derive(Debug)]
245pub struct AttnBlock2D<T: Float> {
246 pub group_norm: GroupNorm<T>,
248 pub to_q: Linear<T>,
250 pub to_k: Linear<T>,
252 pub to_v: Linear<T>,
254 pub to_out_0: Linear<T>,
256 channels: usize,
257 training: bool,
258}
259
260impl<T: Float> AttnBlock2D<T> {
261 pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
269 let group_norm = GroupNorm::<T>::new(norm_num_groups, channels, eps, true)?;
270 let to_q = Linear::<T>::new(channels, channels, true)?;
271 let to_k = Linear::<T>::new(channels, channels, true)?;
272 let to_v = Linear::<T>::new(channels, channels, true)?;
273 let to_out_0 = Linear::<T>::new(channels, channels, true)?;
274 Ok(Self {
275 group_norm,
276 to_q,
277 to_k,
278 to_v,
279 to_out_0,
280 channels,
281 training: false,
282 })
283 }
284}
285
286impl<T: Float> Module<T> for AttnBlock2D<T> {
287 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
288 if input.ndim() != 4 || input.shape()[1] != self.channels {
289 return Err(FerrotorchError::ShapeMismatch {
290 message: format!(
291 "AttnBlock2D::forward: expected [B, {}, H, W], got {:?}",
292 self.channels,
293 input.shape()
294 ),
295 });
296 }
297 let b = input.shape()[0];
298 let c = input.shape()[1];
299 let h = input.shape()[2];
300 let w = input.shape()[3];
301 let hw = h * w;
302
303 let residual = input.clone();
305
306 let hidden = input
311 .reshape_t(&[b as isize, c as isize, hw as isize])?
312 .transpose(1, 2)?
313 .contiguous()?;
314 let normed_hwc = self
316 .group_norm
317 .forward(&hidden.transpose(1, 2)?.contiguous()?)?
318 .transpose(1, 2)?
319 .contiguous()?;
320
321 let q = self.to_q.forward(&normed_hwc)?; let k = self.to_k.forward(&normed_hwc)?; let v = self.to_v.forward(&normed_hwc)?; let scale = (c as f64).sqrt().recip();
334 let scale_t = T::from(scale).ok_or_else(|| FerrotorchError::InvalidArgument {
335 message: "AttnBlock2D::forward: failed to cast attention scale into Float".into(),
336 })?;
337 let scale_tensor = ferrotorch_core::scalar::<T>(scale_t)?;
338 let k_t = k.transpose(1, 2)?.contiguous()?; let scores = q.bmm(&k_t)?; let scores_scaled =
341 ferrotorch_core::grad_fns::arithmetic::mul(&scores, &scale_tensor)?;
342 let probs = scores_scaled.softmax()?; let attended = probs.bmm(&v)?; let projected = self.to_out_0.forward(&attended)?; let back = projected
350 .transpose(1, 2)? .reshape_t(&[b as isize, c as isize, h as isize, w as isize])?
352 .contiguous()?;
353 ferrotorch_core::grad_fns::arithmetic::add(&back, &residual)
354 }
355
356 fn parameters(&self) -> Vec<&Parameter<T>> {
357 let mut out = Vec::new();
358 out.extend(self.group_norm.parameters());
359 out.extend(self.to_q.parameters());
360 out.extend(self.to_k.parameters());
361 out.extend(self.to_v.parameters());
362 out.extend(self.to_out_0.parameters());
363 out
364 }
365
366 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
367 let mut out = Vec::new();
368 out.extend(self.group_norm.parameters_mut());
369 out.extend(self.to_q.parameters_mut());
370 out.extend(self.to_k.parameters_mut());
371 out.extend(self.to_v.parameters_mut());
372 out.extend(self.to_out_0.parameters_mut());
373 out
374 }
375
376 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
377 let mut out = Vec::new();
378 for (n, p) in self.group_norm.named_parameters() {
379 out.push((format!("group_norm.{n}"), p));
380 }
381 for (n, p) in self.to_q.named_parameters() {
382 out.push((format!("to_q.{n}"), p));
383 }
384 for (n, p) in self.to_k.named_parameters() {
385 out.push((format!("to_k.{n}"), p));
386 }
387 for (n, p) in self.to_v.named_parameters() {
388 out.push((format!("to_v.{n}"), p));
389 }
390 for (n, p) in self.to_out_0.named_parameters() {
391 out.push((format!("to_out.0.{n}"), p));
392 }
393 out
394 }
395
396 fn train(&mut self) {
397 self.training = true;
398 }
399 fn eval(&mut self) {
400 self.training = false;
401 }
402 fn is_training(&self) -> bool {
403 self.training
404 }
405
406 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
407 let extract = |prefix: &str| -> StateDict<T> {
408 let p = format!("{prefix}.");
409 state
410 .iter()
411 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
412 .collect()
413 };
414 if strict {
415 for k in state.keys() {
416 let ok = k.starts_with("group_norm.")
417 || k.starts_with("to_q.")
418 || k.starts_with("to_k.")
419 || k.starts_with("to_v.")
420 || k.starts_with("to_out.0.");
421 if !ok {
422 return Err(FerrotorchError::InvalidArgument {
423 message: format!("unexpected key in AttnBlock2D state_dict: \"{k}\""),
424 });
425 }
426 }
427 }
428 self.group_norm
429 .load_state_dict(&extract("group_norm"), strict)?;
430 self.to_q.load_state_dict(&extract("to_q"), strict)?;
431 self.to_k.load_state_dict(&extract("to_k"), strict)?;
432 self.to_v.load_state_dict(&extract("to_v"), strict)?;
433 self.to_out_0
434 .load_state_dict(&extract("to_out.0"), strict)?;
435 Ok(())
436 }
437}
438
439#[derive(Debug)]
449pub struct Upsample2D<T: Float> {
450 pub conv: Conv2d<T>,
452 channels: usize,
453 training: bool,
454}
455
456impl<T: Float> Upsample2D<T> {
457 pub fn new(channels: usize) -> FerrotorchResult<Self> {
463 let conv = Conv2d::<T>::new(channels, channels, (3, 3), (1, 1), (1, 1), true)?;
464 Ok(Self {
465 conv,
466 channels,
467 training: false,
468 })
469 }
470}
471
472impl<T: Float> Module<T> for Upsample2D<T> {
473 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
474 if input.ndim() != 4 || input.shape()[1] != self.channels {
475 return Err(FerrotorchError::ShapeMismatch {
476 message: format!(
477 "Upsample2D::forward: expected [B, {}, H, W], got {:?}",
478 self.channels,
479 input.shape()
480 ),
481 });
482 }
483 let h = input.shape()[2];
484 let w = input.shape()[3];
485 let up = Upsample::new([h * 2, w * 2], InterpolateMode::Nearest);
486 let upsampled = up.forward(input)?;
487 self.conv.forward(&upsampled)
488 }
489
490 fn parameters(&self) -> Vec<&Parameter<T>> {
491 self.conv.parameters()
492 }
493 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
494 self.conv.parameters_mut()
495 }
496 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
497 self.conv
498 .named_parameters()
499 .into_iter()
500 .map(|(n, p)| (format!("conv.{n}"), p))
501 .collect()
502 }
503 fn train(&mut self) {
504 self.training = true;
505 }
506 fn eval(&mut self) {
507 self.training = false;
508 }
509 fn is_training(&self) -> bool {
510 self.training
511 }
512 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
513 let conv_sd: StateDict<T> = state
514 .iter()
515 .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
516 .collect();
517 if strict {
518 for k in state.keys() {
519 if !k.starts_with("conv.") {
520 return Err(FerrotorchError::InvalidArgument {
521 message: format!("unexpected key in Upsample2D state_dict: \"{k}\""),
522 });
523 }
524 }
525 }
526 self.conv.load_state_dict(&conv_sd, strict)
527 }
528}
529
530#[derive(Debug)]
540pub struct Downsample2D<T: Float> {
541 pub conv: Conv2d<T>,
543 channels: usize,
544 training: bool,
545}
546
547impl<T: Float> Downsample2D<T> {
548 pub fn new(channels: usize) -> FerrotorchResult<Self> {
554 let conv = Conv2d::<T>::new(channels, channels, (3, 3), (2, 2), (1, 1), true)?;
555 Ok(Self {
556 conv,
557 channels,
558 training: false,
559 })
560 }
561}
562
563impl<T: Float> Module<T> for Downsample2D<T> {
564 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
565 if input.ndim() != 4 || input.shape()[1] != self.channels {
566 return Err(FerrotorchError::ShapeMismatch {
567 message: format!(
568 "Downsample2D::forward: expected [B, {}, H, W], got {:?}",
569 self.channels,
570 input.shape()
571 ),
572 });
573 }
574 self.conv.forward(input)
575 }
576 fn parameters(&self) -> Vec<&Parameter<T>> {
577 self.conv.parameters()
578 }
579 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
580 self.conv.parameters_mut()
581 }
582 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
583 self.conv
584 .named_parameters()
585 .into_iter()
586 .map(|(n, p)| (format!("conv.{n}"), p))
587 .collect()
588 }
589 fn train(&mut self) {
590 self.training = true;
591 }
592 fn eval(&mut self) {
593 self.training = false;
594 }
595 fn is_training(&self) -> bool {
596 self.training
597 }
598 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
599 let conv_sd: StateDict<T> = state
600 .iter()
601 .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
602 .collect();
603 if strict {
604 for k in state.keys() {
605 if !k.starts_with("conv.") {
606 return Err(FerrotorchError::InvalidArgument {
607 message: format!("unexpected key in Downsample2D state_dict: \"{k}\""),
608 });
609 }
610 }
611 }
612 self.conv.load_state_dict(&conv_sd, strict)
613 }
614}
615
616#[derive(Debug)]
630pub struct UpDecoderBlock2D<T: Float> {
631 pub resnets: Vec<ResnetBlock2D<T>>,
633 pub upsamplers_0: Option<Upsample2D<T>>,
635 training: bool,
636}
637
638impl<T: Float> UpDecoderBlock2D<T> {
639 pub fn new(
646 in_channels: usize,
647 out_channels: usize,
648 num_resnets: usize,
649 norm_num_groups: usize,
650 eps: f64,
651 add_upsample: bool,
652 ) -> FerrotorchResult<Self> {
653 if num_resnets == 0 {
654 return Err(FerrotorchError::InvalidArgument {
655 message: "UpDecoderBlock2D: num_resnets must be > 0".into(),
656 });
657 }
658 let mut resnets = Vec::with_capacity(num_resnets);
659 for i in 0..num_resnets {
660 let in_c = if i == 0 { in_channels } else { out_channels };
661 resnets.push(ResnetBlock2D::<T>::new(
662 in_c,
663 out_channels,
664 norm_num_groups,
665 eps,
666 )?);
667 }
668 let upsamplers_0 = if add_upsample {
669 Some(Upsample2D::<T>::new(out_channels)?)
670 } else {
671 None
672 };
673 Ok(Self {
674 resnets,
675 upsamplers_0,
676 training: false,
677 })
678 }
679}
680
681impl<T: Float> Module<T> for UpDecoderBlock2D<T> {
682 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
683 let mut h = input.clone();
684 for r in &self.resnets {
685 h = r.forward(&h)?;
686 }
687 if let Some(u) = &self.upsamplers_0 {
688 h = u.forward(&h)?;
689 }
690 Ok(h)
691 }
692
693 fn parameters(&self) -> Vec<&Parameter<T>> {
694 let mut out = Vec::new();
695 for r in &self.resnets {
696 out.extend(r.parameters());
697 }
698 if let Some(u) = &self.upsamplers_0 {
699 out.extend(u.parameters());
700 }
701 out
702 }
703
704 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
705 let mut out = Vec::new();
706 for r in &mut self.resnets {
707 out.extend(r.parameters_mut());
708 }
709 if let Some(u) = self.upsamplers_0.as_mut() {
710 out.extend(u.parameters_mut());
711 }
712 out
713 }
714
715 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
716 let mut out = Vec::new();
717 for (i, r) in self.resnets.iter().enumerate() {
718 for (n, p) in r.named_parameters() {
719 out.push((format!("resnets.{i}.{n}"), p));
720 }
721 }
722 if let Some(u) = &self.upsamplers_0 {
723 for (n, p) in u.named_parameters() {
724 out.push((format!("upsamplers.0.{n}"), p));
725 }
726 }
727 out
728 }
729
730 fn train(&mut self) {
731 self.training = true;
732 for r in &mut self.resnets {
733 r.train();
734 }
735 if let Some(u) = self.upsamplers_0.as_mut() {
736 u.train();
737 }
738 }
739 fn eval(&mut self) {
740 self.training = false;
741 for r in &mut self.resnets {
742 r.eval();
743 }
744 if let Some(u) = self.upsamplers_0.as_mut() {
745 u.eval();
746 }
747 }
748 fn is_training(&self) -> bool {
749 self.training
750 }
751
752 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
753 let extract = |prefix: &str| -> StateDict<T> {
754 let p = format!("{prefix}.");
755 state
756 .iter()
757 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
758 .collect()
759 };
760 if strict {
761 for k in state.keys() {
762 let ok =
763 k.starts_with("resnets.") || k.starts_with("upsamplers.0.");
764 if !ok {
765 return Err(FerrotorchError::InvalidArgument {
766 message: format!(
767 "unexpected key in UpDecoderBlock2D state_dict: \"{k}\""
768 ),
769 });
770 }
771 }
772 }
773 for (i, r) in self.resnets.iter_mut().enumerate() {
774 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
775 }
776 if let Some(u) = self.upsamplers_0.as_mut() {
777 u.load_state_dict(&extract("upsamplers.0"), strict)?;
778 } else if strict {
779 for k in state.keys() {
781 if k.starts_with("upsamplers.") {
782 return Err(FerrotorchError::InvalidArgument {
783 message: format!(
784 "UpDecoderBlock2D has no upsampler but state_dict contains \"{k}\""
785 ),
786 });
787 }
788 }
789 }
790 Ok(())
791 }
792}
793
794#[derive(Debug)]
813pub struct DownEncoderBlock2D<T: Float> {
814 pub resnets: Vec<ResnetBlock2D<T>>,
819 pub downsamplers_0: Option<Downsample2D<T>>,
822 training: bool,
823}
824
825impl<T: Float> DownEncoderBlock2D<T> {
826 pub fn new(
834 in_channels: usize,
835 out_channels: usize,
836 num_resnets: usize,
837 norm_num_groups: usize,
838 eps: f64,
839 add_downsample: bool,
840 ) -> FerrotorchResult<Self> {
841 if num_resnets == 0 {
842 return Err(FerrotorchError::InvalidArgument {
843 message: "DownEncoderBlock2D: num_resnets must be > 0".into(),
844 });
845 }
846 let mut resnets = Vec::with_capacity(num_resnets);
847 for i in 0..num_resnets {
848 let in_c = if i == 0 { in_channels } else { out_channels };
849 resnets.push(ResnetBlock2D::<T>::new(
850 in_c,
851 out_channels,
852 norm_num_groups,
853 eps,
854 )?);
855 }
856 let downsamplers_0 = if add_downsample {
857 Some(Downsample2D::<T>::new(out_channels)?)
858 } else {
859 None
860 };
861 Ok(Self {
862 resnets,
863 downsamplers_0,
864 training: false,
865 })
866 }
867}
868
869impl<T: Float> Module<T> for DownEncoderBlock2D<T> {
870 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
871 let mut h = input.clone();
872 for r in &self.resnets {
873 h = r.forward(&h)?;
874 }
875 if let Some(d) = &self.downsamplers_0 {
876 h = d.forward(&h)?;
877 }
878 Ok(h)
879 }
880
881 fn parameters(&self) -> Vec<&Parameter<T>> {
882 let mut out = Vec::new();
883 for r in &self.resnets {
884 out.extend(r.parameters());
885 }
886 if let Some(d) = &self.downsamplers_0 {
887 out.extend(d.parameters());
888 }
889 out
890 }
891
892 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
893 let mut out = Vec::new();
894 for r in &mut self.resnets {
895 out.extend(r.parameters_mut());
896 }
897 if let Some(d) = self.downsamplers_0.as_mut() {
898 out.extend(d.parameters_mut());
899 }
900 out
901 }
902
903 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
904 let mut out = Vec::new();
905 for (i, r) in self.resnets.iter().enumerate() {
906 for (n, p) in r.named_parameters() {
907 out.push((format!("resnets.{i}.{n}"), p));
908 }
909 }
910 if let Some(d) = &self.downsamplers_0 {
911 for (n, p) in d.named_parameters() {
912 out.push((format!("downsamplers.0.{n}"), p));
913 }
914 }
915 out
916 }
917
918 fn train(&mut self) {
919 self.training = true;
920 for r in &mut self.resnets {
921 r.train();
922 }
923 if let Some(d) = self.downsamplers_0.as_mut() {
924 d.train();
925 }
926 }
927 fn eval(&mut self) {
928 self.training = false;
929 for r in &mut self.resnets {
930 r.eval();
931 }
932 if let Some(d) = self.downsamplers_0.as_mut() {
933 d.eval();
934 }
935 }
936 fn is_training(&self) -> bool {
937 self.training
938 }
939
940 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
941 let extract = |prefix: &str| -> StateDict<T> {
942 let p = format!("{prefix}.");
943 state
944 .iter()
945 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
946 .collect()
947 };
948 if strict {
949 for k in state.keys() {
950 let ok =
951 k.starts_with("resnets.") || k.starts_with("downsamplers.0.");
952 if !ok {
953 return Err(FerrotorchError::InvalidArgument {
954 message: format!(
955 "unexpected key in DownEncoderBlock2D state_dict: \"{k}\""
956 ),
957 });
958 }
959 }
960 }
961 for (i, r) in self.resnets.iter_mut().enumerate() {
962 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
963 }
964 if let Some(d) = self.downsamplers_0.as_mut() {
965 d.load_state_dict(&extract("downsamplers.0"), strict)?;
966 } else if strict {
967 for k in state.keys() {
968 if k.starts_with("downsamplers.") {
969 return Err(FerrotorchError::InvalidArgument {
970 message: format!(
971 "DownEncoderBlock2D has no downsampler but state_dict contains \"{k}\""
972 ),
973 });
974 }
975 }
976 }
977 Ok(())
978 }
979}
980
981#[derive(Debug)]
996pub struct UNetMidBlock2D<T: Float> {
997 pub resnets: Vec<ResnetBlock2D<T>>,
999 pub attentions: Vec<AttnBlock2D<T>>,
1001 channels: usize,
1002 training: bool,
1003}
1004
1005impl<T: Float> UNetMidBlock2D<T> {
1006 pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
1012 let resnets = vec![
1014 ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
1015 ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
1016 ];
1017 let attentions = vec![AttnBlock2D::<T>::new(channels, norm_num_groups, eps)?];
1018 Ok(Self {
1019 resnets,
1020 attentions,
1021 channels,
1022 training: false,
1023 })
1024 }
1025}
1026
1027impl<T: Float> Module<T> for UNetMidBlock2D<T> {
1028 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1029 let mut h = self.resnets[0].forward(input)?;
1034 for (attn, resnet) in self.attentions.iter().zip(self.resnets.iter().skip(1)) {
1035 h = attn.forward(&h)?;
1036 h = resnet.forward(&h)?;
1037 }
1038 Ok(h)
1039 }
1040
1041 fn parameters(&self) -> Vec<&Parameter<T>> {
1042 let mut out = Vec::new();
1043 for a in &self.attentions {
1048 out.extend(a.parameters());
1049 }
1050 for r in &self.resnets {
1051 out.extend(r.parameters());
1052 }
1053 out
1054 }
1055
1056 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1057 let mut out = Vec::new();
1058 for a in &mut self.attentions {
1059 out.extend(a.parameters_mut());
1060 }
1061 for r in &mut self.resnets {
1062 out.extend(r.parameters_mut());
1063 }
1064 out
1065 }
1066
1067 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1068 let mut out = Vec::new();
1069 for (i, a) in self.attentions.iter().enumerate() {
1072 for (n, p) in a.named_parameters() {
1073 out.push((format!("attentions.{i}.{n}"), p));
1074 }
1075 }
1076 for (i, r) in self.resnets.iter().enumerate() {
1077 for (n, p) in r.named_parameters() {
1078 out.push((format!("resnets.{i}.{n}"), p));
1079 }
1080 }
1081 out
1082 }
1083
1084 fn train(&mut self) {
1085 self.training = true;
1086 for r in &mut self.resnets {
1087 r.train();
1088 }
1089 for a in &mut self.attentions {
1090 a.train();
1091 }
1092 }
1093 fn eval(&mut self) {
1094 self.training = false;
1095 for r in &mut self.resnets {
1096 r.eval();
1097 }
1098 for a in &mut self.attentions {
1099 a.eval();
1100 }
1101 }
1102 fn is_training(&self) -> bool {
1103 self.training
1104 }
1105
1106 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1107 let extract = |prefix: &str| -> StateDict<T> {
1108 let p = format!("{prefix}.");
1109 state
1110 .iter()
1111 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
1112 .collect()
1113 };
1114 if strict {
1115 for k in state.keys() {
1116 let ok = k.starts_with("attentions.") || k.starts_with("resnets.");
1117 if !ok {
1118 return Err(FerrotorchError::InvalidArgument {
1119 message: format!(
1120 "unexpected key in UNetMidBlock2D state_dict: \"{k}\""
1121 ),
1122 });
1123 }
1124 }
1125 }
1126 for (i, a) in self.attentions.iter_mut().enumerate() {
1127 a.load_state_dict(&extract(&format!("attentions.{i}")), strict)?;
1128 }
1129 for (i, r) in self.resnets.iter_mut().enumerate() {
1130 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
1131 }
1132 let _ = self.channels; let _: HashMap<String, Tensor<T>> = HashMap::new(); Ok(())
1135 }
1136}
1137
1138#[cfg(test)]
1139mod tests {
1140 use super::*;
1141 use ferrotorch_core::TensorStorage;
1142
1143 #[test]
1144 fn resnet_same_channels_no_shortcut() {
1145 let r = ResnetBlock2D::<f32>::new(32, 32, 32, 1e-6).unwrap();
1146 assert!(r.conv_shortcut.is_none());
1147 let x = Tensor::from_storage(
1148 TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1149 vec![1, 32, 4, 4],
1150 false,
1151 )
1152 .unwrap();
1153 let y = r.forward(&x).unwrap();
1154 assert_eq!(y.shape(), &[1, 32, 4, 4]);
1155 for &v in y.data().unwrap() {
1156 assert!(v.is_finite());
1157 }
1158 }
1159
1160 #[test]
1161 fn resnet_different_channels_has_shortcut() {
1162 let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
1163 assert!(r.conv_shortcut.is_some());
1164 let x = Tensor::from_storage(
1165 TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1166 vec![1, 32, 4, 4],
1167 false,
1168 )
1169 .unwrap();
1170 let y = r.forward(&x).unwrap();
1171 assert_eq!(y.shape(), &[1, 64, 4, 4]);
1172 }
1173
1174 #[test]
1175 fn resnet_named_parameters_layout() {
1176 let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
1177 let names: Vec<String> = r.named_parameters().into_iter().map(|(n, _)| n).collect();
1178 for k in [
1179 "norm1.weight",
1180 "norm1.bias",
1181 "conv1.weight",
1182 "conv1.bias",
1183 "norm2.weight",
1184 "norm2.bias",
1185 "conv2.weight",
1186 "conv2.bias",
1187 "conv_shortcut.weight",
1188 "conv_shortcut.bias",
1189 ] {
1190 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1191 }
1192 }
1193
1194 #[test]
1195 fn attn_shape_and_residual() {
1196 let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1198 let x = Tensor::from_storage(
1199 TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1200 vec![1, 32, 4, 4],
1201 false,
1202 )
1203 .unwrap();
1204 let y = a.forward(&x).unwrap();
1205 assert_eq!(y.shape(), &[1, 32, 4, 4]);
1206 for &v in y.data().unwrap() {
1207 assert!(v.is_finite(), "attn output non-finite: {v}");
1208 }
1209 }
1210
1211 #[test]
1212 fn attn_named_parameters_layout() {
1213 let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1214 let names: Vec<String> = a.named_parameters().into_iter().map(|(n, _)| n).collect();
1215 for k in [
1216 "group_norm.weight",
1217 "group_norm.bias",
1218 "to_q.weight",
1219 "to_q.bias",
1220 "to_k.weight",
1221 "to_k.bias",
1222 "to_v.weight",
1223 "to_v.bias",
1224 "to_out.0.weight",
1225 "to_out.0.bias",
1226 ] {
1227 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1228 }
1229 }
1230
1231 #[test]
1232 fn upsample2d_doubles_spatial() {
1233 let u = Upsample2D::<f32>::new(8).unwrap();
1234 let x = Tensor::from_storage(
1235 TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1236 vec![1, 8, 3, 3],
1237 false,
1238 )
1239 .unwrap();
1240 let y = u.forward(&x).unwrap();
1241 assert_eq!(y.shape(), &[1, 8, 6, 6]);
1242 }
1243
1244 #[test]
1245 fn up_decoder_block_shape_with_upsample() {
1246 let b = UpDecoderBlock2D::<f32>::new(16, 8, 2, 4, 1e-6, true).unwrap();
1247 let x = Tensor::from_storage(
1248 TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1249 vec![1, 16, 3, 3],
1250 false,
1251 )
1252 .unwrap();
1253 let y = b.forward(&x).unwrap();
1254 assert_eq!(y.shape(), &[1, 8, 6, 6]);
1255 }
1256
1257 #[test]
1258 fn up_decoder_block_shape_no_upsample() {
1259 let b = UpDecoderBlock2D::<f32>::new(8, 8, 2, 4, 1e-6, false).unwrap();
1260 let x = Tensor::from_storage(
1261 TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1262 vec![1, 8, 3, 3],
1263 false,
1264 )
1265 .unwrap();
1266 let y = b.forward(&x).unwrap();
1267 assert_eq!(y.shape(), &[1, 8, 3, 3]);
1268 }
1269
1270 #[test]
1271 fn down_encoder_block_shape_with_downsample() {
1272 let b = DownEncoderBlock2D::<f32>::new(8, 16, 2, 4, 1e-6, true).unwrap();
1274 let x = Tensor::from_storage(
1275 TensorStorage::cpu(vec![0.05f32; 8 * 4 * 4]),
1276 vec![1, 8, 4, 4],
1277 false,
1278 )
1279 .unwrap();
1280 let y = b.forward(&x).unwrap();
1281 assert_eq!(y.shape(), &[1, 16, 2, 2]);
1282 }
1283
1284 #[test]
1285 fn down_encoder_block_shape_no_downsample() {
1286 let b = DownEncoderBlock2D::<f32>::new(16, 16, 2, 4, 1e-6, false).unwrap();
1288 let x = Tensor::from_storage(
1289 TensorStorage::cpu(vec![0.05f32; 16 * 2 * 2]),
1290 vec![1, 16, 2, 2],
1291 false,
1292 )
1293 .unwrap();
1294 let y = b.forward(&x).unwrap();
1295 assert_eq!(y.shape(), &[1, 16, 2, 2]);
1296 }
1297
1298 #[test]
1299 fn down_encoder_block_named_parameters_layout() {
1300 let b = DownEncoderBlock2D::<f32>::new(8, 16, 2, 4, 1e-6, true).unwrap();
1301 let names: Vec<String> = b.named_parameters().into_iter().map(|(n, _)| n).collect();
1302 for k in [
1303 "resnets.0.norm1.weight",
1304 "resnets.0.conv1.weight",
1305 "resnets.0.conv_shortcut.weight",
1306 "resnets.1.norm1.weight",
1307 "downsamplers.0.conv.weight",
1308 "downsamplers.0.conv.bias",
1309 ] {
1310 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1311 }
1312 }
1313
1314 #[test]
1315 fn mid_block_shape() {
1316 let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1317 let x = Tensor::from_storage(
1318 TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1319 vec![1, 16, 3, 3],
1320 false,
1321 )
1322 .unwrap();
1323 let y = m.forward(&x).unwrap();
1324 assert_eq!(y.shape(), &[1, 16, 3, 3]);
1325 }
1326
1327 #[test]
1328 fn mid_block_named_parameters_layout() {
1329 let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1330 let names: Vec<String> = m.named_parameters().into_iter().map(|(n, _)| n).collect();
1331 for k in [
1332 "attentions.0.group_norm.weight",
1333 "attentions.0.to_q.weight",
1334 "resnets.0.norm1.weight",
1335 "resnets.0.conv1.weight",
1336 "resnets.1.conv2.weight",
1337 ] {
1338 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1339 }
1340 }
1341}