Skip to main content

math_audio_dsp/
rtpghi.rs

1// ============================================================================
2// RTPGHI — Real-Time Phase Gradient Heap Integration
3// ============================================================================
4//
5// State-of-the-art causal phase reconstruction from magnitude-only STFT.
6// Uses phase gradients estimated from log-magnitude differences and
7// integrates phases starting from the highest-magnitude bins (heap-ordered).
8//
9// Reference: Průša, Z. & Søndergaard, P. (2016). "Real-Time Spectrogram
10// Inversion Using Phase Gradient Heap Integration."
11//
12// Key property: causal (only uses past frames), suitable for streaming.
13
14use std::collections::BinaryHeap;
15
16/// Phase reconstruction processor using RTPGHI.
17pub struct RtpghiProcessor {
18    fft_size: usize,
19    hop_size: usize,
20    /// Gamma parameter for the Gaussian-like window (Hann: gamma ≈ 0.25688 * fft_size²)
21    gamma: f64,
22    /// Previous frame log-magnitudes
23    prev_log_mag: Vec<f64>,
24    /// Previous frame phases
25    prev_phase: Vec<f64>,
26    /// Whether we have a previous frame
27    has_prev: bool,
28    /// Tolerance for log-magnitude (bins below this are set to random phase)
29    log_mag_tol: f64,
30
31    // Pre-allocated scratch buffers for process_frame_into (zero-alloc hot path)
32    scratch_log_mag: Vec<f64>,
33    scratch_phases: Vec<f64>,
34    scratch_integrated: Vec<bool>,
35    scratch_d_phase_time: Vec<f64>,
36    scratch_d_phase_freq: Vec<f64>,
37    scratch_heap: Vec<HeapEntry>,
38}
39
40/// A bin entry for the priority queue (max-heap by magnitude).
41#[derive(PartialEq)]
42struct HeapEntry {
43    magnitude: f64,
44    bin: usize,
45}
46
47impl Eq for HeapEntry {}
48
49impl PartialOrd for HeapEntry {
50    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
51        Some(self.cmp(other))
52    }
53}
54
55impl Ord for HeapEntry {
56    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
57        self.magnitude
58            .partial_cmp(&other.magnitude)
59            .unwrap_or(std::cmp::Ordering::Equal)
60    }
61}
62
63impl RtpghiProcessor {
64    /// Create a new RTPGHI processor.
65    ///
66    /// # Arguments
67    /// * `fft_size` - FFT size (must be power of 2)
68    /// * `hop_size` - Hop size in samples
69    pub fn new(fft_size: usize, hop_size: usize) -> Self {
70        let spectrum_size = fft_size / 2 + 1;
71
72        // Lambda parameter for Hann window approximated as Gaussian:
73        // lambda ≈ 0.17 * M^2 (from LTFAT reference implementation)
74        let gamma = 0.17 * (fft_size as f64) * (fft_size as f64);
75
76        Self {
77            fft_size,
78            hop_size,
79            gamma,
80            prev_log_mag: vec![f64::NEG_INFINITY; spectrum_size],
81            prev_phase: vec![0.0; spectrum_size],
82            has_prev: false,
83            log_mag_tol: -60.0, // -60 dB threshold
84            scratch_log_mag: vec![0.0; spectrum_size],
85            scratch_phases: vec![0.0; spectrum_size],
86            scratch_integrated: vec![false; spectrum_size],
87            scratch_d_phase_time: vec![0.0; spectrum_size],
88            scratch_d_phase_freq: vec![0.0; spectrum_size],
89            scratch_heap: Vec::with_capacity(spectrum_size),
90        }
91    }
92
93    /// Process one STFT frame: given magnitudes, reconstruct phases.
94    ///
95    /// # Arguments
96    /// * `magnitudes` - Magnitude spectrum (spectrum_size = fft_size/2 + 1)
97    ///
98    /// # Returns
99    /// Reconstructed phase values for each bin
100    pub fn process_frame(&mut self, magnitudes: &[f32]) -> Vec<f32> {
101        let spectrum_size = self.fft_size / 2 + 1;
102        assert_eq!(
103            magnitudes.len(),
104            spectrum_size,
105            "Expected {} magnitudes, got {}",
106            spectrum_size,
107            magnitudes.len()
108        );
109
110        // Compute log-magnitudes
111        let log_mag: Vec<f64> = magnitudes
112            .iter()
113            .map(|&m| {
114                if m > 0.0 {
115                    (m as f64).ln()
116                } else {
117                    f64::NEG_INFINITY
118                }
119            })
120            .collect();
121
122        let mut phases = vec![0.0f64; spectrum_size];
123        let mut integrated = vec![false; spectrum_size];
124
125        if !self.has_prev {
126            // First frame: use zero phase
127            self.prev_log_mag = log_mag.clone();
128            self.prev_phase = phases.clone();
129            self.has_prev = true;
130            return phases.iter().map(|&p| p as f32).collect();
131        }
132
133        // Phase gradient estimation
134        let hop = self.hop_size as f64;
135        let two_pi = 2.0 * std::f64::consts::PI;
136
137        // Time-direction phase gradient (from previous frame)
138        // Formula: Δ_t φ[m,k] = 2πk·a/M + γ·(log|c[m,k]| - log|c[m-1,k]|)
139        let d_phase_time: Vec<f64> = (0..spectrum_size)
140            .map(|k| {
141                // Expected phase advance from bin frequency
142                let omega_k = two_pi * k as f64 / self.fft_size as f64;
143                let expected_advance = omega_k * hop;
144
145                // Phase gradient correction from log-magnitude difference
146                let time_grad =
147                    if log_mag[k] > self.log_mag_tol && self.prev_log_mag[k] > self.log_mag_tol {
148                        self.gamma * (log_mag[k] - self.prev_log_mag[k])
149                    } else {
150                        0.0
151                    };
152
153                expected_advance + time_grad
154            })
155            .collect();
156
157        // Frequency-direction phase gradient
158        // Formula: Δ_ω φ[m,k] = (1/γ)·(log|c[m,k+1]| - log|c[m,k-1]|)/2
159        let inv_gamma = if self.gamma.abs() > 1e-30 {
160            1.0 / self.gamma
161        } else {
162            0.0
163        };
164        let d_phase_freq: Vec<f64> = (0..spectrum_size)
165            .map(|k| {
166                if k == 0 || k == spectrum_size - 1 {
167                    return 0.0;
168                }
169                if log_mag[k] > self.log_mag_tol
170                    && log_mag[k - 1] > self.log_mag_tol
171                    && log_mag[k + 1] > self.log_mag_tol
172                {
173                    inv_gamma * (log_mag[k + 1] - log_mag[k - 1]) / 2.0
174                } else {
175                    0.0
176                }
177            })
178            .collect();
179
180        // Build max-heap ordered by magnitude
181        let mut heap = BinaryHeap::new();
182        for (k, &mag) in log_mag.iter().enumerate() {
183            if mag > self.log_mag_tol {
184                heap.push(HeapEntry {
185                    magnitude: mag,
186                    bin: k,
187                });
188            }
189        }
190
191        // Integrate phases starting from loudest bins
192        while let Some(entry) = heap.pop() {
193            let k = entry.bin;
194            if integrated[k] {
195                continue;
196            }
197
198            // Try to get phase from already-integrated neighbor or previous frame
199            let phase_from_time = self.prev_phase[k] + d_phase_time[k];
200
201            let phase_from_freq_below = if k > 0 && integrated[k - 1] {
202                Some(phases[k - 1] + d_phase_freq[k - 1])
203            } else {
204                None
205            };
206
207            let phase_from_freq_above = if k + 1 < spectrum_size && integrated[k + 1] {
208                Some(phases[k + 1] - d_phase_freq[k + 1])
209            } else {
210                None
211            };
212
213            // Choose the estimate from the highest-magnitude source
214            let phase = match (phase_from_freq_below, phase_from_freq_above) {
215                (Some(below), Some(above)) => {
216                    // Average the two frequency-direction estimates
217                    let avg = (below + above) / 2.0;
218                    // If previous frame also available, weight by magnitude
219                    if self.prev_log_mag[k] > self.log_mag_tol {
220                        (avg + phase_from_time) / 2.0
221                    } else {
222                        avg
223                    }
224                }
225                (Some(below), None) => {
226                    if self.prev_log_mag[k] > self.log_mag_tol {
227                        (below + phase_from_time) / 2.0
228                    } else {
229                        below
230                    }
231                }
232                (None, Some(above)) => {
233                    if self.prev_log_mag[k] > self.log_mag_tol {
234                        (above + phase_from_time) / 2.0
235                    } else {
236                        above
237                    }
238                }
239                (None, None) => phase_from_time,
240            };
241
242            phases[k] = phase;
243            integrated[k] = true;
244        }
245
246        // Bins below threshold get zero phase
247        for k in 0..spectrum_size {
248            if !integrated[k] {
249                phases[k] = 0.0;
250            }
251        }
252
253        // Store for next frame
254        self.prev_log_mag = log_mag;
255        self.prev_phase = phases.clone();
256
257        phases.iter().map(|&p| p as f32).collect()
258    }
259
260    /// Process one STFT frame without allocations: given magnitudes, write
261    /// reconstructed phases into the provided output slice.
262    ///
263    /// # Arguments
264    /// * `magnitudes` - Magnitude spectrum (spectrum_size = fft_size/2 + 1)
265    /// * `phases_out` - Output slice for reconstructed phases (same length)
266    ///
267    /// # Panics
268    /// If `magnitudes` or `phases_out` length does not equal `fft_size/2 + 1`.
269    pub fn process_frame_into(&mut self, magnitudes: &[f32], phases_out: &mut [f32]) {
270        let spectrum_size = self.fft_size / 2 + 1;
271        assert_eq!(magnitudes.len(), spectrum_size);
272        assert_eq!(phases_out.len(), spectrum_size);
273
274        let log_mag = &mut self.scratch_log_mag;
275        let phases = &mut self.scratch_phases;
276        let integrated = &mut self.scratch_integrated;
277        let d_phase_time = &mut self.scratch_d_phase_time;
278        let d_phase_freq = &mut self.scratch_d_phase_freq;
279
280        // Compute log-magnitudes
281        for (i, &m) in magnitudes.iter().enumerate() {
282            log_mag[i] = if m > 0.0 {
283                (m as f64).ln()
284            } else {
285                f64::NEG_INFINITY
286            };
287        }
288
289        // Zero scratch
290        for v in phases.iter_mut() {
291            *v = 0.0;
292        }
293        for v in integrated.iter_mut() {
294            *v = false;
295        }
296
297        if !self.has_prev {
298            // First frame: use zero phase
299            self.prev_log_mag.copy_from_slice(log_mag);
300            self.prev_phase.copy_from_slice(phases);
301            self.has_prev = true;
302            for (out, &p) in phases_out.iter_mut().zip(phases.iter()) {
303                *out = p as f32;
304            }
305            return;
306        }
307
308        // Phase gradient estimation
309        let hop = self.hop_size as f64;
310        let two_pi = 2.0 * std::f64::consts::PI;
311        let gamma = self.gamma;
312        let log_mag_tol = self.log_mag_tol;
313        let fft_size = self.fft_size;
314
315        // Time-direction phase gradient
316        for k in 0..spectrum_size {
317            let omega_k = two_pi * k as f64 / fft_size as f64;
318            let expected_advance = omega_k * hop;
319            let time_grad = if log_mag[k] > log_mag_tol && self.prev_log_mag[k] > log_mag_tol {
320                gamma * (log_mag[k] - self.prev_log_mag[k])
321            } else {
322                0.0
323            };
324            d_phase_time[k] = expected_advance + time_grad;
325        }
326
327        // Frequency-direction phase gradient
328        let inv_gamma = if gamma.abs() > 1e-30 {
329            1.0 / gamma
330        } else {
331            0.0
332        };
333        d_phase_freq[0] = 0.0;
334        if spectrum_size > 1 {
335            d_phase_freq[spectrum_size - 1] = 0.0;
336        }
337        for k in 1..spectrum_size.saturating_sub(1) {
338            d_phase_freq[k] = if log_mag[k] > log_mag_tol
339                && log_mag[k - 1] > log_mag_tol
340                && log_mag[k + 1] > log_mag_tol
341            {
342                inv_gamma * (log_mag[k + 1] - log_mag[k - 1]) / 2.0
343            } else {
344                0.0
345            };
346        }
347
348        // Build sorted list by magnitude descending (reuse pre-allocated vec)
349        self.scratch_heap.clear();
350        for (k, &mag) in log_mag.iter().enumerate() {
351            if mag > log_mag_tol {
352                self.scratch_heap.push(HeapEntry {
353                    magnitude: mag,
354                    bin: k,
355                });
356            }
357        }
358        // Sort descending by magnitude (highest first) -- no heap allocation needed
359        self.scratch_heap.sort_unstable_by(|a, b| b.cmp(a));
360
361        // Integrate phases starting from loudest bins
362        for idx in 0..self.scratch_heap.len() {
363            let k = self.scratch_heap[idx].bin;
364            if integrated[k] {
365                continue;
366            }
367
368            let phase_from_time = self.prev_phase[k] + d_phase_time[k];
369
370            let phase_from_freq_below = if k > 0 && integrated[k - 1] {
371                Some(phases[k - 1] + d_phase_freq[k - 1])
372            } else {
373                None
374            };
375
376            let phase_from_freq_above = if k + 1 < spectrum_size && integrated[k + 1] {
377                Some(phases[k + 1] - d_phase_freq[k + 1])
378            } else {
379                None
380            };
381
382            let phase = match (phase_from_freq_below, phase_from_freq_above) {
383                (Some(below), Some(above)) => {
384                    let avg = (below + above) / 2.0;
385                    if self.prev_log_mag[k] > log_mag_tol {
386                        (avg + phase_from_time) / 2.0
387                    } else {
388                        avg
389                    }
390                }
391                (Some(below), None) => {
392                    if self.prev_log_mag[k] > log_mag_tol {
393                        (below + phase_from_time) / 2.0
394                    } else {
395                        below
396                    }
397                }
398                (None, Some(above)) => {
399                    if self.prev_log_mag[k] > log_mag_tol {
400                        (above + phase_from_time) / 2.0
401                    } else {
402                        above
403                    }
404                }
405                (None, None) => phase_from_time,
406            };
407
408            phases[k] = phase;
409            integrated[k] = true;
410        }
411
412        // Bins below threshold get zero phase
413        for k in 0..spectrum_size {
414            if !integrated[k] {
415                phases[k] = 0.0;
416            }
417        }
418
419        // Store for next frame
420        self.prev_log_mag.copy_from_slice(log_mag);
421        self.prev_phase.copy_from_slice(phases);
422
423        // Write output
424        for (out, &p) in phases_out.iter_mut().zip(phases.iter()) {
425            *out = p as f32;
426        }
427    }
428
429    /// Reset the processor state.
430    pub fn reset(&mut self) {
431        self.prev_log_mag.fill(f64::NEG_INFINITY);
432        self.prev_phase.fill(0.0);
433        self.has_prev = false;
434    }
435
436    /// Get the latency in samples.
437    pub fn latency_samples(&self) -> usize {
438        self.fft_size
439    }
440}
441
442/// Convenience: time-stretch a signal using RTPGHI for phase reconstruction.
443///
444/// # Arguments
445/// * `magnitudes_frames` - Sequence of magnitude spectra (one per STFT frame)
446/// * `stretch_factor` - Time stretch factor (2.0 = twice as long)
447/// * `fft_size` - FFT size used for STFT
448/// * `hop_size` - Original hop size
449///
450/// # Returns
451/// Reconstructed phases for stretched output (interpolated magnitude frames)
452pub fn stretch_with_rtpghi(
453    magnitude_frames: &[Vec<f32>],
454    stretch_factor: f64,
455    fft_size: usize,
456    hop_size: usize,
457) -> Vec<Vec<f32>> {
458    if magnitude_frames.is_empty() || stretch_factor <= 0.0 {
459        return Vec::new();
460    }
461
462    let num_input_frames = magnitude_frames.len();
463    let num_output_frames = (num_input_frames as f64 * stretch_factor).ceil() as usize;
464
465    // Interpolate magnitudes
466    let mut stretched_mags = Vec::with_capacity(num_output_frames);
467    for i in 0..num_output_frames {
468        let src_pos = i as f64 / stretch_factor;
469        let src_idx = src_pos.floor() as usize;
470        let frac = (src_pos - src_idx as f64) as f32;
471
472        let frame = if src_idx + 1 < num_input_frames {
473            magnitude_frames[src_idx]
474                .iter()
475                .zip(&magnitude_frames[src_idx + 1])
476                .map(|(&a, &b)| a * (1.0 - frac) + b * frac)
477                .collect()
478        } else if src_idx < num_input_frames {
479            magnitude_frames[src_idx].clone()
480        } else {
481            magnitude_frames.last().unwrap().clone()
482        };
483        stretched_mags.push(frame);
484    }
485
486    // Reconstruct phases
487    let mut processor = RtpghiProcessor::new(fft_size, hop_size);
488    stretched_mags
489        .iter()
490        .map(|mags| processor.process_frame(mags))
491        .collect()
492}
493
494// ============================================================================
495// Tests
496// ============================================================================
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501    use crate::stft::RealFftProcessor;
502
503    /// Helper: compute STFT magnitudes of a signal
504    fn compute_stft_magnitudes(signal: &[f32], fft_size: usize, hop_size: usize) -> Vec<Vec<f32>> {
505        let spectrum_size = fft_size / 2 + 1;
506        let window: Vec<f32> = (0..fft_size)
507            .map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / fft_size as f32).cos()))
508            .collect();
509
510        let mut frames = Vec::new();
511        let mut fft = RealFftProcessor::new_forward_only(fft_size);
512
513        let mut pos = 0;
514        while pos + fft_size <= signal.len() {
515            for i in 0..fft_size {
516                fft.time_buffer[i] = signal[pos + i] * window[i];
517            }
518            fft.forward();
519
520            let mags: Vec<f32> = fft.freq_buffer[..spectrum_size]
521                .iter()
522                .map(|c| (c.re * c.re + c.im * c.im).sqrt())
523                .collect();
524            frames.push(mags);
525            pos += hop_size;
526        }
527
528        frames
529    }
530
531    #[test]
532    fn test_identity_stretch() {
533        let fft_size = 256;
534        let hop_size = 64;
535        let sample_rate = 48000.0;
536
537        // Generate a pure tone
538        let num_samples = 4096;
539        let signal: Vec<f32> = (0..num_samples)
540            .map(|i| {
541                let t = i as f32 / sample_rate;
542                (2.0 * std::f32::consts::PI * 440.0 * t).sin()
543            })
544            .collect();
545
546        let mags = compute_stft_magnitudes(&signal, fft_size, hop_size);
547        assert!(!mags.is_empty());
548
549        // Identity stretch (factor = 1.0)
550        let phases = stretch_with_rtpghi(&mags, 1.0, fft_size, hop_size);
551        assert_eq!(phases.len(), mags.len());
552
553        // All phases should be finite
554        for frame in &phases {
555            for &p in frame {
556                assert!(p.is_finite(), "Phase should be finite, got {p}");
557            }
558        }
559    }
560
561    #[test]
562    fn test_2x_stretch_doubles_frames() {
563        let fft_size = 256;
564        let hop_size = 64;
565
566        // Simple magnitude frames
567        let spectrum_size = fft_size / 2 + 1;
568        let frame: Vec<f32> = (0..spectrum_size)
569            .map(|i| (i as f32).exp().recip())
570            .collect();
571        let mags = vec![frame; 10];
572
573        let stretched = stretch_with_rtpghi(&mags, 2.0, fft_size, hop_size);
574        assert_eq!(stretched.len(), 20);
575    }
576
577    #[test]
578    fn test_no_nan_inf() {
579        let fft_size = 512;
580        let hop_size = 128;
581        let spectrum_size = fft_size / 2 + 1;
582
583        let mut processor = RtpghiProcessor::new(fft_size, hop_size);
584
585        // Process several frames with varying magnitudes
586        for frame_idx in 0..20 {
587            let mags: Vec<f32> = (0..spectrum_size)
588                .map(|k| {
589                    let freq_factor = 1.0 - k as f32 / spectrum_size as f32;
590                    let time_factor = 1.0 + 0.5 * (frame_idx as f32 * 0.3).sin();
591                    freq_factor * time_factor
592                })
593                .collect();
594
595            let phases = processor.process_frame(&mags);
596            for (k, &p) in phases.iter().enumerate() {
597                assert!(
598                    p.is_finite(),
599                    "Phase at bin {k}, frame {frame_idx} is not finite: {p}"
600                );
601            }
602        }
603    }
604
605    #[test]
606    fn test_reset() {
607        let fft_size = 256;
608        let hop_size = 64;
609        let spectrum_size = fft_size / 2 + 1;
610
611        let mut processor = RtpghiProcessor::new(fft_size, hop_size);
612        let mags = vec![0.5; spectrum_size];
613
614        // Process then reset
615        let _ = processor.process_frame(&mags);
616        assert!(processor.has_prev);
617
618        processor.reset();
619        assert!(!processor.has_prev);
620    }
621
622    #[test]
623    fn test_empty_stretch() {
624        let result = stretch_with_rtpghi(&[], 2.0, 256, 64);
625        assert!(result.is_empty());
626    }
627
628    #[test]
629    fn test_zero_magnitude_bins() {
630        let fft_size = 256;
631        let hop_size = 64;
632        let spectrum_size = fft_size / 2 + 1;
633
634        let mut processor = RtpghiProcessor::new(fft_size, hop_size);
635
636        // All-zero magnitudes
637        let mags = vec![0.0f32; spectrum_size];
638        let _ = processor.process_frame(&mags);
639        let phases = processor.process_frame(&mags);
640
641        for &p in &phases {
642            assert!(p.is_finite());
643        }
644    }
645
646    /// Verify that `process_frame_into` produces the same results as `process_frame`.
647    #[test]
648    fn test_process_frame_into_matches_process_frame() {
649        let fft_size = 512;
650        let hop_size = 128;
651        let spectrum_size = fft_size / 2 + 1;
652
653        let mut proc_alloc = RtpghiProcessor::new(fft_size, hop_size);
654        let mut proc_noalloc = RtpghiProcessor::new(fft_size, hop_size);
655
656        for frame_idx in 0..15 {
657            let mags: Vec<f32> = (0..spectrum_size)
658                .map(|k| {
659                    let freq_factor = 1.0 - k as f32 / spectrum_size as f32;
660                    let time_factor = 1.0 + 0.5 * (frame_idx as f32 * 0.3).sin();
661                    freq_factor * time_factor
662                })
663                .collect();
664
665            let phases_alloc = proc_alloc.process_frame(&mags);
666            let mut phases_noalloc = vec![0.0f32; spectrum_size];
667            proc_noalloc.process_frame_into(&mags, &mut phases_noalloc);
668
669            for (k, (&a, &b)) in phases_alloc.iter().zip(phases_noalloc.iter()).enumerate() {
670                assert!(
671                    (a - b).abs() < 1e-5,
672                    "Mismatch at bin {k}, frame {frame_idx}: alloc={a}, noalloc={b}"
673                );
674            }
675        }
676    }
677
678    /// Verify that `process_frame_into` produces finite phases and does not panic.
679    #[test]
680    fn test_process_frame_into_no_nan() {
681        let fft_size = 256;
682        let hop_size = 64;
683        let spectrum_size = fft_size / 2 + 1;
684
685        let mut processor = RtpghiProcessor::new(fft_size, hop_size);
686        let mut phases = vec![0.0f32; spectrum_size];
687
688        for frame_idx in 0..10 {
689            let mags: Vec<f32> = (0..spectrum_size)
690                .map(|k| {
691                    let v = 0.5 + 0.5 * ((frame_idx * k) as f32 * 0.1).sin();
692                    v.max(0.0)
693                })
694                .collect();
695
696            processor.process_frame_into(&mags, &mut phases);
697            for (k, &p) in phases.iter().enumerate() {
698                assert!(
699                    p.is_finite(),
700                    "Phase at bin {k}, frame {frame_idx} is not finite: {p}"
701                );
702            }
703        }
704    }
705}