torsh-vision 0.1.2

Computer vision utilities for ToRSh deep learning framework
Documentation
use std::collections::HashMap;
use torsh_core::{device::DeviceType, error::Result};
use torsh_nn::prelude::*;
use torsh_nn::{Module, Parameter};
use torsh_tensor::Tensor;

use crate::{ModelConfig, VisionModel};

/// MobileNetV1 model - placeholder implementation
pub struct MobileNetV1 {
    conv1: Conv2d,
    conv2: Conv2d,
    fc: Linear,
    num_classes: usize,
    is_training: bool,
}

impl MobileNetV1 {
    pub fn new(config: ModelConfig, _width_mult: f32) -> Self {
        Self {
            conv1: Conv2d::new(3, 32, (3, 3), (2, 2), (1, 1), (1, 1), false, 1),
            conv2: Conv2d::new(32, 64, (3, 3), (1, 1), (1, 1), (1, 1), false, 1),
            fc: Linear::new(64, config.num_classes, true),
            num_classes: config.num_classes,
            is_training: true,
        }
    }

    pub fn mobilenet_v1_1_0(config: ModelConfig) -> Self {
        Self::new(config, 1.0)
    }

    pub fn mobilenet_v1_0_75(config: ModelConfig) -> Self {
        Self::new(config, 0.75)
    }

    pub fn mobilenet_v1_0_5(config: ModelConfig) -> Self {
        Self::new(config, 0.5)
    }

    pub fn mobilenet_v1_0_25(config: ModelConfig) -> Self {
        Self::new(config, 0.25)
    }
}

impl Module for MobileNetV1 {
    fn forward(&self, input: &Tensor) -> Result<Tensor> {
        let mut x = self.conv1.forward(input)?;
        x = self.conv2.forward(&x)?;
        self.fc.forward(&x)
    }

    fn parameters(&self) -> HashMap<String, Parameter> {
        let mut params = HashMap::new();
        params.extend(self.conv1.parameters());
        params.extend(self.conv2.parameters());
        params.extend(self.fc.parameters());
        params
    }

    fn named_parameters(&self) -> HashMap<String, Parameter> {
        self.parameters()
    }

    fn training(&self) -> bool {
        self.is_training
    }

    fn train(&mut self) {
        self.is_training = true;
        self.conv1.train();
        self.conv2.train();
        self.fc.train();
    }

    fn eval(&mut self) {
        self.is_training = false;
        self.conv1.eval();
        self.conv2.eval();
        self.fc.eval();
    }

    fn to_device(&mut self, device: DeviceType) -> Result<()> {
        self.conv1.to_device(device)?;
        self.conv2.to_device(device)?;
        self.fc.to_device(device)?;
        Ok(())
    }

    fn set_training(&mut self, training: bool) {
        if training {
            self.train();
        } else {
            self.eval();
        }
    }
}

impl VisionModel for MobileNetV1 {
    fn num_classes(&self) -> usize {
        self.num_classes
    }

    fn input_size(&self) -> (usize, usize) {
        (224, 224)
    }

    fn name(&self) -> &str {
        "MobileNetV1"
    }
}

/// MobileNetV2 model - placeholder implementation
pub struct MobileNetV2 {
    conv1: Conv2d,
    conv2: Conv2d,
    fc: Linear,
    num_classes: usize,
    is_training: bool,
}

impl MobileNetV2 {
    pub fn new(config: ModelConfig, _width_mult: f32) -> Self {
        Self {
            conv1: Conv2d::new(3, 32, (3, 3), (2, 2), (1, 1), (1, 1), false, 1),
            conv2: Conv2d::new(32, 64, (3, 3), (1, 1), (1, 1), (1, 1), false, 1),
            fc: Linear::new(64, config.num_classes, true),
            num_classes: config.num_classes,
            is_training: true,
        }
    }

    pub fn mobilenet_v2_1_0(config: ModelConfig) -> Self {
        Self::new(config, 1.0)
    }

    pub fn mobilenet_v2_0_75(config: ModelConfig) -> Self {
        Self::new(config, 0.75)
    }

    pub fn mobilenet_v2_0_5(config: ModelConfig) -> Self {
        Self::new(config, 0.5)
    }
}

impl Module for MobileNetV2 {
    fn forward(&self, input: &Tensor) -> Result<Tensor> {
        let mut x = self.conv1.forward(input)?;
        x = self.conv2.forward(&x)?;
        self.fc.forward(&x)
    }

    fn parameters(&self) -> HashMap<String, Parameter> {
        let mut params = HashMap::new();
        params.extend(self.conv1.parameters());
        params.extend(self.conv2.parameters());
        params.extend(self.fc.parameters());
        params
    }

    fn named_parameters(&self) -> HashMap<String, Parameter> {
        self.parameters()
    }

    fn training(&self) -> bool {
        self.is_training
    }

    fn train(&mut self) {
        self.is_training = true;
        self.conv1.train();
        self.conv2.train();
        self.fc.train();
    }

    fn eval(&mut self) {
        self.is_training = false;
        self.conv1.eval();
        self.conv2.eval();
        self.fc.eval();
    }

    fn to_device(&mut self, device: DeviceType) -> Result<()> {
        self.conv1.to_device(device)?;
        self.conv2.to_device(device)?;
        self.fc.to_device(device)?;
        Ok(())
    }

    fn set_training(&mut self, training: bool) {
        if training {
            self.train();
        } else {
            self.eval();
        }
    }
}

impl VisionModel for MobileNetV2 {
    fn num_classes(&self) -> usize {
        self.num_classes
    }

    fn input_size(&self) -> (usize, usize) {
        (224, 224)
    }

    fn name(&self) -> &str {
        "MobileNetV2"
    }
}