1use std::collections::HashMap;
10
11use axonml_autograd::Variable;
12
13use crate::module::Module;
14use crate::parameter::Parameter;
15use crate::sequential::Sequential;
16use crate::activation::ReLU;
17
18pub struct ResidualBlock {
44 main_path: Sequential,
45 downsample: Option<Sequential>,
46 activation: Option<Box<dyn Module>>,
47 training: bool,
48}
49
50impl ResidualBlock {
51 pub fn new(main_path: Sequential) -> Self {
56 Self {
57 main_path,
58 downsample: None,
59 activation: Some(Box::new(ReLU)),
60 training: true,
61 }
62 }
63
64 pub fn with_downsample(mut self, downsample: Sequential) -> Self {
68 self.downsample = Some(downsample);
69 self
70 }
71
72 pub fn with_activation<M: Module + 'static>(mut self, activation: M) -> Self {
76 self.activation = Some(Box::new(activation));
77 self
78 }
79
80 pub fn without_activation(mut self) -> Self {
82 self.activation = None;
83 self
84 }
85}
86
87impl Module for ResidualBlock {
88 fn forward(&self, input: &Variable) -> Variable {
89 let identity = match &self.downsample {
90 Some(ds) => ds.forward(input),
91 None => input.clone(),
92 };
93
94 let out = self.main_path.forward(input);
95 let out = out.add_var(&identity);
96
97 match &self.activation {
98 Some(act) => act.forward(&out),
99 None => out,
100 }
101 }
102
103 fn parameters(&self) -> Vec<Parameter> {
104 let mut params = self.main_path.parameters();
105 if let Some(ds) = &self.downsample {
106 params.extend(ds.parameters());
107 }
108 if let Some(act) = &self.activation {
109 params.extend(act.parameters());
110 }
111 params
112 }
113
114 fn named_parameters(&self) -> HashMap<String, Parameter> {
115 let mut params = HashMap::new();
116 for (name, param) in self.main_path.named_parameters() {
117 params.insert(format!("main_path.{name}"), param);
118 }
119 if let Some(ds) = &self.downsample {
120 for (name, param) in ds.named_parameters() {
121 params.insert(format!("downsample.{name}"), param);
122 }
123 }
124 if let Some(act) = &self.activation {
125 for (name, param) in act.named_parameters() {
126 params.insert(format!("activation.{name}"), param);
127 }
128 }
129 params
130 }
131
132 fn set_training(&mut self, training: bool) {
133 self.training = training;
134 self.main_path.set_training(training);
135 if let Some(ds) = &mut self.downsample {
136 ds.set_training(training);
137 }
138 if let Some(act) = &mut self.activation {
139 act.set_training(training);
140 }
141 }
142
143 fn is_training(&self) -> bool {
144 self.training
145 }
146
147 fn name(&self) -> &'static str {
148 "ResidualBlock"
149 }
150}
151
152#[cfg(test)]
157mod tests {
158 use super::*;
159 use axonml_tensor::Tensor;
160 use crate::layers::{Linear, BatchNorm1d, Conv1d};
161 use crate::activation::{ReLU, GELU};
162
163 #[test]
164 fn test_residual_block_identity_skip() {
165 let main = Sequential::new()
167 .add(Linear::new(32, 32))
168 .add(ReLU)
169 .add(Linear::new(32, 32));
170
171 let block = ResidualBlock::new(main);
172
173 let input = Variable::new(
174 Tensor::from_vec(vec![1.0; 64], &[2, 32]).unwrap(),
175 false,
176 );
177 let output = block.forward(&input);
178
179 assert_eq!(output.shape(), vec![2, 32]);
181 }
182
183 #[test]
184 fn test_residual_block_with_downsample() {
185 let main = Sequential::new()
187 .add(Linear::new(32, 64))
188 .add(ReLU)
189 .add(Linear::new(64, 64));
190
191 let downsample = Sequential::new()
193 .add(Linear::new(32, 64));
194
195 let block = ResidualBlock::new(main).with_downsample(downsample);
196
197 let input = Variable::new(
198 Tensor::from_vec(vec![1.0; 64], &[2, 32]).unwrap(),
199 false,
200 );
201 let output = block.forward(&input);
202 assert_eq!(output.shape(), vec![2, 64]);
203 }
204
205 #[test]
206 fn test_residual_block_custom_activation() {
207 let main = Sequential::new()
208 .add(Linear::new(16, 16));
209
210 let block = ResidualBlock::new(main).with_activation(GELU);
211
212 let input = Variable::new(
213 Tensor::from_vec(vec![1.0; 32], &[2, 16]).unwrap(),
214 false,
215 );
216 let output = block.forward(&input);
217 assert_eq!(output.shape(), vec![2, 16]);
218 }
219
220 #[test]
221 fn test_residual_block_no_activation() {
222 let main = Sequential::new()
223 .add(Linear::new(16, 16));
224
225 let block = ResidualBlock::new(main).without_activation();
226
227 let input = Variable::new(
228 Tensor::from_vec(vec![1.0; 32], &[2, 16]).unwrap(),
229 false,
230 );
231 let output = block.forward(&input);
232 assert_eq!(output.shape(), vec![2, 16]);
233 }
234
235 #[test]
236 fn test_residual_block_parameters() {
237 let main = Sequential::new()
238 .add(Linear::new(32, 32)) .add(Linear::new(32, 32)); let block = ResidualBlock::new(main);
242 let params = block.parameters();
243 assert_eq!(params.len(), 4); }
245
246 #[test]
247 fn test_residual_block_named_parameters() {
248 let main = Sequential::new()
249 .add_named("conv1", Linear::new(32, 32))
250 .add_named("conv2", Linear::new(32, 32));
251
252 let downsample = Sequential::new()
253 .add_named("proj", Linear::new(32, 32));
254
255 let block = ResidualBlock::new(main).with_downsample(downsample);
256 let params = block.named_parameters();
257
258 assert!(params.contains_key("main_path.conv1.weight"));
259 assert!(params.contains_key("main_path.conv2.weight"));
260 assert!(params.contains_key("downsample.proj.weight"));
261 }
262
263 #[test]
264 fn test_residual_block_training_mode() {
265 let main = Sequential::new()
266 .add(BatchNorm1d::new(32))
267 .add(Linear::new(32, 32));
268
269 let mut block = ResidualBlock::new(main);
270 assert!(block.is_training());
271
272 block.set_training(false);
273 assert!(!block.is_training());
274
275 block.set_training(true);
276 assert!(block.is_training());
277 }
278
279 #[test]
280 fn test_residual_block_conv1d_with_downsample() {
281 let main = Sequential::new()
284 .add(Conv1d::new(64, 64, 3))
285 .add(BatchNorm1d::new(64))
286 .add(ReLU)
287 .add(Conv1d::new(64, 64, 3))
288 .add(BatchNorm1d::new(64));
289
290 let downsample = Sequential::new()
293 .add(Conv1d::new(64, 64, 5))
294 .add(BatchNorm1d::new(64));
295
296 let block = ResidualBlock::new(main).with_downsample(downsample);
297
298 let input = Variable::new(
300 Tensor::from_vec(vec![1.0; 2 * 64 * 20], &[2, 64, 20]).unwrap(),
301 false,
302 );
303 let output = block.forward(&input);
304
305 assert_eq!(output.shape()[0], 2);
306 assert_eq!(output.shape()[1], 64);
307 assert_eq!(output.shape()[2], 16);
308 }
309
310 #[test]
311 fn test_residual_block_gradient_flow() {
312 let main = Sequential::new()
313 .add(Linear::new(4, 4));
314
315 let block = ResidualBlock::new(main);
316
317 let input = Variable::new(
318 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap(),
319 true,
320 );
321 let output = block.forward(&input);
322
323 let sum = output.sum();
325 sum.backward();
326
327 let params = block.parameters();
329 assert!(!params.is_empty());
330 }
331}