use super::{Backbone, BackboneExt, BackboneType};
use super::blocks::{ConvBNActivation, InvertedResidual as BlockInvertedResidual};
use super::layer::Layer;
use crate::error::CnnResult;
use crate::layers::{self, ActivationType, GlobalAvgPool2d, Linear, TensorShape};
#[derive(Debug, Clone)]
pub struct MobileNetConfig {
pub input_size: usize,
pub width_mult: f32,
pub output_channels: usize,
}
impl Default for MobileNetConfig {
fn default() -> Self {
Self {
input_size: 224,
width_mult: 1.0,
output_channels: 576,
}
}
}
#[deprecated(since = "2.0.6", note = "Use MobileNetV3 with BackboneType::MobileNetV3Small instead")]
#[derive(Debug, Clone)]
pub struct MobileNetV3Small {
config: MobileNetConfig,
stem_weights: Vec<f32>,
stem_bn: BnParams,
blocks: Vec<InvertedResidual>,
head_weights: Vec<f32>,
}
#[deprecated(since = "2.0.6", note = "Use MobileNetV3 with BackboneType::MobileNetV3Large instead")]
#[derive(Debug, Clone)]
pub struct MobileNetV3Large {
config: MobileNetConfig,
stem_weights: Vec<f32>,
stem_bn: BnParams,
blocks: Vec<InvertedResidual>,
head_weights: Vec<f32>,
}
#[derive(Debug, Clone)]
struct BnParams {
gamma: Vec<f32>,
beta: Vec<f32>,
mean: Vec<f32>,
var: Vec<f32>,
}
impl BnParams {
fn new(channels: usize) -> Self {
Self {
gamma: vec![1.0; channels],
beta: vec![0.0; channels],
mean: vec![0.0; channels],
var: vec![1.0; channels],
}
}
}
#[derive(Debug, Clone)]
struct InvertedResidual {
expand_weights: Option<Vec<f32>>,
expand_bn: Option<BnParams>,
dw_weights: Vec<f32>,
dw_bn: BnParams,
se_reduce: Option<Vec<f32>>,
se_expand: Option<Vec<f32>>,
project_weights: Vec<f32>,
project_bn: BnParams,
in_channels: usize,
out_channels: usize,
expansion: usize,
use_se: bool,
use_residual: bool,
}
impl MobileNetV3Small {
pub fn new(config: MobileNetConfig) -> Self {
let stem_channels = 16;
Self {
stem_weights: vec![0.0; 3 * 3 * 3 * stem_channels],
stem_bn: BnParams::new(stem_channels),
blocks: Self::create_blocks(&config),
head_weights: vec![0.0; config.output_channels],
config,
}
}
fn create_blocks(config: &MobileNetConfig) -> Vec<InvertedResidual> {
let block_configs = [
(16, 16, 1, false), (16, 24, 4, false),
(24, 24, 3, false),
(24, 40, 3, true),
(40, 40, 3, true),
(40, 48, 3, true),
(48, 48, 3, true),
(48, 96, 6, true),
(96, 96, 6, true),
(96, 96, 6, true),
];
block_configs.iter().map(|&(in_c, out_c, exp, se)| {
let in_c = ((in_c as f32) * config.width_mult) as usize;
let out_c = ((out_c as f32) * config.width_mult) as usize;
let mid_c = in_c * exp;
InvertedResidual {
expand_weights: if exp != 1 { Some(vec![0.0; in_c * mid_c]) } else { None },
expand_bn: if exp != 1 { Some(BnParams::new(mid_c)) } else { None },
dw_weights: vec![0.0; 9 * mid_c],
dw_bn: BnParams::new(mid_c),
se_reduce: if se { Some(vec![0.0; mid_c * (mid_c / 4)]) } else { None },
se_expand: if se { Some(vec![0.0; (mid_c / 4) * mid_c]) } else { None },
project_weights: vec![0.0; mid_c * out_c],
project_bn: BnParams::new(out_c),
in_channels: in_c,
out_channels: out_c,
expansion: exp,
use_se: se,
use_residual: in_c == out_c,
}
}).collect()
}
}
impl Backbone for MobileNetV3Small {
fn forward(&self, input: &[f32], height: usize, width: usize) -> Vec<f32> {
let mut x = layers::conv2d_3x3(
input,
&self.stem_weights,
3,
16,
height,
width,
);
x = layers::batch_norm(
&x,
&self.stem_bn.gamma,
&self.stem_bn.beta,
&self.stem_bn.mean,
&self.stem_bn.var,
1e-5,
);
x = layers::hard_swish(&x);
let mut current_channels = 16;
for block in &self.blocks {
x = Self::process_inverted_residual(&x, block, current_channels);
current_channels = block.out_channels;
}
let pooled = layers::global_avg_pool(&x, current_channels);
pooled
}
fn output_dim(&self) -> usize {
self.config.output_channels
}
fn input_size(&self) -> usize {
self.config.input_size
}
}
impl MobileNetV3Large {
pub fn new(config: MobileNetConfig) -> Self {
let stem_channels = 16;
Self {
stem_weights: vec![0.0; 3 * 3 * 3 * stem_channels],
stem_bn: BnParams::new(stem_channels),
blocks: Self::create_blocks(&config),
head_weights: vec![0.0; config.output_channels],
config,
}
}
fn create_blocks(config: &MobileNetConfig) -> Vec<InvertedResidual> {
let block_configs = [
(16, 16, 1, false),
(16, 24, 4, false),
(24, 24, 3, false),
(24, 40, 3, true),
(40, 40, 3, true),
(40, 40, 3, true),
(40, 80, 6, false),
(80, 80, 2, false),
(80, 80, 2, false),
(80, 112, 6, true),
(112, 112, 6, true),
(112, 160, 6, true),
(160, 160, 6, true),
(160, 160, 6, true),
];
block_configs.iter().map(|&(in_c, out_c, exp, se)| {
let in_c = ((in_c as f32) * config.width_mult) as usize;
let out_c = ((out_c as f32) * config.width_mult) as usize;
let mid_c = in_c * exp;
InvertedResidual {
expand_weights: if exp != 1 { Some(vec![0.0; in_c * mid_c]) } else { None },
expand_bn: if exp != 1 { Some(BnParams::new(mid_c)) } else { None },
dw_weights: vec![0.0; 9 * mid_c],
dw_bn: BnParams::new(mid_c),
se_reduce: if se { Some(vec![0.0; mid_c * (mid_c / 4)]) } else { None },
se_expand: if se { Some(vec![0.0; (mid_c / 4) * mid_c]) } else { None },
project_weights: vec![0.0; mid_c * out_c],
project_bn: BnParams::new(out_c),
in_channels: in_c,
out_channels: out_c,
expansion: exp,
use_se: se,
use_residual: in_c == out_c,
}
}).collect()
}
}
impl Backbone for MobileNetV3Large {
fn forward(&self, input: &[f32], height: usize, width: usize) -> Vec<f32> {
let mut x = layers::conv2d_3x3(
input,
&self.stem_weights,
3,
16,
height,
width,
);
x = layers::batch_norm(
&x,
&self.stem_bn.gamma,
&self.stem_bn.beta,
&self.stem_bn.mean,
&self.stem_bn.var,
1e-5,
);
x = layers::hard_swish(&x);
let mut current_channels = 16;
for block in &self.blocks {
x = Self::process_inverted_residual(&x, block, current_channels);
current_channels = block.out_channels;
}
let pooled = layers::global_avg_pool(&x, current_channels);
pooled
}
fn output_dim(&self) -> usize {
self.config.output_channels
}
fn input_size(&self) -> usize {
self.config.input_size
}
fn process_inverted_residual(input: &[f32], block: &InvertedResidual, in_channels: usize) -> Vec<f32> {
let spatial = input.len() / in_channels;
let h = (spatial as f32).sqrt() as usize;
let w = h;
let mut x = input.to_vec();
let mut current_c = in_channels;
if let (Some(ref weights), Some(ref bn)) = (&block.expand_weights, &block.expand_bn) {
let exp_c = block.expansion * in_channels;
let mut expanded = vec![0.0f32; spatial * exp_c];
for s in 0..spatial {
for oc in 0..exp_c {
let mut sum = 0.0f32;
for ic in 0..current_c {
sum += x[s * current_c + ic] * weights[oc * current_c + ic];
}
expanded[s * exp_c + oc] = sum;
}
}
x = layers::batch_norm(&expanded, &bn.gamma, &bn.beta, &bn.mean, &bn.var, 1e-5);
x = layers::hard_swish(&x);
current_c = exp_c;
}
let dw_c = current_c;
let mut dw_out = vec![0.0f32; spatial * dw_c];
for oh in 0..h {
for ow in 0..w {
for c in 0..dw_c {
let mut sum = 0.0f32;
for kh in 0..3 {
for kw in 0..3 {
let ih = oh as isize + kh as isize - 1;
let iw = ow as isize + kw as isize - 1;
if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
let idx = (ih as usize * w + iw as usize) * dw_c + c;
sum += x[idx] * block.dw_weights[c * 9 + kh * 3 + kw];
}
}
}
dw_out[(oh * w + ow) * dw_c + c] = sum;
}
}
}
x = layers::batch_norm(&dw_out, &block.dw_bn.gamma, &block.dw_bn.beta,
&block.dw_bn.mean, &block.dw_bn.var, 1e-5);
x = layers::hard_swish(&x);
if let (Some(ref reduce), Some(ref expand)) = (&block.se_reduce, &block.se_expand) {
let se_c = reduce.len() / dw_c;
let mut pooled = vec![0.0f32; dw_c];
for c in 0..dw_c {
let mut sum = 0.0f32;
for s in 0..spatial {
sum += x[s * dw_c + c];
}
pooled[c] = sum / spatial as f32;
}
let mut squeezed = vec![0.0f32; se_c];
for o in 0..se_c {
let mut sum = 0.0f32;
for i in 0..dw_c {
sum += pooled[i] * reduce[o * dw_c + i];
}
squeezed[o] = sum.max(0.0);
}
let mut scale = vec![0.0f32; dw_c];
for o in 0..dw_c {
let mut sum = 0.0f32;
for i in 0..se_c {
sum += squeezed[i] * expand[o * se_c + i];
}
scale[o] = 1.0 / (1.0 + (-sum).exp());
}
for s in 0..spatial {
for c in 0..dw_c {
x[s * dw_c + c] *= scale[c];
}
}
}
let out_c = block.out_channels;
let mut projected = vec![0.0f32; spatial * out_c];
for s in 0..spatial {
for oc in 0..out_c {
let mut sum = 0.0f32;
for ic in 0..dw_c {
sum += x[s * dw_c + ic] * block.project_weights[oc * dw_c + ic];
}
projected[s * out_c + oc] = sum;
}
}
let output = layers::batch_norm(&projected, &block.project_bn.gamma, &block.project_bn.beta,
&block.project_bn.mean, &block.project_bn.var, 1e-5);
if block.use_residual && in_channels == out_c {
let mut result = output.clone();
for i in 0..result.len() {
result[i] += input[i];
}
result
} else {
output
}
}
}
impl MobileNetV3Small {
fn process_inverted_residual(input: &[f32], block: &InvertedResidual, in_channels: usize) -> Vec<f32> {
MobileNetV3Large::process_inverted_residual(input, block, in_channels)
}
}
#[derive(Clone, Debug)]
struct BlockConfig {
in_channels: usize,
expanded_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
use_se: bool,
activation: ActivationType,
}
impl BlockConfig {
fn new(
in_c: usize,
exp_c: usize,
out_c: usize,
kernel: usize,
stride: usize,
use_se: bool,
activation: ActivationType,
) -> Self {
Self {
in_channels: in_c,
expanded_channels: exp_c,
out_channels: out_c,
kernel_size: kernel,
stride,
use_se,
activation,
}
}
}
#[derive(Clone, Debug)]
pub struct MobileNetV3Config {
pub input_size: usize,
pub input_channels: usize,
pub width_mult: f32,
pub num_classes: usize,
pub dropout: f32,
pub variant: BackboneType,
block_configs: Vec<BlockConfig>,
pub last_channels: usize,
pub feature_dim: usize,
}
impl MobileNetV3Config {
pub fn small(num_classes: usize) -> Self {
let block_configs = vec![
BlockConfig::new(16, 16, 16, 3, 2, true, ActivationType::ReLU),
BlockConfig::new(16, 72, 24, 3, 2, false, ActivationType::ReLU),
BlockConfig::new(24, 88, 24, 3, 1, false, ActivationType::ReLU),
BlockConfig::new(24, 96, 40, 5, 2, true, ActivationType::HardSwish),
BlockConfig::new(40, 240, 40, 5, 1, true, ActivationType::HardSwish),
BlockConfig::new(40, 240, 40, 5, 1, true, ActivationType::HardSwish),
BlockConfig::new(40, 120, 48, 5, 1, true, ActivationType::HardSwish),
BlockConfig::new(48, 144, 48, 5, 1, true, ActivationType::HardSwish),
BlockConfig::new(48, 288, 96, 5, 2, true, ActivationType::HardSwish),
BlockConfig::new(96, 576, 96, 5, 1, true, ActivationType::HardSwish),
BlockConfig::new(96, 576, 96, 5, 1, true, ActivationType::HardSwish),
];
Self {
input_size: 224,
input_channels: 3,
width_mult: 1.0,
num_classes,
dropout: 0.2,
variant: BackboneType::MobileNetV3Small,
block_configs,
last_channels: 1024,
feature_dim: 576,
}
}
pub fn large(num_classes: usize) -> Self {
let block_configs = vec![
BlockConfig::new(16, 16, 16, 3, 1, false, ActivationType::ReLU),
BlockConfig::new(16, 64, 24, 3, 2, false, ActivationType::ReLU),
BlockConfig::new(24, 72, 24, 3, 1, false, ActivationType::ReLU),
BlockConfig::new(24, 72, 40, 5, 2, true, ActivationType::ReLU),
BlockConfig::new(40, 120, 40, 5, 1, true, ActivationType::ReLU),
BlockConfig::new(40, 120, 40, 5, 1, true, ActivationType::ReLU),
BlockConfig::new(40, 240, 80, 3, 2, false, ActivationType::HardSwish),
BlockConfig::new(80, 200, 80, 3, 1, false, ActivationType::HardSwish),
BlockConfig::new(80, 184, 80, 3, 1, false, ActivationType::HardSwish),
BlockConfig::new(80, 184, 80, 3, 1, false, ActivationType::HardSwish),
BlockConfig::new(80, 480, 112, 3, 1, true, ActivationType::HardSwish),
BlockConfig::new(112, 672, 112, 3, 1, true, ActivationType::HardSwish),
BlockConfig::new(112, 672, 160, 5, 2, true, ActivationType::HardSwish),
BlockConfig::new(160, 960, 160, 5, 1, true, ActivationType::HardSwish),
BlockConfig::new(160, 960, 160, 5, 1, true, ActivationType::HardSwish),
];
Self {
input_size: 224,
input_channels: 3,
width_mult: 1.0,
num_classes,
dropout: 0.2,
variant: BackboneType::MobileNetV3Large,
block_configs,
last_channels: 1280,
feature_dim: 960,
}
}
pub fn width_mult(mut self, mult: f32) -> Self {
self.width_mult = mult;
self
}
pub fn dropout(mut self, rate: f32) -> Self {
self.dropout = rate;
self
}
fn scale_channels(&self, channels: usize) -> usize {
((channels as f32 * self.width_mult).round() as usize).max(1)
}
}
#[derive(Clone, Debug)]
pub struct MobileNetV3 {
config: MobileNetV3Config,
stem: ConvBNActivation,
blocks: Vec<BlockInvertedResidual>,
last_conv: ConvBNActivation,
pool: GlobalAvgPool2d,
classifier: Option<Linear>,
}
impl MobileNetV3 {
pub fn new(config: MobileNetV3Config) -> CnnResult<Self> {
let stem_out = config.scale_channels(16);
let stem = ConvBNActivation::new(
config.input_channels,
stem_out,
3,
2,
1,
1,
ActivationType::HardSwish,
)?;
let mut blocks = Vec::with_capacity(config.block_configs.len());
let mut in_channels = stem_out;
for bc in &config.block_configs {
let exp_channels = config.scale_channels(bc.expanded_channels);
let out_channels = config.scale_channels(bc.out_channels);
let block = BlockInvertedResidual::create(
in_channels,
exp_channels,
out_channels,
bc.kernel_size,
bc.stride,
bc.use_se,
bc.activation,
)?;
blocks.push(block);
in_channels = out_channels;
}
let feature_dim = config.scale_channels(config.feature_dim);
let last_conv = ConvBNActivation::pointwise(
in_channels,
feature_dim,
ActivationType::HardSwish,
)?;
let pool = GlobalAvgPool2d::new();
let classifier = if config.num_classes > 0 {
let last_channels = config.scale_channels(config.last_channels);
Some(Linear::new(feature_dim, config.num_classes, true)?)
} else {
None
};
Ok(Self {
config,
stem,
blocks,
last_conv,
pool,
classifier,
})
}
pub fn small(num_classes: usize) -> CnnResult<Self> {
Self::new(MobileNetV3Config::small(num_classes))
}
pub fn large(num_classes: usize) -> CnnResult<Self> {
Self::new(MobileNetV3Config::large(num_classes))
}
pub fn config(&self) -> &MobileNetV3Config {
&self.config
}
pub fn num_blocks(&self) -> usize {
self.blocks.len()
}
pub fn stem(&self) -> &ConvBNActivation {
&self.stem
}
pub fn blocks(&self) -> &[BlockInvertedResidual] {
&self.blocks
}
pub fn last_conv(&self) -> &ConvBNActivation {
&self.last_conv
}
pub fn classifier(&self) -> Option<&Linear> {
self.classifier.as_ref()
}
fn forward_features_impl(&self, input: &[f32], input_shape: &TensorShape) -> CnnResult<Vec<f32>> {
let mut x = input.to_vec();
let mut shape = *input_shape;
let stem_shape = self.stem.output_shape(&shape);
x = self.stem.forward(&x, &shape)?;
shape = stem_shape;
for block in &self.blocks {
let block_shape = block.output_shape(&shape);
x = block.forward(&x, &shape)?;
shape = block_shape;
}
let last_shape = self.last_conv.output_shape(&shape);
x = self.last_conv.forward(&x, &shape)?;
shape = last_shape;
x = self.pool.forward(&x, &shape)?;
Ok(x)
}
}
impl Backbone for MobileNetV3 {
fn forward(&self, input: &[f32], height: usize, width: usize) -> Vec<f32> {
let input_shape = TensorShape::new(1, self.config.input_channels, height, width);
self.forward_features_impl(input, &input_shape)
.unwrap_or_else(|_| vec![0.0; self.output_dim()])
}
fn output_dim(&self) -> usize {
self.config.scale_channels(self.config.feature_dim)
}
fn input_size(&self) -> usize {
self.config.input_size
}
}
impl BackboneExt for MobileNetV3 {
fn backbone_type(&self) -> BackboneType {
self.config.variant
}
fn num_params(&self) -> usize {
let mut total = self.stem.num_params();
for block in &self.blocks {
total += block.num_params();
}
total += self.last_conv.num_params();
if let Some(ref classifier) = self.classifier {
total += classifier.num_params();
}
total
}
fn forward_with_shape(&self, input: &[f32], input_shape: &TensorShape) -> CnnResult<Vec<f32>> {
let features = self.forward_features_impl(input, input_shape)?;
if let Some(ref classifier) = self.classifier {
let batch_size = input_shape.n;
let feature_shape = TensorShape::new(batch_size, self.output_dim(), 1, 1);
classifier.forward(&features, &feature_shape)
} else {
Ok(features)
}
}
fn forward_features(&self, input: &[f32], input_shape: &TensorShape) -> CnnResult<Vec<f32>> {
self.forward_features_impl(input, input_shape)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mobilenet_v3_small_creation() {
let config = MobileNetConfig::default();
let model = MobileNetV3Small::new(config);
assert_eq!(model.output_dim(), 576);
}
#[test]
fn test_mobilenet_v3_large_creation() {
let config = MobileNetConfig {
output_channels: 960,
..Default::default()
};
let model = MobileNetV3Large::new(config);
assert_eq!(model.output_dim(), 960);
}
#[test]
fn test_unified_mobilenet_v3_small() {
let model = MobileNetV3::small(1000).unwrap();
assert_eq!(model.backbone_type(), BackboneType::MobileNetV3Small);
assert_eq!(model.output_dim(), 576);
assert!(model.num_params() > 0);
}
#[test]
fn test_unified_mobilenet_v3_large() {
let model = MobileNetV3::large(1000).unwrap();
assert_eq!(model.backbone_type(), BackboneType::MobileNetV3Large);
assert_eq!(model.output_dim(), 960);
assert!(model.num_params() > 0);
}
#[test]
fn test_mobilenet_v3_config() {
let config = MobileNetV3Config::small(1000);
assert_eq!(config.input_size, 224);
assert_eq!(config.input_channels, 3);
assert_eq!(config.num_classes, 1000);
assert_eq!(config.feature_dim, 576);
}
#[test]
fn test_mobilenet_v3_forward() {
let model = MobileNetV3::small(0).unwrap(); let input_shape = TensorShape::new(1, 3, 224, 224);
let input = vec![0.5; input_shape.numel()];
let output = model.forward_features(&input, &input_shape).unwrap();
assert_eq!(output.len(), 576);
}
#[test]
fn test_mobilenet_v3_with_classifier() {
let model = MobileNetV3::small(1000).unwrap();
let input_shape = TensorShape::new(1, 3, 224, 224);
let input = vec![0.5; input_shape.numel()];
let output = model.forward_with_shape(&input, &input_shape).unwrap();
assert_eq!(output.len(), 1000);
}
#[test]
fn test_mobilenet_v3_batch() {
let model = MobileNetV3::small(0).unwrap();
let batch_size = 2;
let input_shape = TensorShape::new(batch_size, 3, 224, 224);
let input = vec![0.5; input_shape.numel()];
let output = model.forward_features(&input, &input_shape).unwrap();
assert_eq!(output.len(), batch_size * 576);
}
}