Skip to main content

nice_plug_core/util/
stft.rs

1//! Utilities for buffering audio, likely used as part of a short-term Fourier transform.
2
3use std::cmp;
4
5use crate::buffer::{Block, Buffer};
6
7/// Some buffer that can be used with the [`StftHelper`].
8pub trait StftInput {
9    /// The number of samples in this input.
10    fn num_samples(&self) -> usize;
11
12    /// The number of channels in this input.
13    fn num_channels(&self) -> usize;
14
15    /// Index the buffer without any bounds checks.
16    unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32;
17}
18
19/// The same as [`StftInput`], but with support for writing results back to the buffer
20pub trait StftInputMut: StftInput {
21    /// Get a mutable reference to a sample in the buffer without any bounds checks.
22    unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32;
23}
24
25/// Process the input buffer in equal sized blocks, running a callback on each block to transform
26/// the block and then writing back the results from the previous block to the buffer. This
27/// introduces latency equal to the size of the block.
28///
29/// Additional inputs can be processed by setting the `NUM_SIDECHAIN_INPUTS` constant. These buffers
30/// will not be written to, so they are purely used for analysis. These sidechain inputs will have
31/// the same number of channels as the main input.
32///
33/// TODO: Better name?
34/// TODO: We may need something like this purely for analysis, e.g. for showing spectrums in a GUI.
35///       Figure out the cleanest way to adapt this for the non-processing use case.
36pub struct StftHelper<const NUM_SIDECHAIN_INPUTS: usize = 0> {
37    // These ring buffers store the input samples and the already processed output produced by
38    // adding overlapping windows. Whenever we reach a new overlapping window, we'll write the
39    // already calculated outputs to the main buffer passed to the process function and then process
40    // a new block.
41    main_input_ring_buffers: Vec<Vec<f32>>,
42    main_output_ring_buffers: Vec<Vec<f32>>,
43    sidechain_ring_buffers: [Vec<Vec<f32>>; NUM_SIDECHAIN_INPUTS],
44
45    /// Results from the ring buffers are copied to this scratch buffer before being passed to the
46    /// plugin. Needed to handle overlap.
47    scratch_buffer: Vec<f32>,
48    /// If padding is used, then this will contain the previous iteration's values from the padding
49    /// values in `scratch_buffer` (`scratch_buffer[(scratch_buffer.len() - padding -
50    /// 1)..scratch_buffer.len()]`). This is then added to the ring buffer in the next iteration.
51    padding_buffers: Vec<Vec<f32>>,
52
53    /// The current position in our ring buffers. Whenever this wraps around to 0, we'll process
54    /// a block.
55    current_pos: usize,
56    /// If padding is used, then this much extra capacity has been added to the buffers.
57    padding: usize,
58}
59
60/// Marker struct for the version without sidechaining.
61struct NoSidechain;
62
63impl StftInput for Buffer<'_> {
64    #[inline]
65    fn num_samples(&self) -> usize {
66        self.samples()
67    }
68
69    #[inline]
70    fn num_channels(&self) -> usize {
71        self.channels()
72    }
73
74    #[inline]
75    unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 {
76        unsafe {
77            *self
78                .as_slice_immutable()
79                .get_unchecked(channel)
80                .get_unchecked(sample_idx)
81        }
82    }
83}
84
85impl StftInputMut for Buffer<'_> {
86    #[inline]
87    unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32 {
88        unsafe {
89            self.as_slice()
90                .get_unchecked_mut(channel)
91                .get_unchecked_mut(sample_idx)
92        }
93    }
94}
95
96impl StftInput for Block<'_, '_> {
97    #[inline]
98    fn num_samples(&self) -> usize {
99        self.samples()
100    }
101
102    #[inline]
103    fn num_channels(&self) -> usize {
104        self.channels()
105    }
106
107    #[inline]
108    unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 {
109        unsafe { *self.get_unchecked(channel).get_unchecked(sample_idx) }
110    }
111}
112
113impl StftInputMut for Block<'_, '_> {
114    #[inline]
115    unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32 {
116        unsafe {
117            self.get_unchecked_mut(channel)
118                .get_unchecked_mut(sample_idx)
119        }
120    }
121}
122
123impl StftInput for [&[f32]] {
124    #[inline]
125    fn num_samples(&self) -> usize {
126        if self.is_empty() { 0 } else { self[0].len() }
127    }
128
129    #[inline]
130    fn num_channels(&self) -> usize {
131        self.len()
132    }
133
134    #[inline]
135    unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 {
136        unsafe { *self.get_unchecked(channel).get_unchecked(sample_idx) }
137    }
138}
139
140impl StftInput for [&mut [f32]] {
141    #[inline]
142    fn num_samples(&self) -> usize {
143        if self.is_empty() { 0 } else { self[0].len() }
144    }
145
146    #[inline]
147    fn num_channels(&self) -> usize {
148        self.len()
149    }
150
151    #[inline]
152    unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 {
153        unsafe { *self.get_unchecked(channel).get_unchecked(sample_idx) }
154    }
155}
156
157impl StftInputMut for [&mut [f32]] {
158    #[inline]
159    unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32 {
160        unsafe {
161            self.get_unchecked_mut(channel)
162                .get_unchecked_mut(sample_idx)
163        }
164    }
165}
166
167impl StftInput for NoSidechain {
168    fn num_samples(&self) -> usize {
169        0
170    }
171
172    fn num_channels(&self) -> usize {
173        0
174    }
175
176    unsafe fn get_sample_unchecked(&self, _channel: usize, _sample_idx: usize) -> f32 {
177        0.0
178    }
179}
180
181impl<const NUM_SIDECHAIN_INPUTS: usize> StftHelper<NUM_SIDECHAIN_INPUTS> {
182    /// Initialize the [`StftHelper`] for [`Buffer`]s with the specified number of channels and the
183    /// given maximum block size. When the option is set, then every yielded sample buffer will have
184    /// this many zero samples appended at the end of the block. Call
185    /// [`set_block_size()`][Self::set_block_size()] afterwards if you do not need the full capacity
186    /// upfront. If the padding option is non zero, then all yielded blocks will have that many
187    /// zeroes added to the end of it and the results stored in the padding area will be added to
188    /// the outputs in the next iteration(s). You may also change how much padding is added with
189    /// [`set_padding()`][Self::set_padding()].
190    ///
191    /// # Panics
192    ///
193    /// Panics if `num_channels == 0 || max_block_size == 0`.
194    pub fn new(num_channels: usize, max_block_size: usize, max_padding: usize) -> Self {
195        assert_ne!(num_channels, 0);
196        assert_ne!(max_block_size, 0);
197
198        Self {
199            main_input_ring_buffers: vec![vec![0.0; max_block_size]; num_channels],
200            main_output_ring_buffers: vec![vec![0.0; max_block_size]; num_channels],
201            // Kinda hacky way to initialize an array of non-copy types
202            sidechain_ring_buffers: [(); NUM_SIDECHAIN_INPUTS]
203                .map(|_| vec![vec![0.0; max_block_size]; num_channels]),
204
205            // When padding is used this scratch buffer will have a bunch of zeroes added to it
206            // after copying a block of audio to it
207            scratch_buffer: vec![0.0; max_block_size + max_padding],
208            padding_buffers: vec![vec![0.0; max_padding]; num_channels],
209
210            current_pos: 0,
211            padding: max_padding,
212        }
213    }
214
215    /// Change the current block size. This will clear the buffers, causing the next block to output
216    /// silence.
217    ///
218    /// # Panics
219    ///
220    /// Will panic if `block_size > max_block_size`.
221    pub fn set_block_size(&mut self, block_size: usize) {
222        assert!(block_size <= self.main_input_ring_buffers[0].capacity());
223
224        self.update_buffers(block_size);
225    }
226
227    /// Change the current padding amount. This will clear the buffers, causing the next block to
228    /// output silence.
229    ///
230    /// # Panics
231    ///
232    /// Will panic if `padding > max_padding`.
233    pub fn set_padding(&mut self, padding: usize) {
234        assert!(padding <= self.padding_buffers[0].capacity());
235
236        self.padding = padding;
237        self.update_buffers(self.main_input_ring_buffers[0].len());
238    }
239
240    /// The number of channels this `StftHelper` was configured for
241    pub fn num_channels(&self) -> usize {
242        self.main_input_ring_buffers.len()
243    }
244
245    /// The maximum block size supported by this instance.
246    pub fn max_block_size(&self) -> usize {
247        self.main_input_ring_buffers.capacity()
248    }
249
250    /// The maximum amount of padding supported by this instance.
251    pub fn max_padding(&self) -> usize {
252        self.padding_buffers[0].capacity()
253    }
254
255    /// The amount of latency introduced when processing audio through this [`StftHelper`].
256    pub fn latency_samples(&self) -> u32 {
257        self.main_input_ring_buffers[0].len() as u32
258    }
259
260    /// Process the audio in `main_buffer` in small overlapping blocks, adding up the results for
261    /// the main buffer so they can eventually be written back to the host one block later. This
262    /// means that this function will introduce one block of latency. This can be compensated by
263    /// calling [`InitContext::set_latency()`][`crate::context::init::InitContext::set_latency_samples()`]
264    /// in your plugin's initialization function.
265    ///
266    /// If a padding value was specified in [`new()`][Self::new()], then the yielded blocks will
267    /// have that many zeroes appended at the end of them. The padding values will be added to the
268    /// next block before `process_cb()` is called.
269    ///
270    /// Since there are a couple different ways to do it, any window functions needs to be applied
271    /// in the callbacks. Check the [`nice_plug_core::util::window`][crate::util::window] module for more
272    /// information.
273    ///
274    /// For efficiency's sake this function will reuse the same vector for all calls to
275    /// `process_cb`. This means you can only access a single channel's worth of windowed data at a
276    /// time. The arguments to that function are `process_cb(channel_idx, real_fft_buffer)`.
277    /// `real_fft_buffer` will be a slice of `block_size` real valued samples. This can be passed
278    /// directly to an FFT algorithm.
279    ///
280    /// # Panics
281    ///
282    /// Panics if `main_buffer` or the buffers in `sidechain_buffers` do not have the same number of
283    /// channels as this [`StftHelper`], or if the sidechain buffers do not contain the same number of
284    /// samples as the main buffer.
285    ///
286    /// TODO: Add more useful ways to do STFT and other buffered operations. I just went with this
287    ///       approach because it's what I needed myself, but generic combinators like this could
288    ///       also be useful for other operations.
289    pub fn process_overlap_add<M, F>(
290        &mut self,
291        main_buffer: &mut M,
292        overlap_times: usize,
293        mut process_cb: F,
294    ) where
295        M: StftInputMut,
296        F: FnMut(usize, &mut [f32]),
297    {
298        self.process_overlap_add_sidechain(
299            main_buffer,
300            [&NoSidechain; NUM_SIDECHAIN_INPUTS],
301            overlap_times,
302            |channel_idx, sidechain_idx, real_fft_scratch_buffer| {
303                if sidechain_idx.is_none() {
304                    process_cb(channel_idx, real_fft_scratch_buffer);
305                }
306            },
307        );
308    }
309
310    /// The same as [`process_overlap_add()`][Self::process_overlap_add()], but with sidechain
311    /// inputs that can be analyzed before the main input gets processed.
312    ///
313    /// The extra argument in the process function is `sidechain_buffer_idx`, which will be `None`
314    /// for the main buffer.
315    pub fn process_overlap_add_sidechain<M, S, F>(
316        &mut self,
317        main_buffer: &mut M,
318        sidechain_buffers: [&S; NUM_SIDECHAIN_INPUTS],
319        overlap_times: usize,
320        mut process_cb: F,
321    ) where
322        M: StftInputMut,
323        S: StftInput,
324        F: FnMut(usize, Option<usize>, &mut [f32]),
325    {
326        assert_eq!(
327            main_buffer.num_channels(),
328            self.main_input_ring_buffers.len()
329        );
330        assert!(overlap_times > 0);
331
332        // We'll copy samples from `*_buffer` into `*_ring_buffers` while simultaneously copying
333        // already processed samples from `main_ring_buffers` in into `main_buffer`
334        let main_buffer_len = main_buffer.num_samples();
335        let num_channels = main_buffer.num_channels();
336        let block_size = self.main_input_ring_buffers[0].len();
337        let window_interval = (block_size / overlap_times) as i32;
338        let mut already_processed_samples = 0;
339        while already_processed_samples < main_buffer_len {
340            let remaining_samples = main_buffer_len - already_processed_samples;
341            let samples_until_next_window = ((window_interval - self.current_pos as i32 - 1)
342                .rem_euclid(window_interval)
343                + 1) as usize;
344            let samples_to_process = samples_until_next_window.min(remaining_samples);
345
346            // Copy the input from `main_buffer` to the ring buffer while copying last block's
347            // result from the buffer to `main_buffer`
348            // TODO: This might be able to be sped up a bit with SIMD
349
350            // For the main buffer
351            for sample_offset in 0..samples_to_process {
352                for channel_idx in 0..num_channels {
353                    let sample = unsafe {
354                        main_buffer.get_sample_unchecked_mut(
355                            channel_idx,
356                            already_processed_samples + sample_offset,
357                        )
358                    };
359                    let input_ring_buffer_sample = unsafe {
360                        self.main_input_ring_buffers
361                            .get_unchecked_mut(channel_idx)
362                            .get_unchecked_mut(self.current_pos + sample_offset)
363                    };
364                    let output_ring_buffer_sample = unsafe {
365                        self.main_output_ring_buffers
366                            .get_unchecked_mut(channel_idx)
367                            .get_unchecked_mut(self.current_pos + sample_offset)
368                    };
369                    *input_ring_buffer_sample = *sample;
370                    *sample = *output_ring_buffer_sample;
371                    // Very important, or else we'll overlap-add ourselves into a feedback hell
372                    *output_ring_buffer_sample = 0.0;
373                }
374            }
375
376            // And for the sidechain buffers we only need to copy the inputs
377            for (sidechain_buffer, sidechain_ring_buffers) in sidechain_buffers
378                .iter()
379                .zip(self.sidechain_ring_buffers.iter_mut())
380            {
381                for sample_offset in 0..samples_to_process {
382                    for channel_idx in 0..num_channels {
383                        let sample = unsafe {
384                            sidechain_buffer.get_sample_unchecked(
385                                channel_idx,
386                                already_processed_samples + sample_offset,
387                            )
388                        };
389                        let ring_buffer_sample = unsafe {
390                            sidechain_ring_buffers
391                                .get_unchecked_mut(channel_idx)
392                                .get_unchecked_mut(self.current_pos + sample_offset)
393                        };
394                        *ring_buffer_sample = sample;
395                    }
396                }
397            }
398
399            already_processed_samples += samples_to_process;
400            self.current_pos = (self.current_pos + samples_to_process) % block_size;
401
402            // At this point we either have `already_processed_samples == main_buffer_len`, or
403            // `self.current_pos % window_interval == 0`. If it's the latter, then we can process a
404            // new block.
405            if samples_to_process == samples_until_next_window {
406                // Because we're processing in smaller windows, the input ring buffers sadly does
407                // not always contain the full contiguous range we're interested in because they map
408                // wrap around. Because premade FFT algorithms typically can't handle this, we'll
409                // start with copying the wrapped ranges from our ring buffers to the scratch
410                // buffer. Then we apply the windowing function and this it along to
411                for (sidechain_idx, sidechain_ring_buffers) in
412                    self.sidechain_ring_buffers.iter().enumerate()
413                {
414                    for (channel_idx, sidechain_ring_buffer) in
415                        sidechain_ring_buffers.iter().enumerate()
416                    {
417                        copy_ring_to_scratch_buffer(
418                            &mut self.scratch_buffer,
419                            self.current_pos,
420                            sidechain_ring_buffer,
421                        );
422                        if self.padding > 0 {
423                            self.scratch_buffer[block_size..].fill(0.0);
424                        }
425
426                        process_cb(channel_idx, Some(sidechain_idx), &mut self.scratch_buffer);
427                    }
428                }
429
430                for (channel_idx, ((input_ring_buffer, output_ring_buffer), padding_buffer)) in self
431                    .main_input_ring_buffers
432                    .iter()
433                    .zip(self.main_output_ring_buffers.iter_mut())
434                    .zip(self.padding_buffers.iter_mut())
435                    .enumerate()
436                {
437                    copy_ring_to_scratch_buffer(
438                        &mut self.scratch_buffer,
439                        self.current_pos,
440                        input_ring_buffer,
441                    );
442                    if self.padding > 0 {
443                        self.scratch_buffer[block_size..].fill(0.0);
444                    }
445
446                    process_cb(channel_idx, None, &mut self.scratch_buffer);
447
448                    // Add the padding from the last iteration (for this channel) to the scratch
449                    // buffer before it is copied to the output ring buffer. In case the padding is
450                    // longer than the block size, then this will cause everything else to be
451                    // shifted to the left so it can be added in the iteration after this.
452                    if self.padding > 0 {
453                        let padding_to_copy = cmp::min(self.padding, block_size);
454                        for (scratch_sample, padding_sample) in self.scratch_buffer
455                            [..padding_to_copy]
456                            .iter_mut()
457                            .zip(&mut padding_buffer[..padding_to_copy])
458                        {
459                            *scratch_sample += *padding_sample;
460                        }
461
462                        // Any remaining padding tail should be moved towards the start of the
463                        // buffer
464                        padding_buffer.copy_within(padding_to_copy.., 0);
465
466                        // And we obviously don't want this to feedback
467                        padding_buffer[self.padding - padding_to_copy..].fill(0.0);
468                    }
469
470                    // The actual overlap-add part of the equation
471                    add_scratch_to_ring_buffer(
472                        &self.scratch_buffer,
473                        self.current_pos,
474                        output_ring_buffer,
475                    );
476
477                    // And the data from the padding area should be saved so it can be added to next
478                    // iteration's scratch buffer. Like mentioned above, the padding can be larger
479                    // than the block size so we also need to do overlap-add here.
480                    if self.padding > 0 {
481                        for (padding_sample, scratch_sample) in padding_buffer
482                            .iter_mut()
483                            .zip(&mut self.scratch_buffer[block_size..])
484                        {
485                            *padding_sample += *scratch_sample;
486                        }
487                    }
488                }
489            }
490        }
491    }
492
493    /// Similar to [`process_overlap_add()`][Self::process_overlap_add()], but without the inverse
494    /// STFT part. `buffer` will only ever be read from. This can be useful for providing FFT data
495    /// for a spectrum analyzer in a plugin GUI. These is still a delay to the analysis equal to the
496    /// block size.
497    pub fn process_analyze_only<B, F>(
498        &mut self,
499        buffer: &B,
500        overlap_times: usize,
501        mut analyze_cb: F,
502    ) where
503        B: StftInput,
504        F: FnMut(usize, &mut [f32]),
505    {
506        assert_eq!(buffer.num_channels(), self.main_input_ring_buffers.len());
507        assert!(overlap_times > 0);
508
509        // See `process_overlap_add_sidechain` for an annotated version
510        let main_buffer_len = buffer.num_samples();
511        let num_channels = buffer.num_channels();
512        let block_size = self.main_input_ring_buffers[0].len();
513        let window_interval = (block_size / overlap_times) as i32;
514        let mut already_processed_samples = 0;
515        while already_processed_samples < main_buffer_len {
516            let remaining_samples = main_buffer_len - already_processed_samples;
517            let samples_until_next_window = ((window_interval - self.current_pos as i32 - 1)
518                .rem_euclid(window_interval)
519                + 1) as usize;
520            let samples_to_process = samples_until_next_window.min(remaining_samples);
521
522            for sample_offset in 0..samples_to_process {
523                for channel_idx in 0..num_channels {
524                    let sample = unsafe {
525                        buffer.get_sample_unchecked(
526                            channel_idx,
527                            already_processed_samples + sample_offset,
528                        )
529                    };
530                    let input_ring_buffer_sample = unsafe {
531                        self.main_input_ring_buffers
532                            .get_unchecked_mut(channel_idx)
533                            .get_unchecked_mut(self.current_pos + sample_offset)
534                    };
535                    *input_ring_buffer_sample = sample;
536                }
537            }
538
539            already_processed_samples += samples_to_process;
540            self.current_pos = (self.current_pos + samples_to_process) % block_size;
541
542            if samples_to_process == samples_until_next_window {
543                for (channel_idx, input_ring_buffer) in
544                    self.main_input_ring_buffers.iter().enumerate()
545                {
546                    copy_ring_to_scratch_buffer(
547                        &mut self.scratch_buffer,
548                        self.current_pos,
549                        input_ring_buffer,
550                    );
551                    if self.padding > 0 {
552                        self.scratch_buffer[block_size..].fill(0.0);
553                    }
554
555                    analyze_cb(channel_idx, &mut self.scratch_buffer);
556                }
557            }
558        }
559    }
560
561    fn update_buffers(&mut self, block_size: usize) {
562        for main_ring_buffer in &mut self.main_input_ring_buffers {
563            main_ring_buffer.resize(block_size, 0.0);
564            main_ring_buffer.fill(0.0);
565        }
566        for main_ring_buffer in &mut self.main_output_ring_buffers {
567            main_ring_buffer.resize(block_size, 0.0);
568            main_ring_buffer.fill(0.0);
569        }
570        for sidechain_ring_buffers in &mut self.sidechain_ring_buffers {
571            for sidechain_ring_buffer in sidechain_ring_buffers {
572                sidechain_ring_buffer.resize(block_size, 0.0);
573                sidechain_ring_buffer.fill(0.0);
574            }
575        }
576        self.scratch_buffer.resize(block_size + self.padding, 0.0);
577        self.scratch_buffer.fill(0.0);
578
579        for padding_buffer in &mut self.padding_buffers {
580            // In case this changed since the last call, like in `set_padding()`
581            padding_buffer.resize(self.padding, 0.0);
582            padding_buffer.fill(0.0);
583        }
584
585        self.current_pos = 0;
586    }
587}
588
589/// Copy data from the the specified ring buffer (borrowed from `self`) to the scratch buffers at
590/// the current position. This is a free function because you cannot pass an immutable reference to
591/// a field from `&self` to a `&mut self` method.
592#[inline]
593fn copy_ring_to_scratch_buffer(
594    scratch_buffer: &mut [f32],
595    current_pos: usize,
596    ring_buffer: &[f32],
597) {
598    let block_size = ring_buffer.len();
599    let num_copy_before_wrap = block_size - current_pos;
600    scratch_buffer[0..num_copy_before_wrap].copy_from_slice(&ring_buffer[current_pos..block_size]);
601    scratch_buffer[num_copy_before_wrap..block_size].copy_from_slice(&ring_buffer[0..current_pos]);
602}
603
604/// Add data from the scratch buffer to the specified ring buffer. When writing samples from this
605/// ring buffer back to the host's outputs they must be cleared to prevent infinite feedback.
606#[inline]
607fn add_scratch_to_ring_buffer(scratch_buffer: &[f32], current_pos: usize, ring_buffer: &mut [f32]) {
608    // TODO: This could also use some SIMD
609    let block_size = ring_buffer.len();
610    let num_copy_before_wrap = block_size - current_pos;
611    for (scratch_sample, ring_sample) in scratch_buffer[0..num_copy_before_wrap]
612        .iter()
613        .zip(&mut ring_buffer[current_pos..block_size])
614    {
615        *ring_sample += *scratch_sample;
616    }
617    for (scratch_sample, ring_sample) in scratch_buffer[num_copy_before_wrap..block_size]
618        .iter()
619        .zip(&mut ring_buffer[0..current_pos])
620    {
621        *ring_sample += *scratch_sample;
622    }
623}