use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum PreprocessingStepType {
ImageNormalization,
ImageResize,
DataAugmentation,
TextTokenization,
AudioFeatures,
VideoFrames,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageNormalizationConfig {
pub mean: Vec<f32>,
pub std: Vec<f32>,
pub normalize_range: bool,
}
impl Default for ImageNormalizationConfig {
fn default() -> Self {
Self {
mean: vec![0.485, 0.456, 0.406],
std: vec![0.229, 0.224, 0.225],
normalize_range: true,
}
}
}
impl ImageNormalizationConfig {
pub fn imagenet() -> Self {
Self {
mean: vec![0.485, 0.456, 0.406],
std: vec![0.229, 0.224, 0.225],
normalize_range: true,
}
}
pub fn clip() -> Self {
Self {
mean: vec![0.481_454_7, 0.457_827_5, 0.408_210_7],
std: vec![0.268_629_5, 0.261_302_6, 0.275_777_1],
normalize_range: true,
}
}
pub fn dinov2() -> Self {
Self {
mean: vec![0.485, 0.456, 0.406],
std: vec![0.229, 0.224, 0.225],
normalize_range: true,
}
}
pub fn vit() -> Self {
Self {
mean: vec![0.5, 0.5, 0.5],
std: vec![0.5, 0.5, 0.5],
normalize_range: true,
}
}
pub fn inception() -> Self {
Self {
mean: vec![0.5, 0.5, 0.5],
std: vec![0.5, 0.5, 0.5],
normalize_range: true,
}
}
pub fn mobilenet() -> Self {
Self {
mean: vec![0.485, 0.456, 0.406],
std: vec![0.229, 0.224, 0.225],
normalize_range: true,
}
}
pub fn efficientnet() -> Self {
Self {
mean: vec![0.485, 0.456, 0.406],
std: vec![0.229, 0.224, 0.225],
normalize_range: true,
}
}
pub fn custom(mean: Vec<f32>, std: Vec<f32>, normalize_range: bool) -> Self {
Self {
mean,
std,
normalize_range,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum DatasetResizeMode {
Exact,
Fit,
Fill,
Stretch,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageResizeConfig {
pub width: u32,
pub height: u32,
pub mode: DatasetResizeMode,
pub filter: String, }
impl Default for ImageResizeConfig {
fn default() -> Self {
Self {
width: 224,
height: 224,
mode: DatasetResizeMode::Fit,
filter: "bilinear".to_string(),
}
}
}
impl ImageResizeConfig {
pub fn resnet() -> Self {
Self {
width: 224,
height: 224,
mode: DatasetResizeMode::Fit,
filter: "bilinear".to_string(),
}
}
pub fn clip() -> Self {
Self {
width: 224,
height: 224,
mode: DatasetResizeMode::Fit,
filter: "bicubic".to_string(),
}
}
pub fn dinov2() -> Self {
Self {
width: 518,
height: 518,
mode: DatasetResizeMode::Fit,
filter: "bicubic".to_string(),
}
}
pub fn vit_base() -> Self {
Self {
width: 224,
height: 224,
mode: DatasetResizeMode::Fit,
filter: "bicubic".to_string(),
}
}
pub fn vit_large() -> Self {
Self {
width: 384,
height: 384,
mode: DatasetResizeMode::Fit,
filter: "bicubic".to_string(),
}
}
pub fn inception_v3() -> Self {
Self {
width: 299,
height: 299,
mode: DatasetResizeMode::Fit,
filter: "bicubic".to_string(),
}
}
pub fn efficientnet_b0() -> Self {
Self {
width: 224,
height: 224,
mode: DatasetResizeMode::Fit,
filter: "bicubic".to_string(),
}
}
pub fn efficientnet_b7() -> Self {
Self {
width: 600,
height: 600,
mode: DatasetResizeMode::Fit,
filter: "bicubic".to_string(),
}
}
pub fn yolo() -> Self {
Self {
width: 640,
height: 640,
mode: DatasetResizeMode::Fit,
filter: "bilinear".to_string(),
}
}
pub fn custom(width: u32, height: u32, mode: DatasetResizeMode, filter: &str) -> Self {
Self {
width,
height,
mode,
filter: filter.to_string(),
}
}
}