ferrotorch_nn/se.rs
1//! Squeeze-and-Excitation (SE) block — Hu et al. 2018, *Squeeze-and-Excitation
2//! Networks*.
3//!
4//! Mirrors `torchvision.ops.misc.SqueezeExcitation`, including the
5//! `named_children()` order and the use of 1×1 [`Conv2d`] (NOT [`Linear`])
6//! for both the squeeze and excitation projections. Per-channel attention
7//! via global average pool → 1×1 conv → activation → 1×1 conv →
8//! scale_activation → broadcast multiply.
9//!
10//! ```text
11//! x: [B, C, H, W]
12//! │
13//! ├─────────────────────────────────────┐
14//! ▼ │
15//! avgpool([B,C,1,1]) → fc1([B,sq,1,1]) │
16//! → activation │
17//! → fc2([B,C,1,1]) │
18//! → scale_activation │
19//! │ │
20//! └────────────── (broadcast *) ────────┘
21//! ▼
22//! [B, C, H, W]
23//! ```
24//!
25//! Used by MobileNetV3 (with [`HardSigmoid`](crate::activation::HardSigmoid)
26//! as `scale_activation`) and EfficientNet (with [`Sigmoid`] as
27//! `scale_activation`). The default [`SqueezeExcitation::new`] constructor
28//! ([`ReLU`] + [`Sigmoid`]) matches `torchvision.ops.SqueezeExcitation`'s
29//! default. For mixed-precision or alternative activations, use
30//! [`SqueezeExcitation::new_with_activations`].
31//!
32//! # Module trait surface
33//!
34//! - [`named_parameters`](Module::named_parameters): `fc1.weight`,
35//! `fc1.bias`, `fc2.weight`, `fc2.bias` — exactly the four keys
36//! produced by torchvision's `SqueezeExcitation`.
37//! - [`named_children`](Module::named_children): `avgpool`, `fc1`,
38//! `activation`, `fc2`, `scale_activation` — same order torchvision
39//! exposes through `Sequential`-style submodule naming.
40//!
41//! # Differentiability
42//!
43//! Forward composes only [`Module::forward`]-shaped primitives that
44//! already track gradients (`Conv2d`, `AdaptiveAvgPool2d`, `ReLU`,
45//! `Sigmoid`, `HardSigmoid`, `mul`), so backward flows end-to-end.
46//!
47//! ## REQ status (per `.design/ferrotorch-nn/se.md`)
48//!
49//! | REQ | Status | Evidence |
50//! |---|---|---|
51//! | REQ-1 | SHIPPED | the `SqueezeExcitation<T>` struct here; non-test consumer: re-export at `ferrotorch-nn/src/lib.rs:247` + `ferrotorch-vision/src/models/mobilenet.rs:56` + `efficientnet.rs:39` |
52//! | REQ-2 | SHIPPED | the `SqueezeExcitation::new` constructor here (delegates to `new_with_activations`); non-test consumer: re-export at `lib.rs:247` + the two vision consumers |
53//! | REQ-3 | SHIPPED | the `new_with_activations` constructor on `SqueezeExcitation` here; non-test consumer: `mobilenet.rs:56` uses `HardSigmoid` scale; `efficientnet.rs:39` uses `Sigmoid` |
54//! | REQ-4 | SHIPPED | the `forward` method on `SqueezeExcitation` here; non-test consumer: re-export at `lib.rs:247` + MobileNetV3 and EfficientNet forwards |
55//! | REQ-5 | SHIPPED | the `impl<T: Float> Module<T> for SqueezeExcitation<T>` with `parameters` / `parameters_mut` / `named_parameters` here; non-test consumer: re-export at `lib.rs:247` |
56//! | REQ-6 | SHIPPED | the `children` + `named_children` methods inside the Module impl here; non-test consumer: re-export at `lib.rs:247` |
57//! | REQ-7 | SHIPPED | the `train` / `eval` methods inside the Module impl here (forwards to both boxed activations); non-test consumer: re-export at `lib.rs:247` |
58//! | REQ-8 | SHIPPED | the `impl<T: Float> std::fmt::Debug for SqueezeExcitation<T>` here; non-test consumer: re-export at `lib.rs:247` |
59//! | REQ-9 | SHIPPED | forward body composes only differentiable primitives (Conv2d, AdaptiveAvgPool2d, dyn Module activations, mul) here; non-test consumer: re-export at `lib.rs:247` |
60//! | REQ-10 | SHIPPED | `Send + Sync` bound automatic from Conv2d + boxed activation bounds; pinned by `se_is_send_sync` in `mod tests`; non-test consumer: re-export at `lib.rs:247` |
61
62use ferrotorch_core::grad_fns::arithmetic::mul;
63use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
64
65use crate::activation::{ReLU, Sigmoid};
66use crate::conv::Conv2d;
67use crate::module::Module;
68use crate::parameter::Parameter;
69use crate::pooling::AdaptiveAvgPool2d;
70
71/// Squeeze-and-Excitation block.
72///
73/// Channel-wise attention: pool → squeeze → excite → scale.
74/// See module docs for the full diagram. The default constructor
75/// ([`Self::new`]) uses ReLU as the inner activation and Sigmoid
76/// as the scale activation, matching torchvision's
77/// `SqueezeExcitation` default; [`Self::new_with_activations`] lets
78/// callers swap either (e.g. SiLU + HardSigmoid for MobileNetV3).
79pub struct SqueezeExcitation<T: Float> {
80 /// Global-average-pool to `[B, C, 1, 1]`.
81 avgpool: AdaptiveAvgPool2d,
82 /// 1×1 convolution: `[B, C, 1, 1] → [B, sq, 1, 1]`.
83 fc1: Conv2d<T>,
84 /// Inner activation between fc1 and fc2 (default ReLU).
85 activation: Box<dyn Module<T>>,
86 /// 1×1 convolution: `[B, sq, 1, 1] → [B, C, 1, 1]`.
87 fc2: Conv2d<T>,
88 /// Output activation that gates the input (default Sigmoid).
89 scale_activation: Box<dyn Module<T>>,
90 /// Whether the module is in training mode.
91 training: bool,
92}
93
94impl<T: Float> std::fmt::Debug for SqueezeExcitation<T> {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 f.debug_struct("SqueezeExcitation")
97 .field("fc1", &self.fc1)
98 .field("fc2", &self.fc2)
99 .field("training", &self.training)
100 .finish()
101 }
102}
103
104impl<T: Float> SqueezeExcitation<T> {
105 /// Create a new SE block with the default activations
106 /// (ReLU squeeze, Sigmoid scale).
107 ///
108 /// `input_channels` is `C` in the diagram above; `squeeze_channels`
109 /// is the bottleneck width (typically `C / r` for a reduction ratio
110 /// `r`).
111 ///
112 /// # Errors
113 /// Returns [`FerrotorchError::InvalidArgument`] if either channel
114 /// count is zero.
115 pub fn new(input_channels: usize, squeeze_channels: usize) -> FerrotorchResult<Self> {
116 Self::new_with_activations(
117 input_channels,
118 squeeze_channels,
119 Box::new(ReLU::new()),
120 Box::new(Sigmoid::new()),
121 )
122 }
123
124 /// Create a new SE block with caller-supplied activations.
125 ///
126 /// MobileNetV3-Small uses ReLU + HardSigmoid; EfficientNet uses
127 /// SiLU + Sigmoid. Both are constructible from this entry point.
128 ///
129 /// # Errors
130 /// Returns [`FerrotorchError::InvalidArgument`] if either channel
131 /// count is zero. Errors from [`Conv2d::new`] (negative shape,
132 /// allocation failure) bubble up.
133 pub fn new_with_activations(
134 input_channels: usize,
135 squeeze_channels: usize,
136 activation: Box<dyn Module<T>>,
137 scale_activation: Box<dyn Module<T>>,
138 ) -> FerrotorchResult<Self> {
139 if input_channels == 0 || squeeze_channels == 0 {
140 return Err(FerrotorchError::InvalidArgument {
141 message: format!(
142 "SqueezeExcitation: input_channels and squeeze_channels must be > 0 \
143 (got input_channels={input_channels}, squeeze_channels={squeeze_channels})"
144 ),
145 });
146 }
147
148 // 1×1 convolutions with bias — torchvision's `SqueezeExcitation`
149 // uses bias=True for both fc1 and fc2 (they appear unset, which
150 // defaults to True in `nn.Conv2d`).
151 let fc1 = Conv2d::new(
152 input_channels,
153 squeeze_channels,
154 (1, 1),
155 (1, 1),
156 (0, 0),
157 true,
158 )?;
159 let fc2 = Conv2d::new(
160 squeeze_channels,
161 input_channels,
162 (1, 1),
163 (1, 1),
164 (0, 0),
165 true,
166 )?;
167 let avgpool = AdaptiveAvgPool2d::new((1, 1));
168
169 Ok(Self {
170 avgpool,
171 fc1,
172 activation,
173 fc2,
174 scale_activation,
175 training: true,
176 })
177 }
178
179 /// Forward pass.
180 ///
181 /// `input` must be 4-D `[B, C, H, W]` with `C == input_channels`.
182 pub fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
183 // 1) Squeeze: global average pool to [B, C, 1, 1].
184 let scale = Module::<T>::forward(&self.avgpool, input)?;
185 // 2) fc1 → activation → fc2.
186 let scale = self.fc1.forward(&scale)?;
187 let scale = self.activation.forward(&scale)?;
188 let scale = self.fc2.forward(&scale)?;
189 // 3) scale_activation gates the [B, C, 1, 1] tensor.
190 let scale = self.scale_activation.forward(&scale)?;
191 // 4) Broadcast multiply: [B, C, H, W] * [B, C, 1, 1].
192 mul(input, &scale)
193 }
194}
195
196impl<T: Float> Module<T> for SqueezeExcitation<T> {
197 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
198 self.forward(input)
199 }
200
201 fn parameters(&self) -> Vec<&Parameter<T>> {
202 let mut p = Vec::new();
203 p.extend(self.fc1.parameters());
204 p.extend(self.fc2.parameters());
205 p
206 }
207
208 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
209 let mut p = Vec::new();
210 p.extend(self.fc1.parameters_mut());
211 p.extend(self.fc2.parameters_mut());
212 p
213 }
214
215 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
216 let mut p = Vec::new();
217 for (n, param) in self.fc1.named_parameters() {
218 p.push((format!("fc1.{n}"), param));
219 }
220 for (n, param) in self.fc2.named_parameters() {
221 p.push((format!("fc2.{n}"), param));
222 }
223 p
224 }
225
226 fn children(&self) -> Vec<&dyn Module<T>> {
227 vec![
228 &self.avgpool,
229 &self.fc1,
230 self.activation.as_ref(),
231 &self.fc2,
232 self.scale_activation.as_ref(),
233 ]
234 }
235
236 fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
237 vec![
238 ("avgpool".to_string(), &self.avgpool as &dyn Module<T>),
239 ("fc1".to_string(), &self.fc1),
240 ("activation".to_string(), self.activation.as_ref()),
241 ("fc2".to_string(), &self.fc2),
242 (
243 "scale_activation".to_string(),
244 self.scale_activation.as_ref(),
245 ),
246 ]
247 }
248
249 fn train(&mut self) {
250 self.training = true;
251 self.activation.train();
252 self.scale_activation.train();
253 }
254
255 fn eval(&mut self) {
256 self.training = false;
257 self.activation.eval();
258 self.scale_activation.eval();
259 }
260
261 fn is_training(&self) -> bool {
262 self.training
263 }
264}
265
266// ===========================================================================
267// Tests
268// ===========================================================================
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::activation::{HardSigmoid, SiLU};
274 use ferrotorch_core::storage::TensorStorage;
275
276 fn cpu_tensor_4d(data: Vec<f32>, shape: [usize; 4]) -> Tensor<f32> {
277 Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false).unwrap()
278 }
279
280 #[test]
281 fn se_construction_smoke() {
282 let se = SqueezeExcitation::<f32>::new(16, 4).expect("SE construction");
283 assert_eq!(se.fc1.parameters().len(), 2); // weight + bias
284 assert_eq!(se.fc2.parameters().len(), 2);
285 }
286
287 /// Evidence #6: named_parameters returns exactly fc1.{weight,bias},
288 /// fc2.{weight,bias} — matching torchvision's SqueezeExcitation
289 /// state_dict key set.
290 #[test]
291 fn se_named_parameters_match_torchvision() {
292 let se = SqueezeExcitation::<f32>::new(16, 4).unwrap();
293 let names: Vec<String> = se.named_parameters().into_iter().map(|(n, _)| n).collect();
294 assert_eq!(
295 names,
296 vec![
297 "fc1.weight".to_string(),
298 "fc1.bias".to_string(),
299 "fc2.weight".to_string(),
300 "fc2.bias".to_string(),
301 ]
302 );
303 }
304
305 /// Evidence #7: named_children order matches torchvision's
306 /// `(avgpool, fc1, activation, fc2, scale_activation)`.
307 #[test]
308 fn se_named_children_match_torchvision_order() {
309 let se = SqueezeExcitation::<f32>::new(16, 4).unwrap();
310 let names: Vec<String> = se.named_children().into_iter().map(|(n, _)| n).collect();
311 assert_eq!(
312 names,
313 vec![
314 "avgpool".to_string(),
315 "fc1".to_string(),
316 "activation".to_string(),
317 "fc2".to_string(),
318 "scale_activation".to_string(),
319 ]
320 );
321 }
322
323 /// Evidence #4: SE primitive forward equals manually-composed
324 /// AdaptiveAvgPool2d + Conv2d(1×1) + ReLU + Conv2d(1×1) + Sigmoid +
325 /// broadcast multiply.
326 #[test]
327 fn se_forward_matches_manual_composition() {
328 // Build SE block then a parallel manual pipeline that shares its
329 // weights. Forward both on a deterministic input and check
330 // bitwise (or near-bitwise) equality.
331 let mut se = SqueezeExcitation::<f32>::new(8, 2).unwrap();
332
333 // Replace fc1, fc2 weights with deterministic small values so the
334 // intermediate magnitudes stay finite.
335 let fc1_weight = Tensor::from_storage(
336 TensorStorage::cpu(vec![0.05_f32; 2 * 8]),
337 vec![2, 8, 1, 1],
338 false,
339 )
340 .unwrap();
341 let fc1_bias =
342 Tensor::from_storage(TensorStorage::cpu(vec![0.01_f32; 2]), vec![2], false).unwrap();
343 let fc2_weight = Tensor::from_storage(
344 TensorStorage::cpu(vec![0.07_f32; 8 * 2]),
345 vec![8, 2, 1, 1],
346 false,
347 )
348 .unwrap();
349 let fc2_bias =
350 Tensor::from_storage(TensorStorage::cpu(vec![0.02_f32; 8]), vec![8], false).unwrap();
351
352 se.fc1
353 .set_weight(Parameter::new(fc1_weight.clone()))
354 .unwrap();
355 // Conv2d::set_weight only validates shape; bias is inaccessible
356 // through the public API. Re-build fc1/fc2 from_parts to inject
357 // bias deterministically.
358 let new_fc1 =
359 Conv2d::from_parts(fc1_weight, Some(fc1_bias.clone()), (1, 1), (0, 0)).unwrap();
360 let new_fc2 =
361 Conv2d::from_parts(fc2_weight, Some(fc2_bias.clone()), (1, 1), (0, 0)).unwrap();
362 se.fc1 = new_fc1;
363 se.fc2 = new_fc2;
364
365 // 1×8×4×4 input, deterministic.
366 let n = /* B*C*H*W = 1*8*4*4 */ 8 * 4 * 4;
367 let data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.01).collect();
368 let x = cpu_tensor_4d(data.clone(), [1, 8, 4, 4]);
369
370 let out_se = se.forward(&x).unwrap();
371
372 // Manual pipeline: AdaptiveAvgPool2d → Conv2d(fc1) → ReLU →
373 // Conv2d(fc2) → Sigmoid → mul.
374 let pool = AdaptiveAvgPool2d::new((1, 1));
375 let m_relu = ReLU::new();
376 let m_sig = Sigmoid::new();
377 let p = Module::<f32>::forward(&pool, &x).unwrap();
378 let p = se.fc1.forward(&p).unwrap();
379 let p = m_relu.forward(&p).unwrap();
380 let p = se.fc2.forward(&p).unwrap();
381 let p = m_sig.forward(&p).unwrap();
382 let manual = mul(&x, &p).unwrap();
383
384 let a = out_se.data().unwrap();
385 let m = manual.data().unwrap();
386 assert_eq!(a.len(), m.len());
387 for i in 0..a.len() {
388 assert!(
389 (a[i] - m[i]).abs() < 1e-6,
390 "SE primitive vs manual mismatch at {i}: se={} manual={}",
391 a[i],
392 m[i]
393 );
394 }
395 }
396
397 /// Probe-before-fix (Evidence #3): hand-computed reference for a
398 /// trivial 1×4×8×8 input where every element is 1.0, with both fc
399 /// weights set to 0 and biases tuned so the gate is 0.5 → output
400 /// is 0.5 * input = 0.5 everywhere.
401 #[test]
402 fn se_probe_handcomputed_reference() {
403 // Inputs all 1.0, fc1 weights/bias = 0, fc2 weights/bias = 0.
404 // After avgpool: [1,4,1,1] all 1.0.
405 // After fc1: [1,2,1,1] all 0.0.
406 // After ReLU: [1,2,1,1] all 0.0.
407 // After fc2: [1,4,1,1] all 0.0.
408 // After Sigmoid: [1,4,1,1] all 0.5.
409 // Final: input * 0.5 = 0.5 everywhere.
410 let mut se = SqueezeExcitation::<f32>::new(4, 2).unwrap();
411 let fc1_weight = Tensor::from_storage(
412 TensorStorage::cpu(vec![0.0_f32; 2 * 4]),
413 vec![2, 4, 1, 1],
414 false,
415 )
416 .unwrap();
417 let fc1_bias =
418 Tensor::from_storage(TensorStorage::cpu(vec![0.0_f32; 2]), vec![2], false).unwrap();
419 let fc2_weight = Tensor::from_storage(
420 TensorStorage::cpu(vec![0.0_f32; 4 * 2]),
421 vec![4, 2, 1, 1],
422 false,
423 )
424 .unwrap();
425 let fc2_bias =
426 Tensor::from_storage(TensorStorage::cpu(vec![0.0_f32; 4]), vec![4], false).unwrap();
427 se.fc1 = Conv2d::from_parts(fc1_weight, Some(fc1_bias), (1, 1), (0, 0)).unwrap();
428 se.fc2 = Conv2d::from_parts(fc2_weight, Some(fc2_bias), (1, 1), (0, 0)).unwrap();
429
430 let n = /* B*C*H*W = 1*4*8*8 */ 4 * 8 * 8;
431 let x = cpu_tensor_4d(vec![1.0_f32; n], [1, 4, 8, 8]);
432 let out = se.forward(&x).unwrap();
433 let data = out.data().unwrap();
434 for &v in data.iter() {
435 assert!(
436 (v - 0.5).abs() < 1e-6,
437 "expected gate output 0.5 everywhere, got {v}"
438 );
439 }
440 }
441
442 /// Evidence #5: backward by finite differences vs analytic gradient
443 /// (small input; tolerance 1e-2). Confirms forward is differentiable
444 /// end-to-end.
445 #[test]
446 fn se_backward_finite_differences() {
447 use ferrotorch_core::grad_fns::reduction::sum;
448 // Build SE block on a small input.
449 let se = SqueezeExcitation::<f32>::new(4, 2).unwrap();
450
451 let n = /* B*C*H*W = 1*4*4*4 */ 4 * 4 * 4;
452 let data: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.05).sin()).collect();
453 let x =
454 Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![1, 4, 4, 4], true).unwrap();
455
456 let out = se.forward(&x).unwrap();
457 let loss = sum(&out).unwrap();
458 loss.backward().unwrap();
459 let grad = x.grad().unwrap().expect("x should carry grad");
460
461 // FD on a few elements.
462 let analytic = grad.data().unwrap().to_vec();
463 let h = 1e-3_f32;
464
465 for &i in &[0_usize, 7, 25, 50, n - 1] {
466 let mut p = data.clone();
467 p[i] += h;
468 let xp = Tensor::from_storage(TensorStorage::cpu(p), vec![1, 4, 4, 4], false).unwrap();
469 let mut m = data.clone();
470 m[i] -= h;
471 let xm = Tensor::from_storage(TensorStorage::cpu(m), vec![1, 4, 4, 4], false).unwrap();
472 let lp: f32 = se.forward(&xp).unwrap().data().unwrap().iter().sum();
473 let lm: f32 = se.forward(&xm).unwrap().data().unwrap().iter().sum();
474 let fd = (lp - lm) / (2.0 * h);
475 assert!(
476 (analytic[i] - fd).abs() < 1e-2,
477 "SE backward FD mismatch at {i}: analytic={} fd={}",
478 analytic[i],
479 fd
480 );
481 }
482 }
483
484 #[test]
485 fn se_with_hardsigmoid_scale_smoke() {
486 // V3-style: ReLU + HardSigmoid.
487 let se: SqueezeExcitation<f32> = SqueezeExcitation::new_with_activations(
488 8,
489 2,
490 Box::new(ReLU::new()),
491 Box::new(HardSigmoid::new()),
492 )
493 .unwrap();
494 let x = cpu_tensor_4d(vec![0.1_f32; 8 * 6 * 6], [1, 8, 6, 6]);
495 let out = se.forward(&x).unwrap();
496 assert_eq!(out.shape(), &[1, 8, 6, 6]);
497 }
498
499 #[test]
500 fn se_with_silu_sigmoid_smoke() {
501 // EfficientNet-style: SiLU + Sigmoid.
502 let se: SqueezeExcitation<f32> = SqueezeExcitation::new_with_activations(
503 16,
504 4,
505 Box::new(SiLU::new()),
506 Box::new(Sigmoid::new()),
507 )
508 .unwrap();
509 let x = cpu_tensor_4d(vec![0.05_f32; 16 * 4 * 4], [1, 16, 4, 4]);
510 let out = se.forward(&x).unwrap();
511 assert_eq!(out.shape(), &[1, 16, 4, 4]);
512 }
513
514 /// Validate that SE is `Send + Sync` so it can compose into any
515 /// model whose `Module` bound requires both. (`Box<dyn Module<T>>`
516 /// inside the struct must propagate these bounds.)
517 #[test]
518 fn se_is_send_sync() {
519 fn assert_send_sync<T: Send + Sync>() {}
520 assert_send_sync::<SqueezeExcitation<f32>>();
521 }
522}