use std::collections::HashMap;
use torsh_core::{device::DeviceType, error::Result};
use torsh_nn::prelude::*;
use torsh_nn::{Module, Parameter};
use torsh_tensor::Tensor;
use crate::{ModelConfig, VisionModel};
pub struct MobileNetV1 {
conv1: Conv2d,
conv2: Conv2d,
fc: Linear,
num_classes: usize,
is_training: bool,
}
impl MobileNetV1 {
pub fn new(config: ModelConfig, _width_mult: f32) -> Self {
Self {
conv1: Conv2d::new(3, 32, (3, 3), (2, 2), (1, 1), (1, 1), false, 1),
conv2: Conv2d::new(32, 64, (3, 3), (1, 1), (1, 1), (1, 1), false, 1),
fc: Linear::new(64, config.num_classes, true),
num_classes: config.num_classes,
is_training: true,
}
}
pub fn mobilenet_v1_1_0(config: ModelConfig) -> Self {
Self::new(config, 1.0)
}
pub fn mobilenet_v1_0_75(config: ModelConfig) -> Self {
Self::new(config, 0.75)
}
pub fn mobilenet_v1_0_5(config: ModelConfig) -> Self {
Self::new(config, 0.5)
}
pub fn mobilenet_v1_0_25(config: ModelConfig) -> Self {
Self::new(config, 0.25)
}
}
impl Module for MobileNetV1 {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let mut x = self.conv1.forward(input)?;
x = self.conv2.forward(&x)?;
self.fc.forward(&x)
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
params.extend(self.conv1.parameters());
params.extend(self.conv2.parameters());
params.extend(self.fc.parameters());
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.parameters()
}
fn training(&self) -> bool {
self.is_training
}
fn train(&mut self) {
self.is_training = true;
self.conv1.train();
self.conv2.train();
self.fc.train();
}
fn eval(&mut self) {
self.is_training = false;
self.conv1.eval();
self.conv2.eval();
self.fc.eval();
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.conv1.to_device(device)?;
self.conv2.to_device(device)?;
self.fc.to_device(device)?;
Ok(())
}
fn set_training(&mut self, training: bool) {
if training {
self.train();
} else {
self.eval();
}
}
}
impl VisionModel for MobileNetV1 {
fn num_classes(&self) -> usize {
self.num_classes
}
fn input_size(&self) -> (usize, usize) {
(224, 224)
}
fn name(&self) -> &str {
"MobileNetV1"
}
}
pub struct MobileNetV2 {
conv1: Conv2d,
conv2: Conv2d,
fc: Linear,
num_classes: usize,
is_training: bool,
}
impl MobileNetV2 {
pub fn new(config: ModelConfig, _width_mult: f32) -> Self {
Self {
conv1: Conv2d::new(3, 32, (3, 3), (2, 2), (1, 1), (1, 1), false, 1),
conv2: Conv2d::new(32, 64, (3, 3), (1, 1), (1, 1), (1, 1), false, 1),
fc: Linear::new(64, config.num_classes, true),
num_classes: config.num_classes,
is_training: true,
}
}
pub fn mobilenet_v2_1_0(config: ModelConfig) -> Self {
Self::new(config, 1.0)
}
pub fn mobilenet_v2_0_75(config: ModelConfig) -> Self {
Self::new(config, 0.75)
}
pub fn mobilenet_v2_0_5(config: ModelConfig) -> Self {
Self::new(config, 0.5)
}
}
impl Module for MobileNetV2 {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let mut x = self.conv1.forward(input)?;
x = self.conv2.forward(&x)?;
self.fc.forward(&x)
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
params.extend(self.conv1.parameters());
params.extend(self.conv2.parameters());
params.extend(self.fc.parameters());
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.parameters()
}
fn training(&self) -> bool {
self.is_training
}
fn train(&mut self) {
self.is_training = true;
self.conv1.train();
self.conv2.train();
self.fc.train();
}
fn eval(&mut self) {
self.is_training = false;
self.conv1.eval();
self.conv2.eval();
self.fc.eval();
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.conv1.to_device(device)?;
self.conv2.to_device(device)?;
self.fc.to_device(device)?;
Ok(())
}
fn set_training(&mut self, training: bool) {
if training {
self.train();
} else {
self.eval();
}
}
}
impl VisionModel for MobileNetV2 {
fn num_classes(&self) -> usize {
self.num_classes
}
fn input_size(&self) -> (usize, usize) {
(224, 224)
}
fn name(&self) -> &str {
"MobileNetV2"
}
}