Skip to main content

ferrotorch_diffusion/
resnet_block_time.rs

1//! `ResnetBlock2DTime` — the time-conditioned variant of the
2//! `ResnetBlock2D` used by the SD UNet.
3//!
4//! Identical to `diffusers.models.resnet.ResnetBlock2D` configured with
5//! `temb_channels = 1280` and the default `time_embedding_norm = "default"`
6//! (the SD-1.5 setting):
7//!
8//! ```text
9//! h = silu(norm1(x)); h = conv1(h)
10//! t = silu(temb); t = time_emb_proj(t).view(B, out, 1, 1)
11//! h = h + t
12//! h = silu(norm2(h)); h = conv2(h)
13//! r = x if in==out else conv_shortcut(x)
14//! out = h + r       (output_scale_factor = 1.0)
15//! ```
16//!
17//! State-dict layout (matches diffusers):
18//!
19//! ```text
20//! norm1.{weight,bias}
21//! conv1.{weight,bias}
22//! time_emb_proj.{weight,bias}    [out_channels, temb_channels]
23//! norm2.{weight,bias}
24//! conv2.{weight,bias}
25//! conv_shortcut.{weight,bias}    (iff in_channels != out_channels)
26//! ```
27
28use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
29use ferrotorch_nn::module::{Module, StateDict};
30use ferrotorch_nn::parameter::Parameter;
31use ferrotorch_nn::{Conv2d, GroupNorm, Linear, SiLU};
32
33/// Time-conditioned residual block.
34#[derive(Debug)]
35pub struct ResnetBlock2DTime<T: Float> {
36    /// First GroupNorm.
37    pub norm1: GroupNorm<T>,
38    /// First Conv2d.
39    pub conv1: Conv2d<T>,
40    /// Linear over the time embedding (`temb_channels -> out_channels`).
41    pub time_emb_proj: Linear<T>,
42    /// Second GroupNorm.
43    pub norm2: GroupNorm<T>,
44    /// Second Conv2d.
45    pub conv2: Conv2d<T>,
46    /// Optional 1x1 shortcut.
47    pub conv_shortcut: Option<Conv2d<T>>,
48    activation: SiLU,
49    in_channels: usize,
50    out_channels: usize,
51    training: bool,
52}
53
54impl<T: Float> ResnetBlock2DTime<T> {
55    /// Build a randomly-initialized time-conditioned resnet block.
56    ///
57    /// # Errors
58    ///
59    /// Returns the underlying [`FerrotorchError`] on bad channel/group
60    /// config.
61    pub fn new(
62        in_channels: usize,
63        out_channels: usize,
64        temb_channels: usize,
65        norm_num_groups: usize,
66        eps: f64,
67    ) -> FerrotorchResult<Self> {
68        let norm1 = GroupNorm::<T>::new(norm_num_groups, in_channels, eps, true)?;
69        let conv1 = Conv2d::<T>::new(in_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
70        let time_emb_proj = Linear::<T>::new(temb_channels, out_channels, true)?;
71        let norm2 = GroupNorm::<T>::new(norm_num_groups, out_channels, eps, true)?;
72        let conv2 = Conv2d::<T>::new(out_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
73        let conv_shortcut = if in_channels == out_channels {
74            None
75        } else {
76            Some(Conv2d::<T>::new(
77                in_channels,
78                out_channels,
79                (1, 1),
80                (1, 1),
81                (0, 0),
82                true,
83            )?)
84        };
85        Ok(Self {
86            norm1,
87            conv1,
88            time_emb_proj,
89            norm2,
90            conv2,
91            conv_shortcut,
92            activation: SiLU::new(),
93            in_channels,
94            out_channels,
95            training: false,
96        })
97    }
98
99    /// Forward with the time embedding `temb` (shape `[B, temb_channels]`).
100    ///
101    /// `x` has shape `[B, in_channels, H, W]`; output is
102    /// `[B, out_channels, H, W]`.
103    ///
104    /// # Errors
105    ///
106    /// Returns [`FerrotorchError::ShapeMismatch`] for bad input ranks.
107    pub fn forward_t(&self, x: &Tensor<T>, temb: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
108        if x.ndim() != 4 || x.shape()[1] != self.in_channels {
109            return Err(FerrotorchError::ShapeMismatch {
110                message: format!(
111                    "ResnetBlock2DTime: expected x [B, {}, H, W], got {:?}",
112                    self.in_channels,
113                    x.shape()
114                ),
115            });
116        }
117        if temb.ndim() != 2 {
118            return Err(FerrotorchError::ShapeMismatch {
119                message: format!(
120                    "ResnetBlock2DTime: expected temb [B, temb_channels], got {:?}",
121                    temb.shape()
122                ),
123            });
124        }
125        let b = x.shape()[0];
126        // h = silu(norm1(x)); h = conv1(h)
127        let mut h = self.norm1.forward(x)?;
128        h = self.activation.forward(&h)?;
129        h = self.conv1.forward(&h)?;
130        // Time bias: silu(temb) -> Linear -> [B, out_channels, 1, 1]
131        let temb_silu = self.activation.forward(temb)?;
132        let temb_proj = self.time_emb_proj.forward(&temb_silu)?;
133        let temb_4d = temb_proj.reshape_t(&[
134            b as isize,
135            self.out_channels as isize,
136            1,
137            1,
138        ])?;
139        h = ferrotorch_core::grad_fns::arithmetic::add(&h, &temb_4d)?;
140        // h = silu(norm2(h)); h = conv2(h)
141        h = self.norm2.forward(&h)?;
142        h = self.activation.forward(&h)?;
143        h = self.conv2.forward(&h)?;
144        // Residual.
145        let res = if let Some(sc) = &self.conv_shortcut {
146            sc.forward(x)?
147        } else {
148            x.clone()
149        };
150        ferrotorch_core::grad_fns::arithmetic::add(&h, &res)
151    }
152}
153
154impl<T: Float> Module<T> for ResnetBlock2DTime<T> {
155    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
156        Err(FerrotorchError::InvalidArgument {
157            message: "ResnetBlock2DTime::forward: time-conditioned block requires \
158                      a time embedding — call forward_t instead"
159                .into(),
160        })
161    }
162
163    fn parameters(&self) -> Vec<&Parameter<T>> {
164        let mut o = Vec::new();
165        o.extend(self.norm1.parameters());
166        o.extend(self.conv1.parameters());
167        o.extend(self.time_emb_proj.parameters());
168        o.extend(self.norm2.parameters());
169        o.extend(self.conv2.parameters());
170        if let Some(sc) = &self.conv_shortcut {
171            o.extend(sc.parameters());
172        }
173        o
174    }
175    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
176        let mut o = Vec::new();
177        o.extend(self.norm1.parameters_mut());
178        o.extend(self.conv1.parameters_mut());
179        o.extend(self.time_emb_proj.parameters_mut());
180        o.extend(self.norm2.parameters_mut());
181        o.extend(self.conv2.parameters_mut());
182        if let Some(sc) = self.conv_shortcut.as_mut() {
183            o.extend(sc.parameters_mut());
184        }
185        o
186    }
187    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
188        let mut o = Vec::new();
189        for (n, p) in self.norm1.named_parameters() {
190            o.push((format!("norm1.{n}"), p));
191        }
192        for (n, p) in self.conv1.named_parameters() {
193            o.push((format!("conv1.{n}"), p));
194        }
195        for (n, p) in self.time_emb_proj.named_parameters() {
196            o.push((format!("time_emb_proj.{n}"), p));
197        }
198        for (n, p) in self.norm2.named_parameters() {
199            o.push((format!("norm2.{n}"), p));
200        }
201        for (n, p) in self.conv2.named_parameters() {
202            o.push((format!("conv2.{n}"), p));
203        }
204        if let Some(sc) = &self.conv_shortcut {
205            for (n, p) in sc.named_parameters() {
206                o.push((format!("conv_shortcut.{n}"), p));
207            }
208        }
209        o
210    }
211    fn train(&mut self) {
212        self.training = true;
213    }
214    fn eval(&mut self) {
215        self.training = false;
216    }
217    fn is_training(&self) -> bool {
218        self.training
219    }
220    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
221        let extract = |prefix: &str| -> StateDict<T> {
222            let p = format!("{prefix}.");
223            state
224                .iter()
225                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
226                .collect()
227        };
228        if strict {
229            for k in state.keys() {
230                let ok = k.starts_with("norm1.")
231                    || k.starts_with("conv1.")
232                    || k.starts_with("time_emb_proj.")
233                    || k.starts_with("norm2.")
234                    || k.starts_with("conv2.")
235                    || k.starts_with("conv_shortcut.");
236                if !ok {
237                    return Err(FerrotorchError::InvalidArgument {
238                        message: format!(
239                            "unexpected key in ResnetBlock2DTime state_dict: \"{k}\""
240                        ),
241                    });
242                }
243            }
244        }
245        self.norm1.load_state_dict(&extract("norm1"), strict)?;
246        self.conv1.load_state_dict(&extract("conv1"), strict)?;
247        self.time_emb_proj
248            .load_state_dict(&extract("time_emb_proj"), strict)?;
249        self.norm2.load_state_dict(&extract("norm2"), strict)?;
250        self.conv2.load_state_dict(&extract("conv2"), strict)?;
251        if let Some(sc) = self.conv_shortcut.as_mut() {
252            sc.load_state_dict(&extract("conv_shortcut"), strict)?;
253        }
254        Ok(())
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use ferrotorch_core::TensorStorage;
262
263    #[test]
264    fn resnet_time_shape_same_channels() {
265        let r = ResnetBlock2DTime::<f32>::new(16, 16, 32, 4, 1e-5).unwrap();
266        assert!(r.conv_shortcut.is_none());
267        let x = Tensor::from_storage(
268            TensorStorage::cpu(vec![0.01f32; 16 * 4 * 4]),
269            vec![1, 16, 4, 4],
270            false,
271        )
272        .unwrap();
273        let t = Tensor::from_storage(
274            TensorStorage::cpu(vec![0.01f32; 32]),
275            vec![1, 32],
276            false,
277        )
278        .unwrap();
279        let y = r.forward_t(&x, &t).unwrap();
280        assert_eq!(y.shape(), &[1, 16, 4, 4]);
281    }
282
283    #[test]
284    fn resnet_time_shape_change_channels() {
285        let r = ResnetBlock2DTime::<f32>::new(16, 32, 32, 4, 1e-5).unwrap();
286        assert!(r.conv_shortcut.is_some());
287        let x = Tensor::from_storage(
288            TensorStorage::cpu(vec![0.01f32; 16 * 4 * 4]),
289            vec![1, 16, 4, 4],
290            false,
291        )
292        .unwrap();
293        let t = Tensor::from_storage(
294            TensorStorage::cpu(vec![0.01f32; 32]),
295            vec![1, 32],
296            false,
297        )
298        .unwrap();
299        let y = r.forward_t(&x, &t).unwrap();
300        assert_eq!(y.shape(), &[1, 32, 4, 4]);
301    }
302
303    #[test]
304    fn resnet_time_named_parameters() {
305        let r = ResnetBlock2DTime::<f32>::new(16, 32, 32, 4, 1e-5).unwrap();
306        let names: Vec<String> = r.named_parameters().into_iter().map(|(n, _)| n).collect();
307        for k in [
308            "norm1.weight",
309            "conv1.weight",
310            "time_emb_proj.weight",
311            "time_emb_proj.bias",
312            "norm2.weight",
313            "conv2.weight",
314            "conv_shortcut.weight",
315        ] {
316            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
317        }
318    }
319}