use crate::{
AudioSampleError, AudioSampleResult, AudioSamples, AudioTypeConversion, LayoutError,
ParameterError, traits::StandardSample,
};
use ndarray::{Array1, ArrayView1};
#[inline]
pub fn correlation<T>(a: &AudioSamples<T>, b: &AudioSamples<T>) -> AudioSampleResult<f64>
where
T: StandardSample,
{
if a.num_channels() != b.num_channels() || a.samples_per_channel() != b.samples_per_channel() {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_signals",
"Signals must have the same dimensions for correlation",
)));
}
let a_f = a.as_float();
let b_f = b.as_float();
match (a_f.as_mono(), b_f.as_mono()) {
(Some(a_mono), Some(b_mono)) => {
let corr = correlation_1d(&a_mono.view(), &b_mono.view())?;
Ok(corr)
}
(Some(_), None) | (None, Some(_)) => {
Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_format",
"Signals must have the same channel configuration",
)))
}
(None, None) => {
let a_multi = a_f.as_multi_channel().ok_or_else(|| {
AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_format",
"Must be multi-channel audio",
))
})?;
let b_multi = b_f.as_multi_channel().ok_or_else(|| {
AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_format",
"Must be multi-channel audio",
))
})?;
let mut correlations = Vec::new();
for i in 0..a_multi.nrows().get() {
let a_channel = a_multi.row(i);
let b_channel = b_multi.row(i);
let corr = correlation_1d_slice(
a_channel.as_slice().ok_or_else(|| {
AudioSampleError::Layout(LayoutError::NonContiguous {
operation: "signal processing".to_string(),
layout_type: "non-contiguous multi-channel samples".to_string(),
})
})?,
b_channel.as_slice().ok_or_else(|| {
AudioSampleError::Layout(LayoutError::NonContiguous {
operation: "signal processing".to_string(),
layout_type: "non-contiguous multi-channel samples".to_string(),
})
})?,
)?;
correlations.push(corr);
}
Ok(correlations.iter().fold(0.0, |acc, x| acc + *x) / correlations.len() as f64)
}
}
}
#[inline]
pub fn mse<T>(a: &AudioSamples<T>, b: &AudioSamples<T>) -> AudioSampleResult<f64>
where
T: StandardSample,
{
if a.num_channels() != b.num_channels() || a.samples_per_channel() != b.samples_per_channel() {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_signals",
"Signals must have the same dimensions for MSE",
)));
}
let a_f = a.as_float();
let b_f = b.as_float();
match (a_f.as_mono(), b_f.as_mono()) {
(Some(a_mono), Some(b_mono)) => mse_1d(&a_mono.view(), &b_mono.view()),
(Some(_), None) | (None, Some(_)) => {
Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_format",
"Signals must have the same channel configuration",
)))
}
(None, None) => {
let a_multi = a_f.as_multi_channel().ok_or_else(|| {
AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_format",
"Must be multi-channel audio",
))
})?;
let b_multi = b_f.as_multi_channel().ok_or_else(|| {
AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_format",
"Must be multi-channel audio",
))
})?;
let mut mses = Vec::new();
for i in 0..a_multi.nrows().get() {
let a_channel = a_multi.row(i);
let b_channel = b_multi.row(i);
let mse = mse_1d_slice(
a_channel.as_slice().ok_or_else(|| {
AudioSampleError::Layout(LayoutError::NonContiguous {
operation: "signal processing".to_string(),
layout_type: "non-contiguous multi-channel samples".to_string(),
})
})?,
b_channel.as_slice().ok_or_else(|| {
AudioSampleError::Layout(LayoutError::NonContiguous {
operation: "signal processing".to_string(),
layout_type: "non-contiguous multi-channel samples".to_string(),
})
})?,
)?;
mses.push(mse);
}
Ok(mses.iter().fold(0.0, |acc, x| acc + *x) / mses.len() as f64)
}
}
}
#[inline]
pub fn snr<T>(signal: &AudioSamples<T>, noise: &AudioSamples<T>) -> AudioSampleResult<f64>
where
T: StandardSample,
{
if signal.num_channels() != noise.num_channels()
|| signal.samples_per_channel() != noise.samples_per_channel()
{
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_signals",
"Signal and noise must have the same dimensions for SNR",
)));
}
let signal_f = signal.as_float();
let noise_f = noise.as_float();
let signal_power = if let Some(mono) = signal_f.as_mono() {
mono.iter().map(|&x| x * x).fold(0.0, |acc, x| acc + x) / mono.len().get() as f64
} else {
let multi = signal_f.as_multi_channel().ok_or_else(|| {
AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_format",
"Must be multi-channel audio",
))
})?;
multi.iter().map(|&x| x * x).fold(0.0, |acc, x| acc + x) / multi.len().get() as f64
};
let noise_power = if let Some(mono) = noise_f.as_mono() {
mono.iter().map(|x| *x * *x).fold(0.0, |acc, x| acc + x) / mono.len().get() as f64
} else {
let multi = noise_f.as_multi_channel().ok_or_else(|| {
AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_format",
"Must be multi-channel audio",
))
})?;
multi.iter().map(|&x| x * x).fold(0.0, |acc, x| acc + x) / multi.len().get() as f64
};
if noise_power == 0.0 {
return Ok(f64::INFINITY);
}
let snr_db = 10.0 * (signal_power / noise_power).log10();
Ok(snr_db)
}
#[inline]
pub fn align_signals<T>(
reference: &AudioSamples<'_, T>,
signal: &AudioSamples<'_, T>,
) -> AudioSampleResult<(AudioSamples<'static, T>, usize)>
where
T: StandardSample,
{
if reference.num_channels() != signal.num_channels() {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_channels",
"Signals must have the same number of channels for alignment",
)));
}
let sample_rate = signal.sample_rate();
let ref_f = reference.as_float();
let sig_f = signal.as_float();
let (ref_data, sig_data) = match (ref_f.as_mono(), sig_f.as_mono()) {
(Some(ref_mono), Some(sig_mono)) => (ref_mono.to_vec(), sig_mono.to_vec()),
(None, None) => {
let ref_multi = ref_f.as_multi_channel().ok_or_else(|| {
AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_format",
"Must be multi-channel audio",
))
})?;
let sig_multi = sig_f.as_multi_channel().ok_or_else(|| {
AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_format",
"Must be multi-channel audio",
))
})?;
let ref_avg: Vec<f64> = (0..ref_multi.ncols().get())
.map(|i| {
ref_multi.column(i).iter().fold(0.0, |acc, x| acc + *x)
/ ref_multi.nrows().get() as f64
})
.collect();
let sig_avg: Vec<f64> = (0..sig_multi.ncols().get())
.map(|i| {
sig_multi.column(i).iter().fold(0.0, |acc, x| acc + *x)
/ sig_multi.nrows().get() as f64
})
.collect();
(ref_avg, sig_avg)
}
_ => {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_format",
"Signals must have the same channel configuration",
)));
}
};
let max_offset = ref_data.len().min(sig_data.len()) / 2;
let mut best_offset = 0;
let mut best_correlation = f64::NEG_INFINITY;
for offset in 0..max_offset {
let correlation = if offset < sig_data.len() {
let end = (ref_data.len() - offset).min(sig_data.len() - offset);
correlation_1d_slice(
&ref_data[offset..offset + end],
&sig_data[offset..offset + end],
)?
} else {
0.0
};
if correlation > best_correlation {
best_correlation = correlation;
best_offset = offset;
}
}
let aligned_signal = if best_offset > 0 {
if let Some(mono) = signal.as_mono() {
let mut aligned_data = vec![T::default(); best_offset];
aligned_data.extend_from_slice(
&mono.as_slice().ok_or_else(|| {
AudioSampleError::Layout(LayoutError::NonContiguous {
operation: "signal alignment".to_string(),
layout_type: "non-contiguous mono samples".to_string(),
})
})?[..mono.len().get() - best_offset],
);
let aligned_array = Array1::from_vec(aligned_data);
AudioSamples::new_mono(aligned_array, sample_rate)?
} else {
let multi = signal.as_multi_channel().ok_or_else(|| {
AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_format",
"Must be multi-channel audio",
))
})?;
let mut aligned_data = Vec::new();
for i in 0..multi.nrows().get() {
let mut row = vec![T::default(); best_offset];
row.extend_from_slice(
&multi.row(i).as_slice().ok_or_else(|| {
AudioSampleError::Layout(LayoutError::NonContiguous {
operation: "signal alignment".to_string(),
layout_type: "non-contiguous multi-channel samples".to_string(),
})
})?[..multi.ncols().get() - best_offset],
);
aligned_data.push(row);
}
let aligned_array = ndarray::Array2::from_shape_vec(
(aligned_data.len(), aligned_data[0].len()),
aligned_data.into_iter().flatten().collect(),
)
.map_err(|e| {
AudioSampleError::Parameter(ParameterError::invalid_value(
"array_shape",
format!("Array shape error: {e}"),
))
})?;
AudioSamples::new_multi_channel(aligned_array, sample_rate)?
}
} else {
signal.clone().into_owned()
};
Ok((aligned_signal, best_offset))
}
fn correlation_1d(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> AudioSampleResult<f64> {
correlation_1d_slice(
a.as_slice().ok_or_else(|| {
AudioSampleError::Layout(LayoutError::NonContiguous {
operation: "correlation calculation".to_string(),
layout_type: "non-contiguous mono samples".to_string(),
})
})?,
b.as_slice().ok_or_else(|| {
AudioSampleError::Layout(LayoutError::NonContiguous {
operation: "correlation calculation".to_string(),
layout_type: "non-contiguous mono samples".to_string(),
})
})?,
)
}
fn correlation_1d_slice(a: &[f64], b: &[f64]) -> AudioSampleResult<f64> {
if a.len() != b.len() {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"array_length",
"Arrays must have the same length for correlation",
)));
}
let n = a.len();
let mean_a = a.iter().fold(0.0, |acc, x| acc + *x) / n as f64;
let mean_b = b.iter().fold(0.0, |acc, x| acc + *x) / n as f64;
let mut num = 0.0;
let mut den_a = 0.0;
let mut den_b = 0.0;
for (&x, &y) in a.iter().zip(b.iter()) {
let dx = x - mean_a;
let dy = y - mean_b;
num += dx * dy;
den_a += dx * dx;
den_b += dy * dy;
}
let denominator = (den_a * den_b).sqrt();
if denominator == 0.0 {
Ok(0.0)
} else {
Ok(num / denominator)
}
}
fn mse_1d(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> AudioSampleResult<f64> {
if a.len() != b.len() {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"array_length",
"Arrays must have the same length for MSE",
)));
}
let n = a.len();
let sum_squared_diff: f64 = a
.iter()
.zip(b.iter())
.map(|(&x, &y)| (x - y) * (x - y))
.fold(0.0, |acc, val| acc + val);
Ok(sum_squared_diff / n as f64)
}
fn mse_1d_slice(a: &[f64], b: &[f64]) -> AudioSampleResult<f64> {
if a.len() != b.len() {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"array_length",
"Arrays must have the same length for MSE",
)));
}
let n = a.len();
let sum_squared_diff: f64 = a
.iter()
.zip(b.iter())
.map(|(&x, &y)| (x - y) * (x - y))
.fold(0.0, |acc, val| acc + val)
/ n as f64;
Ok(sum_squared_diff)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sample_rate;
use approx_eq::assert_approx_eq;
use non_empty_slice::non_empty_vec;
#[test]
fn test_correlation_identical_signals() {
let data: non_empty_slice::NonEmptyVec<f64> = non_empty_vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
let audio1: AudioSamples<'static, f64> =
AudioSamples::from_mono_vec::<f64>(data.clone(), sample_rate!(44100));
let audio2: AudioSamples<'static, f64> =
AudioSamples::from_mono_vec::<f64>(data, sample_rate!(44100));
let corr: f64 = correlation(&audio1, &audio2).unwrap();
assert_approx_eq!(corr, 1.0, 1e-10);
}
#[test]
fn test_correlation_opposite_signals() {
let audio1: AudioSamples<'static, f64> = AudioSamples::from_mono_vec::<f64>(
non_empty_vec![1.0f64, 2.0, 3.0, 4.0, 5.0],
sample_rate!(44100),
);
let audio2 = AudioSamples::from_mono_vec::<f64>(
non_empty_vec![-1.0f64, -2.0, -3.0, -4.0, -5.0],
sample_rate!(44100),
);
let corr: f64 = correlation(&audio1, &audio2).unwrap();
assert_approx_eq!(corr, -1.0, 1e-10);
}
#[test]
fn test_mse_identical_signals() {
let audio1: AudioSamples<'static, f64> = AudioSamples::from_mono_vec::<f64>(
non_empty_vec![1.0f64, 2.0, 3.0, 4.0, 5.0],
sample_rate!(44100),
);
let audio2 = AudioSamples::from_mono_vec::<f64>(
non_empty_vec![1.0f64, 2.0, 3.0, 4.0, 5.0],
sample_rate!(44100),
);
let mse_val = mse(&audio1, &audio2).unwrap();
assert_approx_eq!(mse_val, 0.0_f64, 1e-10);
}
#[test]
fn test_snr_calculation() {
let signal: AudioSamples<'static, f64> = AudioSamples::from_mono_vec::<f64>(
non_empty_vec![1.0f64, 2.0, 3.0, 4.0, 5.0],
sample_rate!(44100),
);
let noise = AudioSamples::from_mono_vec::<f64>(
non_empty_vec![0.1f64, 0.2, 0.1, 0.2, 0.1],
sample_rate!(44100),
);
let snr_val: f64 = snr(&signal, &noise).unwrap();
assert!(snr_val > 0.0_f64); }
#[test]
fn test_align_signals_no_offset() {
let data = non_empty_vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
let reference = AudioSamples::from_mono_vec::<f64>(data.clone(), sample_rate!(44100));
let signal = AudioSamples::from_mono_vec::<f64>(data, sample_rate!(44100));
let (aligned, offset) = align_signals::<f64>(&reference, &signal).unwrap();
assert_eq!(offset, 0);
assert_eq!(aligned.samples_per_channel(), signal.samples_per_channel());
}
}