audio_processor_analysis/transient_detection/stft/
mod.rs

1// Augmented Audio: Audio libraries and applications
2// Copyright (c) 2022 Pedro Tacla Yamada
3//
4// The MIT License (MIT)
5//
6// Permission is hereby granted, free of charge, to any person obtaining a copy
7// of this software and associated documentation files (the "Software"), to deal
8// in the Software without restriction, including without limitation the rights
9// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10// copies of the Software, and to permit persons to whom the Software is
11// furnished to do so, subject to the following conditions:
12//
13// The above copyright notice and this permission notice shall be included in
14// all copies or substantial portions of the Software.
15//
16// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22// THE SOFTWARE.
23use rustfft::num_complex::Complex;
24
25use audio_processor_traits::simple_processor::MonoAudioProcessor;
26use audio_processor_traits::{AudioBuffer, AudioContext};
27use dynamic_thresholds::{DynamicThresholds, DynamicThresholdsParams};
28use power_change::{PowerOfChangeFrames, PowerOfChangeParams};
29
30use crate::fft_processor::{FftDirection, FftProcessor, FftProcessorOptions};
31use crate::window_functions::WindowFunctionType;
32
33mod dynamic_thresholds;
34mod frame_deltas;
35mod power_change;
36
37pub mod markers;
38#[cfg(any(test, feature = "visualization"))]
39pub mod visualization;
40
41#[derive(Debug, Clone)]
42pub struct IterativeTransientDetectionParams {
43    /// Size of the FFT windows, defaults to 2048; at 44.1kHz each frame should be ~40ms
44    pub fft_size: usize,
45    /// If 0.75 is provided, 3/4 of the windows will overlap. Defaults to 3/4
46    pub fft_overlap_ratio: f32,
47    /// `v` in the paper (equation `5`)
48    ///
49    /// Defaults to 3 frequency bins or roughly 60Hz at 44.1kHz sample rate
50    pub power_of_change_spectral_spread: usize,
51    /// `τ` in the paper (equation 7)
52    ///
53    /// * When calculating dynamic thresholds, controls how many neighbouring time frames are
54    ///   considered
55    /// * For example, if `threshold_time_spread_factor` is 2.0, a frequency bin and its
56    ///   `spectral_spread` neighbours will have to be 2.0 the average of the `time_spread` time
57    ///   frames' rate of change for this bin
58    ///
59    /// Defaults to 3
60    pub threshold_time_spread: usize,
61    /// `β` - `threshold_time_spread_factor` (equation 7)
62    ///
63    /// * Internal multiplier of dynamic thresholds
64    /// * This nº affects by what factor a frequency bin needs to change in relation to what it has
65    ///   changed in neighboring frames to be considered a transient
66    /// * Higher nºs means sensitivity is decreased
67    ///
68    /// Defaults to 2.0
69    pub threshold_time_spread_factor: f32,
70    /// How many bins should change for a frame to be considered a transient
71    ///
72    /// Defaults to 1/4 of the fft_size
73    pub frequency_bin_change_threshold: usize,
74    /// `δ` - `iteration_magnitude_factor` (equation 10)
75    ///
76    /// * What factor of the magnitude is collected onto the output per iteration
77    ///
78    /// Defaults to 0.1
79    pub iteration_magnitude_factor: f32,
80    /// `N` - `iteration_count` (algorithm 1)
81    ///
82    /// Defaults to 20
83    pub iteration_count: usize,
84}
85
86impl Default for IterativeTransientDetectionParams {
87    fn default() -> Self {
88        let fft_size = 2048;
89        let frequency_bin_change_threshold = 2048 / 4;
90        Self {
91            fft_size,
92            fft_overlap_ratio: 0.75,
93            power_of_change_spectral_spread: 3,
94            threshold_time_spread: 2,
95            threshold_time_spread_factor: 2.0,
96            iteration_magnitude_factor: 0.05,
97            iteration_count: 20,
98            frequency_bin_change_threshold,
99        }
100    }
101}
102
103/// Implements iterative STFT transient detection for polyphonic signals. Output is a monophonic
104/// audio track of the transient signal. Not real-time safe.
105///
106/// Reference:
107/// * <https://www.researchgate.net/profile/Balaji-Thoshkahna/publication/220723752_A_Transient_Detection_Algorithm_for_Audio_Using_Iterative_Analysis_of_STFT/links/0deec52e6331412aed000000/A-Transient-Detection-Algorithm-for-Audio-Using-Iterative-Analysis-of-STFT.pdf>
108///
109/// Inputs to the algorithm:
110/// * `fft_size` - Size of the FFT window
111/// * `fft_overlap` - Amount of overlap between windows
112/// * `v` - `power_of_change_spectral_spread` (equation 5)
113/// * `β` - `threshold_time_spread_factor` (equation 7)
114///   * Internal multiplier of dynamic thresholds
115///   * This nº affects by what factor a frequency bin's rate of change needs to be higher than its
116///     time-domain neighbours
117///   * Higher nºs means sensitivity is decreased
118/// * `λThr` - `frequency_bin_change_threshold` (equation 10)
119///   * If this amount of frequency bins have changed, this frame will be considered a transient
120/// * `δ` - `iteration_magnitude_factor` (equation 10)
121///   * What factor of the magnitude is collected onto the output per iteration
122/// * `N` - `iteration_count` (algorithm 1)
123///
124/// The algorithm is as follows:
125/// * Perform FFT with overlapping windows at 3/4's ratio (e.g. one 40ms window every 30ms)
126/// * Calculate `M(frame, bin)` **magnitudes for each frame/bin**
127/// * Let `P(frame, bin)` be the output `transient_magnitude_frames`
128///   * These are the transient magnitude frames
129///   * e.g. Magnitudes of the transients, per frequency bin, over time
130/// * for iteration in 0..`N`
131///   * For each frame/bin, calculate **power of change** value `F(frame, bin)`
132///     * First we calculate `T-(frame, bin)` and `T+(frame, bin)`, respectively the deltas in
133///       magnitude with the previous and next frames respectively
134///     * `F(frame, bin)` (**power of change** represents how much its magnitude is higher compared
135///       with next and previous frames, if it's higher than its next/previous frames (0.0
136///       if its not higher than neighbouring time-domain frames)
137///     * For each `bin` this `power_of_change` is summed with its `v` (`power_of_change_spectral_spread`)
138///       neighbour frequency bins
139///     * This is for of 'peak detection' in some way, it's finding frames higher than their
140///       time-domain peers and quantifying how much they're larger than them
141///   * Calculate `dynamic_thresholds` `λ(frame, bin)`
142///     * For this, on every frequency bin, the threshold is the average **power of change** of the
143///       `l` (`threshold_time_spread`) neighbouring time-domain frames, multiplied by a magic
144///       constant `β` (`threshold_time_spread_factor`)
145///     * A frequency bin's threshold is defined by how much its neighbour frequency bins have
146///       changed in this frame* (change being quantified by `F`)
147///   * Calculate `Γ(frame, bin)` (`have_bins_changed`), by flipping a flag to 1 or 0 depending on
148///     whether **power of change**  `F(frame, bin)` is higher than its **dynamic threshold**
149///     `λ(frame, bin)`
150///   * Calculate `ΣΓ` `num_changed_bins`, by counting the number of frequency bins that have
151///     changed in this frame
152///     * Simply sum the above for each frame
153///   * If `ΣΓ(frame)` `num_changed_bins` is higher than `λThr` - `frequency_bin_change_threshold`
154///     * Update `P(frame, bin)` - `transient_magnitude_frames` adding `X(frame, bin)` times `δ`
155///       `iteration_magnitude_factor` onto it
156///     * Subtract `(1 - δ) * X(frame, bin)` from `X(frame, bin)`
157/// * At the end of `N` iterations, perform the inverse fourier transform over the each polar
158///   complex nº frame using magnitudes in `transient_magnitude_frames` and using phase from the
159///   input FFT result
160/// * There may now be extra filtering / smoothing steps to extract data or audio, but the output
161///   should be the transient signal
162pub fn find_transients(
163    params: IterativeTransientDetectionParams,
164    data: &mut AudioBuffer<f32>,
165) -> Vec<f32> {
166    let IterativeTransientDetectionParams {
167        fft_size,
168        fft_overlap_ratio,
169        power_of_change_spectral_spread,
170        threshold_time_spread,
171        threshold_time_spread_factor,
172        frequency_bin_change_threshold,
173        iteration_magnitude_factor,
174        iteration_count,
175    } = params;
176
177    log::info!("Performing FFT...");
178    let fft_frames = get_fft_frames(fft_size, fft_overlap_ratio, data);
179
180    log::info!("Finding base function values");
181    let mut magnitude_frames: Vec<Vec<f32>> = get_magnitudes(&fft_frames);
182    let mut transient_magnitude_frames: Vec<Vec<f32>> =
183        initialize_result_transient_magnitude_frames(&mut magnitude_frames);
184
185    for _iteration in 0..iteration_count {
186        let t_results = frame_deltas::calculate_deltas(&magnitude_frames);
187        let f_frames = power_change::calculate_power_of_change(
188            PowerOfChangeParams {
189                spectral_spread_bins: power_of_change_spectral_spread,
190            },
191            &t_results,
192        );
193        let threshold_frames = dynamic_thresholds::calculate_dynamic_thresholds(
194            DynamicThresholdsParams {
195                threshold_time_spread,
196                threshold_time_spread_factor,
197            },
198            &f_frames,
199        );
200
201        let num_changed_bins_frames: Vec<usize> =
202            count_changed_bins_per_frame(f_frames, threshold_frames);
203
204        update_output_and_magnitudes(
205            iteration_magnitude_factor,
206            frequency_bin_change_threshold,
207            num_changed_bins_frames,
208            &mut magnitude_frames,
209            &mut transient_magnitude_frames,
210        );
211    }
212
213    generate_output_frames(
214        fft_size,
215        fft_overlap_ratio,
216        data,
217        &fft_frames,
218        &mut transient_magnitude_frames,
219    )
220}
221
222/// Last step on iteration, collect `iteration_magnitude_factor * M(frame, bin)` if this whole frame
223/// is a transient, subtract `1.0 - iteration_magnitude_factor` from the magnitude frames.
224fn update_output_and_magnitudes(
225    iteration_magnitude_factor: f32,
226    frequency_bin_change_threshold: usize,
227    num_changed_bins_frames: Vec<usize>,
228    magnitude_frames: &mut [Vec<f32>],
229    transient_magnitude_frames: &mut [Vec<f32>],
230) {
231    for i in 0..transient_magnitude_frames.len() {
232        for j in 0..transient_magnitude_frames[i].len() {
233            if num_changed_bins_frames[i] >= frequency_bin_change_threshold {
234                transient_magnitude_frames[i][j] +=
235                    iteration_magnitude_factor * magnitude_frames[i][j];
236                magnitude_frames[i][j] -=
237                    (1.0 - iteration_magnitude_factor) * magnitude_frames[i][j];
238            }
239        }
240    }
241}
242
243/// Equations 8 and 9
244fn count_changed_bins_per_frame(
245    f_frames: PowerOfChangeFrames,
246    threshold_frames: DynamicThresholds,
247) -> Vec<usize> {
248    threshold_frames
249        .buffer
250        .iter()
251        .zip(f_frames.buffer)
252        .map(|(threshold_frame, f_frame)| {
253            // Γ
254            threshold_frame
255                .iter()
256                .zip(f_frame)
257                .map(|(threshold, f)| usize::from(f > *threshold))
258                // end Γ
259                .sum()
260        })
261        .collect()
262}
263
264/// Perform inverse FFT over spectrogram frames
265fn generate_output_frames(
266    fft_size: usize,
267    fft_overlap_ratio: f32,
268    data: &mut AudioBuffer<f32>,
269    fft_frames: &[Vec<Complex<f32>>],
270    transient_magnitude_frames: &mut [Vec<f32>],
271) -> Vec<f32> {
272    let mut planner = rustfft::FftPlanner::new();
273    let fft = planner.plan_fft(fft_size, FftDirection::Inverse);
274    let scratch_size = fft.get_inplace_scratch_len();
275    let mut scratch = Vec::with_capacity(scratch_size);
276    scratch.resize(scratch_size, 0.0.into());
277
278    let mut output = vec![];
279    output.resize(data.num_samples(), 0.0);
280
281    let mut cursor = 0;
282
283    for i in 0..fft_frames.len() {
284        let frame = &fft_frames[i];
285        let mut buffer: Vec<Complex<f32>> = frame
286            .iter()
287            .zip(&transient_magnitude_frames[i])
288            .map(|(input_signal_complex, transient_magnitude)| {
289                Complex::from_polar(*transient_magnitude, input_signal_complex.arg())
290            })
291            .collect();
292
293        fft.process_with_scratch(&mut buffer, &mut scratch);
294        for j in 0..buffer.len() {
295            if cursor + j < output.len() {
296                output[cursor + j] += buffer[j].re;
297            }
298        }
299
300        cursor += (frame.len() as f32 * (1.0 - fft_overlap_ratio)) as usize;
301    }
302
303    let maximum_output = output
304        .iter()
305        .map(|f| f.abs())
306        .max_by(|f1, f2| f1.partial_cmp(f2).unwrap_or(std::cmp::Ordering::Equal))
307        .unwrap_or(0.0);
308    for sample in &mut output {
309        if sample.abs() > maximum_output * 0.05 {
310            *sample /= maximum_output;
311        } else {
312            *sample = 0.0;
313        }
314    }
315
316    // TODO Is this the best way to do latency compensation?
317    // We have to skip the first frame because of the buffering delay.
318    // * Should we wait only one overlap window?
319    output.iter().skip(fft_size).cloned().collect()
320}
321
322fn initialize_result_transient_magnitude_frames(magnitudes: &mut [Vec<f32>]) -> Vec<Vec<f32>> {
323    magnitudes
324        .iter()
325        .map(|frame| frame.iter().map(|_| 0.0).collect())
326        .collect()
327}
328
329fn get_magnitudes(fft_frames: &[Vec<Complex<f32>>]) -> Vec<Vec<f32>> {
330    fft_frames
331        .iter()
332        .map(|frame| {
333            frame
334                .iter()
335                .map(|frequency_bin| frequency_bin.norm())
336                .collect()
337        })
338        .collect()
339}
340
341fn get_fft_frames(
342    fft_size: usize,
343    fft_overlap_ratio: f32,
344    data: &mut AudioBuffer<f32>,
345) -> Vec<Vec<Complex<f32>>> {
346    let mut fft = FftProcessor::new(FftProcessorOptions {
347        size: fft_size,
348        direction: FftDirection::Forward,
349        overlap_ratio: fft_overlap_ratio,
350        // On the paper this is a [Blackman-Harris](https://en.wikipedia.org/wiki/Window_function)
351        // window
352        window_function: WindowFunctionType::Hann,
353    });
354    let mut fft_frames = vec![];
355
356    let mut context = AudioContext::default();
357    for sample_num in 0..data.num_samples() {
358        let mut input_sample = 0.0;
359        for channel in 0..data.num_channels() {
360            input_sample += data.get(channel, sample_num);
361        }
362
363        let output_sample = fft.m_process(&mut context, input_sample);
364
365        for channel in 0..data.num_channels() {
366            data.set(channel, sample_num, output_sample);
367        }
368
369        if fft.has_changed() {
370            fft_frames.push(fft.buffer().clone());
371        }
372    }
373
374    fft_frames
375}
376
377#[cfg(test)]
378mod test {
379    use audio_processor_testing_helpers::relative_path;
380
381    use audio_processor_file::{AudioFileProcessor, OutputAudioFileProcessor};
382    use audio_processor_traits::{AudioProcessor, AudioProcessorSettings};
383
384    use super::*;
385
386    /// Read an input file for testing
387    fn read_input_file(input_file_path: &str) -> AudioBuffer<f32> {
388        log::info!("Reading input file input_file={}", input_file_path);
389        let settings = AudioProcessorSettings::default();
390        let mut input = AudioFileProcessor::from_path(
391            audio_garbage_collector::handle(),
392            settings,
393            input_file_path,
394        )
395        .unwrap();
396        let mut context = AudioContext::from(settings);
397
398        input.prepare(&mut context);
399        let input_buffer = input.buffer();
400        let mut buffer = AudioBuffer::empty();
401
402        // We read at most 10s of audio & mono it.
403        let max_len = (settings.sample_rate() * 10.0) as usize;
404        buffer.resize(1, input_buffer[0].len().min(max_len));
405        for channel in input_buffer.iter() {
406            for (sample_index, sample) in channel.iter().enumerate().take(max_len) {
407                buffer.set(0, sample_index, *sample + buffer.get(0, sample_index));
408            }
409        }
410        buffer
411    }
412
413    #[test]
414    fn test_transient_detector() {
415        use visualization::draw;
416
417        wisual_logger::init_from_env();
418
419        let output_path = relative_path!("./src/transient_detection/stft.png");
420
421        // let input_path = relative_path!("../../../../input-files/C3-loop.mp3");
422        let input_path = relative_path!("./hiphop-drum-loop.mp3");
423        let transients_file_path = format!("{}.transients.wav", input_path);
424        let mut input = read_input_file(&input_path);
425        let frames: Vec<f32> = input.channel(0).iter().cloned().collect();
426        let max_input = frames
427            .iter()
428            .map(|f| f.abs())
429            .max_by(|f1, f2| f1.partial_cmp(f2).unwrap_or(std::cmp::Ordering::Equal))
430            .unwrap();
431
432        let transients = find_transients(
433            IterativeTransientDetectionParams {
434                iteration_count: 2,
435                ..IterativeTransientDetectionParams::default()
436            },
437            &mut input,
438        );
439        assert_eq!(
440            frames.len() - IterativeTransientDetectionParams::default().fft_size,
441            transients.len()
442        );
443        draw(&output_path, &frames, &transients);
444
445        let settings = AudioProcessorSettings {
446            input_channels: 1,
447            output_channels: 1,
448            ..AudioProcessorSettings::default()
449        };
450        let mut output_processor =
451            OutputAudioFileProcessor::from_path(settings, &transients_file_path);
452        output_processor.prepare(settings);
453        // match input signal
454        let transients: Vec<f32> = transients.iter().map(|f| f * max_input).collect();
455        let mut buffer = AudioBuffer::from_interleaved(1, &transients);
456        output_processor
457            .process(&mut buffer)
458            .expect("Failed to write transients to file");
459    }
460}