charon_audio/
model_zoo.rs

1//! Pre-trained model zoo for easy model access
2
3use crate::error::{CharonError, Result};
4use crate::models::ModelConfig;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8
9/// Model metadata
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ModelMetadata {
12    pub name: String,
13    pub version: String,
14    pub description: String,
15    pub sources: Vec<String>,
16    pub sample_rate: u32,
17    pub channels: usize,
18    pub file_size_mb: f64,
19    pub download_url: Option<String>,
20}
21
22/// Pre-trained model zoo
23pub struct ModelZoo {
24    models_dir: PathBuf,
25    registry: HashMap<String, ModelMetadata>,
26}
27
28impl ModelZoo {
29    /// Create new model zoo
30    pub fn new<P: AsRef<Path>>(models_dir: P) -> Result<Self> {
31        let models_dir = models_dir.as_ref().to_path_buf();
32        std::fs::create_dir_all(&models_dir)?;
33
34        let mut zoo = Self {
35            models_dir,
36            registry: HashMap::new(),
37        };
38
39        zoo.register_builtin_models();
40        Ok(zoo)
41    }
42
43    /// Register built-in models
44    fn register_builtin_models(&mut self) {
45        self.registry.insert(
46            "demucs-4stems".to_string(),
47            ModelMetadata {
48                name: "demucs-4stems".to_string(),
49                version: "1.0.0".to_string(),
50                description: "Demucs 4-stem separation (drums, bass, vocals, other)".to_string(),
51                sources: vec![
52                    "drums".to_string(),
53                    "bass".to_string(),
54                    "vocals".to_string(),
55                    "other".to_string(),
56                ],
57                sample_rate: 44100,
58                channels: 2,
59                file_size_mb: 150.0,
60                download_url: Some("https://example.com/models/demucs-4stems.onnx".to_string()),
61            },
62        );
63
64        self.registry.insert(
65            "demucs-6stems".to_string(),
66            ModelMetadata {
67                name: "demucs-6stems".to_string(),
68                version: "1.0.0".to_string(),
69                description: "Demucs 6-stem separation (drums, bass, vocals, other, piano, guitar)"
70                    .to_string(),
71                sources: vec![
72                    "drums".to_string(),
73                    "bass".to_string(),
74                    "vocals".to_string(),
75                    "other".to_string(),
76                    "piano".to_string(),
77                    "guitar".to_string(),
78                ],
79                sample_rate: 44100,
80                channels: 2,
81                file_size_mb: 200.0,
82                download_url: Some("https://example.com/models/demucs-6stems.onnx".to_string()),
83            },
84        );
85
86        self.registry.insert(
87            "vocals-only".to_string(),
88            ModelMetadata {
89                name: "vocals-only".to_string(),
90                version: "1.0.0".to_string(),
91                description: "Optimized vocal extraction model".to_string(),
92                sources: vec!["vocals".to_string(), "instrumental".to_string()],
93                sample_rate: 44100,
94                channels: 2,
95                file_size_mb: 80.0,
96                download_url: Some("https://example.com/models/vocals-only.onnx".to_string()),
97            },
98        );
99    }
100
101    /// List available models
102    pub fn list_models(&self) -> Vec<&ModelMetadata> {
103        self.registry.values().collect()
104    }
105
106    /// Get model metadata by name
107    pub fn get_metadata(&self, name: &str) -> Option<&ModelMetadata> {
108        self.registry.get(name)
109    }
110
111    /// Check if model is downloaded
112    pub fn is_downloaded(&self, name: &str) -> bool {
113        self.get_model_path(name).is_some_and(|p| p.exists())
114    }
115
116    /// Get model path
117    pub fn get_model_path(&self, name: &str) -> Option<PathBuf> {
118        let onnx_path = self.models_dir.join(format!("{name}.onnx"));
119        if onnx_path.exists() {
120            return Some(onnx_path);
121        }
122
123        let safetensors_path = self.models_dir.join(format!("{name}.safetensors"));
124        if safetensors_path.exists() {
125            return Some(safetensors_path);
126        }
127
128        None
129    }
130
131    /// Download model (placeholder - requires actual HTTP client)
132    pub fn download_model(&self, name: &str) -> Result<PathBuf> {
133        let metadata = self
134            .get_metadata(name)
135            .ok_or_else(|| CharonError::NotSupported(format!("Model {name} not found")))?;
136
137        let download_url = metadata
138            .download_url
139            .as_ref()
140            .ok_or_else(|| CharonError::NotSupported("No download URL available".to_string()))?;
141
142        let target_path = self.models_dir.join(format!("{name}.onnx"));
143
144        if target_path.exists() {
145            return Ok(target_path);
146        }
147
148        Err(CharonError::NotSupported(format!(
149            "Model download not implemented. Please manually download from: {download_url}"
150        )))
151    }
152
153    /// Load model configuration
154    pub fn load_model(&self, name: &str) -> Result<ModelConfig> {
155        let metadata = self
156            .get_metadata(name)
157            .ok_or_else(|| CharonError::NotSupported(format!("Model {name} not found")))?;
158
159        let model_path = self
160            .get_model_path(name)
161            .ok_or_else(|| CharonError::NotSupported(format!("Model {name} not downloaded")))?;
162
163        Ok(ModelConfig {
164            model_path,
165            #[cfg(any(feature = "ort-backend", feature = "candle-backend"))]
166            backend: None,
167            sample_rate: metadata.sample_rate,
168            channels: metadata.channels,
169            sources: metadata.sources.clone(),
170            chunk_size: Some(441000),
171        })
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn test_model_zoo_creation() {
181        let temp_dir = std::env::temp_dir().join("charon_test_zoo");
182        let zoo = ModelZoo::new(&temp_dir).unwrap();
183        assert!(!zoo.list_models().is_empty());
184    }
185
186    #[test]
187    fn test_model_metadata() {
188        let temp_dir = std::env::temp_dir().join("charon_test_zoo");
189        let zoo = ModelZoo::new(&temp_dir).unwrap();
190        let metadata = zoo.get_metadata("demucs-4stems");
191        assert!(metadata.is_some());
192        assert_eq!(metadata.unwrap().sources.len(), 4);
193    }
194}