1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20
21use crate::activation::ReLU;
22use crate::module::Module;
23use crate::parameter::Parameter;
24use crate::sequential::Sequential;
25
26pub struct ResidualBlock {
52 main_path: Sequential,
53 downsample: Option<Sequential>,
54 activation: Option<Box<dyn Module>>,
55 training: bool,
56}
57
58impl ResidualBlock {
59 pub fn new(main_path: Sequential) -> Self {
64 Self {
65 main_path,
66 downsample: None,
67 activation: Some(Box::new(ReLU)),
68 training: true,
69 }
70 }
71
72 pub fn with_downsample(mut self, downsample: Sequential) -> Self {
76 self.downsample = Some(downsample);
77 self
78 }
79
80 pub fn with_activation<M: Module + 'static>(mut self, activation: M) -> Self {
84 self.activation = Some(Box::new(activation));
85 self
86 }
87
88 pub fn without_activation(mut self) -> Self {
90 self.activation = None;
91 self
92 }
93}
94
95impl Module for ResidualBlock {
96 fn forward(&self, input: &Variable) -> Variable {
97 let identity = match &self.downsample {
98 Some(ds) => ds.forward(input),
99 None => input.clone(),
100 };
101
102 let out = self.main_path.forward(input);
103 let out = out.add_var(&identity);
104
105 match &self.activation {
106 Some(act) => act.forward(&out),
107 None => out,
108 }
109 }
110
111 fn parameters(&self) -> Vec<Parameter> {
112 let mut params = self.main_path.parameters();
113 if let Some(ds) = &self.downsample {
114 params.extend(ds.parameters());
115 }
116 if let Some(act) = &self.activation {
117 params.extend(act.parameters());
118 }
119 params
120 }
121
122 fn named_parameters(&self) -> HashMap<String, Parameter> {
123 let mut params = HashMap::new();
124 for (name, param) in self.main_path.named_parameters() {
125 params.insert(format!("main_path.{name}"), param);
126 }
127 if let Some(ds) = &self.downsample {
128 for (name, param) in ds.named_parameters() {
129 params.insert(format!("downsample.{name}"), param);
130 }
131 }
132 if let Some(act) = &self.activation {
133 for (name, param) in act.named_parameters() {
134 params.insert(format!("activation.{name}"), param);
135 }
136 }
137 params
138 }
139
140 fn set_training(&mut self, training: bool) {
141 self.training = training;
142 self.main_path.set_training(training);
143 if let Some(ds) = &mut self.downsample {
144 ds.set_training(training);
145 }
146 if let Some(act) = &mut self.activation {
147 act.set_training(training);
148 }
149 }
150
151 fn is_training(&self) -> bool {
152 self.training
153 }
154
155 fn name(&self) -> &'static str {
156 "ResidualBlock"
157 }
158}
159
160#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::activation::{GELU, ReLU};
168 use crate::layers::{BatchNorm1d, Conv1d, Linear};
169 use axonml_tensor::Tensor;
170
171 #[test]
172 fn test_residual_block_identity_skip() {
173 let main = Sequential::new()
175 .add(Linear::new(32, 32))
176 .add(ReLU)
177 .add(Linear::new(32, 32));
178
179 let block = ResidualBlock::new(main);
180
181 let input = Variable::new(
182 Tensor::from_vec(vec![1.0; 64], &[2, 32]).expect("tensor creation failed"),
183 false,
184 );
185 let output = block.forward(&input);
186
187 assert_eq!(output.shape(), vec![2, 32]);
189 }
190
191 #[test]
192 fn test_residual_block_with_downsample() {
193 let main = Sequential::new()
195 .add(Linear::new(32, 64))
196 .add(ReLU)
197 .add(Linear::new(64, 64));
198
199 let downsample = Sequential::new().add(Linear::new(32, 64));
201
202 let block = ResidualBlock::new(main).with_downsample(downsample);
203
204 let input = Variable::new(
205 Tensor::from_vec(vec![1.0; 64], &[2, 32]).expect("tensor creation failed"),
206 false,
207 );
208 let output = block.forward(&input);
209 assert_eq!(output.shape(), vec![2, 64]);
210 }
211
212 #[test]
213 fn test_residual_block_custom_activation() {
214 let main = Sequential::new().add(Linear::new(16, 16));
215
216 let block = ResidualBlock::new(main).with_activation(GELU);
217
218 let input = Variable::new(
219 Tensor::from_vec(vec![1.0; 32], &[2, 16]).expect("tensor creation failed"),
220 false,
221 );
222 let output = block.forward(&input);
223 assert_eq!(output.shape(), vec![2, 16]);
224 }
225
226 #[test]
227 fn test_residual_block_no_activation() {
228 let main = Sequential::new().add(Linear::new(16, 16));
229
230 let block = ResidualBlock::new(main).without_activation();
231
232 let input = Variable::new(
233 Tensor::from_vec(vec![1.0; 32], &[2, 16]).expect("tensor creation failed"),
234 false,
235 );
236 let output = block.forward(&input);
237 assert_eq!(output.shape(), vec![2, 16]);
238 }
239
240 #[test]
241 fn test_residual_block_parameters() {
242 let main = Sequential::new()
243 .add(Linear::new(32, 32)) .add(Linear::new(32, 32)); let block = ResidualBlock::new(main);
247 let params = block.parameters();
248 assert_eq!(params.len(), 4); }
250
251 #[test]
252 fn test_residual_block_named_parameters() {
253 let main = Sequential::new()
254 .add_named("conv1", Linear::new(32, 32))
255 .add_named("conv2", Linear::new(32, 32));
256
257 let downsample = Sequential::new().add_named("proj", Linear::new(32, 32));
258
259 let block = ResidualBlock::new(main).with_downsample(downsample);
260 let params = block.named_parameters();
261
262 assert!(params.contains_key("main_path.conv1.weight"));
263 assert!(params.contains_key("main_path.conv2.weight"));
264 assert!(params.contains_key("downsample.proj.weight"));
265 }
266
267 #[test]
268 fn test_residual_block_training_mode() {
269 let main = Sequential::new()
270 .add(BatchNorm1d::new(32))
271 .add(Linear::new(32, 32));
272
273 let mut block = ResidualBlock::new(main);
274 assert!(block.is_training());
275
276 block.set_training(false);
277 assert!(!block.is_training());
278
279 block.set_training(true);
280 assert!(block.is_training());
281 }
282
283 #[test]
284 fn test_residual_block_conv1d_with_downsample() {
285 let main = Sequential::new()
288 .add(Conv1d::new(64, 64, 3))
289 .add(BatchNorm1d::new(64))
290 .add(ReLU)
291 .add(Conv1d::new(64, 64, 3))
292 .add(BatchNorm1d::new(64));
293
294 let downsample = Sequential::new()
297 .add(Conv1d::new(64, 64, 5))
298 .add(BatchNorm1d::new(64));
299
300 let block = ResidualBlock::new(main).with_downsample(downsample);
301
302 let input = Variable::new(
304 Tensor::from_vec(vec![1.0; 2 * 64 * 20], &[2, 64, 20]).expect("tensor creation failed"),
305 false,
306 );
307 let output = block.forward(&input);
308
309 assert_eq!(output.shape()[0], 2);
310 assert_eq!(output.shape()[1], 64);
311 assert_eq!(output.shape()[2], 16);
312 }
313
314 #[test]
315 fn test_residual_block_gradient_flow() {
316 let main = Sequential::new().add(Linear::new(4, 4));
317
318 let block = ResidualBlock::new(main);
319
320 let input = Variable::new(
321 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
322 true,
323 );
324 let output = block.forward(&input);
325
326 let sum = output.sum();
328 sum.backward();
329
330 let params = block.parameters();
332 assert!(!params.is_empty());
333 }
334}