Skip to main content

ferrotorch_diffusion/
resnet_block_time.rs

1//! `ResnetBlock2DTime` — the time-conditioned variant of the
2//! `ResnetBlock2D` used by the SD UNet.
3//!
4//! Identical to `diffusers.models.resnet.ResnetBlock2D` configured with
5//! `temb_channels = 1280` and the default `time_embedding_norm = "default"`
6//! (the SD-1.5 setting):
7//!
8//! ```text
9//! h = silu(norm1(x)); h = conv1(h)
10//! t = silu(temb); t = time_emb_proj(t).view(B, out, 1, 1)
11//! h = h + t
12//! h = silu(norm2(h)); h = conv2(h)
13//! r = x if in==out else conv_shortcut(x)
14//! out = h + r       (output_scale_factor = 1.0)
15//! ```
16//!
17//! State-dict layout (matches diffusers):
18//!
19//! ```text
20//! norm1.{weight,bias}
21//! conv1.{weight,bias}
22//! time_emb_proj.{weight,bias}    [out_channels, temb_channels]
23//! norm2.{weight,bias}
24//! conv2.{weight,bias}
25//! conv_shortcut.{weight,bias}    (iff in_channels != out_channels)
26//! ```
27//!
28//! ## REQ status (per `.design/ferrotorch-diffusion/resnet_block_time.md`)
29//!
30//! | REQ | Status | Evidence |
31//! |---|---|---|
32//! | REQ-1 | SHIPPED | `ResnetBlock2DTime::new` at `ferrotorch-diffusion/src/resnet_block_time.rs:61..97`; consumer: `ferrotorch-diffusion/src/unet.rs:109`, `unet.rs:302`, `unet.rs:468`, `unet.rs:478`, `unet.rs:657`, `unet.rs:865` all build resnets |
33//! | REQ-2 | SHIPPED | `forward_t` at `ferrotorch-diffusion/src/resnet_block_time.rs:107..146`; consumer: every UNet block in `ferrotorch-diffusion/src/unet.rs` calls `resnet.forward_t(&h, temb)?` |
34//! | REQ-3 | SHIPPED | `named_parameters` at `ferrotorch-diffusion/src/resnet_block_time.rs:182..205` and `load_state_dict` at `resnet_block_time.rs:215..248`; consumer: `ferrotorch-diffusion/src/safetensors_loader.rs:151..175` routes HF UNet keys through this layout |
35//! | REQ-4 | SHIPPED | `Module::forward` error guard at `ferrotorch-diffusion/src/resnet_block_time.rs:150..156`; consumer: every UNet caller uses `forward_t` explicitly so the guard surfaces clearly |
36
37use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
38use ferrotorch_nn::module::{Module, StateDict};
39use ferrotorch_nn::parameter::Parameter;
40use ferrotorch_nn::{Conv2d, GroupNorm, Linear, SiLU};
41
42/// Time-conditioned residual block.
43#[derive(Debug)]
44pub struct ResnetBlock2DTime<T: Float> {
45    /// First GroupNorm.
46    pub norm1: GroupNorm<T>,
47    /// First Conv2d.
48    pub conv1: Conv2d<T>,
49    /// Linear over the time embedding (`temb_channels -> out_channels`).
50    pub time_emb_proj: Linear<T>,
51    /// Second GroupNorm.
52    pub norm2: GroupNorm<T>,
53    /// Second Conv2d.
54    pub conv2: Conv2d<T>,
55    /// Optional 1x1 shortcut.
56    pub conv_shortcut: Option<Conv2d<T>>,
57    activation: SiLU,
58    in_channels: usize,
59    out_channels: usize,
60    training: bool,
61}
62
63impl<T: Float> ResnetBlock2DTime<T> {
64    /// Build a randomly-initialized time-conditioned resnet block.
65    ///
66    /// # Errors
67    ///
68    /// Returns the underlying [`FerrotorchError`] on bad channel/group
69    /// config.
70    pub fn new(
71        in_channels: usize,
72        out_channels: usize,
73        temb_channels: usize,
74        norm_num_groups: usize,
75        eps: f64,
76    ) -> FerrotorchResult<Self> {
77        let norm1 = GroupNorm::<T>::new(norm_num_groups, in_channels, eps, true)?;
78        let conv1 = Conv2d::<T>::new(in_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
79        let time_emb_proj = Linear::<T>::new(temb_channels, out_channels, true)?;
80        let norm2 = GroupNorm::<T>::new(norm_num_groups, out_channels, eps, true)?;
81        let conv2 = Conv2d::<T>::new(out_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
82        let conv_shortcut = if in_channels == out_channels {
83            None
84        } else {
85            Some(Conv2d::<T>::new(
86                in_channels,
87                out_channels,
88                (1, 1),
89                (1, 1),
90                (0, 0),
91                true,
92            )?)
93        };
94        Ok(Self {
95            norm1,
96            conv1,
97            time_emb_proj,
98            norm2,
99            conv2,
100            conv_shortcut,
101            activation: SiLU::new(),
102            in_channels,
103            out_channels,
104            training: false,
105        })
106    }
107
108    /// Forward with the time embedding `temb` (shape `[B, temb_channels]`).
109    ///
110    /// `x` has shape `[B, in_channels, H, W]`; output is
111    /// `[B, out_channels, H, W]`.
112    ///
113    /// # Errors
114    ///
115    /// Returns [`FerrotorchError::ShapeMismatch`] for bad input ranks.
116    pub fn forward_t(&self, x: &Tensor<T>, temb: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
117        if x.ndim() != 4 || x.shape()[1] != self.in_channels {
118            return Err(FerrotorchError::ShapeMismatch {
119                message: format!(
120                    "ResnetBlock2DTime: expected x [B, {}, H, W], got {:?}",
121                    self.in_channels,
122                    x.shape()
123                ),
124            });
125        }
126        if temb.ndim() != 2 {
127            return Err(FerrotorchError::ShapeMismatch {
128                message: format!(
129                    "ResnetBlock2DTime: expected temb [B, temb_channels], got {:?}",
130                    temb.shape()
131                ),
132            });
133        }
134        let b = x.shape()[0];
135        // h = silu(norm1(x)); h = conv1(h)
136        let mut h = self.norm1.forward(x)?;
137        h = self.activation.forward(&h)?;
138        h = self.conv1.forward(&h)?;
139        // Time bias: silu(temb) -> Linear -> [B, out_channels, 1, 1]
140        let temb_silu = self.activation.forward(temb)?;
141        let temb_proj = self.time_emb_proj.forward(&temb_silu)?;
142        let temb_4d = temb_proj.reshape_t(&[b as isize, self.out_channels as isize, 1, 1])?;
143        h = ferrotorch_core::grad_fns::arithmetic::add(&h, &temb_4d)?;
144        // h = silu(norm2(h)); h = conv2(h)
145        h = self.norm2.forward(&h)?;
146        h = self.activation.forward(&h)?;
147        h = self.conv2.forward(&h)?;
148        // Residual.
149        let res = if let Some(sc) = &self.conv_shortcut {
150            sc.forward(x)?
151        } else {
152            x.clone()
153        };
154        ferrotorch_core::grad_fns::arithmetic::add(&h, &res)
155    }
156}
157
158impl<T: Float> Module<T> for ResnetBlock2DTime<T> {
159    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
160        Err(FerrotorchError::InvalidArgument {
161            message: "ResnetBlock2DTime::forward: time-conditioned block requires \
162                      a time embedding — call forward_t instead"
163                .into(),
164        })
165    }
166
167    fn parameters(&self) -> Vec<&Parameter<T>> {
168        let mut o = Vec::new();
169        o.extend(self.norm1.parameters());
170        o.extend(self.conv1.parameters());
171        o.extend(self.time_emb_proj.parameters());
172        o.extend(self.norm2.parameters());
173        o.extend(self.conv2.parameters());
174        if let Some(sc) = &self.conv_shortcut {
175            o.extend(sc.parameters());
176        }
177        o
178    }
179    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
180        let mut o = Vec::new();
181        o.extend(self.norm1.parameters_mut());
182        o.extend(self.conv1.parameters_mut());
183        o.extend(self.time_emb_proj.parameters_mut());
184        o.extend(self.norm2.parameters_mut());
185        o.extend(self.conv2.parameters_mut());
186        if let Some(sc) = self.conv_shortcut.as_mut() {
187            o.extend(sc.parameters_mut());
188        }
189        o
190    }
191    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
192        let mut o = Vec::new();
193        for (n, p) in self.norm1.named_parameters() {
194            o.push((format!("norm1.{n}"), p));
195        }
196        for (n, p) in self.conv1.named_parameters() {
197            o.push((format!("conv1.{n}"), p));
198        }
199        for (n, p) in self.time_emb_proj.named_parameters() {
200            o.push((format!("time_emb_proj.{n}"), p));
201        }
202        for (n, p) in self.norm2.named_parameters() {
203            o.push((format!("norm2.{n}"), p));
204        }
205        for (n, p) in self.conv2.named_parameters() {
206            o.push((format!("conv2.{n}"), p));
207        }
208        if let Some(sc) = &self.conv_shortcut {
209            for (n, p) in sc.named_parameters() {
210                o.push((format!("conv_shortcut.{n}"), p));
211            }
212        }
213        o
214    }
215    fn train(&mut self) {
216        self.training = true;
217    }
218    fn eval(&mut self) {
219        self.training = false;
220    }
221    fn is_training(&self) -> bool {
222        self.training
223    }
224    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
225        let extract = |prefix: &str| -> StateDict<T> {
226            let p = format!("{prefix}.");
227            state
228                .iter()
229                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
230                .collect()
231        };
232        if strict {
233            for k in state.keys() {
234                let ok = k.starts_with("norm1.")
235                    || k.starts_with("conv1.")
236                    || k.starts_with("time_emb_proj.")
237                    || k.starts_with("norm2.")
238                    || k.starts_with("conv2.")
239                    || k.starts_with("conv_shortcut.");
240                if !ok {
241                    return Err(FerrotorchError::InvalidArgument {
242                        message: format!("unexpected key in ResnetBlock2DTime state_dict: \"{k}\""),
243                    });
244                }
245            }
246        }
247        self.norm1.load_state_dict(&extract("norm1"), strict)?;
248        self.conv1.load_state_dict(&extract("conv1"), strict)?;
249        self.time_emb_proj
250            .load_state_dict(&extract("time_emb_proj"), strict)?;
251        self.norm2.load_state_dict(&extract("norm2"), strict)?;
252        self.conv2.load_state_dict(&extract("conv2"), strict)?;
253        if let Some(sc) = self.conv_shortcut.as_mut() {
254            sc.load_state_dict(&extract("conv_shortcut"), strict)?;
255        }
256        Ok(())
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use ferrotorch_core::TensorStorage;
264
265    #[test]
266    fn resnet_time_shape_same_channels() {
267        let r = ResnetBlock2DTime::<f32>::new(16, 16, 32, 4, 1e-5).unwrap();
268        assert!(r.conv_shortcut.is_none());
269        let x = Tensor::from_storage(
270            TensorStorage::cpu(vec![0.01f32; 16 * 4 * 4]),
271            vec![1, 16, 4, 4],
272            false,
273        )
274        .unwrap();
275        let t = Tensor::from_storage(TensorStorage::cpu(vec![0.01f32; 32]), vec![1, 32], false)
276            .unwrap();
277        let y = r.forward_t(&x, &t).unwrap();
278        assert_eq!(y.shape(), &[1, 16, 4, 4]);
279    }
280
281    #[test]
282    fn resnet_time_shape_change_channels() {
283        let r = ResnetBlock2DTime::<f32>::new(16, 32, 32, 4, 1e-5).unwrap();
284        assert!(r.conv_shortcut.is_some());
285        let x = Tensor::from_storage(
286            TensorStorage::cpu(vec![0.01f32; 16 * 4 * 4]),
287            vec![1, 16, 4, 4],
288            false,
289        )
290        .unwrap();
291        let t = Tensor::from_storage(TensorStorage::cpu(vec![0.01f32; 32]), vec![1, 32], false)
292            .unwrap();
293        let y = r.forward_t(&x, &t).unwrap();
294        assert_eq!(y.shape(), &[1, 32, 4, 4]);
295    }
296
297    #[test]
298    fn resnet_time_named_parameters() {
299        let r = ResnetBlock2DTime::<f32>::new(16, 32, 32, 4, 1e-5).unwrap();
300        let names: Vec<String> = r.named_parameters().into_iter().map(|(n, _)| n).collect();
301        for k in [
302            "norm1.weight",
303            "conv1.weight",
304            "time_emb_proj.weight",
305            "time_emb_proj.bias",
306            "norm2.weight",
307            "conv2.weight",
308            "conv_shortcut.weight",
309        ] {
310            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
311        }
312    }
313}