1#![allow(unused_imports)]
7
8use num_traits::{Float, FromPrimitive, NumCast};
9use num_complex::Complex;
10use core::marker::PhantomData;
11
12use crate::stft::STFT;
13
14#[derive(Clone, Debug)]
16pub struct Band<T: Float> {
17 pub input: Complex<T>,
18 pub prev_input: Complex<T>,
19 pub output: Complex<T>,
20 pub input_energy: T,
21}
22
23impl<T: Float> Default for Band<T> {
24 fn default() -> Self {
25 Self {
26 input: Complex::new(T::zero(), T::zero()),
27 prev_input: Complex::new(T::zero(), T::zero()),
28 output: Complex::new(T::zero(), T::zero()),
29 input_energy: T::zero(),
30 }
31 }
32}
33
34#[derive(Clone, Debug)]
36pub struct Peak<T: Float> {
37 pub input: T,
38 pub output: T,
39}
40
41#[derive(Clone, Debug)]
43pub struct PitchMapPoint<T: Float> {
44 pub input_bin: T,
45 pub freq_grad: T,
46}
47
48#[derive(Clone, Debug)]
50pub struct Prediction<T: Float> {
51 pub energy: T,
52 pub input: Complex<T>,
53}
54
55impl<T: Float> Default for Prediction<T> {
56 fn default() -> Self {
57 Self {
58 energy: T::zero(),
59 input: Complex::new(T::zero(), T::zero()),
60 }
61 }
62}
63
64impl<T: Float> Prediction<T> {
65 pub fn make_output(&self, phase: Complex<T>) -> Complex<T> {
66 let phase_norm = phase.norm_sqr();
67 let phase = if phase_norm <= T::epsilon() {
68 self.input
69 } else {
70 phase
71 };
72 let phase_norm = phase.norm_sqr() + T::epsilon();
73 phase * Complex::new((self.energy / phase_norm).sqrt(), T::zero())
74 }
75}
76
77pub struct SignalsmithStretch<T: Float> {
79 split_computation: bool,
81 channels: usize,
82 bands: usize,
83
84 block_samples: usize,
86 interval_samples: usize,
87 tmp_buffer: Vec<T>,
88
89 analysis_stft: STFT<T>,
91 synthesis_stft: STFT<T>,
92
93 channel_bands: Vec<Band<T>>,
95 peaks: Vec<Peak<T>>,
96 energy: Vec<T>,
97 smoothed_energy: Vec<T>,
98 output_map: Vec<PitchMapPoint<T>>,
99 channel_predictions: Vec<Prediction<T>>,
100
101 prev_input_offset: i32,
103 silence_counter: usize,
104 did_seek: bool,
105
106 freq_multiplier: T,
108 freq_tonality_limit: T,
109 custom_freq_map: Option<Box<dyn Fn(T) -> T + Send + Sync + 'static>>,
110
111 formant_multiplier: T,
113 inv_formant_multiplier: T,
114 formant_compensation: bool,
115 formant_base_freq: T,
116}
117
118impl<T: Float + FromPrimitive + NumCast + core::ops::AddAssign> SignalsmithStretch<T> {
119 pub fn new() -> Self {
121 Self {
122 split_computation: false,
123 channels: 0,
124 bands: 0,
125 block_samples: 0,
126 interval_samples: 0,
127 tmp_buffer: Vec::new(),
128 analysis_stft: STFT::new(false),
129 synthesis_stft: STFT::new(false),
130 channel_bands: Vec::new(),
131 peaks: Vec::new(),
132 energy: Vec::new(),
133 smoothed_energy: Vec::new(),
134 output_map: Vec::new(),
135 channel_predictions: Vec::new(),
136 prev_input_offset: -1,
137 silence_counter: 0,
138 did_seek: false,
139 freq_multiplier: T::one(),
140 freq_tonality_limit: T::from_f32(0.5).unwrap(),
141 custom_freq_map: None,
142 formant_multiplier: T::one(),
143 inv_formant_multiplier: T::one(),
144 formant_compensation: false,
145 formant_base_freq: T::zero()
146 }
147 }
148
149 pub fn block_samples(&self) -> usize {
151 self.block_samples
152 }
153
154 pub fn interval_samples(&self) -> usize {
156 self.interval_samples
157 }
158
159 pub fn input_latency(&self) -> usize {
161 self.block_samples / 2
162 }
163
164 pub fn output_latency(&self) -> usize {
166 self.block_samples / 2 + if self.split_computation { self.interval_samples } else { 0 }
167 }
168
169 pub fn reset(&mut self) {
171 self.prev_input_offset = -1;
172 for band in &mut self.channel_bands {
173 *band = Band::default();
174 }
175 self.silence_counter = 0;
176 self.did_seek = false;
177 }
178
179 pub fn preset_default(&mut self, n_channels: usize, sample_rate: T, split_computation: bool) {
181 let block_samples = (sample_rate * T::from_f32(0.12).unwrap()).to_usize().unwrap_or(1024);
182 let interval_samples = (sample_rate * T::from_f32(0.03).unwrap()).to_usize().unwrap_or(256);
183 self.configure(n_channels, block_samples, interval_samples, split_computation);
184 }
185
186 pub fn preset_cheaper(&mut self, n_channels: usize, sample_rate: T, split_computation: bool) {
188 let block_samples = (sample_rate * T::from_f32(0.1).unwrap()).to_usize().unwrap_or(1024);
189 let interval_samples = (sample_rate * T::from_f32(0.04).unwrap()).to_usize().unwrap_or(256);
190 self.configure(n_channels, block_samples, interval_samples, split_computation);
191 }
192
193 pub fn configure(&mut self, n_channels: usize, block_samples: usize, interval_samples: usize, split_computation: bool) {
195 self.split_computation = split_computation;
196 self.channels = n_channels;
197 self.block_samples = block_samples;
198 self.interval_samples = interval_samples;
199
200 self.bands = block_samples / 2 + 1;
201
202 self.analysis_stft.configure(n_channels, n_channels, block_samples, block_samples, interval_samples);
204 self.synthesis_stft.configure(n_channels, n_channels, block_samples, block_samples, interval_samples);
205
206 self.tmp_buffer.resize(block_samples + interval_samples, T::zero());
207 self.channel_bands.resize(self.bands * self.channels, Band::default());
208
209 self.peaks.clear();
210 self.peaks.reserve(self.bands / 2);
211 self.energy.resize(self.bands, T::zero());
212 self.smoothed_energy.resize(self.bands, T::zero());
213 self.output_map.resize(self.bands, PitchMapPoint { input_bin: T::zero(), freq_grad: T::one() });
214 self.channel_predictions.resize(self.channels * self.bands, Prediction::default());
215
216 self.reset();
217 }
218
219 pub fn set_transpose_factor(&mut self, multiplier: T, tonality_limit: T) {
221 self.freq_multiplier = multiplier;
222 if tonality_limit > T::zero() {
223 self.freq_tonality_limit = tonality_limit / multiplier.sqrt();
224 } else {
225 self.freq_tonality_limit = T::one();
226 }
227 self.custom_freq_map = None;
228 }
229
230 pub fn set_transpose_semitones(&mut self, semitones: T, tonality_limit: T) {
232 let multiplier = T::from_f32(2.0).unwrap().powf(semitones / T::from_f32(12.0).unwrap());
233 self.set_transpose_factor(multiplier, tonality_limit);
234 }
235
236 pub fn set_freq_map<F>(&mut self, input_to_output: F)
238 where
239 F: Fn(T) -> T + 'static + Send + Sync,
240 {
241 self.custom_freq_map = Some(Box::new(input_to_output));
242 }
243
244 pub fn set_formant_factor(&mut self, multiplier: T, compensate_pitch: bool) {
246 self.formant_multiplier = multiplier;
247 self.inv_formant_multiplier = T::one() / multiplier;
248 self.formant_compensation = compensate_pitch;
249 }
250
251 pub fn set_formant_semitones(&mut self, semitones: T, compensate_pitch: bool) {
253 let multiplier = T::from_f32(2.0).unwrap().powf(semitones / T::from_f32(12.0).unwrap());
254 self.set_formant_factor(multiplier, compensate_pitch);
255 }
256
257 pub fn set_formant_base(&mut self, base_freq: T) {
259 self.formant_base_freq = base_freq;
260 }
261
262 fn bin_to_freq(&self, bin: T) -> T {
264 bin * T::from_f32(22050.0).unwrap() / T::from_usize(self.bands).unwrap()
265 }
266
267 fn freq_to_bin(&self, freq: T) -> T {
269 freq * T::from_usize(self.bands).unwrap() / T::from_f32(22050.0).unwrap()
270 }
271
272 fn map_freq(&self, freq: T) -> T {
274 if let Some(ref custom_map) = self.custom_freq_map {
275 custom_map(freq)
276 } else if freq > self.freq_tonality_limit {
277 freq + (self.freq_multiplier - T::one()) * self.freq_tonality_limit
278 } else {
279 freq * self.freq_multiplier
280 }
281 }
282
283 fn bands_for_channel(&self, channel: usize) -> &[Band<T>] {
285 let start = channel * self.bands;
286 let end = start + self.bands;
287 &self.channel_bands[start..end]
288 }
289
290 fn bands_for_channel_mut(&mut self, channel: usize) -> &mut [Band<T>] {
292 let start = channel * self.bands;
293 let end = start + self.bands;
294 &mut self.channel_bands[start..end]
295 }
296
297 fn predictions_for_channel(&self, channel: usize) -> &[Prediction<T>] {
299 let start = channel * self.bands;
300 let end = start + self.bands;
301 &self.channel_predictions[start..end]
302 }
303
304 fn predictions_for_channel_mut(&mut self, channel: usize) -> &mut [Prediction<T>] {
306 let start = channel * self.bands;
307 let end = start + self.bands;
308 &mut self.channel_predictions[start..end]
309 }
310
311 fn find_peaks(&mut self) {
313 self.peaks.clear();
314
315 let mut start = 0;
316 while start < self.bands {
317 if self.energy[start] > self.smoothed_energy[start] {
318 let mut end = start;
319 let mut band_sum = T::zero();
320 let mut energy_sum = T::zero();
321
322 while end < self.bands && self.energy[end] > self.smoothed_energy[end] {
323 band_sum = band_sum + T::from_usize(end).unwrap() * self.energy[end];
324 energy_sum = energy_sum + self.energy[end];
325 end += 1;
326 }
327
328 let avg_band = band_sum / energy_sum;
329 let avg_freq = self.bin_to_freq(avg_band);
330 self.peaks.push(Peak {
331 input: avg_band,
332 output: self.freq_to_bin(self.map_freq(avg_freq)),
333 });
334
335 start = end;
336 } else {
337 start += 1;
338 }
339 }
340 }
341
342 fn update_output_map(&mut self) {
344 if self.peaks.is_empty() {
345 for b in 0..self.bands {
346 self.output_map[b] = PitchMapPoint {
347 input_bin: T::from_usize(b).unwrap(),
348 freq_grad: T::one(),
349 };
350 }
351 return;
352 }
353
354 let bottom_offset = self.peaks[0].input - self.peaks[0].output;
355 let end_bin = (self.peaks[0].output.ceil()).to_usize().unwrap_or(0).min(self.bands);
356
357 for b in 0..end_bin {
358 self.output_map[b] = PitchMapPoint {
359 input_bin: T::from_usize(b).unwrap() + bottom_offset,
360 freq_grad: T::one(),
361 };
362 }
363
364 for p in 1..self.peaks.len() {
366 let prev = &self.peaks[p - 1];
367 let next = &self.peaks[p];
368
369 let range_scale = T::one() / (next.output - prev.output);
370 let out_offset = prev.input - prev.output;
371 let out_scale = next.input - next.output - prev.input + prev.output;
372 let grad_scale = out_scale * range_scale;
373
374 let start_bin = (prev.output.ceil()).to_usize().unwrap_or(0);
375 let end_bin = (next.output.ceil()).to_usize().unwrap_or(0).min(self.bands);
376
377 for b in start_bin..end_bin {
378 let r = (T::from_usize(b).unwrap() - prev.output) * range_scale;
379 let h = r * r * (T::from_f32(3.0).unwrap() - T::from_f32(2.0).unwrap() * r);
380 let out_b = T::from_usize(b).unwrap() + out_offset + h * out_scale;
381
382 let grad_h = T::from_f32(6.0).unwrap() * r * (T::one() - r);
383 let grad_b = T::one() + grad_h * grad_scale;
384
385 self.output_map[b] = PitchMapPoint {
386 input_bin: out_b,
387 freq_grad: grad_b,
388 };
389 }
390 }
391
392 let top_offset = self.peaks.last().unwrap().input - self.peaks.last().unwrap().output;
393 let start_bin = (self.peaks.last().unwrap().output).to_usize().unwrap_or(0);
394
395 for b in start_bin..self.bands {
396 self.output_map[b] = PitchMapPoint {
397 input_bin: T::from_usize(b).unwrap() + top_offset,
398 freq_grad: T::one(),
399 };
400 }
401 }
402
403 pub fn process<I, O>(&mut self, inputs: I, input_samples: usize, mut outputs: O, output_samples: usize)
405 where
406 I: AsRef<[Vec<T>]>,
407 O: AsMut<[Vec<T>]>,
408 {
409 let inputs = inputs.as_ref();
410 let outputs = outputs.as_mut();
411
412 for c in 0..self.channels.min(inputs.len()).min(outputs.len()) {
414 let input_channel = &inputs[c];
415 let output_channel = &mut outputs[c];
416
417 for i in 0..output_samples.min(output_channel.len()) {
418 let input_idx = (i * input_samples / output_samples).min(input_channel.len().saturating_sub(1));
419 output_channel[i] = input_channel[input_idx];
420 }
421 }
422 }
423}
424
425impl<T: Float + FromPrimitive + NumCast + core::ops::AddAssign> Default for SignalsmithStretch<T> {
426 fn default() -> Self {
427 Self::new()
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_complex_operations() {
437 let a = Complex::new(1.0, 2.0);
438 let b = Complex::new(3.0, 4.0);
439
440 let c = a * b;
441 assert!((c.re - (-5.0)).abs() < 1e-6);
442 assert!((c.im - 10.0).abs() < 1e-6);
443
444 let norm_sq = a.norm_sqr();
445 assert!((norm_sq - 5.0).abs() < 1e-6);
446
447 let conj = a.conj();
448 assert!((conj.re - 1.0).abs() < 1e-6);
449 assert!((conj.im - (-2.0)).abs() < 1e-6);
450 }
451
452 #[test]
453 fn test_band_default() {
454 let band: Band<f32> = Band::default();
455 assert_eq!(band.input.re, 0.0);
456 assert_eq!(band.input.im, 0.0);
457 assert_eq!(band.input_energy, 0.0);
458 }
459
460 #[test]
461 fn test_prediction_make_output() {
462 let mut pred = Prediction::<f32>::default();
463 pred.energy = 4.0;
464 pred.input = Complex::new(2.0, 0.0);
465
466 let phase = Complex::new(1.0, 1.0);
467 let output = pred.make_output(phase);
468
469 println!("output.norm() = {}", output.norm());
470
471 assert!(output.norm().is_finite() && output.norm() > 0.0);
472 }
473
474 #[test]
475 fn test_cute_stretch_new() {
476 let stretch = SignalsmithStretch::<f32>::new();
477 assert_eq!(stretch.channels, 0);
478 assert_eq!(stretch.bands, 0);
479 assert_eq!(stretch.block_samples, 0);
480 }
481
482 #[test]
483 fn test_cute_stretch_configure() {
484 let mut stretch = SignalsmithStretch::<f32>::new();
485 stretch.configure(2, 1024, 256, false);
486
487 assert_eq!(stretch.channels, 2);
488 assert_eq!(stretch.block_samples, 1024);
489 assert_eq!(stretch.interval_samples, 256);
490 assert_eq!(stretch.bands, 513);
491 assert_eq!(stretch.channel_bands.len(), 2 * 513);
492 }
493
494 #[test]
495 fn test_transpose_factor() {
496 let mut stretch = SignalsmithStretch::<f32>::new();
497 stretch.set_transpose_factor(2.0, 0.5);
498
499 assert_eq!(stretch.freq_multiplier, 2.0);
500 assert!((stretch.freq_tonality_limit - (0.5 / 2.0_f32.sqrt())).abs() < 1e-6);
501 }
502
503 #[test]
504 fn test_transpose_semitones() {
505 let mut stretch = SignalsmithStretch::<f32>::new();
506 stretch.set_transpose_semitones(12.0, 0.5);
507
508 assert!((stretch.freq_multiplier - 2.0).abs() < 1e-6);
509 }
510
511 #[test]
512 fn test_formant_factor() {
513 let mut stretch = SignalsmithStretch::<f32>::new();
514 stretch.set_formant_factor(1.5, true);
515
516 assert_eq!(stretch.formant_multiplier, 1.5);
517 assert!((stretch.inv_formant_multiplier - (1.0/1.5)).abs() < 1e-6);
518 assert!(stretch.formant_compensation);
519 }
520
521 #[test]
522 fn test_find_peaks() {
523 let mut stretch = SignalsmithStretch::<f32>::new();
524 stretch.configure(1, 8, 4, false);
525
526 stretch.energy = vec![0.1, 0.5, 0.8, 0.3, 0.1, 0.2, 0.1, 0.1];
527 stretch.smoothed_energy = vec![0.2, 0.3, 0.4, 0.3, 0.2, 0.2, 0.1, 0.1];
528
529 stretch.find_peaks();
530
531 assert!(!stretch.peaks.is_empty());
532 }
533
534 #[test]
535 fn test_update_output_map() {
536 let mut stretch = SignalsmithStretch::<f32>::new();
537 stretch.configure(1, 8, 4, false);
538
539 stretch.peaks.push(Peak { input: 2.0, output: 3.0 });
540 stretch.peaks.push(Peak { input: 5.0, output: 6.0 });
541
542 stretch.update_output_map();
543
544 assert_eq!(stretch.output_map.len(), stretch.bands);
545 assert!(stretch.output_map[0].input_bin < stretch.output_map[1].input_bin);
546 }
547
548 #[test]
549 fn test_process_simple() {
550 let mut stretch = SignalsmithStretch::<f32>::new();
551 stretch.configure(2, 1024, 256, false);
552
553 let inputs = vec![
554 vec![1.0, 2.0, 3.0, 4.0],
555 vec![5.0, 6.0, 7.0, 8.0],
556 ];
557 let mut outputs = vec![
558 vec![0.0; 6],
559 vec![0.0; 6],
560 ];
561
562 stretch.process(&inputs, 4, &mut outputs, 6);
563
564 assert!(outputs[0].iter().any(|&x| x != 0.0));
565 assert!(outputs[1].iter().any(|&x| x != 0.0));
566 }
567}