use std::sync::{
atomic::{AtomicU32, Ordering},
Arc,
};
use crossbeam_channel::{Receiver, Sender};
use rubato::{FftFixedInOut, Resampler};
use crate::{
alloc::AudioBuffer,
buffer::{ChannelConfig, ChannelConfigOptions},
context::{AsBaseAudioContext, AudioContextRegistration},
process::{AudioParamValues, AudioProcessor},
SampleRate,
};
use super::AudioNode;
struct CurveMessage(Vec<f32>);
#[derive(Debug, Clone, Copy, PartialEq)]
#[allow(clippy::module_name_repetitions)]
pub enum OverSampleType {
None,
X2,
X4,
}
impl Default for OverSampleType {
fn default() -> Self {
Self::None
}
}
impl From<u32> for OverSampleType {
fn from(i: u32) -> Self {
use OverSampleType::{None, X2, X4};
match i {
0 => None,
1 => X2,
2 => X4,
_ => unreachable!(),
}
}
}
#[allow(clippy::module_name_repetitions)]
pub struct WaveShaperOptions {
pub curve: Option<Vec<f32>>,
pub oversample: Option<OverSampleType>,
pub channel_config: Option<ChannelConfigOptions>,
}
impl Default for WaveShaperOptions {
fn default() -> Self {
Self {
curve: Default::default(),
oversample: Some(OverSampleType::None),
channel_config: Default::default(),
}
}
}
#[allow(clippy::module_name_repetitions)]
pub struct WaveShaperNode {
registration: AudioContextRegistration,
channel_config: ChannelConfig,
curve: Option<Vec<f32>>,
set_curve: bool,
oversample: Arc<AtomicU32>,
sender: Sender<CurveMessage>,
}
impl AudioNode for WaveShaperNode {
fn registration(&self) -> &AudioContextRegistration {
&self.registration
}
fn channel_config_raw(&self) -> &ChannelConfig {
&self.channel_config
}
fn number_of_inputs(&self) -> u32 {
1
}
fn number_of_outputs(&self) -> u32 {
1
}
}
impl WaveShaperNode {
pub fn new<C: AsBaseAudioContext>(context: &C, options: Option<WaveShaperOptions>) -> Self {
context.base().register(move |registration| {
let WaveShaperOptions {
curve,
oversample,
channel_config,
} = options.unwrap_or_default();
#[allow(clippy::cast_precision_loss)]
let sample_rate = context.base().sample_rate().0 as usize;
let channel_config = channel_config.unwrap_or_default().into();
let oversample = Arc::new(AtomicU32::new(
oversample.expect("oversample should be OversampleType variant") as u32,
));
let set_curve = curve.is_some();
let (sender, receiver) = crossbeam_channel::bounded(0);
let config = RendererConfig {
sample_rate,
curve: curve.clone(),
oversample: oversample.clone(),
receiver,
};
let renderer = WaveShaperRenderer::new(config);
let node = Self {
registration,
channel_config,
curve,
set_curve,
oversample,
sender,
};
(node, Box::new(renderer))
})
}
#[must_use]
pub fn curve(&self) -> Option<&[f32]> {
self.curve.as_deref()
}
pub fn set_curve(&mut self, curve: Option<Vec<f32>>) {
self.validate_input_curve(curve.clone());
let c = curve.unwrap_or_else(Vec::new);
self.sender
.send(CurveMessage(c))
.expect("Sending CurveMessage failed");
}
#[cfg(test)]
pub fn set_curve_mock(&mut self, curve: Option<Vec<f32>>) {
self.validate_input_curve(curve.clone());
let _c = curve.unwrap_or_else(Vec::new);
}
fn validate_input_curve(&mut self, curve: Option<Vec<f32>>) {
match (self.set_curve, curve) {
(true, Some(_)) => panic!("InvalidStateError"),
(false, opt_c @ Some(_)) => {
self.set_curve = true;
self.curve = opt_c;
}
(_, opt_c) => {
self.curve = opt_c;
}
};
}
#[must_use]
pub fn oversample(&self) -> OverSampleType {
self.oversample.load(Ordering::SeqCst).into()
}
pub fn set_oversample(&mut self, oversample: OverSampleType) {
self.oversample.store(oversample as u32, Ordering::SeqCst);
}
}
struct RendererConfig {
sample_rate: usize,
oversample: Arc<AtomicU32>,
curve: Option<Vec<f32>>,
receiver: Receiver<CurveMessage>,
}
struct WaveShaperRenderer {
sample_rate: usize,
oversample: Arc<AtomicU32>,
channels_x2: usize,
channels_x4: usize,
upsampler_x2: FftFixedInOut<f32>,
upsampler_x4: FftFixedInOut<f32>,
downsampler_x2: FftFixedInOut<f32>,
downsampler_x4: FftFixedInOut<f32>,
curve: Vec<f32>,
curve_set: bool,
receiver: Receiver<CurveMessage>,
}
impl AudioProcessor for WaveShaperRenderer {
fn process(
&mut self,
inputs: &[crate::alloc::AudioBuffer],
outputs: &mut [crate::alloc::AudioBuffer],
_params: AudioParamValues,
_timestamp: f64,
_sample_rate: SampleRate,
) {
let input = &inputs[0];
let output = &mut outputs[0];
if let Ok(msg) = self.receiver.try_recv() {
self.curve = msg.0;
}
if !self.curve_set {
self.no_process(input, output);
}
use OverSampleType::*;
match self.oversample.load(Ordering::SeqCst).into() {
None => self.process_none(input, output),
X2 => {
if input.channels().len() != self.channels_x2 {
self.update_2x(input.channels().len());
}
self.process_2x(input, output)
}
X4 => {
if input.channels().len() != self.channels_x4 {
self.update_4x(input.channels().len());
}
self.process_4x(input, output)
}
}
}
fn tail_time(&self) -> bool {
true
}
}
impl WaveShaperRenderer {
#[allow(clippy::missing_const_for_fn)]
fn new(config: RendererConfig) -> Self {
let RendererConfig {
sample_rate,
oversample,
curve,
receiver,
} = config;
let (curve, curve_set) = match curve {
Some(c) => (c, true),
None => (Vec::new(), false),
};
let channels_x2 = 1;
let channels_x4 = 1;
let upsampler_x2 = FftFixedInOut::<f32>::new(
sample_rate as usize,
sample_rate as usize * 2,
256,
channels_x2,
);
let downsampler_x2 = FftFixedInOut::<f32>::new(
sample_rate as usize,
sample_rate as usize / 2,
128,
channels_x2,
);
let upsampler_x4 = FftFixedInOut::<f32>::new(
sample_rate as usize,
sample_rate as usize * 4,
512,
channels_x4,
);
let downsampler_x4 = FftFixedInOut::<f32>::new(
sample_rate as usize,
sample_rate as usize / 4,
128,
channels_x4,
);
Self {
sample_rate,
oversample,
curve,
curve_set,
upsampler_x2,
upsampler_x4,
downsampler_x2,
downsampler_x4,
channels_x2,
channels_x4,
receiver,
}
}
#[inline]
fn no_process(&self, input: &AudioBuffer, output: &mut AudioBuffer) {
for (i_data, o_data) in input.channels().iter().zip(output.channels_mut()) {
for (&i, o) in i_data.iter().zip(o_data.iter_mut()) {
*o = i;
}
}
}
#[inline]
fn process_none(&self, input: &AudioBuffer, output: &mut AudioBuffer) {
for (i_data, o_data) in input.channels().iter().zip(output.channels_mut()) {
o_data.copy_from_slice(&i_data[..]);
}
}
#[inline]
fn process_2x(&mut self, input: &AudioBuffer, output: &mut AudioBuffer) {
let wave_in = input.channels();
let up_wave_in = self.upsampler_x2.process(wave_in).unwrap();
let mut up_wave_out = up_wave_in.clone();
for (i_data, o_data) in up_wave_in.iter().zip(&mut up_wave_out) {
for (&i, o) in i_data.iter().zip(o_data.iter_mut()) {
*o = self.tick(i);
}
}
let wave_out = self.downsampler_x2.process(&up_wave_out).unwrap();
for (i_data, o_data) in wave_out.iter().zip(output.channels_mut()) {
for (&i, o) in i_data.iter().zip(o_data.iter_mut()) {
*o = i;
}
}
}
#[inline]
fn process_4x(&mut self, input: &AudioBuffer, output: &mut AudioBuffer) {
let wave_in = input.channels();
let up_wave_in = self.upsampler_x4.process(wave_in).unwrap();
let mut up_wave_out = up_wave_in.clone();
for (i_data, o_data) in up_wave_in.iter().zip(&mut up_wave_out) {
for (&i, o) in i_data.iter().zip(o_data.iter_mut()) {
*o = self.tick(i);
}
}
let wave_out = self.downsampler_x4.process(&up_wave_out).unwrap();
for (i_data, o_data) in wave_out.iter().zip(output.channels_mut()) {
for (&i, o) in i_data.iter().zip(o_data.iter_mut()) {
*o = i;
}
}
}
#[inline]
fn tick(&self, input: f32) -> f32 {
if self.curve.is_empty() {
return 0.;
}
let n = self.curve.len() as f32;
let v = (n - 1.) / 2.0 * (input + 1.);
if v <= 0. {
self.curve[0]
} else if v > n - 1. {
self.curve[(n - 1.) as usize]
} else {
let k = v.floor();
let f = v - k;
(1. - f) * self.curve[k as usize] + f * self.curve[(k + 1.) as usize]
}
}
#[inline]
fn update_2x(&mut self, channels_x2: usize) {
self.channels_x2 = channels_x2;
self.upsampler_x2 =
FftFixedInOut::<f32>::new(self.sample_rate, self.sample_rate * 2, 256, channels_x2);
self.downsampler_x2 =
FftFixedInOut::<f32>::new(self.sample_rate, self.sample_rate / 2, 128, channels_x2);
}
#[inline]
fn update_4x(&mut self, channels_x4: usize) {
self.channels_x4 = channels_x4;
self.upsampler_x4 =
FftFixedInOut::<f32>::new(self.sample_rate, self.sample_rate * 4, 512, channels_x4);
self.downsampler_x4 =
FftFixedInOut::<f32>::new(self.sample_rate, self.sample_rate / 4, 128, channels_x4);
}
}
#[cfg(test)]
mod test {
use crate::{
context::{AsBaseAudioContext, OfflineAudioContext},
node::WaveShaperOptions,
SampleRate,
};
use super::{OverSampleType, WaveShaperNode};
const LENGTH: usize = 555;
#[test]
fn build_with_new() {
let context = OfflineAudioContext::new(2, LENGTH, SampleRate(44_100));
let _shaper = WaveShaperNode::new(&context, None);
}
#[test]
fn build_with_factory_func() {
let context = OfflineAudioContext::new(2, LENGTH, SampleRate(44_100));
let _shaper = context.create_wave_shaper();
}
#[test]
fn default_audio_params_are_correct_with_no_options() {
let default_oversample = OverSampleType::None;
let default_curve = None;
let context = OfflineAudioContext::new(2, LENGTH, SampleRate(44_100));
let shaper = WaveShaperNode::new(&context, None);
assert_eq!(shaper.curve(), default_curve);
assert_eq!(shaper.oversample(), default_oversample);
}
#[test]
fn default_audio_params_are_correct_with_default_options() {
let default_oversample = OverSampleType::None;
let default_curve = None;
let context = OfflineAudioContext::new(2, LENGTH, SampleRate(44_100));
let options = WaveShaperOptions::default();
let shaper = WaveShaperNode::new(&context, Some(options));
assert_eq!(shaper.curve(), default_curve);
assert_eq!(shaper.oversample(), default_oversample);
}
#[test]
fn options_sets_audio_params() {
let mut context = OfflineAudioContext::new(2, LENGTH, SampleRate(44_100));
let options = WaveShaperOptions {
curve: Some(vec![1.0]),
oversample: Some(OverSampleType::X2),
..Default::default()
};
let shaper = WaveShaperNode::new(&context, Some(options));
context.start_rendering();
assert_eq!(shaper.curve(), Some(&[1.0][..]));
assert_eq!(shaper.oversample(), OverSampleType::X2);
}
#[test]
#[should_panic]
fn change_a_curve_for_another_curve_should_panic() {
let mut context = OfflineAudioContext::new(2, LENGTH, SampleRate(44_100));
let options = WaveShaperOptions {
curve: Some(vec![1.0]),
oversample: Some(OverSampleType::X2),
..Default::default()
};
let mut shaper = WaveShaperNode::new(&context, Some(options));
assert_eq!(shaper.curve(), Some(&[1.0][..]));
assert_eq!(shaper.oversample(), OverSampleType::X2);
shaper.set_curve_mock(Some(vec![2.0]));
shaper.set_oversample(OverSampleType::X4);
context.start_rendering();
assert_eq!(shaper.curve(), Some(&[2.0][..]));
assert_eq!(shaper.oversample(), OverSampleType::X4);
}
#[test]
fn change_none_for_curve_after_build() {
let mut context = OfflineAudioContext::new(2, LENGTH, SampleRate(44_100));
let options = WaveShaperOptions {
curve: None,
oversample: Some(OverSampleType::X2),
..Default::default()
};
let mut shaper = WaveShaperNode::new(&context, Some(options));
assert_eq!(shaper.curve(), None);
assert_eq!(shaper.oversample(), OverSampleType::X2);
shaper.set_curve_mock(Some(vec![2.0]));
shaper.set_oversample(OverSampleType::X4);
context.start_rendering();
assert_eq!(shaper.curve(), Some(&[2.0][..]));
assert_eq!(shaper.oversample(), OverSampleType::X4);
}
}