use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelType {
Acoustic,
Vocoder,
G2P,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub model_type: ModelType,
pub language: String,
pub description: String,
pub version: String,
pub size_mb: f64,
pub sample_rate: u32,
pub quality_score: f32,
pub supported_backends: Vec<String>,
pub is_installed: bool,
pub installation_path: Option<String>,
pub metadata: HashMap<String, String>,
}
impl ModelInfo {
pub fn new(
id: String,
name: String,
model_type: ModelType,
language: String,
description: String,
) -> Self {
Self {
id,
name,
model_type,
language,
description,
version: "1.0.0".to_string(),
size_mb: 0.0,
sample_rate: 22050,
quality_score: 3.5,
supported_backends: vec!["pytorch".to_string()],
is_installed: false,
installation_path: None,
metadata: HashMap::new(),
}
}
pub fn supports_backend(&self, backend: &str) -> bool {
self.supported_backends.iter().any(|b| b == backend)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_info_creation() {
let model = ModelInfo::new(
"test-model".to_string(),
"Test Model".to_string(),
ModelType::Acoustic,
"en".to_string(),
"A test model".to_string(),
);
assert_eq!(model.id, "test-model");
assert_eq!(model.model_type, ModelType::Acoustic);
assert!(!model.is_installed);
}
#[test]
fn test_supports_backend() {
let mut model = ModelInfo::new(
"test".to_string(),
"Test".to_string(),
ModelType::Vocoder,
"en".to_string(),
"Test".to_string(),
);
model.supported_backends = vec!["pytorch".to_string(), "onnx".to_string()];
assert!(model.supports_backend("pytorch"));
assert!(model.supports_backend("onnx"));
assert!(!model.supports_backend("tensorflow"));
}
}