audio_processor_analysis/transient_detection/stft/
mod.rs1use rustfft::num_complex::Complex;
24
25use audio_processor_traits::simple_processor::MonoAudioProcessor;
26use audio_processor_traits::{AudioBuffer, AudioContext};
27use dynamic_thresholds::{DynamicThresholds, DynamicThresholdsParams};
28use power_change::{PowerOfChangeFrames, PowerOfChangeParams};
29
30use crate::fft_processor::{FftDirection, FftProcessor, FftProcessorOptions};
31use crate::window_functions::WindowFunctionType;
32
33mod dynamic_thresholds;
34mod frame_deltas;
35mod power_change;
36
37pub mod markers;
38#[cfg(any(test, feature = "visualization"))]
39pub mod visualization;
40
41#[derive(Debug, Clone)]
42pub struct IterativeTransientDetectionParams {
43 pub fft_size: usize,
45 pub fft_overlap_ratio: f32,
47 pub power_of_change_spectral_spread: usize,
51 pub threshold_time_spread: usize,
61 pub threshold_time_spread_factor: f32,
70 pub frequency_bin_change_threshold: usize,
74 pub iteration_magnitude_factor: f32,
80 pub iteration_count: usize,
84}
85
86impl Default for IterativeTransientDetectionParams {
87 fn default() -> Self {
88 let fft_size = 2048;
89 let frequency_bin_change_threshold = 2048 / 4;
90 Self {
91 fft_size,
92 fft_overlap_ratio: 0.75,
93 power_of_change_spectral_spread: 3,
94 threshold_time_spread: 2,
95 threshold_time_spread_factor: 2.0,
96 iteration_magnitude_factor: 0.05,
97 iteration_count: 20,
98 frequency_bin_change_threshold,
99 }
100 }
101}
102
103pub fn find_transients(
163 params: IterativeTransientDetectionParams,
164 data: &mut AudioBuffer<f32>,
165) -> Vec<f32> {
166 let IterativeTransientDetectionParams {
167 fft_size,
168 fft_overlap_ratio,
169 power_of_change_spectral_spread,
170 threshold_time_spread,
171 threshold_time_spread_factor,
172 frequency_bin_change_threshold,
173 iteration_magnitude_factor,
174 iteration_count,
175 } = params;
176
177 log::info!("Performing FFT...");
178 let fft_frames = get_fft_frames(fft_size, fft_overlap_ratio, data);
179
180 log::info!("Finding base function values");
181 let mut magnitude_frames: Vec<Vec<f32>> = get_magnitudes(&fft_frames);
182 let mut transient_magnitude_frames: Vec<Vec<f32>> =
183 initialize_result_transient_magnitude_frames(&mut magnitude_frames);
184
185 for _iteration in 0..iteration_count {
186 let t_results = frame_deltas::calculate_deltas(&magnitude_frames);
187 let f_frames = power_change::calculate_power_of_change(
188 PowerOfChangeParams {
189 spectral_spread_bins: power_of_change_spectral_spread,
190 },
191 &t_results,
192 );
193 let threshold_frames = dynamic_thresholds::calculate_dynamic_thresholds(
194 DynamicThresholdsParams {
195 threshold_time_spread,
196 threshold_time_spread_factor,
197 },
198 &f_frames,
199 );
200
201 let num_changed_bins_frames: Vec<usize> =
202 count_changed_bins_per_frame(f_frames, threshold_frames);
203
204 update_output_and_magnitudes(
205 iteration_magnitude_factor,
206 frequency_bin_change_threshold,
207 num_changed_bins_frames,
208 &mut magnitude_frames,
209 &mut transient_magnitude_frames,
210 );
211 }
212
213 generate_output_frames(
214 fft_size,
215 fft_overlap_ratio,
216 data,
217 &fft_frames,
218 &mut transient_magnitude_frames,
219 )
220}
221
222fn update_output_and_magnitudes(
225 iteration_magnitude_factor: f32,
226 frequency_bin_change_threshold: usize,
227 num_changed_bins_frames: Vec<usize>,
228 magnitude_frames: &mut [Vec<f32>],
229 transient_magnitude_frames: &mut [Vec<f32>],
230) {
231 for i in 0..transient_magnitude_frames.len() {
232 for j in 0..transient_magnitude_frames[i].len() {
233 if num_changed_bins_frames[i] >= frequency_bin_change_threshold {
234 transient_magnitude_frames[i][j] +=
235 iteration_magnitude_factor * magnitude_frames[i][j];
236 magnitude_frames[i][j] -=
237 (1.0 - iteration_magnitude_factor) * magnitude_frames[i][j];
238 }
239 }
240 }
241}
242
243fn count_changed_bins_per_frame(
245 f_frames: PowerOfChangeFrames,
246 threshold_frames: DynamicThresholds,
247) -> Vec<usize> {
248 threshold_frames
249 .buffer
250 .iter()
251 .zip(f_frames.buffer)
252 .map(|(threshold_frame, f_frame)| {
253 threshold_frame
255 .iter()
256 .zip(f_frame)
257 .map(|(threshold, f)| usize::from(f > *threshold))
258 .sum()
260 })
261 .collect()
262}
263
264fn generate_output_frames(
266 fft_size: usize,
267 fft_overlap_ratio: f32,
268 data: &mut AudioBuffer<f32>,
269 fft_frames: &[Vec<Complex<f32>>],
270 transient_magnitude_frames: &mut [Vec<f32>],
271) -> Vec<f32> {
272 let mut planner = rustfft::FftPlanner::new();
273 let fft = planner.plan_fft(fft_size, FftDirection::Inverse);
274 let scratch_size = fft.get_inplace_scratch_len();
275 let mut scratch = Vec::with_capacity(scratch_size);
276 scratch.resize(scratch_size, 0.0.into());
277
278 let mut output = vec![];
279 output.resize(data.num_samples(), 0.0);
280
281 let mut cursor = 0;
282
283 for i in 0..fft_frames.len() {
284 let frame = &fft_frames[i];
285 let mut buffer: Vec<Complex<f32>> = frame
286 .iter()
287 .zip(&transient_magnitude_frames[i])
288 .map(|(input_signal_complex, transient_magnitude)| {
289 Complex::from_polar(*transient_magnitude, input_signal_complex.arg())
290 })
291 .collect();
292
293 fft.process_with_scratch(&mut buffer, &mut scratch);
294 for j in 0..buffer.len() {
295 if cursor + j < output.len() {
296 output[cursor + j] += buffer[j].re;
297 }
298 }
299
300 cursor += (frame.len() as f32 * (1.0 - fft_overlap_ratio)) as usize;
301 }
302
303 let maximum_output = output
304 .iter()
305 .map(|f| f.abs())
306 .max_by(|f1, f2| f1.partial_cmp(f2).unwrap_or(std::cmp::Ordering::Equal))
307 .unwrap_or(0.0);
308 for sample in &mut output {
309 if sample.abs() > maximum_output * 0.05 {
310 *sample /= maximum_output;
311 } else {
312 *sample = 0.0;
313 }
314 }
315
316 output.iter().skip(fft_size).cloned().collect()
320}
321
322fn initialize_result_transient_magnitude_frames(magnitudes: &mut [Vec<f32>]) -> Vec<Vec<f32>> {
323 magnitudes
324 .iter()
325 .map(|frame| frame.iter().map(|_| 0.0).collect())
326 .collect()
327}
328
329fn get_magnitudes(fft_frames: &[Vec<Complex<f32>>]) -> Vec<Vec<f32>> {
330 fft_frames
331 .iter()
332 .map(|frame| {
333 frame
334 .iter()
335 .map(|frequency_bin| frequency_bin.norm())
336 .collect()
337 })
338 .collect()
339}
340
341fn get_fft_frames(
342 fft_size: usize,
343 fft_overlap_ratio: f32,
344 data: &mut AudioBuffer<f32>,
345) -> Vec<Vec<Complex<f32>>> {
346 let mut fft = FftProcessor::new(FftProcessorOptions {
347 size: fft_size,
348 direction: FftDirection::Forward,
349 overlap_ratio: fft_overlap_ratio,
350 window_function: WindowFunctionType::Hann,
353 });
354 let mut fft_frames = vec![];
355
356 let mut context = AudioContext::default();
357 for sample_num in 0..data.num_samples() {
358 let mut input_sample = 0.0;
359 for channel in 0..data.num_channels() {
360 input_sample += data.get(channel, sample_num);
361 }
362
363 let output_sample = fft.m_process(&mut context, input_sample);
364
365 for channel in 0..data.num_channels() {
366 data.set(channel, sample_num, output_sample);
367 }
368
369 if fft.has_changed() {
370 fft_frames.push(fft.buffer().clone());
371 }
372 }
373
374 fft_frames
375}
376
377#[cfg(test)]
378mod test {
379 use audio_processor_testing_helpers::relative_path;
380
381 use audio_processor_file::{AudioFileProcessor, OutputAudioFileProcessor};
382 use audio_processor_traits::{AudioProcessor, AudioProcessorSettings};
383
384 use super::*;
385
386 fn read_input_file(input_file_path: &str) -> AudioBuffer<f32> {
388 log::info!("Reading input file input_file={}", input_file_path);
389 let settings = AudioProcessorSettings::default();
390 let mut input = AudioFileProcessor::from_path(
391 audio_garbage_collector::handle(),
392 settings,
393 input_file_path,
394 )
395 .unwrap();
396 let mut context = AudioContext::from(settings);
397
398 input.prepare(&mut context);
399 let input_buffer = input.buffer();
400 let mut buffer = AudioBuffer::empty();
401
402 let max_len = (settings.sample_rate() * 10.0) as usize;
404 buffer.resize(1, input_buffer[0].len().min(max_len));
405 for channel in input_buffer.iter() {
406 for (sample_index, sample) in channel.iter().enumerate().take(max_len) {
407 buffer.set(0, sample_index, *sample + buffer.get(0, sample_index));
408 }
409 }
410 buffer
411 }
412
413 #[test]
414 fn test_transient_detector() {
415 use visualization::draw;
416
417 wisual_logger::init_from_env();
418
419 let output_path = relative_path!("./src/transient_detection/stft.png");
420
421 let input_path = relative_path!("./hiphop-drum-loop.mp3");
423 let transients_file_path = format!("{}.transients.wav", input_path);
424 let mut input = read_input_file(&input_path);
425 let frames: Vec<f32> = input.channel(0).iter().cloned().collect();
426 let max_input = frames
427 .iter()
428 .map(|f| f.abs())
429 .max_by(|f1, f2| f1.partial_cmp(f2).unwrap_or(std::cmp::Ordering::Equal))
430 .unwrap();
431
432 let transients = find_transients(
433 IterativeTransientDetectionParams {
434 iteration_count: 2,
435 ..IterativeTransientDetectionParams::default()
436 },
437 &mut input,
438 );
439 assert_eq!(
440 frames.len() - IterativeTransientDetectionParams::default().fft_size,
441 transients.len()
442 );
443 draw(&output_path, &frames, &transients);
444
445 let settings = AudioProcessorSettings {
446 input_channels: 1,
447 output_channels: 1,
448 ..AudioProcessorSettings::default()
449 };
450 let mut output_processor =
451 OutputAudioFileProcessor::from_path(settings, &transients_file_path);
452 output_processor.prepare(settings);
453 let transients: Vec<f32> = transients.iter().map(|f| f * max_input).collect();
455 let mut buffer = AudioBuffer::from_interleaved(1, &transients);
456 output_processor
457 .process(&mut buffer)
458 .expect("Failed to write transients to file");
459 }
460}