Skip to main content

ringkernel_audio_fft/
processor.rs

1//! Main audio FFT processor orchestrating the full pipeline.
2//!
3//! This module provides the high-level API for processing audio through
4//! the GPU-accelerated FFT bin actor network with direct/ambience separation.
5
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8
9use parking_lot::RwLock;
10use tracing::info;
11
12use crate::audio_input::{AudioInput, AudioOutput, AudioSource};
13use crate::bin_actor::BinNetwork;
14use crate::error::{AudioFftError, Result};
15use crate::fft::{FftProcessor, IfftProcessor, WindowFunction};
16use crate::mixer::{FrameMixer, MixerConfig};
17use crate::separation::SeparationConfig;
18
19/// Builder for AudioFftProcessor.
20#[derive(Debug, Clone)]
21pub struct AudioFftProcessorBuilder {
22    fft_size: usize,
23    hop_size: usize,
24    sample_rate: Option<u32>,
25    window: WindowFunction,
26    separation_config: SeparationConfig,
27    mixer_config: MixerConfig,
28}
29
30impl Default for AudioFftProcessorBuilder {
31    fn default() -> Self {
32        Self {
33            fft_size: 2048,
34            hop_size: 512,
35            sample_rate: None,
36            window: WindowFunction::Hann,
37            separation_config: SeparationConfig::default(),
38            mixer_config: MixerConfig::default(),
39        }
40    }
41}
42
43impl AudioFftProcessorBuilder {
44    /// Create a new builder with default settings.
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Set the FFT size (must be power of 2).
50    pub fn fft_size(mut self, size: usize) -> Self {
51        self.fft_size = size;
52        self
53    }
54
55    /// Set the hop size (for overlap-add).
56    pub fn hop_size(mut self, size: usize) -> Self {
57        self.hop_size = size;
58        self
59    }
60
61    /// Set the sample rate (optional, will use input's rate if not set).
62    pub fn sample_rate(mut self, rate: u32) -> Self {
63        self.sample_rate = Some(rate);
64        self
65    }
66
67    /// Set the window function.
68    pub fn window(mut self, window: WindowFunction) -> Self {
69        self.window = window;
70        self
71    }
72
73    /// Set the separation configuration.
74    pub fn separation_config(mut self, config: SeparationConfig) -> Self {
75        self.separation_config = config;
76        self
77    }
78
79    /// Set the mixer configuration.
80    pub fn mixer_config(mut self, config: MixerConfig) -> Self {
81        self.mixer_config = config;
82        self
83    }
84
85    /// Use music preset for separation.
86    pub fn music_mode(mut self) -> Self {
87        self.separation_config = SeparationConfig::music_preset();
88        self
89    }
90
91    /// Use speech preset for separation.
92    pub fn speech_mode(mut self) -> Self {
93        self.separation_config = SeparationConfig::speech_preset();
94        self
95    }
96
97    /// Build the processor.
98    pub async fn build(self) -> Result<AudioFftProcessor> {
99        let sample_rate = self.sample_rate.unwrap_or(44100);
100
101        info!(
102            "Building AudioFftProcessor: FFT size={}, hop={}, sample_rate={}",
103            self.fft_size, self.hop_size, sample_rate
104        );
105
106        let num_bins = self.fft_size / 2 + 1;
107        let bin_network = BinNetwork::new(num_bins, self.separation_config.clone()).await?;
108
109        Ok(AudioFftProcessor {
110            fft_size: self.fft_size,
111            hop_size: self.hop_size,
112            sample_rate,
113            window: self.window,
114            separation_config: self.separation_config,
115            mixer_config: self.mixer_config,
116            bin_network: Some(bin_network),
117            frame_counter: AtomicU64::new(0),
118            stats: Arc::new(RwLock::new(ProcessingStats::default())),
119        })
120    }
121}
122
123/// Processing statistics.
124#[derive(Debug, Clone, Default)]
125pub struct ProcessingStats {
126    /// Total frames processed.
127    pub frames_processed: u64,
128    /// Total samples processed.
129    pub samples_processed: u64,
130    /// Total K2K messages exchanged.
131    pub k2k_messages: u64,
132    /// Average processing time per frame (microseconds).
133    pub avg_frame_time_us: f64,
134    /// Peak direct signal level.
135    pub peak_direct: f32,
136    /// Peak ambience signal level.
137    pub peak_ambience: f32,
138}
139
140/// Output from processing.
141#[derive(Debug)]
142pub struct ProcessingOutput {
143    /// Direct signal output.
144    pub direct: AudioOutput,
145    /// Ambience signal output.
146    pub ambience: AudioOutput,
147    /// Mixed output (based on dry/wet settings).
148    pub mixed: AudioOutput,
149    /// Processing statistics.
150    pub stats: ProcessingStats,
151}
152
153impl ProcessingOutput {
154    /// Create a new processing output.
155    pub fn new(sample_rate: u32, channels: u8) -> Self {
156        Self {
157            direct: AudioOutput::new(sample_rate, channels),
158            ambience: AudioOutput::new(sample_rate, channels),
159            mixed: AudioOutput::new(sample_rate, channels),
160            stats: ProcessingStats::default(),
161        }
162    }
163}
164
165/// Main audio FFT processor with GPU bin actors.
166pub struct AudioFftProcessor {
167    /// FFT size.
168    fft_size: usize,
169    /// Hop size.
170    hop_size: usize,
171    /// Sample rate (reserved for resampling support).
172    #[allow(dead_code)]
173    sample_rate: u32,
174    /// Window function.
175    window: WindowFunction,
176    /// Separation configuration.
177    separation_config: SeparationConfig,
178    /// Mixer configuration.
179    mixer_config: MixerConfig,
180    /// Bin actor network.
181    bin_network: Option<BinNetwork>,
182    /// Frame counter.
183    frame_counter: AtomicU64,
184    /// Processing statistics.
185    stats: Arc<RwLock<ProcessingStats>>,
186}
187
188impl AudioFftProcessor {
189    /// Create a new builder.
190    pub fn builder() -> AudioFftProcessorBuilder {
191        AudioFftProcessorBuilder::new()
192    }
193
194    /// Get the FFT size.
195    pub fn fft_size(&self) -> usize {
196        self.fft_size
197    }
198
199    /// Get the hop size.
200    pub fn hop_size(&self) -> usize {
201        self.hop_size
202    }
203
204    /// Get the number of frequency bins.
205    pub fn num_bins(&self) -> usize {
206        self.fft_size / 2 + 1
207    }
208
209    /// Get processing statistics.
210    pub fn stats(&self) -> ProcessingStats {
211        self.stats.read().clone()
212    }
213
214    /// Process an audio input and return separated outputs.
215    pub async fn process(&mut self, mut input: AudioInput) -> Result<ProcessingOutput> {
216        // Use input's sample rate if we don't have one
217        let sample_rate = input.sample_rate();
218        let channels = input.channels();
219
220        info!(
221            "Processing audio: {} Hz, {} channels",
222            sample_rate, channels
223        );
224
225        let mut output = ProcessingOutput::new(sample_rate, channels);
226        let mut fft_processor =
227            FftProcessor::with_window(self.fft_size, self.hop_size, sample_rate, self.window)?;
228        let mut ifft_processor =
229            IfftProcessor::with_window(self.fft_size, self.hop_size, self.window)?;
230
231        let mut frame_mixer = FrameMixer::new(self.mixer_config.clone());
232
233        let bin_network = self
234            .bin_network
235            .as_mut()
236            .ok_or_else(|| AudioFftError::kernel("Bin network not initialized"))?;
237
238        // Process mono for now (extract first channel if stereo)
239        let mut total_frames = 0u64;
240        let start_time = std::time::Instant::now();
241
242        while let Some(audio_frame) = input.read_frame(self.hop_size * 4)? {
243            // Get mono samples
244            let samples = if channels > 1 {
245                audio_frame.channel_samples(0)
246            } else {
247                audio_frame.samples.clone()
248            };
249
250            // Process through FFT
251            for fft_frame in fft_processor.process_all(&samples) {
252                let frame_id = self.frame_counter.fetch_add(1, Ordering::Relaxed);
253
254                // Send to bin network and get separated results
255                let separated = bin_network
256                    .process_frame(frame_id, &fft_frame, sample_rate, self.fft_size)
257                    .await?;
258
259                // Mix the separated bins
260                let mixed = frame_mixer.process(&separated);
261
262                // IFFT back to time domain
263                let direct_samples = ifft_processor.process_frame(&mixed.direct_bins);
264                let ambience_samples = ifft_processor.process_frame(&mixed.ambience_bins);
265                let mixed_samples = ifft_processor.process_frame(&mixed.bins);
266
267                // Append to outputs
268                output.direct.append(&direct_samples);
269                output.ambience.append(&ambience_samples);
270                output.mixed.append(&mixed_samples);
271
272                total_frames += 1;
273            }
274        }
275
276        // Flush remaining samples
277        if let Some(last_frame) = fft_processor.flush() {
278            let frame_id = self.frame_counter.fetch_add(1, Ordering::Relaxed);
279            let separated = bin_network
280                .process_frame(frame_id, &last_frame, sample_rate, self.fft_size)
281                .await?;
282            let mixed = frame_mixer.process(&separated);
283
284            output
285                .direct
286                .append(&ifft_processor.process_frame(&mixed.direct_bins));
287            output
288                .ambience
289                .append(&ifft_processor.process_frame(&mixed.ambience_bins));
290            output
291                .mixed
292                .append(&ifft_processor.process_frame(&mixed.bins));
293        }
294
295        // Flush IFFT
296        output.direct.append(&ifft_processor.flush());
297        output.ambience.append(&ifft_processor.flush());
298        output.mixed.append(&ifft_processor.flush());
299
300        let elapsed = start_time.elapsed();
301        let avg_time = if total_frames > 0 {
302            elapsed.as_micros() as f64 / total_frames as f64
303        } else {
304            0.0
305        };
306
307        // Update stats
308        let k2k_stats = bin_network.k2k_stats();
309        {
310            let mut stats = self.stats.write();
311            stats.frames_processed = total_frames;
312            stats.samples_processed = output.mixed.samples.len() as u64;
313            stats.k2k_messages = k2k_stats.messages_delivered;
314            stats.avg_frame_time_us = avg_time;
315
316            let (direct_peak, amb_peak, _) = frame_mixer.mixer().peak_levels();
317            stats.peak_direct = direct_peak;
318            stats.peak_ambience = amb_peak;
319        }
320
321        output.stats = self.stats();
322
323        info!(
324            "Processed {} frames in {:?} ({:.1} us/frame)",
325            total_frames, elapsed, avg_time
326        );
327
328        Ok(output)
329    }
330
331    /// Process with streaming output (for real-time use).
332    pub fn process_streaming(&mut self, input: AudioInput) -> Result<StreamingProcessor> {
333        let sample_rate = input.sample_rate();
334
335        Ok(StreamingProcessor {
336            input: Some(input),
337            fft: FftProcessor::with_window(self.fft_size, self.hop_size, sample_rate, self.window)?,
338            ifft_direct: IfftProcessor::with_window(self.fft_size, self.hop_size, self.window)?,
339            ifft_ambience: IfftProcessor::with_window(self.fft_size, self.hop_size, self.window)?,
340            ifft_mixed: IfftProcessor::with_window(self.fft_size, self.hop_size, self.window)?,
341            sample_rate,
342            fft_size: self.fft_size,
343            hop_size: self.hop_size,
344            mixer: FrameMixer::new(self.mixer_config.clone()),
345            frame_counter: 0,
346        })
347    }
348
349    /// Update the dry/wet mix.
350    pub fn set_dry_wet(&mut self, dry_wet: f32) {
351        self.mixer_config.dry_wet = dry_wet.clamp(0.0, 1.0);
352    }
353
354    /// Update the output gain in dB.
355    pub fn set_gain_db(&mut self, gain_db: f32) {
356        self.mixer_config.output_gain = 10.0_f32.powf(gain_db / 20.0);
357    }
358
359    /// Update separation configuration.
360    pub fn set_separation_config(&mut self, config: SeparationConfig) {
361        self.separation_config = config;
362    }
363
364    /// Shutdown the processor and release resources.
365    pub async fn shutdown(&mut self) -> Result<()> {
366        if let Some(mut network) = self.bin_network.take() {
367            network.stop().await?;
368        }
369        Ok(())
370    }
371}
372
373/// Streaming processor for frame-by-frame processing.
374pub struct StreamingProcessor {
375    input: Option<AudioInput>,
376    fft: FftProcessor,
377    ifft_direct: IfftProcessor,
378    ifft_ambience: IfftProcessor,
379    ifft_mixed: IfftProcessor,
380    sample_rate: u32,
381    fft_size: usize,
382    hop_size: usize,
383    mixer: FrameMixer,
384    frame_counter: u64,
385}
386
387impl StreamingProcessor {
388    /// Set the dry/wet mix.
389    pub fn set_dry_wet(&mut self, dry_wet: f32) {
390        self.mixer.set_dry_wet(dry_wet);
391    }
392
393    /// Set the output gain in dB.
394    pub fn set_gain_db(&mut self, gain_db: f32) {
395        self.mixer.set_gain_db(gain_db);
396    }
397
398    /// Process the next chunk of audio.
399    /// Returns (direct_samples, ambience_samples, mixed_samples) or None if input exhausted.
400    pub async fn next(
401        &mut self,
402        bin_network: &mut BinNetwork,
403    ) -> Result<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>> {
404        let input = match &mut self.input {
405            Some(input) => input,
406            None => return Ok(None),
407        };
408
409        if input.is_exhausted() {
410            return Ok(None);
411        }
412
413        // Read audio frame
414        let audio_frame = match input.read_frame(self.hop_size * 2)? {
415            Some(frame) => frame,
416            None => return Ok(None),
417        };
418
419        // Get mono samples
420        let samples = if audio_frame.channels > 1 {
421            audio_frame.channel_samples(0)
422        } else {
423            audio_frame.samples.clone()
424        };
425
426        let mut direct_out = Vec::new();
427        let mut ambience_out = Vec::new();
428        let mut mixed_out = Vec::new();
429
430        // Process through FFT
431        for fft_frame in self.fft.process_all(&samples) {
432            let frame_id = self.frame_counter;
433            self.frame_counter += 1;
434
435            // Process through bin network
436            let separated = bin_network
437                .process_frame(frame_id, &fft_frame, self.sample_rate, self.fft_size)
438                .await?;
439
440            // Mix
441            let mixed = self.mixer.process(&separated);
442
443            // IFFT
444            direct_out.extend(self.ifft_direct.process_frame(&mixed.direct_bins));
445            ambience_out.extend(self.ifft_ambience.process_frame(&mixed.ambience_bins));
446            mixed_out.extend(self.ifft_mixed.process_frame(&mixed.bins));
447        }
448
449        Ok(Some((direct_out, ambience_out, mixed_out)))
450    }
451
452    /// Flush remaining samples.
453    pub async fn flush(
454        &mut self,
455        bin_network: &mut BinNetwork,
456    ) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>)> {
457        let mut direct_out = Vec::new();
458        let mut ambience_out = Vec::new();
459        let mut mixed_out = Vec::new();
460
461        // Flush FFT
462        if let Some(last_frame) = self.fft.flush() {
463            let frame_id = self.frame_counter;
464            self.frame_counter += 1;
465
466            let separated = bin_network
467                .process_frame(frame_id, &last_frame, self.sample_rate, self.fft_size)
468                .await?;
469
470            let mixed = self.mixer.process(&separated);
471
472            direct_out.extend(self.ifft_direct.process_frame(&mixed.direct_bins));
473            ambience_out.extend(self.ifft_ambience.process_frame(&mixed.ambience_bins));
474            mixed_out.extend(self.ifft_mixed.process_frame(&mixed.bins));
475        }
476
477        // Flush IFFTs
478        direct_out.extend(self.ifft_direct.flush());
479        ambience_out.extend(self.ifft_ambience.flush());
480        mixed_out.extend(self.ifft_mixed.flush());
481
482        Ok((direct_out, ambience_out, mixed_out))
483    }
484}
485
486/// Simplified API for quick processing.
487pub async fn process_file(
488    input_path: &str,
489    output_dir: &str,
490    dry_wet: f32,
491    gain_db: f32,
492) -> Result<ProcessingStats> {
493    let input = AudioInput::from_file(input_path)?;
494
495    let mut processor = AudioFftProcessor::builder()
496        .fft_size(2048)
497        .hop_size(512)
498        .mixer_config(
499            MixerConfig::new()
500                .with_dry_wet(dry_wet)
501                .with_output_gain(10.0_f32.powf(gain_db / 20.0)),
502        )
503        .build()
504        .await?;
505
506    let output = processor.process(input).await?;
507
508    // Write outputs
509    let base_name = std::path::Path::new(input_path)
510        .file_stem()
511        .and_then(|s| s.to_str())
512        .unwrap_or("output");
513
514    output
515        .direct
516        .write_to_file(format!("{}/{}_direct.wav", output_dir, base_name))?;
517    output
518        .ambience
519        .write_to_file(format!("{}/{}_ambience.wav", output_dir, base_name))?;
520    output
521        .mixed
522        .write_to_file(format!("{}/{}_mixed.wav", output_dir, base_name))?;
523
524    processor.shutdown().await?;
525
526    Ok(output.stats)
527}
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532
533    #[tokio::test]
534    async fn test_processor_builder() {
535        let processor = AudioFftProcessor::builder()
536            .fft_size(1024)
537            .hop_size(256)
538            .sample_rate(44100)
539            .music_mode()
540            .build()
541            .await
542            .unwrap();
543
544        assert_eq!(processor.fft_size(), 1024);
545        assert_eq!(processor.hop_size(), 256);
546        assert_eq!(processor.num_bins(), 513);
547    }
548
549    #[tokio::test]
550    async fn test_processor_with_synthetic_input() {
551        let mut processor = AudioFftProcessor::builder()
552            .fft_size(512)
553            .hop_size(128)
554            .sample_rate(44100)
555            .build()
556            .await
557            .unwrap();
558
559        // Create synthetic input (sine wave)
560        let duration = 0.5;
561        let sample_rate = 44100;
562        let samples: Vec<f32> = (0..(sample_rate as f32 * duration) as usize)
563            .map(|i| {
564                (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sample_rate as f32).sin() * 0.5
565            })
566            .collect();
567
568        let input = AudioInput::from_samples(samples.clone(), sample_rate, 1);
569        let output = processor.process(input).await.unwrap();
570
571        // Verify output lengths are reasonable
572        assert!(!output.direct.samples.is_empty());
573        assert!(!output.ambience.samples.is_empty());
574        assert!(!output.mixed.samples.is_empty());
575
576        // All outputs should have similar length
577        let len_diff =
578            (output.direct.samples.len() as i64 - output.ambience.samples.len() as i64).abs();
579        assert!(len_diff < 1000);
580
581        processor.shutdown().await.unwrap();
582    }
583}