Skip to main content

axonml_nn/layers/
residual.rs

1//! Residual Block - Generic Skip Connection Layer
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/residual.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20
21use crate::activation::ReLU;
22use crate::module::Module;
23use crate::parameter::Parameter;
24use crate::sequential::Sequential;
25
26// =============================================================================
27// ResidualBlock
28// =============================================================================
29
30/// A generic residual block that wraps any module sequence with a skip connection.
31///
32/// Computes: `activation(main_path(x) + downsample(x))` where downsample is
33/// optional (defaults to identity). This enables gradient flow through the
34/// skip connection, improving training of deep networks.
35///
36/// # Example
37/// ```ignore
38/// use axonml_nn::prelude::*;
39/// use axonml_nn::layers::ResidualBlock;
40///
41/// // Conv1d residual block
42/// let main = Sequential::new()
43///     .add(Conv1d::new(64, 64, 3))
44///     .add(BatchNorm1d::new(64))
45///     .add(ReLU)
46///     .add(Conv1d::new(64, 64, 3))
47///     .add(BatchNorm1d::new(64));
48///
49/// let block = ResidualBlock::new(main);
50/// ```
51pub struct ResidualBlock {
52    main_path: Sequential,
53    downsample: Option<Sequential>,
54    activation: Option<Box<dyn Module>>,
55    training: bool,
56}
57
58impl ResidualBlock {
59    /// Creates a new residual block with the given main path and ReLU activation.
60    ///
61    /// The skip connection is identity (no downsample). Use `with_downsample()`
62    /// if the main path changes dimensions.
63    pub fn new(main_path: Sequential) -> Self {
64        Self {
65            main_path,
66            downsample: None,
67            activation: Some(Box::new(ReLU)),
68            training: true,
69        }
70    }
71
72    /// Adds a downsample projection for when input/output dimensions differ.
73    ///
74    /// Typically a Conv + BatchNorm to match channel/spatial dimensions.
75    pub fn with_downsample(mut self, downsample: Sequential) -> Self {
76        self.downsample = Some(downsample);
77        self
78    }
79
80    /// Sets a custom activation function applied after the residual addition.
81    ///
82    /// Pass any module implementing `Module` (ReLU, GELU, SiLU, etc.).
83    pub fn with_activation<M: Module + 'static>(mut self, activation: M) -> Self {
84        self.activation = Some(Box::new(activation));
85        self
86    }
87
88    /// Removes the post-addition activation (pre-activation ResNet style).
89    pub fn without_activation(mut self) -> Self {
90        self.activation = None;
91        self
92    }
93}
94
95impl Module for ResidualBlock {
96    fn forward(&self, input: &Variable) -> Variable {
97        let identity = match &self.downsample {
98            Some(ds) => ds.forward(input),
99            None => input.clone(),
100        };
101
102        let out = self.main_path.forward(input);
103        let out = out.add_var(&identity);
104
105        match &self.activation {
106            Some(act) => act.forward(&out),
107            None => out,
108        }
109    }
110
111    fn parameters(&self) -> Vec<Parameter> {
112        let mut params = self.main_path.parameters();
113        if let Some(ds) = &self.downsample {
114            params.extend(ds.parameters());
115        }
116        if let Some(act) = &self.activation {
117            params.extend(act.parameters());
118        }
119        params
120    }
121
122    fn named_parameters(&self) -> HashMap<String, Parameter> {
123        let mut params = HashMap::new();
124        for (name, param) in self.main_path.named_parameters() {
125            params.insert(format!("main_path.{name}"), param);
126        }
127        if let Some(ds) = &self.downsample {
128            for (name, param) in ds.named_parameters() {
129                params.insert(format!("downsample.{name}"), param);
130            }
131        }
132        if let Some(act) = &self.activation {
133            for (name, param) in act.named_parameters() {
134                params.insert(format!("activation.{name}"), param);
135            }
136        }
137        params
138    }
139
140    fn set_training(&mut self, training: bool) {
141        self.training = training;
142        self.main_path.set_training(training);
143        if let Some(ds) = &mut self.downsample {
144            ds.set_training(training);
145        }
146        if let Some(act) = &mut self.activation {
147            act.set_training(training);
148        }
149    }
150
151    fn is_training(&self) -> bool {
152        self.training
153    }
154
155    fn name(&self) -> &'static str {
156        "ResidualBlock"
157    }
158}
159
160// =============================================================================
161// Tests
162// =============================================================================
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::activation::{GELU, ReLU};
168    use crate::layers::{BatchNorm1d, Conv1d, Linear};
169    use axonml_tensor::Tensor;
170
171    #[test]
172    fn test_residual_block_identity_skip() {
173        // Main path that preserves dimensions
174        let main = Sequential::new()
175            .add(Linear::new(32, 32))
176            .add(ReLU)
177            .add(Linear::new(32, 32));
178
179        let block = ResidualBlock::new(main);
180
181        let input = Variable::new(Tensor::from_vec(vec![1.0; 64], &[2, 32]).unwrap(), false);
182        let output = block.forward(&input);
183
184        // Output shape should match input
185        assert_eq!(output.shape(), vec![2, 32]);
186    }
187
188    #[test]
189    fn test_residual_block_with_downsample() {
190        // Main path changes dimensions: 32 -> 64
191        let main = Sequential::new()
192            .add(Linear::new(32, 64))
193            .add(ReLU)
194            .add(Linear::new(64, 64));
195
196        // Downsample projects input: 32 -> 64
197        let downsample = Sequential::new().add(Linear::new(32, 64));
198
199        let block = ResidualBlock::new(main).with_downsample(downsample);
200
201        let input = Variable::new(Tensor::from_vec(vec![1.0; 64], &[2, 32]).unwrap(), false);
202        let output = block.forward(&input);
203        assert_eq!(output.shape(), vec![2, 64]);
204    }
205
206    #[test]
207    fn test_residual_block_custom_activation() {
208        let main = Sequential::new().add(Linear::new(16, 16));
209
210        let block = ResidualBlock::new(main).with_activation(GELU);
211
212        let input = Variable::new(Tensor::from_vec(vec![1.0; 32], &[2, 16]).unwrap(), false);
213        let output = block.forward(&input);
214        assert_eq!(output.shape(), vec![2, 16]);
215    }
216
217    #[test]
218    fn test_residual_block_no_activation() {
219        let main = Sequential::new().add(Linear::new(16, 16));
220
221        let block = ResidualBlock::new(main).without_activation();
222
223        let input = Variable::new(Tensor::from_vec(vec![1.0; 32], &[2, 16]).unwrap(), false);
224        let output = block.forward(&input);
225        assert_eq!(output.shape(), vec![2, 16]);
226    }
227
228    #[test]
229    fn test_residual_block_parameters() {
230        let main = Sequential::new()
231            .add(Linear::new(32, 32)) // weight(32x32) + bias(32) = 1056
232            .add(Linear::new(32, 32)); // weight(32x32) + bias(32) = 1056
233
234        let block = ResidualBlock::new(main);
235        let params = block.parameters();
236        assert_eq!(params.len(), 4); // 2 weights + 2 biases
237    }
238
239    #[test]
240    fn test_residual_block_named_parameters() {
241        let main = Sequential::new()
242            .add_named("conv1", Linear::new(32, 32))
243            .add_named("conv2", Linear::new(32, 32));
244
245        let downsample = Sequential::new().add_named("proj", Linear::new(32, 32));
246
247        let block = ResidualBlock::new(main).with_downsample(downsample);
248        let params = block.named_parameters();
249
250        assert!(params.contains_key("main_path.conv1.weight"));
251        assert!(params.contains_key("main_path.conv2.weight"));
252        assert!(params.contains_key("downsample.proj.weight"));
253    }
254
255    #[test]
256    fn test_residual_block_training_mode() {
257        let main = Sequential::new()
258            .add(BatchNorm1d::new(32))
259            .add(Linear::new(32, 32));
260
261        let mut block = ResidualBlock::new(main);
262        assert!(block.is_training());
263
264        block.set_training(false);
265        assert!(!block.is_training());
266
267        block.set_training(true);
268        assert!(block.is_training());
269    }
270
271    #[test]
272    fn test_residual_block_conv1d_with_downsample() {
273        // Real use case: Conv1d residual block with downsample to match dimensions
274        // Main path: 2 Conv1d(k=3) reduces time by 4 (20 -> 18 -> 16)
275        let main = Sequential::new()
276            .add(Conv1d::new(64, 64, 3))
277            .add(BatchNorm1d::new(64))
278            .add(ReLU)
279            .add(Conv1d::new(64, 64, 3))
280            .add(BatchNorm1d::new(64));
281
282        // Downsample matches skip connection to main path output shape
283        // Conv1d with kernel=5 reduces 20 -> 16
284        let downsample = Sequential::new()
285            .add(Conv1d::new(64, 64, 5))
286            .add(BatchNorm1d::new(64));
287
288        let block = ResidualBlock::new(main).with_downsample(downsample);
289
290        // Input: (batch=2, channels=64, time=20)
291        let input = Variable::new(
292            Tensor::from_vec(vec![1.0; 2 * 64 * 20], &[2, 64, 20]).unwrap(),
293            false,
294        );
295        let output = block.forward(&input);
296
297        assert_eq!(output.shape()[0], 2);
298        assert_eq!(output.shape()[1], 64);
299        assert_eq!(output.shape()[2], 16);
300    }
301
302    #[test]
303    fn test_residual_block_gradient_flow() {
304        let main = Sequential::new().add(Linear::new(4, 4));
305
306        let block = ResidualBlock::new(main);
307
308        let input = Variable::new(
309            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap(),
310            true,
311        );
312        let output = block.forward(&input);
313
314        // Sum to scalar for backward
315        let sum = output.sum();
316        sum.backward();
317
318        // Gradient should flow through both main path and skip connection
319        let params = block.parameters();
320        assert!(!params.is_empty());
321    }
322}