1use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8
9use parking_lot::RwLock;
10use tracing::info;
11
12use crate::audio_input::{AudioInput, AudioOutput, AudioSource};
13use crate::bin_actor::BinNetwork;
14use crate::error::{AudioFftError, Result};
15use crate::fft::{FftProcessor, IfftProcessor, WindowFunction};
16use crate::mixer::{FrameMixer, MixerConfig};
17use crate::separation::SeparationConfig;
18
19#[derive(Debug, Clone)]
21pub struct AudioFftProcessorBuilder {
22 fft_size: usize,
23 hop_size: usize,
24 sample_rate: Option<u32>,
25 window: WindowFunction,
26 separation_config: SeparationConfig,
27 mixer_config: MixerConfig,
28}
29
30impl Default for AudioFftProcessorBuilder {
31 fn default() -> Self {
32 Self {
33 fft_size: 2048,
34 hop_size: 512,
35 sample_rate: None,
36 window: WindowFunction::Hann,
37 separation_config: SeparationConfig::default(),
38 mixer_config: MixerConfig::default(),
39 }
40 }
41}
42
43impl AudioFftProcessorBuilder {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn fft_size(mut self, size: usize) -> Self {
51 self.fft_size = size;
52 self
53 }
54
55 pub fn hop_size(mut self, size: usize) -> Self {
57 self.hop_size = size;
58 self
59 }
60
61 pub fn sample_rate(mut self, rate: u32) -> Self {
63 self.sample_rate = Some(rate);
64 self
65 }
66
67 pub fn window(mut self, window: WindowFunction) -> Self {
69 self.window = window;
70 self
71 }
72
73 pub fn separation_config(mut self, config: SeparationConfig) -> Self {
75 self.separation_config = config;
76 self
77 }
78
79 pub fn mixer_config(mut self, config: MixerConfig) -> Self {
81 self.mixer_config = config;
82 self
83 }
84
85 pub fn music_mode(mut self) -> Self {
87 self.separation_config = SeparationConfig::music_preset();
88 self
89 }
90
91 pub fn speech_mode(mut self) -> Self {
93 self.separation_config = SeparationConfig::speech_preset();
94 self
95 }
96
97 pub async fn build(self) -> Result<AudioFftProcessor> {
99 let sample_rate = self.sample_rate.unwrap_or(44100);
100
101 info!(
102 "Building AudioFftProcessor: FFT size={}, hop={}, sample_rate={}",
103 self.fft_size, self.hop_size, sample_rate
104 );
105
106 let num_bins = self.fft_size / 2 + 1;
107 let bin_network = BinNetwork::new(num_bins, self.separation_config.clone()).await?;
108
109 Ok(AudioFftProcessor {
110 fft_size: self.fft_size,
111 hop_size: self.hop_size,
112 sample_rate,
113 window: self.window,
114 separation_config: self.separation_config,
115 mixer_config: self.mixer_config,
116 bin_network: Some(bin_network),
117 frame_counter: AtomicU64::new(0),
118 stats: Arc::new(RwLock::new(ProcessingStats::default())),
119 })
120 }
121}
122
123#[derive(Debug, Clone, Default)]
125pub struct ProcessingStats {
126 pub frames_processed: u64,
128 pub samples_processed: u64,
130 pub k2k_messages: u64,
132 pub avg_frame_time_us: f64,
134 pub peak_direct: f32,
136 pub peak_ambience: f32,
138}
139
140#[derive(Debug)]
142pub struct ProcessingOutput {
143 pub direct: AudioOutput,
145 pub ambience: AudioOutput,
147 pub mixed: AudioOutput,
149 pub stats: ProcessingStats,
151}
152
153impl ProcessingOutput {
154 pub fn new(sample_rate: u32, channels: u8) -> Self {
156 Self {
157 direct: AudioOutput::new(sample_rate, channels),
158 ambience: AudioOutput::new(sample_rate, channels),
159 mixed: AudioOutput::new(sample_rate, channels),
160 stats: ProcessingStats::default(),
161 }
162 }
163}
164
165pub struct AudioFftProcessor {
167 fft_size: usize,
169 hop_size: usize,
171 #[allow(dead_code)]
173 sample_rate: u32,
174 window: WindowFunction,
176 separation_config: SeparationConfig,
178 mixer_config: MixerConfig,
180 bin_network: Option<BinNetwork>,
182 frame_counter: AtomicU64,
184 stats: Arc<RwLock<ProcessingStats>>,
186}
187
188impl AudioFftProcessor {
189 pub fn builder() -> AudioFftProcessorBuilder {
191 AudioFftProcessorBuilder::new()
192 }
193
194 pub fn fft_size(&self) -> usize {
196 self.fft_size
197 }
198
199 pub fn hop_size(&self) -> usize {
201 self.hop_size
202 }
203
204 pub fn num_bins(&self) -> usize {
206 self.fft_size / 2 + 1
207 }
208
209 pub fn stats(&self) -> ProcessingStats {
211 self.stats.read().clone()
212 }
213
214 pub async fn process(&mut self, mut input: AudioInput) -> Result<ProcessingOutput> {
216 let sample_rate = input.sample_rate();
218 let channels = input.channels();
219
220 info!(
221 "Processing audio: {} Hz, {} channels",
222 sample_rate, channels
223 );
224
225 let mut output = ProcessingOutput::new(sample_rate, channels);
226 let mut fft_processor =
227 FftProcessor::with_window(self.fft_size, self.hop_size, sample_rate, self.window)?;
228 let mut ifft_processor =
229 IfftProcessor::with_window(self.fft_size, self.hop_size, self.window)?;
230
231 let mut frame_mixer = FrameMixer::new(self.mixer_config.clone());
232
233 let bin_network = self
234 .bin_network
235 .as_mut()
236 .ok_or_else(|| AudioFftError::kernel("Bin network not initialized"))?;
237
238 let mut total_frames = 0u64;
240 let start_time = std::time::Instant::now();
241
242 while let Some(audio_frame) = input.read_frame(self.hop_size * 4)? {
243 let samples = if channels > 1 {
245 audio_frame.channel_samples(0)
246 } else {
247 audio_frame.samples.clone()
248 };
249
250 for fft_frame in fft_processor.process_all(&samples) {
252 let frame_id = self.frame_counter.fetch_add(1, Ordering::Relaxed);
253
254 let separated = bin_network
256 .process_frame(frame_id, &fft_frame, sample_rate, self.fft_size)
257 .await?;
258
259 let mixed = frame_mixer.process(&separated);
261
262 let direct_samples = ifft_processor.process_frame(&mixed.direct_bins);
264 let ambience_samples = ifft_processor.process_frame(&mixed.ambience_bins);
265 let mixed_samples = ifft_processor.process_frame(&mixed.bins);
266
267 output.direct.append(&direct_samples);
269 output.ambience.append(&ambience_samples);
270 output.mixed.append(&mixed_samples);
271
272 total_frames += 1;
273 }
274 }
275
276 if let Some(last_frame) = fft_processor.flush() {
278 let frame_id = self.frame_counter.fetch_add(1, Ordering::Relaxed);
279 let separated = bin_network
280 .process_frame(frame_id, &last_frame, sample_rate, self.fft_size)
281 .await?;
282 let mixed = frame_mixer.process(&separated);
283
284 output
285 .direct
286 .append(&ifft_processor.process_frame(&mixed.direct_bins));
287 output
288 .ambience
289 .append(&ifft_processor.process_frame(&mixed.ambience_bins));
290 output
291 .mixed
292 .append(&ifft_processor.process_frame(&mixed.bins));
293 }
294
295 output.direct.append(&ifft_processor.flush());
297 output.ambience.append(&ifft_processor.flush());
298 output.mixed.append(&ifft_processor.flush());
299
300 let elapsed = start_time.elapsed();
301 let avg_time = if total_frames > 0 {
302 elapsed.as_micros() as f64 / total_frames as f64
303 } else {
304 0.0
305 };
306
307 let k2k_stats = bin_network.k2k_stats();
309 {
310 let mut stats = self.stats.write();
311 stats.frames_processed = total_frames;
312 stats.samples_processed = output.mixed.samples.len() as u64;
313 stats.k2k_messages = k2k_stats.messages_delivered;
314 stats.avg_frame_time_us = avg_time;
315
316 let (direct_peak, amb_peak, _) = frame_mixer.mixer().peak_levels();
317 stats.peak_direct = direct_peak;
318 stats.peak_ambience = amb_peak;
319 }
320
321 output.stats = self.stats();
322
323 info!(
324 "Processed {} frames in {:?} ({:.1} us/frame)",
325 total_frames, elapsed, avg_time
326 );
327
328 Ok(output)
329 }
330
331 pub fn process_streaming(&mut self, input: AudioInput) -> Result<StreamingProcessor> {
333 let sample_rate = input.sample_rate();
334
335 Ok(StreamingProcessor {
336 input: Some(input),
337 fft: FftProcessor::with_window(self.fft_size, self.hop_size, sample_rate, self.window)?,
338 ifft_direct: IfftProcessor::with_window(self.fft_size, self.hop_size, self.window)?,
339 ifft_ambience: IfftProcessor::with_window(self.fft_size, self.hop_size, self.window)?,
340 ifft_mixed: IfftProcessor::with_window(self.fft_size, self.hop_size, self.window)?,
341 sample_rate,
342 fft_size: self.fft_size,
343 hop_size: self.hop_size,
344 mixer: FrameMixer::new(self.mixer_config.clone()),
345 frame_counter: 0,
346 })
347 }
348
349 pub fn set_dry_wet(&mut self, dry_wet: f32) {
351 self.mixer_config.dry_wet = dry_wet.clamp(0.0, 1.0);
352 }
353
354 pub fn set_gain_db(&mut self, gain_db: f32) {
356 self.mixer_config.output_gain = 10.0_f32.powf(gain_db / 20.0);
357 }
358
359 pub fn set_separation_config(&mut self, config: SeparationConfig) {
361 self.separation_config = config;
362 }
363
364 pub async fn shutdown(&mut self) -> Result<()> {
366 if let Some(mut network) = self.bin_network.take() {
367 network.stop().await?;
368 }
369 Ok(())
370 }
371}
372
373pub struct StreamingProcessor {
375 input: Option<AudioInput>,
376 fft: FftProcessor,
377 ifft_direct: IfftProcessor,
378 ifft_ambience: IfftProcessor,
379 ifft_mixed: IfftProcessor,
380 sample_rate: u32,
381 fft_size: usize,
382 hop_size: usize,
383 mixer: FrameMixer,
384 frame_counter: u64,
385}
386
387impl StreamingProcessor {
388 pub fn set_dry_wet(&mut self, dry_wet: f32) {
390 self.mixer.set_dry_wet(dry_wet);
391 }
392
393 pub fn set_gain_db(&mut self, gain_db: f32) {
395 self.mixer.set_gain_db(gain_db);
396 }
397
398 pub async fn next(
401 &mut self,
402 bin_network: &mut BinNetwork,
403 ) -> Result<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>> {
404 let input = match &mut self.input {
405 Some(input) => input,
406 None => return Ok(None),
407 };
408
409 if input.is_exhausted() {
410 return Ok(None);
411 }
412
413 let audio_frame = match input.read_frame(self.hop_size * 2)? {
415 Some(frame) => frame,
416 None => return Ok(None),
417 };
418
419 let samples = if audio_frame.channels > 1 {
421 audio_frame.channel_samples(0)
422 } else {
423 audio_frame.samples.clone()
424 };
425
426 let mut direct_out = Vec::new();
427 let mut ambience_out = Vec::new();
428 let mut mixed_out = Vec::new();
429
430 for fft_frame in self.fft.process_all(&samples) {
432 let frame_id = self.frame_counter;
433 self.frame_counter += 1;
434
435 let separated = bin_network
437 .process_frame(frame_id, &fft_frame, self.sample_rate, self.fft_size)
438 .await?;
439
440 let mixed = self.mixer.process(&separated);
442
443 direct_out.extend(self.ifft_direct.process_frame(&mixed.direct_bins));
445 ambience_out.extend(self.ifft_ambience.process_frame(&mixed.ambience_bins));
446 mixed_out.extend(self.ifft_mixed.process_frame(&mixed.bins));
447 }
448
449 Ok(Some((direct_out, ambience_out, mixed_out)))
450 }
451
452 pub async fn flush(
454 &mut self,
455 bin_network: &mut BinNetwork,
456 ) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>)> {
457 let mut direct_out = Vec::new();
458 let mut ambience_out = Vec::new();
459 let mut mixed_out = Vec::new();
460
461 if let Some(last_frame) = self.fft.flush() {
463 let frame_id = self.frame_counter;
464 self.frame_counter += 1;
465
466 let separated = bin_network
467 .process_frame(frame_id, &last_frame, self.sample_rate, self.fft_size)
468 .await?;
469
470 let mixed = self.mixer.process(&separated);
471
472 direct_out.extend(self.ifft_direct.process_frame(&mixed.direct_bins));
473 ambience_out.extend(self.ifft_ambience.process_frame(&mixed.ambience_bins));
474 mixed_out.extend(self.ifft_mixed.process_frame(&mixed.bins));
475 }
476
477 direct_out.extend(self.ifft_direct.flush());
479 ambience_out.extend(self.ifft_ambience.flush());
480 mixed_out.extend(self.ifft_mixed.flush());
481
482 Ok((direct_out, ambience_out, mixed_out))
483 }
484}
485
486pub async fn process_file(
488 input_path: &str,
489 output_dir: &str,
490 dry_wet: f32,
491 gain_db: f32,
492) -> Result<ProcessingStats> {
493 let input = AudioInput::from_file(input_path)?;
494
495 let mut processor = AudioFftProcessor::builder()
496 .fft_size(2048)
497 .hop_size(512)
498 .mixer_config(
499 MixerConfig::new()
500 .with_dry_wet(dry_wet)
501 .with_output_gain(10.0_f32.powf(gain_db / 20.0)),
502 )
503 .build()
504 .await?;
505
506 let output = processor.process(input).await?;
507
508 let base_name = std::path::Path::new(input_path)
510 .file_stem()
511 .and_then(|s| s.to_str())
512 .unwrap_or("output");
513
514 output
515 .direct
516 .write_to_file(format!("{}/{}_direct.wav", output_dir, base_name))?;
517 output
518 .ambience
519 .write_to_file(format!("{}/{}_ambience.wav", output_dir, base_name))?;
520 output
521 .mixed
522 .write_to_file(format!("{}/{}_mixed.wav", output_dir, base_name))?;
523
524 processor.shutdown().await?;
525
526 Ok(output.stats)
527}
528
529#[cfg(test)]
530mod tests {
531 use super::*;
532
533 #[tokio::test]
534 async fn test_processor_builder() {
535 let processor = AudioFftProcessor::builder()
536 .fft_size(1024)
537 .hop_size(256)
538 .sample_rate(44100)
539 .music_mode()
540 .build()
541 .await
542 .unwrap();
543
544 assert_eq!(processor.fft_size(), 1024);
545 assert_eq!(processor.hop_size(), 256);
546 assert_eq!(processor.num_bins(), 513);
547 }
548
549 #[tokio::test]
550 async fn test_processor_with_synthetic_input() {
551 let mut processor = AudioFftProcessor::builder()
552 .fft_size(512)
553 .hop_size(128)
554 .sample_rate(44100)
555 .build()
556 .await
557 .unwrap();
558
559 let duration = 0.5;
561 let sample_rate = 44100;
562 let samples: Vec<f32> = (0..(sample_rate as f32 * duration) as usize)
563 .map(|i| {
564 (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sample_rate as f32).sin() * 0.5
565 })
566 .collect();
567
568 let input = AudioInput::from_samples(samples.clone(), sample_rate, 1);
569 let output = processor.process(input).await.unwrap();
570
571 assert!(!output.direct.samples.is_empty());
573 assert!(!output.ambience.samples.is_empty());
574 assert!(!output.mixed.samples.is_empty());
575
576 let len_diff =
578 (output.direct.samples.len() as i64 - output.ambience.samples.len() as i64).abs();
579 assert!(len_diff < 1000);
580
581 processor.shutdown().await.unwrap();
582 }
583}