1use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
56use ferrotorch_nn::module::{Module, StateDict};
57use ferrotorch_nn::parameter::Parameter;
58use ferrotorch_nn::{Conv2d, GroupNorm, SiLU};
59
60use crate::attention::Transformer2DModel;
61use crate::blocks::{Downsample2D, Upsample2D};
62use crate::resnet_block_time::ResnetBlock2DTime;
63use crate::time_embedding::{TimestepEmbedding, Timesteps};
64use crate::unet_config::UNet2DConditionConfig;
65
66#[derive(Debug)]
81pub struct CrossAttnDownBlock2D<T: Float> {
82 pub resnets: Vec<ResnetBlock2DTime<T>>,
84 pub attentions: Vec<Transformer2DModel<T>>,
86 pub downsamplers_0: Option<Downsample2D<T>>,
88 training: bool,
89}
90
91impl<T: Float> CrossAttnDownBlock2D<T> {
92 #[allow(clippy::too_many_arguments)]
98 pub fn new(
99 in_channels: usize,
100 out_channels: usize,
101 temb_channels: usize,
102 num_layers: usize,
103 attention_head_dim: usize,
104 cross_attention_dim: usize,
105 norm_num_groups: usize,
106 add_downsample: bool,
107 transformer_layers_per_block: usize,
108 ) -> FerrotorchResult<Self> {
109 let heads = attention_head_dim;
117 let dim_head = out_channels / heads;
118 let mut resnets = Vec::with_capacity(num_layers);
119 let mut attentions = Vec::with_capacity(num_layers);
120 for j in 0..num_layers {
121 let in_c = if j == 0 { in_channels } else { out_channels };
122 resnets.push(ResnetBlock2DTime::<T>::new(
123 in_c,
124 out_channels,
125 temb_channels,
126 norm_num_groups,
127 1e-5,
128 )?);
129 attentions.push(Transformer2DModel::<T>::new(
130 out_channels,
131 heads,
132 dim_head,
133 transformer_layers_per_block,
134 cross_attention_dim,
135 norm_num_groups,
136 )?);
137 }
138 let downsamplers_0 = if add_downsample {
139 Some(Downsample2D::<T>::new(out_channels)?)
140 } else {
141 None
142 };
143 Ok(Self {
144 resnets,
145 attentions,
146 downsamplers_0,
147 training: false,
148 })
149 }
150
151 pub fn forward_t(
161 &self,
162 x: &Tensor<T>,
163 temb: &Tensor<T>,
164 encoder_hidden_states: &Tensor<T>,
165 ) -> FerrotorchResult<(Tensor<T>, Vec<Tensor<T>>)> {
166 let mut h = x.clone();
167 let mut skips = Vec::with_capacity(self.resnets.len() + 1);
168 for (r, a) in self.resnets.iter().zip(self.attentions.iter()) {
169 h = r.forward_t(&h, temb)?;
170 h = a.forward_xattn(&h, encoder_hidden_states)?;
171 skips.push(h.clone());
172 }
173 if let Some(d) = &self.downsamplers_0 {
174 h = d.forward(&h)?;
175 skips.push(h.clone());
176 }
177 Ok((h, skips))
178 }
179
180 fn named_params_internal(&self) -> Vec<(String, &Parameter<T>)> {
181 let mut o = Vec::new();
182 for (i, r) in self.resnets.iter().enumerate() {
183 for (n, p) in r.named_parameters() {
184 o.push((format!("resnets.{i}.{n}"), p));
185 }
186 }
187 for (i, a) in self.attentions.iter().enumerate() {
188 for (n, p) in a.named_parameters() {
189 o.push((format!("attentions.{i}.{n}"), p));
190 }
191 }
192 if let Some(d) = &self.downsamplers_0 {
193 for (n, p) in d.named_parameters() {
194 o.push((format!("downsamplers.0.{n}"), p));
195 }
196 }
197 o
198 }
199}
200
201impl<T: Float> Module<T> for CrossAttnDownBlock2D<T> {
202 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
203 Err(FerrotorchError::InvalidArgument {
204 message: "CrossAttnDownBlock2D::forward requires (x, temb, ehs) — call \
205 forward_t instead"
206 .into(),
207 })
208 }
209
210 fn parameters(&self) -> Vec<&Parameter<T>> {
211 let mut o = Vec::new();
212 for r in &self.resnets {
213 o.extend(r.parameters());
214 }
215 for a in &self.attentions {
216 o.extend(a.parameters());
217 }
218 if let Some(d) = &self.downsamplers_0 {
219 o.extend(d.parameters());
220 }
221 o
222 }
223 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
224 let mut o = Vec::new();
225 for r in &mut self.resnets {
226 o.extend(r.parameters_mut());
227 }
228 for a in &mut self.attentions {
229 o.extend(a.parameters_mut());
230 }
231 if let Some(d) = self.downsamplers_0.as_mut() {
232 o.extend(d.parameters_mut());
233 }
234 o
235 }
236 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
237 self.named_params_internal()
238 }
239 fn train(&mut self) {
240 self.training = true;
241 }
242 fn eval(&mut self) {
243 self.training = false;
244 }
245 fn is_training(&self) -> bool {
246 self.training
247 }
248 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
249 let extract = |prefix: &str| -> StateDict<T> {
250 let p = format!("{prefix}.");
251 state
252 .iter()
253 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
254 .collect()
255 };
256 if strict {
257 for k in state.keys() {
258 let ok = k.starts_with("resnets.")
259 || k.starts_with("attentions.")
260 || k.starts_with("downsamplers.0.");
261 if !ok {
262 return Err(FerrotorchError::InvalidArgument {
263 message: format!(
264 "unexpected key in CrossAttnDownBlock2D state_dict: \"{k}\""
265 ),
266 });
267 }
268 }
269 }
270 for (i, r) in self.resnets.iter_mut().enumerate() {
271 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
272 }
273 for (i, a) in self.attentions.iter_mut().enumerate() {
274 a.load_state_dict(&extract(&format!("attentions.{i}")), strict)?;
275 }
276 if let Some(d) = self.downsamplers_0.as_mut() {
277 d.load_state_dict(&extract("downsamplers.0"), strict)?;
278 }
279 Ok(())
280 }
281}
282
283#[derive(Debug)]
290pub struct DownBlock2D<T: Float> {
291 pub resnets: Vec<ResnetBlock2DTime<T>>,
293 pub downsamplers_0: Option<Downsample2D<T>>,
295 training: bool,
296}
297
298impl<T: Float> DownBlock2D<T> {
299 pub fn new(
305 in_channels: usize,
306 out_channels: usize,
307 temb_channels: usize,
308 num_layers: usize,
309 norm_num_groups: usize,
310 add_downsample: bool,
311 ) -> FerrotorchResult<Self> {
312 let mut resnets = Vec::with_capacity(num_layers);
313 for j in 0..num_layers {
314 let in_c = if j == 0 { in_channels } else { out_channels };
315 resnets.push(ResnetBlock2DTime::<T>::new(
316 in_c,
317 out_channels,
318 temb_channels,
319 norm_num_groups,
320 1e-5,
321 )?);
322 }
323 let downsamplers_0 = if add_downsample {
324 Some(Downsample2D::<T>::new(out_channels)?)
325 } else {
326 None
327 };
328 Ok(Self {
329 resnets,
330 downsamplers_0,
331 training: false,
332 })
333 }
334
335 pub fn forward_t(
341 &self,
342 x: &Tensor<T>,
343 temb: &Tensor<T>,
344 ) -> FerrotorchResult<(Tensor<T>, Vec<Tensor<T>>)> {
345 let mut h = x.clone();
346 let mut skips = Vec::with_capacity(self.resnets.len() + 1);
347 for r in &self.resnets {
348 h = r.forward_t(&h, temb)?;
349 skips.push(h.clone());
350 }
351 if let Some(d) = &self.downsamplers_0 {
352 h = d.forward(&h)?;
353 skips.push(h.clone());
354 }
355 Ok((h, skips))
356 }
357}
358
359impl<T: Float> Module<T> for DownBlock2D<T> {
360 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
361 Err(FerrotorchError::InvalidArgument {
362 message: "DownBlock2D::forward requires (x, temb) — call forward_t instead".into(),
363 })
364 }
365
366 fn parameters(&self) -> Vec<&Parameter<T>> {
367 let mut o = Vec::new();
368 for r in &self.resnets {
369 o.extend(r.parameters());
370 }
371 if let Some(d) = &self.downsamplers_0 {
372 o.extend(d.parameters());
373 }
374 o
375 }
376 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
377 let mut o = Vec::new();
378 for r in &mut self.resnets {
379 o.extend(r.parameters_mut());
380 }
381 if let Some(d) = self.downsamplers_0.as_mut() {
382 o.extend(d.parameters_mut());
383 }
384 o
385 }
386 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
387 let mut o = Vec::new();
388 for (i, r) in self.resnets.iter().enumerate() {
389 for (n, p) in r.named_parameters() {
390 o.push((format!("resnets.{i}.{n}"), p));
391 }
392 }
393 if let Some(d) = &self.downsamplers_0 {
394 for (n, p) in d.named_parameters() {
395 o.push((format!("downsamplers.0.{n}"), p));
396 }
397 }
398 o
399 }
400 fn train(&mut self) {
401 self.training = true;
402 }
403 fn eval(&mut self) {
404 self.training = false;
405 }
406 fn is_training(&self) -> bool {
407 self.training
408 }
409 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
410 let extract = |prefix: &str| -> StateDict<T> {
411 let p = format!("{prefix}.");
412 state
413 .iter()
414 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
415 .collect()
416 };
417 if strict {
418 for k in state.keys() {
419 let ok = k.starts_with("resnets.") || k.starts_with("downsamplers.0.");
420 if !ok {
421 return Err(FerrotorchError::InvalidArgument {
422 message: format!("unexpected key in DownBlock2D state_dict: \"{k}\""),
423 });
424 }
425 }
426 }
427 for (i, r) in self.resnets.iter_mut().enumerate() {
428 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
429 }
430 if let Some(d) = self.downsamplers_0.as_mut() {
431 d.load_state_dict(&extract("downsamplers.0"), strict)?;
432 }
433 Ok(())
434 }
435}
436
437#[derive(Debug)]
455pub struct UNetMidBlock2DCrossAttn<T: Float> {
456 pub resnets: Vec<ResnetBlock2DTime<T>>,
458 pub attentions: Vec<Transformer2DModel<T>>,
460 training: bool,
461}
462
463impl<T: Float> UNetMidBlock2DCrossAttn<T> {
464 pub fn new(
470 channels: usize,
471 temb_channels: usize,
472 attention_head_dim: usize,
473 cross_attention_dim: usize,
474 norm_num_groups: usize,
475 transformer_layers_per_block: usize,
476 ) -> FerrotorchResult<Self> {
477 let heads = attention_head_dim;
479 let dim_head = channels / heads;
480 let r0 =
481 ResnetBlock2DTime::<T>::new(channels, channels, temb_channels, norm_num_groups, 1e-5)?;
482 let attn = Transformer2DModel::<T>::new(
483 channels,
484 heads,
485 dim_head,
486 transformer_layers_per_block,
487 cross_attention_dim,
488 norm_num_groups,
489 )?;
490 let r1 =
491 ResnetBlock2DTime::<T>::new(channels, channels, temb_channels, norm_num_groups, 1e-5)?;
492 Ok(Self {
493 resnets: vec![r0, r1],
494 attentions: vec![attn],
495 training: false,
496 })
497 }
498
499 pub fn forward_t(
505 &self,
506 x: &Tensor<T>,
507 temb: &Tensor<T>,
508 encoder_hidden_states: &Tensor<T>,
509 ) -> FerrotorchResult<Tensor<T>> {
510 let mut h = self.resnets[0].forward_t(x, temb)?;
511 for (a, r) in self.attentions.iter().zip(self.resnets.iter().skip(1)) {
512 h = a.forward_xattn(&h, encoder_hidden_states)?;
513 h = r.forward_t(&h, temb)?;
514 }
515 Ok(h)
516 }
517}
518
519impl<T: Float> Module<T> for UNetMidBlock2DCrossAttn<T> {
520 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
521 Err(FerrotorchError::InvalidArgument {
522 message: "UNetMidBlock2DCrossAttn::forward requires (x, temb, ehs) — call \
523 forward_t instead"
524 .into(),
525 })
526 }
527
528 fn parameters(&self) -> Vec<&Parameter<T>> {
529 let mut o = Vec::new();
530 for a in &self.attentions {
534 o.extend(a.parameters());
535 }
536 for r in &self.resnets {
537 o.extend(r.parameters());
538 }
539 o
540 }
541 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
542 let mut o = Vec::new();
543 for a in &mut self.attentions {
544 o.extend(a.parameters_mut());
545 }
546 for r in &mut self.resnets {
547 o.extend(r.parameters_mut());
548 }
549 o
550 }
551 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
552 let mut o = Vec::new();
553 for (i, a) in self.attentions.iter().enumerate() {
554 for (n, p) in a.named_parameters() {
555 o.push((format!("attentions.{i}.{n}"), p));
556 }
557 }
558 for (i, r) in self.resnets.iter().enumerate() {
559 for (n, p) in r.named_parameters() {
560 o.push((format!("resnets.{i}.{n}"), p));
561 }
562 }
563 o
564 }
565 fn train(&mut self) {
566 self.training = true;
567 }
568 fn eval(&mut self) {
569 self.training = false;
570 }
571 fn is_training(&self) -> bool {
572 self.training
573 }
574 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
575 let extract = |prefix: &str| -> StateDict<T> {
576 let p = format!("{prefix}.");
577 state
578 .iter()
579 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
580 .collect()
581 };
582 if strict {
583 for k in state.keys() {
584 let ok = k.starts_with("resnets.") || k.starts_with("attentions.");
585 if !ok {
586 return Err(FerrotorchError::InvalidArgument {
587 message: format!(
588 "unexpected key in UNetMidBlock2DCrossAttn state_dict: \"{k}\""
589 ),
590 });
591 }
592 }
593 }
594 for (i, a) in self.attentions.iter_mut().enumerate() {
595 a.load_state_dict(&extract(&format!("attentions.{i}")), strict)?;
596 }
597 for (i, r) in self.resnets.iter_mut().enumerate() {
598 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
599 }
600 Ok(())
601 }
602}
603
604#[derive(Debug)]
613pub struct CrossAttnUpBlock2D<T: Float> {
614 pub resnets: Vec<ResnetBlock2DTime<T>>,
616 pub attentions: Vec<Transformer2DModel<T>>,
618 pub upsamplers_0: Option<Upsample2D<T>>,
620 training: bool,
621}
622
623impl<T: Float> CrossAttnUpBlock2D<T> {
624 #[allow(clippy::too_many_arguments)]
630 pub fn new(
631 in_channels: usize,
632 out_channels: usize,
633 prev_output_channels: usize,
634 temb_channels: usize,
635 num_layers: usize,
636 attention_head_dim: usize,
637 cross_attention_dim: usize,
638 norm_num_groups: usize,
639 add_upsample: bool,
640 transformer_layers_per_block: usize,
641 ) -> FerrotorchResult<Self> {
642 let heads = attention_head_dim;
644 let dim_head = out_channels / heads;
645 let mut resnets = Vec::with_capacity(num_layers);
646 let mut attentions = Vec::with_capacity(num_layers);
647 for j in 0..num_layers {
660 let res_skip = if j == num_layers - 1 {
661 in_channels
662 } else {
663 out_channels
664 };
665 let resnet_in = if j == 0 {
666 prev_output_channels + res_skip
667 } else {
668 out_channels + res_skip
669 };
670 resnets.push(ResnetBlock2DTime::<T>::new(
671 resnet_in,
672 out_channels,
673 temb_channels,
674 norm_num_groups,
675 1e-5,
676 )?);
677 attentions.push(Transformer2DModel::<T>::new(
678 out_channels,
679 heads,
680 dim_head,
681 transformer_layers_per_block,
682 cross_attention_dim,
683 norm_num_groups,
684 )?);
685 }
686 let upsamplers_0 = if add_upsample {
687 Some(Upsample2D::<T>::new(out_channels)?)
688 } else {
689 None
690 };
691 Ok(Self {
692 resnets,
693 attentions,
694 upsamplers_0,
695 training: false,
696 })
697 }
698
699 pub fn forward_t(
705 &self,
706 x: &Tensor<T>,
707 skips: &[Tensor<T>],
708 temb: &Tensor<T>,
709 encoder_hidden_states: &Tensor<T>,
710 ) -> FerrotorchResult<Tensor<T>> {
711 if skips.len() != self.resnets.len() {
712 return Err(FerrotorchError::InvalidArgument {
713 message: format!(
714 "CrossAttnUpBlock2D::forward_t: got {} skips for {} resnets",
715 skips.len(),
716 self.resnets.len()
717 ),
718 });
719 }
720 let mut h = x.clone();
721 for ((r, a), skip) in self
722 .resnets
723 .iter()
724 .zip(self.attentions.iter())
725 .zip(skips.iter())
726 {
727 h = ferrotorch_core::grad_fns::shape::cat(&[h.clone(), skip.clone()], 1)?;
728 h = r.forward_t(&h, temb)?;
729 h = a.forward_xattn(&h, encoder_hidden_states)?;
730 }
731 if let Some(u) = &self.upsamplers_0 {
732 h = u.forward(&h)?;
733 }
734 Ok(h)
735 }
736}
737
738impl<T: Float> Module<T> for CrossAttnUpBlock2D<T> {
739 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
740 Err(FerrotorchError::InvalidArgument {
741 message: "CrossAttnUpBlock2D::forward requires (x, skips, temb, ehs) — call \
742 forward_t instead"
743 .into(),
744 })
745 }
746 fn parameters(&self) -> Vec<&Parameter<T>> {
747 let mut o = Vec::new();
748 for r in &self.resnets {
749 o.extend(r.parameters());
750 }
751 for a in &self.attentions {
752 o.extend(a.parameters());
753 }
754 if let Some(u) = &self.upsamplers_0 {
755 o.extend(u.parameters());
756 }
757 o
758 }
759 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
760 let mut o = Vec::new();
761 for r in &mut self.resnets {
762 o.extend(r.parameters_mut());
763 }
764 for a in &mut self.attentions {
765 o.extend(a.parameters_mut());
766 }
767 if let Some(u) = self.upsamplers_0.as_mut() {
768 o.extend(u.parameters_mut());
769 }
770 o
771 }
772 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
773 let mut o = Vec::new();
774 for (i, r) in self.resnets.iter().enumerate() {
775 for (n, p) in r.named_parameters() {
776 o.push((format!("resnets.{i}.{n}"), p));
777 }
778 }
779 for (i, a) in self.attentions.iter().enumerate() {
780 for (n, p) in a.named_parameters() {
781 o.push((format!("attentions.{i}.{n}"), p));
782 }
783 }
784 if let Some(u) = &self.upsamplers_0 {
785 for (n, p) in u.named_parameters() {
786 o.push((format!("upsamplers.0.{n}"), p));
787 }
788 }
789 o
790 }
791 fn train(&mut self) {
792 self.training = true;
793 }
794 fn eval(&mut self) {
795 self.training = false;
796 }
797 fn is_training(&self) -> bool {
798 self.training
799 }
800 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
801 let extract = |prefix: &str| -> StateDict<T> {
802 let p = format!("{prefix}.");
803 state
804 .iter()
805 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
806 .collect()
807 };
808 if strict {
809 for k in state.keys() {
810 let ok = k.starts_with("resnets.")
811 || k.starts_with("attentions.")
812 || k.starts_with("upsamplers.0.");
813 if !ok {
814 return Err(FerrotorchError::InvalidArgument {
815 message: format!(
816 "unexpected key in CrossAttnUpBlock2D state_dict: \"{k}\""
817 ),
818 });
819 }
820 }
821 }
822 for (i, r) in self.resnets.iter_mut().enumerate() {
823 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
824 }
825 for (i, a) in self.attentions.iter_mut().enumerate() {
826 a.load_state_dict(&extract(&format!("attentions.{i}")), strict)?;
827 }
828 if let Some(u) = self.upsamplers_0.as_mut() {
829 u.load_state_dict(&extract("upsamplers.0"), strict)?;
830 }
831 Ok(())
832 }
833}
834
835#[derive(Debug)]
842pub struct UpBlock2D<T: Float> {
843 pub resnets: Vec<ResnetBlock2DTime<T>>,
845 pub upsamplers_0: Option<Upsample2D<T>>,
847 training: bool,
848}
849
850impl<T: Float> UpBlock2D<T> {
851 #[allow(clippy::too_many_arguments)]
857 pub fn new(
858 in_channels: usize,
859 out_channels: usize,
860 prev_output_channels: usize,
861 temb_channels: usize,
862 num_layers: usize,
863 norm_num_groups: usize,
864 add_upsample: bool,
865 ) -> FerrotorchResult<Self> {
866 let mut resnets = Vec::with_capacity(num_layers);
867 for j in 0..num_layers {
868 let res_skip = if j == num_layers - 1 {
869 in_channels
870 } else {
871 out_channels
872 };
873 let resnet_in = if j == 0 {
874 prev_output_channels + res_skip
875 } else {
876 out_channels + res_skip
877 };
878 resnets.push(ResnetBlock2DTime::<T>::new(
879 resnet_in,
880 out_channels,
881 temb_channels,
882 norm_num_groups,
883 1e-5,
884 )?);
885 }
886 let upsamplers_0 = if add_upsample {
887 Some(Upsample2D::<T>::new(out_channels)?)
888 } else {
889 None
890 };
891 Ok(Self {
892 resnets,
893 upsamplers_0,
894 training: false,
895 })
896 }
897
898 pub fn forward_t(
904 &self,
905 x: &Tensor<T>,
906 skips: &[Tensor<T>],
907 temb: &Tensor<T>,
908 ) -> FerrotorchResult<Tensor<T>> {
909 if skips.len() != self.resnets.len() {
910 return Err(FerrotorchError::InvalidArgument {
911 message: format!(
912 "UpBlock2D::forward_t: got {} skips for {} resnets",
913 skips.len(),
914 self.resnets.len()
915 ),
916 });
917 }
918 let mut h = x.clone();
919 for (r, skip) in self.resnets.iter().zip(skips.iter()) {
920 h = ferrotorch_core::grad_fns::shape::cat(&[h.clone(), skip.clone()], 1)?;
921 h = r.forward_t(&h, temb)?;
922 }
923 if let Some(u) = &self.upsamplers_0 {
924 h = u.forward(&h)?;
925 }
926 Ok(h)
927 }
928}
929
930impl<T: Float> Module<T> for UpBlock2D<T> {
931 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
932 Err(FerrotorchError::InvalidArgument {
933 message: "UpBlock2D::forward requires (x, skips, temb) — call forward_t instead".into(),
934 })
935 }
936 fn parameters(&self) -> Vec<&Parameter<T>> {
937 let mut o = Vec::new();
938 for r in &self.resnets {
939 o.extend(r.parameters());
940 }
941 if let Some(u) = &self.upsamplers_0 {
942 o.extend(u.parameters());
943 }
944 o
945 }
946 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
947 let mut o = Vec::new();
948 for r in &mut self.resnets {
949 o.extend(r.parameters_mut());
950 }
951 if let Some(u) = self.upsamplers_0.as_mut() {
952 o.extend(u.parameters_mut());
953 }
954 o
955 }
956 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
957 let mut o = Vec::new();
958 for (i, r) in self.resnets.iter().enumerate() {
959 for (n, p) in r.named_parameters() {
960 o.push((format!("resnets.{i}.{n}"), p));
961 }
962 }
963 if let Some(u) = &self.upsamplers_0 {
964 for (n, p) in u.named_parameters() {
965 o.push((format!("upsamplers.0.{n}"), p));
966 }
967 }
968 o
969 }
970 fn train(&mut self) {
971 self.training = true;
972 }
973 fn eval(&mut self) {
974 self.training = false;
975 }
976 fn is_training(&self) -> bool {
977 self.training
978 }
979 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
980 let extract = |prefix: &str| -> StateDict<T> {
981 let p = format!("{prefix}.");
982 state
983 .iter()
984 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
985 .collect()
986 };
987 if strict {
988 for k in state.keys() {
989 let ok = k.starts_with("resnets.") || k.starts_with("upsamplers.0.");
990 if !ok {
991 return Err(FerrotorchError::InvalidArgument {
992 message: format!("unexpected key in UpBlock2D state_dict: \"{k}\""),
993 });
994 }
995 }
996 }
997 for (i, r) in self.resnets.iter_mut().enumerate() {
998 r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
999 }
1000 if let Some(u) = self.upsamplers_0.as_mut() {
1001 u.load_state_dict(&extract("upsamplers.0"), strict)?;
1002 }
1003 Ok(())
1004 }
1005}
1006
1007#[derive(Debug)]
1013pub enum AnyDownBlock<T: Float> {
1014 CrossAttn(CrossAttnDownBlock2D<T>),
1016 Plain(DownBlock2D<T>),
1018}
1019
1020impl<T: Float> AnyDownBlock<T> {
1021 fn forward_t(
1022 &self,
1023 x: &Tensor<T>,
1024 temb: &Tensor<T>,
1025 encoder_hidden_states: &Tensor<T>,
1026 ) -> FerrotorchResult<(Tensor<T>, Vec<Tensor<T>>)> {
1027 match self {
1028 AnyDownBlock::CrossAttn(b) => b.forward_t(x, temb, encoder_hidden_states),
1029 AnyDownBlock::Plain(b) => b.forward_t(x, temb),
1030 }
1031 }
1032
1033 fn parameters(&self) -> Vec<&Parameter<T>> {
1034 match self {
1035 AnyDownBlock::CrossAttn(b) => b.parameters(),
1036 AnyDownBlock::Plain(b) => b.parameters(),
1037 }
1038 }
1039 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1040 match self {
1041 AnyDownBlock::CrossAttn(b) => b.parameters_mut(),
1042 AnyDownBlock::Plain(b) => b.parameters_mut(),
1043 }
1044 }
1045 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1046 match self {
1047 AnyDownBlock::CrossAttn(b) => b.named_parameters(),
1048 AnyDownBlock::Plain(b) => b.named_parameters(),
1049 }
1050 }
1051 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1052 match self {
1053 AnyDownBlock::CrossAttn(b) => b.load_state_dict(state, strict),
1054 AnyDownBlock::Plain(b) => b.load_state_dict(state, strict),
1055 }
1056 }
1057}
1058
1059#[derive(Debug)]
1061pub enum AnyUpBlock<T: Float> {
1062 CrossAttn(CrossAttnUpBlock2D<T>),
1064 Plain(UpBlock2D<T>),
1066}
1067
1068impl<T: Float> AnyUpBlock<T> {
1069 fn forward_t(
1070 &self,
1071 x: &Tensor<T>,
1072 skips: &[Tensor<T>],
1073 temb: &Tensor<T>,
1074 encoder_hidden_states: &Tensor<T>,
1075 ) -> FerrotorchResult<Tensor<T>> {
1076 match self {
1077 AnyUpBlock::CrossAttn(b) => b.forward_t(x, skips, temb, encoder_hidden_states),
1078 AnyUpBlock::Plain(b) => b.forward_t(x, skips, temb),
1079 }
1080 }
1081 fn num_resnets(&self) -> usize {
1082 match self {
1083 AnyUpBlock::CrossAttn(b) => b.resnets.len(),
1084 AnyUpBlock::Plain(b) => b.resnets.len(),
1085 }
1086 }
1087 fn parameters(&self) -> Vec<&Parameter<T>> {
1088 match self {
1089 AnyUpBlock::CrossAttn(b) => b.parameters(),
1090 AnyUpBlock::Plain(b) => b.parameters(),
1091 }
1092 }
1093 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1094 match self {
1095 AnyUpBlock::CrossAttn(b) => b.parameters_mut(),
1096 AnyUpBlock::Plain(b) => b.parameters_mut(),
1097 }
1098 }
1099 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1100 match self {
1101 AnyUpBlock::CrossAttn(b) => b.named_parameters(),
1102 AnyUpBlock::Plain(b) => b.named_parameters(),
1103 }
1104 }
1105 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1106 match self {
1107 AnyUpBlock::CrossAttn(b) => b.load_state_dict(state, strict),
1108 AnyUpBlock::Plain(b) => b.load_state_dict(state, strict),
1109 }
1110 }
1111}
1112
1113#[derive(Debug)]
1119pub struct UNet2DConditionModel<T: Float> {
1120 pub time_proj: Timesteps,
1122 pub time_embedding: TimestepEmbedding<T>,
1124 pub conv_in: Conv2d<T>,
1126 pub down_blocks: Vec<AnyDownBlock<T>>,
1128 pub mid_block: UNetMidBlock2DCrossAttn<T>,
1130 pub up_blocks: Vec<AnyUpBlock<T>>,
1133 pub conv_norm_out: GroupNorm<T>,
1135 pub conv_act: SiLU,
1137 pub conv_out: Conv2d<T>,
1139 pub config: UNet2DConditionConfig,
1141 training: bool,
1142}
1143
1144impl<T: Float> UNet2DConditionModel<T> {
1145 pub fn new(cfg: UNet2DConditionConfig) -> FerrotorchResult<Self> {
1152 cfg.validate()?;
1153 let temb_channels = cfg.time_embed_dim();
1154 let bocs = &cfg.block_out_channels;
1155 let groups = cfg.norm_num_groups;
1156 let num_blocks = bocs.len();
1157
1158 let time_proj = Timesteps::new(bocs[0], cfg.flip_sin_to_cos, cfg.freq_shift)?;
1162 let time_embedding = TimestepEmbedding::<T>::new(bocs[0], temb_channels)?;
1163
1164 let conv_in = Conv2d::<T>::new(cfg.in_channels, bocs[0], (3, 3), (1, 1), (1, 1), true)?;
1166
1167 let mut down_blocks: Vec<AnyDownBlock<T>> = Vec::with_capacity(num_blocks);
1169 let mut prev = bocs[0];
1170 for i in 0..num_blocks {
1171 let out_c = bocs[i];
1172 let is_final = i == num_blocks - 1;
1173 let add_downsample = !is_final;
1174 let block: AnyDownBlock<T> = if cfg.down_block_has_attn[i] {
1175 AnyDownBlock::CrossAttn(CrossAttnDownBlock2D::<T>::new(
1176 prev,
1177 out_c,
1178 temb_channels,
1179 cfg.layers_per_block,
1180 cfg.attention_head_dim,
1181 cfg.cross_attention_dim,
1182 groups,
1183 add_downsample,
1184 cfg.transformer_layers_per_block,
1185 )?)
1186 } else {
1187 AnyDownBlock::Plain(DownBlock2D::<T>::new(
1188 prev,
1189 out_c,
1190 temb_channels,
1191 cfg.layers_per_block,
1192 groups,
1193 add_downsample,
1194 )?)
1195 };
1196 down_blocks.push(block);
1197 prev = out_c;
1198 }
1199
1200 let mid_channels = bocs[num_blocks - 1];
1202 let mid_block = UNetMidBlock2DCrossAttn::<T>::new(
1203 mid_channels,
1204 temb_channels,
1205 cfg.attention_head_dim,
1206 cfg.cross_attention_dim,
1207 groups,
1208 cfg.transformer_layers_per_block,
1209 )?;
1210
1211 let mut up_blocks: Vec<AnyUpBlock<T>> = Vec::with_capacity(num_blocks);
1222 let reversed: Vec<usize> = bocs.iter().rev().copied().collect();
1223 let mut prev_up = mid_channels;
1224 for i in 0..num_blocks {
1225 let out_c = reversed[i];
1226 let in_c = reversed[(i + 1).min(num_blocks - 1)];
1227 let is_final = i == num_blocks - 1;
1228 let add_upsample = !is_final;
1229 let block: AnyUpBlock<T> = if cfg.up_block_has_attn[i] {
1230 AnyUpBlock::CrossAttn(CrossAttnUpBlock2D::<T>::new(
1231 in_c,
1232 out_c,
1233 prev_up,
1234 temb_channels,
1235 cfg.layers_per_block + 1,
1236 cfg.attention_head_dim,
1237 cfg.cross_attention_dim,
1238 groups,
1239 add_upsample,
1240 cfg.transformer_layers_per_block,
1241 )?)
1242 } else {
1243 AnyUpBlock::Plain(UpBlock2D::<T>::new(
1244 in_c,
1245 out_c,
1246 prev_up,
1247 temb_channels,
1248 cfg.layers_per_block + 1,
1249 groups,
1250 add_upsample,
1251 )?)
1252 };
1253 up_blocks.push(block);
1254 prev_up = out_c;
1255 }
1256
1257 let conv_norm_out = GroupNorm::<T>::new(groups, bocs[0], 1e-5, true)?;
1258 let conv_out = Conv2d::<T>::new(bocs[0], cfg.out_channels, (3, 3), (1, 1), (1, 1), true)?;
1259
1260 Ok(Self {
1261 time_proj,
1262 time_embedding,
1263 conv_in,
1264 down_blocks,
1265 mid_block,
1266 up_blocks,
1267 conv_norm_out,
1268 conv_act: SiLU::new(),
1269 conv_out,
1270 config: cfg,
1271 training: false,
1272 })
1273 }
1274
1275 pub fn forward_t(
1288 &self,
1289 sample: &Tensor<T>,
1290 timesteps: &Tensor<T>,
1291 encoder_hidden_states: &Tensor<T>,
1292 ) -> FerrotorchResult<Tensor<T>> {
1293 let cfg = &self.config;
1294 if sample.ndim() != 4 || sample.shape()[1] != cfg.in_channels {
1295 return Err(FerrotorchError::ShapeMismatch {
1296 message: format!(
1297 "UNet2DConditionModel: expected sample [B, {}, H, W], got {:?}",
1298 cfg.in_channels,
1299 sample.shape()
1300 ),
1301 });
1302 }
1303 if timesteps.ndim() != 1 {
1304 return Err(FerrotorchError::ShapeMismatch {
1305 message: format!(
1306 "UNet2DConditionModel: expected timesteps [B], got {:?}",
1307 timesteps.shape()
1308 ),
1309 });
1310 }
1311 if encoder_hidden_states.ndim() != 3
1312 || encoder_hidden_states.shape()[2] != cfg.cross_attention_dim
1313 {
1314 return Err(FerrotorchError::ShapeMismatch {
1315 message: format!(
1316 "UNet2DConditionModel: expected encoder_hidden_states [B, S, {}], got {:?}",
1317 cfg.cross_attention_dim,
1318 encoder_hidden_states.shape()
1319 ),
1320 });
1321 }
1322
1323 let t_enc = self.time_proj.forward_t(timesteps)?;
1325 let temb = self.time_embedding.forward(&t_enc)?;
1326
1327 let mut h = self.conv_in.forward(sample)?;
1329 let mut skips: Vec<Tensor<T>> = Vec::new();
1330 skips.push(h.clone());
1334
1335 for db in &self.down_blocks {
1337 let (out, mut block_skips) = db.forward_t(&h, &temb, encoder_hidden_states)?;
1338 h = out;
1339 skips.append(&mut block_skips);
1340 }
1341
1342 h = self.mid_block.forward_t(&h, &temb, encoder_hidden_states)?;
1344
1345 for ub in &self.up_blocks {
1348 let n = ub.num_resnets();
1349 if skips.len() < n {
1350 return Err(FerrotorchError::InvalidArgument {
1351 message: format!(
1352 "UNet2DConditionModel: up-block needs {n} skips, only {} left",
1353 skips.len()
1354 ),
1355 });
1356 }
1357 let split_at = skips.len() - n;
1358 let popped = skips.split_off(split_at);
1359 let popped_rev: Vec<Tensor<T>> = popped.into_iter().rev().collect();
1368 h = ub.forward_t(&h, &popped_rev, &temb, encoder_hidden_states)?;
1369 }
1370
1371 h = self.conv_norm_out.forward(&h)?;
1373 h = self.conv_act.forward(&h)?;
1374 self.conv_out.forward(&h)
1375 }
1376}
1377
1378impl<T: Float> Module<T> for UNet2DConditionModel<T> {
1379 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1380 Err(FerrotorchError::InvalidArgument {
1381 message: "UNet2DConditionModel::forward requires (sample, timesteps, ehs) — \
1382 call forward_t instead"
1383 .into(),
1384 })
1385 }
1386
1387 fn parameters(&self) -> Vec<&Parameter<T>> {
1388 let mut o = Vec::new();
1389 o.extend(self.time_embedding.parameters());
1390 o.extend(self.conv_in.parameters());
1391 for db in &self.down_blocks {
1392 o.extend(db.parameters());
1393 }
1394 o.extend(self.mid_block.parameters());
1395 for ub in &self.up_blocks {
1396 o.extend(ub.parameters());
1397 }
1398 o.extend(self.conv_norm_out.parameters());
1399 o.extend(self.conv_out.parameters());
1400 o
1401 }
1402
1403 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1404 let mut o = Vec::new();
1405 o.extend(self.time_embedding.parameters_mut());
1406 o.extend(self.conv_in.parameters_mut());
1407 for db in &mut self.down_blocks {
1408 o.extend(db.parameters_mut());
1409 }
1410 o.extend(self.mid_block.parameters_mut());
1411 for ub in &mut self.up_blocks {
1412 o.extend(ub.parameters_mut());
1413 }
1414 o.extend(self.conv_norm_out.parameters_mut());
1415 o.extend(self.conv_out.parameters_mut());
1416 o
1417 }
1418
1419 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1420 let mut o = Vec::new();
1421 for (n, p) in self.time_embedding.named_parameters() {
1422 o.push((format!("time_embedding.{n}"), p));
1423 }
1424 for (n, p) in self.conv_in.named_parameters() {
1425 o.push((format!("conv_in.{n}"), p));
1426 }
1427 for (i, db) in self.down_blocks.iter().enumerate() {
1428 for (n, p) in db.named_parameters() {
1429 o.push((format!("down_blocks.{i}.{n}"), p));
1430 }
1431 }
1432 for (n, p) in self.mid_block.named_parameters() {
1433 o.push((format!("mid_block.{n}"), p));
1434 }
1435 for (i, ub) in self.up_blocks.iter().enumerate() {
1436 for (n, p) in ub.named_parameters() {
1437 o.push((format!("up_blocks.{i}.{n}"), p));
1438 }
1439 }
1440 for (n, p) in self.conv_norm_out.named_parameters() {
1441 o.push((format!("conv_norm_out.{n}"), p));
1442 }
1443 for (n, p) in self.conv_out.named_parameters() {
1444 o.push((format!("conv_out.{n}"), p));
1445 }
1446 o
1447 }
1448 fn train(&mut self) {
1449 self.training = true;
1450 }
1451 fn eval(&mut self) {
1452 self.training = false;
1453 }
1454 fn is_training(&self) -> bool {
1455 self.training
1456 }
1457 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1458 let extract = |prefix: &str| -> StateDict<T> {
1459 let p = format!("{prefix}.");
1460 state
1461 .iter()
1462 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
1463 .collect()
1464 };
1465 if strict {
1466 for k in state.keys() {
1467 let ok = k.starts_with("time_embedding.")
1468 || k.starts_with("conv_in.")
1469 || k.starts_with("down_blocks.")
1470 || k.starts_with("mid_block.")
1471 || k.starts_with("up_blocks.")
1472 || k.starts_with("conv_norm_out.")
1473 || k.starts_with("conv_out.");
1474 if !ok {
1475 return Err(FerrotorchError::InvalidArgument {
1476 message: format!(
1477 "unexpected key in UNet2DConditionModel state_dict: \"{k}\""
1478 ),
1479 });
1480 }
1481 }
1482 }
1483 self.time_embedding
1484 .load_state_dict(&extract("time_embedding"), strict)?;
1485 self.conv_in.load_state_dict(&extract("conv_in"), strict)?;
1486 for (i, db) in self.down_blocks.iter_mut().enumerate() {
1487 db.load_state_dict(&extract(&format!("down_blocks.{i}")), strict)?;
1488 }
1489 self.mid_block
1490 .load_state_dict(&extract("mid_block"), strict)?;
1491 for (i, ub) in self.up_blocks.iter_mut().enumerate() {
1492 ub.load_state_dict(&extract(&format!("up_blocks.{i}")), strict)?;
1493 }
1494 self.conv_norm_out
1495 .load_state_dict(&extract("conv_norm_out"), strict)?;
1496 self.conv_out
1497 .load_state_dict(&extract("conv_out"), strict)?;
1498 Ok(())
1499 }
1500}
1501
1502#[cfg(test)]
1503mod tests {
1504 use super::*;
1505 use ferrotorch_core::TensorStorage;
1506
1507 fn tiny_cfg() -> UNet2DConditionConfig {
1508 UNet2DConditionConfig {
1509 in_channels: 4,
1510 out_channels: 4,
1511 block_out_channels: vec![16, 32, 64, 64],
1512 layers_per_block: 1,
1513 attention_head_dim: 8,
1514 cross_attention_dim: 24,
1515 norm_num_groups: 4,
1516 sample_size: 8,
1517 flip_sin_to_cos: true,
1518 freq_shift: 0.0,
1519 transformer_layers_per_block: 1,
1520 down_block_has_attn: vec![true, true, true, false],
1521 up_block_has_attn: vec![false, true, true, true],
1522 }
1523 }
1524
1525 #[test]
1526 fn unet_forward_shape() {
1527 let cfg = tiny_cfg();
1528 let unet = UNet2DConditionModel::<f32>::new(cfg.clone()).unwrap();
1529 let sample = Tensor::from_storage(
1530 TensorStorage::cpu(vec![0.01f32; 4 * 8 * 8]),
1531 vec![1, 4, 8, 8],
1532 false,
1533 )
1534 .unwrap();
1535 let timesteps =
1536 Tensor::from_storage(TensorStorage::cpu(vec![5.0f32]), vec![1], false).unwrap();
1537 let ehs = Tensor::from_storage(
1538 TensorStorage::cpu(vec![0.01f32; 7 * 24]),
1539 vec![1, 7, 24],
1540 false,
1541 )
1542 .unwrap();
1543 let y = unet.forward_t(&sample, ×teps, &ehs).unwrap();
1544 assert_eq!(y.shape(), &[1, 4, 8, 8]);
1545 }
1546
1547 #[test]
1548 fn unet_named_parameters_includes_canonical_keys() {
1549 let cfg = tiny_cfg();
1550 let unet = UNet2DConditionModel::<f32>::new(cfg).unwrap();
1551 let names: Vec<String> = unet
1552 .named_parameters()
1553 .into_iter()
1554 .map(|(n, _)| n)
1555 .collect();
1556 for k in [
1557 "time_embedding.linear_1.weight",
1558 "time_embedding.linear_2.bias",
1559 "conv_in.weight",
1560 "down_blocks.0.resnets.0.norm1.weight",
1561 "down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight",
1562 "mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight",
1563 "mid_block.resnets.1.conv2.weight",
1564 "up_blocks.0.resnets.0.norm1.weight",
1565 "up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
1566 "conv_norm_out.weight",
1567 "conv_out.bias",
1568 ] {
1569 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1570 }
1571 }
1572}