charon_audio/
processor.rs

1//! Audio processing pipeline
2
3use crate::audio::AudioBuffer;
4use crate::error::Result;
5use crate::models::Model;
6use ndarray::Array2;
7use rayon::prelude::*;
8use serde::{Deserialize, Serialize};
9
10/// Processing configuration
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ProcessConfig {
13    /// Segment length in seconds (for splitting long audio)
14    pub segment_length: Option<f64>,
15    /// Overlap between segments (0.0 to 1.0)
16    pub overlap: f32,
17    /// Number of shifts for ensemble prediction
18    pub shifts: usize,
19    /// Normalize input audio
20    pub normalize: bool,
21    /// Number of parallel jobs (0 = auto)
22    pub num_jobs: usize,
23}
24
25impl Default for ProcessConfig {
26    fn default() -> Self {
27        Self {
28            segment_length: Some(10.0),
29            overlap: 0.25,
30            shifts: 1,
31            normalize: true,
32            num_jobs: 0,
33        }
34    }
35}
36
37/// Audio processor for source separation
38pub struct Processor {
39    config: ProcessConfig,
40}
41
42impl Processor {
43    /// Create new processor
44    pub fn new(config: ProcessConfig) -> Self {
45        Self { config }
46    }
47
48    /// Process audio buffer with model
49    pub fn process(&self, model: &Model, audio: &AudioBuffer) -> Result<Vec<AudioBuffer>> {
50        let mut processed_audio = audio.clone();
51
52        // Normalize if requested
53        if self.config.normalize {
54            let mean = processed_audio.data.mean().unwrap_or(0.0);
55            let std = processed_audio.data.std(0.0);
56            processed_audio
57                .data
58                .mapv_inplace(|x| (x - mean) / (std + 1e-8));
59        }
60
61        // Check if we need to segment
62        let segment_samples = self
63            .config
64            .segment_length
65            .map(|len| (len * processed_audio.sample_rate as f64) as usize);
66
67        let separated = if let Some(seg_len) = segment_samples {
68            if processed_audio.samples() > seg_len {
69                self.process_segmented(model, &processed_audio, seg_len)?
70            } else {
71                self.process_single(model, &processed_audio)?
72            }
73        } else {
74            self.process_single(model, &processed_audio)?
75        };
76
77        // Denormalize and create audio buffers
78        let mut output_buffers = Vec::new();
79        for separated_source in separated {
80            let mut buffer = AudioBuffer::new(separated_source, audio.sample_rate);
81
82            // Apply inverse normalization if needed
83            if self.config.normalize {
84                let mean = audio.data.mean().unwrap_or(0.0);
85                let std = audio.data.std(0.0);
86                buffer.data.mapv_inplace(|x| x * (std + 1e-8) + mean);
87            }
88
89            output_buffers.push(buffer);
90        }
91
92        Ok(output_buffers)
93    }
94
95    /// Process single segment
96    fn process_single(&self, model: &Model, audio: &AudioBuffer) -> Result<Vec<Array2<f32>>> {
97        if self.config.shifts <= 1 {
98            model.infer(&audio.data)
99        } else {
100            self.process_with_shifts(model, audio)
101        }
102    }
103
104    /// Process with multiple shifts (ensemble)
105    fn process_with_shifts(&self, model: &Model, audio: &AudioBuffer) -> Result<Vec<Array2<f32>>> {
106        let shift_amount = audio.sample_rate as usize / 2; // 0.5 second shift
107        let num_sources = model.config().sources.len();
108
109        let mut accumulated: Vec<Array2<f32>> =
110            vec![Array2::zeros((audio.channels(), audio.samples())); num_sources];
111
112        for shift_idx in 0..self.config.shifts {
113            let shift = (shift_idx * shift_amount) % audio.samples();
114
115            // Shift input
116            let mut shifted_data = audio.data.clone();
117            if shift > 0 {
118                let (left, right) = shifted_data.view().split_at(ndarray::Axis(1), shift);
119                shifted_data = ndarray::concatenate![ndarray::Axis(1), right, left];
120            }
121
122            // Run inference
123            let separated = model.infer(&shifted_data)?;
124
125            // Shift back and accumulate
126            for (src_idx, mut source) in separated.into_iter().enumerate() {
127                if shift > 0 {
128                    let samples = source.ncols();
129                    let unshift = samples - shift;
130                    let (left, right) = source.view().split_at(ndarray::Axis(1), unshift);
131                    source = ndarray::concatenate![ndarray::Axis(1), right, left];
132                }
133                accumulated[src_idx] = &accumulated[src_idx] + &source;
134            }
135        }
136
137        // Average
138        for source in &mut accumulated {
139            *source /= self.config.shifts as f32;
140        }
141
142        Ok(accumulated)
143    }
144
145    /// Process audio in segments with overlap
146    fn process_segmented(
147        &self,
148        model: &Model,
149        audio: &AudioBuffer,
150        segment_length: usize,
151    ) -> Result<Vec<Array2<f32>>> {
152        let total_samples = audio.samples();
153        let overlap_samples = (segment_length as f32 * self.config.overlap) as usize;
154        let step = segment_length - overlap_samples;
155
156        // Calculate segments
157        let mut segments = Vec::new();
158        let mut pos = 0;
159        while pos < total_samples {
160            let end = (pos + segment_length).min(total_samples);
161            segments.push((pos, end));
162            pos += step;
163            if end >= total_samples {
164                break;
165            }
166        }
167
168        let num_sources = model.config().sources.len();
169        let channels = audio.channels();
170
171        // Process segments (can be parallelized)
172        let segment_results: Vec<Result<Vec<Array2<f32>>>> = if self.config.num_jobs != 1 {
173            segments
174                .par_iter()
175                .map(|&(start, end)| {
176                    let segment = audio.data.slice(ndarray::s![.., start..end]).to_owned();
177                    model.infer(&segment)
178                })
179                .collect()
180        } else {
181            segments
182                .iter()
183                .map(|&(start, end)| {
184                    let segment = audio.data.slice(ndarray::s![.., start..end]).to_owned();
185                    model.infer(&segment)
186                })
187                .collect()
188        };
189
190        // Initialize output arrays
191        let mut outputs: Vec<Array2<f32>> =
192            vec![Array2::zeros((channels, total_samples)); num_sources];
193        let mut weight = Array2::zeros((1, total_samples));
194
195        // Combine segments with overlap
196        for (segment_idx, result) in segment_results.into_iter().enumerate() {
197            let separated = result?;
198            let (start, end) = segments[segment_idx];
199            let seg_len = end - start;
200
201            // Create fade in/out for overlap
202            let fade = self.create_fade_window(seg_len, overlap_samples);
203
204            for (src_idx, source) in separated.into_iter().enumerate() {
205                for ch in 0..channels {
206                    for i in 0..seg_len {
207                        outputs[src_idx][[ch, start + i]] += source[[ch, i]] * fade[i];
208                    }
209                }
210            }
211
212            // Track weights for normalization
213            for i in 0..seg_len {
214                weight[[0, start + i]] += fade[i];
215            }
216        }
217
218        // Normalize by overlap weight
219        for output in &mut outputs {
220            *output /= &weight;
221        }
222
223        Ok(outputs)
224    }
225
226    /// Create fade window for overlap-add
227    fn create_fade_window(&self, length: usize, overlap: usize) -> Vec<f32> {
228        let mut window = vec![1.0; length];
229
230        if overlap > 0 {
231            // Fade in
232            for (i, win) in window.iter_mut().enumerate().take(overlap.min(length)) {
233                let t = i as f32 / overlap as f32;
234                *win = t;
235            }
236
237            // Fade out
238            for i in 0..overlap.min(length) {
239                let idx = length - overlap + i;
240                if idx < length {
241                    let t = i as f32 / overlap as f32;
242                    window[idx] = 1.0 - t;
243                }
244            }
245        }
246
247        window
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_process_config_default() {
257        let config = ProcessConfig::default();
258        assert_eq!(config.overlap, 0.25);
259        assert_eq!(config.shifts, 1);
260        assert!(config.normalize);
261    }
262
263    #[test]
264    fn test_fade_window() {
265        use approx::assert_abs_diff_eq;
266
267        let processor = Processor::new(ProcessConfig::default());
268        let window = processor.create_fade_window(100, 20);
269
270        assert_eq!(window.len(), 100);
271        // First sample starts fading in from 0
272        assert_abs_diff_eq!(window[0], 0.0, epsilon = 0.01);
273        // Last sample is fading out (near 0 but not exactly 0)
274        assert!(window[99] < 0.1);
275        // Middle should be at full volume
276        assert_abs_diff_eq!(window[50], 1.0, epsilon = 0.01);
277    }
278}