1use crate::error::{CharonError, Result};
4#[cfg(feature = "candle-backend")]
5use candle_core::{Device, Tensor};
6use ndarray::Array2;
7#[cfg(feature = "ort-backend")]
8use ort::session::{
9 builder::{GraphOptimizationLevel, SessionBuilder},
10 Session,
11};
12use serde::{Deserialize, Serialize};
13use std::path::{Path, PathBuf};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ModelBackend {
18 #[cfg(feature = "ort-backend")]
19 OnnxRuntime,
21 #[cfg(feature = "candle-backend")]
22 Candle,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ModelConfig {
29 pub model_path: PathBuf,
31 #[serde(skip, default)]
33 #[cfg(any(feature = "ort-backend", feature = "candle-backend"))]
34 pub backend: Option<ModelBackend>,
35 pub sample_rate: u32,
37 pub channels: usize,
39 pub sources: Vec<String>,
41 pub chunk_size: Option<usize>,
43}
44
45impl Default for ModelConfig {
46 fn default() -> Self {
47 Self {
48 model_path: PathBuf::from("model.onnx"),
49 #[cfg(any(feature = "ort-backend", feature = "candle-backend"))]
50 backend: None, sample_rate: 44100,
52 channels: 2,
53 sources: vec![
54 "drums".to_string(),
55 "bass".to_string(),
56 "vocals".to_string(),
57 "other".to_string(),
58 ],
59 chunk_size: Some(441000), }
61 }
62}
63
64#[cfg(feature = "ort-backend")]
66pub struct OnnxModel {
67 #[allow(dead_code)]
68 session: Session,
69 config: ModelConfig,
70}
71
72#[cfg(feature = "ort-backend")]
73impl OnnxModel {
74 pub fn new(config: ModelConfig) -> Result<Self> {
76 let session = SessionBuilder::new()?
77 .with_optimization_level(GraphOptimizationLevel::Level3)?
78 .with_intra_threads(4)?
79 .commit_from_file(&config.model_path)?;
80
81 Ok(Self { session, config })
82 }
83
84 pub fn infer(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>> {
86 let num_sources = self.config.sources.len();
94 let separated = vec![input.clone(); num_sources];
95
96 Ok(separated)
97 }
98}
99
100#[cfg(feature = "candle-backend")]
102pub struct CandleModel {
103 device: Device,
104 config: ModelConfig,
105 model: Option<candle_nn::VarMap>,
106}
107
108#[cfg(feature = "candle-backend")]
109impl CandleModel {
110 pub fn new(config: ModelConfig) -> Result<Self> {
112 use candle_core::safetensors;
113
114 let device = if cfg!(target_arch = "wasm32") {
115 Device::Cpu
116 } else {
117 Device::cuda_if_available(0).unwrap_or(Device::Cpu)
118 };
119
120 let model = if config.model_path.exists() {
121 let tensors = safetensors::load(&config.model_path, &device)?;
122 let mut varmap = candle_nn::VarMap::new();
123 for (name, tensor) in tensors {
124 varmap
125 .data()
126 .lock()
127 .unwrap()
128 .insert(name, candle_nn::Var::from_tensor(&tensor)?);
129 }
130 Some(varmap)
131 } else {
132 None
133 };
134
135 Ok(Self {
136 device,
137 config,
138 model,
139 })
140 }
141
142 pub fn infer(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>> {
144 let (channels, samples) = (input.nrows(), input.ncols());
145 let data: Vec<f32> = input.t().iter().copied().collect();
146
147 let tensor = Tensor::from_vec(data, (samples, channels), &self.device)?;
148
149 let output = if let Some(ref _model) = self.model {
150 tensor.clone()
151 } else {
152 tensor.clone()
153 };
154
155 let output_data: Vec<f32> = output.flatten_all()?.to_vec1()?;
156 let num_sources = self.config.sources.len();
157 let samples_per_source = output_data.len() / num_sources;
158
159 let mut separated = Vec::new();
160 for i in 0..num_sources {
161 let start = i * samples_per_source;
162 let end = start + samples_per_source;
163 let source_data = &output_data[start..end];
164
165 let mut source_array = Array2::zeros((channels, samples));
166 for (idx, &val) in source_data.iter().enumerate() {
167 let ch = idx % channels;
168 let samp = idx / channels;
169 if samp < samples {
170 source_array[[ch, samp]] = val;
171 }
172 }
173 separated.push(source_array);
174 }
175
176 Ok(separated)
177 }
178}
179
180pub enum Model {
182 #[cfg(feature = "ort-backend")]
183 Onnx(OnnxModel),
184 #[cfg(feature = "candle-backend")]
185 Candle(CandleModel),
186}
187
188impl Model {
189 pub fn from_config(config: ModelConfig) -> Result<Self> {
191 #[cfg(any(feature = "ort-backend", feature = "candle-backend"))]
193 let backend = config.backend.or_else(|| {
194 if config.model_path.extension()?.to_str()? == "onnx" {
195 #[cfg(feature = "ort-backend")]
196 return Some(ModelBackend::OnnxRuntime);
197 }
198 #[cfg(feature = "candle-backend")]
199 return Some(ModelBackend::Candle);
200 #[allow(unreachable_code)]
201 None
202 });
203
204 #[cfg(feature = "ort-backend")]
205 if matches!(backend, Some(ModelBackend::OnnxRuntime)) {
206 return Ok(Model::Onnx(OnnxModel::new(config)?));
207 }
208 #[cfg(feature = "candle-backend")]
209 if matches!(backend, Some(ModelBackend::Candle)) {
210 return Ok(Model::Candle(CandleModel::new(config)?));
211 }
212 Err(CharonError::NotSupported(
213 "No ML backend enabled or auto-detected".to_string(),
214 ))
215 }
216
217 #[allow(unreachable_patterns)]
219 pub fn infer(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>> {
220 match self {
221 #[cfg(feature = "ort-backend")]
222 Model::Onnx(model) => model.infer(input),
223 #[cfg(feature = "candle-backend")]
224 Model::Candle(model) => model.infer(input),
225 #[allow(unreachable_patterns)]
226 _ => Err(CharonError::NotSupported(
227 "No model backend available".to_string(),
228 )),
229 }
230 }
231
232 #[allow(unreachable_patterns)]
234 pub fn config(&self) -> &ModelConfig {
235 match self {
236 #[cfg(feature = "ort-backend")]
237 Model::Onnx(model) => &model.config,
238 #[cfg(feature = "candle-backend")]
239 Model::Candle(model) => &model.config,
240 #[allow(unreachable_patterns)]
241 _ => panic!("No model backend available"),
242 }
243 }
244}
245
246pub struct ModelRegistry {
248 models_dir: PathBuf,
249}
250
251impl ModelRegistry {
252 pub fn new<P: AsRef<Path>>(models_dir: P) -> Self {
254 Self {
255 models_dir: models_dir.as_ref().to_path_buf(),
256 }
257 }
258
259 pub fn list_models(&self) -> Result<Vec<String>> {
261 let mut models = Vec::new();
262
263 if !self.models_dir.exists() {
264 return Ok(models);
265 }
266
267 for entry in std::fs::read_dir(&self.models_dir)? {
268 let entry = entry?;
269 let path = entry.path();
270 if path.is_file() {
271 if let Some(ext) = path.extension() {
272 if ext == "onnx" || ext == "safetensors" {
273 if let Some(name) = path.file_stem() {
274 models.push(name.to_string_lossy().to_string());
275 }
276 }
277 }
278 }
279 }
280
281 Ok(models)
282 }
283
284 pub fn get_model_path(&self, name: &str) -> Option<PathBuf> {
286 let onnx_path = self.models_dir.join(format!("{name}.onnx"));
287 if onnx_path.exists() {
288 return Some(onnx_path);
289 }
290
291 let safetensors_path = self.models_dir.join(format!("{name}.safetensors"));
292 if safetensors_path.exists() {
293 return Some(safetensors_path);
294 }
295
296 None
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_model_config_default() {
306 let config = ModelConfig::default();
307 assert_eq!(config.sample_rate, 44100);
308 assert_eq!(config.channels, 2);
309 assert_eq!(config.sources.len(), 4);
310 }
311
312 #[test]
313 #[cfg(all(feature = "ort-backend", feature = "candle-backend"))]
314 fn test_model_backend_types() {
315 assert_ne!(ModelBackend::OnnxRuntime, ModelBackend::Candle);
316 }
317}