charon_audio/
separator.rs

1//! Main separator API
2
3use crate::audio::{AudioBuffer, AudioFile};
4use crate::error::{CharonError, Result};
5use crate::models::{Model, ModelBackend, ModelConfig};
6use crate::processor::{ProcessConfig, Processor};
7use indicatif::{ProgressBar, ProgressStyle};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::Path;
11
12/// Separator configuration
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SeparatorConfig {
15    /// Model configuration
16    pub model: ModelConfig,
17    /// Processing configuration
18    pub process: ProcessConfig,
19    /// Show progress bars
20    pub show_progress: bool,
21}
22
23impl Default for SeparatorConfig {
24    fn default() -> Self {
25        Self {
26            model: ModelConfig::default(),
27            process: ProcessConfig::default(),
28            show_progress: true,
29        }
30    }
31}
32
33impl SeparatorConfig {
34    /// Create configuration for ONNX backend
35    #[cfg(feature = "ort-backend")]
36    pub fn onnx<P: AsRef<Path>>(model_path: P) -> Self {
37        let mut config = Self::default();
38        config.model.model_path = model_path.as_ref().to_path_buf();
39        config.model.backend = Some(ModelBackend::OnnxRuntime);
40        config
41    }
42
43    /// Create configuration for Candle backend
44    #[cfg(feature = "candle-backend")]
45    pub fn candle<P: AsRef<Path>>(model_path: P) -> Self {
46        let mut config = Self::default();
47        config.model.model_path = model_path.as_ref().to_path_buf();
48        config.model.backend = Some(ModelBackend::Candle);
49        config
50    }
51
52    /// Set number of ensemble shifts
53    pub fn with_shifts(mut self, shifts: usize) -> Self {
54        self.process.shifts = shifts;
55        self
56    }
57
58    /// Set segment length
59    pub fn with_segment_length(mut self, seconds: f64) -> Self {
60        self.process.segment_length = Some(seconds);
61        self
62    }
63
64    /// Enable/disable progress display
65    pub fn with_progress(mut self, show: bool) -> Self {
66        self.show_progress = show;
67        self
68    }
69}
70
71/// Separated audio stems
72pub struct Stems {
73    /// Map of source name to audio buffer
74    pub sources: HashMap<String, AudioBuffer>,
75}
76
77impl Stems {
78    /// Create new stems collection
79    pub fn new(sources: HashMap<String, AudioBuffer>) -> Self {
80        Self { sources }
81    }
82
83    /// Get stem by name
84    pub fn get(&self, name: &str) -> Option<&AudioBuffer> {
85        self.sources.get(name)
86    }
87
88    /// Save all stems to directory
89    pub fn save_all<P: AsRef<Path>>(&self, output_dir: P) -> Result<()> {
90        let output_dir = output_dir.as_ref();
91        std::fs::create_dir_all(output_dir)?;
92
93        for (name, buffer) in &self.sources {
94            let output_path = output_dir.join(format!("{name}.wav"));
95            AudioFile::write_wav(&output_path, buffer)?;
96        }
97
98        Ok(())
99    }
100
101    /// Save specific stem
102    pub fn save<P: AsRef<Path>>(&self, name: &str, path: P) -> Result<()> {
103        let buffer = self
104            .sources
105            .get(name)
106            .ok_or_else(|| CharonError::Audio(format!("Stem '{name}' not found")))?;
107        AudioFile::write_wav(path, buffer)
108    }
109
110    /// List available stem names
111    pub fn list(&self) -> Vec<String> {
112        self.sources.keys().cloned().collect()
113    }
114}
115
116/// Main separator for audio source separation
117pub struct Separator {
118    model: Model,
119    processor: Processor,
120    config: SeparatorConfig,
121}
122
123impl Separator {
124    /// Create new separator from configuration
125    pub fn new(config: SeparatorConfig) -> Result<Self> {
126        let model = Model::from_config(config.model.clone())?;
127        let processor = Processor::new(config.process.clone());
128
129        Ok(Self {
130            model,
131            processor,
132            config,
133        })
134    }
135
136    /// Create separator with default configuration
137    pub fn with_default_model() -> Result<Self> {
138        Self::new(SeparatorConfig::default())
139    }
140
141    /// Separate audio buffer into stems
142    pub fn separate(&self, audio: &AudioBuffer) -> Result<Stems> {
143        // Resample if needed
144        let audio = if audio.sample_rate != self.config.model.sample_rate {
145            audio.resample(self.config.model.sample_rate)?
146        } else {
147            audio.clone()
148        };
149
150        // Convert channels if needed
151        let audio = if audio.channels() != self.config.model.channels {
152            audio.convert_channels(self.config.model.channels)?
153        } else {
154            audio
155        };
156
157        // Create progress bar
158        let pb = if self.config.show_progress {
159            let pb = ProgressBar::new(100);
160            pb.set_style(
161                ProgressStyle::default_bar()
162                    .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}")
163                    .unwrap()
164                    .progress_chars("=>-"),
165            );
166            pb.set_message("Separating audio...");
167            Some(pb)
168        } else {
169            None
170        };
171
172        // Process audio
173        let separated = self.processor.process(&self.model, &audio)?;
174
175        if let Some(pb) = &pb {
176            pb.finish_with_message("Separation complete!");
177        }
178
179        // Build stems map
180        let mut sources = HashMap::new();
181        for (idx, buffer) in separated.into_iter().enumerate() {
182            if idx < self.config.model.sources.len() {
183                let name = &self.config.model.sources[idx];
184                sources.insert(name.clone(), buffer);
185            }
186        }
187
188        Ok(Stems::new(sources))
189    }
190
191    /// Separate audio from file
192    pub fn separate_file<P: AsRef<Path>>(&self, path: P) -> Result<Stems> {
193        let audio = AudioFile::read(path)?;
194        self.separate(&audio)
195    }
196
197    /// Separate audio and save stems
198    pub fn separate_and_save<P: AsRef<Path>, O: AsRef<Path>>(
199        &self,
200        input_path: P,
201        output_dir: O,
202    ) -> Result<()> {
203        let stems = self.separate_file(input_path)?;
204        stems.save_all(output_dir)
205    }
206
207    /// Batch separate multiple files
208    pub fn separate_batch<P: AsRef<Path>, O: AsRef<Path>>(
209        &self,
210        input_paths: &[P],
211        output_dir: O,
212    ) -> Result<()> {
213        let output_dir = output_dir.as_ref();
214        std::fs::create_dir_all(output_dir)?;
215
216        for (idx, input_path) in input_paths.iter().enumerate() {
217            let input_path = input_path.as_ref();
218            let file_stem = input_path
219                .file_stem()
220                .and_then(|s| s.to_str())
221                .unwrap_or("output");
222
223            let file_output = output_dir.join(file_stem);
224
225            if self.config.show_progress {
226                log::info!(
227                    "Processing file {} of {}: {:?}",
228                    idx + 1,
229                    input_paths.len(),
230                    input_path
231                );
232            }
233
234            self.separate_and_save(input_path, &file_output)?;
235        }
236
237        Ok(())
238    }
239
240    /// Get model configuration
241    pub fn model_config(&self) -> &ModelConfig {
242        &self.config.model
243    }
244
245    /// Get processing configuration
246    pub fn process_config(&self) -> &ProcessConfig {
247        &self.config.process
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_separator_config_default() {
257        let config = SeparatorConfig::default();
258        assert!(config.show_progress);
259        assert_eq!(config.model.sample_rate, 44100);
260    }
261
262    #[test]
263    fn test_stems_creation() {
264        let mut sources = HashMap::new();
265        let data = ndarray::Array2::zeros((2, 1000));
266        sources.insert("vocals".to_string(), AudioBuffer::new(data, 44100));
267
268        let stems = Stems::new(sources);
269        assert!(stems.get("vocals").is_some());
270        assert!(stems.get("drums").is_none());
271    }
272
273    #[test]
274    fn test_config_builders() {
275        let config = SeparatorConfig::onnx("model.onnx")
276            .with_shifts(2)
277            .with_segment_length(5.0)
278            .with_progress(false);
279
280        assert_eq!(config.process.shifts, 2);
281        assert_eq!(config.process.segment_length, Some(5.0));
282        assert!(!config.show_progress);
283    }
284}