charon_audio/
separator.rs1use crate::audio::{AudioBuffer, AudioFile};
4use crate::error::{CharonError, Result};
5use crate::models::{Model, ModelBackend, ModelConfig};
6use crate::processor::{ProcessConfig, Processor};
7use indicatif::{ProgressBar, ProgressStyle};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::Path;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SeparatorConfig {
15 pub model: ModelConfig,
17 pub process: ProcessConfig,
19 pub show_progress: bool,
21}
22
23impl Default for SeparatorConfig {
24 fn default() -> Self {
25 Self {
26 model: ModelConfig::default(),
27 process: ProcessConfig::default(),
28 show_progress: true,
29 }
30 }
31}
32
33impl SeparatorConfig {
34 #[cfg(feature = "ort-backend")]
36 pub fn onnx<P: AsRef<Path>>(model_path: P) -> Self {
37 let mut config = Self::default();
38 config.model.model_path = model_path.as_ref().to_path_buf();
39 config.model.backend = Some(ModelBackend::OnnxRuntime);
40 config
41 }
42
43 #[cfg(feature = "candle-backend")]
45 pub fn candle<P: AsRef<Path>>(model_path: P) -> Self {
46 let mut config = Self::default();
47 config.model.model_path = model_path.as_ref().to_path_buf();
48 config.model.backend = Some(ModelBackend::Candle);
49 config
50 }
51
52 pub fn with_shifts(mut self, shifts: usize) -> Self {
54 self.process.shifts = shifts;
55 self
56 }
57
58 pub fn with_segment_length(mut self, seconds: f64) -> Self {
60 self.process.segment_length = Some(seconds);
61 self
62 }
63
64 pub fn with_progress(mut self, show: bool) -> Self {
66 self.show_progress = show;
67 self
68 }
69}
70
71pub struct Stems {
73 pub sources: HashMap<String, AudioBuffer>,
75}
76
77impl Stems {
78 pub fn new(sources: HashMap<String, AudioBuffer>) -> Self {
80 Self { sources }
81 }
82
83 pub fn get(&self, name: &str) -> Option<&AudioBuffer> {
85 self.sources.get(name)
86 }
87
88 pub fn save_all<P: AsRef<Path>>(&self, output_dir: P) -> Result<()> {
90 let output_dir = output_dir.as_ref();
91 std::fs::create_dir_all(output_dir)?;
92
93 for (name, buffer) in &self.sources {
94 let output_path = output_dir.join(format!("{name}.wav"));
95 AudioFile::write_wav(&output_path, buffer)?;
96 }
97
98 Ok(())
99 }
100
101 pub fn save<P: AsRef<Path>>(&self, name: &str, path: P) -> Result<()> {
103 let buffer = self
104 .sources
105 .get(name)
106 .ok_or_else(|| CharonError::Audio(format!("Stem '{name}' not found")))?;
107 AudioFile::write_wav(path, buffer)
108 }
109
110 pub fn list(&self) -> Vec<String> {
112 self.sources.keys().cloned().collect()
113 }
114}
115
116pub struct Separator {
118 model: Model,
119 processor: Processor,
120 config: SeparatorConfig,
121}
122
123impl Separator {
124 pub fn new(config: SeparatorConfig) -> Result<Self> {
126 let model = Model::from_config(config.model.clone())?;
127 let processor = Processor::new(config.process.clone());
128
129 Ok(Self {
130 model,
131 processor,
132 config,
133 })
134 }
135
136 pub fn with_default_model() -> Result<Self> {
138 Self::new(SeparatorConfig::default())
139 }
140
141 pub fn separate(&self, audio: &AudioBuffer) -> Result<Stems> {
143 let audio = if audio.sample_rate != self.config.model.sample_rate {
145 audio.resample(self.config.model.sample_rate)?
146 } else {
147 audio.clone()
148 };
149
150 let audio = if audio.channels() != self.config.model.channels {
152 audio.convert_channels(self.config.model.channels)?
153 } else {
154 audio
155 };
156
157 let pb = if self.config.show_progress {
159 let pb = ProgressBar::new(100);
160 pb.set_style(
161 ProgressStyle::default_bar()
162 .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}")
163 .unwrap()
164 .progress_chars("=>-"),
165 );
166 pb.set_message("Separating audio...");
167 Some(pb)
168 } else {
169 None
170 };
171
172 let separated = self.processor.process(&self.model, &audio)?;
174
175 if let Some(pb) = &pb {
176 pb.finish_with_message("Separation complete!");
177 }
178
179 let mut sources = HashMap::new();
181 for (idx, buffer) in separated.into_iter().enumerate() {
182 if idx < self.config.model.sources.len() {
183 let name = &self.config.model.sources[idx];
184 sources.insert(name.clone(), buffer);
185 }
186 }
187
188 Ok(Stems::new(sources))
189 }
190
191 pub fn separate_file<P: AsRef<Path>>(&self, path: P) -> Result<Stems> {
193 let audio = AudioFile::read(path)?;
194 self.separate(&audio)
195 }
196
197 pub fn separate_and_save<P: AsRef<Path>, O: AsRef<Path>>(
199 &self,
200 input_path: P,
201 output_dir: O,
202 ) -> Result<()> {
203 let stems = self.separate_file(input_path)?;
204 stems.save_all(output_dir)
205 }
206
207 pub fn separate_batch<P: AsRef<Path>, O: AsRef<Path>>(
209 &self,
210 input_paths: &[P],
211 output_dir: O,
212 ) -> Result<()> {
213 let output_dir = output_dir.as_ref();
214 std::fs::create_dir_all(output_dir)?;
215
216 for (idx, input_path) in input_paths.iter().enumerate() {
217 let input_path = input_path.as_ref();
218 let file_stem = input_path
219 .file_stem()
220 .and_then(|s| s.to_str())
221 .unwrap_or("output");
222
223 let file_output = output_dir.join(file_stem);
224
225 if self.config.show_progress {
226 log::info!(
227 "Processing file {} of {}: {:?}",
228 idx + 1,
229 input_paths.len(),
230 input_path
231 );
232 }
233
234 self.separate_and_save(input_path, &file_output)?;
235 }
236
237 Ok(())
238 }
239
240 pub fn model_config(&self) -> &ModelConfig {
242 &self.config.model
243 }
244
245 pub fn process_config(&self) -> &ProcessConfig {
247 &self.config.process
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_separator_config_default() {
257 let config = SeparatorConfig::default();
258 assert!(config.show_progress);
259 assert_eq!(config.model.sample_rate, 44100);
260 }
261
262 #[test]
263 fn test_stems_creation() {
264 let mut sources = HashMap::new();
265 let data = ndarray::Array2::zeros((2, 1000));
266 sources.insert("vocals".to_string(), AudioBuffer::new(data, 44100));
267
268 let stems = Stems::new(sources);
269 assert!(stems.get("vocals").is_some());
270 assert!(stems.get("drums").is_none());
271 }
272
273 #[test]
274 fn test_config_builders() {
275 let config = SeparatorConfig::onnx("model.onnx")
276 .with_shifts(2)
277 .with_segment_length(5.0)
278 .with_progress(false);
279
280 assert_eq!(config.process.shifts, 2);
281 assert_eq!(config.process.segment_length, Some(5.0));
282 assert!(!config.show_progress);
283 }
284}