axonml_nn/layers/
residual.rs1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20
21use crate::activation::ReLU;
22use crate::module::Module;
23use crate::parameter::Parameter;
24use crate::sequential::Sequential;
25
26pub struct ResidualBlock {
52 main_path: Sequential,
53 downsample: Option<Sequential>,
54 activation: Option<Box<dyn Module>>,
55 training: bool,
56}
57
58impl ResidualBlock {
59 pub fn new(main_path: Sequential) -> Self {
64 Self {
65 main_path,
66 downsample: None,
67 activation: Some(Box::new(ReLU)),
68 training: true,
69 }
70 }
71
72 pub fn with_downsample(mut self, downsample: Sequential) -> Self {
76 self.downsample = Some(downsample);
77 self
78 }
79
80 pub fn with_activation<M: Module + 'static>(mut self, activation: M) -> Self {
84 self.activation = Some(Box::new(activation));
85 self
86 }
87
88 pub fn without_activation(mut self) -> Self {
90 self.activation = None;
91 self
92 }
93}
94
95impl Module for ResidualBlock {
96 fn forward(&self, input: &Variable) -> Variable {
97 let identity = match &self.downsample {
98 Some(ds) => ds.forward(input),
99 None => input.clone(),
100 };
101
102 let out = self.main_path.forward(input);
103 let out = out.add_var(&identity);
104
105 match &self.activation {
106 Some(act) => act.forward(&out),
107 None => out,
108 }
109 }
110
111 fn parameters(&self) -> Vec<Parameter> {
112 let mut params = self.main_path.parameters();
113 if let Some(ds) = &self.downsample {
114 params.extend(ds.parameters());
115 }
116 if let Some(act) = &self.activation {
117 params.extend(act.parameters());
118 }
119 params
120 }
121
122 fn named_parameters(&self) -> HashMap<String, Parameter> {
123 let mut params = HashMap::new();
124 for (name, param) in self.main_path.named_parameters() {
125 params.insert(format!("main_path.{name}"), param);
126 }
127 if let Some(ds) = &self.downsample {
128 for (name, param) in ds.named_parameters() {
129 params.insert(format!("downsample.{name}"), param);
130 }
131 }
132 if let Some(act) = &self.activation {
133 for (name, param) in act.named_parameters() {
134 params.insert(format!("activation.{name}"), param);
135 }
136 }
137 params
138 }
139
140 fn set_training(&mut self, training: bool) {
141 self.training = training;
142 self.main_path.set_training(training);
143 if let Some(ds) = &mut self.downsample {
144 ds.set_training(training);
145 }
146 if let Some(act) = &mut self.activation {
147 act.set_training(training);
148 }
149 }
150
151 fn is_training(&self) -> bool {
152 self.training
153 }
154
155 fn name(&self) -> &'static str {
156 "ResidualBlock"
157 }
158}
159
160#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::activation::{GELU, ReLU};
168 use crate::layers::{BatchNorm1d, Conv1d, Linear};
169 use axonml_tensor::Tensor;
170
171 #[test]
172 fn test_residual_block_identity_skip() {
173 let main = Sequential::new()
175 .add(Linear::new(32, 32))
176 .add(ReLU)
177 .add(Linear::new(32, 32));
178
179 let block = ResidualBlock::new(main);
180
181 let input = Variable::new(Tensor::from_vec(vec![1.0; 64], &[2, 32]).unwrap(), false);
182 let output = block.forward(&input);
183
184 assert_eq!(output.shape(), vec![2, 32]);
186 }
187
188 #[test]
189 fn test_residual_block_with_downsample() {
190 let main = Sequential::new()
192 .add(Linear::new(32, 64))
193 .add(ReLU)
194 .add(Linear::new(64, 64));
195
196 let downsample = Sequential::new().add(Linear::new(32, 64));
198
199 let block = ResidualBlock::new(main).with_downsample(downsample);
200
201 let input = Variable::new(Tensor::from_vec(vec![1.0; 64], &[2, 32]).unwrap(), false);
202 let output = block.forward(&input);
203 assert_eq!(output.shape(), vec![2, 64]);
204 }
205
206 #[test]
207 fn test_residual_block_custom_activation() {
208 let main = Sequential::new().add(Linear::new(16, 16));
209
210 let block = ResidualBlock::new(main).with_activation(GELU);
211
212 let input = Variable::new(Tensor::from_vec(vec![1.0; 32], &[2, 16]).unwrap(), false);
213 let output = block.forward(&input);
214 assert_eq!(output.shape(), vec![2, 16]);
215 }
216
217 #[test]
218 fn test_residual_block_no_activation() {
219 let main = Sequential::new().add(Linear::new(16, 16));
220
221 let block = ResidualBlock::new(main).without_activation();
222
223 let input = Variable::new(Tensor::from_vec(vec![1.0; 32], &[2, 16]).unwrap(), false);
224 let output = block.forward(&input);
225 assert_eq!(output.shape(), vec![2, 16]);
226 }
227
228 #[test]
229 fn test_residual_block_parameters() {
230 let main = Sequential::new()
231 .add(Linear::new(32, 32)) .add(Linear::new(32, 32)); let block = ResidualBlock::new(main);
235 let params = block.parameters();
236 assert_eq!(params.len(), 4); }
238
239 #[test]
240 fn test_residual_block_named_parameters() {
241 let main = Sequential::new()
242 .add_named("conv1", Linear::new(32, 32))
243 .add_named("conv2", Linear::new(32, 32));
244
245 let downsample = Sequential::new().add_named("proj", Linear::new(32, 32));
246
247 let block = ResidualBlock::new(main).with_downsample(downsample);
248 let params = block.named_parameters();
249
250 assert!(params.contains_key("main_path.conv1.weight"));
251 assert!(params.contains_key("main_path.conv2.weight"));
252 assert!(params.contains_key("downsample.proj.weight"));
253 }
254
255 #[test]
256 fn test_residual_block_training_mode() {
257 let main = Sequential::new()
258 .add(BatchNorm1d::new(32))
259 .add(Linear::new(32, 32));
260
261 let mut block = ResidualBlock::new(main);
262 assert!(block.is_training());
263
264 block.set_training(false);
265 assert!(!block.is_training());
266
267 block.set_training(true);
268 assert!(block.is_training());
269 }
270
271 #[test]
272 fn test_residual_block_conv1d_with_downsample() {
273 let main = Sequential::new()
276 .add(Conv1d::new(64, 64, 3))
277 .add(BatchNorm1d::new(64))
278 .add(ReLU)
279 .add(Conv1d::new(64, 64, 3))
280 .add(BatchNorm1d::new(64));
281
282 let downsample = Sequential::new()
285 .add(Conv1d::new(64, 64, 5))
286 .add(BatchNorm1d::new(64));
287
288 let block = ResidualBlock::new(main).with_downsample(downsample);
289
290 let input = Variable::new(
292 Tensor::from_vec(vec![1.0; 2 * 64 * 20], &[2, 64, 20]).unwrap(),
293 false,
294 );
295 let output = block.forward(&input);
296
297 assert_eq!(output.shape()[0], 2);
298 assert_eq!(output.shape()[1], 64);
299 assert_eq!(output.shape()[2], 16);
300 }
301
302 #[test]
303 fn test_residual_block_gradient_flow() {
304 let main = Sequential::new().add(Linear::new(4, 4));
305
306 let block = ResidualBlock::new(main);
307
308 let input = Variable::new(
309 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap(),
310 true,
311 );
312 let output = block.forward(&input);
313
314 let sum = output.sum();
316 sum.backward();
317
318 let params = block.parameters();
320 assert!(!params.is_empty());
321 }
322}