use rand::{rngs::SmallRng, Rng};
use std::time::Duration;
use crate::{
source::noise::{Blue, WhiteGaussian, WhiteTriangular, WhiteUniform},
BitDepth, ChannelCount, Float, Sample, SampleRate, Source,
};
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum Algorithm {
GPDF,
HighPass,
RPDF,
#[default]
TPDF,
}
#[derive(Clone, Debug)]
#[allow(clippy::upper_case_acronyms)]
enum NoiseGenerator<R: Rng = SmallRng> {
TPDF(WhiteTriangular<R>),
RPDF(WhiteUniform<R>),
GPDF(WhiteGaussian<R>),
HighPass(Vec<Blue<R>>),
}
impl NoiseGenerator {
fn new(algorithm: Algorithm, sample_rate: SampleRate, channels: ChannelCount) -> Self {
match algorithm {
Algorithm::TPDF => Self::TPDF(WhiteTriangular::new(sample_rate)),
Algorithm::RPDF => Self::RPDF(WhiteUniform::new(sample_rate)),
Algorithm::GPDF => Self::GPDF(WhiteGaussian::new(sample_rate)),
Algorithm::HighPass => {
Self::HighPass(
(0..channels.get())
.map(|_| Blue::new(sample_rate))
.collect(),
)
}
}
}
#[inline]
fn next(&mut self, channel: usize) -> Option<Sample> {
match self {
Self::TPDF(gen) => gen.next(),
Self::RPDF(gen) => gen.next(),
Self::GPDF(gen) => gen.next(),
Self::HighPass(gens) => gens[channel].next(),
}
}
#[inline]
fn algorithm(&self) -> Algorithm {
match self {
Self::TPDF(_) => Algorithm::TPDF,
Self::RPDF(_) => Algorithm::RPDF,
Self::GPDF(_) => Algorithm::GPDF,
Self::HighPass(_) => Algorithm::HighPass,
}
}
#[inline]
fn sample_rate(&self) -> SampleRate {
match self {
Self::TPDF(gen) => gen.sample_rate(),
Self::RPDF(gen) => gen.sample_rate(),
Self::GPDF(gen) => gen.sample_rate(),
Self::HighPass(gens) => gens
.first()
.map(|g| g.sample_rate())
.expect("HighPass should have at least one generator"),
}
}
#[inline]
fn update_parameters(&mut self, sample_rate: SampleRate, channels: ChannelCount) {
if self.sample_rate() != sample_rate {
*self = Self::new(self.algorithm(), sample_rate, channels);
} else if let Self::HighPass(gens) = self {
gens.resize_with(channels.get() as usize, || Blue::new(sample_rate));
}
}
}
#[derive(Clone, Debug)]
pub struct Dither<I> {
input: I,
noise: NoiseGenerator,
current_channel: usize,
remaining_in_span: Option<usize>,
lsb_amplitude: Float,
}
impl<I> Dither<I>
where
I: Source,
{
pub fn new(input: I, target_bits: BitDepth, algorithm: Algorithm) -> Self {
let lsb_amplitude = (1.0 / (1_u64 << (target_bits.get() - 1)) as f64) as Float;
let sample_rate = input.sample_rate();
let channels = input.channels();
let active_span_len = input.current_span_len();
Self {
input,
noise: NoiseGenerator::new(algorithm, sample_rate, channels),
current_channel: 0,
remaining_in_span: active_span_len,
lsb_amplitude,
}
}
pub fn set_algorithm(&mut self, algorithm: Algorithm) {
if self.noise.algorithm() != algorithm {
self.noise =
NoiseGenerator::new(algorithm, self.input.sample_rate(), self.input.channels());
}
}
#[inline]
pub fn algorithm(&self) -> Algorithm {
self.noise.algorithm()
}
}
impl<I> Iterator for Dither<I>
where
I: Source,
{
type Item = Sample;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if let Some(ref mut remaining) = self.remaining_in_span {
*remaining = remaining.saturating_sub(1);
}
let input_sample = self.input.next()?;
let num_channels = self.input.channels();
if self.remaining_in_span == Some(0) {
self.noise
.update_parameters(self.input.sample_rate(), num_channels);
self.current_channel = 0;
self.remaining_in_span = self.input.current_span_len();
}
let noise_sample = self
.noise
.next(self.current_channel)
.expect("Noise generator should always produce samples");
self.current_channel = (self.current_channel + 1) % num_channels.get() as usize;
Some(input_sample - noise_sample * self.lsb_amplitude)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.input.size_hint()
}
}
impl<I> ExactSizeIterator for Dither<I> where I: Source + ExactSizeIterator {}
impl<I> Source for Dither<I>
where
I: Source,
{
#[inline]
fn current_span_len(&self) -> Option<usize> {
self.input.current_span_len()
}
#[inline]
fn channels(&self) -> ChannelCount {
self.input.channels()
}
#[inline]
fn sample_rate(&self) -> SampleRate {
self.input.sample_rate()
}
#[inline]
fn total_duration(&self) -> Option<Duration> {
self.input.total_duration()
}
#[inline]
fn try_seek(&mut self, pos: Duration) -> Result<(), crate::source::SeekError> {
self.input.try_seek(pos)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::source::{SineWave, Source};
use crate::{nz, BitDepth, SampleRate};
const TEST_SAMPLE_RATE: SampleRate = nz!(44100);
const TEST_BIT_DEPTH: BitDepth = nz!(16);
#[test]
fn test_dither_adds_noise() {
let source = SineWave::new(440.0).take_duration(std::time::Duration::from_millis(10));
let mut dithered = Dither::new(source.clone(), TEST_BIT_DEPTH, Algorithm::TPDF);
let mut undithered = source;
let dithered_samples: Vec<Sample> = (0..10).filter_map(|_| dithered.next()).collect();
let undithered_samples: Vec<Sample> = (0..10).filter_map(|_| undithered.next()).collect();
let lsb = 1.0 / (1_i64 << (TEST_BIT_DEPTH.get() - 1)) as Float;
for (i, (&dithered_sample, &undithered_sample)) in dithered_samples
.iter()
.zip(undithered_samples.iter())
.enumerate()
{
assert!(
dithered_sample.is_finite(),
"Dithered sample {} should be finite",
i
);
let diff = (dithered_sample - undithered_sample).abs();
let max_expected_diff = lsb * 2.0; assert!(
diff <= max_expected_diff,
"Dither noise too large: sample {}, diff {}, max expected {}",
i,
diff,
max_expected_diff
);
}
}
#[test]
fn test_highpass_dither_multichannel_independence() {
use crate::source::Zero;
let constant_source = Zero::new(nz!(2), TEST_SAMPLE_RATE);
let mut dithered = Dither::new(constant_source, TEST_BIT_DEPTH, Algorithm::HighPass);
let samples: Vec<Sample> = dithered.by_ref().take(1000).collect();
let left: Vec<Sample> = samples.iter().step_by(2).copied().collect();
let right: Vec<Sample> = samples.iter().skip(1).step_by(2).copied().collect();
assert_eq!(left.len(), 500);
assert_eq!(right.len(), 500);
let left_autocorr: Float =
left.windows(2).map(|w| w[0] * w[1]).sum::<Float>() / (left.len() - 1) as Float;
let right_autocorr: Float =
right.windows(2).map(|w| w[0] * w[1]).sum::<Float>() / (right.len() - 1) as Float;
assert!(
left_autocorr < 0.0,
"Left channel should have negative autocorr (high-pass), got {}",
left_autocorr
);
assert!(
right_autocorr < 0.0,
"Right channel should have negative autocorr (high-pass), got {}",
right_autocorr
);
let cross_corr: Float = left
.iter()
.zip(right.iter())
.map(|(l, r)| l * r)
.sum::<Float>()
/ left.len() as Float;
assert!(
cross_corr.abs() < 0.1,
"Channels should be independent, cross-correlation should be near 0, got {}",
cross_corr
);
}
}