charon_audio/
model_zoo.rs1use crate::error::{CharonError, Result};
4use crate::models::ModelConfig;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ModelMetadata {
12 pub name: String,
13 pub version: String,
14 pub description: String,
15 pub sources: Vec<String>,
16 pub sample_rate: u32,
17 pub channels: usize,
18 pub file_size_mb: f64,
19 pub download_url: Option<String>,
20}
21
22pub struct ModelZoo {
24 models_dir: PathBuf,
25 registry: HashMap<String, ModelMetadata>,
26}
27
28impl ModelZoo {
29 pub fn new<P: AsRef<Path>>(models_dir: P) -> Result<Self> {
31 let models_dir = models_dir.as_ref().to_path_buf();
32 std::fs::create_dir_all(&models_dir)?;
33
34 let mut zoo = Self {
35 models_dir,
36 registry: HashMap::new(),
37 };
38
39 zoo.register_builtin_models();
40 Ok(zoo)
41 }
42
43 fn register_builtin_models(&mut self) {
45 self.registry.insert(
46 "demucs-4stems".to_string(),
47 ModelMetadata {
48 name: "demucs-4stems".to_string(),
49 version: "1.0.0".to_string(),
50 description: "Demucs 4-stem separation (drums, bass, vocals, other)".to_string(),
51 sources: vec![
52 "drums".to_string(),
53 "bass".to_string(),
54 "vocals".to_string(),
55 "other".to_string(),
56 ],
57 sample_rate: 44100,
58 channels: 2,
59 file_size_mb: 150.0,
60 download_url: Some("https://example.com/models/demucs-4stems.onnx".to_string()),
61 },
62 );
63
64 self.registry.insert(
65 "demucs-6stems".to_string(),
66 ModelMetadata {
67 name: "demucs-6stems".to_string(),
68 version: "1.0.0".to_string(),
69 description: "Demucs 6-stem separation (drums, bass, vocals, other, piano, guitar)"
70 .to_string(),
71 sources: vec![
72 "drums".to_string(),
73 "bass".to_string(),
74 "vocals".to_string(),
75 "other".to_string(),
76 "piano".to_string(),
77 "guitar".to_string(),
78 ],
79 sample_rate: 44100,
80 channels: 2,
81 file_size_mb: 200.0,
82 download_url: Some("https://example.com/models/demucs-6stems.onnx".to_string()),
83 },
84 );
85
86 self.registry.insert(
87 "vocals-only".to_string(),
88 ModelMetadata {
89 name: "vocals-only".to_string(),
90 version: "1.0.0".to_string(),
91 description: "Optimized vocal extraction model".to_string(),
92 sources: vec!["vocals".to_string(), "instrumental".to_string()],
93 sample_rate: 44100,
94 channels: 2,
95 file_size_mb: 80.0,
96 download_url: Some("https://example.com/models/vocals-only.onnx".to_string()),
97 },
98 );
99 }
100
101 pub fn list_models(&self) -> Vec<&ModelMetadata> {
103 self.registry.values().collect()
104 }
105
106 pub fn get_metadata(&self, name: &str) -> Option<&ModelMetadata> {
108 self.registry.get(name)
109 }
110
111 pub fn is_downloaded(&self, name: &str) -> bool {
113 self.get_model_path(name).is_some_and(|p| p.exists())
114 }
115
116 pub fn get_model_path(&self, name: &str) -> Option<PathBuf> {
118 let onnx_path = self.models_dir.join(format!("{name}.onnx"));
119 if onnx_path.exists() {
120 return Some(onnx_path);
121 }
122
123 let safetensors_path = self.models_dir.join(format!("{name}.safetensors"));
124 if safetensors_path.exists() {
125 return Some(safetensors_path);
126 }
127
128 None
129 }
130
131 pub fn download_model(&self, name: &str) -> Result<PathBuf> {
133 let metadata = self
134 .get_metadata(name)
135 .ok_or_else(|| CharonError::NotSupported(format!("Model {name} not found")))?;
136
137 let download_url = metadata
138 .download_url
139 .as_ref()
140 .ok_or_else(|| CharonError::NotSupported("No download URL available".to_string()))?;
141
142 let target_path = self.models_dir.join(format!("{name}.onnx"));
143
144 if target_path.exists() {
145 return Ok(target_path);
146 }
147
148 Err(CharonError::NotSupported(format!(
149 "Model download not implemented. Please manually download from: {download_url}"
150 )))
151 }
152
153 pub fn load_model(&self, name: &str) -> Result<ModelConfig> {
155 let metadata = self
156 .get_metadata(name)
157 .ok_or_else(|| CharonError::NotSupported(format!("Model {name} not found")))?;
158
159 let model_path = self
160 .get_model_path(name)
161 .ok_or_else(|| CharonError::NotSupported(format!("Model {name} not downloaded")))?;
162
163 Ok(ModelConfig {
164 model_path,
165 #[cfg(any(feature = "ort-backend", feature = "candle-backend"))]
166 backend: None,
167 sample_rate: metadata.sample_rate,
168 channels: metadata.channels,
169 sources: metadata.sources.clone(),
170 chunk_size: Some(441000),
171 })
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn test_model_zoo_creation() {
181 let temp_dir = std::env::temp_dir().join("charon_test_zoo");
182 let zoo = ModelZoo::new(&temp_dir).unwrap();
183 assert!(!zoo.list_models().is_empty());
184 }
185
186 #[test]
187 fn test_model_metadata() {
188 let temp_dir = std::env::temp_dir().join("charon_test_zoo");
189 let zoo = ModelZoo::new(&temp_dir).unwrap();
190 let metadata = zoo.get_metadata("demucs-4stems");
191 assert!(metadata.is_some());
192 assert_eq!(metadata.unwrap().sources.len(), 4);
193 }
194}