use crate::analysis::{Analysis, analyzer};
use crate::spectrum;
use std::thread;
use threadpool::ThreadPool;
use std::sync::mpsc;
use num::Complex;
use crate::WindowType;
pub fn stft_analysis(audio: &mut Vec<f64>, fft_size: usize, sample_rate: u32, max_num_threads: Option<usize>) -> Vec<Analysis> {
let max_available_threads = match std::thread::available_parallelism() {
Ok(x) => x.get(),
Err(_) => 1
};
let pool_size = match max_num_threads {
Some(x) => {
if x > max_available_threads || x == 0 {
max_available_threads
} else {
x
}
},
None => max_available_threads
};
let rfft_freqs = spectrum::rfftfreq(fft_size, sample_rate);
let stft_imaginary_spectrum: Vec<Vec<Complex<f64>>> = spectrum::rstft(audio, fft_size, fft_size / 2, crate::WindowType::Hamming);
let (stft_magnitude_spectrum, _) = spectrum::complex_to_polar_rstft(&stft_imaginary_spectrum);
let (tx, rx) = mpsc::channel(); let num_threads: usize = match thread::available_parallelism() {
Ok(x) => x.get(),
Err(_) => 1
};
let mut thread_start_indices: Vec<usize> = vec![0; num_threads];
let num_frames_per_thread: usize = f64::ceil(stft_magnitude_spectrum.len() as f64 / num_threads as f64) as usize;
for i in 0..num_threads {
thread_start_indices[i] = num_frames_per_thread * i;
}
let pool = ThreadPool::new(pool_size);
for i in 0..num_threads {
let tx_clone = tx.clone();
let thread_idx = i;
let mut local_magnitude_spectrum: Vec<Vec<f64>> = Vec::with_capacity(num_frames_per_thread);
let start_idx = i * num_frames_per_thread;
let end_idx = usize::min(start_idx + num_frames_per_thread, stft_magnitude_spectrum.len());
for j in start_idx..end_idx {
let mut rfft_frame: Vec<f64> = Vec::with_capacity(stft_magnitude_spectrum[j].len());
for k in 0..stft_magnitude_spectrum[j].len() {
rfft_frame.push(stft_magnitude_spectrum[j][k]);
}
local_magnitude_spectrum.push(rfft_frame);
}
let local_sample_rate = sample_rate;
let local_rfft_freqs = rfft_freqs.clone();
pool.execute(move || {
let mut analyses: Vec<Analysis> = Vec::with_capacity(local_magnitude_spectrum.len());
for j in 0..local_magnitude_spectrum.len() {
analyses.push(analyzer(&local_magnitude_spectrum[j], None, local_sample_rate, &local_rfft_freqs))
}
let _ = match tx_clone.send((thread_idx, analyses)) {
Ok(x) => x,
Err(_) => ()
};
});
}
drop(tx);
let mut results = vec![];
for received_data in rx {
results.push(received_data);
}
results.sort_by_key(|&(index, _)| index);
pool.join();
let mut analyses: Vec<Analysis> = Vec::new();
for i in 0..results.len() {
for j in 0..results[i].1.len() {
analyses.push(results[i].1[j]);
}
}
analyses
}
pub fn rstft(audio: &[f64], fft_size: usize, hop_size: usize, window_type: WindowType, max_num_threads: Option<usize>) -> Vec<Vec<Complex<f64>>> {
let max_available_threads = match std::thread::available_parallelism() {
Ok(x) => x.get(),
Err(_) => 1
};
let pool_size = match max_num_threads {
Some(x) => {
if x > max_available_threads || x == 0 {
max_available_threads
} else {
x
}
},
None => max_available_threads
};
let mut thread_start: Vec<usize> = vec![0; pool_size];
let mut thread_end: Vec<usize> = vec![0; pool_size];
let mut samples_per_thread = audio.len() / pool_size;
let mut i = 1;
loop {
if i > samples_per_thread {
samples_per_thread = i / 2;
break;
}
i *= 2;
}
for i in 0..pool_size {
thread_start[i] = samples_per_thread * i;
thread_end[i] = samples_per_thread * (i + 1);
}
thread_end[pool_size-1] = audio.len();
let mut spectrogram: Vec<Vec<Complex<f64>>> = Vec::new();
let (tx, rx) = mpsc::channel(); let pool = ThreadPool::new(pool_size);
for i in 0..pool_size {
let tx_clone = tx.clone();
let thread_idx = i;
let mut local_audio: Vec<f64> = Vec::with_capacity(samples_per_thread);
for j in thread_start[i]..thread_end[i] {
local_audio.push(audio[j]);
}
let local_fft_size = fft_size;
let local_hop_size = hop_size;
let local_window_type = window_type;
pool.execute(move || {
let local_spectrogram = spectrum::rstft(&local_audio, local_fft_size, local_hop_size, local_window_type);
let _ = match tx_clone.send((thread_idx, local_spectrogram)) {
Ok(x) => x,
Err(_) => ()
};
});
}
drop(tx);
let mut results = vec![];
for received_data in rx {
results.push(received_data);
}
results.sort_by_key(|&(index, _)| index);
pool.join();
for i in 0..results.len() {
for j in 0..results[i].1.len() {
spectrogram.push(results[i].1[j].clone());
}
}
pool.join();
spectrogram
}
#[cfg(test)]
mod tests {
use super::*;
use super::spectrum::irstft;
const AUDIO: &str = "D:/Recording/tests/grains.wav";
const FFT_SIZE: usize = 2048;
#[test]
fn mp_basic_tests1() {
let path = String::from(AUDIO);
let audio = match crate::read(&path) {
Ok(x) => x,
Err(_) => panic!("could not read audio")
};
let imaginary_spectrogram = rstft(&audio.samples[0], FFT_SIZE, FFT_SIZE / 2, WindowType::Hamming, Some(8));
let new_audio = irstft(&imaginary_spectrogram, FFT_SIZE, FFT_SIZE / 2, WindowType::Hamming).unwrap();
let new_audio_file = crate::AudioFile::new_mono(audio.audio_format, audio.sample_rate, new_audio);
crate::write("D:/Recording/tests/graintest.wav", &new_audio_file).unwrap();
}
#[test]
fn mp_basic_tests2() {
let path = String::from(AUDIO);
let mut audio = match crate::read(&path) {
Ok(x) => x,
Err(_) => panic!("could not read audio")
};
let _ = stft_analysis(&mut audio.samples[0], FFT_SIZE, audio.sample_rate, Some(8));
}
}