charon_audio/
processor.rs1use crate::audio::AudioBuffer;
4use crate::error::Result;
5use crate::models::Model;
6use ndarray::Array2;
7use rayon::prelude::*;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ProcessConfig {
13 pub segment_length: Option<f64>,
15 pub overlap: f32,
17 pub shifts: usize,
19 pub normalize: bool,
21 pub num_jobs: usize,
23}
24
25impl Default for ProcessConfig {
26 fn default() -> Self {
27 Self {
28 segment_length: Some(10.0),
29 overlap: 0.25,
30 shifts: 1,
31 normalize: true,
32 num_jobs: 0,
33 }
34 }
35}
36
37pub struct Processor {
39 config: ProcessConfig,
40}
41
42impl Processor {
43 pub fn new(config: ProcessConfig) -> Self {
45 Self { config }
46 }
47
48 pub fn process(&self, model: &Model, audio: &AudioBuffer) -> Result<Vec<AudioBuffer>> {
50 let mut processed_audio = audio.clone();
51
52 if self.config.normalize {
54 let mean = processed_audio.data.mean().unwrap_or(0.0);
55 let std = processed_audio.data.std(0.0);
56 processed_audio
57 .data
58 .mapv_inplace(|x| (x - mean) / (std + 1e-8));
59 }
60
61 let segment_samples = self
63 .config
64 .segment_length
65 .map(|len| (len * processed_audio.sample_rate as f64) as usize);
66
67 let separated = if let Some(seg_len) = segment_samples {
68 if processed_audio.samples() > seg_len {
69 self.process_segmented(model, &processed_audio, seg_len)?
70 } else {
71 self.process_single(model, &processed_audio)?
72 }
73 } else {
74 self.process_single(model, &processed_audio)?
75 };
76
77 let mut output_buffers = Vec::new();
79 for separated_source in separated {
80 let mut buffer = AudioBuffer::new(separated_source, audio.sample_rate);
81
82 if self.config.normalize {
84 let mean = audio.data.mean().unwrap_or(0.0);
85 let std = audio.data.std(0.0);
86 buffer.data.mapv_inplace(|x| x * (std + 1e-8) + mean);
87 }
88
89 output_buffers.push(buffer);
90 }
91
92 Ok(output_buffers)
93 }
94
95 fn process_single(&self, model: &Model, audio: &AudioBuffer) -> Result<Vec<Array2<f32>>> {
97 if self.config.shifts <= 1 {
98 model.infer(&audio.data)
99 } else {
100 self.process_with_shifts(model, audio)
101 }
102 }
103
104 fn process_with_shifts(&self, model: &Model, audio: &AudioBuffer) -> Result<Vec<Array2<f32>>> {
106 let shift_amount = audio.sample_rate as usize / 2; let num_sources = model.config().sources.len();
108
109 let mut accumulated: Vec<Array2<f32>> =
110 vec![Array2::zeros((audio.channels(), audio.samples())); num_sources];
111
112 for shift_idx in 0..self.config.shifts {
113 let shift = (shift_idx * shift_amount) % audio.samples();
114
115 let mut shifted_data = audio.data.clone();
117 if shift > 0 {
118 let (left, right) = shifted_data.view().split_at(ndarray::Axis(1), shift);
119 shifted_data = ndarray::concatenate![ndarray::Axis(1), right, left];
120 }
121
122 let separated = model.infer(&shifted_data)?;
124
125 for (src_idx, mut source) in separated.into_iter().enumerate() {
127 if shift > 0 {
128 let samples = source.ncols();
129 let unshift = samples - shift;
130 let (left, right) = source.view().split_at(ndarray::Axis(1), unshift);
131 source = ndarray::concatenate![ndarray::Axis(1), right, left];
132 }
133 accumulated[src_idx] = &accumulated[src_idx] + &source;
134 }
135 }
136
137 for source in &mut accumulated {
139 *source /= self.config.shifts as f32;
140 }
141
142 Ok(accumulated)
143 }
144
145 fn process_segmented(
147 &self,
148 model: &Model,
149 audio: &AudioBuffer,
150 segment_length: usize,
151 ) -> Result<Vec<Array2<f32>>> {
152 let total_samples = audio.samples();
153 let overlap_samples = (segment_length as f32 * self.config.overlap) as usize;
154 let step = segment_length - overlap_samples;
155
156 let mut segments = Vec::new();
158 let mut pos = 0;
159 while pos < total_samples {
160 let end = (pos + segment_length).min(total_samples);
161 segments.push((pos, end));
162 pos += step;
163 if end >= total_samples {
164 break;
165 }
166 }
167
168 let num_sources = model.config().sources.len();
169 let channels = audio.channels();
170
171 let segment_results: Vec<Result<Vec<Array2<f32>>>> = if self.config.num_jobs != 1 {
173 segments
174 .par_iter()
175 .map(|&(start, end)| {
176 let segment = audio.data.slice(ndarray::s![.., start..end]).to_owned();
177 model.infer(&segment)
178 })
179 .collect()
180 } else {
181 segments
182 .iter()
183 .map(|&(start, end)| {
184 let segment = audio.data.slice(ndarray::s![.., start..end]).to_owned();
185 model.infer(&segment)
186 })
187 .collect()
188 };
189
190 let mut outputs: Vec<Array2<f32>> =
192 vec![Array2::zeros((channels, total_samples)); num_sources];
193 let mut weight = Array2::zeros((1, total_samples));
194
195 for (segment_idx, result) in segment_results.into_iter().enumerate() {
197 let separated = result?;
198 let (start, end) = segments[segment_idx];
199 let seg_len = end - start;
200
201 let fade = self.create_fade_window(seg_len, overlap_samples);
203
204 for (src_idx, source) in separated.into_iter().enumerate() {
205 for ch in 0..channels {
206 for i in 0..seg_len {
207 outputs[src_idx][[ch, start + i]] += source[[ch, i]] * fade[i];
208 }
209 }
210 }
211
212 for i in 0..seg_len {
214 weight[[0, start + i]] += fade[i];
215 }
216 }
217
218 for output in &mut outputs {
220 *output /= &weight;
221 }
222
223 Ok(outputs)
224 }
225
226 fn create_fade_window(&self, length: usize, overlap: usize) -> Vec<f32> {
228 let mut window = vec![1.0; length];
229
230 if overlap > 0 {
231 for (i, win) in window.iter_mut().enumerate().take(overlap.min(length)) {
233 let t = i as f32 / overlap as f32;
234 *win = t;
235 }
236
237 for i in 0..overlap.min(length) {
239 let idx = length - overlap + i;
240 if idx < length {
241 let t = i as f32 / overlap as f32;
242 window[idx] = 1.0 - t;
243 }
244 }
245 }
246
247 window
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_process_config_default() {
257 let config = ProcessConfig::default();
258 assert_eq!(config.overlap, 0.25);
259 assert_eq!(config.shifts, 1);
260 assert!(config.normalize);
261 }
262
263 #[test]
264 fn test_fade_window() {
265 use approx::assert_abs_diff_eq;
266
267 let processor = Processor::new(ProcessConfig::default());
268 let window = processor.create_fade_window(100, 20);
269
270 assert_eq!(window.len(), 100);
271 assert_abs_diff_eq!(window[0], 0.0, epsilon = 0.01);
273 assert!(window[99] < 0.1);
275 assert_abs_diff_eq!(window[50], 1.0, epsilon = 0.01);
277 }
278}