use ferrotorch_core::grad_fns::arithmetic::mul;
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
use crate::activation::{ReLU, Sigmoid};
use crate::conv::Conv2d;
use crate::module::Module;
use crate::parameter::Parameter;
use crate::pooling::AdaptiveAvgPool2d;
pub struct SqueezeExcitation<T: Float> {
avgpool: AdaptiveAvgPool2d,
fc1: Conv2d<T>,
activation: Box<dyn Module<T>>,
fc2: Conv2d<T>,
scale_activation: Box<dyn Module<T>>,
training: bool,
}
impl<T: Float> std::fmt::Debug for SqueezeExcitation<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SqueezeExcitation")
.field("fc1", &self.fc1)
.field("fc2", &self.fc2)
.field("training", &self.training)
.finish()
}
}
impl<T: Float> SqueezeExcitation<T> {
pub fn new(input_channels: usize, squeeze_channels: usize) -> FerrotorchResult<Self> {
Self::new_with_activations(
input_channels,
squeeze_channels,
Box::new(ReLU::new()),
Box::new(Sigmoid::new()),
)
}
pub fn new_with_activations(
input_channels: usize,
squeeze_channels: usize,
activation: Box<dyn Module<T>>,
scale_activation: Box<dyn Module<T>>,
) -> FerrotorchResult<Self> {
if input_channels == 0 || squeeze_channels == 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"SqueezeExcitation: input_channels and squeeze_channels must be > 0 \
(got input_channels={input_channels}, squeeze_channels={squeeze_channels})"
),
});
}
let fc1 = Conv2d::new(
input_channels,
squeeze_channels,
(1, 1),
(1, 1),
(0, 0),
true,
)?;
let fc2 = Conv2d::new(
squeeze_channels,
input_channels,
(1, 1),
(1, 1),
(0, 0),
true,
)?;
let avgpool = AdaptiveAvgPool2d::new((1, 1));
Ok(Self {
avgpool,
fc1,
activation,
fc2,
scale_activation,
training: true,
})
}
pub fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let scale = Module::<T>::forward(&self.avgpool, input)?;
let scale = self.fc1.forward(&scale)?;
let scale = self.activation.forward(&scale)?;
let scale = self.fc2.forward(&scale)?;
let scale = self.scale_activation.forward(&scale)?;
mul(input, &scale)
}
}
impl<T: Float> Module<T> for SqueezeExcitation<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
self.forward(input)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut p = Vec::new();
p.extend(self.fc1.parameters());
p.extend(self.fc2.parameters());
p
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut p = Vec::new();
p.extend(self.fc1.parameters_mut());
p.extend(self.fc2.parameters_mut());
p
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut p = Vec::new();
for (n, param) in self.fc1.named_parameters() {
p.push((format!("fc1.{n}"), param));
}
for (n, param) in self.fc2.named_parameters() {
p.push((format!("fc2.{n}"), param));
}
p
}
fn children(&self) -> Vec<&dyn Module<T>> {
vec![
&self.avgpool,
&self.fc1,
self.activation.as_ref(),
&self.fc2,
self.scale_activation.as_ref(),
]
}
fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
vec![
("avgpool".to_string(), &self.avgpool as &dyn Module<T>),
("fc1".to_string(), &self.fc1),
("activation".to_string(), self.activation.as_ref()),
("fc2".to_string(), &self.fc2),
(
"scale_activation".to_string(),
self.scale_activation.as_ref(),
),
]
}
fn train(&mut self) {
self.training = true;
self.activation.train();
self.scale_activation.train();
}
fn eval(&mut self) {
self.training = false;
self.activation.eval();
self.scale_activation.eval();
}
fn is_training(&self) -> bool {
self.training
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::activation::{HardSigmoid, SiLU};
use ferrotorch_core::storage::TensorStorage;
fn cpu_tensor_4d(data: Vec<f32>, shape: [usize; 4]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false).unwrap()
}
#[test]
fn se_construction_smoke() {
let se = SqueezeExcitation::<f32>::new(16, 4).expect("SE construction");
assert_eq!(se.fc1.parameters().len(), 2); assert_eq!(se.fc2.parameters().len(), 2);
}
#[test]
fn se_named_parameters_match_torchvision() {
let se = SqueezeExcitation::<f32>::new(16, 4).unwrap();
let names: Vec<String> = se.named_parameters().into_iter().map(|(n, _)| n).collect();
assert_eq!(
names,
vec![
"fc1.weight".to_string(),
"fc1.bias".to_string(),
"fc2.weight".to_string(),
"fc2.bias".to_string(),
]
);
}
#[test]
fn se_named_children_match_torchvision_order() {
let se = SqueezeExcitation::<f32>::new(16, 4).unwrap();
let names: Vec<String> = se.named_children().into_iter().map(|(n, _)| n).collect();
assert_eq!(
names,
vec![
"avgpool".to_string(),
"fc1".to_string(),
"activation".to_string(),
"fc2".to_string(),
"scale_activation".to_string(),
]
);
}
#[test]
fn se_forward_matches_manual_composition() {
let mut se = SqueezeExcitation::<f32>::new(8, 2).unwrap();
let fc1_weight = Tensor::from_storage(
TensorStorage::cpu(vec![0.05_f32; 2 * 8]),
vec![2, 8, 1, 1],
false,
)
.unwrap();
let fc1_bias =
Tensor::from_storage(TensorStorage::cpu(vec![0.01_f32; 2]), vec![2], false).unwrap();
let fc2_weight = Tensor::from_storage(
TensorStorage::cpu(vec![0.07_f32; 8 * 2]),
vec![8, 2, 1, 1],
false,
)
.unwrap();
let fc2_bias =
Tensor::from_storage(TensorStorage::cpu(vec![0.02_f32; 8]), vec![8], false).unwrap();
se.fc1
.set_weight(Parameter::new(fc1_weight.clone()))
.unwrap();
let new_fc1 =
Conv2d::from_parts(fc1_weight, Some(fc1_bias.clone()), (1, 1), (0, 0)).unwrap();
let new_fc2 =
Conv2d::from_parts(fc2_weight, Some(fc2_bias.clone()), (1, 1), (0, 0)).unwrap();
se.fc1 = new_fc1;
se.fc2 = new_fc2;
let n = 8 * 4 * 4;
let data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.01).collect();
let x = cpu_tensor_4d(data.clone(), [1, 8, 4, 4]);
let out_se = se.forward(&x).unwrap();
let pool = AdaptiveAvgPool2d::new((1, 1));
let m_relu = ReLU::new();
let m_sig = Sigmoid::new();
let p = Module::<f32>::forward(&pool, &x).unwrap();
let p = se.fc1.forward(&p).unwrap();
let p = m_relu.forward(&p).unwrap();
let p = se.fc2.forward(&p).unwrap();
let p = m_sig.forward(&p).unwrap();
let manual = mul(&x, &p).unwrap();
let a = out_se.data().unwrap();
let m = manual.data().unwrap();
assert_eq!(a.len(), m.len());
for i in 0..a.len() {
assert!(
(a[i] - m[i]).abs() < 1e-6,
"SE primitive vs manual mismatch at {i}: se={} manual={}",
a[i],
m[i]
);
}
}
#[test]
fn se_probe_handcomputed_reference() {
let mut se = SqueezeExcitation::<f32>::new(4, 2).unwrap();
let fc1_weight = Tensor::from_storage(
TensorStorage::cpu(vec![0.0_f32; 2 * 4]),
vec![2, 4, 1, 1],
false,
)
.unwrap();
let fc1_bias =
Tensor::from_storage(TensorStorage::cpu(vec![0.0_f32; 2]), vec![2], false).unwrap();
let fc2_weight = Tensor::from_storage(
TensorStorage::cpu(vec![0.0_f32; 4 * 2]),
vec![4, 2, 1, 1],
false,
)
.unwrap();
let fc2_bias =
Tensor::from_storage(TensorStorage::cpu(vec![0.0_f32; 4]), vec![4], false).unwrap();
se.fc1 = Conv2d::from_parts(fc1_weight, Some(fc1_bias), (1, 1), (0, 0)).unwrap();
se.fc2 = Conv2d::from_parts(fc2_weight, Some(fc2_bias), (1, 1), (0, 0)).unwrap();
let n = 4 * 8 * 8;
let x = cpu_tensor_4d(vec![1.0_f32; n], [1, 4, 8, 8]);
let out = se.forward(&x).unwrap();
let data = out.data().unwrap();
for &v in data.iter() {
assert!(
(v - 0.5).abs() < 1e-6,
"expected gate output 0.5 everywhere, got {v}"
);
}
}
#[test]
fn se_backward_finite_differences() {
use ferrotorch_core::grad_fns::reduction::sum;
let se = SqueezeExcitation::<f32>::new(4, 2).unwrap();
let n = 4 * 4 * 4;
let data: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.05).sin()).collect();
let x =
Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![1, 4, 4, 4], true).unwrap();
let out = se.forward(&x).unwrap();
let loss = sum(&out).unwrap();
loss.backward().unwrap();
let grad = x.grad().unwrap().expect("x should carry grad");
let analytic = grad.data().unwrap().to_vec();
let h = 1e-3_f32;
for &i in &[0_usize, 7, 25, 50, n - 1] {
let mut p = data.clone();
p[i] += h;
let xp = Tensor::from_storage(TensorStorage::cpu(p), vec![1, 4, 4, 4], false).unwrap();
let mut m = data.clone();
m[i] -= h;
let xm = Tensor::from_storage(TensorStorage::cpu(m), vec![1, 4, 4, 4], false).unwrap();
let lp: f32 = se.forward(&xp).unwrap().data().unwrap().iter().sum();
let lm: f32 = se.forward(&xm).unwrap().data().unwrap().iter().sum();
let fd = (lp - lm) / (2.0 * h);
assert!(
(analytic[i] - fd).abs() < 1e-2,
"SE backward FD mismatch at {i}: analytic={} fd={}",
analytic[i],
fd
);
}
}
#[test]
fn se_with_hardsigmoid_scale_smoke() {
let se: SqueezeExcitation<f32> = SqueezeExcitation::new_with_activations(
8,
2,
Box::new(ReLU::new()),
Box::new(HardSigmoid::new()),
)
.unwrap();
let x = cpu_tensor_4d(vec![0.1_f32; 8 * 6 * 6], [1, 8, 6, 6]);
let out = se.forward(&x).unwrap();
assert_eq!(out.shape(), &[1, 8, 6, 6]);
}
#[test]
fn se_with_silu_sigmoid_smoke() {
let se: SqueezeExcitation<f32> = SqueezeExcitation::new_with_activations(
16,
4,
Box::new(SiLU::new()),
Box::new(Sigmoid::new()),
)
.unwrap();
let x = cpu_tensor_4d(vec![0.05_f32; 16 * 4 * 4], [1, 16, 4, 4]);
let out = se.forward(&x).unwrap();
assert_eq!(out.shape(), &[1, 16, 4, 4]);
}
#[test]
fn se_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<SqueezeExcitation<f32>>();
}
}