1use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
29use ferrotorch_nn::module::{Module, StateDict};
30use ferrotorch_nn::parameter::Parameter;
31use ferrotorch_nn::{Conv2d, GroupNorm, Linear, SiLU};
32
33#[derive(Debug)]
35pub struct ResnetBlock2DTime<T: Float> {
36 pub norm1: GroupNorm<T>,
38 pub conv1: Conv2d<T>,
40 pub time_emb_proj: Linear<T>,
42 pub norm2: GroupNorm<T>,
44 pub conv2: Conv2d<T>,
46 pub conv_shortcut: Option<Conv2d<T>>,
48 activation: SiLU,
49 in_channels: usize,
50 out_channels: usize,
51 training: bool,
52}
53
54impl<T: Float> ResnetBlock2DTime<T> {
55 pub fn new(
62 in_channels: usize,
63 out_channels: usize,
64 temb_channels: usize,
65 norm_num_groups: usize,
66 eps: f64,
67 ) -> FerrotorchResult<Self> {
68 let norm1 = GroupNorm::<T>::new(norm_num_groups, in_channels, eps, true)?;
69 let conv1 = Conv2d::<T>::new(in_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
70 let time_emb_proj = Linear::<T>::new(temb_channels, out_channels, true)?;
71 let norm2 = GroupNorm::<T>::new(norm_num_groups, out_channels, eps, true)?;
72 let conv2 = Conv2d::<T>::new(out_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
73 let conv_shortcut = if in_channels == out_channels {
74 None
75 } else {
76 Some(Conv2d::<T>::new(
77 in_channels,
78 out_channels,
79 (1, 1),
80 (1, 1),
81 (0, 0),
82 true,
83 )?)
84 };
85 Ok(Self {
86 norm1,
87 conv1,
88 time_emb_proj,
89 norm2,
90 conv2,
91 conv_shortcut,
92 activation: SiLU::new(),
93 in_channels,
94 out_channels,
95 training: false,
96 })
97 }
98
99 pub fn forward_t(&self, x: &Tensor<T>, temb: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
108 if x.ndim() != 4 || x.shape()[1] != self.in_channels {
109 return Err(FerrotorchError::ShapeMismatch {
110 message: format!(
111 "ResnetBlock2DTime: expected x [B, {}, H, W], got {:?}",
112 self.in_channels,
113 x.shape()
114 ),
115 });
116 }
117 if temb.ndim() != 2 {
118 return Err(FerrotorchError::ShapeMismatch {
119 message: format!(
120 "ResnetBlock2DTime: expected temb [B, temb_channels], got {:?}",
121 temb.shape()
122 ),
123 });
124 }
125 let b = x.shape()[0];
126 let mut h = self.norm1.forward(x)?;
128 h = self.activation.forward(&h)?;
129 h = self.conv1.forward(&h)?;
130 let temb_silu = self.activation.forward(temb)?;
132 let temb_proj = self.time_emb_proj.forward(&temb_silu)?;
133 let temb_4d = temb_proj.reshape_t(&[
134 b as isize,
135 self.out_channels as isize,
136 1,
137 1,
138 ])?;
139 h = ferrotorch_core::grad_fns::arithmetic::add(&h, &temb_4d)?;
140 h = self.norm2.forward(&h)?;
142 h = self.activation.forward(&h)?;
143 h = self.conv2.forward(&h)?;
144 let res = if let Some(sc) = &self.conv_shortcut {
146 sc.forward(x)?
147 } else {
148 x.clone()
149 };
150 ferrotorch_core::grad_fns::arithmetic::add(&h, &res)
151 }
152}
153
154impl<T: Float> Module<T> for ResnetBlock2DTime<T> {
155 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
156 Err(FerrotorchError::InvalidArgument {
157 message: "ResnetBlock2DTime::forward: time-conditioned block requires \
158 a time embedding — call forward_t instead"
159 .into(),
160 })
161 }
162
163 fn parameters(&self) -> Vec<&Parameter<T>> {
164 let mut o = Vec::new();
165 o.extend(self.norm1.parameters());
166 o.extend(self.conv1.parameters());
167 o.extend(self.time_emb_proj.parameters());
168 o.extend(self.norm2.parameters());
169 o.extend(self.conv2.parameters());
170 if let Some(sc) = &self.conv_shortcut {
171 o.extend(sc.parameters());
172 }
173 o
174 }
175 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
176 let mut o = Vec::new();
177 o.extend(self.norm1.parameters_mut());
178 o.extend(self.conv1.parameters_mut());
179 o.extend(self.time_emb_proj.parameters_mut());
180 o.extend(self.norm2.parameters_mut());
181 o.extend(self.conv2.parameters_mut());
182 if let Some(sc) = self.conv_shortcut.as_mut() {
183 o.extend(sc.parameters_mut());
184 }
185 o
186 }
187 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
188 let mut o = Vec::new();
189 for (n, p) in self.norm1.named_parameters() {
190 o.push((format!("norm1.{n}"), p));
191 }
192 for (n, p) in self.conv1.named_parameters() {
193 o.push((format!("conv1.{n}"), p));
194 }
195 for (n, p) in self.time_emb_proj.named_parameters() {
196 o.push((format!("time_emb_proj.{n}"), p));
197 }
198 for (n, p) in self.norm2.named_parameters() {
199 o.push((format!("norm2.{n}"), p));
200 }
201 for (n, p) in self.conv2.named_parameters() {
202 o.push((format!("conv2.{n}"), p));
203 }
204 if let Some(sc) = &self.conv_shortcut {
205 for (n, p) in sc.named_parameters() {
206 o.push((format!("conv_shortcut.{n}"), p));
207 }
208 }
209 o
210 }
211 fn train(&mut self) {
212 self.training = true;
213 }
214 fn eval(&mut self) {
215 self.training = false;
216 }
217 fn is_training(&self) -> bool {
218 self.training
219 }
220 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
221 let extract = |prefix: &str| -> StateDict<T> {
222 let p = format!("{prefix}.");
223 state
224 .iter()
225 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
226 .collect()
227 };
228 if strict {
229 for k in state.keys() {
230 let ok = k.starts_with("norm1.")
231 || k.starts_with("conv1.")
232 || k.starts_with("time_emb_proj.")
233 || k.starts_with("norm2.")
234 || k.starts_with("conv2.")
235 || k.starts_with("conv_shortcut.");
236 if !ok {
237 return Err(FerrotorchError::InvalidArgument {
238 message: format!(
239 "unexpected key in ResnetBlock2DTime state_dict: \"{k}\""
240 ),
241 });
242 }
243 }
244 }
245 self.norm1.load_state_dict(&extract("norm1"), strict)?;
246 self.conv1.load_state_dict(&extract("conv1"), strict)?;
247 self.time_emb_proj
248 .load_state_dict(&extract("time_emb_proj"), strict)?;
249 self.norm2.load_state_dict(&extract("norm2"), strict)?;
250 self.conv2.load_state_dict(&extract("conv2"), strict)?;
251 if let Some(sc) = self.conv_shortcut.as_mut() {
252 sc.load_state_dict(&extract("conv_shortcut"), strict)?;
253 }
254 Ok(())
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use ferrotorch_core::TensorStorage;
262
263 #[test]
264 fn resnet_time_shape_same_channels() {
265 let r = ResnetBlock2DTime::<f32>::new(16, 16, 32, 4, 1e-5).unwrap();
266 assert!(r.conv_shortcut.is_none());
267 let x = Tensor::from_storage(
268 TensorStorage::cpu(vec![0.01f32; 16 * 4 * 4]),
269 vec![1, 16, 4, 4],
270 false,
271 )
272 .unwrap();
273 let t = Tensor::from_storage(
274 TensorStorage::cpu(vec![0.01f32; 32]),
275 vec![1, 32],
276 false,
277 )
278 .unwrap();
279 let y = r.forward_t(&x, &t).unwrap();
280 assert_eq!(y.shape(), &[1, 16, 4, 4]);
281 }
282
283 #[test]
284 fn resnet_time_shape_change_channels() {
285 let r = ResnetBlock2DTime::<f32>::new(16, 32, 32, 4, 1e-5).unwrap();
286 assert!(r.conv_shortcut.is_some());
287 let x = Tensor::from_storage(
288 TensorStorage::cpu(vec![0.01f32; 16 * 4 * 4]),
289 vec![1, 16, 4, 4],
290 false,
291 )
292 .unwrap();
293 let t = Tensor::from_storage(
294 TensorStorage::cpu(vec![0.01f32; 32]),
295 vec![1, 32],
296 false,
297 )
298 .unwrap();
299 let y = r.forward_t(&x, &t).unwrap();
300 assert_eq!(y.shape(), &[1, 32, 4, 4]);
301 }
302
303 #[test]
304 fn resnet_time_named_parameters() {
305 let r = ResnetBlock2DTime::<f32>::new(16, 32, 32, 4, 1e-5).unwrap();
306 let names: Vec<String> = r.named_parameters().into_iter().map(|(n, _)| n).collect();
307 for k in [
308 "norm1.weight",
309 "conv1.weight",
310 "time_emb_proj.weight",
311 "time_emb_proj.bias",
312 "norm2.weight",
313 "conv2.weight",
314 "conv_shortcut.weight",
315 ] {
316 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
317 }
318 }
319}