dasp_rs/features/
spectral.rs

1use ndarray::{s, stack, Array1, Array2, Axis};
2use rayon::prelude::*;
3use crate::signal_processing::time_frequency::{stft, cqt};
4use crate::signal_processing::time_domain::{autocorrelate, log_energy};
5use crate::hz_to_midi;
6use ndarray_linalg::{Solve, Eig};
7use num_complex::Complex;
8use thiserror::Error;
9use crate::core::io::{AudioError, AudioData};
10use crate::utils::frequency::fft_frequencies;
11
12/// Custom error types for spectral signal processing operations.
13///
14/// This enum defines errors specific to spectral feature extraction and analysis.
15#[derive(Error, Debug)]
16pub enum SpectralError {
17    /// Error when a parameter (e.g., frame length, hop length) is invalid.
18    #[error("Invalid parameter: {0}")]
19    InvalidParameter(String),
20
21    /// Error when input dimensions or sizes are insufficient or mismatched.
22    #[error("Invalid input size: {0}")]
23    InvalidSize(String),
24
25    /// Error during numerical computations (e.g., matrix solving, eigenvalue decomposition).
26    #[error("Numerical error: {0}")]
27    Numerical(String),
28
29    /// Wraps an AudioError from the core module (e.g., from time-domain functions).
30    #[error("Audio processing error: {0}")]
31    Audio(#[from] AudioError),
32
33    /// A variant for TimeDomainError
34    #[error("Time-domain processing error: {0}")]
35    TimeDomain(String),
36
37    /// Wraps a time-frequency processing error (e.g., from STFT or CQT).
38    #[error("Time-frequency error: {0}")]
39    TimeFrequency(String),
40}
41
42/// Computes chroma features using Short-Time Fourier Transform (STFT).
43///
44/// Maps spectral energy to 12 pitch classes based on a magnitude spectrogram.
45///
46/// # Arguments
47/// * `signal` - The input audio signal.
48/// * `s` - Optional pre-computed magnitude spectrogram.
49/// * `norm` - Optional normalization factor.
50/// * `n_fft` - Optional FFT window size (defaults to 2048).
51/// * `hop_length` - Optional hop length (defaults to n_fft/4).
52/// * `tuning` - Optional tuning adjustment in semitones (currently unused).
53///
54/// # Returns
55/// Returns `Result<Array2<f32>, SpectralError>` containing a 2D array of shape `(12, n_frames)`
56/// with chroma features, or an error.
57///
58/// # Examples
59/// ```
60/// use dasp_rs::io::core::AudioData;
61/// use dasp_rs::signal_processing::spectral::chroma_stft;
62/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
63/// let chroma = chroma_stft(&signal, None, None, None, None, None).unwrap();
64/// assert_eq!(chroma.shape(), &[12, 1]);
65/// ```
66pub fn chroma_stft(
67    signal: &AudioData,
68    s: Option<&Array2<f32>>,
69    norm: Option<f32>,
70    n_fft: Option<usize>,
71    hop_length: Option<usize>,
72) -> Result<Array2<f32>, SpectralError> {
73    let n_fft = n_fft.unwrap_or(2048);
74    let hop = hop_length.unwrap_or(n_fft / 4);
75    
76    if n_fft == 0 || hop == 0 {
77        return Err(SpectralError::InvalidParameter(
78            "n_fft and hop_length must be positive".into(),
79        ));
80    }
81    if signal.samples.len() < n_fft {
82        return Err(SpectralError::InvalidSize(
83            "Signal length must be at least n_fft".into(),
84        ));
85    }
86
87    let s = match s {
88        Some(s) => s.to_owned(),
89        None => stft(&signal.samples, Some(n_fft), Some(hop), None)
90            .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?
91            .mapv(|x| x.norm().powi(2)),
92    };
93
94    let n_bins = s.shape()[0];
95    let freqs = fft_frequencies(Some(signal.sample_rate), Some(n_fft));
96
97    let pitch_classes: Vec<usize> = (1..n_bins)
98        .map(|bin| {
99            let midi = hz_to_midi(&[freqs[bin]])[0];
100            (midi.round() as isize).rem_euclid(12) as usize
101        })
102        .collect();
103
104    if let Some(norm_val) = norm {
105        if norm_val <= 0.0 {
106            return Err(SpectralError::InvalidParameter(
107                "Normalization factor must be positive".into(),
108            ));
109        }
110    }
111
112    let n_frames = s.shape()[1];
113
114    let chroma_cols: Vec<Array1<f32>> = (0..n_frames)
115        .into_par_iter()
116        .map(|frame| {
117            let mut temp = Array1::zeros(12);
118            for bin in 1..n_bins {
119                let pitch_class = pitch_classes[bin - 1];
120                temp[pitch_class] += s[[bin, frame]];
121            }
122            if let Some(norm_val) = norm {
123                temp /= norm_val;
124            }
125            temp
126        })
127        .collect();
128
129    let views: Vec<_> = chroma_cols.iter().map(|col| col.view()).collect();
130    let chroma = stack(Axis(1), views.as_slice()).map_err(|e| {
131        SpectralError::Numerical(format!("Failed to stack chroma columns: {}", e))
132    })?;
133
134    Ok(chroma)
135}
136
137/// Computes chroma features using Constant-Q Transform (CQT).
138///
139/// Maps CQT spectral energy to 12 pitch classes.
140///
141/// # Arguments
142/// * `signal` - The input audio signal.
143/// * `c` - Optional pre-computed CQT spectrogram.
144/// * `hop_length` - Optional hop length (defaults to 512).
145/// * `fmin` - Optional minimum frequency (defaults to 32.70 Hz, C1).
146/// * `bins_per_octave` - Optional bins per octave (defaults to 12).
147///
148/// # Returns
149/// Returns `Result<Array2<f32>, SpectralError>` containing a 2D array of shape `(12, n_frames)`
150/// with chroma features, or an error.
151///
152/// # Examples
153/// ```
154/// use dasp_rs::io::core::AudioData;
155/// use dasp_rs::signal_processing::spectral::chroma_cqt;
156/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
157/// let chroma = chroma_cqt(&signal, None, None, None, None).unwrap();
158/// assert_eq!(chroma.shape(), &[12, 1]);
159/// ```
160pub fn chroma_cqt(
161    signal: &AudioData,
162    c: Option<&Array2<f32>>,
163    hop_length: Option<usize>,
164    fmin: Option<f32>,
165    bins_per_octave: Option<usize>,
166) -> Result<Array2<f32>, SpectralError> {
167    let hop = hop_length.unwrap_or(512);
168    let fmin = fmin.unwrap_or(32.70);
169    let bpo = bins_per_octave.unwrap_or(12);
170    
171    if hop == 0 {
172        return Err(SpectralError::InvalidParameter("hop_length must be positive".into()));
173    }
174    if fmin <= 0.0 {
175        return Err(SpectralError::InvalidParameter("fmin must be positive".into()));
176    }
177    if bpo == 0 {
178        return Err(SpectralError::InvalidParameter("bins_per_octave must be positive".into()));
179    }
180
181    let nyquist = signal.sample_rate as f32 / 2.0;
182    if fmin >= nyquist {
183        return Err(SpectralError::InvalidParameter("fmin must be less than Nyquist frequency".into()));
184    }
185    let max_bin = (nyquist / fmin).log2() * bpo as f32;
186    let n_bins = max_bin.floor() as usize + 1;
187
188    let c: Array2<f32> = match c {
189            Some(c_mag) => Ok::<Array2<f32>, SpectralError>(c_mag.to_owned()),
190            None => {
191                let cqt_result = cqt(signal, Some(hop), Some(fmin), Some(n_bins))
192                    .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?;
193                Ok(cqt_result.mapv(|x| x.norm()))
194            }
195        }?;
196
197    let mut pitch_classes = Vec::with_capacity(n_bins);
198    for bin in 0..n_bins {
199        let freq = fmin * 2.0f32.powf(bin as f32 / bpo as f32);
200        if !freq.is_finite() || freq <= 0.0 {
201            return Err(SpectralError::Numerical("Invalid frequency computed for bin".into()));
202        }
203        let midi = hz_to_midi(&[freq])[0];
204        if !midi.is_finite() {
205            return Err(SpectralError::Numerical("Invalid MIDI value computed".into()));
206        }
207        let pitch_class = (midi.round() as usize) % 12;
208        pitch_classes.push(pitch_class);
209    }
210
211    let n_frames = c.shape()[1];
212    let chroma_cols: Vec<Array1<f32>> = (0..n_frames)
213        .into_par_iter()
214        .map(|frame| {
215            let mut chroma_frame = Array1::zeros(12);
216            for bin in 0..n_bins {
217                let pc = pitch_classes[bin];
218                chroma_frame[pc] += c[[bin, frame]];
219            }
220            chroma_frame
221        })
222        .collect();
223
224    let views: Vec<_> = chroma_cols.iter().map(|col| col.view()).collect();
225    let chroma = stack(Axis(1), views.as_slice())
226        .map_err(|e| SpectralError::Numerical(e.to_string()))?;
227
228    Ok(chroma)
229}
230
231/// Computes Chroma Energy Normalized Statistics (CENS) features.
232///
233/// Normalizes chroma features over a window to emphasize energy distribution.
234///
235/// # Arguments
236/// * `signal` - The input audio signal.
237/// * `C` - Optional pre-computed CQT spectrogram.
238/// * `hop_length` - Optional hop length (defaults to 512).
239/// * `fmin` - Optional minimum frequency (defaults to 32.70 Hz).
240/// * `bins_per_octave` - Optional bins per octave (defaults to 12).
241/// * `win_length` - Optional window length for normalization (defaults to 41).
242///
243/// # Returns
244/// Returns `Result<Array2<f32>, SpectralError>` containing a 2D array of shape `(12, n_frames)`
245/// with CENS features, or an error.
246///
247/// # Examples
248/// ```
249/// use dasp_rs::io::core::AudioData;
250/// use dasp_rs::signal_processing::spectral::chroma_cens;
251/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
252/// let cens = chroma_cens(&signal, None, None, None, None, None).unwrap();
253/// assert_eq!(cens.shape(), &[12, 1]);
254/// ```
255pub fn chroma_cens(
256    signal: &AudioData,
257    c: Option<&Array2<f32>>,
258    hop_length: Option<usize>,
259    fmin: Option<f32>,
260    bins_per_octave: Option<usize>,
261    win_length: Option<usize>,
262) -> Result<Array2<f32>, SpectralError> {
263    let win = win_length.unwrap_or(41);
264    if win == 0 {
265        return Err(SpectralError::InvalidParameter(
266            "win_length must be positive".to_string(),
267        ));
268    }
269    let chroma = chroma_cqt(signal, c, hop_length, fmin, bins_per_octave)?;
270    let half_win = win / 2;
271    let mut cens = Array2::zeros(chroma.dim());
272    for t in 0..chroma.shape()[1] {
273        let start = t.saturating_sub(half_win);
274        let end = (t + half_win + 1).min(chroma.shape()[1]);
275        let slice = chroma.slice(s![.., start..end]);
276        let norm = slice
277            .mapv(|x| x.powi(2))
278            .sum_axis(Axis(1))
279            .mapv(f32::sqrt);
280        for p in 0..12 {
281            cens[[p, t]] = if norm[p] > 1e-6 {
282                chroma[[p, t]] / norm[p]
283            } else {
284                0.0
285            };
286        }
287    }
288    Ok(cens)
289}
290
291/// Computes a mel spectrogram.
292///
293/// Projects spectral energy onto mel-frequency bands.
294///
295/// # Arguments
296/// * `signal` - The input audio signal.
297/// * `S` - Optional pre-computed magnitude spectrogram.
298/// * `n_fft` - Optional FFT window size (defaults to 2048).
299/// * `hop_length` - Optional hop length (defaults to n_fft/4).
300/// * `n_mels` - Optional number of mel bands (defaults to 128).
301/// * `fmin` - Optional minimum frequency (defaults to 0 Hz).
302/// * `fmax` - Optional maximum frequency (defaults to sr/2).
303///
304/// # Returns
305/// Returns `Result<Array2<f32>, SpectralError>` containing a 2D array of shape `(n_mels, n_frames)`
306/// with mel spectrogram, or an error.
307///
308/// # Examples
309/// ```
310/// use dasp_rs::io::core::AudioData;
311/// use dasp_rs::signal_processing::spectral::melspectrogram;
312/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
313/// let mel = melspectrogram(&signal, None, None, None, None, None, None).unwrap();
314/// assert_eq!(mel.shape(), &[128, 1]);
315/// ```
316pub fn melspectrogram(
317    signal: &AudioData,
318    s: Option<&Array2<f32>>,
319    n_fft: Option<usize>,
320    hop_length: Option<usize>,
321    n_mels: Option<usize>,
322    fmin: Option<f32>,
323    fmax: Option<f32>,
324) -> Result<Array2<f32>, SpectralError> {
325    let n_fft = n_fft.unwrap_or(2048);
326    let hop = hop_length.unwrap_or(n_fft / 4);
327    let n_mels = n_mels.unwrap_or(128);
328    let fmin = fmin.unwrap_or(0.0);
329    let fmax = fmax.unwrap_or(signal.sample_rate as f32 / 2.0);
330    if n_fft == 0 || hop == 0 || n_mels == 0 {
331        return Err(SpectralError::InvalidParameter(
332            "n_fft, hop_length, and n_mels must be positive".to_string(),
333        ));
334    }
335    if fmin < 0.0 || fmax <= fmin || fmax > signal.sample_rate as f32 / 2.0 {
336        return Err(SpectralError::InvalidParameter(
337            "fmin and fmax must satisfy 0 <= fmin < fmax <= sr/2".to_string(),
338        ));
339    }
340    if signal.samples.len() < n_fft {
341        return Err(SpectralError::InvalidSize(
342            "Signal length must be at least n_fft".to_string(),
343        ));
344    }
345
346    let s = match s {
347        Some(s) => s.to_owned(),
348        None => stft(&signal.samples, Some(n_fft), Some(hop), None)
349            .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?
350            .mapv(|x| x.norm().powi(2)),
351    };
352
353    let mel_f = crate::mel_frequencies(Some(n_mels), Some(fmin), Some(fmax), None);
354    let mut mel_s = Array2::zeros((n_mels, s.shape()[1]));
355    let fft_f = fft_frequencies(Some(signal.sample_rate), Some(n_fft));
356    for m in 0..n_mels {
357        let f_low = if m == 0 { fmin } else { mel_f[m - 1] };
358        let f_center = mel_f[m];
359        let f_high = mel_f.get(m + 1).copied().unwrap_or(fmax);
360        for (bin, &f) in fft_f.iter().enumerate() {
361            let weight = if f >= f_low && f <= f_high {
362                if f <= f_center {
363                    (f - f_low) / (f_center - f_low)
364                } else {
365                    (f_high - f) / (f_high - f_center)
366                }
367            } else {
368                0.0
369            };
370            for t in 0..s.shape()[1] {
371                mel_s[[m, t]] += s[[bin, t]] * weight.max(0.0);
372            }
373        }
374    }
375    Ok(mel_s)
376}
377
378/// Computes Mel-frequency cepstral coefficients (MFCCs).
379///
380/// Extracts MFCCs from a mel spectrogram using DCT.
381///
382/// # Arguments
383/// * `signal` - The input audio signal.
384/// * `S` - Optional pre-computed spectrogram.
385/// * `n_mfcc` - Optional number of MFCCs (defaults to 20).
386/// * `dct_type` - Optional DCT type (defaults to 2; only 2 is supported).
387/// * `norm` - Optional normalization type ("ortho" or None).
388///
389/// # Returns
390/// Returns `Result<Array2<f32>, SpectralError>` containing a 2D array of shape `(n_mfcc, n_frames)`
391/// with MFCCs, or an error.
392///
393/// # Examples
394/// ```
395/// use dasp_rs::io::core::AudioData;
396/// use dasp_rs::signal_processing::spectral::mfcc;
397/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
398/// let mfcc = mfcc(&signal, None, None, None, None).unwrap();
399/// assert_eq!(mfcc.shape(), &[20, 1]);
400/// ```
401pub fn mfcc(
402    signal: &AudioData,
403    s: Option<&Array2<f32>>,
404    n_mfcc: Option<usize>,
405    dct_type: Option<i32>,
406    norm: Option<&str>,
407) -> Result<Array2<f32>, SpectralError> {
408    let n_mfcc = n_mfcc.unwrap_or(20);
409    let dct_type = dct_type.unwrap_or(2);
410    if n_mfcc == 0 {
411        return Err(SpectralError::InvalidParameter(
412            "n_mfcc must be positive".to_string(),
413        ));
414    }
415    if dct_type != 2 {
416        return Err(SpectralError::InvalidParameter(
417            "Only DCT type 2 is supported".to_string(),
418        ));
419    }
420    if let Some(n) = norm {
421        if n != "ortho" {
422            return Err(SpectralError::InvalidParameter(
423                "norm must be 'ortho' or None".to_string(),
424            ));
425        }
426    }
427
428    let s = match s {
429        Some(s) => s.to_owned(),
430        None => melspectrogram(signal, None, None, None, None, None, None)?,
431    };
432    let log_s = s.mapv(|x| x.max(1e-10).ln());
433    let mut mfcc = Array2::zeros((n_mfcc, s.shape()[1]));
434    for t in 0..s.shape()[1] {
435        for k in 0..n_mfcc {
436            let mut sum = 0.0;
437            for n in 0..s.shape()[0] {
438                sum += log_s[[n, t]] * (std::f32::consts::PI * k as f32 * (n as f32 + 0.5) / s.shape()[0] as f32).cos();
439            }
440            mfcc[[k, t]] = sum * if k == 0 { 1.0 / f32::sqrt(2.0) } else { 1.0 } * 2.0 / s.shape()[0] as f32;
441        }
442    }
443    if norm == Some("ortho") {
444        mfcc *= f32::sqrt(2.0 / s.shape()[0] as f32);
445    }
446    Ok(mfcc)
447}
448
449/// Computes root mean square (RMS) energy.
450///
451/// Calculates RMS energy per frame from either the signal or a spectrogram.
452///
453/// # Arguments
454/// * `signal` - The input audio signal.
455/// * `S` - Optional pre-computed spectrogram.
456/// * `frame_length` - Optional frame length (defaults to 2048).
457/// * `hop_length` - Optional hop length (defaults to frame_length/4).
458///
459/// # Returns
460/// Returns `Result<Array1<f32>, SpectralError>` containing RMS values per frame, or an error.
461///
462/// # Examples
463/// ```
464/// use dasp_rs::io::core::AudioData;
465/// use dasp_rs::signal_processing::spectral::rms;
466/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
467/// let rms = rms(&signal, None, None, None).unwrap();
468/// assert_eq!(rms.len(), 1);
469/// ```
470pub fn rms(
471    signal: &AudioData,
472    s: Option<&Array2<f32>>,
473    frame_length: Option<usize>,
474    hop_length: Option<usize>,
475) -> Result<Array1<f32>, SpectralError> {
476    let frame_len = frame_length.unwrap_or(2048);
477    let hop = hop_length.unwrap_or(frame_len / 4);
478    if frame_len == 0 || hop == 0 {
479        return Err(SpectralError::InvalidParameter(
480            "frame_length and hop_length must be positive".to_string(),
481        ));
482    }
483
484    match s {
485        Some(s) => Ok(s.map_axis(Axis(0), |row| {
486            f32::sqrt(row.iter().map(|x| x.powi(2)).sum::<f32>() / row.len() as f32)
487        })),
488        None => {
489            if signal.samples.len() < frame_len {
490                return Err(SpectralError::InvalidSize(
491                    "Signal length must be at least frame_length".to_string(),
492                ));
493            }
494            let n_frames = (signal.samples.len() - frame_len) / hop + 1;
495            let mut rms = Array1::zeros(n_frames);
496            for i in 0..n_frames {
497                let start = i * hop;
498                let slice = &signal.samples[start..(start + frame_len).min(signal.samples.len())];
499                rms[i] = f32::sqrt(slice.iter().map(|x| x.powi(2)).sum::<f32>() / slice.len() as f32);
500            }
501            Ok(rms)
502        }
503    }
504}
505
506/// Computes spectral centroid frequencies.
507///
508/// Represents the "center of mass" of the spectrum per frame.
509///
510/// # Arguments
511/// * `signal` - The input audio signal.
512/// * `S` - Optional pre-computed spectrogram.
513/// * `n_fft` - Optional FFT window size (defaults to 2048).
514/// * `hop_length` - Optional hop length (defaults to n_fft/4).
515///
516/// # Returns
517/// Returns `Result<Array1<f32>, SpectralError>` containing centroid frequencies per frame, or an error.
518///
519/// # Examples
520/// ```
521/// use dasp_rs::io::core::AudioData;
522/// use dasp_rs::signal_processing::spectral::spectral_centroid;
523/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
524/// let centroid = spectral_centroid(&signal, None, None, None).unwrap();
525/// assert_eq!(centroid.len(), 1);
526/// ```
527pub fn spectral_centroid(
528    signal: &AudioData,
529    s: Option<&Array2<f32>>,
530    n_fft: Option<usize>,
531    hop_length: Option<usize>,
532) -> Result<Array1<f32>, SpectralError> {
533    let n_fft = n_fft.unwrap_or(2048);
534    let hop = hop_length.unwrap_or(n_fft / 4);
535    if n_fft == 0 || hop == 0 {
536        return Err(SpectralError::InvalidParameter(
537            "n_fft and hop_length must be positive".to_string(),
538        ));
539    }
540    if signal.samples.len() < n_fft {
541        return Err(SpectralError::InvalidSize(
542            "Signal length must be at least n_fft".to_string(),
543        ));
544    }
545
546    let s = match s {
547        Some(s) => s.to_owned(),
548        None => stft(&signal.samples, Some(n_fft), Some(hop), None)
549            .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?
550            .mapv(|x| x.norm()),
551    };
552
553    let freqs = fft_frequencies(Some(signal.sample_rate), Some(n_fft));
554    Ok(s.axis_iter(Axis(1))
555        .map(|frame| {
556            let total = frame.sum();
557            if total > 1e-6 {
558                frame.dot(&Array1::from_vec(freqs.clone())) / total
559            } else {
560                0.0
561            }
562        })
563        .collect())
564}
565
566/// Computes spectral bandwidth.
567///
568/// Measures the spread of the spectrum around the centroid per frame.
569///
570/// # Arguments
571/// * `signal` - The input audio signal.
572/// * `S` - Optional pre-computed spectrogram.
573/// * `n_fft` - Optional FFT window size (defaults to 2048).
574/// * `hop_length` - Optional hop length (defaults to n_fft/4).
575/// * `p` - Optional power for bandwidth calculation (defaults to 2).
576///
577/// # Returns
578/// Returns `Result<Array1<f32>, SpectralError>` containing bandwidth values per frame, or an error.
579///
580/// # Examples
581/// ```
582/// use dasp_rs::io::core::AudioData;
583/// use dasp_rs::signal_processing::spectral::spectral_bandwidth;
584/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
585/// let bandwidth = spectral_bandwidth(&signal, None, None, None, None).unwrap();
586/// assert_eq!(bandwidth.len(), 1);
587/// ```
588pub fn spectral_bandwidth(
589    signal: &AudioData,
590    s: Option<&Array2<f32>>,
591    n_fft: Option<usize>,
592    hop_length: Option<usize>,
593    p: Option<i32>,
594) -> Result<Array1<f32>, SpectralError> {
595    let p = p.unwrap_or(2);
596    if p <= 0 {
597        return Err(SpectralError::InvalidParameter(
598            "p must be positive".to_string(),
599        ));
600    }
601    let centroid = spectral_centroid(signal, s, n_fft, hop_length)?;
602    let n_fft = n_fft.unwrap_or(2048);
603    let hop = hop_length.unwrap_or(n_fft / 4);
604    let s = match s {
605        Some(s) => s.to_owned(),
606        None => stft(&signal.samples, Some(n_fft), Some(hop), None)
607            .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?
608            .mapv(|x| x.norm()),
609    };
610    let freqs = fft_frequencies(Some(signal.sample_rate), Some(n_fft));
611    Ok(s.axis_iter(Axis(1))
612        .zip(centroid.iter())
613        .map(|(frame, &c)| {
614            let total = frame.sum();
615            if total > 1e-6 {
616                let dev = frame
617                    .iter()
618                    .zip(freqs.iter())
619                    .map(|(&s, &f)| s * (f - c).powi(p))
620                    .fold(0.0, |acc, x| acc + x)
621                    / total;
622                dev.powf(1.0 / p as f32)
623            } else {
624                0.0
625            }
626        })
627        .collect())
628}
629
630/// Computes spectral contrast across frequency bands.
631///
632/// Calculates the difference between peaks and valleys in subbands.
633///
634/// # Arguments
635/// * `signal` - The input audio signal.
636/// * `S` - Optional pre-computed spectrogram.
637/// * `n_fft` - Optional FFT window size (defaults to 2048).
638/// * `hop_length` - Optional hop length (defaults to n_fft/4).
639/// * `n_bands` - Optional number of frequency bands (defaults to 6).
640///
641/// # Returns
642/// Returns `Result<Array2<f32>, SpectralError>` containing contrast values of shape `(n_bands + 1, n_frames)`.
643///
644/// # Examples
645/// ```
646/// use dasp_rs::io::core::AudioData;
647/// use dasp_rs::signal_processing::spectral::spectral_contrast;
648/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
649/// let contrast = spectral_contrast(&signal, None, None, None, None).unwrap();
650/// assert_eq!(contrast.shape(), &[7, 1]);
651/// ```
652pub fn spectral_contrast(
653    signal: &AudioData,
654    s: Option<&Array2<f32>>,
655    n_fft: Option<usize>,
656    hop_length: Option<usize>,
657    n_bands: Option<usize>,
658) -> Result<Array2<f32>, SpectralError> {
659    let n_fft = n_fft.unwrap_or(2048);
660    let hop = hop_length.unwrap_or(n_fft / 4);
661    let n_bands = n_bands.unwrap_or(6);
662    if n_fft == 0 || hop == 0 || n_bands == 0 {
663        return Err(SpectralError::InvalidParameter(
664            "n_fft, hop_length, and n_bands must be positive".to_string(),
665        ));
666    }
667    if signal.samples.len() < n_fft {
668        return Err(SpectralError::InvalidSize(
669            "Signal length must be at least n_fft".to_string(),
670        ));
671    }
672
673    let s = match s {
674        Some(s) => s.to_owned(),
675        None => stft(&signal.samples, Some(n_fft), Some(hop), None)
676            .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?
677            .mapv(|x| x.norm()),
678    };
679
680    let freqs = fft_frequencies(Some(signal.sample_rate), Some(n_fft));
681    let band_edges = Array1::logspace(
682        2.0,
683        0.0,
684        f32::log2(signal.sample_rate as f32 / 2.0),
685        n_bands + 1,
686    );
687    let mut contrast = Array2::zeros((n_bands + 1, s.shape()[1]));
688    for t in 0..s.shape()[1] {
689        for b in 0..n_bands + 1 {
690            let f_low = if b == 0 { 0.0 } else { band_edges[b - 1] };
691            let f_high = band_edges[b];
692            let slice = s.slice(s![.., t]);
693            let band: Vec<f32> = slice
694                .iter()
695                .zip(freqs.iter())
696                .filter(|&(_, &f)| f >= f_low && f <= f_high)
697                .map(|(&s, _)| s)
698                .collect();
699            if !band.is_empty() {
700                let mut sorted = band;
701                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
702                let peak = sorted[sorted.len() - 1];
703                let valley = sorted[0];
704                contrast[[b, t]] = peak - valley;
705            }
706        }
707    }
708    Ok(contrast)
709}
710
711/// Computes spectral flatness.
712///
713/// Measures the uniformity of the spectrum per frame (geometric mean / arithmetic mean).
714///
715/// # Arguments
716/// * `signal` - The input audio signal.
717/// * `S` - Optional pre-computed spectrogram.
718/// * `n_fft` - Optional FFT window size (defaults to 2048).
719/// * `hop_length` - Optional hop length (defaults to n_fft/4).
720///
721/// # Returns
722/// Returns `Result<Array1<f32>, SpectralError>` containing flatness values per frame.
723///
724/// # Examples
725/// ```
726/// use dasp_rs::io::core::AudioData;
727/// use dasp_rs::signal_processing::spectral::spectral_flatness;
728/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
729/// let flatness = spectral_flatness(&signal, None, None, None).unwrap();
730/// assert_eq!(flatness.len(), 1);
731/// ```
732pub fn spectral_flatness(
733    signal: &AudioData,
734    s: Option<&Array2<f32>>,
735    n_fft: Option<usize>,
736    hop_length: Option<usize>,
737) -> Result<Array1<f32>, SpectralError> {
738    let n_fft = n_fft.unwrap_or(2048);
739    let hop = hop_length.unwrap_or(n_fft / 4);
740    if n_fft == 0 || hop == 0 {
741        return Err(SpectralError::InvalidParameter(
742            "n_fft and hop_length must be positive".to_string(),
743        ));
744    }
745    if signal.samples.len() < n_fft {
746        return Err(SpectralError::InvalidSize(
747            "Signal length must be at least n_fft".to_string(),
748        ));
749    }
750
751    let s = match s {
752        Some(s) => s.to_owned(),
753        None => stft(&signal.samples, Some(n_fft), Some(hop), None)
754            .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?
755            .mapv(|x| x.norm().max(1e-10)),
756    };
757
758    Ok(s.axis_iter(Axis(1))
759        .map(|frame| {
760            let log_frame = frame.mapv(f32::ln);
761            let geo_mean = log_frame.sum() / frame.len() as f32;
762            let arith_mean = frame.sum() / frame.len() as f32;
763            f32::exp(geo_mean) / arith_mean
764        })
765        .collect())
766}
767
768/// Computes spectral roll-off frequency.
769///
770/// Finds the frequency below which a specified percentage of total spectral energy lies.
771///
772/// # Arguments
773/// * `signal` - The input audio signal.
774/// * `S` - Optional pre-computed spectrogram.
775/// * `n_fft` - Optional FFT window size (defaults to 2048).
776/// * `hop_length` - Optional hop length (defaults to n_fft/4).
777/// * `roll_percent` - Optional roll-off percentage (defaults to 0.85).
778///
779/// # Returns
780/// Returns `Result<Array1<f32>, SpectralError>` containing roll-off frequencies per frame.
781///
782/// # Examples
783/// ```
784/// use dasp_rs::io::core::AudioData;
785/// use dasp_rs::signal_processing::spectral::spectral_rolloff;
786/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
787/// let rolloff = spectral_rolloff(&signal, None, None, None, None).unwrap();
788/// assert_eq!(rolloff.len(), 1);
789/// ```
790pub fn spectral_rolloff(
791    signal: &AudioData,
792    s: Option<&Array2<f32>>,
793    n_fft: Option<usize>,
794    hop_length: Option<usize>,
795    roll_percent: Option<f32>,
796) -> Result<Array1<f32>, SpectralError> {
797    let n_fft = n_fft.unwrap_or(2048);
798    let hop = hop_length.unwrap_or(n_fft / 4);
799    let roll_percent = roll_percent.unwrap_or(0.85);
800    if n_fft == 0 || hop == 0 {
801        return Err(SpectralError::InvalidParameter(
802            "n_fft and hop_length must be positive".to_string(),
803        ));
804    }
805    if roll_percent <= 0.0 || roll_percent > 1.0 {
806        return Err(SpectralError::InvalidParameter(
807            "roll_percent must be between 0 and 1".to_string(),
808        ));
809    }
810    if signal.samples.len() < n_fft {
811        return Err(SpectralError::InvalidSize(
812            "Signal length must be at least n_fft".to_string(),
813        ));
814    }
815
816    let s = match s {
817        Some(s) => s.to_owned(),
818        None => stft(&signal.samples, Some(n_fft), Some(hop), None)
819            .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?
820            .mapv(|x| x.norm()),
821    };
822
823    let freqs = fft_frequencies(Some(signal.sample_rate), Some(n_fft));
824    Ok(s.axis_iter(Axis(1))
825        .map(|frame| {
826            let total_energy = frame.sum();
827            let target_energy = total_energy * roll_percent;
828            let mut cum_energy = 0.0;
829            for (f, &s) in freqs.iter().zip(frame.iter()) {
830                cum_energy += s;
831                if cum_energy >= target_energy {
832                    return *f;
833                }
834            }
835            freqs[freqs.len() - 1]
836        })
837        .collect())
838}
839
840/// Computes polynomial fit coefficients for spectral features.
841///
842/// Fits a polynomial to each frame’s spectral magnitude.
843///
844/// # Arguments
845/// * `signal` - The input audio signal.
846/// * `S` - Optional pre-computed spectrogram.
847/// * `n_fft` - Optional FFT window size (defaults to 2048).
848/// * `hop_length` - Optional hop length (defaults to n_fft/4).
849/// * `order` - Optional polynomial order (defaults to 1).
850///
851/// # Returns
852/// Returns `Result<Array2<f32>, SpectralError>` containing coefficients of shape `(order + 1, n_frames)`.
853///
854/// # Examples
855/// ```
856/// use dasp_rs::io::core::AudioData;
857/// use dasp_rs::signal_processing::spectral::poly_features;
858/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
859/// let coeffs = poly_features(&signal, None, None, None, None).unwrap();
860/// assert_eq!(coeffs.shape(), &[2, 1]);
861/// ```
862pub fn poly_features(
863    signal: &AudioData,
864    s: Option<&Array2<f32>>,
865    n_fft: Option<usize>,
866    hop_length: Option<usize>,
867    order: Option<usize>,
868) -> Result<Array2<f32>, SpectralError> {
869    let n_fft = n_fft.unwrap_or(2048);
870    let hop = hop_length.unwrap_or(n_fft / 4);
871    let order = order.unwrap_or(1);
872    if n_fft == 0 || hop == 0 {
873        return Err(SpectralError::InvalidParameter(
874            "n_fft and hop_length must be positive".to_string(),
875        ));
876    }
877    if signal.samples.len() < n_fft {
878        return Err(SpectralError::InvalidSize(
879            "Signal length must be at least n_fft".to_string(),
880        ));
881    }
882
883    let s = match s {
884        Some(s) => s.to_owned(),
885        None => stft(&signal.samples, Some(n_fft), Some(hop), None)
886            .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?
887            .mapv(|x| x.norm()),
888    };
889
890    let mut coeffs = Array2::zeros((order + 1, s.shape()[1]));
891    let x = Array1::linspace(0.0, s.shape()[0] as f32 - 1.0, s.shape()[0]);
892    for t in 0..s.shape()[1] {
893        let y_t = s.slice(s![.., t]).to_owned();
894        let poly = polyfit(&x, &y_t, order);
895        for (i, &c) in poly.iter().enumerate() {
896            coeffs[[i, t]] = c;
897        }
898    }
899    Ok(coeffs)
900}
901
902/// Computes Tonnetz features from chroma.
903///
904/// Projects chroma features onto a 6-dimensional tonal space.
905///
906/// # Arguments
907/// * `signal` - The input audio signal.
908/// * `chroma` - Optional pre-computed chroma features.
909///
910/// # Returns
911/// Returns `Result<Array2<f32>, SpectralError>` containing Tonnetz features of shape `(6, n_frames)`.
912///
913/// # Examples
914/// ```
915/// use dasp_rs::io::core::AudioData;
916/// use dasp_rs::signal_processing::spectral::tonnetz;
917/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
918/// let tonnetz = tonnetz(&signal, None).unwrap();
919/// assert_eq!(tonnetz.shape(), &[6, 1]);
920/// ```
921pub fn tonnetz(
922    signal: &AudioData,
923    chroma: Option<&Array2<f32>>,
924) -> Result<Array2<f32>, SpectralError> {
925    let chroma_stft_result = chroma_stft(signal, None, None, None, None)?;
926    let chroma = chroma.unwrap_or(&chroma_stft_result);
927    if chroma.shape()[0] != 12 {
928        return Err(SpectralError::InvalidSize(
929            "Chroma must have 12 pitch classes".to_string(),
930        ));
931    }
932    let transform = Array2::from_shape_vec(
933        (6, 12),
934        vec![
935            1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // Fifths
936            0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // Minor thirds
937            0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // Major thirds
938            0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // Minor sevenths
939            0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // Major seconds
940            0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, // Tritones
941        ],
942    )
943    .unwrap();
944    Ok(transform.dot(chroma))
945}
946
947/// Fits a polynomial to data points.
948///
949/// Helper function for polynomial feature extraction.
950///
951/// # Arguments
952/// * `x` - X-coordinates.
953/// * `y` - Y-coordinates.
954/// * `order` - Polynomial order.
955///
956/// # Returns
957/// Returns a vector of polynomial coefficients, or zeros if solving fails.
958fn polyfit(x: &Array1<f32>, y: &Array1<f32>, order: usize) -> Vec<f32> {
959    let n = order + 1;
960    let mut a = Array2::zeros((x.len(), n));
961    for i in 0..x.len() {
962        for j in 0..n {
963            a[[i, j]] = x[i].powi(j as i32);
964        }
965    }
966    a.solve(&y.to_owned()).unwrap_or_else(|_| Array1::zeros(n)).to_vec()
967}
968
969/// Computes spectral flux.
970///
971/// Measures the change in spectral magnitude between consecutive frames.
972///
973/// # Arguments
974/// * `signal` - The input audio signal.
975/// * `S` - Optional pre-computed spectrogram.
976/// * `n_fft` - Optional FFT window size (defaults to 2048).
977/// * `hop_length` - Optional hop length (defaults to n_fft/4).
978///
979/// # Returns
980/// Returns `Result<Array1<f32>, SpectralError>` containing flux values per frame.
981///
982/// # Examples
983/// ```
984/// use dasp_rs::io::core::AudioData;
985/// use dasp_rs::signal_processing::spectral::spectral_flux;
986/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
987/// let flux = spectral_flux(&signal, None, None, None).unwrap();
988/// assert_eq!(flux.len(), 1);
989/// ```
990pub fn spectral_flux(
991    signal: &AudioData,
992    s: Option<&Array2<f32>>,
993    n_fft: Option<usize>,
994    hop_length: Option<usize>,
995) -> Result<Array1<f32>, SpectralError> {
996    let n_fft = n_fft.unwrap_or(2048);
997    let hop = hop_length.unwrap_or(n_fft / 4);
998    if n_fft == 0 || hop == 0 {
999        return Err(SpectralError::InvalidParameter(
1000            "n_fft and hop_length must be positive".to_string(),
1001        ));
1002    }
1003    if signal.samples.len() < n_fft {
1004        return Err(SpectralError::InvalidSize(
1005            "Signal length must be at least n_fft".to_string(),
1006        ));
1007    }
1008
1009    let s = match s {
1010        Some(s) => s.to_owned(),
1011        None => stft(&signal.samples, Some(n_fft), Some(hop), None)
1012            .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?
1013            .mapv(|x| x.norm()),
1014    };
1015
1016    let mut flux = Array1::zeros(s.shape()[1]);
1017    for t in 1..s.shape()[1] {
1018        let diff = &s.slice(s![.., t]) - &s.slice(s![.., t - 1]);
1019        flux[t] = diff.mapv(|x| x.powi(2)).sum().sqrt();
1020    }
1021    Ok(flux)
1022}
1023
1024/// Computes spectral entropy.
1025///
1026/// Calculates the entropy of the normalized spectrum per frame.
1027///
1028/// # Arguments
1029/// * `signal` - The input audio signal.
1030/// * `S` - Optional pre-computed spectrogram.
1031/// * `n_fft` - Optional FFT window size (defaults to 2048).
1032/// * `hop_length` - Optional hop length (defaults to n_fft/4).
1033///
1034/// # Returns
1035/// Returns `Result<Array1<f32>, SpectralError>` containing entropy values per frame.
1036///
1037/// # Examples
1038/// ```
1039/// use dasp_rs::io::core::AudioData;
1040/// use dasp_rs::signal_processing::spectral::spectral_entropy;
1041/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
1042/// let entropy = spectral_entropy(&signal, None, None, None).unwrap();
1043/// assert_eq!(entropy.len(), 1);
1044/// ```
1045pub fn spectral_entropy(
1046    signal: &AudioData,
1047    s: Option<&Array2<f32>>,
1048    n_fft: Option<usize>,
1049    hop_length: Option<usize>,
1050) -> Result<Array1<f32>, SpectralError> {
1051    let n_fft = n_fft.unwrap_or(2048);
1052    let hop = hop_length.unwrap_or(n_fft / 4);
1053    if n_fft == 0 || hop == 0 {
1054        return Err(SpectralError::InvalidParameter(
1055            "n_fft and hop_length must be positive".to_string(),
1056        ));
1057    }
1058    if signal.samples.len() < n_fft {
1059        return Err(SpectralError::InvalidSize(
1060            "Signal length must be at least n_fft".to_string(),
1061        ));
1062    }
1063
1064    let s = match s {
1065        Some(s) => s.to_owned(),
1066        None => stft(&signal.samples, Some(n_fft), Some(hop), None)
1067            .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?
1068            .mapv(|x| x.norm()),
1069    };
1070
1071    Ok(s.axis_iter(Axis(1))
1072        .map(|frame| {
1073            let sum = frame.sum();
1074            if sum <= 1e-10 {
1075                0.0
1076            } else {
1077                let p = frame.mapv(|x| x / sum);
1078                -p.mapv(|x| if x > 1e-10 { x * x.ln() } else { 0.0 }).sum()
1079            }
1080        })
1081        .collect())
1082}
1083
1084/// Computes pitch chroma features.
1085///
1086/// Normalizes spectral energy across 12 pitch classes per frame.
1087///
1088/// # Arguments
1089/// * `signal` - The input audio signal.
1090/// * `S` - Optional pre-computed spectrogram.
1091/// * `n_fft` - Optional FFT window size (defaults to 2048).
1092/// * `hop_length` - Optional hop length (defaults to n_fft/4).
1093///
1094/// # Returns
1095/// Returns `Result<Array2<f32>, SpectralError>` containing normalized pitch chroma features of shape `(12, n_frames)`.
1096///
1097/// # Examples
1098/// ```
1099/// use dasp_rs::io::core::AudioData;
1100/// use dasp_rs::signal_processing::spectral::pitch_chroma;
1101/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
1102/// let chroma = pitch_chroma(&signal, None, None, None).unwrap();
1103/// assert_eq!(chroma.shape(), &[12, 1]);
1104/// ```
1105pub fn pitch_chroma(
1106    signal: &AudioData,
1107    s: Option<&Array2<f32>>,
1108    n_fft: Option<usize>,
1109    hop_length: Option<usize>,
1110) -> Result<Array2<f32>, SpectralError> {
1111    let n_fft = n_fft.unwrap_or(2048);
1112    let hop = hop_length.unwrap_or(n_fft / 4);
1113    if n_fft == 0 || hop == 0 {
1114        return Err(SpectralError::InvalidParameter(
1115            "n_fft and hop_length must be positive".to_string(),
1116        ));
1117    }
1118    if signal.samples.len() < n_fft {
1119        return Err(SpectralError::InvalidSize(
1120            "Signal length must be at least n_fft".to_string(),
1121        ));
1122    }
1123
1124    let s = match s {
1125        Some(s) => s.to_owned(),
1126        None => stft(&signal.samples, Some(n_fft), Some(hop), None)
1127            .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?
1128            .mapv(|x| x.norm()),
1129    };
1130
1131    let freqs = fft_frequencies(Some(signal.sample_rate), Some(n_fft));
1132    let mut chroma = Array2::zeros((12, s.shape()[1]));
1133    for t in 0..s.shape()[1] {
1134        let frame = s.column(t);
1135        for (bin, &f) in freqs.iter().enumerate() {
1136            if frame[bin] > 0.0 {
1137                let midi = crate::hz_to_midi(&[f])[0];
1138                let pitch_class = midi.round() as usize % 12;
1139                chroma[[pitch_class, t]] += frame[bin];
1140            }
1141        }
1142    }
1143    for t in 0..chroma.shape()[1] {
1144        let sum = chroma.column(t).sum();
1145        if sum > 1e-6 {
1146            chroma.column_mut(t).mapv_inplace(|x| x / sum);
1147        }
1148    }
1149    Ok(chroma)
1150}
1151
1152/// Applies cepstral mean and variance normalization (CMVN).
1153///
1154/// Normalizes features by subtracting the mean and optionally dividing by the standard deviation.
1155///
1156/// # Arguments
1157/// * `features` - Input feature matrix.
1158/// * `axis` - Optional axis for normalization (-1 for time, 0 for features; defaults to -1).
1159/// * `variance` - Optional flag to normalize variance (defaults to true).
1160///
1161/// # Returns
1162/// Returns `Result<Array2<f32>, SpectralError>` containing the normalized feature matrix.
1163///
1164/// # Examples
1165/// ```
1166/// use dasp_rs::signal_processing::spectral::cmvn;
1167/// use ndarray::Array2;
1168/// let features = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1169/// let normalized = cmvn(&features, None, None).unwrap();
1170/// assert_eq!(normalized.shape(), &[2, 3]);
1171/// ```
1172pub fn cmvn(
1173    features: &Array2<f32>,
1174    axis: Option<isize>,
1175    variance: Option<bool>,
1176) -> Result<Array2<f32>, SpectralError> {
1177    let axis = axis.unwrap_or(-1);
1178    let do_variance = variance.unwrap_or(true);
1179    let ax = if axis < 0 { 1 } else { 0 };
1180
1181    if features.shape()[ax] < 2 {
1182        return Err(SpectralError::InvalidSize(
1183            "Feature dimension too small for normalization".to_string(),
1184        ));
1185    }
1186
1187    let mut normalized = features.to_owned();
1188    let means = normalized
1189        .mean_axis(Axis(ax))
1190        .ok_or(SpectralError::Numerical("Failed to compute mean".to_string()))?;
1191    for i in 0..normalized.shape()[1 - ax] {
1192        for j in 0..normalized.shape()[ax] {
1193            let idx = if ax == 1 { [j, i] } else { [i, j] };
1194            normalized[idx] -= means[i];
1195        }
1196    }
1197
1198    if do_variance {
1199        let variances = normalized
1200            .mapv(|x| x.powi(2))
1201            .mean_axis(Axis(ax))
1202            .ok_or(SpectralError::Numerical("Failed to compute variance".to_string()))?;
1203        let std_devs = variances.mapv(|x| (x + 1e-10).sqrt());
1204        for i in 0..normalized.shape()[1 - ax] {
1205            for j in 0..normalized.shape()[ax] {
1206                let idx = if ax == 1 { [j, i] } else { [i, j] };
1207                normalized[idx] /= std_devs[if ax == 1 { j } else { i }];
1208            }
1209        }
1210    }
1211
1212    Ok(normalized)
1213}
1214
1215/// Performs Harmonic-Percussive Source Separation (HPSS).
1216///
1217/// Separates the spectrogram into harmonic and percussive components using median filtering.
1218///
1219/// # Arguments
1220/// * `signal` - The input audio signal.
1221/// * `S` - Optional pre-computed spectrogram.
1222/// * `n_fft` - Optional FFT window size (defaults to 2048).
1223/// * `hop_length` - Optional hop length (defaults to n_fft/4).
1224/// * `harm_win` - Optional window size for harmonic component (defaults to 31).
1225/// * `perc_win` - Optional window size for percussive component (defaults to 31).
1226///
1227/// # Returns
1228/// Returns a tuple `(harmonic, percussive)` containing two `Array2<f32>` with separated components.
1229///
1230/// # Examples
1231/// ```
1232/// use dasp_rs::io::core::AudioData;
1233/// use dasp_rs::signal_processing::spectral::hpss;
1234/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
1235/// let (harmonic, percussive) = hpss(&signal, None, None, None, None, None);
1236/// assert_eq!(harmonic.shape(), &[2, 1]);
1237/// ```
1238pub fn hpss(
1239    signal: &AudioData,
1240    s: Option<&Array2<f32>>,
1241    n_fft: Option<usize>,
1242    hop_length: Option<usize>,
1243    harm_win: Option<usize>,
1244    perc_win: Option<usize>,
1245) -> Result<(Array2<f32>, Array2<f32>), SpectralError> {
1246    let n_fft = n_fft.unwrap_or(2048);
1247    let hop = hop_length.unwrap_or(n_fft / 4);
1248    let harm_win = harm_win.unwrap_or(31);
1249    let perc_win = perc_win.unwrap_or(31);
1250    if n_fft == 0 || hop == 0 || harm_win == 0 || perc_win == 0 {
1251        return Err(SpectralError::InvalidParameter(
1252            "n_fft, hop_length, harm_win, and perc_win must be positive".to_string(),
1253        ));
1254    }
1255    if signal.samples.len() < n_fft {
1256        return Err(SpectralError::InvalidSize(
1257            "Signal length must be at least n_fft".to_string(),
1258        ));
1259    }
1260
1261    let s = match s {
1262        Some(s) => s.to_owned(),
1263        None => stft(&signal.samples, Some(n_fft), Some(hop), None)
1264            .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?
1265            .mapv(|x| x.norm().powi(2)),
1266    };
1267
1268    let mut harmonic = Array2::zeros(s.dim());
1269    for f in 0..s.shape()[0] {
1270        let row = s.index_axis(Axis(0), f);
1271        for t in 0..s.shape()[1] {
1272            let start = t.saturating_sub(harm_win / 2);
1273            let end = (t + harm_win / 2 + 1).min(s.shape()[1]);
1274            let mut slice: Vec<f32> = row.slice(s![start..end]).to_vec();
1275            slice.sort_by(|a, b| a.partial_cmp(b).unwrap());
1276            harmonic[[f, t]] = slice[slice.len() / 2];
1277        }
1278    }
1279
1280    let mut percussive = Array2::zeros(s.dim());
1281    for t in 0..s.shape()[1] {
1282        let col = s.index_axis(Axis(1), t);
1283        for f in 0..s.shape()[0] {
1284            let start = f.saturating_sub(perc_win / 2);
1285            let end = (f + perc_win / 2 + 1).min(s.shape()[0]);
1286            let mut slice: Vec<f32> = col.slice(s![start..end]).to_vec();
1287            slice.sort_by(|a, b| a.partial_cmp(b).unwrap());
1288            percussive[[f, t]] = slice[slice.len() / 2];
1289        }
1290    }
1291
1292    let total = harmonic.clone() + percussive.clone();
1293    let harm_mask = &harmonic / &total.mapv(|x| if x > 0.0 { x } else { 1.0 });
1294    let perc_mask = &percussive / &total.mapv(|x| if x > 0.0 { x } else { 1.0 });
1295    Ok((s.to_owned() * &harm_mask, s.to_owned() * &perc_mask))
1296}
1297
1298/// Estimates pitch using autocorrelation.
1299///
1300/// Detects pitch by finding peaks in the autocorrelation function.
1301///
1302/// # Arguments
1303/// * `signal` - The input audio signal.
1304/// * `frame_length` - Optional frame length (defaults to 2048).
1305/// * `hop_length` - Optional hop length (defaults to frame_length/4).
1306/// * `fmin` - Optional minimum frequency (defaults to 50 Hz).
1307/// * `fmax` - Optional maximum frequency (defaults to 500 Hz).
1308///
1309/// # Returns
1310/// Returns `Result<Array1<f32>, SpectralError>` containing pitch estimates in Hz per frame.
1311///
1312/// # Examples
1313/// ```
1314/// use dasp_rs::io::core::AudioData;
1315/// use dasp_rs::signal_processing::spectral::pitch_autocorr;
1316/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
1317/// let pitch = pitch_autocorr(&signal, None, None, None, None).unwrap();
1318/// assert_eq!(pitch.len(), 1);
1319/// ```
1320pub fn pitch_autocorr(
1321    signal: &AudioData,
1322    frame_length: Option<usize>,
1323    hop_length: Option<usize>,
1324    fmin: Option<f32>,
1325    fmax: Option<f32>,
1326) -> Result<Array1<f32>, SpectralError> {
1327    let frame_len = frame_length.unwrap_or(2048);
1328    let hop = hop_length.unwrap_or(frame_len / 4);
1329    let fmin = fmin.unwrap_or(50.0);
1330    let fmax = fmax.unwrap_or(500.0);
1331    if frame_len == 0 || hop == 0 {
1332        return Err(SpectralError::InvalidParameter(
1333            "frame_length and hop_length must be positive".to_string(),
1334        ));
1335    }
1336    if fmin <= 0.0 || fmax <= fmin || fmax > signal.sample_rate as f32 / 2.0 {
1337        return Err(SpectralError::InvalidParameter(
1338            "fmin and fmax must satisfy 0 < fmin < fmax <= sr/2".to_string(),
1339        ));
1340    }
1341    if signal.samples.len() < frame_len {
1342        return Err(SpectralError::InvalidSize(
1343            "Signal length must be at least frame_length".to_string(),
1344        ));
1345    }
1346
1347    let n_frames = (signal.samples.len() - frame_len) / hop + 1;
1348    let mut pitch = Array1::zeros(n_frames);
1349
1350    for i in 0..n_frames {
1351        let start = i * hop;
1352        let frame = &signal.samples[start..(start + frame_len).min(signal.samples.len())];
1353        let frame_audio = AudioData {
1354            samples: frame.to_vec(),
1355            sample_rate: signal.sample_rate,
1356            channels: signal.channels,
1357        };
1358        let autocorr = autocorrelate(&frame_audio, Some(frame_len))
1359            .map_err(|e| SpectralError::TimeDomain(e.to_string()))?;
1360        let lag_min = (signal.sample_rate as f32 / fmax).round() as usize;
1361        let lag_max = (signal.sample_rate as f32 / fmin).round() as usize;
1362        let max_idx = autocorr[lag_min..lag_max.min(autocorr.len())]
1363            .iter()
1364            .position(|&x| {
1365                x == *autocorr[lag_min..lag_max.min(autocorr.len())]
1366                    .iter()
1367                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1368                    .unwrap()
1369            })
1370            .unwrap_or(0)
1371            + lag_min;
1372        pitch[i] = if max_idx > 0 {
1373            signal.sample_rate as f32 / max_idx as f32
1374        } else {
1375            0.0
1376        };
1377    }
1378
1379    Ok(pitch)
1380}
1381
1382/// Computes features for voice activity detection (VAD).
1383///
1384/// Extracts log energy, zero-crossing rate, and spectral flatness.
1385///
1386/// # Arguments
1387/// * `signal` - The input audio signal.
1388/// * `frame_length` - Optional frame length (defaults to 2048).
1389/// * `hop_length` - Optional hop length (defaults to frame_length/4).
1390/// * `n_fft` - Optional FFT window size (defaults to 2048).
1391///
1392/// # Returns
1393/// Returns `Result<Array2<f32>, SpectralError>` containing features of shape `(3, n_frames)`.
1394///
1395/// # Examples
1396/// ```
1397/// use dasp_rs::io::core::AudioData;
1398/// use dasp_rs::signal_processing::spectral::vad_features;
1399/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
1400/// let vad = vad_features(&signal, None, None, None).unwrap();
1401/// assert_eq!(vad.shape(), &[3, 1]);
1402/// ```
1403pub fn vad_features(
1404    signal: &AudioData,
1405    frame_length: Option<usize>,
1406    hop_length: Option<usize>,
1407    n_fft: Option<usize>,
1408) -> Result<Array2<f32>, SpectralError> {
1409    let frame_len = frame_length.unwrap_or(2048);
1410    let hop = hop_length.unwrap_or(frame_len / 4);
1411    let n_fft = n_fft.unwrap_or(2048);
1412    if frame_len == 0 || hop == 0 || n_fft == 0 {
1413        return Err(SpectralError::InvalidParameter(
1414            "frame_length, hop_length, and n_fft must be positive".to_string(),
1415        ));
1416    }
1417    if signal.samples.len() < frame_len || signal.samples.len() < n_fft {
1418        return Err(SpectralError::InvalidSize(
1419            "Signal length must be at least max(frame_length, n_fft)".to_string(),
1420        ));
1421    }
1422
1423    let n_frames = (signal.samples.len() - frame_len) / hop + 1;
1424    let energy = log_energy(signal, Some(frame_len), Some(hop))
1425        .map_err(|e| SpectralError::TimeDomain(e.to_string()))?;
1426    let zcr = crate::features::zero_crossing_rate(&signal.samples, Some(frame_len), Some(hop));
1427    let s = stft(&signal.samples, Some(n_fft), Some(hop), None)
1428        .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?
1429        .mapv(|x| x.norm());
1430    let flatness = s.axis_iter(Axis(1))
1431        .map(|frame| {
1432            let geo_mean = frame.mapv(|x| x.max(1e-10).ln()).mean().unwrap().exp();
1433            let arith_mean = frame.mean().unwrap();
1434            if arith_mean > 1e-10 {
1435                geo_mean / arith_mean
1436            } else {
1437                0.0
1438            }
1439        })
1440        .collect::<Array1<f32>>();
1441
1442    let mut features = Array2::zeros((3, n_frames));
1443    for i in 0..n_frames {
1444        features[[0, i]] = energy[i];
1445        features[[1, i]] = zcr[i];
1446        features[[2, i]] = flatness[i];
1447    }
1448    Ok(features)
1449}
1450
1451/// Computes spectral subband centroids.
1452///
1453/// Calculates the centroid frequency for each subband per frame.
1454///
1455/// # Arguments
1456/// * `signal` - The input audio signal.
1457/// * `S` - Optional pre-computed spectrogram.
1458/// * `n_fft` - Optional FFT window size (defaults to 2048).
1459/// * `hop_length` - Optional hop length (defaults to n_fft/4).
1460/// * `n_bands` - Optional number of subbands (defaults to 4).
1461///
1462/// # Returns
1463/// Returns `Result<Array2<f32>, SpectralError>` containing subband centroids of shape `(n_bands, n_frames)`.
1464///
1465/// # Examples
1466/// ```
1467/// use dasp_rs::io::core::AudioData;
1468/// use dasp_rs::signal_processing::spectral::spectral_subband_centroids;
1469/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
1470/// let centroids = spectral_subband_centroids(&signal, None, None, None, None).unwrap();
1471/// assert_eq!(centroids.shape(), &[4, 1]);
1472/// ```
1473pub fn spectral_subband_centroids(
1474    signal: &AudioData,
1475    s: Option<&Array2<f32>>,
1476    n_fft: Option<usize>,
1477    hop_length: Option<usize>,
1478    n_bands: Option<usize>,
1479) -> Result<Array2<f32>, SpectralError> {
1480    let n_fft = n_fft.unwrap_or(2048);
1481    let hop = hop_length.unwrap_or(n_fft / 4);
1482    let n_bands = n_bands.unwrap_or(4);
1483    if n_fft == 0 || hop == 0 || n_bands == 0 {
1484        return Err(SpectralError::InvalidParameter(
1485            "n_fft, hop_length, and n_bands must be positive".to_string(),
1486        ));
1487    }
1488    if signal.samples.len() < n_fft {
1489        return Err(SpectralError::InvalidSize(
1490            "Signal length must be at least n_fft".to_string(),
1491        ));
1492    }
1493
1494    let s = match s {
1495        Some(s) => s.to_owned(),
1496        None => stft(&signal.samples, Some(n_fft), Some(hop), None)
1497            .map_err(|e| SpectralError::TimeFrequency(e.to_string()))?
1498            .mapv(|x| x.norm()),
1499    };
1500
1501    let freqs = fft_frequencies(Some(signal.sample_rate), Some(n_fft));
1502    let band_edges = Array1::linspace(0.0, signal.sample_rate as f32 / 2.0, n_bands + 1);
1503    let mut centroids = Array2::zeros((n_bands, s.shape()[1]));
1504    for t in 0..s.shape()[1] {
1505        for b in 0..n_bands {
1506            let f_low = band_edges[b];
1507            let f_high = band_edges[b + 1];
1508            let subband: Vec<(f32, f32)> = freqs
1509                .iter()
1510                .zip(s.column(t))
1511                .filter(|(f, _)| **f >= f_low && **f < f_high)
1512                .map(|(f, s)| (*f, *s))
1513                .collect();
1514            if subband.is_empty() {
1515                centroids[[b, t]] = (f_low + f_high) / 2.0;
1516            } else {
1517                let total_energy = subband.iter().map(|(_, s)| s).sum::<f32>();
1518                centroids[[b, t]] = if total_energy > 1e-10 {
1519                    subband.iter().map(|(f, s)| f * s).sum::<f32>() / total_energy
1520                } else {
1521                    (f_low + f_high) / 2.0
1522                };
1523            }
1524        }
1525    }
1526    Ok(centroids)
1527}
1528
1529/// Estimates formant frequencies using LPC.
1530///
1531/// Extracts resonant frequencies from the vocal tract model.
1532///
1533/// # Arguments
1534/// * `signal` - The input audio signal.
1535/// * `n_formants` - Optional number of formants to extract (defaults to 3).
1536/// * `frame_length` - Optional frame length (defaults to 2048).
1537/// * `hop_length` - Optional hop length (defaults to frame_length/4).
1538///
1539/// # Returns
1540/// Returns `Result<Array2<f32>, SpectralError>` containing formant frequencies of shape `(n_formants, n_frames)`.
1541///
1542/// # Examples
1543/// ```
1544/// use dasp_rs::io::core::AudioData;
1545/// use dasp_rs::signal_processing::spectral::formant_frequencies;
1546/// let signal = AudioData { samples: vec![0.1, 0.2, 0.3, 0.4], sample_rate: 44100, channels: 1 };
1547/// let formants = formant_frequencies(&signal, None, None, None).unwrap();
1548/// assert_eq!(formants.shape(), &[3, 1]);
1549/// ```
1550pub fn formant_frequencies(
1551    signal: &AudioData,
1552    n_formants: Option<usize>,
1553    frame_length: Option<usize>,
1554    hop_length: Option<usize>,
1555) -> Result<Array2<f32>, SpectralError> {
1556    let n_formants = n_formants.unwrap_or(3);
1557    let frame_len = frame_length.unwrap_or(2048);
1558    let hop = hop_length.unwrap_or(frame_len / 4);
1559    let order = (2.0 * signal.sample_rate as f32 / 1000.0).round() as usize + 2;
1560    if frame_len == 0 || hop == 0 || n_formants == 0 {
1561        return Err(SpectralError::InvalidParameter(
1562            "frame_length, hop_length, and n_formants must be positive".to_string(),
1563        ));
1564    }
1565    if signal.samples.len() < frame_len {
1566        return Err(SpectralError::InvalidSize(
1567            "Signal length must be at least frame_length".to_string(),
1568        ));
1569    }
1570
1571    let n_frames = (signal.samples.len() - frame_len) / hop + 1;
1572    let mut formants = Array2::zeros((n_formants, n_frames));
1573
1574    for i in 0..n_frames {
1575        let start = i * hop;
1576        let frame_slice = &signal.samples[start..(start + frame_len).min(signal.samples.len())];
1577        let frame = AudioData {
1578            samples: frame_slice.to_vec(),
1579            sample_rate: signal.sample_rate,
1580            channels: signal.channels,
1581        };
1582        let lpc_coeffs = lpc(&frame, order)?;
1583        let roots = polynomial_roots(&lpc_coeffs)?;
1584        let mut freqs: Vec<f32> = roots
1585            .iter()
1586            .filter_map(|r| {
1587                if r.im.abs() > 1e-6 {
1588                    let freq = r.arg().abs() * signal.sample_rate as f32 / (2.0 * std::f32::consts::PI);
1589                    if freq > 50.0 && freq < signal.sample_rate as f32 / 2.0 {
1590                        Some(freq)
1591                    } else {
1592                        None
1593                    }
1594                } else {
1595                    None
1596                }
1597            })
1598            .collect();
1599        freqs.sort_by(|a, b| a.partial_cmp(b).unwrap());
1600        for (j, &f) in freqs.iter().take(n_formants).enumerate() {
1601            formants[[j, i]] = f;
1602        }
1603    }
1604    Ok(formants)
1605}
1606
1607/// Computes Linear Predictive Coding (LPC) coefficients.
1608///
1609/// Helper function for formant estimation.
1610///
1611/// # Arguments
1612/// * `frame` - Audio frame as AudioData.
1613/// * `order` - LPC order.
1614///
1615/// # Returns
1616/// Returns `Result<Vec<f32>, SpectralError>` containing LPC coefficients.
1617fn lpc(frame: &AudioData, order: usize) -> Result<Vec<f32>, SpectralError> {
1618    if frame.samples.len() < order {
1619        return Err(SpectralError::InvalidSize(
1620            "Frame length must be at least LPC order".to_string(),
1621        ));
1622    }
1623    let autocorr = autocorrelate(frame, Some(order + 1))
1624        .map_err(|e| SpectralError::TimeDomain(e.to_string()))?;
1625    if autocorr[0] <= 1e-10 {
1626        return Err(SpectralError::Numerical(
1627            "Frame energy too low for LPC".to_string(),
1628        ));
1629    }
1630
1631    let mut a = vec![1.0; order + 1];
1632    let mut e = autocorr[0];
1633    let mut tmp = vec![0.0; order + 1];
1634
1635    for i in 1..=order {
1636        let mut lambda = 0.0;
1637        for j in 0..i {
1638            lambda -= a[j] * autocorr[i - j];
1639        }
1640        lambda /= e;
1641        for j in 0..i {
1642            tmp[j] = a[j] + lambda * a[i - 1 - j];
1643        }
1644        a[..i].copy_from_slice(&tmp[..i]);
1645        a[i] = lambda;
1646        e *= 1.0 - lambda * lambda;
1647        if e <= 1e-10 {
1648            return Err(SpectralError::Numerical(
1649                "LPC instability detected".to_string(),
1650            ));
1651        }
1652    }
1653    Ok(a)
1654}
1655
1656/// Computes roots of a polynomial.
1657///
1658/// Helper function for formant estimation.
1659///
1660/// # Arguments
1661/// * `coeffs` - Polynomial coefficients (highest degree first).
1662///
1663/// # Returns
1664/// Returns `Result<Vec<Complex<f32>>, SpectralError>` containing complex roots.
1665fn polynomial_roots(coeffs: &[f32]) -> Result<Vec<Complex<f32>>, SpectralError> {
1666    if coeffs.len() <= 1 {
1667        return Ok(vec![]);
1668    }
1669
1670    let n = coeffs.len() - 1;
1671    let mut companion = Array2::zeros((n, n));
1672    for i in 0..n - 1 {
1673        companion[[i + 1, i]] = 1.0;
1674    }
1675    let a_n = coeffs[n];
1676    if a_n.abs() < 1e-10 {
1677        return Err(SpectralError::Numerical(
1678            "Leading coefficient too small".to_string(),
1679        ));
1680    }
1681    for i in 0..n {
1682        companion[[i, n - 1]] = -coeffs[n - 1 - i] / a_n;
1683    }
1684
1685    let eigenvalues = companion
1686        .eig()
1687        .map_err(|e| SpectralError::Numerical(format!("Eigenvalue computation failed: {}", e)))?;
1688    Ok(eigenvalues.0.to_vec())
1689}