use crate::error::{NeuralError, Result};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MobileNetConfig {
pub width_multiplier: f64,
pub input_resolution: usize,
pub num_classes: usize,
pub version: MobileNetVersion,
pub dropout_rate: f64,
pub use_batch_norm: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MobileNetVersion {
V1,
V2,
V3Small,
V3Large,
}
impl MobileNetConfig {
pub fn mobilenet_v1() -> Self {
Self {
width_multiplier: 1.0,
input_resolution: 224,
num_classes: 1000,
version: MobileNetVersion::V1,
dropout_rate: 0.001,
use_batch_norm: true,
}
}
pub fn mobilenet_v2() -> Self {
Self {
width_multiplier: 1.0,
input_resolution: 224,
num_classes: 1000,
version: MobileNetVersion::V2,
dropout_rate: 0.2,
use_batch_norm: true,
}
}
pub fn mobile_lite() -> Self {
Self {
width_multiplier: 0.25,
input_resolution: 128,
num_classes: 10,
version: MobileNetVersion::V2,
dropout_rate: 0.0,
use_batch_norm: true,
}
}
pub fn scaled_channels(&self, base: usize) -> usize {
((base as f64) * self.width_multiplier).round() as usize
}
pub fn estimated_param_count(&self) -> usize {
let c = self.scaled_channels(32);
let dw_params = 3 * 3 * c; let pw_params = c * (c * 2); (dw_params + pw_params) * 4
}
}
#[derive(Debug, Clone)]
pub struct DepthwiseSeparableConv {
in_ch: usize,
out_ch: usize,
kernel_size: (usize, usize),
dw_weights: Vec<f32>,
pw_weights: Vec<f32>,
bias: Vec<f32>,
}
impl DepthwiseSeparableConv {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
) -> Result<Self> {
if in_channels == 0 || out_channels == 0 {
return Err(NeuralError::InvalidArgument(
"DepthwiseSeparableConv: channel counts must be > 0".to_string(),
));
}
let (kh, kw) = kernel_size;
let dw_size = in_channels * kh * kw;
let pw_size = out_channels * in_channels;
let dw_scale = (2.0_f32 / (kh * kw) as f32).sqrt();
let pw_scale = (2.0_f32 / in_channels as f32).sqrt();
let dw_weights = pseudo_random_weights(dw_size, dw_scale, 1);
let pw_weights = pseudo_random_weights(pw_size, pw_scale, 2);
let bias = vec![0.0_f32; out_channels];
Ok(Self {
in_ch: in_channels,
out_ch: out_channels,
kernel_size,
dw_weights,
pw_weights,
bias,
})
}
pub fn in_channels(&self) -> usize {
self.in_ch
}
pub fn out_channels(&self) -> usize {
self.out_ch
}
pub fn kernel_size(&self) -> (usize, usize) {
self.kernel_size
}
pub fn parameter_count(&self) -> usize {
self.dw_weights.len() + self.pw_weights.len() + self.bias.len()
}
pub fn forward(
&self,
input: &[f32],
input_shape: [usize; 4],
) -> Result<(Vec<f32>, [usize; 4])> {
let [batch, in_ch, h, w] = input_shape;
if in_ch != self.in_ch {
return Err(NeuralError::ShapeMismatch(format!(
"DepthwiseSeparableConv: expected in_ch={}, got {}",
self.in_ch, in_ch
)));
}
if input.len() != batch * in_ch * h * w {
return Err(NeuralError::ShapeMismatch(
"DepthwiseSeparableConv: input slice length mismatch".to_string(),
));
}
let (kh, kw) = self.kernel_size;
let padding = (kh / 2, kw / 2);
let h_out = (h + 2 * padding.0).saturating_sub(kh) + 1;
let w_out = (w + 2 * padding.1).saturating_sub(kw) + 1;
let dw_size = batch * in_ch * h_out * w_out;
let mut dw_out = vec![0.0_f32; dw_size];
for b in 0..batch {
for c in 0..in_ch {
for oh in 0..h_out {
for ow in 0..w_out {
let mut acc = 0.0_f32;
for ki in 0..kh {
for kj in 0..kw {
let ih = oh + ki;
let iw = ow + kj;
let ih_src = ih.wrapping_sub(padding.0);
let iw_src = iw.wrapping_sub(padding.1);
if ih_src < h && iw_src < w {
let in_idx =
b * in_ch * h * w + c * h * w + ih_src * w + iw_src;
let w_idx = c * kh * kw + ki * kw + kj;
acc += input[in_idx] * self.dw_weights[w_idx];
}
}
}
let idx = b * in_ch * h_out * w_out + c * h_out * w_out + oh * w_out + ow;
dw_out[idx] = acc.clamp(0.0, 6.0);
}
}
}
}
let pw_size = batch * self.out_ch * h_out * w_out;
let mut pw_out = vec![0.0_f32; pw_size];
for b in 0..batch {
for oc in 0..self.out_ch {
for oh in 0..h_out {
for ow in 0..w_out {
let mut acc = self.bias[oc];
for ic in 0..in_ch {
let dw_idx =
b * in_ch * h_out * w_out + ic * h_out * w_out + oh * w_out + ow;
let pw_idx = oc * in_ch + ic;
acc += dw_out[dw_idx] * self.pw_weights[pw_idx];
}
let out_idx =
b * self.out_ch * h_out * w_out + oc * h_out * w_out + oh * w_out + ow;
pw_out[out_idx] = acc.clamp(0.0, 6.0);
}
}
}
}
Ok((pw_out, [batch, self.out_ch, h_out, w_out]))
}
}
#[derive(Debug, Clone)]
pub struct MobileNetV2Block {
in_ch: usize,
out_ch: usize,
expansion: usize,
stride: usize,
expand_pw: Option<PointwiseConv>,
dw: DepthwiseSeparableConv,
project_pw: PointwiseConv,
use_residual: bool,
}
impl MobileNetV2Block {
pub fn new(
in_channels: usize,
out_channels: usize,
expansion_factor: usize,
stride: usize,
) -> Result<Self> {
if in_channels == 0 || out_channels == 0 {
return Err(NeuralError::InvalidArgument(
"MobileNetV2Block: channel counts must be > 0".to_string(),
));
}
if stride == 0 {
return Err(NeuralError::InvalidArgument(
"MobileNetV2Block: stride must be >= 1".to_string(),
));
}
let expanded_ch = in_channels * expansion_factor;
let expand_pw = if expansion_factor != 1 {
Some(PointwiseConv::new(in_channels, expanded_ch)?)
} else {
None
};
let dw = DepthwiseSeparableConv::new(expanded_ch, expanded_ch, (3, 3))?;
let project_pw = PointwiseConv::new(expanded_ch, out_channels)?;
let use_residual = stride == 1 && in_channels == out_channels;
Ok(Self {
in_ch: in_channels,
out_ch: out_channels,
expansion: expansion_factor,
stride,
expand_pw,
dw,
project_pw,
use_residual,
})
}
pub fn in_channels(&self) -> usize {
self.in_ch
}
pub fn out_channels(&self) -> usize {
self.out_ch
}
pub fn expansion(&self) -> usize {
self.expansion
}
pub fn stride(&self) -> usize {
self.stride
}
pub fn has_residual(&self) -> bool {
self.use_residual
}
pub fn parameter_count(&self) -> usize {
let expand = self
.expand_pw
.as_ref()
.map(|p| p.parameter_count())
.unwrap_or(0);
expand + self.dw.parameter_count() + self.project_pw.parameter_count()
}
pub fn forward(&self, input: &[f32], shape: [usize; 4]) -> Result<(Vec<f32>, [usize; 4])> {
let [batch, in_ch, h, w] = shape;
if in_ch != self.in_ch {
return Err(NeuralError::ShapeMismatch(format!(
"MobileNetV2Block: expected in_ch={}, got {}",
self.in_ch, in_ch
)));
}
let (expanded, expanded_shape) = if let Some(ref pw) = self.expand_pw {
pw.forward_with_relu6(input, shape)?
} else {
(input.to_vec(), shape)
};
let (dw_out, dw_shape) = depthwise_only(
&expanded,
expanded_shape,
&self.dw.dw_weights,
self.dw.kernel_size,
self.stride,
)?;
let (projected, proj_shape) = self
.project_pw
.forward_linear(dw_out.as_slice(), dw_shape)?;
let output = if self.use_residual {
input
.iter()
.zip(projected.iter())
.map(|(a, b)| a + b)
.collect()
} else {
projected
};
Ok((output, proj_shape))
}
}
#[derive(Debug, Clone)]
struct PointwiseConv {
in_ch: usize,
out_ch: usize,
weights: Vec<f32>, bias: Vec<f32>, }
impl PointwiseConv {
fn new(in_channels: usize, out_channels: usize) -> Result<Self> {
let size = out_channels * in_channels;
let scale = (2.0_f32 / in_channels as f32).sqrt();
Ok(Self {
in_ch: in_channels,
out_ch: out_channels,
weights: pseudo_random_weights(size, scale, 3),
bias: vec![0.0_f32; out_channels],
})
}
fn parameter_count(&self) -> usize {
self.weights.len() + self.bias.len()
}
fn forward_with_relu6(
&self,
input: &[f32],
shape: [usize; 4],
) -> Result<(Vec<f32>, [usize; 4])> {
let [batch, in_ch, h, w] = shape;
if in_ch != self.in_ch {
return Err(NeuralError::ShapeMismatch(format!(
"PointwiseConv: in_ch mismatch {} vs {}",
self.in_ch, in_ch
)));
}
let out_size = batch * self.out_ch * h * w;
let mut out = vec![0.0_f32; out_size];
for b in 0..batch {
for oc in 0..self.out_ch {
for ph in 0..h {
for pw_pos in 0..w {
let mut acc = self.bias[oc];
for ic in 0..in_ch {
let in_idx = b * in_ch * h * w + ic * h * w + ph * w + pw_pos;
acc += input[in_idx] * self.weights[oc * in_ch + ic];
}
let out_idx = b * self.out_ch * h * w + oc * h * w + ph * w + pw_pos;
out[out_idx] = acc.clamp(0.0, 6.0);
}
}
}
}
Ok((out, [batch, self.out_ch, h, w]))
}
fn forward_linear(&self, input: &[f32], shape: [usize; 4]) -> Result<(Vec<f32>, [usize; 4])> {
let [batch, in_ch, h, w] = shape;
if in_ch != self.in_ch {
return Err(NeuralError::ShapeMismatch(format!(
"PointwiseConv(linear): in_ch mismatch {} vs {}",
self.in_ch, in_ch
)));
}
let out_size = batch * self.out_ch * h * w;
let mut out = vec![0.0_f32; out_size];
for b in 0..batch {
for oc in 0..self.out_ch {
for ph in 0..h {
for pw_pos in 0..w {
let mut acc = self.bias[oc];
for ic in 0..in_ch {
let in_idx = b * in_ch * h * w + ic * h * w + ph * w + pw_pos;
acc += input[in_idx] * self.weights[oc * in_ch + ic];
}
let out_idx = b * self.out_ch * h * w + oc * h * w + ph * w + pw_pos;
out[out_idx] = acc;
}
}
}
}
Ok((out, [batch, self.out_ch, h, w]))
}
}
fn depthwise_only(
input: &[f32],
shape: [usize; 4],
weights: &[f32],
kernel_size: (usize, usize),
stride: usize,
) -> Result<(Vec<f32>, [usize; 4])> {
let [batch, channels, h, w] = shape;
let (kh, kw) = kernel_size;
let padding = (kh / 2, kw / 2);
let h_out = if stride == 1 {
h
} else {
(h + 2 * padding.0).saturating_sub(kh) / stride + 1
};
let w_out = if stride == 1 {
w
} else {
(w + 2 * padding.1).saturating_sub(kw) / stride + 1
};
let mut out = vec![0.0_f32; batch * channels * h_out * w_out];
for b in 0..batch {
for c in 0..channels {
for oh in 0..h_out {
for ow in 0..w_out {
let mut acc = 0.0_f32;
for ki in 0..kh {
for kj in 0..kw {
let ih = oh * stride + ki;
let iw = ow * stride + kj;
let ih_src = ih.wrapping_sub(padding.0);
let iw_src = iw.wrapping_sub(padding.1);
if ih_src < h && iw_src < w {
let in_idx = b * channels * h * w + c * h * w + ih_src * w + iw_src;
let w_idx = c * kh * kw + ki * kw + kj;
acc += input[in_idx] * weights[w_idx];
}
}
}
let out_idx =
b * channels * h_out * w_out + c * h_out * w_out + oh * w_out + ow;
out[out_idx] = acc.clamp(0.0, 6.0);
}
}
}
}
Ok((out, [batch, channels, h_out, w_out]))
}
pub struct MobileOptimizer {
pub size_budget_kb: f64,
pub max_accuracy_drop: f64,
}
impl MobileOptimizer {
pub fn new(size_budget_kb: f64, max_accuracy_drop: f64) -> Result<Self> {
if size_budget_kb <= 0.0 {
return Err(NeuralError::InvalidArgument(
"size_budget_kb must be > 0".to_string(),
));
}
Ok(Self {
size_budget_kb,
max_accuracy_drop: max_accuracy_drop.clamp(0.0, 1.0),
})
}
pub fn estimate_size_bytes(num_weights: usize, bits_per_weight: u8) -> usize {
(num_weights * bits_per_weight as usize).div_ceil(8)
}
pub fn quantize_int8(weights: &[f32]) -> Result<(Vec<i8>, f32)> {
if weights.is_empty() {
return Err(NeuralError::InvalidArgument(
"quantize_int8: empty weights".to_string(),
));
}
let abs_max = weights.iter().fold(0.0_f32, |acc, &v| acc.max(v.abs()));
let scale = if abs_max > 0.0 { abs_max / 127.0 } else { 1.0 };
let quantized: Vec<i8> = weights
.iter()
.map(|&w| (w / scale).round().clamp(-128.0, 127.0) as i8)
.collect();
Ok((quantized, scale))
}
pub fn magnitude_prune(weights: &mut [f32], sparsity: f64) {
if weights.is_empty() || sparsity <= 0.0 {
return;
}
let n = weights.len();
let mut sorted_abs: Vec<f32> = weights.iter().map(|v| v.abs()).collect();
sorted_abs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let cutoff_idx = ((sparsity.clamp(0.0, 1.0) * n as f64) as usize).min(n.saturating_sub(1));
let threshold = sorted_abs[cutoff_idx];
for w in weights.iter_mut() {
if w.abs() < threshold {
*w = 0.0;
}
}
}
pub fn fits_budget(&self, param_count: usize) -> bool {
let bytes = Self::estimate_size_bytes(param_count, 32);
(bytes as f64 / 1024.0) <= self.size_budget_kb
}
}
fn pseudo_random_weights(n: usize, scale: f32, seed_offset: u64) -> Vec<f32> {
let mut state: u64 = 0xDEAD_BEEF_0000_0001u64.wrapping_add(seed_offset);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let u = (state >> 33) as f32 / u32::MAX as f32; (u * 2.0 - 1.0) * scale
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mobile_net_config_v1() {
let cfg = MobileNetConfig::mobilenet_v1();
assert_eq!(cfg.input_resolution, 224);
assert_eq!(cfg.version, MobileNetVersion::V1);
assert!((cfg.width_multiplier - 1.0).abs() < 1e-6);
}
#[test]
fn test_mobile_net_config_v2() {
let cfg = MobileNetConfig::mobilenet_v2();
assert_eq!(cfg.version, MobileNetVersion::V2);
}
#[test]
fn test_scaled_channels() {
let cfg = MobileNetConfig {
width_multiplier: 0.5,
..MobileNetConfig::mobilenet_v2()
};
assert_eq!(cfg.scaled_channels(32), 16);
assert_eq!(cfg.scaled_channels(64), 32);
}
#[test]
fn test_depthwise_separable_conv_creation() {
let dsc = DepthwiseSeparableConv::new(4, 8, (3, 3)).expect("dsc ok");
assert_eq!(dsc.in_channels(), 4);
assert_eq!(dsc.out_channels(), 8);
assert!(dsc.parameter_count() > 0);
}
#[test]
fn test_depthwise_separable_conv_forward() {
let dsc = DepthwiseSeparableConv::new(2, 4, (3, 3)).expect("dsc ok");
let input = vec![0.5_f32; 2 * 8 * 8];
let (output, out_shape) = dsc.forward(&input, [1, 2, 8, 8]).expect("forward ok");
let [b, c, h, w] = out_shape;
assert_eq!(b, 1);
assert_eq!(c, 4);
assert_eq!(h, 8); assert_eq!(w, 8);
assert_eq!(output.len(), b * c * h * w);
}
#[test]
fn test_depthwise_separable_conv_channel_mismatch_err() {
let dsc = DepthwiseSeparableConv::new(4, 8, (3, 3)).expect("dsc ok");
let input = vec![0.0_f32; 2 * 4 * 4]; let result = dsc.forward(&input, [1, 2, 4, 4]);
assert!(result.is_err());
}
#[test]
fn test_mobilenet_v2_block_creation() {
let block = MobileNetV2Block::new(32, 16, 6, 1).expect("block ok");
assert_eq!(block.in_channels(), 32);
assert_eq!(block.out_channels(), 16);
assert!(!block.has_residual()); }
#[test]
fn test_mobilenet_v2_block_residual() {
let block = MobileNetV2Block::new(16, 16, 6, 1).expect("block ok");
assert!(block.has_residual());
}
#[test]
fn test_mobilenet_v2_block_forward() {
let block = MobileNetV2Block::new(8, 8, 6, 1).expect("block ok");
let input = vec![0.1_f32; 8 * 4 * 4]; let (output, out_shape) = block.forward(&input, [1, 8, 4, 4]).expect("fwd ok");
let [b, c, _h, _w] = out_shape;
assert_eq!(b, 1);
assert_eq!(c, 8);
assert_eq!(output.len(), 8 * 4 * 4);
}
#[test]
fn test_mobilenet_v2_block_stride2() {
let block = MobileNetV2Block::new(8, 16, 6, 2).expect("block ok");
assert!(!block.has_residual());
let input = vec![0.1_f32; 8 * 8 * 8]; let (output, out_shape) = block.forward(&input, [1, 8, 8, 8]).expect("fwd ok");
let [b, c, h, w] = out_shape;
assert_eq!(b, 1);
assert_eq!(c, 16);
assert!(h <= 4 && w <= 4, "expected ≤4, got h={h} w={w}");
assert_eq!(output.len(), b * c * h * w);
}
#[test]
fn test_mobile_optimizer_quantize_int8() {
let weights = vec![0.5_f32, -0.5, 1.0, -1.0, 0.0];
let (q, scale) = MobileOptimizer::quantize_int8(&weights).expect("ok");
assert_eq!(q.len(), weights.len());
let dequant: Vec<f32> = q.iter().map(|&v| v as f32 * scale).collect();
for (orig, deq) in weights.iter().zip(dequant.iter()) {
assert!((orig - deq).abs() < 0.01, "orig={orig} deq={deq}");
}
}
#[test]
fn test_mobile_optimizer_prune() {
let mut weights = vec![0.01_f32, 0.5, 0.001, 1.0, 0.002];
MobileOptimizer::magnitude_prune(&mut weights, 0.6);
let zeros = weights.iter().filter(|&&v| v == 0.0).count();
assert!(zeros >= 2, "expected ≥2 zeros, got {zeros}");
}
#[test]
fn test_mobile_optimizer_budget() {
let opt = MobileOptimizer::new(1000.0, 0.01).expect("ok");
assert!(opt.fits_budget(10));
assert!(!opt.fits_budget(10_000_000));
}
#[test]
fn test_depthwise_separable_conv_zero_channels_err() {
assert!(DepthwiseSeparableConv::new(0, 8, (3, 3)).is_err());
}
}