ringkernel_audio_fft/
separation.rs

1//! Direct/ambience signal separation algorithm.
2//!
3//! This module implements the coherence-based separation of direct sound
4//! from room ambience using inter-bin phase analysis.
5//!
6//! ## Algorithm Overview
7//!
8//! The separation is based on the observation that:
9//! - **Direct sound** has coherent phase relationships between neighboring frequency bins
10//!   (the phase progression follows a predictable pattern related to time-of-arrival)
11//! - **Ambience/reverb** has random, diffuse phase relationships
12//!
13//! We use multiple cues:
14//! 1. **Inter-bin phase coherence**: Direct sound shows correlated phase between neighbors
15//! 2. **Spectral flux**: Transients (direct attacks) have higher positive flux
16//! 3. **Temporal stability**: Ambience is more temporally stable
17//! 4. **Magnitude correlation**: Direct sound often shows correlated magnitude envelopes
18
19use crate::messages::{Complex, NeighborData};
20
21/// Configuration for signal separation.
22#[derive(Debug, Clone)]
23pub struct SeparationConfig {
24    /// Weight for phase coherence in separation (0.0-1.0).
25    pub phase_coherence_weight: f32,
26    /// Weight for spectral flux in separation (0.0-1.0).
27    pub spectral_flux_weight: f32,
28    /// Weight for magnitude correlation (0.0-1.0).
29    pub magnitude_correlation_weight: f32,
30    /// Transient sensitivity (higher = more sensitive to attacks).
31    pub transient_sensitivity: f32,
32    /// Temporal smoothing factor (0.0 = no smoothing, 1.0 = full smoothing).
33    pub temporal_smoothing: f32,
34    /// Separation curve exponent (higher = sharper separation).
35    pub separation_curve: f32,
36    /// Minimum coherence threshold (below this = pure ambience).
37    pub min_coherence: f32,
38    /// Maximum coherence threshold (above this = pure direct).
39    pub max_coherence: f32,
40    /// Frequency-dependent weighting (lower frequencies get more smoothing).
41    pub frequency_smoothing: bool,
42}
43
44impl Default for SeparationConfig {
45    fn default() -> Self {
46        Self {
47            phase_coherence_weight: 0.4,
48            spectral_flux_weight: 0.3,
49            magnitude_correlation_weight: 0.3,
50            transient_sensitivity: 1.0,
51            temporal_smoothing: 0.7,
52            separation_curve: 1.5,
53            min_coherence: 0.1,
54            max_coherence: 0.9,
55            frequency_smoothing: true,
56        }
57    }
58}
59
60impl SeparationConfig {
61    /// Create a new configuration with default values.
62    pub fn new() -> Self {
63        Self::default()
64    }
65
66    /// Set phase coherence weight.
67    pub fn with_phase_coherence_weight(mut self, weight: f32) -> Self {
68        self.phase_coherence_weight = weight.clamp(0.0, 1.0);
69        self
70    }
71
72    /// Set spectral flux weight.
73    pub fn with_spectral_flux_weight(mut self, weight: f32) -> Self {
74        self.spectral_flux_weight = weight.clamp(0.0, 1.0);
75        self
76    }
77
78    /// Set transient sensitivity.
79    pub fn with_transient_sensitivity(mut self, sensitivity: f32) -> Self {
80        self.transient_sensitivity = sensitivity.max(0.0);
81        self
82    }
83
84    /// Set temporal smoothing.
85    pub fn with_temporal_smoothing(mut self, smoothing: f32) -> Self {
86        self.temporal_smoothing = smoothing.clamp(0.0, 0.99);
87        self
88    }
89
90    /// Set separation curve.
91    pub fn with_separation_curve(mut self, curve: f32) -> Self {
92        self.separation_curve = curve.max(0.1);
93        self
94    }
95
96    /// Preset for music (more ambience preserved).
97    pub fn music_preset() -> Self {
98        Self {
99            phase_coherence_weight: 0.35,
100            spectral_flux_weight: 0.25,
101            magnitude_correlation_weight: 0.4,
102            transient_sensitivity: 0.8,
103            temporal_smoothing: 0.8,
104            separation_curve: 1.2,
105            min_coherence: 0.15,
106            max_coherence: 0.85,
107            frequency_smoothing: true,
108        }
109    }
110
111    /// Preset for speech (cleaner separation).
112    pub fn speech_preset() -> Self {
113        Self {
114            phase_coherence_weight: 0.5,
115            spectral_flux_weight: 0.3,
116            magnitude_correlation_weight: 0.2,
117            transient_sensitivity: 1.2,
118            temporal_smoothing: 0.6,
119            separation_curve: 2.0,
120            min_coherence: 0.1,
121            max_coherence: 0.9,
122            frequency_smoothing: true,
123        }
124    }
125
126    /// Preset for aggressive separation.
127    pub fn aggressive_preset() -> Self {
128        Self {
129            phase_coherence_weight: 0.45,
130            spectral_flux_weight: 0.35,
131            magnitude_correlation_weight: 0.2,
132            transient_sensitivity: 1.5,
133            temporal_smoothing: 0.5,
134            separation_curve: 2.5,
135            min_coherence: 0.05,
136            max_coherence: 0.95,
137            frequency_smoothing: false,
138        }
139    }
140}
141
142/// Coherence analyzer for bin-to-bin relationships.
143pub struct CoherenceAnalyzer {
144    config: SeparationConfig,
145    /// Running average of phase coherence.
146    phase_coherence_avg: f32,
147    /// Running average of magnitude.
148    magnitude_avg: f32,
149    /// Running average of spectral flux.
150    flux_avg: f32,
151    /// Frame count for averaging.
152    frame_count: u64,
153}
154
155impl CoherenceAnalyzer {
156    /// Create a new coherence analyzer.
157    pub fn new(config: SeparationConfig) -> Self {
158        Self {
159            config,
160            phase_coherence_avg: 0.0,
161            magnitude_avg: 0.0,
162            flux_avg: 0.0,
163            frame_count: 0,
164        }
165    }
166
167    /// Analyze coherence and return (coherence, transient) scores.
168    pub fn analyze(
169        &mut self,
170        current: &Complex,
171        left_neighbor: Option<&NeighborData>,
172        right_neighbor: Option<&NeighborData>,
173        _phase_derivative: f32,
174        spectral_flux: f32,
175    ) -> (f32, f32) {
176        self.frame_count += 1;
177
178        // 1. Compute phase coherence with neighbors
179        let phase_coherence = self.compute_phase_coherence(current, left_neighbor, right_neighbor);
180
181        // 2. Compute magnitude correlation
182        let magnitude_correlation =
183            self.compute_magnitude_correlation(current, left_neighbor, right_neighbor);
184
185        // 3. Compute transient score based on spectral flux
186        let transient = self.compute_transient_score(spectral_flux);
187
188        // 4. Update running averages (for adaptive thresholds)
189        let alpha = 0.99;
190        self.phase_coherence_avg =
191            self.phase_coherence_avg * alpha + phase_coherence * (1.0 - alpha);
192        self.magnitude_avg = self.magnitude_avg * alpha + current.magnitude() * (1.0 - alpha);
193        self.flux_avg = self.flux_avg * alpha + spectral_flux * (1.0 - alpha);
194
195        // 5. Combine cues with weights
196        let coherence = self.config.phase_coherence_weight * phase_coherence
197            + self.config.magnitude_correlation_weight * magnitude_correlation
198            + self.config.spectral_flux_weight * transient;
199
200        // Normalize and clamp
201        let total_weight = self.config.phase_coherence_weight
202            + self.config.magnitude_correlation_weight
203            + self.config.spectral_flux_weight;
204
205        let coherence = if total_weight > 0.0 {
206            (coherence / total_weight).clamp(self.config.min_coherence, self.config.max_coherence)
207        } else {
208            0.5
209        };
210
211        // Rescale to 0-1 range
212        let coherence = (coherence - self.config.min_coherence)
213            / (self.config.max_coherence - self.config.min_coherence);
214
215        (coherence.clamp(0.0, 1.0), transient)
216    }
217
218    /// Compute phase coherence with neighbors.
219    fn compute_phase_coherence(
220        &self,
221        current: &Complex,
222        left: Option<&NeighborData>,
223        right: Option<&NeighborData>,
224    ) -> f32 {
225        let current_phase = current.phase();
226        let mut coherence_sum = 0.0;
227        let mut count = 0;
228
229        // Compare phase with left neighbor
230        if let Some(left_data) = left {
231            let phase_diff = self.wrapped_phase_diff(current_phase, left_data.phase);
232            // Coherent signals have small phase differences or differences that
233            // follow a linear progression
234            let coherence = (-phase_diff.abs() * 2.0).exp();
235            coherence_sum += coherence;
236            count += 1;
237        }
238
239        // Compare phase with right neighbor
240        if let Some(right_data) = right {
241            let phase_diff = self.wrapped_phase_diff(current_phase, right_data.phase);
242            let coherence = (-phase_diff.abs() * 2.0).exp();
243            coherence_sum += coherence;
244            count += 1;
245        }
246
247        // Check phase derivative consistency between neighbors
248        if let (Some(left_data), Some(right_data)) = (left, right) {
249            // For coherent signals, the phase derivative should vary smoothly
250            let left_deriv = left_data.phase_derivative;
251            let right_deriv = right_data.phase_derivative;
252            let deriv_diff = (left_deriv - right_deriv).abs();
253            let deriv_coherence = (-deriv_diff).exp();
254            coherence_sum += deriv_coherence * 0.5;
255            count += 1;
256        }
257
258        if count > 0 {
259            coherence_sum / count as f32
260        } else {
261            0.5 // Default to neutral if no neighbors
262        }
263    }
264
265    /// Compute magnitude correlation with neighbors.
266    fn compute_magnitude_correlation(
267        &self,
268        current: &Complex,
269        left: Option<&NeighborData>,
270        right: Option<&NeighborData>,
271    ) -> f32 {
272        let current_mag = current.magnitude();
273        let mut correlation_sum = 0.0;
274        let mut count = 0;
275
276        if let Some(left_data) = left {
277            // Compute correlation based on relative magnitudes
278            let left_mag = left_data.magnitude;
279            if left_mag > 1e-10 && current_mag > 1e-10 {
280                let ratio = (current_mag / left_mag).ln().abs();
281                // Similar magnitudes indicate coherent source
282                let correlation = (-ratio * 0.5).exp();
283                correlation_sum += correlation;
284                count += 1;
285            }
286        }
287
288        if let Some(right_data) = right {
289            let right_mag = right_data.magnitude;
290            if right_mag > 1e-10 && current_mag > 1e-10 {
291                let ratio = (current_mag / right_mag).ln().abs();
292                let correlation = (-ratio * 0.5).exp();
293                correlation_sum += correlation;
294                count += 1;
295            }
296        }
297
298        // Check flux correlation (coherent sources have correlated flux)
299        if let (Some(left_data), Some(right_data)) = (left, right) {
300            let left_flux = left_data.spectral_flux;
301            let right_flux = right_data.spectral_flux;
302            let avg_flux = (left_flux + right_flux) / 2.0;
303            if avg_flux > 1e-6 {
304                let flux_ratio = (left_flux - right_flux).abs() / avg_flux;
305                let flux_correlation = (-flux_ratio).exp();
306                correlation_sum += flux_correlation * 0.5;
307                count += 1;
308            }
309        }
310
311        if count > 0 {
312            correlation_sum / count as f32
313        } else {
314            0.5
315        }
316    }
317
318    /// Compute transient score from spectral flux.
319    fn compute_transient_score(&self, spectral_flux: f32) -> f32 {
320        // Normalize flux relative to running average
321        let threshold = self.flux_avg * 2.0 + 0.01;
322        let normalized_flux = spectral_flux / threshold;
323
324        // Apply sensitivity and sigmoid-like shaping
325        let shaped = (normalized_flux * self.config.transient_sensitivity).tanh();
326
327        shaped.clamp(0.0, 1.0)
328    }
329
330    /// Calculate wrapped phase difference.
331    fn wrapped_phase_diff(&self, phase1: f32, phase2: f32) -> f32 {
332        let mut diff = phase1 - phase2;
333        while diff > std::f32::consts::PI {
334            diff -= 2.0 * std::f32::consts::PI;
335        }
336        while diff < -std::f32::consts::PI {
337            diff += 2.0 * std::f32::consts::PI;
338        }
339        diff
340    }
341
342    /// Reset the analyzer state.
343    pub fn reset(&mut self) {
344        self.phase_coherence_avg = 0.0;
345        self.magnitude_avg = 0.0;
346        self.flux_avg = 0.0;
347        self.frame_count = 0;
348    }
349}
350
351/// Signal separator that applies the coherence analysis to split signals.
352pub struct SignalSeparator {
353    config: SeparationConfig,
354}
355
356impl SignalSeparator {
357    /// Create a new signal separator.
358    pub fn new(config: SeparationConfig) -> Self {
359        Self { config }
360    }
361
362    /// Separate a complex value into direct and ambient components.
363    pub fn separate(&self, value: Complex, coherence: f32) -> (Complex, Complex) {
364        // Apply separation curve
365        let direct_ratio = coherence.powf(self.config.separation_curve);
366        let ambient_ratio = 1.0 - direct_ratio;
367
368        let direct = value.scale(direct_ratio);
369        let ambient = value.scale(ambient_ratio);
370
371        (direct, ambient)
372    }
373
374    /// Separate with frequency-dependent adjustment.
375    pub fn separate_with_frequency(
376        &self,
377        value: Complex,
378        coherence: f32,
379        bin_index: u32,
380        total_bins: u32,
381    ) -> (Complex, Complex) {
382        let mut adjusted_coherence = coherence;
383
384        if self.config.frequency_smoothing {
385            // Lower frequencies get more smoothing (less separation)
386            // Higher frequencies can have sharper separation
387            let freq_ratio = bin_index as f32 / total_bins as f32;
388            let freq_factor = 0.8 + 0.4 * freq_ratio; // 0.8 at DC, 1.2 at Nyquist
389
390            adjusted_coherence = coherence * freq_factor;
391            adjusted_coherence = adjusted_coherence.clamp(0.0, 1.0);
392        }
393
394        self.separate(value, adjusted_coherence)
395    }
396
397    /// Get configuration.
398    pub fn config(&self) -> &SeparationConfig {
399        &self.config
400    }
401
402    /// Update configuration.
403    pub fn set_config(&mut self, config: SeparationConfig) {
404        self.config = config;
405    }
406}
407
408/// Stereo separation for maintaining spatial information.
409pub struct StereoSeparator {
410    left_analyzer: CoherenceAnalyzer,
411    right_analyzer: CoherenceAnalyzer,
412    separator: SignalSeparator,
413    /// Cross-channel coherence weight.
414    cross_channel_weight: f32,
415}
416
417impl StereoSeparator {
418    /// Create a new stereo separator.
419    pub fn new(config: SeparationConfig) -> Self {
420        Self {
421            left_analyzer: CoherenceAnalyzer::new(config.clone()),
422            right_analyzer: CoherenceAnalyzer::new(config.clone()),
423            separator: SignalSeparator::new(config),
424            cross_channel_weight: 0.3,
425        }
426    }
427
428    /// Process stereo bins and return separated results.
429    #[allow(clippy::too_many_arguments)]
430    pub fn process_stereo(
431        &mut self,
432        left_bin: &Complex,
433        right_bin: &Complex,
434        left_neighbors: (Option<&NeighborData>, Option<&NeighborData>),
435        right_neighbors: (Option<&NeighborData>, Option<&NeighborData>),
436        left_phase_deriv: f32,
437        right_phase_deriv: f32,
438        left_flux: f32,
439        right_flux: f32,
440        bin_index: u32,
441        total_bins: u32,
442    ) -> ((Complex, Complex), (Complex, Complex)) {
443        // Analyze each channel
444        let (left_coherence, _left_transient) = self.left_analyzer.analyze(
445            left_bin,
446            left_neighbors.0,
447            left_neighbors.1,
448            left_phase_deriv,
449            left_flux,
450        );
451
452        let (right_coherence, _right_transient) = self.right_analyzer.analyze(
453            right_bin,
454            right_neighbors.0,
455            right_neighbors.1,
456            right_phase_deriv,
457            right_flux,
458        );
459
460        // Cross-channel coherence (correlated L/R = direct source)
461        let cross_coherence = self.compute_cross_channel_coherence(left_bin, right_bin);
462
463        // Combine with cross-channel information
464        let combined_left_coherence = left_coherence * (1.0 - self.cross_channel_weight)
465            + cross_coherence * self.cross_channel_weight;
466        let combined_right_coherence = right_coherence * (1.0 - self.cross_channel_weight)
467            + cross_coherence * self.cross_channel_weight;
468
469        // Separate each channel
470        let left_separated = self.separator.separate_with_frequency(
471            *left_bin,
472            combined_left_coherence,
473            bin_index,
474            total_bins,
475        );
476        let right_separated = self.separator.separate_with_frequency(
477            *right_bin,
478            combined_right_coherence,
479            bin_index,
480            total_bins,
481        );
482
483        (left_separated, right_separated)
484    }
485
486    /// Compute cross-channel coherence.
487    fn compute_cross_channel_coherence(&self, left: &Complex, right: &Complex) -> f32 {
488        // Compute correlation between left and right channels
489        let left_mag = left.magnitude();
490        let right_mag = right.magnitude();
491
492        if left_mag < 1e-10 || right_mag < 1e-10 {
493            return 0.5;
494        }
495
496        // Magnitude similarity
497        let mag_ratio = (left_mag / right_mag).ln().abs();
498        let mag_coherence = (-mag_ratio * 0.5).exp();
499
500        // Phase similarity (mono sources have similar phase)
501        let phase_diff = self.wrapped_phase_diff(left.phase(), right.phase());
502        let phase_coherence = (-phase_diff.abs() * 2.0).exp();
503
504        // Combine (high correlation in both = likely direct sound)
505        0.6 * phase_coherence + 0.4 * mag_coherence
506    }
507
508    fn wrapped_phase_diff(&self, phase1: f32, phase2: f32) -> f32 {
509        let mut diff = phase1 - phase2;
510        while diff > std::f32::consts::PI {
511            diff -= 2.0 * std::f32::consts::PI;
512        }
513        while diff < -std::f32::consts::PI {
514            diff += 2.0 * std::f32::consts::PI;
515        }
516        diff
517    }
518
519    /// Reset both analyzers.
520    pub fn reset(&mut self) {
521        self.left_analyzer.reset();
522        self.right_analyzer.reset();
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529
530    #[test]
531    fn test_separation_config_presets() {
532        let music = SeparationConfig::music_preset();
533        assert!(music.temporal_smoothing > 0.7);
534
535        let speech = SeparationConfig::speech_preset();
536        assert!(speech.separation_curve > 1.5);
537
538        let aggressive = SeparationConfig::aggressive_preset();
539        assert!(aggressive.transient_sensitivity > 1.0);
540    }
541
542    #[test]
543    fn test_coherence_analyzer() {
544        let config = SeparationConfig::default();
545        let mut analyzer = CoherenceAnalyzer::new(config);
546
547        // Test with no neighbors
548        let value = Complex::new(1.0, 0.0);
549        let (coherence, transient) = analyzer.analyze(&value, None, None, 0.0, 0.0);
550
551        assert!((0.0..=1.0).contains(&coherence));
552        assert!((0.0..=1.0).contains(&transient));
553    }
554
555    #[test]
556    fn test_signal_separator() {
557        let config = SeparationConfig::default();
558        let separator = SignalSeparator::new(config);
559
560        let value = Complex::new(1.0, 0.0);
561
562        // High coherence = mostly direct
563        let (direct, ambient) = separator.separate(value, 0.9);
564        assert!(direct.magnitude() > ambient.magnitude());
565
566        // Low coherence = mostly ambient
567        let (direct2, ambient2) = separator.separate(value, 0.1);
568        assert!(ambient2.magnitude() > direct2.magnitude());
569    }
570
571    #[test]
572    fn test_separation_preserves_energy() {
573        let config = SeparationConfig::default();
574        let separator = SignalSeparator::new(config);
575
576        let value = Complex::new(3.0, 4.0); // magnitude = 5
577        let original_energy = value.magnitude_squared();
578
579        for coherence in [0.0, 0.25, 0.5, 0.75, 1.0] {
580            let (direct, ambient) = separator.separate(value, coherence);
581            // Energy should be approximately preserved (with some curve distortion)
582            let separated_energy = direct.magnitude_squared() + ambient.magnitude_squared();
583            // Due to the power curve, exact preservation isn't guaranteed, but it should be close
584            assert!(separated_energy <= original_energy * 1.1);
585        }
586    }
587}