Skip to main content

oxicuda_vision/blocks/
mbconv.rs

1//! EfficientNet MBConv (Mobile Inverted Bottleneck Convolution) block.
2//!
3//! Sandler et al. 2018 "MobileNetV2: Inverted Residuals and Linear Bottlenecks".
4//! Tan & Le 2019 "EfficientNet: Rethinking Model Scaling for CNNs".
5//!
6//! MBConv pipeline:
7//! ```text
8//! Expand  (1×1, in_ch  → in_ch * expand_ratio, ReLU6)
9//! DW-Conv (k×k, depthwise, same groups, ReLU6)
10//! SE      (squeeze-excite gate)
11//! Project (1×1, in_ch * expand_ratio → out_ch, linear)
12//! Skip    (if stride==1 && in_ch==out_ch, add input)
13//! ```
14//!
15//! For CPU unit testing we use a 1-D "spatial" abstraction:
16//! the input tensor is treated as `[batch_size × in_channels]` and the output
17//! as `[batch_size × out_channels]`.  This preserves all channel-wise
18//! computations (expand, SE gate, project, skip) while avoiding the
19//! full 2-D convolution that is prohibitively expensive in pure Rust.
20
21use crate::error::{VisionError, VisionResult};
22use crate::handle::LcgRng;
23
24/// Type alias for the RNG used by this module.
25pub type VisionRng = LcgRng;
26
27// ─── Activation helpers ───────────────────────────────────────────────────────
28
29/// ReLU6: clamp to `[0, 6]`.
30#[inline]
31fn relu6(x: f32) -> f32 {
32    x.clamp(0.0, 6.0)
33}
34
35/// Sigmoid: numerically stable.
36#[inline]
37fn sigmoid(x: f32) -> f32 {
38    if x >= 0.0 {
39        1.0 / (1.0 + (-x).exp())
40    } else {
41        let e = x.exp();
42        e / (1.0 + e)
43    }
44}
45
46// ─── Configuration ────────────────────────────────────────────────────────────
47
48/// Configuration for a single MBConv block.
49#[derive(Debug, Clone)]
50pub struct MbConvConfig {
51    /// Input channel count.
52    pub in_channels: usize,
53    /// Output channel count.
54    pub out_channels: usize,
55    /// Expansion ratio for the hidden channel dimension (≥ 1).
56    pub expand_ratio: usize,
57    /// Stride (1 = same spatial size, 2 = spatial downsampling).
58    pub stride: usize,
59    /// Depthwise convolution kernel size (3 or 5 in EfficientNet).
60    pub kernel_size: usize,
61    /// Squeeze-Excite reduction ratio (e.g. 0.25).
62    pub se_ratio: f32,
63    /// Input spatial height (used for metadata / future spatial convs).
64    pub h: usize,
65    /// Input spatial width (used for metadata / future spatial convs).
66    pub w: usize,
67}
68
69impl MbConvConfig {
70    /// Expanded channel count.
71    #[must_use]
72    pub fn expanded_channels(&self) -> usize {
73        self.in_channels * self.expand_ratio
74    }
75
76    /// Squeeze-Excite hidden channel count: max(1, round(in_ch * se_ratio)).
77    #[must_use]
78    pub fn se_channels(&self) -> usize {
79        let se = (self.in_channels as f32 * self.se_ratio).round() as usize;
80        se.max(1)
81    }
82
83    /// Whether a skip connection is applicable.
84    #[must_use]
85    pub fn has_skip(&self) -> bool {
86        self.stride == 1 && self.in_channels == self.out_channels
87    }
88}
89
90// ─── MbConvBlock ─────────────────────────────────────────────────────────────
91
92/// EfficientNet MBConv inverted residual block.
93///
94/// Weight layout (row-major):
95/// - `expand_w`: `[exp_ch × in_ch]`  (1×1 expand conv, bias: `expand_b [exp_ch]`)
96/// - `dw_w`    : `[exp_ch × k × k]`  (depthwise weights, bias: `dw_b [exp_ch]`)
97/// - `se_fc1_w`: `[se_ch × exp_ch]`  (SE squeeze, bias: `se_fc1_b [se_ch]`)
98/// - `se_fc2_w`: `[exp_ch × se_ch]`  (SE excite,  bias: `se_fc2_b [exp_ch]`)
99/// - `proj_w`  : `[out_ch × exp_ch]` (1×1 project conv, bias: `proj_b [out_ch]`)
100pub struct MbConvBlock {
101    expand_w: Vec<f32>,
102    expand_b: Vec<f32>,
103    dw_w: Vec<f32>,
104    dw_b: Vec<f32>,
105    se_fc1_w: Vec<f32>,
106    se_fc1_b: Vec<f32>,
107    se_fc2_w: Vec<f32>,
108    se_fc2_b: Vec<f32>,
109    proj_w: Vec<f32>,
110    proj_b: Vec<f32>,
111    config: MbConvConfig,
112    has_skip: bool,
113}
114
115impl MbConvBlock {
116    /// Construct a new MBConv block with Xavier-uniform weight initialization.
117    ///
118    /// # Errors
119    ///
120    /// - [`VisionError::InvalidImageSize`] if `in_channels` or `out_channels` is 0.
121    /// - [`VisionError::InvalidEmbedDim`] if `expand_ratio` is 0.
122    /// - [`VisionError::NonPositiveTemperature`] if `se_ratio ≤ 0`.
123    pub fn new(config: MbConvConfig, rng: &mut VisionRng) -> VisionResult<Self> {
124        if config.in_channels == 0 || config.out_channels == 0 {
125            return Err(VisionError::InvalidImageSize {
126                height: config.h,
127                width: config.w,
128                channels: config.in_channels,
129            });
130        }
131        if config.expand_ratio == 0 {
132            return Err(VisionError::InvalidEmbedDim(0));
133        }
134        if config.se_ratio <= 0.0 {
135            return Err(VisionError::NonPositiveTemperature(config.se_ratio));
136        }
137
138        let in_ch = config.in_channels;
139        let exp_ch = config.expanded_channels();
140        let out_ch = config.out_channels;
141        let k = config.kernel_size;
142        let se_ch = config.se_channels();
143        let has_skip = config.has_skip();
144
145        let xavier = |fan_in: usize, fan_out: usize, rng: &mut VisionRng| -> Vec<f32> {
146            let limit = (6.0_f32 / (fan_in + fan_out) as f32).sqrt();
147            let n = fan_out * fan_in;
148            (0..n)
149                .map(|_| (rng.next_f32() * 2.0 - 1.0) * limit)
150                .collect()
151        };
152
153        // Expand: [exp_ch × in_ch]
154        let expand_w = xavier(in_ch, exp_ch, rng);
155        let expand_b = vec![0.0_f32; exp_ch];
156
157        // Depthwise: [exp_ch × k × k]
158        let dw_w = xavier(k * k, exp_ch, rng);
159        let dw_b = vec![0.0_f32; exp_ch];
160
161        // SE fc1: [se_ch × exp_ch]
162        let se_fc1_w = xavier(exp_ch, se_ch, rng);
163        let se_fc1_b = vec![0.0_f32; se_ch];
164
165        // SE fc2: [exp_ch × se_ch]
166        let se_fc2_w = xavier(se_ch, exp_ch, rng);
167        let se_fc2_b = vec![0.0_f32; exp_ch];
168
169        // Project: [out_ch × exp_ch]
170        let proj_w = xavier(exp_ch, out_ch, rng);
171        let proj_b = vec![0.0_f32; out_ch];
172
173        Ok(Self {
174            expand_w,
175            expand_b,
176            dw_w,
177            dw_b,
178            se_fc1_w,
179            se_fc1_b,
180            se_fc2_w,
181            se_fc2_b,
182            proj_w,
183            proj_b,
184            config,
185            has_skip,
186        })
187    }
188
189    /// Whether this block uses a skip (residual) connection.
190    #[must_use]
191    pub fn has_skip(&self) -> bool {
192        self.has_skip
193    }
194
195    /// Forward pass.
196    ///
197    /// # Input / output layout
198    ///
199    /// `x`: `[batch_size × in_channels]` row-major.
200    /// Returns: `[batch_size × out_channels]` row-major.
201    ///
202    /// ## Pipeline (per sample)
203    ///
204    /// 1. **Expand** : `h = ReLU6(W_exp · x + b_exp)`
205    ///    output: `[exp_ch]`
206    /// 2. **Depthwise** : channel-wise multiplication by `dw_w[c, :]` mean (1D proxy),
207    ///    then bias + ReLU6.
208    ///    Full 2-D depthwise conv would require H×W spatial input; here we use
209    ///    `dw_w[c]` as a per-channel scale (mean of k×k filter weights).
210    /// 3. **SE** : global pool → FC1+ReLU → FC2+sigmoid → broadcast multiply.
211    /// 4. **Project** : `out = W_proj · h_se + b_proj`  (no activation).
212    /// 5. **Skip** : if `has_skip`, `out += x`.
213    ///
214    /// # Errors
215    ///
216    /// - [`VisionError::DimensionMismatch`] if `x.len() != batch_size * in_channels`.
217    pub fn forward(&self, x: &[f32], batch_size: usize) -> VisionResult<Vec<f32>> {
218        let in_ch = self.config.in_channels;
219        let exp_ch = self.config.expanded_channels();
220        let out_ch = self.config.out_channels;
221        let se_ch = self.config.se_channels();
222        let k = self.config.kernel_size;
223
224        if x.len() != batch_size * in_ch {
225            return Err(VisionError::DimensionMismatch {
226                expected: batch_size * in_ch,
227                got: x.len(),
228            });
229        }
230
231        let mut out = vec![0.0_f32; batch_size * out_ch];
232
233        for b in 0..batch_size {
234            let x_row = &x[b * in_ch..(b + 1) * in_ch];
235
236            // ── 1. Expand ──────────────────────────────────────────────────
237            let h_exp: Vec<f32> = (0..exp_ch)
238                .map(|i| {
239                    let acc = self.expand_b[i]
240                        + x_row
241                            .iter()
242                            .enumerate()
243                            .map(|(j, &xj)| self.expand_w[i * in_ch + j] * xj)
244                            .sum::<f32>();
245                    relu6(acc)
246                })
247                .collect();
248
249            // ── 2. Depthwise (1-D proxy) ───────────────────────────────────
250            // Per-channel scale = mean of the k×k filter weights.
251            let h_dw: Vec<f32> = (0..exp_ch)
252                .map(|c| {
253                    let w_slice = &self.dw_w[c * k * k..(c + 1) * k * k];
254                    let w_mean: f32 = w_slice.iter().sum::<f32>() / (k * k) as f32;
255                    relu6(h_exp[c] * w_mean + self.dw_b[c])
256                })
257                .collect();
258
259            // ── 3. Squeeze-Excite ──────────────────────────────────────────
260            // 3a. Global avg pool (1-D: pool = h_dw directly).
261            let pooled = &h_dw;
262
263            // 3b. FC1: [se_ch × exp_ch], ReLU
264            let se_h1: Vec<f32> = (0..se_ch)
265                .map(|i| {
266                    let acc = self.se_fc1_b[i]
267                        + pooled
268                            .iter()
269                            .enumerate()
270                            .map(|(j, &pj)| self.se_fc1_w[i * exp_ch + j] * pj)
271                            .sum::<f32>();
272                    acc.max(0.0)
273                })
274                .collect();
275
276            // 3c. FC2: [exp_ch × se_ch], sigmoid
277            let se_gate: Vec<f32> = (0..exp_ch)
278                .map(|i| {
279                    let acc = self.se_fc2_b[i]
280                        + se_h1
281                            .iter()
282                            .enumerate()
283                            .map(|(j, &sj)| self.se_fc2_w[i * se_ch + j] * sj)
284                            .sum::<f32>();
285                    sigmoid(acc)
286                })
287                .collect();
288
289            // 3d. Channel-wise gate multiply.
290            let h_se: Vec<f32> = h_dw
291                .iter()
292                .zip(se_gate.iter())
293                .map(|(&hd, &sg)| hd * sg)
294                .collect();
295
296            // ── 4. Project ─────────────────────────────────────────────────
297            let mut y: Vec<f32> = (0..out_ch)
298                .map(|i| {
299                    self.proj_b[i]
300                        + h_se
301                            .iter()
302                            .enumerate()
303                            .map(|(j, &hj)| self.proj_w[i * exp_ch + j] * hj)
304                            .sum::<f32>()
305                    // no activation
306                })
307                .collect();
308
309            // ── 5. Skip connection ─────────────────────────────────────────
310            if self.has_skip {
311                for (yi, &xi) in y.iter_mut().zip(x_row.iter()) {
312                    *yi += xi;
313                }
314            }
315
316            out[b * out_ch..(b + 1) * out_ch].copy_from_slice(&y);
317        }
318
319        Ok(out)
320    }
321}
322
323// ─── Tests ────────────────────────────────────────────────────────────────────
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use crate::handle::LcgRng;
329
330    fn rng() -> LcgRng {
331        LcgRng::new(42)
332    }
333
334    fn default_config() -> MbConvConfig {
335        MbConvConfig {
336            in_channels: 16,
337            out_channels: 16,
338            expand_ratio: 6,
339            stride: 1,
340            kernel_size: 3,
341            se_ratio: 0.25,
342            h: 8,
343            w: 8,
344        }
345    }
346
347    fn make_input(batch: usize, channels: usize, seed: u64) -> Vec<f32> {
348        let mut r = LcgRng::new(seed);
349        (0..batch * channels).map(|_| r.next_f32()).collect()
350    }
351
352    // 1 ─ output shape
353    #[test]
354    fn output_shape() {
355        let cfg = default_config();
356        let mut r = rng();
357        let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
358        let x = make_input(4, 16, 1);
359        let out = block.forward(&x, 4).expect("forward should succeed");
360        assert_eq!(out.len(), 4 * 16);
361    }
362
363    // 2 ─ output finite
364    #[test]
365    fn output_finite() {
366        let cfg = default_config();
367        let mut r = rng();
368        let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
369        let x = make_input(2, 16, 2);
370        let out = block.forward(&x, 2).expect("forward should succeed");
371        for (i, &v) in out.iter().enumerate() {
372            assert!(v.is_finite(), "out[{i}] = {v}");
373        }
374    }
375
376    // 3 ─ expand_ratio=1 works
377    #[test]
378    fn expand_ratio_1_works() {
379        let cfg = MbConvConfig {
380            in_channels: 8,
381            out_channels: 8,
382            expand_ratio: 1,
383            stride: 1,
384            kernel_size: 3,
385            se_ratio: 0.25,
386            h: 4,
387            w: 4,
388        };
389        let mut r = rng();
390        let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
391        let x = make_input(3, 8, 3);
392        let out = block.forward(&x, 3).expect("forward should succeed");
393        assert_eq!(out.len(), 3 * 8);
394    }
395
396    // 4 ─ has_skip correct (stride=1, same channels → skip)
397    #[test]
398    fn has_skip_correct_same_channels() {
399        let cfg = default_config(); // stride=1, in=out=16
400        let mut r = rng();
401        let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
402        assert!(block.has_skip());
403    }
404
405    // 5 ─ no skip when in_channels != out_channels
406    #[test]
407    fn no_skip_different_channels() {
408        let cfg = MbConvConfig {
409            in_channels: 8,
410            out_channels: 16,
411            expand_ratio: 6,
412            stride: 1,
413            kernel_size: 3,
414            se_ratio: 0.25,
415            h: 4,
416            w: 4,
417        };
418        let mut r = rng();
419        let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
420        assert!(!block.has_skip());
421    }
422
423    // 6 ─ relu6 clamps at 6
424    #[test]
425    fn relu6_clamps_at_6() {
426        assert!((relu6(10.0) - 6.0).abs() < 1e-7);
427        assert!((relu6(-1.0) - 0.0).abs() < 1e-7);
428        assert!((relu6(3.0) - 3.0).abs() < 1e-7);
429    }
430
431    // 7 ─ batch_size varies
432    #[test]
433    fn batch_size_varies() {
434        let cfg = default_config();
435        for &bs in &[1_usize, 2, 8] {
436            let mut r = LcgRng::new(bs as u64);
437            let block = MbConvBlock::new(cfg.clone(), &mut r).expect("value should be present");
438            let x = make_input(bs, 16, bs as u64);
439            let out = block.forward(&x, bs).expect("forward should succeed");
440            assert_eq!(out.len(), bs * 16);
441        }
442    }
443
444    // 8 ─ stride=2 config accepted (output shape unchanged in 1-D abstraction)
445    #[test]
446    fn stride_2_config_accepted() {
447        let cfg = MbConvConfig {
448            in_channels: 8,
449            out_channels: 16,
450            expand_ratio: 6,
451            stride: 2,
452            kernel_size: 5,
453            se_ratio: 0.25,
454            h: 8,
455            w: 8,
456        };
457        let mut r = rng();
458        let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
459        assert!(!block.has_skip()); // stride=2 → no skip
460        let x = make_input(2, 8, 9);
461        let out = block.forward(&x, 2).expect("forward should succeed");
462        assert_eq!(out.len(), 2 * 16);
463    }
464
465    // 9 ─ expand_ratio=0 returns error
466    #[test]
467    fn expand_ratio_0_error() {
468        let cfg = MbConvConfig {
469            in_channels: 8,
470            out_channels: 8,
471            expand_ratio: 0,
472            stride: 1,
473            kernel_size: 3,
474            se_ratio: 0.25,
475            h: 4,
476            w: 4,
477        };
478        let mut r = rng();
479        let result = MbConvBlock::new(cfg, &mut r);
480        assert!(result.is_err());
481    }
482
483    // 10 ─ se_ratio=0 returns error
484    #[test]
485    fn se_ratio_zero_error() {
486        let cfg = MbConvConfig {
487            in_channels: 8,
488            out_channels: 8,
489            expand_ratio: 6,
490            stride: 1,
491            kernel_size: 3,
492            se_ratio: 0.0,
493            h: 4,
494            w: 4,
495        };
496        let mut r = rng();
497        let result = MbConvBlock::new(cfg, &mut r);
498        assert!(result.is_err());
499    }
500
501    // 11 ─ se_ratio affects param count (different se_ch)
502    #[test]
503    fn se_ratio_affects_se_channels() {
504        let cfg1 = MbConvConfig {
505            se_ratio: 0.25,
506            ..default_config()
507        };
508        let cfg2 = MbConvConfig {
509            se_ratio: 0.5,
510            ..default_config()
511        };
512        assert!(cfg1.se_channels() < cfg2.se_channels());
513    }
514
515    // 12 ─ dimension mismatch error
516    #[test]
517    fn dimension_mismatch_error() {
518        let cfg = default_config();
519        let mut r = rng();
520        let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
521        let wrong_x = vec![0.0_f32; 2 * 8]; // wrong: 8 channels not 16
522        let result = block.forward(&wrong_x, 2);
523        assert!(result.is_err());
524    }
525}