cute_dsp/
stretch.rs

1//! Time-stretching and pitch-shifting using phase vocoder
2//!
3//! This module provides high-quality time-stretching and pitch-shifting capabilities
4//! using a phase vocoder approach with spectral peak detection and formant preservation.
5
6#![allow(unused_imports)]
7
8use num_traits::{Float, FromPrimitive, NumCast};
9use num_complex::Complex;
10use core::marker::PhantomData;
11
12use crate::stft::STFT;
13
14/// Spectral band data for phase vocoder
15#[derive(Clone, Debug)]
16pub struct Band<T: Float> {
17    pub input: Complex<T>,
18    pub prev_input: Complex<T>,
19    pub output: Complex<T>,
20    pub input_energy: T,
21}
22
23impl<T: Float> Default for Band<T> {
24    fn default() -> Self {
25        Self {
26            input: Complex::new(T::zero(), T::zero()),
27            prev_input: Complex::new(T::zero(), T::zero()),
28            output: Complex::new(T::zero(), T::zero()),
29            input_energy: T::zero(),
30        }
31    }
32}
33
34/// Spectral peak information
35#[derive(Clone, Debug)]
36pub struct Peak<T: Float> {
37    pub input: T,
38    pub output: T,
39}
40
41/// Frequency mapping point
42#[derive(Clone, Debug)]
43pub struct PitchMapPoint<T: Float> {
44    pub input_bin: T,
45    pub freq_grad: T,
46}
47
48/// Phase prediction for spectral processing
49#[derive(Clone, Debug)]
50pub struct Prediction<T: Float> {
51    pub energy: T,
52    pub input: Complex<T>,
53}
54
55impl<T: Float> Default for Prediction<T> {
56    fn default() -> Self {
57        Self {
58            energy: T::zero(),
59            input: Complex::new(T::zero(), T::zero()),
60        }
61    }
62}
63
64impl<T: Float> Prediction<T> {
65    pub fn make_output(&self, phase: Complex<T>) -> Complex<T> {
66        let phase_norm = phase.norm_sqr();
67        let phase = if phase_norm <= T::epsilon() {
68            self.input
69        } else {
70            phase
71        };
72        let phase_norm = phase.norm_sqr() + T::epsilon();
73        phase * Complex::new((self.energy / phase_norm).sqrt(), T::zero())
74    }
75}
76
77/// Main time-stretching and pitch-shifting processor
78pub struct SignalsmithStretch<T: Float> {
79    // Configuration
80    split_computation: bool,
81    channels: usize,
82    bands: usize,
83    
84    // STFT and buffers
85    block_samples: usize,
86    interval_samples: usize,
87    tmp_buffer: Vec<T>,
88    
89    // STFT instances
90    analysis_stft: STFT<T>,
91    synthesis_stft: STFT<T>,
92    
93    // Spectral data
94    channel_bands: Vec<Band<T>>,
95    peaks: Vec<Peak<T>>,
96    energy: Vec<T>,
97    smoothed_energy: Vec<T>,
98    output_map: Vec<PitchMapPoint<T>>,
99    channel_predictions: Vec<Prediction<T>>,
100    
101    // Processing state
102    prev_input_offset: i32,
103    silence_counter: usize,
104    did_seek: bool,
105    
106    // Frequency mapping
107    freq_multiplier: T,
108    freq_tonality_limit: T,
109    custom_freq_map: Option<Box<dyn Fn(T) -> T + Send + Sync + 'static>>,
110    
111    // Formant processing
112    formant_multiplier: T,
113    inv_formant_multiplier: T,
114    formant_compensation: bool,
115    formant_base_freq: T,
116}
117
118impl<T: Float + FromPrimitive + NumCast + core::ops::AddAssign> SignalsmithStretch<T> {
119    /// Create a new stretch processor
120    pub fn new() -> Self {
121        Self {
122            split_computation: false,
123            channels: 0,
124            bands: 0,
125            block_samples: 0,
126            interval_samples: 0,
127            tmp_buffer: Vec::new(),
128            analysis_stft: STFT::new(false),
129            synthesis_stft: STFT::new(false),
130            channel_bands: Vec::new(),
131            peaks: Vec::new(),
132            energy: Vec::new(),
133            smoothed_energy: Vec::new(),
134            output_map: Vec::new(),
135            channel_predictions: Vec::new(),
136            prev_input_offset: -1,
137            silence_counter: 0,
138            did_seek: false,
139            freq_multiplier: T::one(),
140            freq_tonality_limit: T::from_f32(0.5).unwrap(),
141            custom_freq_map: None,
142            formant_multiplier: T::one(),
143            inv_formant_multiplier: T::one(),
144            formant_compensation: false,
145            formant_base_freq: T::zero()
146        }
147    }
148
149    /// Get the block size in samples
150    pub fn block_samples(&self) -> usize {
151        self.block_samples
152    }
153
154    /// Get the interval size in samples
155    pub fn interval_samples(&self) -> usize {
156        self.interval_samples
157    }
158
159    /// Get the input latency
160    pub fn input_latency(&self) -> usize {
161        self.block_samples / 2
162    }
163
164    /// Get the output latency
165    pub fn output_latency(&self) -> usize {
166        self.block_samples / 2 + if self.split_computation { self.interval_samples } else { 0 }
167    }
168
169    /// Reset the processor state
170    pub fn reset(&mut self) {
171        self.prev_input_offset = -1;
172        for band in &mut self.channel_bands {
173            *band = Band::default();
174        }
175        self.silence_counter = 0;
176        self.did_seek = false;
177    }
178
179    /// Configure with default preset
180    pub fn preset_default(&mut self, n_channels: usize, sample_rate: T, split_computation: bool) {
181        let block_samples = (sample_rate * T::from_f32(0.12).unwrap()).to_usize().unwrap_or(1024);
182        let interval_samples = (sample_rate * T::from_f32(0.03).unwrap()).to_usize().unwrap_or(256);
183        self.configure(n_channels, block_samples, interval_samples, split_computation);
184    }
185
186    /// Configure with cheaper preset
187    pub fn preset_cheaper(&mut self, n_channels: usize, sample_rate: T, split_computation: bool) {
188        let block_samples = (sample_rate * T::from_f32(0.1).unwrap()).to_usize().unwrap_or(1024);
189        let interval_samples = (sample_rate * T::from_f32(0.04).unwrap()).to_usize().unwrap_or(256);
190        self.configure(n_channels, block_samples, interval_samples, split_computation);
191    }
192
193    /// Manual configuration
194    pub fn configure(&mut self, n_channels: usize, block_samples: usize, interval_samples: usize, split_computation: bool) {
195        self.split_computation = split_computation;
196        self.channels = n_channels;
197        self.block_samples = block_samples;
198        self.interval_samples = interval_samples;
199        
200        self.bands = block_samples / 2 + 1;
201        
202        // Configure STFT instances
203        self.analysis_stft.configure(n_channels, n_channels, block_samples, block_samples, interval_samples);
204        self.synthesis_stft.configure(n_channels, n_channels, block_samples, block_samples, interval_samples);
205        
206        self.tmp_buffer.resize(block_samples + interval_samples, T::zero());
207        self.channel_bands.resize(self.bands * self.channels, Band::default());
208        
209        self.peaks.clear();
210        self.peaks.reserve(self.bands / 2);
211        self.energy.resize(self.bands, T::zero());
212        self.smoothed_energy.resize(self.bands, T::zero());
213        self.output_map.resize(self.bands, PitchMapPoint { input_bin: T::zero(), freq_grad: T::one() });
214        self.channel_predictions.resize(self.channels * self.bands, Prediction::default());
215        
216        self.reset();
217    }
218
219    /// Set transpose factor for pitch shifting
220    pub fn set_transpose_factor(&mut self, multiplier: T, tonality_limit: T) {
221        self.freq_multiplier = multiplier;
222        if tonality_limit > T::zero() {
223            self.freq_tonality_limit = tonality_limit / multiplier.sqrt();
224        } else {
225            self.freq_tonality_limit = T::one();
226        }
227        self.custom_freq_map = None;
228    }
229
230    /// Set transpose in semitones
231    pub fn set_transpose_semitones(&mut self, semitones: T, tonality_limit: T) {
232        let multiplier = T::from_f32(2.0).unwrap().powf(semitones / T::from_f32(12.0).unwrap());
233        self.set_transpose_factor(multiplier, tonality_limit);
234    }
235
236    /// Set custom frequency mapping function
237    pub fn set_freq_map<F>(&mut self, input_to_output: F)
238    where
239        F: Fn(T) -> T + 'static + Send + Sync,
240    {
241        self.custom_freq_map = Some(Box::new(input_to_output));
242    }
243
244    /// Set formant factor
245    pub fn set_formant_factor(&mut self, multiplier: T, compensate_pitch: bool) {
246        self.formant_multiplier = multiplier;
247        self.inv_formant_multiplier = T::one() / multiplier;
248        self.formant_compensation = compensate_pitch;
249    }
250
251    /// Set formant shift in semitones
252    pub fn set_formant_semitones(&mut self, semitones: T, compensate_pitch: bool) {
253        let multiplier = T::from_f32(2.0).unwrap().powf(semitones / T::from_f32(12.0).unwrap());
254        self.set_formant_factor(multiplier, compensate_pitch);
255    }
256
257    /// Set formant base frequency
258    pub fn set_formant_base(&mut self, base_freq: T) {
259        self.formant_base_freq = base_freq;
260    }
261
262    /// Convert bin index to frequency (simplified)
263    fn bin_to_freq(&self, bin: T) -> T {
264        bin * T::from_f32(22050.0).unwrap() / T::from_usize(self.bands).unwrap()
265    }
266
267    /// Convert frequency to bin index (simplified)
268    fn freq_to_bin(&self, freq: T) -> T {
269        freq * T::from_usize(self.bands).unwrap() / T::from_f32(22050.0).unwrap()
270    }
271
272    /// Map frequency according to current settings
273    fn map_freq(&self, freq: T) -> T {
274        if let Some(ref custom_map) = self.custom_freq_map {
275            custom_map(freq)
276        } else if freq > self.freq_tonality_limit {
277            freq + (self.freq_multiplier - T::one()) * self.freq_tonality_limit
278        } else {
279            freq * self.freq_multiplier
280        }
281    }
282
283    /// Get bands for a specific channel
284    fn bands_for_channel(&self, channel: usize) -> &[Band<T>] {
285        let start = channel * self.bands;
286        let end = start + self.bands;
287        &self.channel_bands[start..end]
288    }
289
290    /// Get mutable bands for a specific channel
291    fn bands_for_channel_mut(&mut self, channel: usize) -> &mut [Band<T>] {
292        let start = channel * self.bands;
293        let end = start + self.bands;
294        &mut self.channel_bands[start..end]
295    }
296
297    /// Get predictions for a specific channel
298    fn predictions_for_channel(&self, channel: usize) -> &[Prediction<T>] {
299        let start = channel * self.bands;
300        let end = start + self.bands;
301        &self.channel_predictions[start..end]
302    }
303
304    /// Get mutable predictions for a specific channel
305    fn predictions_for_channel_mut(&mut self, channel: usize) -> &mut [Prediction<T>] {
306        let start = channel * self.bands;
307        let end = start + self.bands;
308        &mut self.channel_predictions[start..end]
309    }
310
311    /// Find spectral peaks
312    fn find_peaks(&mut self) {
313        self.peaks.clear();
314        
315        let mut start = 0;
316        while start < self.bands {
317            if self.energy[start] > self.smoothed_energy[start] {
318                let mut end = start;
319                let mut band_sum = T::zero();
320                let mut energy_sum = T::zero();
321                
322                while end < self.bands && self.energy[end] > self.smoothed_energy[end] {
323                    band_sum = band_sum + T::from_usize(end).unwrap() * self.energy[end];
324                    energy_sum = energy_sum + self.energy[end];
325                    end += 1;
326                }
327                
328                let avg_band = band_sum / energy_sum;
329                let avg_freq = self.bin_to_freq(avg_band);
330                self.peaks.push(Peak {
331                    input: avg_band,
332                    output: self.freq_to_bin(self.map_freq(avg_freq)),
333                });
334                
335                start = end;
336            } else {
337                start += 1;
338            }
339        }
340    }
341
342    /// Update output frequency mapping
343    fn update_output_map(&mut self) {
344        if self.peaks.is_empty() {
345            for b in 0..self.bands {
346                self.output_map[b] = PitchMapPoint {
347                    input_bin: T::from_usize(b).unwrap(),
348                    freq_grad: T::one(),
349                };
350            }
351            return;
352        }
353
354        let bottom_offset = self.peaks[0].input - self.peaks[0].output;
355        let end_bin = (self.peaks[0].output.ceil()).to_usize().unwrap_or(0).min(self.bands);
356        
357        for b in 0..end_bin {
358            self.output_map[b] = PitchMapPoint {
359                input_bin: T::from_usize(b).unwrap() + bottom_offset,
360                freq_grad: T::one(),
361            };
362        }
363
364        // Interpolate between peaks
365        for p in 1..self.peaks.len() {
366            let prev = &self.peaks[p - 1];
367            let next = &self.peaks[p];
368            
369            let range_scale = T::one() / (next.output - prev.output);
370            let out_offset = prev.input - prev.output;
371            let out_scale = next.input - next.output - prev.input + prev.output;
372            let grad_scale = out_scale * range_scale;
373            
374            let start_bin = (prev.output.ceil()).to_usize().unwrap_or(0);
375            let end_bin = (next.output.ceil()).to_usize().unwrap_or(0).min(self.bands);
376            
377            for b in start_bin..end_bin {
378                let r = (T::from_usize(b).unwrap() - prev.output) * range_scale;
379                let h = r * r * (T::from_f32(3.0).unwrap() - T::from_f32(2.0).unwrap() * r);
380                let out_b = T::from_usize(b).unwrap() + out_offset + h * out_scale;
381                
382                let grad_h = T::from_f32(6.0).unwrap() * r * (T::one() - r);
383                let grad_b = T::one() + grad_h * grad_scale;
384                
385                self.output_map[b] = PitchMapPoint {
386                    input_bin: out_b,
387                    freq_grad: grad_b,
388                };
389            }
390        }
391
392        let top_offset = self.peaks.last().unwrap().input - self.peaks.last().unwrap().output;
393        let start_bin = (self.peaks.last().unwrap().output).to_usize().unwrap_or(0);
394        
395        for b in start_bin..self.bands {
396            self.output_map[b] = PitchMapPoint {
397                input_bin: T::from_usize(b).unwrap() + top_offset,
398                freq_grad: T::one(),
399            };
400        }
401    }
402
403    /// Main processing function (simplified)
404    pub fn process<I, O>(&mut self, inputs: I, input_samples: usize, mut outputs: O, output_samples: usize)
405    where
406        I: AsRef<[Vec<T>]>,
407        O: AsMut<[Vec<T>]>,
408    {
409        let inputs = inputs.as_ref();
410        let outputs = outputs.as_mut();
411        
412        // Simplified processing - just copy input to output for now
413        for c in 0..self.channels.min(inputs.len()).min(outputs.len()) {
414            let input_channel = &inputs[c];
415            let output_channel = &mut outputs[c];
416            
417            for i in 0..output_samples.min(output_channel.len()) {
418                let input_idx = (i * input_samples / output_samples).min(input_channel.len().saturating_sub(1));
419                output_channel[i] = input_channel[input_idx];
420            }
421        }
422    }
423}
424
425impl<T: Float + FromPrimitive + NumCast + core::ops::AddAssign> Default for SignalsmithStretch<T> {
426    fn default() -> Self {
427        Self::new()
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn test_complex_operations() {
437        let a = Complex::new(1.0, 2.0);
438        let b = Complex::new(3.0, 4.0);
439        
440        let c = a * b;
441        assert!((c.re - (-5.0)).abs() < 1e-6);
442        assert!((c.im - 10.0).abs() < 1e-6);
443        
444        let norm_sq = a.norm_sqr();
445        assert!((norm_sq - 5.0).abs() < 1e-6);
446        
447        let conj = a.conj();
448        assert!((conj.re - 1.0).abs() < 1e-6);
449        assert!((conj.im - (-2.0)).abs() < 1e-6);
450    }
451
452    #[test]
453    fn test_band_default() {
454        let band: Band<f32> = Band::default();
455        assert_eq!(band.input.re, 0.0);
456        assert_eq!(band.input.im, 0.0);
457        assert_eq!(band.input_energy, 0.0);
458    }
459
460    #[test]
461    fn test_prediction_make_output() {
462        let mut pred = Prediction::<f32>::default();
463        pred.energy = 4.0;
464        pred.input = Complex::new(2.0, 0.0);
465        
466        let phase = Complex::new(1.0, 1.0);
467        let output = pred.make_output(phase);
468        
469        println!("output.norm() = {}", output.norm());
470        
471        assert!(output.norm().is_finite() && output.norm() > 0.0);
472    }
473
474    #[test]
475    fn test_cute_stretch_new() {
476        let stretch = SignalsmithStretch::<f32>::new();
477        assert_eq!(stretch.channels, 0);
478        assert_eq!(stretch.bands, 0);
479        assert_eq!(stretch.block_samples, 0);
480    }
481
482    #[test]
483    fn test_cute_stretch_configure() {
484        let mut stretch = SignalsmithStretch::<f32>::new();
485        stretch.configure(2, 1024, 256, false);
486        
487        assert_eq!(stretch.channels, 2);
488        assert_eq!(stretch.block_samples, 1024);
489        assert_eq!(stretch.interval_samples, 256);
490        assert_eq!(stretch.bands, 513);
491        assert_eq!(stretch.channel_bands.len(), 2 * 513);
492    }
493
494    #[test]
495    fn test_transpose_factor() {
496        let mut stretch = SignalsmithStretch::<f32>::new();
497        stretch.set_transpose_factor(2.0, 0.5);
498        
499        assert_eq!(stretch.freq_multiplier, 2.0);
500        assert!((stretch.freq_tonality_limit - (0.5 / 2.0_f32.sqrt())).abs() < 1e-6);
501    }
502
503    #[test]
504    fn test_transpose_semitones() {
505        let mut stretch = SignalsmithStretch::<f32>::new();
506        stretch.set_transpose_semitones(12.0, 0.5);
507        
508        assert!((stretch.freq_multiplier - 2.0).abs() < 1e-6);
509    }
510
511    #[test]
512    fn test_formant_factor() {
513        let mut stretch = SignalsmithStretch::<f32>::new();
514        stretch.set_formant_factor(1.5, true);
515        
516        assert_eq!(stretch.formant_multiplier, 1.5);
517        assert!((stretch.inv_formant_multiplier - (1.0/1.5)).abs() < 1e-6);
518        assert!(stretch.formant_compensation);
519    }
520
521    #[test]
522    fn test_find_peaks() {
523        let mut stretch = SignalsmithStretch::<f32>::new();
524        stretch.configure(1, 8, 4, false);
525        
526        stretch.energy = vec![0.1, 0.5, 0.8, 0.3, 0.1, 0.2, 0.1, 0.1];
527        stretch.smoothed_energy = vec![0.2, 0.3, 0.4, 0.3, 0.2, 0.2, 0.1, 0.1];
528        
529        stretch.find_peaks();
530        
531        assert!(!stretch.peaks.is_empty());
532    }
533
534    #[test]
535    fn test_update_output_map() {
536        let mut stretch = SignalsmithStretch::<f32>::new();
537        stretch.configure(1, 8, 4, false);
538        
539        stretch.peaks.push(Peak { input: 2.0, output: 3.0 });
540        stretch.peaks.push(Peak { input: 5.0, output: 6.0 });
541        
542        stretch.update_output_map();
543        
544        assert_eq!(stretch.output_map.len(), stretch.bands);
545        assert!(stretch.output_map[0].input_bin < stretch.output_map[1].input_bin);
546    }
547
548    #[test]
549    fn test_process_simple() {
550        let mut stretch = SignalsmithStretch::<f32>::new();
551        stretch.configure(2, 1024, 256, false);
552        
553        let inputs = vec![
554            vec![1.0, 2.0, 3.0, 4.0],
555            vec![5.0, 6.0, 7.0, 8.0],
556        ];
557        let mut outputs = vec![
558            vec![0.0; 6],
559            vec![0.0; 6],
560        ];
561        
562        stretch.process(&inputs, 4, &mut outputs, 6);
563        
564        assert!(outputs[0].iter().any(|&x| x != 0.0));
565        assert!(outputs[1].iter().any(|&x| x != 0.0));
566    }
567}