use crossbeam_channel::{Receiver, Sender};
use ringbuf::traits::{Consumer, Observer, Producer, Split};
use ringbuf::LocalRb;
use rustfft::{num_complex::Complex, FftPlanner};
use tracing::{debug, warn};
use std::cmp;
use std::env;
use std::f32::consts::PI;
use std::str::FromStr;
fn hann(width: usize, offset: usize, window_length: usize) -> Vec<f32> {
if let Ok(disabled) = env::var("DEMO_DISABLE_INPUT_WEIGHTS") {
if !disabled.is_empty() {
return vec![1.0; window_length];
}
}
let mut samples = vec![0.0; window_length];
let end = cmp::min(offset + width, window_length);
for i in offset..end {
let n = (i - offset) as f32;
samples[i] = (PI * n / (width - 1) as f32).sin().powi(2);
}
samples
}
fn topn_vals(vals: &Vec<f32>, n: usize) -> Vec<f32> {
let mut topn_vec = Vec::with_capacity(n);
for val in vals {
if topn_vec.len() >= n {
if val > topn_vec.last().unwrap() {
topn_vec.pop();
} else {
continue;
}
}
match topn_vec.binary_search_by(|x| {
if x == val {
std::cmp::Ordering::Equal
} else if x > val {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Greater
}
}) {
Ok(idx_found) => {
topn_vec.insert(idx_found + 1, *val);
}
Err(idx_insert) => {
topn_vec.insert(idx_insert, *val);
}
}
}
topn_vec
}
struct CompressState {
recent_max_low: f32,
recent_max_high: f32,
compression_enabled: bool,
}
const HIGH_BUCKET_COUNT: usize = 10;
const HIGH_BUCKET_COMPRESSION: f32 = 0.8;
const RECENT_MAX_FAST_MULTIPLIER: f32 = 0.99;
const RECENT_MAX_SLOW_MULTIPLIER: f32 = 0.995;
fn compress_amplitudes(vals: Vec<f32>, s: &mut CompressState) -> Vec<f32> {
if !s.compression_enabled {
let max = vals
.iter()
.max_by(|x, y| {
if x == y {
std::cmp::Ordering::Equal
} else if x < y {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Greater
}
})
.unwrap()
.clone();
s.recent_max_high *= RECENT_MAX_SLOW_MULTIPLIER;
if max > s.recent_max_high {
s.recent_max_high = max;
}
return vals.into_iter().map(|x| x / s.recent_max_high).collect();
}
let top_bucket_vals = topn_vals(&vals, HIGH_BUCKET_COUNT);
let topn = *top_bucket_vals.last().unwrap();
s.recent_max_low *= RECENT_MAX_SLOW_MULTIPLIER;
if topn > s.recent_max_low {
s.recent_max_low = topn;
}
let max = *top_bucket_vals.first().unwrap();
s.recent_max_high *= RECENT_MAX_FAST_MULTIPLIER;
if max > s.recent_max_high {
s.recent_max_high = max;
}
if s.recent_max_low > s.recent_max_high {
s.recent_max_high = s.recent_max_low;
}
if s.recent_max_high == 0.0 {
return vals;
}
let low_val_divisor = s.recent_max_low / HIGH_BUCKET_COMPRESSION;
let high_val_multiplier =
(1.0 - HIGH_BUCKET_COMPRESSION) / (s.recent_max_high - s.recent_max_low);
let high_val_adder = 1.0 - high_val_multiplier * s.recent_max_high;
vals.into_iter()
.map(|val| {
if val <= s.recent_max_low {
val / low_val_divisor
} else {
val * high_val_multiplier + high_val_adder
}
})
.collect()
}
fn freq_flattening_weights(output_len: usize, output_frequency: i32) -> Vec<f32> {
let mut weights = Vec::with_capacity(output_len);
if let Ok(disabled) = env::var("DEMO_DISABLE_FREQ_WEIGHTS") {
if !disabled.is_empty() {
return vec![1.0; output_len];
}
}
let bucket_freq_width = output_frequency as f32 / output_len as f32;
let ten: f32 = 10.;
let mut cur_freq: f32 = 0.;
for _i in 0..output_len {
cur_freq += bucket_freq_width;
let dbgain = 3.01 * (cur_freq / 1000.).log2();
weights.push(ten.powf(dbgain / 20.));
}
weights
}
fn calculate_fourier_size(output_len: usize, input_freq: Option<i32>) -> (usize, Vec<f32>) {
if let Some(f) = input_freq {
if let Ok(scale) = env::var("DEMO_FOURIER_SCALE") {
if let Ok(scalesz) = usize::from_str(&scale) {
return (
scalesz * output_len,
freq_flattening_weights(output_len, f / (scalesz as i32 / 2)),
);
}
}
if f >= 192000 {
(8 * output_len, freq_flattening_weights(output_len, f / 4))
} else if f >= 64000 {
(4 * output_len, freq_flattening_weights(output_len, f / 2))
} else {
(2 * output_len, freq_flattening_weights(output_len, f))
}
} else {
(2 * output_len, freq_flattening_weights(output_len, 48000))
}
}
pub fn process_audio_loop(
output_len: usize,
input_frequency: Option<i32>,
recv_audio: Receiver<Vec<f32>>,
send_processed: Sender<Vec<f32>>,
) {
let (input_len, output_freq_weights) = calculate_fourier_size(output_len, input_frequency);
debug!("Using fourier size: {}", input_len);
let fft = FftPlanner::new().plan_fft_forward(input_len);
let input_window_scales = hann(input_len - 2, 1, input_len);
let (mut audio_buf_in, mut audio_buf_out) = LocalRb::new(input_len).split();
let mut fft_buf = Vec::with_capacity(input_len);
fft_buf.resize(input_len, Complex::new(0.0, 0.0));
let mut compress_state = CompressState {
recent_max_low: 0.,
recent_max_high: 0.,
compression_enabled: match env::var("DEMO_DISABLE_COMPRESSION") {
Ok(_) => false,
Err(_) => true,
},
};
loop {
match recv_audio.recv() {
Ok(audio) => {
if audio.is_empty() {
compress_state.recent_max_low = 0.;
compress_state.recent_max_high = 0.;
continue;
}
if audio_buf_in.capacity().get() < audio_buf_in.occupied_len() + audio.len() {
let (mut new_audio_buf_in, new_audio_buf_out) =
LocalRb::new(audio_buf_in.capacity().get() + audio.len()).split();
ringbuf::transfer(&mut audio_buf_out, &mut new_audio_buf_in, None);
audio_buf_in = new_audio_buf_in;
audio_buf_out = new_audio_buf_out;
}
audio_buf_in.push_iter(&mut audio.into_iter());
while audio_buf_out.occupied_len() >= input_len {
let (older_audio, newer_audio) = audio_buf_out.as_slices();
if older_audio.len() >= input_len {
for i in 0..input_len {
fft_buf[i] = Complex::new(input_window_scales[i] * older_audio[i], 0.0);
}
} else {
for idx in 0..older_audio.len() {
fft_buf[idx] =
Complex::new(input_window_scales[idx] * older_audio[idx], 0.0);
}
for i in 0..(input_len - older_audio.len()) {
let buf_idx = older_audio.len() + i;
fft_buf[buf_idx] =
Complex::new(input_window_scales[buf_idx] * newer_audio[i], 0.0);
}
}
audio_buf_out.skip(input_len / 2);
fft.process(&mut fft_buf);
let flattened_result = Vec::from_iter(
fft_buf
.iter()
.zip(output_freq_weights.iter())
.map(|(val, weight)| val.norm() * weight),
);
let compressed_result =
compress_amplitudes(flattened_result, &mut compress_state);
if let Err(e) = send_processed.send(compressed_result) {
warn!("exiting audio processing thread, output error: {}", e);
return;
}
}
}
Err(_e) => {
return;
}
}
}
}