1use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
38use ferrotorch_nn::module::{Module, StateDict};
39use ferrotorch_nn::parameter::Parameter;
40use ferrotorch_nn::{Conv2d, GroupNorm, Linear, SiLU};
41
42#[derive(Debug)]
44pub struct ResnetBlock2DTime<T: Float> {
45 pub norm1: GroupNorm<T>,
47 pub conv1: Conv2d<T>,
49 pub time_emb_proj: Linear<T>,
51 pub norm2: GroupNorm<T>,
53 pub conv2: Conv2d<T>,
55 pub conv_shortcut: Option<Conv2d<T>>,
57 activation: SiLU,
58 in_channels: usize,
59 out_channels: usize,
60 training: bool,
61}
62
63impl<T: Float> ResnetBlock2DTime<T> {
64 pub fn new(
71 in_channels: usize,
72 out_channels: usize,
73 temb_channels: usize,
74 norm_num_groups: usize,
75 eps: f64,
76 ) -> FerrotorchResult<Self> {
77 let norm1 = GroupNorm::<T>::new(norm_num_groups, in_channels, eps, true)?;
78 let conv1 = Conv2d::<T>::new(in_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
79 let time_emb_proj = Linear::<T>::new(temb_channels, out_channels, true)?;
80 let norm2 = GroupNorm::<T>::new(norm_num_groups, out_channels, eps, true)?;
81 let conv2 = Conv2d::<T>::new(out_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
82 let conv_shortcut = if in_channels == out_channels {
83 None
84 } else {
85 Some(Conv2d::<T>::new(
86 in_channels,
87 out_channels,
88 (1, 1),
89 (1, 1),
90 (0, 0),
91 true,
92 )?)
93 };
94 Ok(Self {
95 norm1,
96 conv1,
97 time_emb_proj,
98 norm2,
99 conv2,
100 conv_shortcut,
101 activation: SiLU::new(),
102 in_channels,
103 out_channels,
104 training: false,
105 })
106 }
107
108 pub fn forward_t(&self, x: &Tensor<T>, temb: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
117 if x.ndim() != 4 || x.shape()[1] != self.in_channels {
118 return Err(FerrotorchError::ShapeMismatch {
119 message: format!(
120 "ResnetBlock2DTime: expected x [B, {}, H, W], got {:?}",
121 self.in_channels,
122 x.shape()
123 ),
124 });
125 }
126 if temb.ndim() != 2 {
127 return Err(FerrotorchError::ShapeMismatch {
128 message: format!(
129 "ResnetBlock2DTime: expected temb [B, temb_channels], got {:?}",
130 temb.shape()
131 ),
132 });
133 }
134 let b = x.shape()[0];
135 let mut h = self.norm1.forward(x)?;
137 h = self.activation.forward(&h)?;
138 h = self.conv1.forward(&h)?;
139 let temb_silu = self.activation.forward(temb)?;
141 let temb_proj = self.time_emb_proj.forward(&temb_silu)?;
142 let temb_4d = temb_proj.reshape_t(&[b as isize, self.out_channels as isize, 1, 1])?;
143 h = ferrotorch_core::grad_fns::arithmetic::add(&h, &temb_4d)?;
144 h = self.norm2.forward(&h)?;
146 h = self.activation.forward(&h)?;
147 h = self.conv2.forward(&h)?;
148 let res = if let Some(sc) = &self.conv_shortcut {
150 sc.forward(x)?
151 } else {
152 x.clone()
153 };
154 ferrotorch_core::grad_fns::arithmetic::add(&h, &res)
155 }
156}
157
158impl<T: Float> Module<T> for ResnetBlock2DTime<T> {
159 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
160 Err(FerrotorchError::InvalidArgument {
161 message: "ResnetBlock2DTime::forward: time-conditioned block requires \
162 a time embedding — call forward_t instead"
163 .into(),
164 })
165 }
166
167 fn parameters(&self) -> Vec<&Parameter<T>> {
168 let mut o = Vec::new();
169 o.extend(self.norm1.parameters());
170 o.extend(self.conv1.parameters());
171 o.extend(self.time_emb_proj.parameters());
172 o.extend(self.norm2.parameters());
173 o.extend(self.conv2.parameters());
174 if let Some(sc) = &self.conv_shortcut {
175 o.extend(sc.parameters());
176 }
177 o
178 }
179 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
180 let mut o = Vec::new();
181 o.extend(self.norm1.parameters_mut());
182 o.extend(self.conv1.parameters_mut());
183 o.extend(self.time_emb_proj.parameters_mut());
184 o.extend(self.norm2.parameters_mut());
185 o.extend(self.conv2.parameters_mut());
186 if let Some(sc) = self.conv_shortcut.as_mut() {
187 o.extend(sc.parameters_mut());
188 }
189 o
190 }
191 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
192 let mut o = Vec::new();
193 for (n, p) in self.norm1.named_parameters() {
194 o.push((format!("norm1.{n}"), p));
195 }
196 for (n, p) in self.conv1.named_parameters() {
197 o.push((format!("conv1.{n}"), p));
198 }
199 for (n, p) in self.time_emb_proj.named_parameters() {
200 o.push((format!("time_emb_proj.{n}"), p));
201 }
202 for (n, p) in self.norm2.named_parameters() {
203 o.push((format!("norm2.{n}"), p));
204 }
205 for (n, p) in self.conv2.named_parameters() {
206 o.push((format!("conv2.{n}"), p));
207 }
208 if let Some(sc) = &self.conv_shortcut {
209 for (n, p) in sc.named_parameters() {
210 o.push((format!("conv_shortcut.{n}"), p));
211 }
212 }
213 o
214 }
215 fn train(&mut self) {
216 self.training = true;
217 }
218 fn eval(&mut self) {
219 self.training = false;
220 }
221 fn is_training(&self) -> bool {
222 self.training
223 }
224 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
225 let extract = |prefix: &str| -> StateDict<T> {
226 let p = format!("{prefix}.");
227 state
228 .iter()
229 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
230 .collect()
231 };
232 if strict {
233 for k in state.keys() {
234 let ok = k.starts_with("norm1.")
235 || k.starts_with("conv1.")
236 || k.starts_with("time_emb_proj.")
237 || k.starts_with("norm2.")
238 || k.starts_with("conv2.")
239 || k.starts_with("conv_shortcut.");
240 if !ok {
241 return Err(FerrotorchError::InvalidArgument {
242 message: format!("unexpected key in ResnetBlock2DTime state_dict: \"{k}\""),
243 });
244 }
245 }
246 }
247 self.norm1.load_state_dict(&extract("norm1"), strict)?;
248 self.conv1.load_state_dict(&extract("conv1"), strict)?;
249 self.time_emb_proj
250 .load_state_dict(&extract("time_emb_proj"), strict)?;
251 self.norm2.load_state_dict(&extract("norm2"), strict)?;
252 self.conv2.load_state_dict(&extract("conv2"), strict)?;
253 if let Some(sc) = self.conv_shortcut.as_mut() {
254 sc.load_state_dict(&extract("conv_shortcut"), strict)?;
255 }
256 Ok(())
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use ferrotorch_core::TensorStorage;
264
265 #[test]
266 fn resnet_time_shape_same_channels() {
267 let r = ResnetBlock2DTime::<f32>::new(16, 16, 32, 4, 1e-5).unwrap();
268 assert!(r.conv_shortcut.is_none());
269 let x = Tensor::from_storage(
270 TensorStorage::cpu(vec![0.01f32; 16 * 4 * 4]),
271 vec![1, 16, 4, 4],
272 false,
273 )
274 .unwrap();
275 let t = Tensor::from_storage(TensorStorage::cpu(vec![0.01f32; 32]), vec![1, 32], false)
276 .unwrap();
277 let y = r.forward_t(&x, &t).unwrap();
278 assert_eq!(y.shape(), &[1, 16, 4, 4]);
279 }
280
281 #[test]
282 fn resnet_time_shape_change_channels() {
283 let r = ResnetBlock2DTime::<f32>::new(16, 32, 32, 4, 1e-5).unwrap();
284 assert!(r.conv_shortcut.is_some());
285 let x = Tensor::from_storage(
286 TensorStorage::cpu(vec![0.01f32; 16 * 4 * 4]),
287 vec![1, 16, 4, 4],
288 false,
289 )
290 .unwrap();
291 let t = Tensor::from_storage(TensorStorage::cpu(vec![0.01f32; 32]), vec![1, 32], false)
292 .unwrap();
293 let y = r.forward_t(&x, &t).unwrap();
294 assert_eq!(y.shape(), &[1, 32, 4, 4]);
295 }
296
297 #[test]
298 fn resnet_time_named_parameters() {
299 let r = ResnetBlock2DTime::<f32>::new(16, 32, 32, 4, 1e-5).unwrap();
300 let names: Vec<String> = r.named_parameters().into_iter().map(|(n, _)| n).collect();
301 for k in [
302 "norm1.weight",
303 "conv1.weight",
304 "time_emb_proj.weight",
305 "time_emb_proj.bias",
306 "norm2.weight",
307 "conv2.weight",
308 "conv_shortcut.weight",
309 ] {
310 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
311 }
312 }
313}