1use crate::messages::{Complex, NeighborData};
20
21#[derive(Debug, Clone)]
23pub struct SeparationConfig {
24 pub phase_coherence_weight: f32,
26 pub spectral_flux_weight: f32,
28 pub magnitude_correlation_weight: f32,
30 pub transient_sensitivity: f32,
32 pub temporal_smoothing: f32,
34 pub separation_curve: f32,
36 pub min_coherence: f32,
38 pub max_coherence: f32,
40 pub frequency_smoothing: bool,
42}
43
44impl Default for SeparationConfig {
45 fn default() -> Self {
46 Self {
47 phase_coherence_weight: 0.4,
48 spectral_flux_weight: 0.3,
49 magnitude_correlation_weight: 0.3,
50 transient_sensitivity: 1.0,
51 temporal_smoothing: 0.7,
52 separation_curve: 1.5,
53 min_coherence: 0.1,
54 max_coherence: 0.9,
55 frequency_smoothing: true,
56 }
57 }
58}
59
60impl SeparationConfig {
61 pub fn new() -> Self {
63 Self::default()
64 }
65
66 pub fn with_phase_coherence_weight(mut self, weight: f32) -> Self {
68 self.phase_coherence_weight = weight.clamp(0.0, 1.0);
69 self
70 }
71
72 pub fn with_spectral_flux_weight(mut self, weight: f32) -> Self {
74 self.spectral_flux_weight = weight.clamp(0.0, 1.0);
75 self
76 }
77
78 pub fn with_transient_sensitivity(mut self, sensitivity: f32) -> Self {
80 self.transient_sensitivity = sensitivity.max(0.0);
81 self
82 }
83
84 pub fn with_temporal_smoothing(mut self, smoothing: f32) -> Self {
86 self.temporal_smoothing = smoothing.clamp(0.0, 0.99);
87 self
88 }
89
90 pub fn with_separation_curve(mut self, curve: f32) -> Self {
92 self.separation_curve = curve.max(0.1);
93 self
94 }
95
96 pub fn music_preset() -> Self {
98 Self {
99 phase_coherence_weight: 0.35,
100 spectral_flux_weight: 0.25,
101 magnitude_correlation_weight: 0.4,
102 transient_sensitivity: 0.8,
103 temporal_smoothing: 0.8,
104 separation_curve: 1.2,
105 min_coherence: 0.15,
106 max_coherence: 0.85,
107 frequency_smoothing: true,
108 }
109 }
110
111 pub fn speech_preset() -> Self {
113 Self {
114 phase_coherence_weight: 0.5,
115 spectral_flux_weight: 0.3,
116 magnitude_correlation_weight: 0.2,
117 transient_sensitivity: 1.2,
118 temporal_smoothing: 0.6,
119 separation_curve: 2.0,
120 min_coherence: 0.1,
121 max_coherence: 0.9,
122 frequency_smoothing: true,
123 }
124 }
125
126 pub fn aggressive_preset() -> Self {
128 Self {
129 phase_coherence_weight: 0.45,
130 spectral_flux_weight: 0.35,
131 magnitude_correlation_weight: 0.2,
132 transient_sensitivity: 1.5,
133 temporal_smoothing: 0.5,
134 separation_curve: 2.5,
135 min_coherence: 0.05,
136 max_coherence: 0.95,
137 frequency_smoothing: false,
138 }
139 }
140}
141
142pub struct CoherenceAnalyzer {
144 config: SeparationConfig,
145 phase_coherence_avg: f32,
147 magnitude_avg: f32,
149 flux_avg: f32,
151 frame_count: u64,
153}
154
155impl CoherenceAnalyzer {
156 pub fn new(config: SeparationConfig) -> Self {
158 Self {
159 config,
160 phase_coherence_avg: 0.0,
161 magnitude_avg: 0.0,
162 flux_avg: 0.0,
163 frame_count: 0,
164 }
165 }
166
167 pub fn analyze(
169 &mut self,
170 current: &Complex,
171 left_neighbor: Option<&NeighborData>,
172 right_neighbor: Option<&NeighborData>,
173 _phase_derivative: f32,
174 spectral_flux: f32,
175 ) -> (f32, f32) {
176 self.frame_count += 1;
177
178 let phase_coherence = self.compute_phase_coherence(current, left_neighbor, right_neighbor);
180
181 let magnitude_correlation =
183 self.compute_magnitude_correlation(current, left_neighbor, right_neighbor);
184
185 let transient = self.compute_transient_score(spectral_flux);
187
188 let alpha = 0.99;
190 self.phase_coherence_avg =
191 self.phase_coherence_avg * alpha + phase_coherence * (1.0 - alpha);
192 self.magnitude_avg = self.magnitude_avg * alpha + current.magnitude() * (1.0 - alpha);
193 self.flux_avg = self.flux_avg * alpha + spectral_flux * (1.0 - alpha);
194
195 let coherence = self.config.phase_coherence_weight * phase_coherence
197 + self.config.magnitude_correlation_weight * magnitude_correlation
198 + self.config.spectral_flux_weight * transient;
199
200 let total_weight = self.config.phase_coherence_weight
202 + self.config.magnitude_correlation_weight
203 + self.config.spectral_flux_weight;
204
205 let coherence = if total_weight > 0.0 {
206 (coherence / total_weight).clamp(self.config.min_coherence, self.config.max_coherence)
207 } else {
208 0.5
209 };
210
211 let coherence = (coherence - self.config.min_coherence)
213 / (self.config.max_coherence - self.config.min_coherence);
214
215 (coherence.clamp(0.0, 1.0), transient)
216 }
217
218 fn compute_phase_coherence(
220 &self,
221 current: &Complex,
222 left: Option<&NeighborData>,
223 right: Option<&NeighborData>,
224 ) -> f32 {
225 let current_phase = current.phase();
226 let mut coherence_sum = 0.0;
227 let mut count = 0;
228
229 if let Some(left_data) = left {
231 let phase_diff = self.wrapped_phase_diff(current_phase, left_data.phase);
232 let coherence = (-phase_diff.abs() * 2.0).exp();
235 coherence_sum += coherence;
236 count += 1;
237 }
238
239 if let Some(right_data) = right {
241 let phase_diff = self.wrapped_phase_diff(current_phase, right_data.phase);
242 let coherence = (-phase_diff.abs() * 2.0).exp();
243 coherence_sum += coherence;
244 count += 1;
245 }
246
247 if let (Some(left_data), Some(right_data)) = (left, right) {
249 let left_deriv = left_data.phase_derivative;
251 let right_deriv = right_data.phase_derivative;
252 let deriv_diff = (left_deriv - right_deriv).abs();
253 let deriv_coherence = (-deriv_diff).exp();
254 coherence_sum += deriv_coherence * 0.5;
255 count += 1;
256 }
257
258 if count > 0 {
259 coherence_sum / count as f32
260 } else {
261 0.5 }
263 }
264
265 fn compute_magnitude_correlation(
267 &self,
268 current: &Complex,
269 left: Option<&NeighborData>,
270 right: Option<&NeighborData>,
271 ) -> f32 {
272 let current_mag = current.magnitude();
273 let mut correlation_sum = 0.0;
274 let mut count = 0;
275
276 if let Some(left_data) = left {
277 let left_mag = left_data.magnitude;
279 if left_mag > 1e-10 && current_mag > 1e-10 {
280 let ratio = (current_mag / left_mag).ln().abs();
281 let correlation = (-ratio * 0.5).exp();
283 correlation_sum += correlation;
284 count += 1;
285 }
286 }
287
288 if let Some(right_data) = right {
289 let right_mag = right_data.magnitude;
290 if right_mag > 1e-10 && current_mag > 1e-10 {
291 let ratio = (current_mag / right_mag).ln().abs();
292 let correlation = (-ratio * 0.5).exp();
293 correlation_sum += correlation;
294 count += 1;
295 }
296 }
297
298 if let (Some(left_data), Some(right_data)) = (left, right) {
300 let left_flux = left_data.spectral_flux;
301 let right_flux = right_data.spectral_flux;
302 let avg_flux = (left_flux + right_flux) / 2.0;
303 if avg_flux > 1e-6 {
304 let flux_ratio = (left_flux - right_flux).abs() / avg_flux;
305 let flux_correlation = (-flux_ratio).exp();
306 correlation_sum += flux_correlation * 0.5;
307 count += 1;
308 }
309 }
310
311 if count > 0 {
312 correlation_sum / count as f32
313 } else {
314 0.5
315 }
316 }
317
318 fn compute_transient_score(&self, spectral_flux: f32) -> f32 {
320 let threshold = self.flux_avg * 2.0 + 0.01;
322 let normalized_flux = spectral_flux / threshold;
323
324 let shaped = (normalized_flux * self.config.transient_sensitivity).tanh();
326
327 shaped.clamp(0.0, 1.0)
328 }
329
330 fn wrapped_phase_diff(&self, phase1: f32, phase2: f32) -> f32 {
332 let mut diff = phase1 - phase2;
333 while diff > std::f32::consts::PI {
334 diff -= 2.0 * std::f32::consts::PI;
335 }
336 while diff < -std::f32::consts::PI {
337 diff += 2.0 * std::f32::consts::PI;
338 }
339 diff
340 }
341
342 pub fn reset(&mut self) {
344 self.phase_coherence_avg = 0.0;
345 self.magnitude_avg = 0.0;
346 self.flux_avg = 0.0;
347 self.frame_count = 0;
348 }
349}
350
351pub struct SignalSeparator {
353 config: SeparationConfig,
354}
355
356impl SignalSeparator {
357 pub fn new(config: SeparationConfig) -> Self {
359 Self { config }
360 }
361
362 pub fn separate(&self, value: Complex, coherence: f32) -> (Complex, Complex) {
364 let direct_ratio = coherence.powf(self.config.separation_curve);
366 let ambient_ratio = 1.0 - direct_ratio;
367
368 let direct = value.scale(direct_ratio);
369 let ambient = value.scale(ambient_ratio);
370
371 (direct, ambient)
372 }
373
374 pub fn separate_with_frequency(
376 &self,
377 value: Complex,
378 coherence: f32,
379 bin_index: u32,
380 total_bins: u32,
381 ) -> (Complex, Complex) {
382 let mut adjusted_coherence = coherence;
383
384 if self.config.frequency_smoothing {
385 let freq_ratio = bin_index as f32 / total_bins as f32;
388 let freq_factor = 0.8 + 0.4 * freq_ratio; adjusted_coherence = coherence * freq_factor;
391 adjusted_coherence = adjusted_coherence.clamp(0.0, 1.0);
392 }
393
394 self.separate(value, adjusted_coherence)
395 }
396
397 pub fn config(&self) -> &SeparationConfig {
399 &self.config
400 }
401
402 pub fn set_config(&mut self, config: SeparationConfig) {
404 self.config = config;
405 }
406}
407
408pub struct StereoSeparator {
410 left_analyzer: CoherenceAnalyzer,
411 right_analyzer: CoherenceAnalyzer,
412 separator: SignalSeparator,
413 cross_channel_weight: f32,
415}
416
417impl StereoSeparator {
418 pub fn new(config: SeparationConfig) -> Self {
420 Self {
421 left_analyzer: CoherenceAnalyzer::new(config.clone()),
422 right_analyzer: CoherenceAnalyzer::new(config.clone()),
423 separator: SignalSeparator::new(config),
424 cross_channel_weight: 0.3,
425 }
426 }
427
428 #[allow(clippy::too_many_arguments)]
430 pub fn process_stereo(
431 &mut self,
432 left_bin: &Complex,
433 right_bin: &Complex,
434 left_neighbors: (Option<&NeighborData>, Option<&NeighborData>),
435 right_neighbors: (Option<&NeighborData>, Option<&NeighborData>),
436 left_phase_deriv: f32,
437 right_phase_deriv: f32,
438 left_flux: f32,
439 right_flux: f32,
440 bin_index: u32,
441 total_bins: u32,
442 ) -> ((Complex, Complex), (Complex, Complex)) {
443 let (left_coherence, _left_transient) = self.left_analyzer.analyze(
445 left_bin,
446 left_neighbors.0,
447 left_neighbors.1,
448 left_phase_deriv,
449 left_flux,
450 );
451
452 let (right_coherence, _right_transient) = self.right_analyzer.analyze(
453 right_bin,
454 right_neighbors.0,
455 right_neighbors.1,
456 right_phase_deriv,
457 right_flux,
458 );
459
460 let cross_coherence = self.compute_cross_channel_coherence(left_bin, right_bin);
462
463 let combined_left_coherence = left_coherence * (1.0 - self.cross_channel_weight)
465 + cross_coherence * self.cross_channel_weight;
466 let combined_right_coherence = right_coherence * (1.0 - self.cross_channel_weight)
467 + cross_coherence * self.cross_channel_weight;
468
469 let left_separated = self.separator.separate_with_frequency(
471 *left_bin,
472 combined_left_coherence,
473 bin_index,
474 total_bins,
475 );
476 let right_separated = self.separator.separate_with_frequency(
477 *right_bin,
478 combined_right_coherence,
479 bin_index,
480 total_bins,
481 );
482
483 (left_separated, right_separated)
484 }
485
486 fn compute_cross_channel_coherence(&self, left: &Complex, right: &Complex) -> f32 {
488 let left_mag = left.magnitude();
490 let right_mag = right.magnitude();
491
492 if left_mag < 1e-10 || right_mag < 1e-10 {
493 return 0.5;
494 }
495
496 let mag_ratio = (left_mag / right_mag).ln().abs();
498 let mag_coherence = (-mag_ratio * 0.5).exp();
499
500 let phase_diff = self.wrapped_phase_diff(left.phase(), right.phase());
502 let phase_coherence = (-phase_diff.abs() * 2.0).exp();
503
504 0.6 * phase_coherence + 0.4 * mag_coherence
506 }
507
508 fn wrapped_phase_diff(&self, phase1: f32, phase2: f32) -> f32 {
509 let mut diff = phase1 - phase2;
510 while diff > std::f32::consts::PI {
511 diff -= 2.0 * std::f32::consts::PI;
512 }
513 while diff < -std::f32::consts::PI {
514 diff += 2.0 * std::f32::consts::PI;
515 }
516 diff
517 }
518
519 pub fn reset(&mut self) {
521 self.left_analyzer.reset();
522 self.right_analyzer.reset();
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529
530 #[test]
531 fn test_separation_config_presets() {
532 let music = SeparationConfig::music_preset();
533 assert!(music.temporal_smoothing > 0.7);
534
535 let speech = SeparationConfig::speech_preset();
536 assert!(speech.separation_curve > 1.5);
537
538 let aggressive = SeparationConfig::aggressive_preset();
539 assert!(aggressive.transient_sensitivity > 1.0);
540 }
541
542 #[test]
543 fn test_coherence_analyzer() {
544 let config = SeparationConfig::default();
545 let mut analyzer = CoherenceAnalyzer::new(config);
546
547 let value = Complex::new(1.0, 0.0);
549 let (coherence, transient) = analyzer.analyze(&value, None, None, 0.0, 0.0);
550
551 assert!((0.0..=1.0).contains(&coherence));
552 assert!((0.0..=1.0).contains(&transient));
553 }
554
555 #[test]
556 fn test_signal_separator() {
557 let config = SeparationConfig::default();
558 let separator = SignalSeparator::new(config);
559
560 let value = Complex::new(1.0, 0.0);
561
562 let (direct, ambient) = separator.separate(value, 0.9);
564 assert!(direct.magnitude() > ambient.magnitude());
565
566 let (direct2, ambient2) = separator.separate(value, 0.1);
568 assert!(ambient2.magnitude() > direct2.magnitude());
569 }
570
571 #[test]
572 fn test_separation_preserves_energy() {
573 let config = SeparationConfig::default();
574 let separator = SignalSeparator::new(config);
575
576 let value = Complex::new(3.0, 4.0); let original_energy = value.magnitude_squared();
578
579 for coherence in [0.0, 0.25, 0.5, 0.75, 1.0] {
580 let (direct, ambient) = separator.separate(value, coherence);
581 let separated_energy = direct.magnitude_squared() + ambient.magnitude_squared();
583 assert!(separated_energy <= original_energy * 1.1);
585 }
586 }
587}