Skip to main content

axonml_nn/layers/
residual.rs

1//! Residual Block - Generic Skip Connection Layer
2//!
3//! Provides a reusable residual connection wrapper that can be used with
4//! any inner module (Conv1d, Conv2d, Linear sequences, etc.).
5//!
6//! @version 0.1.0
7//! @author AutomataNexus Development Team
8
9use std::collections::HashMap;
10
11use axonml_autograd::Variable;
12
13use crate::module::Module;
14use crate::parameter::Parameter;
15use crate::sequential::Sequential;
16use crate::activation::ReLU;
17
18// =============================================================================
19// ResidualBlock
20// =============================================================================
21
22/// A generic residual block that wraps any module sequence with a skip connection.
23///
24/// Computes: `activation(main_path(x) + downsample(x))` where downsample is
25/// optional (defaults to identity). This enables gradient flow through the
26/// skip connection, improving training of deep networks.
27///
28/// # Example
29/// ```ignore
30/// use axonml_nn::prelude::*;
31/// use axonml_nn::layers::ResidualBlock;
32///
33/// // Conv1d residual block
34/// let main = Sequential::new()
35///     .add(Conv1d::new(64, 64, 3))
36///     .add(BatchNorm1d::new(64))
37///     .add(ReLU)
38///     .add(Conv1d::new(64, 64, 3))
39///     .add(BatchNorm1d::new(64));
40///
41/// let block = ResidualBlock::new(main);
42/// ```
43pub struct ResidualBlock {
44    main_path: Sequential,
45    downsample: Option<Sequential>,
46    activation: Option<Box<dyn Module>>,
47    training: bool,
48}
49
50impl ResidualBlock {
51    /// Creates a new residual block with the given main path and ReLU activation.
52    ///
53    /// The skip connection is identity (no downsample). Use `with_downsample()`
54    /// if the main path changes dimensions.
55    pub fn new(main_path: Sequential) -> Self {
56        Self {
57            main_path,
58            downsample: None,
59            activation: Some(Box::new(ReLU)),
60            training: true,
61        }
62    }
63
64    /// Adds a downsample projection for when input/output dimensions differ.
65    ///
66    /// Typically a Conv + BatchNorm to match channel/spatial dimensions.
67    pub fn with_downsample(mut self, downsample: Sequential) -> Self {
68        self.downsample = Some(downsample);
69        self
70    }
71
72    /// Sets a custom activation function applied after the residual addition.
73    ///
74    /// Pass any module implementing `Module` (ReLU, GELU, SiLU, etc.).
75    pub fn with_activation<M: Module + 'static>(mut self, activation: M) -> Self {
76        self.activation = Some(Box::new(activation));
77        self
78    }
79
80    /// Removes the post-addition activation (pre-activation ResNet style).
81    pub fn without_activation(mut self) -> Self {
82        self.activation = None;
83        self
84    }
85}
86
87impl Module for ResidualBlock {
88    fn forward(&self, input: &Variable) -> Variable {
89        let identity = match &self.downsample {
90            Some(ds) => ds.forward(input),
91            None => input.clone(),
92        };
93
94        let out = self.main_path.forward(input);
95        let out = out.add_var(&identity);
96
97        match &self.activation {
98            Some(act) => act.forward(&out),
99            None => out,
100        }
101    }
102
103    fn parameters(&self) -> Vec<Parameter> {
104        let mut params = self.main_path.parameters();
105        if let Some(ds) = &self.downsample {
106            params.extend(ds.parameters());
107        }
108        if let Some(act) = &self.activation {
109            params.extend(act.parameters());
110        }
111        params
112    }
113
114    fn named_parameters(&self) -> HashMap<String, Parameter> {
115        let mut params = HashMap::new();
116        for (name, param) in self.main_path.named_parameters() {
117            params.insert(format!("main_path.{name}"), param);
118        }
119        if let Some(ds) = &self.downsample {
120            for (name, param) in ds.named_parameters() {
121                params.insert(format!("downsample.{name}"), param);
122            }
123        }
124        if let Some(act) = &self.activation {
125            for (name, param) in act.named_parameters() {
126                params.insert(format!("activation.{name}"), param);
127            }
128        }
129        params
130    }
131
132    fn set_training(&mut self, training: bool) {
133        self.training = training;
134        self.main_path.set_training(training);
135        if let Some(ds) = &mut self.downsample {
136            ds.set_training(training);
137        }
138        if let Some(act) = &mut self.activation {
139            act.set_training(training);
140        }
141    }
142
143    fn is_training(&self) -> bool {
144        self.training
145    }
146
147    fn name(&self) -> &'static str {
148        "ResidualBlock"
149    }
150}
151
152// =============================================================================
153// Tests
154// =============================================================================
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use axonml_tensor::Tensor;
160    use crate::layers::{Linear, BatchNorm1d, Conv1d};
161    use crate::activation::{ReLU, GELU};
162
163    #[test]
164    fn test_residual_block_identity_skip() {
165        // Main path that preserves dimensions
166        let main = Sequential::new()
167            .add(Linear::new(32, 32))
168            .add(ReLU)
169            .add(Linear::new(32, 32));
170
171        let block = ResidualBlock::new(main);
172
173        let input = Variable::new(
174            Tensor::from_vec(vec![1.0; 64], &[2, 32]).unwrap(),
175            false,
176        );
177        let output = block.forward(&input);
178
179        // Output shape should match input
180        assert_eq!(output.shape(), vec![2, 32]);
181    }
182
183    #[test]
184    fn test_residual_block_with_downsample() {
185        // Main path changes dimensions: 32 -> 64
186        let main = Sequential::new()
187            .add(Linear::new(32, 64))
188            .add(ReLU)
189            .add(Linear::new(64, 64));
190
191        // Downsample projects input: 32 -> 64
192        let downsample = Sequential::new()
193            .add(Linear::new(32, 64));
194
195        let block = ResidualBlock::new(main).with_downsample(downsample);
196
197        let input = Variable::new(
198            Tensor::from_vec(vec![1.0; 64], &[2, 32]).unwrap(),
199            false,
200        );
201        let output = block.forward(&input);
202        assert_eq!(output.shape(), vec![2, 64]);
203    }
204
205    #[test]
206    fn test_residual_block_custom_activation() {
207        let main = Sequential::new()
208            .add(Linear::new(16, 16));
209
210        let block = ResidualBlock::new(main).with_activation(GELU);
211
212        let input = Variable::new(
213            Tensor::from_vec(vec![1.0; 32], &[2, 16]).unwrap(),
214            false,
215        );
216        let output = block.forward(&input);
217        assert_eq!(output.shape(), vec![2, 16]);
218    }
219
220    #[test]
221    fn test_residual_block_no_activation() {
222        let main = Sequential::new()
223            .add(Linear::new(16, 16));
224
225        let block = ResidualBlock::new(main).without_activation();
226
227        let input = Variable::new(
228            Tensor::from_vec(vec![1.0; 32], &[2, 16]).unwrap(),
229            false,
230        );
231        let output = block.forward(&input);
232        assert_eq!(output.shape(), vec![2, 16]);
233    }
234
235    #[test]
236    fn test_residual_block_parameters() {
237        let main = Sequential::new()
238            .add(Linear::new(32, 32))   // weight(32x32) + bias(32) = 1056
239            .add(Linear::new(32, 32));  // weight(32x32) + bias(32) = 1056
240
241        let block = ResidualBlock::new(main);
242        let params = block.parameters();
243        assert_eq!(params.len(), 4); // 2 weights + 2 biases
244    }
245
246    #[test]
247    fn test_residual_block_named_parameters() {
248        let main = Sequential::new()
249            .add_named("conv1", Linear::new(32, 32))
250            .add_named("conv2", Linear::new(32, 32));
251
252        let downsample = Sequential::new()
253            .add_named("proj", Linear::new(32, 32));
254
255        let block = ResidualBlock::new(main).with_downsample(downsample);
256        let params = block.named_parameters();
257
258        assert!(params.contains_key("main_path.conv1.weight"));
259        assert!(params.contains_key("main_path.conv2.weight"));
260        assert!(params.contains_key("downsample.proj.weight"));
261    }
262
263    #[test]
264    fn test_residual_block_training_mode() {
265        let main = Sequential::new()
266            .add(BatchNorm1d::new(32))
267            .add(Linear::new(32, 32));
268
269        let mut block = ResidualBlock::new(main);
270        assert!(block.is_training());
271
272        block.set_training(false);
273        assert!(!block.is_training());
274
275        block.set_training(true);
276        assert!(block.is_training());
277    }
278
279    #[test]
280    fn test_residual_block_conv1d_with_downsample() {
281        // Real use case: Conv1d residual block with downsample to match dimensions
282        // Main path: 2 Conv1d(k=3) reduces time by 4 (20 -> 18 -> 16)
283        let main = Sequential::new()
284            .add(Conv1d::new(64, 64, 3))
285            .add(BatchNorm1d::new(64))
286            .add(ReLU)
287            .add(Conv1d::new(64, 64, 3))
288            .add(BatchNorm1d::new(64));
289
290        // Downsample matches skip connection to main path output shape
291        // Conv1d with kernel=5 reduces 20 -> 16
292        let downsample = Sequential::new()
293            .add(Conv1d::new(64, 64, 5))
294            .add(BatchNorm1d::new(64));
295
296        let block = ResidualBlock::new(main).with_downsample(downsample);
297
298        // Input: (batch=2, channels=64, time=20)
299        let input = Variable::new(
300            Tensor::from_vec(vec![1.0; 2 * 64 * 20], &[2, 64, 20]).unwrap(),
301            false,
302        );
303        let output = block.forward(&input);
304
305        assert_eq!(output.shape()[0], 2);
306        assert_eq!(output.shape()[1], 64);
307        assert_eq!(output.shape()[2], 16);
308    }
309
310    #[test]
311    fn test_residual_block_gradient_flow() {
312        let main = Sequential::new()
313            .add(Linear::new(4, 4));
314
315        let block = ResidualBlock::new(main);
316
317        let input = Variable::new(
318            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap(),
319            true,
320        );
321        let output = block.forward(&input);
322
323        // Sum to scalar for backward
324        let sum = output.sum();
325        sum.backward();
326
327        // Gradient should flow through both main path and skip connection
328        let params = block.parameters();
329        assert!(!params.is_empty());
330    }
331}