Skip to main content

ferrotorch_nn/
se.rs

1//! Squeeze-and-Excitation (SE) block — Hu et al. 2018, *Squeeze-and-Excitation
2//! Networks*.
3//!
4//! Mirrors `torchvision.ops.misc.SqueezeExcitation`, including the
5//! `named_children()` order and the use of 1×1 [`Conv2d`] (NOT [`Linear`])
6//! for both the squeeze and excitation projections. Per-channel attention
7//! via global average pool → 1×1 conv → activation → 1×1 conv →
8//! scale_activation → broadcast multiply.
9//!
10//! ```text
11//! x: [B, C, H, W]
12//!     │
13//!     ├─────────────────────────────────────┐
14//!     ▼                                     │
15//! avgpool([B,C,1,1]) → fc1([B,sq,1,1])     │
16//!     → activation                          │
17//!     → fc2([B,C,1,1])                      │
18//!     → scale_activation                    │
19//!     │                                     │
20//!     └────────────── (broadcast *) ────────┘
21//!                     ▼
22//!                  [B, C, H, W]
23//! ```
24//!
25//! Used by MobileNetV3 (with [`HardSigmoid`](crate::activation::HardSigmoid)
26//! as `scale_activation`) and EfficientNet (with [`Sigmoid`] as
27//! `scale_activation`). The default [`SqueezeExcitation::new`] constructor
28//! ([`ReLU`] + [`Sigmoid`]) matches `torchvision.ops.SqueezeExcitation`'s
29//! default. For mixed-precision or alternative activations, use
30//! [`SqueezeExcitation::new_with_activations`].
31//!
32//! # Module trait surface
33//!
34//! - [`named_parameters`](Module::named_parameters): `fc1.weight`,
35//!   `fc1.bias`, `fc2.weight`, `fc2.bias` — exactly the four keys
36//!   produced by torchvision's `SqueezeExcitation`.
37//! - [`named_children`](Module::named_children): `avgpool`, `fc1`,
38//!   `activation`, `fc2`, `scale_activation` — same order torchvision
39//!   exposes through `Sequential`-style submodule naming.
40//!
41//! # Differentiability
42//!
43//! Forward composes only [`Module::forward`]-shaped primitives that
44//! already track gradients (`Conv2d`, `AdaptiveAvgPool2d`, `ReLU`,
45//! `Sigmoid`, `HardSigmoid`, `mul`), so backward flows end-to-end.
46//!
47//! ## REQ status (per `.design/ferrotorch-nn/se.md`)
48//!
49//! | REQ | Status | Evidence |
50//! |---|---|---|
51//! | REQ-1 | SHIPPED | the `SqueezeExcitation<T>` struct here; non-test consumer: re-export at `ferrotorch-nn/src/lib.rs:247` + `ferrotorch-vision/src/models/mobilenet.rs:56` + `efficientnet.rs:39` |
52//! | REQ-2 | SHIPPED | the `SqueezeExcitation::new` constructor here (delegates to `new_with_activations`); non-test consumer: re-export at `lib.rs:247` + the two vision consumers |
53//! | REQ-3 | SHIPPED | the `new_with_activations` constructor on `SqueezeExcitation` here; non-test consumer: `mobilenet.rs:56` uses `HardSigmoid` scale; `efficientnet.rs:39` uses `Sigmoid` |
54//! | REQ-4 | SHIPPED | the `forward` method on `SqueezeExcitation` here; non-test consumer: re-export at `lib.rs:247` + MobileNetV3 and EfficientNet forwards |
55//! | REQ-5 | SHIPPED | the `impl<T: Float> Module<T> for SqueezeExcitation<T>` with `parameters` / `parameters_mut` / `named_parameters` here; non-test consumer: re-export at `lib.rs:247` |
56//! | REQ-6 | SHIPPED | the `children` + `named_children` methods inside the Module impl here; non-test consumer: re-export at `lib.rs:247` |
57//! | REQ-7 | SHIPPED | the `train` / `eval` methods inside the Module impl here (forwards to both boxed activations); non-test consumer: re-export at `lib.rs:247` |
58//! | REQ-8 | SHIPPED | the `impl<T: Float> std::fmt::Debug for SqueezeExcitation<T>` here; non-test consumer: re-export at `lib.rs:247` |
59//! | REQ-9 | SHIPPED | forward body composes only differentiable primitives (Conv2d, AdaptiveAvgPool2d, dyn Module activations, mul) here; non-test consumer: re-export at `lib.rs:247` |
60//! | REQ-10 | SHIPPED | `Send + Sync` bound automatic from Conv2d + boxed activation bounds; pinned by `se_is_send_sync` in `mod tests`; non-test consumer: re-export at `lib.rs:247` |
61
62use ferrotorch_core::grad_fns::arithmetic::mul;
63use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
64
65use crate::activation::{ReLU, Sigmoid};
66use crate::conv::Conv2d;
67use crate::module::Module;
68use crate::parameter::Parameter;
69use crate::pooling::AdaptiveAvgPool2d;
70
71/// Squeeze-and-Excitation block.
72///
73/// Channel-wise attention: pool → squeeze → excite → scale.
74/// See module docs for the full diagram. The default constructor
75/// ([`Self::new`]) uses ReLU as the inner activation and Sigmoid
76/// as the scale activation, matching torchvision's
77/// `SqueezeExcitation` default; [`Self::new_with_activations`] lets
78/// callers swap either (e.g. SiLU + HardSigmoid for MobileNetV3).
79pub struct SqueezeExcitation<T: Float> {
80    /// Global-average-pool to `[B, C, 1, 1]`.
81    avgpool: AdaptiveAvgPool2d,
82    /// 1×1 convolution: `[B, C, 1, 1] → [B, sq, 1, 1]`.
83    fc1: Conv2d<T>,
84    /// Inner activation between fc1 and fc2 (default ReLU).
85    activation: Box<dyn Module<T>>,
86    /// 1×1 convolution: `[B, sq, 1, 1] → [B, C, 1, 1]`.
87    fc2: Conv2d<T>,
88    /// Output activation that gates the input (default Sigmoid).
89    scale_activation: Box<dyn Module<T>>,
90    /// Whether the module is in training mode.
91    training: bool,
92}
93
94impl<T: Float> std::fmt::Debug for SqueezeExcitation<T> {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        f.debug_struct("SqueezeExcitation")
97            .field("fc1", &self.fc1)
98            .field("fc2", &self.fc2)
99            .field("training", &self.training)
100            .finish()
101    }
102}
103
104impl<T: Float> SqueezeExcitation<T> {
105    /// Create a new SE block with the default activations
106    /// (ReLU squeeze, Sigmoid scale).
107    ///
108    /// `input_channels` is `C` in the diagram above; `squeeze_channels`
109    /// is the bottleneck width (typically `C / r` for a reduction ratio
110    /// `r`).
111    ///
112    /// # Errors
113    /// Returns [`FerrotorchError::InvalidArgument`] if either channel
114    /// count is zero.
115    pub fn new(input_channels: usize, squeeze_channels: usize) -> FerrotorchResult<Self> {
116        Self::new_with_activations(
117            input_channels,
118            squeeze_channels,
119            Box::new(ReLU::new()),
120            Box::new(Sigmoid::new()),
121        )
122    }
123
124    /// Create a new SE block with caller-supplied activations.
125    ///
126    /// MobileNetV3-Small uses ReLU + HardSigmoid; EfficientNet uses
127    /// SiLU + Sigmoid. Both are constructible from this entry point.
128    ///
129    /// # Errors
130    /// Returns [`FerrotorchError::InvalidArgument`] if either channel
131    /// count is zero. Errors from [`Conv2d::new`] (negative shape,
132    /// allocation failure) bubble up.
133    pub fn new_with_activations(
134        input_channels: usize,
135        squeeze_channels: usize,
136        activation: Box<dyn Module<T>>,
137        scale_activation: Box<dyn Module<T>>,
138    ) -> FerrotorchResult<Self> {
139        if input_channels == 0 || squeeze_channels == 0 {
140            return Err(FerrotorchError::InvalidArgument {
141                message: format!(
142                    "SqueezeExcitation: input_channels and squeeze_channels must be > 0 \
143                     (got input_channels={input_channels}, squeeze_channels={squeeze_channels})"
144                ),
145            });
146        }
147
148        // 1×1 convolutions with bias — torchvision's `SqueezeExcitation`
149        // uses bias=True for both fc1 and fc2 (they appear unset, which
150        // defaults to True in `nn.Conv2d`).
151        let fc1 = Conv2d::new(
152            input_channels,
153            squeeze_channels,
154            (1, 1),
155            (1, 1),
156            (0, 0),
157            true,
158        )?;
159        let fc2 = Conv2d::new(
160            squeeze_channels,
161            input_channels,
162            (1, 1),
163            (1, 1),
164            (0, 0),
165            true,
166        )?;
167        let avgpool = AdaptiveAvgPool2d::new((1, 1));
168
169        Ok(Self {
170            avgpool,
171            fc1,
172            activation,
173            fc2,
174            scale_activation,
175            training: true,
176        })
177    }
178
179    /// Forward pass.
180    ///
181    /// `input` must be 4-D `[B, C, H, W]` with `C == input_channels`.
182    pub fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
183        // 1) Squeeze: global average pool to [B, C, 1, 1].
184        let scale = Module::<T>::forward(&self.avgpool, input)?;
185        // 2) fc1 → activation → fc2.
186        let scale = self.fc1.forward(&scale)?;
187        let scale = self.activation.forward(&scale)?;
188        let scale = self.fc2.forward(&scale)?;
189        // 3) scale_activation gates the [B, C, 1, 1] tensor.
190        let scale = self.scale_activation.forward(&scale)?;
191        // 4) Broadcast multiply: [B, C, H, W] * [B, C, 1, 1].
192        mul(input, &scale)
193    }
194}
195
196impl<T: Float> Module<T> for SqueezeExcitation<T> {
197    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
198        self.forward(input)
199    }
200
201    fn parameters(&self) -> Vec<&Parameter<T>> {
202        let mut p = Vec::new();
203        p.extend(self.fc1.parameters());
204        p.extend(self.fc2.parameters());
205        p
206    }
207
208    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
209        let mut p = Vec::new();
210        p.extend(self.fc1.parameters_mut());
211        p.extend(self.fc2.parameters_mut());
212        p
213    }
214
215    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
216        let mut p = Vec::new();
217        for (n, param) in self.fc1.named_parameters() {
218            p.push((format!("fc1.{n}"), param));
219        }
220        for (n, param) in self.fc2.named_parameters() {
221            p.push((format!("fc2.{n}"), param));
222        }
223        p
224    }
225
226    fn children(&self) -> Vec<&dyn Module<T>> {
227        vec![
228            &self.avgpool,
229            &self.fc1,
230            self.activation.as_ref(),
231            &self.fc2,
232            self.scale_activation.as_ref(),
233        ]
234    }
235
236    fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
237        vec![
238            ("avgpool".to_string(), &self.avgpool as &dyn Module<T>),
239            ("fc1".to_string(), &self.fc1),
240            ("activation".to_string(), self.activation.as_ref()),
241            ("fc2".to_string(), &self.fc2),
242            (
243                "scale_activation".to_string(),
244                self.scale_activation.as_ref(),
245            ),
246        ]
247    }
248
249    fn train(&mut self) {
250        self.training = true;
251        self.activation.train();
252        self.scale_activation.train();
253    }
254
255    fn eval(&mut self) {
256        self.training = false;
257        self.activation.eval();
258        self.scale_activation.eval();
259    }
260
261    fn is_training(&self) -> bool {
262        self.training
263    }
264}
265
266// ===========================================================================
267// Tests
268// ===========================================================================
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use crate::activation::{HardSigmoid, SiLU};
274    use ferrotorch_core::storage::TensorStorage;
275
276    fn cpu_tensor_4d(data: Vec<f32>, shape: [usize; 4]) -> Tensor<f32> {
277        Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false).unwrap()
278    }
279
280    #[test]
281    fn se_construction_smoke() {
282        let se = SqueezeExcitation::<f32>::new(16, 4).expect("SE construction");
283        assert_eq!(se.fc1.parameters().len(), 2); // weight + bias
284        assert_eq!(se.fc2.parameters().len(), 2);
285    }
286
287    /// Evidence #6: named_parameters returns exactly fc1.{weight,bias},
288    /// fc2.{weight,bias} — matching torchvision's SqueezeExcitation
289    /// state_dict key set.
290    #[test]
291    fn se_named_parameters_match_torchvision() {
292        let se = SqueezeExcitation::<f32>::new(16, 4).unwrap();
293        let names: Vec<String> = se.named_parameters().into_iter().map(|(n, _)| n).collect();
294        assert_eq!(
295            names,
296            vec![
297                "fc1.weight".to_string(),
298                "fc1.bias".to_string(),
299                "fc2.weight".to_string(),
300                "fc2.bias".to_string(),
301            ]
302        );
303    }
304
305    /// Evidence #7: named_children order matches torchvision's
306    /// `(avgpool, fc1, activation, fc2, scale_activation)`.
307    #[test]
308    fn se_named_children_match_torchvision_order() {
309        let se = SqueezeExcitation::<f32>::new(16, 4).unwrap();
310        let names: Vec<String> = se.named_children().into_iter().map(|(n, _)| n).collect();
311        assert_eq!(
312            names,
313            vec![
314                "avgpool".to_string(),
315                "fc1".to_string(),
316                "activation".to_string(),
317                "fc2".to_string(),
318                "scale_activation".to_string(),
319            ]
320        );
321    }
322
323    /// Evidence #4: SE primitive forward equals manually-composed
324    /// AdaptiveAvgPool2d + Conv2d(1×1) + ReLU + Conv2d(1×1) + Sigmoid +
325    /// broadcast multiply.
326    #[test]
327    fn se_forward_matches_manual_composition() {
328        // Build SE block then a parallel manual pipeline that shares its
329        // weights. Forward both on a deterministic input and check
330        // bitwise (or near-bitwise) equality.
331        let mut se = SqueezeExcitation::<f32>::new(8, 2).unwrap();
332
333        // Replace fc1, fc2 weights with deterministic small values so the
334        // intermediate magnitudes stay finite.
335        let fc1_weight = Tensor::from_storage(
336            TensorStorage::cpu(vec![0.05_f32; 2 * 8]),
337            vec![2, 8, 1, 1],
338            false,
339        )
340        .unwrap();
341        let fc1_bias =
342            Tensor::from_storage(TensorStorage::cpu(vec![0.01_f32; 2]), vec![2], false).unwrap();
343        let fc2_weight = Tensor::from_storage(
344            TensorStorage::cpu(vec![0.07_f32; 8 * 2]),
345            vec![8, 2, 1, 1],
346            false,
347        )
348        .unwrap();
349        let fc2_bias =
350            Tensor::from_storage(TensorStorage::cpu(vec![0.02_f32; 8]), vec![8], false).unwrap();
351
352        se.fc1
353            .set_weight(Parameter::new(fc1_weight.clone()))
354            .unwrap();
355        // Conv2d::set_weight only validates shape; bias is inaccessible
356        // through the public API. Re-build fc1/fc2 from_parts to inject
357        // bias deterministically.
358        let new_fc1 =
359            Conv2d::from_parts(fc1_weight, Some(fc1_bias.clone()), (1, 1), (0, 0)).unwrap();
360        let new_fc2 =
361            Conv2d::from_parts(fc2_weight, Some(fc2_bias.clone()), (1, 1), (0, 0)).unwrap();
362        se.fc1 = new_fc1;
363        se.fc2 = new_fc2;
364
365        // 1×8×4×4 input, deterministic.
366        let n = /* B*C*H*W = 1*8*4*4 */ 8 * 4 * 4;
367        let data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.01).collect();
368        let x = cpu_tensor_4d(data.clone(), [1, 8, 4, 4]);
369
370        let out_se = se.forward(&x).unwrap();
371
372        // Manual pipeline: AdaptiveAvgPool2d → Conv2d(fc1) → ReLU →
373        // Conv2d(fc2) → Sigmoid → mul.
374        let pool = AdaptiveAvgPool2d::new((1, 1));
375        let m_relu = ReLU::new();
376        let m_sig = Sigmoid::new();
377        let p = Module::<f32>::forward(&pool, &x).unwrap();
378        let p = se.fc1.forward(&p).unwrap();
379        let p = m_relu.forward(&p).unwrap();
380        let p = se.fc2.forward(&p).unwrap();
381        let p = m_sig.forward(&p).unwrap();
382        let manual = mul(&x, &p).unwrap();
383
384        let a = out_se.data().unwrap();
385        let m = manual.data().unwrap();
386        assert_eq!(a.len(), m.len());
387        for i in 0..a.len() {
388            assert!(
389                (a[i] - m[i]).abs() < 1e-6,
390                "SE primitive vs manual mismatch at {i}: se={} manual={}",
391                a[i],
392                m[i]
393            );
394        }
395    }
396
397    /// Probe-before-fix (Evidence #3): hand-computed reference for a
398    /// trivial 1×4×8×8 input where every element is 1.0, with both fc
399    /// weights set to 0 and biases tuned so the gate is 0.5 → output
400    /// is 0.5 * input = 0.5 everywhere.
401    #[test]
402    fn se_probe_handcomputed_reference() {
403        // Inputs all 1.0, fc1 weights/bias = 0, fc2 weights/bias = 0.
404        // After avgpool: [1,4,1,1] all 1.0.
405        // After fc1: [1,2,1,1] all 0.0.
406        // After ReLU: [1,2,1,1] all 0.0.
407        // After fc2: [1,4,1,1] all 0.0.
408        // After Sigmoid: [1,4,1,1] all 0.5.
409        // Final: input * 0.5 = 0.5 everywhere.
410        let mut se = SqueezeExcitation::<f32>::new(4, 2).unwrap();
411        let fc1_weight = Tensor::from_storage(
412            TensorStorage::cpu(vec![0.0_f32; 2 * 4]),
413            vec![2, 4, 1, 1],
414            false,
415        )
416        .unwrap();
417        let fc1_bias =
418            Tensor::from_storage(TensorStorage::cpu(vec![0.0_f32; 2]), vec![2], false).unwrap();
419        let fc2_weight = Tensor::from_storage(
420            TensorStorage::cpu(vec![0.0_f32; 4 * 2]),
421            vec![4, 2, 1, 1],
422            false,
423        )
424        .unwrap();
425        let fc2_bias =
426            Tensor::from_storage(TensorStorage::cpu(vec![0.0_f32; 4]), vec![4], false).unwrap();
427        se.fc1 = Conv2d::from_parts(fc1_weight, Some(fc1_bias), (1, 1), (0, 0)).unwrap();
428        se.fc2 = Conv2d::from_parts(fc2_weight, Some(fc2_bias), (1, 1), (0, 0)).unwrap();
429
430        let n = /* B*C*H*W = 1*4*8*8 */ 4 * 8 * 8;
431        let x = cpu_tensor_4d(vec![1.0_f32; n], [1, 4, 8, 8]);
432        let out = se.forward(&x).unwrap();
433        let data = out.data().unwrap();
434        for &v in data.iter() {
435            assert!(
436                (v - 0.5).abs() < 1e-6,
437                "expected gate output 0.5 everywhere, got {v}"
438            );
439        }
440    }
441
442    /// Evidence #5: backward by finite differences vs analytic gradient
443    /// (small input; tolerance 1e-2). Confirms forward is differentiable
444    /// end-to-end.
445    #[test]
446    fn se_backward_finite_differences() {
447        use ferrotorch_core::grad_fns::reduction::sum;
448        // Build SE block on a small input.
449        let se = SqueezeExcitation::<f32>::new(4, 2).unwrap();
450
451        let n = /* B*C*H*W = 1*4*4*4 */ 4 * 4 * 4;
452        let data: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.05).sin()).collect();
453        let x =
454            Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![1, 4, 4, 4], true).unwrap();
455
456        let out = se.forward(&x).unwrap();
457        let loss = sum(&out).unwrap();
458        loss.backward().unwrap();
459        let grad = x.grad().unwrap().expect("x should carry grad");
460
461        // FD on a few elements.
462        let analytic = grad.data().unwrap().to_vec();
463        let h = 1e-3_f32;
464
465        for &i in &[0_usize, 7, 25, 50, n - 1] {
466            let mut p = data.clone();
467            p[i] += h;
468            let xp = Tensor::from_storage(TensorStorage::cpu(p), vec![1, 4, 4, 4], false).unwrap();
469            let mut m = data.clone();
470            m[i] -= h;
471            let xm = Tensor::from_storage(TensorStorage::cpu(m), vec![1, 4, 4, 4], false).unwrap();
472            let lp: f32 = se.forward(&xp).unwrap().data().unwrap().iter().sum();
473            let lm: f32 = se.forward(&xm).unwrap().data().unwrap().iter().sum();
474            let fd = (lp - lm) / (2.0 * h);
475            assert!(
476                (analytic[i] - fd).abs() < 1e-2,
477                "SE backward FD mismatch at {i}: analytic={} fd={}",
478                analytic[i],
479                fd
480            );
481        }
482    }
483
484    #[test]
485    fn se_with_hardsigmoid_scale_smoke() {
486        // V3-style: ReLU + HardSigmoid.
487        let se: SqueezeExcitation<f32> = SqueezeExcitation::new_with_activations(
488            8,
489            2,
490            Box::new(ReLU::new()),
491            Box::new(HardSigmoid::new()),
492        )
493        .unwrap();
494        let x = cpu_tensor_4d(vec![0.1_f32; 8 * 6 * 6], [1, 8, 6, 6]);
495        let out = se.forward(&x).unwrap();
496        assert_eq!(out.shape(), &[1, 8, 6, 6]);
497    }
498
499    #[test]
500    fn se_with_silu_sigmoid_smoke() {
501        // EfficientNet-style: SiLU + Sigmoid.
502        let se: SqueezeExcitation<f32> = SqueezeExcitation::new_with_activations(
503            16,
504            4,
505            Box::new(SiLU::new()),
506            Box::new(Sigmoid::new()),
507        )
508        .unwrap();
509        let x = cpu_tensor_4d(vec![0.05_f32; 16 * 4 * 4], [1, 16, 4, 4]);
510        let out = se.forward(&x).unwrap();
511        assert_eq!(out.shape(), &[1, 16, 4, 4]);
512    }
513
514    /// Validate that SE is `Send + Sync` so it can compose into any
515    /// model whose `Module` bound requires both. (`Box<dyn Module<T>>`
516    /// inside the struct must propagate these bounds.)
517    #[test]
518    fn se_is_send_sync() {
519        fn assert_send_sync<T: Send + Sync>() {}
520        assert_send_sync::<SqueezeExcitation<f32>>();
521    }
522}