Skip to main content

axonml_nn/layers/
residual.rs

1//! Residual Block - Generic Skip Connection Layer
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/residual.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20
21use crate::activation::ReLU;
22use crate::module::Module;
23use crate::parameter::Parameter;
24use crate::sequential::Sequential;
25
26// =============================================================================
27// ResidualBlock
28// =============================================================================
29
30/// A generic residual block that wraps any module sequence with a skip connection.
31///
32/// Computes: `activation(main_path(x) + downsample(x))` where downsample is
33/// optional (defaults to identity). This enables gradient flow through the
34/// skip connection, improving training of deep networks.
35///
36/// # Example
37/// ```ignore
38/// use axonml_nn::prelude::*;
39/// use axonml_nn::layers::ResidualBlock;
40///
41/// // Conv1d residual block
42/// let main = Sequential::new()
43///     .add(Conv1d::new(64, 64, 3))
44///     .add(BatchNorm1d::new(64))
45///     .add(ReLU)
46///     .add(Conv1d::new(64, 64, 3))
47///     .add(BatchNorm1d::new(64));
48///
49/// let block = ResidualBlock::new(main);
50/// ```
51pub struct ResidualBlock {
52    main_path: Sequential,
53    downsample: Option<Sequential>,
54    activation: Option<Box<dyn Module>>,
55    training: bool,
56}
57
58impl ResidualBlock {
59    /// Creates a new residual block with the given main path and ReLU activation.
60    ///
61    /// The skip connection is identity (no downsample). Use `with_downsample()`
62    /// if the main path changes dimensions.
63    pub fn new(main_path: Sequential) -> Self {
64        Self {
65            main_path,
66            downsample: None,
67            activation: Some(Box::new(ReLU)),
68            training: true,
69        }
70    }
71
72    /// Adds a downsample projection for when input/output dimensions differ.
73    ///
74    /// Typically a Conv + BatchNorm to match channel/spatial dimensions.
75    pub fn with_downsample(mut self, downsample: Sequential) -> Self {
76        self.downsample = Some(downsample);
77        self
78    }
79
80    /// Sets a custom activation function applied after the residual addition.
81    ///
82    /// Pass any module implementing `Module` (ReLU, GELU, SiLU, etc.).
83    pub fn with_activation<M: Module + 'static>(mut self, activation: M) -> Self {
84        self.activation = Some(Box::new(activation));
85        self
86    }
87
88    /// Removes the post-addition activation (pre-activation ResNet style).
89    pub fn without_activation(mut self) -> Self {
90        self.activation = None;
91        self
92    }
93}
94
95impl Module for ResidualBlock {
96    fn forward(&self, input: &Variable) -> Variable {
97        let identity = match &self.downsample {
98            Some(ds) => ds.forward(input),
99            None => input.clone(),
100        };
101
102        let out = self.main_path.forward(input);
103        let out = out.add_var(&identity);
104
105        match &self.activation {
106            Some(act) => act.forward(&out),
107            None => out,
108        }
109    }
110
111    fn parameters(&self) -> Vec<Parameter> {
112        let mut params = self.main_path.parameters();
113        if let Some(ds) = &self.downsample {
114            params.extend(ds.parameters());
115        }
116        if let Some(act) = &self.activation {
117            params.extend(act.parameters());
118        }
119        params
120    }
121
122    fn named_parameters(&self) -> HashMap<String, Parameter> {
123        let mut params = HashMap::new();
124        for (name, param) in self.main_path.named_parameters() {
125            params.insert(format!("main_path.{name}"), param);
126        }
127        if let Some(ds) = &self.downsample {
128            for (name, param) in ds.named_parameters() {
129                params.insert(format!("downsample.{name}"), param);
130            }
131        }
132        if let Some(act) = &self.activation {
133            for (name, param) in act.named_parameters() {
134                params.insert(format!("activation.{name}"), param);
135            }
136        }
137        params
138    }
139
140    fn set_training(&mut self, training: bool) {
141        self.training = training;
142        self.main_path.set_training(training);
143        if let Some(ds) = &mut self.downsample {
144            ds.set_training(training);
145        }
146        if let Some(act) = &mut self.activation {
147            act.set_training(training);
148        }
149    }
150
151    fn is_training(&self) -> bool {
152        self.training
153    }
154
155    fn name(&self) -> &'static str {
156        "ResidualBlock"
157    }
158}
159
160// =============================================================================
161// Tests
162// =============================================================================
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::activation::{GELU, ReLU};
168    use crate::layers::{BatchNorm1d, Conv1d, Linear};
169    use axonml_tensor::Tensor;
170
171    #[test]
172    fn test_residual_block_identity_skip() {
173        // Main path that preserves dimensions
174        let main = Sequential::new()
175            .add(Linear::new(32, 32))
176            .add(ReLU)
177            .add(Linear::new(32, 32));
178
179        let block = ResidualBlock::new(main);
180
181        let input = Variable::new(
182            Tensor::from_vec(vec![1.0; 64], &[2, 32]).expect("tensor creation failed"),
183            false,
184        );
185        let output = block.forward(&input);
186
187        // Output shape should match input
188        assert_eq!(output.shape(), vec![2, 32]);
189    }
190
191    #[test]
192    fn test_residual_block_with_downsample() {
193        // Main path changes dimensions: 32 -> 64
194        let main = Sequential::new()
195            .add(Linear::new(32, 64))
196            .add(ReLU)
197            .add(Linear::new(64, 64));
198
199        // Downsample projects input: 32 -> 64
200        let downsample = Sequential::new().add(Linear::new(32, 64));
201
202        let block = ResidualBlock::new(main).with_downsample(downsample);
203
204        let input = Variable::new(
205            Tensor::from_vec(vec![1.0; 64], &[2, 32]).expect("tensor creation failed"),
206            false,
207        );
208        let output = block.forward(&input);
209        assert_eq!(output.shape(), vec![2, 64]);
210    }
211
212    #[test]
213    fn test_residual_block_custom_activation() {
214        let main = Sequential::new().add(Linear::new(16, 16));
215
216        let block = ResidualBlock::new(main).with_activation(GELU);
217
218        let input = Variable::new(
219            Tensor::from_vec(vec![1.0; 32], &[2, 16]).expect("tensor creation failed"),
220            false,
221        );
222        let output = block.forward(&input);
223        assert_eq!(output.shape(), vec![2, 16]);
224    }
225
226    #[test]
227    fn test_residual_block_no_activation() {
228        let main = Sequential::new().add(Linear::new(16, 16));
229
230        let block = ResidualBlock::new(main).without_activation();
231
232        let input = Variable::new(
233            Tensor::from_vec(vec![1.0; 32], &[2, 16]).expect("tensor creation failed"),
234            false,
235        );
236        let output = block.forward(&input);
237        assert_eq!(output.shape(), vec![2, 16]);
238    }
239
240    #[test]
241    fn test_residual_block_parameters() {
242        let main = Sequential::new()
243            .add(Linear::new(32, 32)) // weight(32x32) + bias(32) = 1056
244            .add(Linear::new(32, 32)); // weight(32x32) + bias(32) = 1056
245
246        let block = ResidualBlock::new(main);
247        let params = block.parameters();
248        assert_eq!(params.len(), 4); // 2 weights + 2 biases
249    }
250
251    #[test]
252    fn test_residual_block_named_parameters() {
253        let main = Sequential::new()
254            .add_named("conv1", Linear::new(32, 32))
255            .add_named("conv2", Linear::new(32, 32));
256
257        let downsample = Sequential::new().add_named("proj", Linear::new(32, 32));
258
259        let block = ResidualBlock::new(main).with_downsample(downsample);
260        let params = block.named_parameters();
261
262        assert!(params.contains_key("main_path.conv1.weight"));
263        assert!(params.contains_key("main_path.conv2.weight"));
264        assert!(params.contains_key("downsample.proj.weight"));
265    }
266
267    #[test]
268    fn test_residual_block_training_mode() {
269        let main = Sequential::new()
270            .add(BatchNorm1d::new(32))
271            .add(Linear::new(32, 32));
272
273        let mut block = ResidualBlock::new(main);
274        assert!(block.is_training());
275
276        block.set_training(false);
277        assert!(!block.is_training());
278
279        block.set_training(true);
280        assert!(block.is_training());
281    }
282
283    #[test]
284    fn test_residual_block_conv1d_with_downsample() {
285        // Real use case: Conv1d residual block with downsample to match dimensions
286        // Main path: 2 Conv1d(k=3) reduces time by 4 (20 -> 18 -> 16)
287        let main = Sequential::new()
288            .add(Conv1d::new(64, 64, 3))
289            .add(BatchNorm1d::new(64))
290            .add(ReLU)
291            .add(Conv1d::new(64, 64, 3))
292            .add(BatchNorm1d::new(64));
293
294        // Downsample matches skip connection to main path output shape
295        // Conv1d with kernel=5 reduces 20 -> 16
296        let downsample = Sequential::new()
297            .add(Conv1d::new(64, 64, 5))
298            .add(BatchNorm1d::new(64));
299
300        let block = ResidualBlock::new(main).with_downsample(downsample);
301
302        // Input: (batch=2, channels=64, time=20)
303        let input = Variable::new(
304            Tensor::from_vec(vec![1.0; 2 * 64 * 20], &[2, 64, 20]).expect("tensor creation failed"),
305            false,
306        );
307        let output = block.forward(&input);
308
309        assert_eq!(output.shape()[0], 2);
310        assert_eq!(output.shape()[1], 64);
311        assert_eq!(output.shape()[2], 16);
312    }
313
314    #[test]
315    fn test_residual_block_gradient_flow() {
316        let main = Sequential::new().add(Linear::new(4, 4));
317
318        let block = ResidualBlock::new(main);
319
320        let input = Variable::new(
321            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
322            true,
323        );
324        let output = block.forward(&input);
325
326        // Sum to scalar for backward
327        let sum = output.sum();
328        sum.backward();
329
330        // Gradient should flow through both main path and skip connection
331        let params = block.parameters();
332        assert!(!params.is_empty());
333    }
334}