1use std::collections::HashMap;
24
25use axonml_autograd::Variable;
26
27use crate::activation::ReLU;
28use crate::module::Module;
29use crate::parameter::Parameter;
30use crate::sequential::Sequential;
31
32pub struct ResidualBlock {
58 main_path: Sequential,
59 downsample: Option<Sequential>,
60 activation: Option<Box<dyn Module>>,
61 training: bool,
62}
63
64impl ResidualBlock {
65 pub fn new(main_path: Sequential) -> Self {
70 Self {
71 main_path,
72 downsample: None,
73 activation: Some(Box::new(ReLU)),
74 training: true,
75 }
76 }
77
78 pub fn with_downsample(mut self, downsample: Sequential) -> Self {
82 self.downsample = Some(downsample);
83 self
84 }
85
86 pub fn with_activation<M: Module + 'static>(mut self, activation: M) -> Self {
90 self.activation = Some(Box::new(activation));
91 self
92 }
93
94 pub fn without_activation(mut self) -> Self {
96 self.activation = None;
97 self
98 }
99}
100
101impl Module for ResidualBlock {
102 fn forward(&self, input: &Variable) -> Variable {
103 let identity = match &self.downsample {
104 Some(ds) => ds.forward(input),
105 None => input.clone(),
106 };
107
108 let out = self.main_path.forward(input);
109 let out = out.add_var(&identity);
110
111 match &self.activation {
112 Some(act) => act.forward(&out),
113 None => out,
114 }
115 }
116
117 fn parameters(&self) -> Vec<Parameter> {
118 let mut params = self.main_path.parameters();
119 if let Some(ds) = &self.downsample {
120 params.extend(ds.parameters());
121 }
122 if let Some(act) = &self.activation {
123 params.extend(act.parameters());
124 }
125 params
126 }
127
128 fn named_parameters(&self) -> HashMap<String, Parameter> {
129 let mut params = HashMap::new();
130 for (name, param) in self.main_path.named_parameters() {
131 params.insert(format!("main_path.{name}"), param);
132 }
133 if let Some(ds) = &self.downsample {
134 for (name, param) in ds.named_parameters() {
135 params.insert(format!("downsample.{name}"), param);
136 }
137 }
138 if let Some(act) = &self.activation {
139 for (name, param) in act.named_parameters() {
140 params.insert(format!("activation.{name}"), param);
141 }
142 }
143 params
144 }
145
146 fn set_training(&mut self, training: bool) {
147 self.training = training;
148 self.main_path.set_training(training);
149 if let Some(ds) = &mut self.downsample {
150 ds.set_training(training);
151 }
152 if let Some(act) = &mut self.activation {
153 act.set_training(training);
154 }
155 }
156
157 fn is_training(&self) -> bool {
158 self.training
159 }
160
161 fn name(&self) -> &'static str {
162 "ResidualBlock"
163 }
164}
165
166#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::activation::{GELU, ReLU};
174 use crate::layers::{BatchNorm1d, Conv1d, Linear};
175 use axonml_tensor::Tensor;
176
177 #[test]
178 fn test_residual_block_identity_skip() {
179 let main = Sequential::new()
181 .add(Linear::new(32, 32))
182 .add(ReLU)
183 .add(Linear::new(32, 32));
184
185 let block = ResidualBlock::new(main);
186
187 let input = Variable::new(
188 Tensor::from_vec(vec![1.0; 64], &[2, 32]).expect("tensor creation failed"),
189 false,
190 );
191 let output = block.forward(&input);
192
193 assert_eq!(output.shape(), vec![2, 32]);
195 }
196
197 #[test]
198 fn test_residual_block_with_downsample() {
199 let main = Sequential::new()
201 .add(Linear::new(32, 64))
202 .add(ReLU)
203 .add(Linear::new(64, 64));
204
205 let downsample = Sequential::new().add(Linear::new(32, 64));
207
208 let block = ResidualBlock::new(main).with_downsample(downsample);
209
210 let input = Variable::new(
211 Tensor::from_vec(vec![1.0; 64], &[2, 32]).expect("tensor creation failed"),
212 false,
213 );
214 let output = block.forward(&input);
215 assert_eq!(output.shape(), vec![2, 64]);
216 }
217
218 #[test]
219 fn test_residual_block_custom_activation() {
220 let main = Sequential::new().add(Linear::new(16, 16));
221
222 let block = ResidualBlock::new(main).with_activation(GELU);
223
224 let input = Variable::new(
225 Tensor::from_vec(vec![1.0; 32], &[2, 16]).expect("tensor creation failed"),
226 false,
227 );
228 let output = block.forward(&input);
229 assert_eq!(output.shape(), vec![2, 16]);
230 }
231
232 #[test]
233 fn test_residual_block_no_activation() {
234 let main = Sequential::new().add(Linear::new(16, 16));
235
236 let block = ResidualBlock::new(main).without_activation();
237
238 let input = Variable::new(
239 Tensor::from_vec(vec![1.0; 32], &[2, 16]).expect("tensor creation failed"),
240 false,
241 );
242 let output = block.forward(&input);
243 assert_eq!(output.shape(), vec![2, 16]);
244 }
245
246 #[test]
247 fn test_residual_block_parameters() {
248 let main = Sequential::new()
249 .add(Linear::new(32, 32)) .add(Linear::new(32, 32)); let block = ResidualBlock::new(main);
253 let params = block.parameters();
254 assert_eq!(params.len(), 4); }
256
257 #[test]
258 fn test_residual_block_named_parameters() {
259 let main = Sequential::new()
260 .add_named("conv1", Linear::new(32, 32))
261 .add_named("conv2", Linear::new(32, 32));
262
263 let downsample = Sequential::new().add_named("proj", Linear::new(32, 32));
264
265 let block = ResidualBlock::new(main).with_downsample(downsample);
266 let params = block.named_parameters();
267
268 assert!(params.contains_key("main_path.conv1.weight"));
269 assert!(params.contains_key("main_path.conv2.weight"));
270 assert!(params.contains_key("downsample.proj.weight"));
271 }
272
273 #[test]
274 fn test_residual_block_training_mode() {
275 let main = Sequential::new()
276 .add(BatchNorm1d::new(32))
277 .add(Linear::new(32, 32));
278
279 let mut block = ResidualBlock::new(main);
280 assert!(block.is_training());
281
282 block.set_training(false);
283 assert!(!block.is_training());
284
285 block.set_training(true);
286 assert!(block.is_training());
287 }
288
289 #[test]
290 fn test_residual_block_conv1d_with_downsample() {
291 let main = Sequential::new()
294 .add(Conv1d::new(64, 64, 3))
295 .add(BatchNorm1d::new(64))
296 .add(ReLU)
297 .add(Conv1d::new(64, 64, 3))
298 .add(BatchNorm1d::new(64));
299
300 let downsample = Sequential::new()
303 .add(Conv1d::new(64, 64, 5))
304 .add(BatchNorm1d::new(64));
305
306 let block = ResidualBlock::new(main).with_downsample(downsample);
307
308 let input = Variable::new(
310 Tensor::from_vec(vec![1.0; 2 * 64 * 20], &[2, 64, 20]).expect("tensor creation failed"),
311 false,
312 );
313 let output = block.forward(&input);
314
315 assert_eq!(output.shape()[0], 2);
316 assert_eq!(output.shape()[1], 64);
317 assert_eq!(output.shape()[2], 16);
318 }
319
320 #[test]
321 fn test_residual_block_gradient_flow() {
322 let main = Sequential::new().add(Linear::new(4, 4));
323
324 let block = ResidualBlock::new(main);
325
326 let input = Variable::new(
327 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
328 true,
329 );
330 let output = block.forward(&input);
331
332 let sum = output.sum();
334 sum.backward();
335
336 let params = block.parameters();
338 assert!(!params.is_empty());
339 }
340}