charon_audio/
models.rs

1//! ML model backends and configuration
2
3use crate::error::{CharonError, Result};
4#[cfg(feature = "candle-backend")]
5use candle_core::{Device, Tensor};
6use ndarray::Array2;
7#[cfg(feature = "ort-backend")]
8use ort::session::{
9    builder::{GraphOptimizationLevel, SessionBuilder},
10    Session,
11};
12use serde::{Deserialize, Serialize};
13use std::path::{Path, PathBuf};
14
15/// Model backend types
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ModelBackend {
18    #[cfg(feature = "ort-backend")]
19    /// ONNX Runtime (production-ready, hardware accelerated)
20    OnnxRuntime,
21    #[cfg(feature = "candle-backend")]
22    /// HuggingFace Candle (pure Rust, flexible)
23    Candle,
24}
25
26/// Model configuration
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ModelConfig {
29    /// Path to the model file
30    pub model_path: PathBuf,
31    /// Model backend to use (optional, will be inferred if not set)
32    #[serde(skip, default)]
33    #[cfg(any(feature = "ort-backend", feature = "candle-backend"))]
34    pub backend: Option<ModelBackend>,
35    /// Expected sample rate
36    pub sample_rate: u32,
37    /// Number of audio channels
38    pub channels: usize,
39    /// Source names (e.g., ["drums", "bass", "vocals", "other"])
40    pub sources: Vec<String>,
41    /// Chunk size for processing (in samples)
42    pub chunk_size: Option<usize>,
43}
44
45impl Default for ModelConfig {
46    fn default() -> Self {
47        Self {
48            model_path: PathBuf::from("model.onnx"),
49            #[cfg(any(feature = "ort-backend", feature = "candle-backend"))]
50            backend: None, // Will be inferred from file extension
51            sample_rate: 44100,
52            channels: 2,
53            sources: vec![
54                "drums".to_string(),
55                "bass".to_string(),
56                "vocals".to_string(),
57                "other".to_string(),
58            ],
59            chunk_size: Some(441000), // 10 seconds at 44.1kHz
60        }
61    }
62}
63
64/// ONNX Runtime model wrapper
65#[cfg(feature = "ort-backend")]
66pub struct OnnxModel {
67    #[allow(dead_code)]
68    session: Session,
69    config: ModelConfig,
70}
71
72#[cfg(feature = "ort-backend")]
73impl OnnxModel {
74    /// Create new ONNX model
75    pub fn new(config: ModelConfig) -> Result<Self> {
76        let session = SessionBuilder::new()?
77            .with_optimization_level(GraphOptimizationLevel::Level3)?
78            .with_intra_threads(4)?
79            .commit_from_file(&config.model_path)?;
80
81        Ok(Self { session, config })
82    }
83
84    /// Run inference on audio data
85    pub fn infer(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>> {
86        // For now, return a placeholder implementation
87        // A real implementation would require:
88        // 1. Converting input to proper ONNX tensor format
89        // 2. Running session inference
90        // 3. Parsing output tensors
91
92        // Placeholder: return copies of input as "separated" sources
93        let num_sources = self.config.sources.len();
94        let separated = vec![input.clone(); num_sources];
95
96        Ok(separated)
97    }
98}
99
100/// Candle model wrapper (for pure Rust inference)
101#[cfg(feature = "candle-backend")]
102pub struct CandleModel {
103    device: Device,
104    config: ModelConfig,
105    model: Option<candle_nn::VarMap>,
106}
107
108#[cfg(feature = "candle-backend")]
109impl CandleModel {
110    /// Create new Candle model
111    pub fn new(config: ModelConfig) -> Result<Self> {
112        use candle_core::safetensors;
113
114        let device = if cfg!(target_arch = "wasm32") {
115            Device::Cpu
116        } else {
117            Device::cuda_if_available(0).unwrap_or(Device::Cpu)
118        };
119
120        let model = if config.model_path.exists() {
121            let tensors = safetensors::load(&config.model_path, &device)?;
122            let mut varmap = candle_nn::VarMap::new();
123            for (name, tensor) in tensors {
124                varmap
125                    .data()
126                    .lock()
127                    .unwrap()
128                    .insert(name, candle_nn::Var::from_tensor(&tensor)?);
129            }
130            Some(varmap)
131        } else {
132            None
133        };
134
135        Ok(Self {
136            device,
137            config,
138            model,
139        })
140    }
141
142    /// Run inference on audio data
143    pub fn infer(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>> {
144        let (channels, samples) = (input.nrows(), input.ncols());
145        let data: Vec<f32> = input.t().iter().copied().collect();
146
147        let tensor = Tensor::from_vec(data, (samples, channels), &self.device)?;
148
149        let output = if let Some(ref _model) = self.model {
150            tensor.clone()
151        } else {
152            tensor.clone()
153        };
154
155        let output_data: Vec<f32> = output.flatten_all()?.to_vec1()?;
156        let num_sources = self.config.sources.len();
157        let samples_per_source = output_data.len() / num_sources;
158
159        let mut separated = Vec::new();
160        for i in 0..num_sources {
161            let start = i * samples_per_source;
162            let end = start + samples_per_source;
163            let source_data = &output_data[start..end];
164
165            let mut source_array = Array2::zeros((channels, samples));
166            for (idx, &val) in source_data.iter().enumerate() {
167                let ch = idx % channels;
168                let samp = idx / channels;
169                if samp < samples {
170                    source_array[[ch, samp]] = val;
171                }
172            }
173            separated.push(source_array);
174        }
175
176        Ok(separated)
177    }
178}
179
180/// Generic model interface
181pub enum Model {
182    #[cfg(feature = "ort-backend")]
183    Onnx(OnnxModel),
184    #[cfg(feature = "candle-backend")]
185    Candle(CandleModel),
186}
187
188impl Model {
189    /// Create a new model from configuration
190    pub fn from_config(config: ModelConfig) -> Result<Self> {
191        // Infer backend from file extension if not specified
192        #[cfg(any(feature = "ort-backend", feature = "candle-backend"))]
193        let backend = config.backend.or_else(|| {
194            if config.model_path.extension()?.to_str()? == "onnx" {
195                #[cfg(feature = "ort-backend")]
196                return Some(ModelBackend::OnnxRuntime);
197            }
198            #[cfg(feature = "candle-backend")]
199            return Some(ModelBackend::Candle);
200            #[allow(unreachable_code)]
201            None
202        });
203
204        #[cfg(feature = "ort-backend")]
205        if matches!(backend, Some(ModelBackend::OnnxRuntime)) {
206            return Ok(Model::Onnx(OnnxModel::new(config)?));
207        }
208        #[cfg(feature = "candle-backend")]
209        if matches!(backend, Some(ModelBackend::Candle)) {
210            return Ok(Model::Candle(CandleModel::new(config)?));
211        }
212        Err(CharonError::NotSupported(
213            "No ML backend enabled or auto-detected".to_string(),
214        ))
215    }
216
217    /// Run inference
218    #[allow(unreachable_patterns)]
219    pub fn infer(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>> {
220        match self {
221            #[cfg(feature = "ort-backend")]
222            Model::Onnx(model) => model.infer(input),
223            #[cfg(feature = "candle-backend")]
224            Model::Candle(model) => model.infer(input),
225            #[allow(unreachable_patterns)]
226            _ => Err(CharonError::NotSupported(
227                "No model backend available".to_string(),
228            )),
229        }
230    }
231
232    /// Get model configuration
233    #[allow(unreachable_patterns)]
234    pub fn config(&self) -> &ModelConfig {
235        match self {
236            #[cfg(feature = "ort-backend")]
237            Model::Onnx(model) => &model.config,
238            #[cfg(feature = "candle-backend")]
239            Model::Candle(model) => &model.config,
240            #[allow(unreachable_patterns)]
241            _ => panic!("No model backend available"),
242        }
243    }
244}
245
246/// Model registry for managing pre-trained models
247pub struct ModelRegistry {
248    models_dir: PathBuf,
249}
250
251impl ModelRegistry {
252    /// Create new model registry
253    pub fn new<P: AsRef<Path>>(models_dir: P) -> Self {
254        Self {
255            models_dir: models_dir.as_ref().to_path_buf(),
256        }
257    }
258
259    /// List available models
260    pub fn list_models(&self) -> Result<Vec<String>> {
261        let mut models = Vec::new();
262
263        if !self.models_dir.exists() {
264            return Ok(models);
265        }
266
267        for entry in std::fs::read_dir(&self.models_dir)? {
268            let entry = entry?;
269            let path = entry.path();
270            if path.is_file() {
271                if let Some(ext) = path.extension() {
272                    if ext == "onnx" || ext == "safetensors" {
273                        if let Some(name) = path.file_stem() {
274                            models.push(name.to_string_lossy().to_string());
275                        }
276                    }
277                }
278            }
279        }
280
281        Ok(models)
282    }
283
284    /// Get model path by name
285    pub fn get_model_path(&self, name: &str) -> Option<PathBuf> {
286        let onnx_path = self.models_dir.join(format!("{name}.onnx"));
287        if onnx_path.exists() {
288            return Some(onnx_path);
289        }
290
291        let safetensors_path = self.models_dir.join(format!("{name}.safetensors"));
292        if safetensors_path.exists() {
293            return Some(safetensors_path);
294        }
295
296        None
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_model_config_default() {
306        let config = ModelConfig::default();
307        assert_eq!(config.sample_rate, 44100);
308        assert_eq!(config.channels, 2);
309        assert_eq!(config.sources.len(), 4);
310    }
311
312    #[test]
313    #[cfg(all(feature = "ort-backend", feature = "candle-backend"))]
314    fn test_model_backend_types() {
315        assert_ne!(ModelBackend::OnnxRuntime, ModelBackend::Candle);
316    }
317}