mod blocks;
mod layer;
mod mobilenet;
pub use layer::Layer;
pub use blocks::{ConvBNActivation, InvertedResidual, InvertedResidualConfig, SqueezeExcitation};
pub use mobilenet::{MobileNetV3, MobileNetV3Config};
pub use mobilenet::{MobileNetConfig, MobileNetV3Large, MobileNetV3Small};
use crate::error::CnnResult;
use crate::layers::TensorShape;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum BackboneType {
MobileNetV3Small,
MobileNetV3Large,
}
impl BackboneType {
pub fn output_dim(&self) -> usize {
match self {
BackboneType::MobileNetV3Small => 576,
BackboneType::MobileNetV3Large => 960,
}
}
pub fn name(&self) -> &'static str {
match self {
BackboneType::MobileNetV3Small => "MobileNetV3-Small",
BackboneType::MobileNetV3Large => "MobileNetV3-Large",
}
}
pub fn input_size(&self) -> (usize, usize) {
(224, 224)
}
pub fn input_channels(&self) -> usize {
3
}
}
impl std::fmt::Display for BackboneType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
pub trait Backbone: Send + Sync {
fn forward(&self, input: &[f32], height: usize, width: usize) -> Vec<f32>;
fn output_dim(&self) -> usize;
fn input_size(&self) -> usize;
}
pub trait BackboneExt: Backbone {
fn backbone_type(&self) -> BackboneType;
fn num_params(&self) -> usize;
fn forward_with_shape(&self, input: &[f32], input_shape: &TensorShape) -> CnnResult<Vec<f32>>;
fn forward_features(&self, input: &[f32], input_shape: &TensorShape) -> CnnResult<Vec<f32>>;
fn feature_output_shape(&self, input_shape: &TensorShape) -> TensorShape {
TensorShape {
n: input_shape.n,
c: self.output_dim(),
h: 1,
w: 1,
}
}
}
pub fn create_backbone(
backbone_type: BackboneType,
num_classes: usize,
) -> CnnResult<Box<dyn BackboneExt>> {
match backbone_type {
BackboneType::MobileNetV3Small => {
let config = MobileNetV3Config::small(num_classes);
Ok(Box::new(MobileNetV3::new(config)?))
}
BackboneType::MobileNetV3Large => {
let config = MobileNetV3Config::large(num_classes);
Ok(Box::new(MobileNetV3::new(config)?))
}
}
}
pub fn mobilenet_v3_small(num_classes: usize) -> CnnResult<MobileNetV3> {
let config = MobileNetV3Config::small(num_classes);
MobileNetV3::new(config)
}
pub fn mobilenet_v3_large(num_classes: usize) -> CnnResult<MobileNetV3> {
let config = MobileNetV3Config::large(num_classes);
MobileNetV3::new(config)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backbone_type_output_dim() {
assert_eq!(BackboneType::MobileNetV3Small.output_dim(), 576);
assert_eq!(BackboneType::MobileNetV3Large.output_dim(), 960);
}
#[test]
fn test_backbone_type_name() {
assert_eq!(BackboneType::MobileNetV3Small.name(), "MobileNetV3-Small");
assert_eq!(BackboneType::MobileNetV3Large.name(), "MobileNetV3-Large");
}
#[test]
fn test_create_backbone_small() {
let backbone = create_backbone(BackboneType::MobileNetV3Small, 1000).unwrap();
assert_eq!(backbone.backbone_type(), BackboneType::MobileNetV3Small);
assert_eq!(backbone.output_dim(), 576);
}
#[test]
fn test_create_backbone_large() {
let backbone = create_backbone(BackboneType::MobileNetV3Large, 1000).unwrap();
assert_eq!(backbone.backbone_type(), BackboneType::MobileNetV3Large);
assert_eq!(backbone.output_dim(), 960);
}
#[test]
fn test_backward_compat_small() {
let config = MobileNetConfig::default();
let model = MobileNetV3Small::new(config);
assert_eq!(model.output_dim(), 576);
}
#[test]
fn test_backward_compat_large() {
let config = MobileNetConfig {
output_channels: 960,
..Default::default()
};
let model = MobileNetV3Large::new(config);
assert_eq!(model.output_dim(), 960);
}
}