Skip to main content

math_audio_dsp/
rtpghi.rs

1// ============================================================================
2// RTPGHI — Real-Time Phase Gradient Heap Integration
3// ============================================================================
4//
5// State-of-the-art causal phase reconstruction from magnitude-only STFT.
6// Uses phase gradients estimated from log-magnitude differences and
7// integrates phases starting from the highest-magnitude bins (heap-ordered).
8//
9// Reference: Průša, Z. & Søndergaard, P. (2016). "Real-Time Spectrogram
10// Inversion Using Phase Gradient Heap Integration."
11//
12// Key property: causal (only uses past frames), suitable for streaming.
13
14/// Phase reconstruction processor using RTPGHI.
15pub struct RtpghiProcessor {
16    fft_size: usize,
17    hop_size: usize,
18    /// Gamma parameter for the Gaussian-like window (Hann: gamma ≈ 0.25688 * fft_size²)
19    gamma: f64,
20    /// Previous frame log-magnitudes
21    prev_log_mag: Vec<f64>,
22    /// Previous frame phases
23    prev_phase: Vec<f64>,
24    /// Whether we have a previous frame
25    has_prev: bool,
26    /// Tolerance for log-magnitude (bins below this are set to random phase)
27    log_mag_tol: f64,
28
29    // Pre-allocated scratch buffers for process_frame_into (zero-alloc hot path)
30    scratch_log_mag: Vec<f64>,
31    scratch_phases: Vec<f64>,
32    scratch_integrated: Vec<bool>,
33    scratch_d_phase_time: Vec<f64>,
34    scratch_d_phase_freq: Vec<f64>,
35    scratch_heap: Vec<HeapEntry>,
36}
37
38/// A bin entry for the priority queue (max-heap by magnitude).
39#[derive(PartialEq)]
40struct HeapEntry {
41    magnitude: f64,
42    bin: usize,
43}
44
45impl Eq for HeapEntry {}
46
47impl PartialOrd for HeapEntry {
48    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
49        Some(self.cmp(other))
50    }
51}
52
53impl Ord for HeapEntry {
54    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
55        self.magnitude
56            .partial_cmp(&other.magnitude)
57            .unwrap_or(std::cmp::Ordering::Equal)
58    }
59}
60
61impl RtpghiProcessor {
62    /// Create a new RTPGHI processor.
63    ///
64    /// # Arguments
65    /// * `fft_size` - FFT size (must be power of 2)
66    /// * `hop_size` - Hop size in samples
67    pub fn new(fft_size: usize, hop_size: usize) -> Self {
68        let spectrum_size = fft_size / 2 + 1;
69
70        // Lambda parameter for Hann window approximated as Gaussian:
71        // lambda ≈ 0.17 * M^2 (from LTFAT reference implementation)
72        let gamma = 0.17 * (fft_size as f64) * (fft_size as f64);
73
74        Self {
75            fft_size,
76            hop_size,
77            gamma,
78            prev_log_mag: vec![f64::NEG_INFINITY; spectrum_size],
79            prev_phase: vec![0.0; spectrum_size],
80            has_prev: false,
81            log_mag_tol: -60.0, // -60 dB threshold
82            scratch_log_mag: vec![0.0; spectrum_size],
83            scratch_phases: vec![0.0; spectrum_size],
84            scratch_integrated: vec![false; spectrum_size],
85            scratch_d_phase_time: vec![0.0; spectrum_size],
86            scratch_d_phase_freq: vec![0.0; spectrum_size],
87            scratch_heap: Vec::with_capacity(spectrum_size),
88        }
89    }
90
91    /// Process one STFT frame: given magnitudes, reconstruct phases.
92    ///
93    /// # Arguments
94    /// * `magnitudes` - Magnitude spectrum (spectrum_size = fft_size/2 + 1)
95    ///
96    /// # Returns
97    /// Reconstructed phase values for each bin.
98    ///
99    /// This is a thin allocation wrapper over [`Self::process_frame_into`].
100    /// For real-time / allocation-free streaming, call `process_frame_into` directly.
101    pub fn process_frame(&mut self, magnitudes: &[f32]) -> Vec<f32> {
102        let n = self.fft_size / 2 + 1;
103        let mut out = vec![0.0f32; n];
104        self.process_frame_into(magnitudes, &mut out);
105        out
106    }
107
108    /// Process one STFT frame without allocations: given magnitudes, write
109    /// reconstructed phases into the provided output slice.
110    ///
111    /// # Arguments
112    /// * `magnitudes` - Magnitude spectrum (spectrum_size = fft_size/2 + 1)
113    /// * `phases_out` - Output slice for reconstructed phases (same length)
114    ///
115    /// # Panics
116    /// If `magnitudes` or `phases_out` length does not equal `fft_size/2 + 1`.
117    pub fn process_frame_into(&mut self, magnitudes: &[f32], phases_out: &mut [f32]) {
118        let spectrum_size = self.fft_size / 2 + 1;
119        assert_eq!(magnitudes.len(), spectrum_size);
120        assert_eq!(phases_out.len(), spectrum_size);
121
122        let log_mag = &mut self.scratch_log_mag;
123        let phases = &mut self.scratch_phases;
124        let integrated = &mut self.scratch_integrated;
125        let d_phase_time = &mut self.scratch_d_phase_time;
126        let d_phase_freq = &mut self.scratch_d_phase_freq;
127
128        // Compute log-magnitudes
129        for (i, &m) in magnitudes.iter().enumerate() {
130            log_mag[i] = if m > 0.0 {
131                (m as f64).ln()
132            } else {
133                f64::NEG_INFINITY
134            };
135        }
136
137        // Zero scratch
138        for v in phases.iter_mut() {
139            *v = 0.0;
140        }
141        for v in integrated.iter_mut() {
142            *v = false;
143        }
144
145        if !self.has_prev {
146            // First frame: use zero phase
147            self.prev_log_mag.copy_from_slice(log_mag);
148            self.prev_phase.copy_from_slice(phases);
149            self.has_prev = true;
150            for (out, &p) in phases_out.iter_mut().zip(phases.iter()) {
151                *out = p as f32;
152            }
153            return;
154        }
155
156        // Phase gradient estimation
157        let hop = self.hop_size as f64;
158        let two_pi = 2.0 * std::f64::consts::PI;
159        let gamma = self.gamma;
160        let log_mag_tol = self.log_mag_tol;
161        let fft_size = self.fft_size;
162
163        // Time-direction phase gradient
164        for k in 0..spectrum_size {
165            let omega_k = two_pi * k as f64 / fft_size as f64;
166            let expected_advance = omega_k * hop;
167            let time_grad = if log_mag[k] > log_mag_tol && self.prev_log_mag[k] > log_mag_tol {
168                gamma * (log_mag[k] - self.prev_log_mag[k])
169            } else {
170                0.0
171            };
172            d_phase_time[k] = expected_advance + time_grad;
173        }
174
175        // Frequency-direction phase gradient
176        let inv_gamma = if gamma.abs() > 1e-30 {
177            1.0 / gamma
178        } else {
179            0.0
180        };
181        d_phase_freq[0] = 0.0;
182        if spectrum_size > 1 {
183            d_phase_freq[spectrum_size - 1] = 0.0;
184        }
185        for k in 1..spectrum_size.saturating_sub(1) {
186            d_phase_freq[k] = if log_mag[k] > log_mag_tol
187                && log_mag[k - 1] > log_mag_tol
188                && log_mag[k + 1] > log_mag_tol
189            {
190                inv_gamma * (log_mag[k + 1] - log_mag[k - 1]) / 2.0
191            } else {
192                0.0
193            };
194        }
195
196        // Build sorted list by magnitude descending (reuse pre-allocated vec)
197        self.scratch_heap.clear();
198        for (k, &mag) in log_mag.iter().enumerate() {
199            if mag > log_mag_tol {
200                self.scratch_heap.push(HeapEntry {
201                    magnitude: mag,
202                    bin: k,
203                });
204            }
205        }
206        // Sort descending by magnitude (highest first) while reusing scratch storage.
207        // This preserves the priority integration order without allocating a BinaryHeap.
208        self.scratch_heap.sort_unstable_by(|a, b| b.cmp(a));
209
210        // Integrate phases starting from loudest bins
211        for idx in 0..self.scratch_heap.len() {
212            let k = self.scratch_heap[idx].bin;
213            if integrated[k] {
214                continue;
215            }
216
217            let phase_from_time = self.prev_phase[k] + d_phase_time[k];
218
219            let phase_from_freq_below = if k > 0 && integrated[k - 1] {
220                Some(phases[k - 1] + d_phase_freq[k - 1])
221            } else {
222                None
223            };
224
225            let phase_from_freq_above = if k + 1 < spectrum_size && integrated[k + 1] {
226                Some(phases[k + 1] - d_phase_freq[k + 1])
227            } else {
228                None
229            };
230
231            let phase = match (phase_from_freq_below, phase_from_freq_above) {
232                (Some(below), Some(above)) => {
233                    let avg = (below + above) / 2.0;
234                    if self.prev_log_mag[k] > log_mag_tol {
235                        (avg + phase_from_time) / 2.0
236                    } else {
237                        avg
238                    }
239                }
240                (Some(below), None) => {
241                    if self.prev_log_mag[k] > log_mag_tol {
242                        (below + phase_from_time) / 2.0
243                    } else {
244                        below
245                    }
246                }
247                (None, Some(above)) => {
248                    if self.prev_log_mag[k] > log_mag_tol {
249                        (above + phase_from_time) / 2.0
250                    } else {
251                        above
252                    }
253                }
254                (None, None) => phase_from_time,
255            };
256
257            phases[k] = phase;
258            integrated[k] = true;
259        }
260
261        // Bins below threshold get zero phase
262        for k in 0..spectrum_size {
263            if !integrated[k] {
264                phases[k] = 0.0;
265            }
266        }
267
268        // Store for next frame
269        self.prev_log_mag.copy_from_slice(log_mag);
270        self.prev_phase.copy_from_slice(phases);
271
272        // Write output
273        for (out, &p) in phases_out.iter_mut().zip(phases.iter()) {
274            *out = p as f32;
275        }
276    }
277
278    /// Reset the processor state.
279    pub fn reset(&mut self) {
280        self.prev_log_mag.fill(f64::NEG_INFINITY);
281        self.prev_phase.fill(0.0);
282        self.has_prev = false;
283    }
284
285    /// Get the latency in samples.
286    pub fn latency_samples(&self) -> usize {
287        self.fft_size
288    }
289}
290
291/// Convenience: time-stretch a signal using RTPGHI for phase reconstruction.
292///
293/// # Arguments
294/// * `magnitudes_frames` - Sequence of magnitude spectra (one per STFT frame)
295/// * `stretch_factor` - Time stretch factor (2.0 = twice as long)
296/// * `fft_size` - FFT size used for STFT
297/// * `hop_size` - Original hop size
298///
299/// # Returns
300/// Reconstructed phases for stretched output (interpolated magnitude frames)
301pub fn stretch_with_rtpghi(
302    magnitude_frames: &[Vec<f32>],
303    stretch_factor: f64,
304    fft_size: usize,
305    hop_size: usize,
306) -> Vec<Vec<f32>> {
307    if magnitude_frames.is_empty() || stretch_factor <= 0.0 {
308        return Vec::new();
309    }
310
311    let num_input_frames = magnitude_frames.len();
312    let num_output_frames = (num_input_frames as f64 * stretch_factor).ceil() as usize;
313
314    // Interpolate magnitudes
315    let mut stretched_mags = Vec::with_capacity(num_output_frames);
316    for i in 0..num_output_frames {
317        let src_pos = i as f64 / stretch_factor;
318        let src_idx = src_pos.floor() as usize;
319        let frac = (src_pos - src_idx as f64) as f32;
320
321        let frame = if src_idx + 1 < num_input_frames {
322            magnitude_frames[src_idx]
323                .iter()
324                .zip(&magnitude_frames[src_idx + 1])
325                .map(|(&a, &b)| a * (1.0 - frac) + b * frac)
326                .collect()
327        } else if src_idx < num_input_frames {
328            magnitude_frames[src_idx].clone()
329        } else {
330            magnitude_frames.last().unwrap().clone()
331        };
332        stretched_mags.push(frame);
333    }
334
335    // Reconstruct phases
336    let mut processor = RtpghiProcessor::new(fft_size, hop_size);
337    stretched_mags
338        .iter()
339        .map(|mags| processor.process_frame(mags))
340        .collect()
341}
342
343// ============================================================================
344// Tests
345// ============================================================================
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use crate::stft::RealFftProcessor;
351
352    /// Helper: compute STFT magnitudes of a signal
353    fn compute_stft_magnitudes(signal: &[f32], fft_size: usize, hop_size: usize) -> Vec<Vec<f32>> {
354        let spectrum_size = fft_size / 2 + 1;
355        let window: Vec<f32> = (0..fft_size)
356            .map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / fft_size as f32).cos()))
357            .collect();
358
359        let mut frames = Vec::new();
360        let mut fft = RealFftProcessor::new_forward_only(fft_size);
361
362        let mut pos = 0;
363        while pos + fft_size <= signal.len() {
364            for i in 0..fft_size {
365                fft.time_buffer[i] = signal[pos + i] * window[i];
366            }
367            fft.forward();
368
369            let mags: Vec<f32> = fft.freq_buffer[..spectrum_size]
370                .iter()
371                .map(|c| (c.re * c.re + c.im * c.im).sqrt())
372                .collect();
373            frames.push(mags);
374            pos += hop_size;
375        }
376
377        frames
378    }
379
380    #[test]
381    fn test_identity_stretch() {
382        let fft_size = 256;
383        let hop_size = 64;
384        let sample_rate = 48000.0;
385
386        // Generate a pure tone
387        let num_samples = 4096;
388        let signal: Vec<f32> = (0..num_samples)
389            .map(|i| {
390                let t = i as f32 / sample_rate;
391                (2.0 * std::f32::consts::PI * 440.0 * t).sin()
392            })
393            .collect();
394
395        let mags = compute_stft_magnitudes(&signal, fft_size, hop_size);
396        assert!(!mags.is_empty());
397
398        // Identity stretch (factor = 1.0)
399        let phases = stretch_with_rtpghi(&mags, 1.0, fft_size, hop_size);
400        assert_eq!(phases.len(), mags.len());
401
402        // All phases should be finite
403        for frame in &phases {
404            for &p in frame {
405                assert!(p.is_finite(), "Phase should be finite, got {p}");
406            }
407        }
408    }
409
410    #[test]
411    fn test_2x_stretch_doubles_frames() {
412        let fft_size = 256;
413        let hop_size = 64;
414
415        // Simple magnitude frames
416        let spectrum_size = fft_size / 2 + 1;
417        let frame: Vec<f32> = (0..spectrum_size)
418            .map(|i| (i as f32).exp().recip())
419            .collect();
420        let mags = vec![frame; 10];
421
422        let stretched = stretch_with_rtpghi(&mags, 2.0, fft_size, hop_size);
423        assert_eq!(stretched.len(), 20);
424    }
425
426    #[test]
427    fn test_no_nan_inf() {
428        let fft_size = 512;
429        let hop_size = 128;
430        let spectrum_size = fft_size / 2 + 1;
431
432        let mut processor = RtpghiProcessor::new(fft_size, hop_size);
433
434        // Process several frames with varying magnitudes
435        for frame_idx in 0..20 {
436            let mags: Vec<f32> = (0..spectrum_size)
437                .map(|k| {
438                    let freq_factor = 1.0 - k as f32 / spectrum_size as f32;
439                    let time_factor = 1.0 + 0.5 * (frame_idx as f32 * 0.3).sin();
440                    freq_factor * time_factor
441                })
442                .collect();
443
444            let phases = processor.process_frame(&mags);
445            for (k, &p) in phases.iter().enumerate() {
446                assert!(
447                    p.is_finite(),
448                    "Phase at bin {k}, frame {frame_idx} is not finite: {p}"
449                );
450            }
451        }
452    }
453
454    #[test]
455    fn test_reset() {
456        let fft_size = 256;
457        let hop_size = 64;
458        let spectrum_size = fft_size / 2 + 1;
459
460        let mut processor = RtpghiProcessor::new(fft_size, hop_size);
461        let mags = vec![0.5; spectrum_size];
462
463        // Process then reset
464        let _ = processor.process_frame(&mags);
465        assert!(processor.has_prev);
466
467        processor.reset();
468        assert!(!processor.has_prev);
469    }
470
471    #[test]
472    fn test_empty_stretch() {
473        let result = stretch_with_rtpghi(&[], 2.0, 256, 64);
474        assert!(result.is_empty());
475    }
476
477    #[test]
478    fn test_zero_magnitude_bins() {
479        let fft_size = 256;
480        let hop_size = 64;
481        let spectrum_size = fft_size / 2 + 1;
482
483        let mut processor = RtpghiProcessor::new(fft_size, hop_size);
484
485        // All-zero magnitudes
486        let mags = vec![0.0f32; spectrum_size];
487        let _ = processor.process_frame(&mags);
488        let phases = processor.process_frame(&mags);
489
490        for &p in &phases {
491            assert!(p.is_finite());
492        }
493    }
494
495    /// Verify that `process_frame_into` produces the same results as `process_frame`.
496    #[test]
497    fn test_process_frame_into_matches_process_frame() {
498        let fft_size = 512;
499        let hop_size = 128;
500        let spectrum_size = fft_size / 2 + 1;
501
502        let mut proc_alloc = RtpghiProcessor::new(fft_size, hop_size);
503        let mut proc_noalloc = RtpghiProcessor::new(fft_size, hop_size);
504
505        for frame_idx in 0..15 {
506            let mags: Vec<f32> = (0..spectrum_size)
507                .map(|k| {
508                    let freq_factor = 1.0 - k as f32 / spectrum_size as f32;
509                    let time_factor = 1.0 + 0.5 * (frame_idx as f32 * 0.3).sin();
510                    freq_factor * time_factor
511                })
512                .collect();
513
514            let phases_alloc = proc_alloc.process_frame(&mags);
515            let mut phases_noalloc = vec![0.0f32; spectrum_size];
516            proc_noalloc.process_frame_into(&mags, &mut phases_noalloc);
517
518            for (k, (&a, &b)) in phases_alloc.iter().zip(phases_noalloc.iter()).enumerate() {
519                assert!(
520                    (a - b).abs() < 1e-5,
521                    "Mismatch at bin {k}, frame {frame_idx}: alloc={a}, noalloc={b}"
522                );
523            }
524        }
525    }
526
527    /// Verify that `process_frame_into` produces finite phases and does not panic.
528    #[test]
529    fn test_process_frame_into_no_nan() {
530        let fft_size = 256;
531        let hop_size = 64;
532        let spectrum_size = fft_size / 2 + 1;
533
534        let mut processor = RtpghiProcessor::new(fft_size, hop_size);
535        let mut phases = vec![0.0f32; spectrum_size];
536
537        for frame_idx in 0..10 {
538            let mags: Vec<f32> = (0..spectrum_size)
539                .map(|k| {
540                    let v = 0.5 + 0.5 * ((frame_idx * k) as f32 * 0.1).sin();
541                    v.max(0.0)
542                })
543                .collect();
544
545            processor.process_frame_into(&mags, &mut phases);
546            for (k, &p) in phases.iter().enumerate() {
547                assert!(
548                    p.is_finite(),
549                    "Phase at bin {k}, frame {frame_idx} is not finite: {p}"
550                );
551            }
552        }
553    }
554}