Skip to main content

axonml_nn/layers/
residual.rs

1//! `ResidualBlock` — generic skip connection wrapper.
2//!
3//! 334 lines. Wraps any `Module` sub-block and adds the input to its
4//! output: `forward(x) = x + sub_block(x)`. Optionally applies a
5//! projection Linear if input/output dims differ. Used by ResNet,
6//! transformer layers, and other architectures with residual connections.
7//!
8//! # File
9//! `crates/axonml-nn/src/layers/residual.rs`
10//!
11//! # Author
12//! Andrew Jewell Sr. — AutomataNexus LLC
13//! ORCID: 0009-0005-2158-7060
14//!
15//! # Updated
16//! April 14, 2026 11:15 PM EST
17//!
18//! # Disclaimer
19//! Use at own risk. This software is provided "as is", without warranty of any
20//! kind, express or implied. The author and AutomataNexus shall not be held
21//! liable for any damages arising from the use of this software.
22
23use std::collections::HashMap;
24
25use axonml_autograd::Variable;
26
27use crate::activation::ReLU;
28use crate::module::Module;
29use crate::parameter::Parameter;
30use crate::sequential::Sequential;
31
32// =============================================================================
33// ResidualBlock
34// =============================================================================
35
36/// A generic residual block that wraps any module sequence with a skip connection.
37///
38/// Computes: `activation(main_path(x) + downsample(x))` where downsample is
39/// optional (defaults to identity). This enables gradient flow through the
40/// skip connection, improving training of deep networks.
41///
42/// # Example
43/// ```ignore
44/// use axonml_nn::prelude::*;
45/// use axonml_nn::layers::ResidualBlock;
46///
47/// // Conv1d residual block
48/// let main = Sequential::new()
49///     .add(Conv1d::new(64, 64, 3))
50///     .add(BatchNorm1d::new(64))
51///     .add(ReLU)
52///     .add(Conv1d::new(64, 64, 3))
53///     .add(BatchNorm1d::new(64));
54///
55/// let block = ResidualBlock::new(main);
56/// ```
57pub struct ResidualBlock {
58    main_path: Sequential,
59    downsample: Option<Sequential>,
60    activation: Option<Box<dyn Module>>,
61    training: bool,
62}
63
64impl ResidualBlock {
65    /// Creates a new residual block with the given main path and ReLU activation.
66    ///
67    /// The skip connection is identity (no downsample). Use `with_downsample()`
68    /// if the main path changes dimensions.
69    pub fn new(main_path: Sequential) -> Self {
70        Self {
71            main_path,
72            downsample: None,
73            activation: Some(Box::new(ReLU)),
74            training: true,
75        }
76    }
77
78    /// Adds a downsample projection for when input/output dimensions differ.
79    ///
80    /// Typically a Conv + BatchNorm to match channel/spatial dimensions.
81    pub fn with_downsample(mut self, downsample: Sequential) -> Self {
82        self.downsample = Some(downsample);
83        self
84    }
85
86    /// Sets a custom activation function applied after the residual addition.
87    ///
88    /// Pass any module implementing `Module` (ReLU, GELU, SiLU, etc.).
89    pub fn with_activation<M: Module + 'static>(mut self, activation: M) -> Self {
90        self.activation = Some(Box::new(activation));
91        self
92    }
93
94    /// Removes the post-addition activation (pre-activation ResNet style).
95    pub fn without_activation(mut self) -> Self {
96        self.activation = None;
97        self
98    }
99}
100
101impl Module for ResidualBlock {
102    fn forward(&self, input: &Variable) -> Variable {
103        let identity = match &self.downsample {
104            Some(ds) => ds.forward(input),
105            None => input.clone(),
106        };
107
108        let out = self.main_path.forward(input);
109        let out = out.add_var(&identity);
110
111        match &self.activation {
112            Some(act) => act.forward(&out),
113            None => out,
114        }
115    }
116
117    fn parameters(&self) -> Vec<Parameter> {
118        let mut params = self.main_path.parameters();
119        if let Some(ds) = &self.downsample {
120            params.extend(ds.parameters());
121        }
122        if let Some(act) = &self.activation {
123            params.extend(act.parameters());
124        }
125        params
126    }
127
128    fn named_parameters(&self) -> HashMap<String, Parameter> {
129        let mut params = HashMap::new();
130        for (name, param) in self.main_path.named_parameters() {
131            params.insert(format!("main_path.{name}"), param);
132        }
133        if let Some(ds) = &self.downsample {
134            for (name, param) in ds.named_parameters() {
135                params.insert(format!("downsample.{name}"), param);
136            }
137        }
138        if let Some(act) = &self.activation {
139            for (name, param) in act.named_parameters() {
140                params.insert(format!("activation.{name}"), param);
141            }
142        }
143        params
144    }
145
146    fn set_training(&mut self, training: bool) {
147        self.training = training;
148        self.main_path.set_training(training);
149        if let Some(ds) = &mut self.downsample {
150            ds.set_training(training);
151        }
152        if let Some(act) = &mut self.activation {
153            act.set_training(training);
154        }
155    }
156
157    fn is_training(&self) -> bool {
158        self.training
159    }
160
161    fn name(&self) -> &'static str {
162        "ResidualBlock"
163    }
164}
165
166// =============================================================================
167// Tests
168// =============================================================================
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::activation::{GELU, ReLU};
174    use crate::layers::{BatchNorm1d, Conv1d, Linear};
175    use axonml_tensor::Tensor;
176
177    #[test]
178    fn test_residual_block_identity_skip() {
179        // Main path that preserves dimensions
180        let main = Sequential::new()
181            .add(Linear::new(32, 32))
182            .add(ReLU)
183            .add(Linear::new(32, 32));
184
185        let block = ResidualBlock::new(main);
186
187        let input = Variable::new(
188            Tensor::from_vec(vec![1.0; 64], &[2, 32]).expect("tensor creation failed"),
189            false,
190        );
191        let output = block.forward(&input);
192
193        // Output shape should match input
194        assert_eq!(output.shape(), vec![2, 32]);
195    }
196
197    #[test]
198    fn test_residual_block_with_downsample() {
199        // Main path changes dimensions: 32 -> 64
200        let main = Sequential::new()
201            .add(Linear::new(32, 64))
202            .add(ReLU)
203            .add(Linear::new(64, 64));
204
205        // Downsample projects input: 32 -> 64
206        let downsample = Sequential::new().add(Linear::new(32, 64));
207
208        let block = ResidualBlock::new(main).with_downsample(downsample);
209
210        let input = Variable::new(
211            Tensor::from_vec(vec![1.0; 64], &[2, 32]).expect("tensor creation failed"),
212            false,
213        );
214        let output = block.forward(&input);
215        assert_eq!(output.shape(), vec![2, 64]);
216    }
217
218    #[test]
219    fn test_residual_block_custom_activation() {
220        let main = Sequential::new().add(Linear::new(16, 16));
221
222        let block = ResidualBlock::new(main).with_activation(GELU);
223
224        let input = Variable::new(
225            Tensor::from_vec(vec![1.0; 32], &[2, 16]).expect("tensor creation failed"),
226            false,
227        );
228        let output = block.forward(&input);
229        assert_eq!(output.shape(), vec![2, 16]);
230    }
231
232    #[test]
233    fn test_residual_block_no_activation() {
234        let main = Sequential::new().add(Linear::new(16, 16));
235
236        let block = ResidualBlock::new(main).without_activation();
237
238        let input = Variable::new(
239            Tensor::from_vec(vec![1.0; 32], &[2, 16]).expect("tensor creation failed"),
240            false,
241        );
242        let output = block.forward(&input);
243        assert_eq!(output.shape(), vec![2, 16]);
244    }
245
246    #[test]
247    fn test_residual_block_parameters() {
248        let main = Sequential::new()
249            .add(Linear::new(32, 32)) // weight(32x32) + bias(32) = 1056
250            .add(Linear::new(32, 32)); // weight(32x32) + bias(32) = 1056
251
252        let block = ResidualBlock::new(main);
253        let params = block.parameters();
254        assert_eq!(params.len(), 4); // 2 weights + 2 biases
255    }
256
257    #[test]
258    fn test_residual_block_named_parameters() {
259        let main = Sequential::new()
260            .add_named("conv1", Linear::new(32, 32))
261            .add_named("conv2", Linear::new(32, 32));
262
263        let downsample = Sequential::new().add_named("proj", Linear::new(32, 32));
264
265        let block = ResidualBlock::new(main).with_downsample(downsample);
266        let params = block.named_parameters();
267
268        assert!(params.contains_key("main_path.conv1.weight"));
269        assert!(params.contains_key("main_path.conv2.weight"));
270        assert!(params.contains_key("downsample.proj.weight"));
271    }
272
273    #[test]
274    fn test_residual_block_training_mode() {
275        let main = Sequential::new()
276            .add(BatchNorm1d::new(32))
277            .add(Linear::new(32, 32));
278
279        let mut block = ResidualBlock::new(main);
280        assert!(block.is_training());
281
282        block.set_training(false);
283        assert!(!block.is_training());
284
285        block.set_training(true);
286        assert!(block.is_training());
287    }
288
289    #[test]
290    fn test_residual_block_conv1d_with_downsample() {
291        // Real use case: Conv1d residual block with downsample to match dimensions
292        // Main path: 2 Conv1d(k=3) reduces time by 4 (20 -> 18 -> 16)
293        let main = Sequential::new()
294            .add(Conv1d::new(64, 64, 3))
295            .add(BatchNorm1d::new(64))
296            .add(ReLU)
297            .add(Conv1d::new(64, 64, 3))
298            .add(BatchNorm1d::new(64));
299
300        // Downsample matches skip connection to main path output shape
301        // Conv1d with kernel=5 reduces 20 -> 16
302        let downsample = Sequential::new()
303            .add(Conv1d::new(64, 64, 5))
304            .add(BatchNorm1d::new(64));
305
306        let block = ResidualBlock::new(main).with_downsample(downsample);
307
308        // Input: (batch=2, channels=64, time=20)
309        let input = Variable::new(
310            Tensor::from_vec(vec![1.0; 2 * 64 * 20], &[2, 64, 20]).expect("tensor creation failed"),
311            false,
312        );
313        let output = block.forward(&input);
314
315        assert_eq!(output.shape()[0], 2);
316        assert_eq!(output.shape()[1], 64);
317        assert_eq!(output.shape()[2], 16);
318    }
319
320    #[test]
321    fn test_residual_block_gradient_flow() {
322        let main = Sequential::new().add(Linear::new(4, 4));
323
324        let block = ResidualBlock::new(main);
325
326        let input = Variable::new(
327            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
328            true,
329        );
330        let output = block.forward(&input);
331
332        // Sum to scalar for backward
333        let sum = output.sum();
334        sum.backward();
335
336        // Gradient should flow through both main path and skip connection
337        let params = block.parameters();
338        assert!(!params.is_empty());
339    }
340}