use crate::error::{VisionError, VisionResult};
use crate::handle::LcgRng;
pub type VisionRng = LcgRng;
#[inline]
fn relu6(x: f32) -> f32 {
x.clamp(0.0, 6.0)
}
#[inline]
fn sigmoid(x: f32) -> f32 {
if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let e = x.exp();
e / (1.0 + e)
}
}
#[derive(Debug, Clone)]
pub struct MbConvConfig {
pub in_channels: usize,
pub out_channels: usize,
pub expand_ratio: usize,
pub stride: usize,
pub kernel_size: usize,
pub se_ratio: f32,
pub h: usize,
pub w: usize,
}
impl MbConvConfig {
#[must_use]
pub fn expanded_channels(&self) -> usize {
self.in_channels * self.expand_ratio
}
#[must_use]
pub fn se_channels(&self) -> usize {
let se = (self.in_channels as f32 * self.se_ratio).round() as usize;
se.max(1)
}
#[must_use]
pub fn has_skip(&self) -> bool {
self.stride == 1 && self.in_channels == self.out_channels
}
}
pub struct MbConvBlock {
expand_w: Vec<f32>,
expand_b: Vec<f32>,
dw_w: Vec<f32>,
dw_b: Vec<f32>,
se_fc1_w: Vec<f32>,
se_fc1_b: Vec<f32>,
se_fc2_w: Vec<f32>,
se_fc2_b: Vec<f32>,
proj_w: Vec<f32>,
proj_b: Vec<f32>,
config: MbConvConfig,
has_skip: bool,
}
impl MbConvBlock {
pub fn new(config: MbConvConfig, rng: &mut VisionRng) -> VisionResult<Self> {
if config.in_channels == 0 || config.out_channels == 0 {
return Err(VisionError::InvalidImageSize {
height: config.h,
width: config.w,
channels: config.in_channels,
});
}
if config.expand_ratio == 0 {
return Err(VisionError::InvalidEmbedDim(0));
}
if config.se_ratio <= 0.0 {
return Err(VisionError::NonPositiveTemperature(config.se_ratio));
}
let in_ch = config.in_channels;
let exp_ch = config.expanded_channels();
let out_ch = config.out_channels;
let k = config.kernel_size;
let se_ch = config.se_channels();
let has_skip = config.has_skip();
let xavier = |fan_in: usize, fan_out: usize, rng: &mut VisionRng| -> Vec<f32> {
let limit = (6.0_f32 / (fan_in + fan_out) as f32).sqrt();
let n = fan_out * fan_in;
(0..n)
.map(|_| (rng.next_f32() * 2.0 - 1.0) * limit)
.collect()
};
let expand_w = xavier(in_ch, exp_ch, rng);
let expand_b = vec![0.0_f32; exp_ch];
let dw_w = xavier(k * k, exp_ch, rng);
let dw_b = vec![0.0_f32; exp_ch];
let se_fc1_w = xavier(exp_ch, se_ch, rng);
let se_fc1_b = vec![0.0_f32; se_ch];
let se_fc2_w = xavier(se_ch, exp_ch, rng);
let se_fc2_b = vec![0.0_f32; exp_ch];
let proj_w = xavier(exp_ch, out_ch, rng);
let proj_b = vec![0.0_f32; out_ch];
Ok(Self {
expand_w,
expand_b,
dw_w,
dw_b,
se_fc1_w,
se_fc1_b,
se_fc2_w,
se_fc2_b,
proj_w,
proj_b,
config,
has_skip,
})
}
#[must_use]
pub fn has_skip(&self) -> bool {
self.has_skip
}
pub fn forward(&self, x: &[f32], batch_size: usize) -> VisionResult<Vec<f32>> {
let in_ch = self.config.in_channels;
let exp_ch = self.config.expanded_channels();
let out_ch = self.config.out_channels;
let se_ch = self.config.se_channels();
let k = self.config.kernel_size;
if x.len() != batch_size * in_ch {
return Err(VisionError::DimensionMismatch {
expected: batch_size * in_ch,
got: x.len(),
});
}
let mut out = vec![0.0_f32; batch_size * out_ch];
for b in 0..batch_size {
let x_row = &x[b * in_ch..(b + 1) * in_ch];
let h_exp: Vec<f32> = (0..exp_ch)
.map(|i| {
let acc = self.expand_b[i]
+ x_row
.iter()
.enumerate()
.map(|(j, &xj)| self.expand_w[i * in_ch + j] * xj)
.sum::<f32>();
relu6(acc)
})
.collect();
let h_dw: Vec<f32> = (0..exp_ch)
.map(|c| {
let w_slice = &self.dw_w[c * k * k..(c + 1) * k * k];
let w_mean: f32 = w_slice.iter().sum::<f32>() / (k * k) as f32;
relu6(h_exp[c] * w_mean + self.dw_b[c])
})
.collect();
let pooled = &h_dw;
let se_h1: Vec<f32> = (0..se_ch)
.map(|i| {
let acc = self.se_fc1_b[i]
+ pooled
.iter()
.enumerate()
.map(|(j, &pj)| self.se_fc1_w[i * exp_ch + j] * pj)
.sum::<f32>();
acc.max(0.0)
})
.collect();
let se_gate: Vec<f32> = (0..exp_ch)
.map(|i| {
let acc = self.se_fc2_b[i]
+ se_h1
.iter()
.enumerate()
.map(|(j, &sj)| self.se_fc2_w[i * se_ch + j] * sj)
.sum::<f32>();
sigmoid(acc)
})
.collect();
let h_se: Vec<f32> = h_dw
.iter()
.zip(se_gate.iter())
.map(|(&hd, &sg)| hd * sg)
.collect();
let mut y: Vec<f32> = (0..out_ch)
.map(|i| {
self.proj_b[i]
+ h_se
.iter()
.enumerate()
.map(|(j, &hj)| self.proj_w[i * exp_ch + j] * hj)
.sum::<f32>()
})
.collect();
if self.has_skip {
for (yi, &xi) in y.iter_mut().zip(x_row.iter()) {
*yi += xi;
}
}
out[b * out_ch..(b + 1) * out_ch].copy_from_slice(&y);
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
fn rng() -> LcgRng {
LcgRng::new(42)
}
fn default_config() -> MbConvConfig {
MbConvConfig {
in_channels: 16,
out_channels: 16,
expand_ratio: 6,
stride: 1,
kernel_size: 3,
se_ratio: 0.25,
h: 8,
w: 8,
}
}
fn make_input(batch: usize, channels: usize, seed: u64) -> Vec<f32> {
let mut r = LcgRng::new(seed);
(0..batch * channels).map(|_| r.next_f32()).collect()
}
#[test]
fn output_shape() {
let cfg = default_config();
let mut r = rng();
let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
let x = make_input(4, 16, 1);
let out = block.forward(&x, 4).expect("forward should succeed");
assert_eq!(out.len(), 4 * 16);
}
#[test]
fn output_finite() {
let cfg = default_config();
let mut r = rng();
let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
let x = make_input(2, 16, 2);
let out = block.forward(&x, 2).expect("forward should succeed");
for (i, &v) in out.iter().enumerate() {
assert!(v.is_finite(), "out[{i}] = {v}");
}
}
#[test]
fn expand_ratio_1_works() {
let cfg = MbConvConfig {
in_channels: 8,
out_channels: 8,
expand_ratio: 1,
stride: 1,
kernel_size: 3,
se_ratio: 0.25,
h: 4,
w: 4,
};
let mut r = rng();
let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
let x = make_input(3, 8, 3);
let out = block.forward(&x, 3).expect("forward should succeed");
assert_eq!(out.len(), 3 * 8);
}
#[test]
fn has_skip_correct_same_channels() {
let cfg = default_config(); let mut r = rng();
let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
assert!(block.has_skip());
}
#[test]
fn no_skip_different_channels() {
let cfg = MbConvConfig {
in_channels: 8,
out_channels: 16,
expand_ratio: 6,
stride: 1,
kernel_size: 3,
se_ratio: 0.25,
h: 4,
w: 4,
};
let mut r = rng();
let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
assert!(!block.has_skip());
}
#[test]
fn relu6_clamps_at_6() {
assert!((relu6(10.0) - 6.0).abs() < 1e-7);
assert!((relu6(-1.0) - 0.0).abs() < 1e-7);
assert!((relu6(3.0) - 3.0).abs() < 1e-7);
}
#[test]
fn batch_size_varies() {
let cfg = default_config();
for &bs in &[1_usize, 2, 8] {
let mut r = LcgRng::new(bs as u64);
let block = MbConvBlock::new(cfg.clone(), &mut r).expect("value should be present");
let x = make_input(bs, 16, bs as u64);
let out = block.forward(&x, bs).expect("forward should succeed");
assert_eq!(out.len(), bs * 16);
}
}
#[test]
fn stride_2_config_accepted() {
let cfg = MbConvConfig {
in_channels: 8,
out_channels: 16,
expand_ratio: 6,
stride: 2,
kernel_size: 5,
se_ratio: 0.25,
h: 8,
w: 8,
};
let mut r = rng();
let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
assert!(!block.has_skip()); let x = make_input(2, 8, 9);
let out = block.forward(&x, 2).expect("forward should succeed");
assert_eq!(out.len(), 2 * 16);
}
#[test]
fn expand_ratio_0_error() {
let cfg = MbConvConfig {
in_channels: 8,
out_channels: 8,
expand_ratio: 0,
stride: 1,
kernel_size: 3,
se_ratio: 0.25,
h: 4,
w: 4,
};
let mut r = rng();
let result = MbConvBlock::new(cfg, &mut r);
assert!(result.is_err());
}
#[test]
fn se_ratio_zero_error() {
let cfg = MbConvConfig {
in_channels: 8,
out_channels: 8,
expand_ratio: 6,
stride: 1,
kernel_size: 3,
se_ratio: 0.0,
h: 4,
w: 4,
};
let mut r = rng();
let result = MbConvBlock::new(cfg, &mut r);
assert!(result.is_err());
}
#[test]
fn se_ratio_affects_se_channels() {
let cfg1 = MbConvConfig {
se_ratio: 0.25,
..default_config()
};
let cfg2 = MbConvConfig {
se_ratio: 0.5,
..default_config()
};
assert!(cfg1.se_channels() < cfg2.se_channels());
}
#[test]
fn dimension_mismatch_error() {
let cfg = default_config();
let mut r = rng();
let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
let wrong_x = vec![0.0_f32; 2 * 8]; let result = block.forward(&wrong_x, 2);
assert!(result.is_err());
}
}