pub struct RtpghiProcessor {
fft_size: usize,
hop_size: usize,
gamma: f64,
prev_log_mag: Vec<f64>,
prev_phase: Vec<f64>,
has_prev: bool,
log_mag_tol: f64,
scratch_log_mag: Vec<f64>,
scratch_phases: Vec<f64>,
scratch_integrated: Vec<bool>,
scratch_d_phase_time: Vec<f64>,
scratch_d_phase_freq: Vec<f64>,
scratch_heap: Vec<HeapEntry>,
}
#[derive(PartialEq)]
struct HeapEntry {
magnitude: f64,
bin: usize,
}
impl Eq for HeapEntry {}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.magnitude
.partial_cmp(&other.magnitude)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
impl RtpghiProcessor {
pub fn new(fft_size: usize, hop_size: usize) -> Self {
let spectrum_size = fft_size / 2 + 1;
let gamma = 0.17 * (fft_size as f64) * (fft_size as f64);
Self {
fft_size,
hop_size,
gamma,
prev_log_mag: vec![f64::NEG_INFINITY; spectrum_size],
prev_phase: vec![0.0; spectrum_size],
has_prev: false,
log_mag_tol: -60.0, scratch_log_mag: vec![0.0; spectrum_size],
scratch_phases: vec![0.0; spectrum_size],
scratch_integrated: vec![false; spectrum_size],
scratch_d_phase_time: vec![0.0; spectrum_size],
scratch_d_phase_freq: vec![0.0; spectrum_size],
scratch_heap: Vec::with_capacity(spectrum_size),
}
}
pub fn process_frame(&mut self, magnitudes: &[f32]) -> Vec<f32> {
let n = self.fft_size / 2 + 1;
let mut out = vec![0.0f32; n];
self.process_frame_into(magnitudes, &mut out);
out
}
pub fn process_frame_into(&mut self, magnitudes: &[f32], phases_out: &mut [f32]) {
let spectrum_size = self.fft_size / 2 + 1;
assert_eq!(magnitudes.len(), spectrum_size);
assert_eq!(phases_out.len(), spectrum_size);
let log_mag = &mut self.scratch_log_mag;
let phases = &mut self.scratch_phases;
let integrated = &mut self.scratch_integrated;
let d_phase_time = &mut self.scratch_d_phase_time;
let d_phase_freq = &mut self.scratch_d_phase_freq;
for (i, &m) in magnitudes.iter().enumerate() {
log_mag[i] = if m > 0.0 {
(m as f64).ln()
} else {
f64::NEG_INFINITY
};
}
for v in phases.iter_mut() {
*v = 0.0;
}
for v in integrated.iter_mut() {
*v = false;
}
if !self.has_prev {
self.prev_log_mag.copy_from_slice(log_mag);
self.prev_phase.copy_from_slice(phases);
self.has_prev = true;
for (out, &p) in phases_out.iter_mut().zip(phases.iter()) {
*out = p as f32;
}
return;
}
let hop = self.hop_size as f64;
let two_pi = 2.0 * std::f64::consts::PI;
let gamma = self.gamma;
let log_mag_tol = self.log_mag_tol;
let fft_size = self.fft_size;
for k in 0..spectrum_size {
let omega_k = two_pi * k as f64 / fft_size as f64;
let expected_advance = omega_k * hop;
let time_grad = if log_mag[k] > log_mag_tol && self.prev_log_mag[k] > log_mag_tol {
gamma * (log_mag[k] - self.prev_log_mag[k])
} else {
0.0
};
d_phase_time[k] = expected_advance + time_grad;
}
let inv_gamma = if gamma.abs() > 1e-30 {
1.0 / gamma
} else {
0.0
};
d_phase_freq[0] = 0.0;
if spectrum_size > 1 {
d_phase_freq[spectrum_size - 1] = 0.0;
}
for k in 1..spectrum_size.saturating_sub(1) {
d_phase_freq[k] = if log_mag[k] > log_mag_tol
&& log_mag[k - 1] > log_mag_tol
&& log_mag[k + 1] > log_mag_tol
{
inv_gamma * (log_mag[k + 1] - log_mag[k - 1]) / 2.0
} else {
0.0
};
}
self.scratch_heap.clear();
for (k, &mag) in log_mag.iter().enumerate() {
if mag > log_mag_tol {
self.scratch_heap.push(HeapEntry {
magnitude: mag,
bin: k,
});
}
}
self.scratch_heap.sort_unstable_by(|a, b| b.cmp(a));
for idx in 0..self.scratch_heap.len() {
let k = self.scratch_heap[idx].bin;
if integrated[k] {
continue;
}
let phase_from_time = self.prev_phase[k] + d_phase_time[k];
let phase_from_freq_below = if k > 0 && integrated[k - 1] {
Some(phases[k - 1] + d_phase_freq[k - 1])
} else {
None
};
let phase_from_freq_above = if k + 1 < spectrum_size && integrated[k + 1] {
Some(phases[k + 1] - d_phase_freq[k + 1])
} else {
None
};
let phase = match (phase_from_freq_below, phase_from_freq_above) {
(Some(below), Some(above)) => {
let avg = (below + above) / 2.0;
if self.prev_log_mag[k] > log_mag_tol {
(avg + phase_from_time) / 2.0
} else {
avg
}
}
(Some(below), None) => {
if self.prev_log_mag[k] > log_mag_tol {
(below + phase_from_time) / 2.0
} else {
below
}
}
(None, Some(above)) => {
if self.prev_log_mag[k] > log_mag_tol {
(above + phase_from_time) / 2.0
} else {
above
}
}
(None, None) => phase_from_time,
};
phases[k] = phase;
integrated[k] = true;
}
for k in 0..spectrum_size {
if !integrated[k] {
phases[k] = 0.0;
}
}
self.prev_log_mag.copy_from_slice(log_mag);
self.prev_phase.copy_from_slice(phases);
for (out, &p) in phases_out.iter_mut().zip(phases.iter()) {
*out = p as f32;
}
}
pub fn reset(&mut self) {
self.prev_log_mag.fill(f64::NEG_INFINITY);
self.prev_phase.fill(0.0);
self.has_prev = false;
}
pub fn latency_samples(&self) -> usize {
self.fft_size
}
}
pub fn stretch_with_rtpghi(
magnitude_frames: &[Vec<f32>],
stretch_factor: f64,
fft_size: usize,
hop_size: usize,
) -> Vec<Vec<f32>> {
if magnitude_frames.is_empty() || stretch_factor <= 0.0 {
return Vec::new();
}
let num_input_frames = magnitude_frames.len();
let num_output_frames = (num_input_frames as f64 * stretch_factor).ceil() as usize;
let mut stretched_mags = Vec::with_capacity(num_output_frames);
for i in 0..num_output_frames {
let src_pos = i as f64 / stretch_factor;
let src_idx = src_pos.floor() as usize;
let frac = (src_pos - src_idx as f64) as f32;
let frame = if src_idx + 1 < num_input_frames {
magnitude_frames[src_idx]
.iter()
.zip(&magnitude_frames[src_idx + 1])
.map(|(&a, &b)| a * (1.0 - frac) + b * frac)
.collect()
} else if src_idx < num_input_frames {
magnitude_frames[src_idx].clone()
} else {
magnitude_frames.last().unwrap().clone()
};
stretched_mags.push(frame);
}
let mut processor = RtpghiProcessor::new(fft_size, hop_size);
stretched_mags
.iter()
.map(|mags| processor.process_frame(mags))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stft::RealFftProcessor;
fn compute_stft_magnitudes(signal: &[f32], fft_size: usize, hop_size: usize) -> Vec<Vec<f32>> {
let spectrum_size = fft_size / 2 + 1;
let window: Vec<f32> = (0..fft_size)
.map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / fft_size as f32).cos()))
.collect();
let mut frames = Vec::new();
let mut fft = RealFftProcessor::new_forward_only(fft_size);
let mut pos = 0;
while pos + fft_size <= signal.len() {
for i in 0..fft_size {
fft.time_buffer[i] = signal[pos + i] * window[i];
}
fft.forward();
let mags: Vec<f32> = fft.freq_buffer[..spectrum_size]
.iter()
.map(|c| (c.re * c.re + c.im * c.im).sqrt())
.collect();
frames.push(mags);
pos += hop_size;
}
frames
}
#[test]
fn test_identity_stretch() {
let fft_size = 256;
let hop_size = 64;
let sample_rate = 48000.0;
let num_samples = 4096;
let signal: Vec<f32> = (0..num_samples)
.map(|i| {
let t = i as f32 / sample_rate;
(2.0 * std::f32::consts::PI * 440.0 * t).sin()
})
.collect();
let mags = compute_stft_magnitudes(&signal, fft_size, hop_size);
assert!(!mags.is_empty());
let phases = stretch_with_rtpghi(&mags, 1.0, fft_size, hop_size);
assert_eq!(phases.len(), mags.len());
for frame in &phases {
for &p in frame {
assert!(p.is_finite(), "Phase should be finite, got {p}");
}
}
}
#[test]
fn test_2x_stretch_doubles_frames() {
let fft_size = 256;
let hop_size = 64;
let spectrum_size = fft_size / 2 + 1;
let frame: Vec<f32> = (0..spectrum_size)
.map(|i| (i as f32).exp().recip())
.collect();
let mags = vec![frame; 10];
let stretched = stretch_with_rtpghi(&mags, 2.0, fft_size, hop_size);
assert_eq!(stretched.len(), 20);
}
#[test]
fn test_no_nan_inf() {
let fft_size = 512;
let hop_size = 128;
let spectrum_size = fft_size / 2 + 1;
let mut processor = RtpghiProcessor::new(fft_size, hop_size);
for frame_idx in 0..20 {
let mags: Vec<f32> = (0..spectrum_size)
.map(|k| {
let freq_factor = 1.0 - k as f32 / spectrum_size as f32;
let time_factor = 1.0 + 0.5 * (frame_idx as f32 * 0.3).sin();
freq_factor * time_factor
})
.collect();
let phases = processor.process_frame(&mags);
for (k, &p) in phases.iter().enumerate() {
assert!(
p.is_finite(),
"Phase at bin {k}, frame {frame_idx} is not finite: {p}"
);
}
}
}
#[test]
fn test_reset() {
let fft_size = 256;
let hop_size = 64;
let spectrum_size = fft_size / 2 + 1;
let mut processor = RtpghiProcessor::new(fft_size, hop_size);
let mags = vec![0.5; spectrum_size];
let _ = processor.process_frame(&mags);
assert!(processor.has_prev);
processor.reset();
assert!(!processor.has_prev);
}
#[test]
fn test_empty_stretch() {
let result = stretch_with_rtpghi(&[], 2.0, 256, 64);
assert!(result.is_empty());
}
#[test]
fn test_zero_magnitude_bins() {
let fft_size = 256;
let hop_size = 64;
let spectrum_size = fft_size / 2 + 1;
let mut processor = RtpghiProcessor::new(fft_size, hop_size);
let mags = vec![0.0f32; spectrum_size];
let _ = processor.process_frame(&mags);
let phases = processor.process_frame(&mags);
for &p in &phases {
assert!(p.is_finite());
}
}
#[test]
fn test_process_frame_into_matches_process_frame() {
let fft_size = 512;
let hop_size = 128;
let spectrum_size = fft_size / 2 + 1;
let mut proc_alloc = RtpghiProcessor::new(fft_size, hop_size);
let mut proc_noalloc = RtpghiProcessor::new(fft_size, hop_size);
for frame_idx in 0..15 {
let mags: Vec<f32> = (0..spectrum_size)
.map(|k| {
let freq_factor = 1.0 - k as f32 / spectrum_size as f32;
let time_factor = 1.0 + 0.5 * (frame_idx as f32 * 0.3).sin();
freq_factor * time_factor
})
.collect();
let phases_alloc = proc_alloc.process_frame(&mags);
let mut phases_noalloc = vec![0.0f32; spectrum_size];
proc_noalloc.process_frame_into(&mags, &mut phases_noalloc);
for (k, (&a, &b)) in phases_alloc.iter().zip(phases_noalloc.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-5,
"Mismatch at bin {k}, frame {frame_idx}: alloc={a}, noalloc={b}"
);
}
}
}
#[test]
fn test_process_frame_into_no_nan() {
let fft_size = 256;
let hop_size = 64;
let spectrum_size = fft_size / 2 + 1;
let mut processor = RtpghiProcessor::new(fft_size, hop_size);
let mut phases = vec![0.0f32; spectrum_size];
for frame_idx in 0..10 {
let mags: Vec<f32> = (0..spectrum_size)
.map(|k| {
let v = 0.5 + 0.5 * ((frame_idx * k) as f32 * 0.1).sin();
v.max(0.0)
})
.collect();
processor.process_frame_into(&mags, &mut phases);
for (k, &p) in phases.iter().enumerate() {
assert!(
p.is_finite(),
"Phase at bin {k}, frame {frame_idx} is not finite: {p}"
);
}
}
}
}