1use std::collections::BinaryHeap;
15
16pub struct RtpghiProcessor {
18 fft_size: usize,
19 hop_size: usize,
20 gamma: f64,
22 prev_log_mag: Vec<f64>,
24 prev_phase: Vec<f64>,
26 has_prev: bool,
28 log_mag_tol: f64,
30
31 scratch_log_mag: Vec<f64>,
33 scratch_phases: Vec<f64>,
34 scratch_integrated: Vec<bool>,
35 scratch_d_phase_time: Vec<f64>,
36 scratch_d_phase_freq: Vec<f64>,
37 scratch_heap: Vec<HeapEntry>,
38}
39
40#[derive(PartialEq)]
42struct HeapEntry {
43 magnitude: f64,
44 bin: usize,
45}
46
47impl Eq for HeapEntry {}
48
49impl PartialOrd for HeapEntry {
50 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
51 Some(self.cmp(other))
52 }
53}
54
55impl Ord for HeapEntry {
56 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
57 self.magnitude
58 .partial_cmp(&other.magnitude)
59 .unwrap_or(std::cmp::Ordering::Equal)
60 }
61}
62
63impl RtpghiProcessor {
64 pub fn new(fft_size: usize, hop_size: usize) -> Self {
70 let spectrum_size = fft_size / 2 + 1;
71
72 let gamma = 0.17 * (fft_size as f64) * (fft_size as f64);
75
76 Self {
77 fft_size,
78 hop_size,
79 gamma,
80 prev_log_mag: vec![f64::NEG_INFINITY; spectrum_size],
81 prev_phase: vec![0.0; spectrum_size],
82 has_prev: false,
83 log_mag_tol: -60.0, scratch_log_mag: vec![0.0; spectrum_size],
85 scratch_phases: vec![0.0; spectrum_size],
86 scratch_integrated: vec![false; spectrum_size],
87 scratch_d_phase_time: vec![0.0; spectrum_size],
88 scratch_d_phase_freq: vec![0.0; spectrum_size],
89 scratch_heap: Vec::with_capacity(spectrum_size),
90 }
91 }
92
93 pub fn process_frame(&mut self, magnitudes: &[f32]) -> Vec<f32> {
101 let spectrum_size = self.fft_size / 2 + 1;
102 assert_eq!(
103 magnitudes.len(),
104 spectrum_size,
105 "Expected {} magnitudes, got {}",
106 spectrum_size,
107 magnitudes.len()
108 );
109
110 let log_mag: Vec<f64> = magnitudes
112 .iter()
113 .map(|&m| {
114 if m > 0.0 {
115 (m as f64).ln()
116 } else {
117 f64::NEG_INFINITY
118 }
119 })
120 .collect();
121
122 let mut phases = vec![0.0f64; spectrum_size];
123 let mut integrated = vec![false; spectrum_size];
124
125 if !self.has_prev {
126 self.prev_log_mag = log_mag.clone();
128 self.prev_phase = phases.clone();
129 self.has_prev = true;
130 return phases.iter().map(|&p| p as f32).collect();
131 }
132
133 let hop = self.hop_size as f64;
135 let two_pi = 2.0 * std::f64::consts::PI;
136
137 let d_phase_time: Vec<f64> = (0..spectrum_size)
140 .map(|k| {
141 let omega_k = two_pi * k as f64 / self.fft_size as f64;
143 let expected_advance = omega_k * hop;
144
145 let time_grad =
147 if log_mag[k] > self.log_mag_tol && self.prev_log_mag[k] > self.log_mag_tol {
148 self.gamma * (log_mag[k] - self.prev_log_mag[k])
149 } else {
150 0.0
151 };
152
153 expected_advance + time_grad
154 })
155 .collect();
156
157 let inv_gamma = if self.gamma.abs() > 1e-30 {
160 1.0 / self.gamma
161 } else {
162 0.0
163 };
164 let d_phase_freq: Vec<f64> = (0..spectrum_size)
165 .map(|k| {
166 if k == 0 || k == spectrum_size - 1 {
167 return 0.0;
168 }
169 if log_mag[k] > self.log_mag_tol
170 && log_mag[k - 1] > self.log_mag_tol
171 && log_mag[k + 1] > self.log_mag_tol
172 {
173 inv_gamma * (log_mag[k + 1] - log_mag[k - 1]) / 2.0
174 } else {
175 0.0
176 }
177 })
178 .collect();
179
180 let mut heap = BinaryHeap::new();
182 for (k, &mag) in log_mag.iter().enumerate() {
183 if mag > self.log_mag_tol {
184 heap.push(HeapEntry {
185 magnitude: mag,
186 bin: k,
187 });
188 }
189 }
190
191 while let Some(entry) = heap.pop() {
193 let k = entry.bin;
194 if integrated[k] {
195 continue;
196 }
197
198 let phase_from_time = self.prev_phase[k] + d_phase_time[k];
200
201 let phase_from_freq_below = if k > 0 && integrated[k - 1] {
202 Some(phases[k - 1] + d_phase_freq[k - 1])
203 } else {
204 None
205 };
206
207 let phase_from_freq_above = if k + 1 < spectrum_size && integrated[k + 1] {
208 Some(phases[k + 1] - d_phase_freq[k + 1])
209 } else {
210 None
211 };
212
213 let phase = match (phase_from_freq_below, phase_from_freq_above) {
215 (Some(below), Some(above)) => {
216 let avg = (below + above) / 2.0;
218 if self.prev_log_mag[k] > self.log_mag_tol {
220 (avg + phase_from_time) / 2.0
221 } else {
222 avg
223 }
224 }
225 (Some(below), None) => {
226 if self.prev_log_mag[k] > self.log_mag_tol {
227 (below + phase_from_time) / 2.0
228 } else {
229 below
230 }
231 }
232 (None, Some(above)) => {
233 if self.prev_log_mag[k] > self.log_mag_tol {
234 (above + phase_from_time) / 2.0
235 } else {
236 above
237 }
238 }
239 (None, None) => phase_from_time,
240 };
241
242 phases[k] = phase;
243 integrated[k] = true;
244 }
245
246 for k in 0..spectrum_size {
248 if !integrated[k] {
249 phases[k] = 0.0;
250 }
251 }
252
253 self.prev_log_mag = log_mag;
255 self.prev_phase = phases.clone();
256
257 phases.iter().map(|&p| p as f32).collect()
258 }
259
260 pub fn process_frame_into(&mut self, magnitudes: &[f32], phases_out: &mut [f32]) {
270 let spectrum_size = self.fft_size / 2 + 1;
271 assert_eq!(magnitudes.len(), spectrum_size);
272 assert_eq!(phases_out.len(), spectrum_size);
273
274 let log_mag = &mut self.scratch_log_mag;
275 let phases = &mut self.scratch_phases;
276 let integrated = &mut self.scratch_integrated;
277 let d_phase_time = &mut self.scratch_d_phase_time;
278 let d_phase_freq = &mut self.scratch_d_phase_freq;
279
280 for (i, &m) in magnitudes.iter().enumerate() {
282 log_mag[i] = if m > 0.0 {
283 (m as f64).ln()
284 } else {
285 f64::NEG_INFINITY
286 };
287 }
288
289 for v in phases.iter_mut() {
291 *v = 0.0;
292 }
293 for v in integrated.iter_mut() {
294 *v = false;
295 }
296
297 if !self.has_prev {
298 self.prev_log_mag.copy_from_slice(log_mag);
300 self.prev_phase.copy_from_slice(phases);
301 self.has_prev = true;
302 for (out, &p) in phases_out.iter_mut().zip(phases.iter()) {
303 *out = p as f32;
304 }
305 return;
306 }
307
308 let hop = self.hop_size as f64;
310 let two_pi = 2.0 * std::f64::consts::PI;
311 let gamma = self.gamma;
312 let log_mag_tol = self.log_mag_tol;
313 let fft_size = self.fft_size;
314
315 for k in 0..spectrum_size {
317 let omega_k = two_pi * k as f64 / fft_size as f64;
318 let expected_advance = omega_k * hop;
319 let time_grad = if log_mag[k] > log_mag_tol && self.prev_log_mag[k] > log_mag_tol {
320 gamma * (log_mag[k] - self.prev_log_mag[k])
321 } else {
322 0.0
323 };
324 d_phase_time[k] = expected_advance + time_grad;
325 }
326
327 let inv_gamma = if gamma.abs() > 1e-30 {
329 1.0 / gamma
330 } else {
331 0.0
332 };
333 d_phase_freq[0] = 0.0;
334 if spectrum_size > 1 {
335 d_phase_freq[spectrum_size - 1] = 0.0;
336 }
337 for k in 1..spectrum_size.saturating_sub(1) {
338 d_phase_freq[k] = if log_mag[k] > log_mag_tol
339 && log_mag[k - 1] > log_mag_tol
340 && log_mag[k + 1] > log_mag_tol
341 {
342 inv_gamma * (log_mag[k + 1] - log_mag[k - 1]) / 2.0
343 } else {
344 0.0
345 };
346 }
347
348 self.scratch_heap.clear();
350 for (k, &mag) in log_mag.iter().enumerate() {
351 if mag > log_mag_tol {
352 self.scratch_heap.push(HeapEntry {
353 magnitude: mag,
354 bin: k,
355 });
356 }
357 }
358 self.scratch_heap.sort_unstable_by(|a, b| b.cmp(a));
360
361 for idx in 0..self.scratch_heap.len() {
363 let k = self.scratch_heap[idx].bin;
364 if integrated[k] {
365 continue;
366 }
367
368 let phase_from_time = self.prev_phase[k] + d_phase_time[k];
369
370 let phase_from_freq_below = if k > 0 && integrated[k - 1] {
371 Some(phases[k - 1] + d_phase_freq[k - 1])
372 } else {
373 None
374 };
375
376 let phase_from_freq_above = if k + 1 < spectrum_size && integrated[k + 1] {
377 Some(phases[k + 1] - d_phase_freq[k + 1])
378 } else {
379 None
380 };
381
382 let phase = match (phase_from_freq_below, phase_from_freq_above) {
383 (Some(below), Some(above)) => {
384 let avg = (below + above) / 2.0;
385 if self.prev_log_mag[k] > log_mag_tol {
386 (avg + phase_from_time) / 2.0
387 } else {
388 avg
389 }
390 }
391 (Some(below), None) => {
392 if self.prev_log_mag[k] > log_mag_tol {
393 (below + phase_from_time) / 2.0
394 } else {
395 below
396 }
397 }
398 (None, Some(above)) => {
399 if self.prev_log_mag[k] > log_mag_tol {
400 (above + phase_from_time) / 2.0
401 } else {
402 above
403 }
404 }
405 (None, None) => phase_from_time,
406 };
407
408 phases[k] = phase;
409 integrated[k] = true;
410 }
411
412 for k in 0..spectrum_size {
414 if !integrated[k] {
415 phases[k] = 0.0;
416 }
417 }
418
419 self.prev_log_mag.copy_from_slice(log_mag);
421 self.prev_phase.copy_from_slice(phases);
422
423 for (out, &p) in phases_out.iter_mut().zip(phases.iter()) {
425 *out = p as f32;
426 }
427 }
428
429 pub fn reset(&mut self) {
431 self.prev_log_mag.fill(f64::NEG_INFINITY);
432 self.prev_phase.fill(0.0);
433 self.has_prev = false;
434 }
435
436 pub fn latency_samples(&self) -> usize {
438 self.fft_size
439 }
440}
441
442pub fn stretch_with_rtpghi(
453 magnitude_frames: &[Vec<f32>],
454 stretch_factor: f64,
455 fft_size: usize,
456 hop_size: usize,
457) -> Vec<Vec<f32>> {
458 if magnitude_frames.is_empty() || stretch_factor <= 0.0 {
459 return Vec::new();
460 }
461
462 let num_input_frames = magnitude_frames.len();
463 let num_output_frames = (num_input_frames as f64 * stretch_factor).ceil() as usize;
464
465 let mut stretched_mags = Vec::with_capacity(num_output_frames);
467 for i in 0..num_output_frames {
468 let src_pos = i as f64 / stretch_factor;
469 let src_idx = src_pos.floor() as usize;
470 let frac = (src_pos - src_idx as f64) as f32;
471
472 let frame = if src_idx + 1 < num_input_frames {
473 magnitude_frames[src_idx]
474 .iter()
475 .zip(&magnitude_frames[src_idx + 1])
476 .map(|(&a, &b)| a * (1.0 - frac) + b * frac)
477 .collect()
478 } else if src_idx < num_input_frames {
479 magnitude_frames[src_idx].clone()
480 } else {
481 magnitude_frames.last().unwrap().clone()
482 };
483 stretched_mags.push(frame);
484 }
485
486 let mut processor = RtpghiProcessor::new(fft_size, hop_size);
488 stretched_mags
489 .iter()
490 .map(|mags| processor.process_frame(mags))
491 .collect()
492}
493
494#[cfg(test)]
499mod tests {
500 use super::*;
501 use crate::stft::RealFftProcessor;
502
503 fn compute_stft_magnitudes(signal: &[f32], fft_size: usize, hop_size: usize) -> Vec<Vec<f32>> {
505 let spectrum_size = fft_size / 2 + 1;
506 let window: Vec<f32> = (0..fft_size)
507 .map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / fft_size as f32).cos()))
508 .collect();
509
510 let mut frames = Vec::new();
511 let mut fft = RealFftProcessor::new_forward_only(fft_size);
512
513 let mut pos = 0;
514 while pos + fft_size <= signal.len() {
515 for i in 0..fft_size {
516 fft.time_buffer[i] = signal[pos + i] * window[i];
517 }
518 fft.forward();
519
520 let mags: Vec<f32> = fft.freq_buffer[..spectrum_size]
521 .iter()
522 .map(|c| (c.re * c.re + c.im * c.im).sqrt())
523 .collect();
524 frames.push(mags);
525 pos += hop_size;
526 }
527
528 frames
529 }
530
531 #[test]
532 fn test_identity_stretch() {
533 let fft_size = 256;
534 let hop_size = 64;
535 let sample_rate = 48000.0;
536
537 let num_samples = 4096;
539 let signal: Vec<f32> = (0..num_samples)
540 .map(|i| {
541 let t = i as f32 / sample_rate;
542 (2.0 * std::f32::consts::PI * 440.0 * t).sin()
543 })
544 .collect();
545
546 let mags = compute_stft_magnitudes(&signal, fft_size, hop_size);
547 assert!(!mags.is_empty());
548
549 let phases = stretch_with_rtpghi(&mags, 1.0, fft_size, hop_size);
551 assert_eq!(phases.len(), mags.len());
552
553 for frame in &phases {
555 for &p in frame {
556 assert!(p.is_finite(), "Phase should be finite, got {p}");
557 }
558 }
559 }
560
561 #[test]
562 fn test_2x_stretch_doubles_frames() {
563 let fft_size = 256;
564 let hop_size = 64;
565
566 let spectrum_size = fft_size / 2 + 1;
568 let frame: Vec<f32> = (0..spectrum_size)
569 .map(|i| (i as f32).exp().recip())
570 .collect();
571 let mags = vec![frame; 10];
572
573 let stretched = stretch_with_rtpghi(&mags, 2.0, fft_size, hop_size);
574 assert_eq!(stretched.len(), 20);
575 }
576
577 #[test]
578 fn test_no_nan_inf() {
579 let fft_size = 512;
580 let hop_size = 128;
581 let spectrum_size = fft_size / 2 + 1;
582
583 let mut processor = RtpghiProcessor::new(fft_size, hop_size);
584
585 for frame_idx in 0..20 {
587 let mags: Vec<f32> = (0..spectrum_size)
588 .map(|k| {
589 let freq_factor = 1.0 - k as f32 / spectrum_size as f32;
590 let time_factor = 1.0 + 0.5 * (frame_idx as f32 * 0.3).sin();
591 freq_factor * time_factor
592 })
593 .collect();
594
595 let phases = processor.process_frame(&mags);
596 for (k, &p) in phases.iter().enumerate() {
597 assert!(
598 p.is_finite(),
599 "Phase at bin {k}, frame {frame_idx} is not finite: {p}"
600 );
601 }
602 }
603 }
604
605 #[test]
606 fn test_reset() {
607 let fft_size = 256;
608 let hop_size = 64;
609 let spectrum_size = fft_size / 2 + 1;
610
611 let mut processor = RtpghiProcessor::new(fft_size, hop_size);
612 let mags = vec![0.5; spectrum_size];
613
614 let _ = processor.process_frame(&mags);
616 assert!(processor.has_prev);
617
618 processor.reset();
619 assert!(!processor.has_prev);
620 }
621
622 #[test]
623 fn test_empty_stretch() {
624 let result = stretch_with_rtpghi(&[], 2.0, 256, 64);
625 assert!(result.is_empty());
626 }
627
628 #[test]
629 fn test_zero_magnitude_bins() {
630 let fft_size = 256;
631 let hop_size = 64;
632 let spectrum_size = fft_size / 2 + 1;
633
634 let mut processor = RtpghiProcessor::new(fft_size, hop_size);
635
636 let mags = vec![0.0f32; spectrum_size];
638 let _ = processor.process_frame(&mags);
639 let phases = processor.process_frame(&mags);
640
641 for &p in &phases {
642 assert!(p.is_finite());
643 }
644 }
645
646 #[test]
648 fn test_process_frame_into_matches_process_frame() {
649 let fft_size = 512;
650 let hop_size = 128;
651 let spectrum_size = fft_size / 2 + 1;
652
653 let mut proc_alloc = RtpghiProcessor::new(fft_size, hop_size);
654 let mut proc_noalloc = RtpghiProcessor::new(fft_size, hop_size);
655
656 for frame_idx in 0..15 {
657 let mags: Vec<f32> = (0..spectrum_size)
658 .map(|k| {
659 let freq_factor = 1.0 - k as f32 / spectrum_size as f32;
660 let time_factor = 1.0 + 0.5 * (frame_idx as f32 * 0.3).sin();
661 freq_factor * time_factor
662 })
663 .collect();
664
665 let phases_alloc = proc_alloc.process_frame(&mags);
666 let mut phases_noalloc = vec![0.0f32; spectrum_size];
667 proc_noalloc.process_frame_into(&mags, &mut phases_noalloc);
668
669 for (k, (&a, &b)) in phases_alloc.iter().zip(phases_noalloc.iter()).enumerate() {
670 assert!(
671 (a - b).abs() < 1e-5,
672 "Mismatch at bin {k}, frame {frame_idx}: alloc={a}, noalloc={b}"
673 );
674 }
675 }
676 }
677
678 #[test]
680 fn test_process_frame_into_no_nan() {
681 let fft_size = 256;
682 let hop_size = 64;
683 let spectrum_size = fft_size / 2 + 1;
684
685 let mut processor = RtpghiProcessor::new(fft_size, hop_size);
686 let mut phases = vec![0.0f32; spectrum_size];
687
688 for frame_idx in 0..10 {
689 let mags: Vec<f32> = (0..spectrum_size)
690 .map(|k| {
691 let v = 0.5 + 0.5 * ((frame_idx * k) as f32 * 0.1).sin();
692 v.max(0.0)
693 })
694 .collect();
695
696 processor.process_frame_into(&mags, &mut phases);
697 for (k, &p) in phases.iter().enumerate() {
698 assert!(
699 p.is_finite(),
700 "Phase at bin {k}, frame {frame_idx} is not finite: {p}"
701 );
702 }
703 }
704 }
705}