use smol_str::format_smolstr;
use crate::{
Array, Error, Result,
error::{
AllocFailurePayload, ArithmeticOverflowPayload, CapExceededPayload, DtypeMismatchPayload,
EmptyInputPayload, InvariantViolationPayload, LengthMismatchPayload, NonFiniteScalarPayload,
OutOfRangePayload, RankMismatchPayload, UnknownEnumValuePayload,
},
ops::{
self,
fft::{self, FftNorm},
},
};
const MEL_HZ_DIV: f32 = 2595.0;
const MEL_HZ_BREAK: f32 = 700.0;
const MEL_LOG_BASE: f32 = 10.0;
const LOG_FLOOR_WHISPER: f32 = 1e-10;
const LOG_FLOOR_KALDI: f32 = 1e-8;
const MAX_OLA_WORK: usize = 64 * 1024 * 1024;
const MAX_STFT_WORK: usize = 64 * 1024 * 1024;
const COVERAGE_EPS: f32 = 1e-10;
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Default, derive_more::Display, derive_more::IsVariant,
)]
#[display("{}", self.as_str())]
pub enum WindowPad {
#[default]
Right,
Center,
}
impl WindowPad {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Right => "right",
Self::Center => "center",
}
}
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Default, derive_more::Display, derive_more::IsVariant,
)]
#[display("{}", self.as_str())]
pub enum PadMode {
#[default]
Reflect,
}
impl PadMode {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Reflect => "reflect",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StftConfig {
center: bool,
pad_mode: PadMode,
}
impl StftConfig {
pub const fn new(center: bool, pad_mode: PadMode) -> Self {
Self { center, pad_mode }
}
#[inline(always)]
pub fn center(&self) -> bool {
self.center
}
#[inline(always)]
pub fn pad_mode(&self) -> PadMode {
self.pad_mode
}
}
impl Default for StftConfig {
fn default() -> Self {
Self::new(true, PadMode::Reflect)
}
}
#[derive(Debug)]
pub struct Spectrum {
data: Array,
n_fft: usize,
hop_length: usize,
win_length: usize,
window_pad: WindowPad,
center: bool,
}
impl Spectrum {
#[inline(always)]
pub fn data_ref(&self) -> &Array {
&self.data
}
#[inline(always)]
pub fn n_fft(&self) -> usize {
self.n_fft
}
#[inline(always)]
pub fn hop_length(&self) -> usize {
self.hop_length
}
#[inline(always)]
pub fn win_length(&self) -> usize {
self.win_length
}
#[inline(always)]
pub fn window_pad(&self) -> WindowPad {
self.window_pad
}
#[inline(always)]
pub fn center(&self) -> bool {
self.center
}
#[inline(always)]
pub fn num_frames(&self) -> usize {
self.data.shape()[0]
}
#[inline(always)]
pub fn n_freqs(&self) -> usize {
self.data.shape()[1]
}
pub fn from_parts(
data: Array,
n_fft: usize,
hop_length: usize,
win_length: usize,
window_pad: WindowPad,
center: bool,
) -> Result<Spectrum> {
let shape = data.shape();
if shape.len() != 2 {
let rank = shape.len() as u32;
return Err(Error::RankMismatch(RankMismatchPayload::new(
"Spectrum::from_parts: data must be 2-D (num_frames, n_freqs)",
rank,
shape,
)));
}
if n_fft == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"Spectrum::from_parts: n_fft",
"must be > 0",
)));
}
if !n_fft.is_multiple_of(2) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Spectrum::from_parts: n_fft",
"must be even (odd n_fft is unsupported because the one-sided spectrum \
cannot be inverted unambiguously)",
format_smolstr!("{n_fft}"),
)));
}
if hop_length == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"Spectrum::from_parts: hop_length",
"must be > 0",
)));
}
if win_length == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"Spectrum::from_parts: win_length",
"must be > 0",
)));
}
if win_length > n_fft {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Spectrum::from_parts: win_length (the window cannot exceed the irfft frame)",
"must be <= n_fft",
format_smolstr!("win_length={win_length}, n_fft={n_fft}"),
)));
}
let num_frames = shape[0];
if num_frames == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"Spectrum::from_parts: num_frames",
"must be > 0",
)));
}
let n_freqs = shape[1];
let expected_n_freqs = n_fft / 2 + 1;
if n_freqs != expected_n_freqs {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"Spectrum::from_parts: data last dim must equal n_fft/2 + 1 \
(the bin count must match the declared n_fft)",
expected_n_freqs,
n_freqs,
)));
}
if data.dtype()? != crate::Dtype::Complex64 {
return Err(Error::DtypeMismatch(DtypeMismatchPayload::new(
crate::Dtype::Complex64,
data.dtype()?,
)));
}
Ok(Spectrum {
data,
n_fft,
hop_length,
win_length,
window_pad,
center,
})
}
}
fn place_window(
caller: &'static str,
w: &Array,
win_length: usize,
n_fft: usize,
window_pad: WindowPad,
) -> Result<Array> {
if w.ndim() != 1 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
caller,
w.ndim() as u32,
w.shape(),
)));
}
let w_len = w.shape()[0];
if w_len != win_length {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
caller, win_length, w_len,
)));
}
if win_length > n_fft {
return Err(Error::OutOfRange(OutOfRangePayload::new(
caller,
"win_length must be <= n_fft",
format_smolstr!("win_length={win_length}, n_fft={n_fft}"),
)));
}
if win_length == n_fft {
return w.try_clone();
}
let total = n_fft - win_length;
let (low, high) = match window_pad {
WindowPad::Right => (0usize, total),
WindowPad::Center => {
let low = total / 2;
(low, total - low)
}
};
let pad_value = Array::zeros::<f32>(&[0i32; 0])?;
let low_i32 = i32::try_from(low).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
caller,
"window pad-low must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{low}"),
))
})?;
let high_i32 = i32::try_from(high).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
caller,
"window pad-high must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{high}"),
))
})?;
ops::shape::pad(
w,
&[0_i32],
&[low_i32],
&[high_i32],
&pad_value,
c"constant",
)
}
fn frame_window(win_length: usize, n_fft: usize, window_pad: WindowPad) -> Result<Array> {
let w = hann_window(win_length)?;
place_window("frame_window", &w, win_length, n_fft, window_pad)
}
#[derive(Debug, Clone, Copy, PartialEq, Default, derive_more::IsVariant)]
pub enum LogFloor {
#[default]
Whisper,
Kaldi,
Custom(f32),
}
impl LogFloor {
pub fn value(self) -> f32 {
match self {
LogFloor::Whisper => LOG_FLOOR_WHISPER,
LogFloor::Kaldi => LOG_FLOOR_KALDI,
LogFloor::Custom(x) => {
if x.is_finite() && x > 0.0 {
x
} else {
f32::MIN_POSITIVE
}
}
}
}
}
fn symmetric_window(
name: &'static str,
n: usize,
kind: crate::simd::audio::window::SymWindowKind,
) -> Result<Array> {
if n < 2 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
name,
"n must be >= 2",
format_smolstr!("{n}"),
)));
}
if n > crate::audio::io::MAX_DECODED_SAMPLES {
return Err(Error::CapExceeded(CapExceededPayload::new(
name,
"MAX_DECODED_SAMPLES",
crate::audio::io::MAX_DECODED_SAMPLES as u64,
n as u64,
)));
}
let n_i32 = i32::try_from(n).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
name,
"n must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{n}"),
))
})?;
let buf = crate::simd::audio::window::symmetric_window(kind, n)?;
Array::from_slice::<f32>(&buf, &[n_i32])
}
pub fn hann_window(n: usize) -> Result<Array> {
symmetric_window(
"hann_window",
n,
crate::simd::audio::window::SymWindowKind::Hann,
)
}
pub fn hamming_window(n: usize) -> Result<Array> {
symmetric_window(
"hamming_window",
n,
crate::simd::audio::window::SymWindowKind::Hamming,
)
}
pub fn blackman_window(n: usize) -> Result<Array> {
symmetric_window(
"blackman_window",
n,
crate::simd::audio::window::SymWindowKind::Blackman,
)
}
pub fn bartlett_window(n: usize) -> Result<Array> {
symmetric_window(
"bartlett_window",
n,
crate::simd::audio::window::SymWindowKind::Bartlett,
)
}
pub fn window_from_name(name: &str, n: usize) -> Result<Array> {
match name.to_ascii_lowercase().as_str() {
"hann" | "hanning" => hann_window(n),
"hamming" => hamming_window(n),
"blackman" => blackman_window(n),
"bartlett" => bartlett_window(n),
other => Err(Error::UnknownEnumValue(UnknownEnumValuePayload::new(
"window_from_name",
other,
&["hann", "hanning", "hamming", "blackman", "bartlett"],
))),
}
}
fn reflect_pad_1d(samples: &Array, padding: usize) -> Result<Array> {
if padding == 0 {
return samples.try_clone();
}
let shape = samples.shape();
if shape.len() != 1 {
let rank = shape.len() as u32;
return Err(Error::RankMismatch(RankMismatchPayload::new(
"reflect_pad_1d: expected 1-D input",
rank,
shape,
)));
}
let len = shape[0];
if len < padding + 1 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"reflect_pad_1d: samples len for reflect padding",
"must be >= padding + 1",
format_smolstr!("len={len}, padding={padding}"),
)));
}
let p_i32 = i32::try_from(padding).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"reflect_pad_1d: padding",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{padding}"),
))
})?;
let len_i32 = i32::try_from(len).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"reflect_pad_1d: samples len",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{len}"),
))
})?;
let prefix = ops::indexing::slice(samples, &[p_i32], &[0], &[-1])?;
let suffix_start = len_i32 - 2;
let suffix_stop = if padding + 1 < len {
len_i32 - p_i32 - 2
} else {
let sentinel_i64 = -(i64::from(len_i32) + 1);
i32::try_from(sentinel_i64).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"reflect_pad_1d: reflect-pad sentinel `-(len + 1)` \
(len == padding + 1, near i32::MAX boundary)",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("sentinel={sentinel_i64}, len={len}"),
))
})?
};
let suffix = ops::indexing::slice(samples, &[suffix_start], &[suffix_stop], &[-1])?;
ops::shape::concatenate(&[&prefix, samples, &suffix], 0)
}
pub fn stft(
samples: &Array,
n_fft: usize,
hop_length: usize,
win_length: Option<usize>,
window_pad: WindowPad,
) -> Result<Spectrum> {
stft_with_config(
samples,
n_fft,
hop_length,
win_length,
window_pad,
&StftConfig::default(),
)
}
pub fn stft_aligned(
samples: &Array,
n_fft: usize,
hop_length: usize,
win_length: Option<usize>,
window_pad: WindowPad,
) -> Result<Spectrum> {
stft_with_config(
samples,
n_fft,
hop_length,
win_length,
window_pad,
&StftConfig::new(false, PadMode::Reflect),
)
}
pub fn stft_with_config(
samples: &Array,
n_fft: usize,
hop_length: usize,
win_length: Option<usize>,
window_pad: WindowPad,
cfg: &StftConfig,
) -> Result<Spectrum> {
if n_fft == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"stft: n_fft",
"must be > 0",
)));
}
if !n_fft.is_multiple_of(2) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"stft: n_fft",
"must be even (odd n_fft is unsupported because the one-sided spectrum \
cannot be inverted unambiguously)",
format_smolstr!("{n_fft}"),
)));
}
if hop_length == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"stft: hop_length",
"must be > 0",
)));
}
let win_length = win_length.unwrap_or(n_fft);
if win_length == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"stft: win_length",
"must be > 0",
)));
}
if win_length > n_fft {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"stft: win_length",
"must be <= n_fft (unsupported)",
format_smolstr!("win_length={win_length}, n_fft={n_fft}"),
)));
}
let shape = samples.shape();
if shape.len() != 1 {
let rank = shape.len() as u32;
return Err(Error::RankMismatch(RankMismatchPayload::new(
"stft: expected 1-D input",
rank,
shape,
)));
}
let samples_len = shape[0];
if samples_len > crate::audio::io::MAX_DECODED_SAMPLES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"stft: input sample count exceeds sample budget \
(would force a reflect-pad allocation proportional to the input)",
"MAX_DECODED_SAMPLES",
crate::audio::io::MAX_DECODED_SAMPLES as u64,
samples_len as u64,
)));
}
if cfg.center() {
let padded_len_budget = samples_len.checked_add(n_fft).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"stft: padded length samples_len + n_fft",
"usize",
[("samples_len", samples_len as u64), ("n_fft", n_fft as u64)],
))
})?;
if padded_len_budget > crate::audio::io::MAX_DECODED_SAMPLES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"stft: padded length (= samples_len + n_fft) exceeds sample budget \
(would force a reflect-pad allocation proportional to the input)",
"MAX_DECODED_SAMPLES",
crate::audio::io::MAX_DECODED_SAMPLES as u64,
padded_len_budget as u64,
)));
}
}
let padded = if cfg.center() {
match cfg.pad_mode() {
PadMode::Reflect => reflect_pad_1d(samples, n_fft / 2)?,
}
} else {
samples.try_clone()?
};
let padded_len = padded.shape()[0];
if padded_len < n_fft {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"stft: padded_len (input is too short for n_fft)",
"must be >= n_fft",
format_smolstr!("padded_len={padded_len}, n_fft={n_fft}"),
)));
}
let num_frames = 1 + (padded_len - n_fft) / hop_length;
if num_frames == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"stft: num_frames (input is too short for n_fft/hop_length)",
"must be >= 1",
format_smolstr!("num_frames=0, n_fft={n_fft}, hop_length={hop_length}"),
)));
}
let last_element_index = (num_frames - 1)
.checked_mul(hop_length)
.and_then(|v| v.checked_add(n_fft))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"stft: reachable element range ((num_frames - 1) * hop_length + n_fft)",
"usize",
[
("num_frames", num_frames as u64),
("hop_length", hop_length as u64),
("n_fft", n_fft as u64),
],
))
})?;
if last_element_index > padded_len {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"stft: derived frame reach (internal invariant violated)",
"must be <= padded_len",
format_smolstr!(
"last_element_index={last_element_index}, padded_len={padded_len}, \
num_frames={num_frames}, hop_length={hop_length}, n_fft={n_fft}"
),
)));
}
let frame_work = num_frames.checked_mul(n_fft).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"stft: frame work count num_frames * n_fft",
"usize",
[("num_frames", num_frames as u64), ("n_fft", n_fft as u64)],
))
})?;
if frame_work > MAX_STFT_WORK {
return Err(Error::CapExceeded(CapExceededPayload::new(
"stft: frame work count (= num_frames * n_fft) exceeds work cap",
"MAX_STFT_WORK",
MAX_STFT_WORK as u64,
frame_work as u64,
)));
}
let out_elems = num_frames.checked_mul(n_fft / 2 + 1).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"stft: output element count num_frames * (n_fft/2 + 1)",
"usize",
[("num_frames", num_frames as u64), ("n_fft", n_fft as u64)],
))
})?;
if out_elems > MAX_STFT_WORK {
return Err(Error::CapExceeded(CapExceededPayload::new(
"stft: output element count (= num_frames * (n_fft/2 + 1)) exceeds work cap",
"MAX_STFT_WORK",
MAX_STFT_WORK as u64,
out_elems as u64,
)));
}
let window = frame_window(win_length, n_fft, window_pad)?;
let num_frames_i32 = i32::try_from(num_frames).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"stft: num_frames",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{num_frames}"),
))
})?;
let n_fft_i32 = i32::try_from(n_fft).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"stft: n_fft",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{n_fft}"),
))
})?;
let hop_i64 = i64::try_from(hop_length).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"stft: hop_length",
"must fit in i64 (i64::MAX = 9223372036854775807)",
format_smolstr!("{hop_length}"),
))
})?;
let shape: &[i32] = &[num_frames_i32, n_fft_i32];
let frames = unsafe { ops::shape::as_strided(&padded, &shape, &[hop_i64, 1], 0)? };
let windowed = ops::arithmetic::multiply(&frames, &window)?;
let data = fft::rfft(&windowed, n_fft_i32, 1, FftNorm::Backward)?;
Ok(Spectrum {
data,
n_fft,
hop_length,
win_length,
window_pad,
center: cfg.center(),
})
}
pub fn istft(spectrum: &Spectrum, length: Option<usize>) -> Result<Array> {
let x = spectrum.data_ref();
let n_fft = spectrum.n_fft();
let hop_length = spectrum.hop_length();
let win_length = spectrum.win_length();
let window_pad = spectrum.window_pad();
let center = spectrum.center();
let shape = x.shape();
if shape.len() != 2 {
let rank = shape.len() as u32;
return Err(Error::RankMismatch(RankMismatchPayload::new(
"istft: expected 2-D (num_frames, n_freqs) spectrum data",
rank,
shape,
)));
}
let num_frames = shape[0];
if matches!(window_pad, WindowPad::Right) && win_length != n_fft {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"istft: win_length with WindowPad::Right \
(right-pad short-window inversion is not a faithful inverse — \
use WindowPad::Center for short-window inversion)",
"must equal n_fft",
format_smolstr!("win_length={win_length}, n_fft={n_fft}"),
)));
}
let frame_width = n_fft;
let t = (num_frames - 1)
.checked_mul(hop_length)
.and_then(|v| v.checked_add(frame_width))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"istft: OLA length ((num_frames - 1) * hop + n_fft)",
"usize",
[
("num_frames", num_frames as u64),
("hop_length", hop_length as u64),
("n_fft", n_fft as u64),
],
))
})?;
if t > crate::audio::io::MAX_DECODED_SAMPLES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"istft: OLA length exceeds cap",
"MAX_DECODED_SAMPLES",
crate::audio::io::MAX_DECODED_SAMPLES as u64,
t as u64,
)));
}
let idx_len = num_frames.checked_mul(frame_width).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"istft: scatter work count num_frames * n_fft",
"usize",
[("num_frames", num_frames as u64), ("n_fft", n_fft as u64)],
))
})?;
if idx_len > MAX_OLA_WORK {
return Err(Error::CapExceeded(CapExceededPayload::new(
"istft: scatter work count (= num_frames * n_fft) exceeds work cap",
"MAX_OLA_WORK",
MAX_OLA_WORK as u64,
idx_len as u64,
)));
}
let idx_len_i32 = i32::try_from(idx_len).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"istft: scatter work count",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{idx_len}"),
))
})?;
let t_i32 = i32::try_from(t).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"istft: OLA length",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{t}"),
))
})?;
let n_fft_i32 = i32::try_from(n_fft).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"istft: n_fft",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{n_fft}"),
))
})?;
let window = frame_window(win_length, n_fft, window_pad)?;
let frames_time = fft::irfft(x, n_fft_i32, 1, FftNorm::Backward)?;
let windowed = ops::arithmetic::multiply(&frames_time, &window)?;
let updates_reconstructed = ops::shape::flatten(&windowed, 0, -1)?;
let window_norm = ops::arithmetic::multiply(&window, &window)?;
let window_norm_row = ops::shape::reshape(&window_norm, &(1usize, frame_width))?;
let window_norm_tiled = ops::shape::broadcast_to(&window_norm_row, &(num_frames, frame_width))?;
let updates_window = ops::shape::flatten(&window_norm_tiled, 0, -1)?;
let mut idx_buf: Vec<i32> = Vec::new();
idx_buf.try_reserve_exact(idx_len).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"istft: index reservation",
"i32 elements",
idx_len as u64,
e,
))
})?;
let frame_width_i32 = i32::try_from(frame_width).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"istft: frame_width (n_fft)",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{frame_width}"),
))
})?;
for m in 0..num_frames {
let off = (m * hop_length) as i32;
for j in 0..frame_width_i32 {
idx_buf.push(off + j);
}
}
let indices = Array::from_slice::<i32>(&idx_buf, &[idx_len_i32])?;
let zeros_recon = Array::zeros::<f32>(&[t_i32])?;
let zeros_wsum = Array::zeros::<f32>(&[t_i32])?;
let reconstructed =
ops::indexing::scatter_add_axis(&zeros_recon, &indices, &updates_reconstructed, 0)?;
let window_sum = ops::indexing::scatter_add_axis(&zeros_wsum, &indices, &updates_window, 0)?;
let pad = n_fft / 2;
let (start_usize, stop_usize) = match (center, length) {
(true, Some(len)) => {
let end = pad.checked_add(len).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"istft: center offset pad + length",
"usize",
[("pad", pad as u64), ("len", len as u64)],
))
})?;
if end > t {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"istft: center offset pad + length",
"must be <= reconstruction length t",
format_smolstr!("pad={pad}, len={len}, end={end}, t={t}"),
)));
}
(pad, end)
}
(true, None) => {
(pad, t - pad)
}
(false, Some(len)) => {
if len > t {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"istft: requested length",
"must be <= reconstruction length t",
format_smolstr!("len={len}, t={t}"),
)));
}
(0usize, len)
}
(false, None) => (0usize, t),
};
let start_i32 = i32::try_from(start_usize).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"istft: trim start",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{start_usize}"),
))
})?;
let stop_i32 = i32::try_from(stop_usize).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"istft: trim stop",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{stop_usize}"),
))
})?;
if start_usize < stop_usize {
let region_wsum = ops::indexing::slice(&window_sum, &[start_i32], &[stop_i32], &[1])?;
let mut region_min = ops::reduction::min(®ion_wsum, false)?;
let min_wsum = region_min.item::<f32>()?;
if min_wsum <= COVERAGE_EPS || min_wsum.is_nan() {
let mut min_idx_arr = ops::misc::argmin(®ion_wsum, None, false)?;
let local_idx = min_idx_arr.item::<u32>()? as usize;
let global_idx = start_usize + local_idx;
return Err(Error::OutOfRange(OutOfRangePayload::new(
"istft: requested output sample window-sum (received no window coverage \
in the overlap-add and is not recoverable; \
the requested region (e.g. center=false head/tail) includes a zero-coverage sample — \
adjust length/center or the window)",
"must be > COVERAGE_EPS (1e-10) and finite",
format_smolstr!(
"global_idx={global_idx}, local_idx={local_idx}, min_wsum={min_wsum:.3e}, \
n_fft={n_fft}, win_length={win_length}, hop={hop_length}, window_pad={window_pad:?}"
),
)));
}
}
let threshold = Array::full::<f32>(&[0i32; 0], COVERAGE_EPS)?;
let mask = ops::comparison::greater(&window_sum, &threshold)?;
let normalized_recon = ops::arithmetic::divide(&reconstructed, &window_sum)?;
let reconstructed = ops::logical::select(&mask, &normalized_recon, &reconstructed)?;
ops::indexing::slice(&reconstructed, &[start_i32], &[stop_i32], &[1])
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct PositionKey {
num_frames: usize,
frame_width: usize,
hop_length: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct NormKey {
n_fft: usize,
hop_length: usize,
win_length: usize,
window_pad: WindowPad,
num_frames: usize,
}
#[derive(Debug)]
pub struct ISTFTCache {
position_cache: std::collections::HashMap<PositionKey, Array>,
norm_buffer_cache: std::collections::HashMap<NormKey, Array>,
}
impl ISTFTCache {
pub fn new() -> ISTFTCache {
ISTFTCache {
position_cache: std::collections::HashMap::new(),
norm_buffer_cache: std::collections::HashMap::new(),
}
}
pub fn len(&self) -> usize {
self.position_cache.len() + self.norm_buffer_cache.len()
}
pub fn is_empty(&self) -> bool {
self.position_cache.is_empty() && self.norm_buffer_cache.is_empty()
}
pub fn clear(&mut self) {
self.position_cache.clear();
self.norm_buffer_cache.clear();
}
pub fn istft(&mut self, spectrum: &Spectrum, length: Option<usize>) -> Result<Array> {
let x = spectrum.data_ref();
let n_fft = spectrum.n_fft();
let hop_length = spectrum.hop_length();
let win_length = spectrum.win_length();
let window_pad = spectrum.window_pad();
let center = spectrum.center();
let num_frames = spectrum.num_frames();
if matches!(window_pad, WindowPad::Right) && win_length != n_fft {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"ISTFTCache::istft: win_length with WindowPad::Right \
(right-pad short-window inversion is not a faithful inverse — \
use WindowPad::Center for short-window inversion)",
"must equal n_fft",
format_smolstr!("win_length={win_length}, n_fft={n_fft}"),
)));
}
let frame_width = n_fft;
let t = (num_frames - 1)
.checked_mul(hop_length)
.and_then(|v| v.checked_add(frame_width))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"ISTFTCache::istft: OLA length ((num_frames - 1) * hop + n_fft)",
"usize",
[
("num_frames", num_frames as u64),
("hop_length", hop_length as u64),
("n_fft", n_fft as u64),
],
))
})?;
if t > crate::audio::io::MAX_DECODED_SAMPLES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"ISTFTCache::istft: OLA length exceeds cap",
"MAX_DECODED_SAMPLES",
crate::audio::io::MAX_DECODED_SAMPLES as u64,
t as u64,
)));
}
let t_i32 = i32::try_from(t).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"ISTFTCache::istft: OLA length",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{t}"),
))
})?;
let n_fft_i32 = i32::try_from(n_fft).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"ISTFTCache::istft: n_fft",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{n_fft}"),
))
})?;
let idx_len = num_frames.checked_mul(frame_width).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"ISTFTCache::istft: scatter work count num_frames * n_fft",
"usize",
[("num_frames", num_frames as u64), ("n_fft", n_fft as u64)],
))
})?;
if idx_len > MAX_OLA_WORK {
return Err(Error::CapExceeded(CapExceededPayload::new(
"ISTFTCache::istft: scatter work count (= num_frames * n_fft) exceeds work cap",
"MAX_OLA_WORK",
MAX_OLA_WORK as u64,
idx_len as u64,
)));
}
let idx_len_i32 = i32::try_from(idx_len).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"ISTFTCache::istft: scatter work count",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{idx_len}"),
))
})?;
let pos_key = PositionKey {
num_frames,
frame_width,
hop_length,
};
if let std::collections::hash_map::Entry::Vacant(slot) = self.position_cache.entry(pos_key) {
let mut idx_buf: Vec<i32> = Vec::new();
idx_buf.try_reserve_exact(idx_len).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"ISTFTCache::istft: index reservation",
"i32 elements",
idx_len as u64,
e,
))
})?;
let frame_width_i32 = i32::try_from(frame_width).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"ISTFTCache::istft: frame_width (n_fft)",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{frame_width}"),
))
})?;
for m in 0..num_frames {
let off = (m * hop_length) as i32;
for j in 0..frame_width_i32 {
idx_buf.push(off + j);
}
}
let indices = Array::from_slice::<i32>(&idx_buf, &[idx_len_i32])?;
slot.insert(indices);
}
let norm_key = NormKey {
n_fft,
hop_length,
win_length,
window_pad,
num_frames,
};
if !self.norm_buffer_cache.contains_key(&norm_key) {
let window = frame_window(win_length, n_fft, window_pad)?;
let window_norm = ops::arithmetic::multiply(&window, &window)?;
let window_norm_row = ops::shape::reshape(&window_norm, &(1usize, frame_width))?;
let window_norm_tiled =
ops::shape::broadcast_to(&window_norm_row, &(num_frames, frame_width))?;
let updates_window = ops::shape::flatten(&window_norm_tiled, 0, -1)?;
let indices = self
.position_cache
.get(&pos_key)
.expect("position_cache populated for pos_key above");
let zeros_wsum = Array::zeros::<f32>(&[t_i32])?;
let window_sum = ops::indexing::scatter_add_axis(&zeros_wsum, indices, &updates_window, 0)?;
self.norm_buffer_cache.insert(norm_key, window_sum);
}
let frames_time = fft::irfft(x, n_fft_i32, 1, FftNorm::Backward)?;
let window = frame_window(win_length, n_fft, window_pad)?;
let windowed = ops::arithmetic::multiply(&frames_time, &window)?;
let updates_reconstructed = ops::shape::flatten(&windowed, 0, -1)?;
let indices = self
.position_cache
.get(&pos_key)
.expect("position_cache populated for pos_key above");
let window_sum = self
.norm_buffer_cache
.get(&norm_key)
.expect("norm_buffer_cache populated for norm_key above");
let zeros_recon = Array::zeros::<f32>(&[t_i32])?;
let reconstructed =
ops::indexing::scatter_add_axis(&zeros_recon, indices, &updates_reconstructed, 0)?;
let pad = n_fft / 2;
let (start_usize, stop_usize) = match (center, length) {
(true, Some(len)) => {
let end = pad.checked_add(len).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"ISTFTCache::istft: center offset pad + length",
"usize",
[("pad", pad as u64), ("len", len as u64)],
))
})?;
if end > t {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"ISTFTCache::istft: center offset pad + length",
"must be <= reconstruction length t",
format_smolstr!("pad={pad}, len={len}, end={end}, t={t}"),
)));
}
(pad, end)
}
(true, None) => (pad, t - pad),
(false, Some(len)) => {
if len > t {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"ISTFTCache::istft: requested length",
"must be <= reconstruction length t",
format_smolstr!("len={len}, t={t}"),
)));
}
(0usize, len)
}
(false, None) => (0usize, t),
};
let start_i32 = i32::try_from(start_usize).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"ISTFTCache::istft: trim start",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{start_usize}"),
))
})?;
let stop_i32 = i32::try_from(stop_usize).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"ISTFTCache::istft: trim stop",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{stop_usize}"),
))
})?;
if start_usize < stop_usize {
let region_wsum = ops::indexing::slice(window_sum, &[start_i32], &[stop_i32], &[1])?;
let mut region_min = ops::reduction::min(®ion_wsum, false)?;
let min_wsum = region_min.item::<f32>()?;
if min_wsum <= COVERAGE_EPS || min_wsum.is_nan() {
let mut min_idx_arr = ops::misc::argmin(®ion_wsum, None, false)?;
let local_idx = min_idx_arr.item::<u32>()? as usize;
let global_idx = start_usize + local_idx;
return Err(Error::OutOfRange(OutOfRangePayload::new(
"ISTFTCache::istft: requested output sample window-sum (received no window coverage \
in the overlap-add and is not recoverable; \
the requested region (e.g. center=false head/tail) includes a zero-coverage sample — \
adjust length/center or the window)",
"must be > COVERAGE_EPS (1e-10) and finite",
format_smolstr!(
"global_idx={global_idx}, local_idx={local_idx}, min_wsum={min_wsum:.3e}, \
n_fft={n_fft}, win_length={win_length}, hop={hop_length}, window_pad={window_pad:?}"
),
)));
}
}
let threshold = Array::full::<f32>(&[0i32; 0], COVERAGE_EPS)?;
let mask = ops::comparison::greater(window_sum, &threshold)?;
let normalized_recon = ops::arithmetic::divide(&reconstructed, window_sum)?;
let reconstructed = ops::logical::select(&mask, &normalized_recon, &reconstructed)?;
ops::indexing::slice(&reconstructed, &[start_i32], &[stop_i32], &[1])
}
}
impl Default for ISTFTCache {
fn default() -> Self {
Self::new()
}
}
#[inline]
fn hz_to_mel(hz: f32) -> f32 {
MEL_HZ_DIV * (1.0 + hz / MEL_HZ_BREAK).log10()
}
#[inline]
fn mel_to_hz(mel: f32) -> f32 {
MEL_HZ_BREAK * (MEL_LOG_BASE.powf(mel / MEL_HZ_DIV) - 1.0)
}
#[inline]
fn hz_to_mel_f64(hz: f64) -> f64 {
f64::from(MEL_HZ_DIV) * (1.0 + hz / f64::from(MEL_HZ_BREAK)).log10()
}
#[inline]
fn mel_to_hz_f64(mel: f64) -> f64 {
f64::from(MEL_HZ_BREAK) * (f64::from(MEL_LOG_BASE).powf(mel / f64::from(MEL_HZ_DIV)) - 1.0)
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Default, derive_more::Display, derive_more::IsVariant,
)]
#[display("{}", self.as_str())]
pub enum MelPrecision {
#[default]
Standard,
Precise,
}
impl MelPrecision {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Standard => "standard",
Self::Precise => "precise",
}
}
}
pub fn mel_filter_bank(
n_mels: usize,
n_fft: usize,
sample_rate: u32,
f_min: f32,
f_max: Option<f32>,
) -> Result<Array> {
mel_filter_bank_with(
n_mels,
n_fft,
sample_rate,
f_min,
f_max,
MelPrecision::Standard,
)
}
pub fn mel_filter_bank_with(
n_mels: usize,
n_fft: usize,
sample_rate: u32,
f_min: f32,
f_max: Option<f32>,
precision: MelPrecision,
) -> Result<Array> {
if n_fft == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"mel_filter_bank: n_fft",
"must be > 0",
)));
}
if n_mels == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"mel_filter_bank: n_mels",
"must be > 0",
)));
}
if sample_rate == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"mel_filter_bank: sample_rate",
"must be > 0",
)));
}
let f_max = f_max.unwrap_or((sample_rate / 2) as f32);
if !(f_min >= 0.0 && f_max > f_min) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"mel_filter_bank: f_min / f_max",
"must satisfy f_min >= 0.0 and f_max > f_min",
format_smolstr!("f_min={f_min}, f_max={f_max}"),
)));
}
let n_freqs = n_fft / 2 + 1;
let n_pts = n_mels.checked_add(2).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"mel_filter_bank: n_mels + 2",
"usize",
[("n_mels", n_mels as u64)],
))
})?;
let bank_len = n_mels.checked_mul(n_freqs).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"mel_filter_bank: n_mels * n_freqs",
"usize",
[("n_mels", n_mels as u64), ("n_freqs", n_freqs as u64)],
))
})?;
let n_mels_i32 = i32::try_from(n_mels).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"mel_filter_bank: n_mels",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{n_mels}"),
))
})?;
let n_freqs_i32 = i32::try_from(n_freqs).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"mel_filter_bank: n_freqs",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{n_freqs}"),
))
})?;
if precision.is_precise() {
return mel_filter_bank_f64(
n_mels,
n_freqs,
n_pts,
bank_len,
n_mels_i32,
n_freqs_i32,
sample_rate,
f_min,
f_max,
);
}
let nyq = (sample_rate / 2) as f32;
let denom = (n_freqs as f32 - 1.0).max(1.0);
let mut all_freqs: Vec<f32> = Vec::new();
all_freqs.try_reserve_exact(n_freqs).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"mel_filter_bank: all_freqs reservation",
"f32 elements",
n_freqs as u64,
e,
))
})?;
for i in 0..n_freqs {
all_freqs.push(i as f32 * nyq / denom);
}
let m_min = hz_to_mel(f_min);
let m_max = hz_to_mel(f_max);
let m_denom = (n_pts as f32 - 1.0).max(1.0);
let mut f_pts: Vec<f32> = Vec::new();
f_pts.try_reserve_exact(n_pts).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"mel_filter_bank: f_pts reservation",
"f32 elements",
n_pts as u64,
e,
))
})?;
for i in 0..n_pts {
let m = m_min + (m_max - m_min) * (i as f32) / m_denom;
f_pts.push(mel_to_hz(m));
}
let mut bank: Vec<f32> = Vec::new();
bank.try_reserve_exact(bank_len).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"mel_filter_bank: bank reservation",
"f32 elements",
bank_len as u64,
e,
))
})?;
let spare = bank.spare_capacity_mut();
crate::simd::audio::mel_triangle::mel_filter_bank_rows(
&mut spare[..bank_len],
&all_freqs,
&f_pts,
n_mels,
);
unsafe { bank.set_len(bank_len) };
Array::from_slice::<f32>(&bank, &[n_mels_i32, n_freqs_i32])
}
#[allow(clippy::too_many_arguments)]
fn mel_filter_bank_f64(
n_mels: usize,
n_freqs: usize,
n_pts: usize,
bank_len: usize,
n_mels_i32: i32,
n_freqs_i32: i32,
sample_rate: u32,
f_min: f32,
f_max: f32,
) -> Result<Array> {
let nyq = f64::from(sample_rate / 2);
let denom = (n_freqs as f64 - 1.0).max(1.0);
let mut all_freqs: Vec<f64> = Vec::new();
all_freqs.try_reserve_exact(n_freqs).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"mel_filter_bank: all_freqs reservation (precise)",
"f64 elements",
n_freqs as u64,
e,
))
})?;
for i in 0..n_freqs {
all_freqs.push(i as f64 * nyq / denom);
}
let m_min = hz_to_mel_f64(f64::from(f_min));
let m_max = hz_to_mel_f64(f64::from(f_max));
let m_denom = (n_pts as f64 - 1.0).max(1.0);
let mut f_pts: Vec<f64> = Vec::new();
f_pts.try_reserve_exact(n_pts).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"mel_filter_bank: f_pts reservation (precise)",
"f64 elements",
n_pts as u64,
e,
))
})?;
for i in 0..n_pts {
let m = m_min + (m_max - m_min) * (i as f64) / m_denom;
f_pts.push(mel_to_hz_f64(m));
}
let mut bank: Vec<f64> = Vec::new();
bank.try_reserve_exact(bank_len).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"mel_filter_bank: bank reservation (precise)",
"f64 elements",
bank_len as u64,
e,
))
})?;
for m in 0..n_mels {
let left = f_pts[m];
let center = f_pts[m + 1];
let right = f_pts[m + 2];
let lc = center - left;
let cr = right - center;
if lc <= 0.0 || cr <= 0.0 {
bank.resize(bank.len() + n_freqs, 0.0);
continue;
}
for &freq in &all_freqs {
let up = (freq - left) / lc;
let down = (right - freq) / cr;
bank.push(up.min(down).max(0.0));
}
}
let mut bank_f32: Vec<f32> = Vec::new();
bank_f32.try_reserve_exact(bank_len).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"mel_filter_bank: bank f32 cast reservation (precise)",
"f32 elements",
bank_len as u64,
e,
))
})?;
for &v in &bank {
bank_f32.push(v as f32);
}
Array::from_slice::<f32>(&bank_f32, &[n_mels_i32, n_freqs_i32])
}
pub const MEL_FILTER_CACHE_CAP: usize = 8;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct MelFilterCacheKey {
n_mels: usize,
n_fft: usize,
sample_rate: u32,
f_min_bits: u32,
f_max_bits: Option<u32>,
precision: MelPrecision,
}
impl MelFilterCacheKey {
fn new(
n_mels: usize,
n_fft: usize,
sample_rate: u32,
f_min: f32,
f_max: Option<f32>,
precision: MelPrecision,
) -> Self {
Self {
n_mels,
n_fft,
sample_rate,
f_min_bits: f_min.to_bits(),
f_max_bits: f_max.map(f32::to_bits),
precision,
}
}
}
thread_local! {
static MEL_FILTER_CACHE: std::cell::RefCell<Vec<(MelFilterCacheKey, Array)>> =
const { std::cell::RefCell::new(Vec::new()) };
}
pub fn mel_filter_bank_cached(
n_mels: usize,
n_fft: usize,
sample_rate: u32,
f_min: f32,
f_max: Option<f32>,
) -> Result<Array> {
mel_filter_bank_cached_with(
n_mels,
n_fft,
sample_rate,
f_min,
f_max,
MelPrecision::Standard,
)
}
pub fn mel_filter_bank_cached_with(
n_mels: usize,
n_fft: usize,
sample_rate: u32,
f_min: f32,
f_max: Option<f32>,
precision: MelPrecision,
) -> Result<Array> {
let key = MelFilterCacheKey::new(n_mels, n_fft, sample_rate, f_min, f_max, precision);
let hit = MEL_FILTER_CACHE.with(|cell| -> Result<Option<Array>> {
let mut cache = cell.borrow_mut();
if let Some(pos) = cache.iter().position(|(k, _)| *k == key) {
let entry = cache.remove(pos);
let clone = entry.1.try_clone()?;
cache.push(entry);
Ok(Some(clone))
} else {
Ok(None)
}
})?;
if let Some(arr) = hit {
return Ok(arr);
}
let bank = mel_filter_bank_with(n_mels, n_fft, sample_rate, f_min, f_max, precision)?;
let for_caller = bank.try_clone()?;
MEL_FILTER_CACHE.with(|cell| {
let mut cache = cell.borrow_mut();
if cache.len() >= MEL_FILTER_CACHE_CAP {
let _ = cache.remove(0);
}
cache.push((key, bank));
});
Ok(for_caller)
}
pub fn clear_mel_filter_cache() {
MEL_FILTER_CACHE.with(|cell| cell.borrow_mut().clear());
}
#[allow(clippy::too_many_arguments)]
pub fn mel_spectrogram(
samples: &Array,
n_fft: usize,
hop_length: usize,
win_length: Option<usize>,
n_mels: usize,
sample_rate: u32,
f_min: f32,
f_max: Option<f32>,
) -> Result<Array> {
let spec = stft(samples, n_fft, hop_length, win_length, WindowPad::Right)?;
let mag = spec.data_ref().abs()?;
let power = mag.square()?;
let mel = mel_filter_bank_cached(n_mels, n_fft, sample_rate, f_min, f_max)?;
let power_t = power.transpose()?;
ops::linalg_basic::matmul(&mel, &power_t)
}
#[allow(clippy::too_many_arguments)]
pub fn log_mel_spectrogram(
samples: &Array,
n_fft: usize,
hop_length: usize,
win_length: Option<usize>,
n_mels: usize,
sample_rate: u32,
f_min: f32,
f_max: Option<f32>,
) -> Result<Array> {
log_mel_spectrogram_with(
samples,
n_fft,
hop_length,
win_length,
n_mels,
sample_rate,
f_min,
f_max,
LogFloor::default(),
)
}
#[allow(clippy::too_many_arguments)]
pub fn log_mel_spectrogram_with(
samples: &Array,
n_fft: usize,
hop_length: usize,
win_length: Option<usize>,
n_mels: usize,
sample_rate: u32,
f_min: f32,
f_max: Option<f32>,
floor: LogFloor,
) -> Result<Array> {
let mel = mel_spectrogram(
samples,
n_fft,
hop_length,
win_length,
n_mels,
sample_rate,
f_min,
f_max,
)?;
let eps = Array::full::<f32>(&[0i32; 0], floor.value())?;
let floored = ops::arithmetic::maximum(&mel, &eps)?;
floored.log()
}
const MAX_LFILTER_SAMPLES: usize = crate::audio::io::MAX_DECODED_SAMPLES;
const MAX_LOUDNESS_SAMPLES: usize = crate::audio::io::MAX_DECODED_SAMPLES;
const MAX_LOUDNESS_BLOCK_BYTES: usize = 64 * 1024 * 1024;
const MAX_LOUDNESS_WORK: usize = 256 * 1024 * 1024;
const BS1770_CHANNEL_GAINS: [f64; 5] = [1.0, 1.0, 1.0, 1.41, 1.41];
const BS1770_ABSOLUTE_THRESHOLD_LUFS: f64 = -70.0;
const BS1770_RELATIVE_OFFSET_LUFS: f64 = 10.0;
const BS1770_LOUDNESS_OFFSET_LUFS: f64 = -0.691;
const BS1770_MAX_CHANNELS: usize = 5;
pub fn lfilter(b: &[f64], a: &[f64], data: &Array) -> Result<Array> {
if data.ndim() != 1 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"lfilter: only supports 1-D input",
data.ndim() as u32,
data.shape(),
)));
}
let shape = data.shape();
let n_samples = shape[0];
let n_samples_i32 = i32::try_from(n_samples).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"lfilter: n_samples",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{n_samples}"),
))
})?;
if n_samples > MAX_LFILTER_SAMPLES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"lfilter: sample count exceeds cap",
"MAX_LFILTER_SAMPLES",
MAX_LFILTER_SAMPLES as u64,
n_samples as u64,
)));
}
let x_f32 = data.try_clone()?.to_vec::<f32>()?;
let mut x_f64: Vec<f64> = Vec::new();
x_f64.try_reserve_exact(n_samples).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"lfilter: input promotion reservation",
"f64 samples",
n_samples as u64,
e,
))
})?;
for v in &x_f32 {
x_f64.push(f64::from(*v));
}
let y_f64 = lfilter_f64(b, a, &x_f64)?;
let mut y: Vec<f32> = Vec::new();
y.try_reserve_exact(n_samples).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"lfilter: output reservation",
"f32 samples",
n_samples as u64,
e,
))
})?;
for v in &y_f64 {
y.push(*v as f32);
}
Array::from_slice::<f32>(&y, &[n_samples_i32])
}
fn lfilter_f64(b: &[f64], a: &[f64], x: &[f64]) -> Result<Vec<f64>> {
if a.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"lfilter: filter denominator (a)",
)));
}
if a[0] == 0.0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"lfilter: filter denominator a[0]",
"must be non-zero (a[0] != 0)",
)));
}
let n_samples = x.len();
if n_samples > MAX_LFILTER_SAMPLES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"lfilter: sample count exceeds cap",
"MAX_LFILTER_SAMPLES",
MAX_LFILTER_SAMPLES as u64,
n_samples as u64,
)));
}
if b.is_empty() {
let mut y: Vec<f64> = Vec::new();
y.try_reserve_exact(n_samples).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"lfilter: zero-output reservation",
"f64 samples",
n_samples as u64,
e,
))
})?;
y.resize(n_samples, 0.0);
return Ok(y);
}
let a0 = a[0];
let mut b_norm: Vec<f64> = Vec::new();
b_norm.try_reserve_exact(b.len()).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"lfilter: numerator (b) normalize reservation",
"f64 taps",
b.len() as u64,
e,
))
})?;
for &bv in b {
b_norm.push(bv / a0);
}
let mut a_norm: Vec<f64> = Vec::new();
a_norm.try_reserve_exact(a.len()).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"lfilter: denominator (a) normalize reservation",
"f64 taps",
a.len() as u64,
e,
))
})?;
for &av in a {
a_norm.push(av / a0);
}
let state_len = a_norm.len().max(b_norm.len()) - 1;
let mut y: Vec<f64> = Vec::new();
y.try_reserve_exact(n_samples).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"lfilter: output reservation",
"f64 samples",
n_samples as u64,
e,
))
})?;
if state_len == 0 {
let b0 = b_norm[0];
for &sample in x {
y.push(b0 * sample);
}
return Ok(y);
}
let mut state: Vec<f64> = Vec::new();
state.try_reserve_exact(state_len).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"lfilter: state reservation",
"f64 taps",
state_len as u64,
e,
))
})?;
state.resize(state_len, 0.0);
for &sample in x {
let output = b_norm[0] * sample + state[0];
for i in 1..state_len {
let feedforward = b_norm.get(i).copied().unwrap_or(0.0) * sample;
let feedback = a_norm.get(i).copied().unwrap_or(0.0) * output;
state[i - 1] = state[i] + feedforward - feedback;
}
let feedforward_last = b_norm.get(state_len).copied().unwrap_or(0.0) * sample;
let feedback_last = a_norm.get(state_len).copied().unwrap_or(0.0) * output;
state[state_len - 1] = feedforward_last - feedback_last;
y.push(output);
}
Ok(y)
}
fn lfilter_f64_in_place(b: &[f64], a: &[f64], x: &mut [f64]) -> Result<()> {
if a.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"lfilter: filter denominator (a)",
)));
}
if a[0] == 0.0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"lfilter: filter denominator a[0]",
"must be non-zero (a[0] != 0)",
)));
}
let n_samples = x.len();
if n_samples > MAX_LFILTER_SAMPLES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"lfilter: sample count exceeds cap (in-place)",
"MAX_LFILTER_SAMPLES",
MAX_LFILTER_SAMPLES as u64,
n_samples as u64,
)));
}
if b.is_empty() {
for v in x.iter_mut() {
*v = 0.0;
}
return Ok(());
}
let a0 = a[0];
let mut b_norm: Vec<f64> = Vec::new();
b_norm.try_reserve_exact(b.len()).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"lfilter (in-place): numerator (b) normalize reservation",
"f64 taps",
b.len() as u64,
e,
))
})?;
for &bv in b {
b_norm.push(bv / a0);
}
let mut a_norm: Vec<f64> = Vec::new();
a_norm.try_reserve_exact(a.len()).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"lfilter (in-place): denominator (a) normalize reservation",
"f64 taps",
a.len() as u64,
e,
))
})?;
for &av in a {
a_norm.push(av / a0);
}
let state_len = a_norm.len().max(b_norm.len()) - 1;
if state_len == 0 {
let b0 = b_norm[0];
for v in x.iter_mut() {
*v *= b0;
}
return Ok(());
}
if state_len == 2 && b_norm.len() == 3 && a_norm.len() == 3 {
crate::simd::audio::lfilter::lfilter_biquad(x, &b_norm, &a_norm);
return Ok(());
}
let mut state: Vec<f64> = Vec::new();
state.try_reserve_exact(state_len).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"lfilter: state reservation",
"f64 taps",
state_len as u64,
e,
))
})?;
state.resize(state_len, 0.0);
for slot in x.iter_mut() {
let sample = *slot;
let output = b_norm[0] * sample + state[0];
for i in 1..state_len {
let feedforward = b_norm.get(i).copied().unwrap_or(0.0) * sample;
let feedback = a_norm.get(i).copied().unwrap_or(0.0) * output;
state[i - 1] = state[i] + feedforward - feedback;
}
let feedforward_last = b_norm.get(state_len).copied().unwrap_or(0.0) * sample;
let feedback_last = a_norm.get(state_len).copied().unwrap_or(0.0) * output;
state[state_len - 1] = feedforward_last - feedback_last;
*slot = output;
}
Ok(())
}
fn bs1770_biquad_coefficients(
gain_db: f64,
q_factor: f64,
center_freq: f64,
rate: f64,
filter_type: BiquadKind,
) -> ([f64; 3], [f64; 3]) {
let amplitude = 10.0_f64.powf(gain_db / 40.0);
let omega = 2.0 * std::f64::consts::PI * (center_freq / rate);
let alpha = omega.sin() / (2.0 * q_factor);
let cos_omega = omega.cos();
let (b0, b1, b2, a0, a1, a2) = match filter_type {
BiquadKind::HighShelf => {
let sqrt_a = amplitude.sqrt();
let b0 =
amplitude * ((amplitude + 1.0) + (amplitude - 1.0) * cos_omega + 2.0 * sqrt_a * alpha);
let b1 = -2.0 * amplitude * ((amplitude - 1.0) + (amplitude + 1.0) * cos_omega);
let b2 =
amplitude * ((amplitude + 1.0) + (amplitude - 1.0) * cos_omega - 2.0 * sqrt_a * alpha);
let a0 = (amplitude + 1.0) - (amplitude - 1.0) * cos_omega + 2.0 * sqrt_a * alpha;
let a1 = 2.0 * ((amplitude - 1.0) - (amplitude + 1.0) * cos_omega);
let a2 = (amplitude + 1.0) - (amplitude - 1.0) * cos_omega - 2.0 * sqrt_a * alpha;
(b0, b1, b2, a0, a1, a2)
}
BiquadKind::HighPass => {
let b0 = (1.0 + cos_omega) / 2.0;
let b1 = -(1.0 + cos_omega);
let b2 = (1.0 + cos_omega) / 2.0;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_omega;
let a2 = 1.0 - alpha;
(b0, b1, b2, a0, a1, a2)
}
};
([b0 / a0, b1 / a0, b2 / a0], [1.0, a1 / a0, a2 / a0])
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, derive_more::IsVariant)]
enum BiquadKind {
HighShelf,
HighPass,
}
impl BiquadKind {
#[allow(dead_code)]
const fn as_str(&self) -> &'static str {
match self {
Self::HighShelf => "high_shelf",
Self::HighPass => "high_pass",
}
}
}
fn k_weight_channel(channel: &[f32], rate: u32) -> Result<Vec<f64>> {
let rate_f64 = f64::from(rate);
let (hs_b, hs_a) = bs1770_biquad_coefficients(
4.0,
1.0 / std::f64::consts::SQRT_2,
1500.0,
rate_f64,
BiquadKind::HighShelf,
);
let (hp_b, hp_a) = bs1770_biquad_coefficients(0.0, 0.5, 38.0, rate_f64, BiquadKind::HighPass);
let n = channel.len();
let mut chan_f64: Vec<f64> = Vec::new();
chan_f64.try_reserve_exact(n).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"k_weight_channel: input promotion reservation",
"f64 samples",
n as u64,
e,
))
})?;
for &v in channel {
chan_f64.push(f64::from(v));
}
lfilter_f64_in_place(&hs_b, &hs_a, &mut chan_f64)?;
lfilter_f64_in_place(&hp_b, &hp_a, &mut chan_f64)?;
Ok(chan_f64)
}
pub fn integrated_loudness(data: &Array, rate: u32, block_size: f64, overlap: f64) -> Result<f64> {
if rate == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"integrated_loudness: rate",
"must be > 0",
"0",
)));
}
if !(block_size > 0.0 && block_size.is_finite()) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"integrated_loudness: block_size",
"must be a finite value > 0",
format!("{block_size}"),
)));
}
if !((0.0..1.0).contains(&overlap) && overlap.is_finite()) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"integrated_loudness: overlap",
"must be a finite value in [0, 1)",
format!("{overlap}"),
)));
}
let shape = data.shape();
let (n_samples, n_channels) = match shape.len() {
1 => (shape[0], 1usize),
2 => {
let (n_samples, n_channels) = (shape[0], shape[1]);
if n_channels > BS1770_MAX_CHANNELS {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"integrated_loudness: n_channels",
"must be at most 5 (BS.1770 standard limit)",
format!("{n_channels}"),
)));
}
if n_channels == 0 {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"integrated_loudness: audio channels (must have at least 1 channel)",
)));
}
(n_samples, n_channels)
}
other => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"integrated_loudness: data must be 1-D (mono) or 2-D (n_samples, n_channels)",
other as u32,
data.shape(),
)));
}
};
let total_elements = n_samples.checked_mul(n_channels).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"integrated_loudness: total element count n_samples * n_channels",
"usize",
[
("n_samples", n_samples as u64),
("n_channels", n_channels as u64),
],
))
})?;
if total_elements > MAX_LOUDNESS_SAMPLES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"integrated_loudness: total element count (= n_samples * n_channels) exceeds cap",
"MAX_LOUDNESS_SAMPLES",
MAX_LOUDNESS_SAMPLES as u64,
total_elements as u64,
)));
}
let rate_f64 = f64::from(rate);
let block_samples_f64 = block_size * rate_f64;
if !block_samples_f64.is_finite() || block_samples_f64 < 1.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"integrated_loudness: block_size * rate",
"must be finite and >= 1 sample",
format_smolstr!("block_samples={block_samples_f64}, block_size={block_size}, rate={rate}"),
)));
}
if (n_samples as f64) < block_samples_f64 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"integrated_loudness: audio length (samples)",
"must be greater than the block size",
format!("{n_samples} samples (block_size*rate = {block_samples_f64:.1} samples)"),
)));
}
let step = 1.0 - overlap;
let duration_seconds = n_samples as f64 / rate_f64;
let num_blocks_f64 =
((duration_seconds - block_size) / (block_size * step)).round_ties_even() + 1.0;
if !num_blocks_f64.is_finite() || num_blocks_f64 < 1.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"integrated_loudness: derived num_blocks",
"must be finite and >= 1",
format_smolstr!(
"num_blocks={num_blocks_f64}, duration={duration_seconds}, \
block_size={block_size}, overlap={overlap}"
),
)));
}
if num_blocks_f64 > usize::MAX as f64 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"integrated_loudness: derived num_blocks",
"must fit in usize",
format_smolstr!(
"num_blocks={num_blocks_f64}, duration={duration_seconds}, \
block_size={block_size}, overlap={overlap}"
),
)));
}
let num_blocks = num_blocks_f64 as usize;
if num_blocks == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"integrated_loudness: derived num_blocks",
"must be >= 1",
"0",
)));
}
let block_cells = num_blocks.checked_mul(n_channels).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"integrated_loudness: block cells num_blocks * n_channels",
"usize",
[
("num_blocks", num_blocks as u64),
("n_channels", n_channels as u64),
],
))
})?;
let block_bytes = block_cells
.checked_mul(std::mem::size_of::<f64>())
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"integrated_loudness: block bytes block_cells * 8",
"usize",
[("block_cells", block_cells as u64)],
))
})?;
if block_bytes > MAX_LOUDNESS_BLOCK_BYTES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"integrated_loudness: mean-square byte footprint exceeds cap",
"MAX_LOUDNESS_BLOCK_BYTES",
MAX_LOUDNESS_BLOCK_BYTES as u64,
block_bytes as u64,
)));
}
let block_samples_usize: usize = block_samples_f64.ceil() as usize;
let total_work = num_blocks
.checked_mul(block_samples_usize)
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"integrated_loudness: partial work num_blocks * block_samples",
"usize",
[
("num_blocks", num_blocks as u64),
("block_samples", block_samples_usize as u64),
],
))
})?
.checked_mul(n_channels)
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"integrated_loudness: total work num_blocks * block_samples * n_channels",
"usize",
[
("num_blocks", num_blocks as u64),
("block_samples", block_samples_usize as u64),
("n_channels", n_channels as u64),
],
))
})?;
if total_work > MAX_LOUDNESS_WORK {
return Err(Error::CapExceeded(CapExceededPayload::new(
"integrated_loudness: total sample-visit work \
(= num_blocks * block_samples * n_channels) exceeds cap",
"MAX_LOUDNESS_WORK",
MAX_LOUDNESS_WORK as u64,
total_work as u64,
)));
}
let raw_f32 = data.try_clone()?.to_vec::<f32>()?;
if raw_f32.len() != total_elements {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"integrated_loudness: internal shape mismatch — raw_f32 sample count",
total_elements,
raw_f32.len(),
)));
}
let mut mean_square: Vec<Vec<f64>> = Vec::new();
mean_square.try_reserve_exact(n_channels).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"integrated_loudness: mean_square channels reservation",
"Vec<f64> rows",
n_channels as u64,
e,
))
})?;
for _ in 0..n_channels {
let mut row: Vec<f64> = Vec::new();
row.try_reserve_exact(num_blocks).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"integrated_loudness: mean_square blocks reservation",
"f64 blocks",
num_blocks as u64,
e,
))
})?;
row.resize(num_blocks, 0.0);
mean_square.push(row);
}
let mut chan_f32: Vec<f32> = Vec::new();
chan_f32.try_reserve_exact(n_samples).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"integrated_loudness: chan_f32 reservation",
"f32 samples",
n_samples as u64,
e,
))
})?;
for (c, ms_row) in mean_square.iter_mut().enumerate() {
chan_f32.clear();
for i in 0..n_samples {
chan_f32.push(raw_f32[i * n_channels + c]);
}
let weighted = k_weight_channel(&chan_f32, rate)?;
for (block_index, ms_cell) in ms_row.iter_mut().enumerate() {
let bi_f64 = block_index as f64;
let lower_f64 = block_size * (bi_f64 * step) * rate_f64;
let upper_f64 = block_size * (bi_f64 * step + 1.0) * rate_f64;
let lower = lower_f64 as usize;
let upper = (upper_f64 as usize).min(weighted.len());
if upper <= lower {
continue;
}
*ms_cell = crate::simd::sum_of_squares(&weighted[lower..upper]) / block_samples_f64;
}
drop(weighted);
}
drop(raw_f32);
drop(chan_f32);
let mut block_loudness: Vec<f64> = Vec::new();
block_loudness.try_reserve_exact(num_blocks).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"integrated_loudness: block_loudness reservation",
"f64 blocks",
num_blocks as u64,
e,
))
})?;
for b in 0..num_blocks {
let mut weighted_sum = 0.0_f64;
for (gain, ms_row) in BS1770_CHANNEL_GAINS.iter().zip(mean_square.iter()) {
weighted_sum += gain * ms_row[b];
}
block_loudness.push(BS1770_LOUDNESS_OFFSET_LUFS + 10.0 * weighted_sum.log10());
}
let gated_mean_per_channel = |pred: &dyn Fn(f64) -> bool| -> Vec<f64> {
let mut out = Vec::with_capacity(n_channels);
for ms_row in mean_square.iter() {
let mut acc = 0.0_f64;
let mut count: usize = 0;
for (b, &l) in block_loudness.iter().enumerate() {
if pred(l) {
acc += ms_row[b];
count += 1;
}
}
if count == 0 {
out.push(f64::NAN);
} else {
out.push(acc / count as f64);
}
}
out
};
let gated_mean_square_abs = gated_mean_per_channel(&|l| l >= BS1770_ABSOLUTE_THRESHOLD_LUFS);
let mut weighted_abs = 0.0_f64;
for (gain, &gms) in BS1770_CHANNEL_GAINS
.iter()
.zip(gated_mean_square_abs.iter())
{
weighted_abs += gain * gms;
}
let relative_threshold =
BS1770_LOUDNESS_OFFSET_LUFS + 10.0 * weighted_abs.log10() - BS1770_RELATIVE_OFFSET_LUFS;
let mut gated_mean_square_rel =
gated_mean_per_channel(&|l| l > relative_threshold && l > BS1770_ABSOLUTE_THRESHOLD_LUFS);
for v in gated_mean_square_rel.iter_mut() {
if v.is_nan() {
*v = 0.0;
}
}
let mut weighted_rel = 0.0_f64;
for (gain, &gms) in BS1770_CHANNEL_GAINS
.iter()
.zip(gated_mean_square_rel.iter())
{
weighted_rel += gain * gms;
}
Ok(BS1770_LOUDNESS_OFFSET_LUFS + 10.0 * weighted_rel.log10())
}
pub fn normalize_loudness(
data: &Array,
input_loudness: f64,
target_loudness: f64,
) -> Result<Array> {
if !input_loudness.is_finite() {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"normalize_loudness: input_loudness",
"must be finite (NaN/±inf would yield a non-finite gain)",
format!("{input_loudness}"),
)));
}
if !target_loudness.is_finite() {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"normalize_loudness: target_loudness",
"must be finite (NaN/±inf would yield a non-finite gain)",
format!("{target_loudness}"),
)));
}
let delta = target_loudness - input_loudness;
let gain_f64 = 10.0_f64.powf(delta / 20.0);
let gain = gain_f64 as f32;
let gain_arr = Array::full::<f32>(&[0i32; 0], gain)?;
ops::arithmetic::multiply(data, &gain_arr)
}
pub fn normalize_peak(data: &Array, target_peak_db: f64) -> Result<Array> {
if !target_peak_db.is_finite() {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"normalize_peak: target_peak_db",
"must be finite (NaN/±inf cannot represent a dB level)",
format!("{target_peak_db}"),
)));
}
if data.size() == 0 {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"normalize_peak: data must be non-empty (max over an empty array is undefined)",
)));
}
let abs_data = data.abs()?;
let mut peak_arr = ops::reduction::max(&abs_data, false)?;
let current_peak = peak_arr.item::<f32>()?;
if !current_peak.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"normalize_peak: current peak max(|data|) (the input contains a NaN or infinite sample)",
current_peak as f64,
)));
}
if current_peak == 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"normalize_peak: current peak max(|data|)",
"must be > 0 — an all-silence input cannot be peak-normalized (the gain would divide by zero)",
"0",
)));
}
let target_linear = 10.0_f64.powf(target_peak_db / 20.0) as f32;
let gain = target_linear / current_peak;
if !target_linear.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"normalize_peak: target amplitude 10^(target_peak_db / 20) \
(target_peak_db is too large to represent as a finite f32 gain)",
target_linear as f64,
)));
}
if !gain.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"normalize_peak: gain (= target_linear / current_peak) \
(the current peak is too small for target_peak_db — scaling overflows to non-finite)",
gain as f64,
)));
}
let gain_arr = Array::full::<f32>(&[0i32; 0], gain)?;
ops::arithmetic::multiply(data, &gain_arr)
}
#[cfg(test)]
mod tests;