autopitch 0.1.0

A modular pitch detection library
Documentation
//! Autocorrelation-based pitch detection
//!
//! This algorithm works by computing the similarity (dot product) of the signal with a delayed version of itself
//! for various lag values. To avoid octave errors, it selects the first significant peak (above 80% of the maximum)
//! rather than the absolute maximum. The lag is then refined using parabolic interpolation for sub-sample accuracy.

#[cfg(test)]
mod tests;

use crate::detect::PitchDetector;

/// Threshold ratio for first-peak detection.
const PEAK_THRESHOLD_RATIO: f32 = 0.8;

/// Minimum signal variance required for pitch detection.
const MIN_VARIANCE_THRESHOLD: f32 = 1e-6;

/// Autocorrelation-based pitch detector.
///
/// This is a simple implementation that iterates over the range of lags and finds the one
/// with the highest dot product between the original and delayed signal.
///
/// ### Defaults
/// By default, the detector searches between [`Self::DEFAULT_MIN_LAG`] and [`Self::DEFAULT_MAX_LAG`],
/// which covers the human-audible pitch range (~22 Hz to ~2.2 kHz) for 44.1kHz audio.
/// This is wide enough to detect most musical instruments and the full human vocal range,
/// while keeping the computation efficient.
pub struct Autocorrelation {
    /// Minimum lag to test (in samples). Higher = detect lower max frequency.
    min_lag: usize,
    /// Maximum lag to test (in samples). Higher = detect lower pitch floor.
    max_lag: usize,
    /// Stores the autocorrelation score for each tested lag value.
    ///
    /// The vector is reused between calls to [`detect`](#method.detect)
    /// to avoid reallocating memory during repeated or real-time processing.
    lag_scores: Vec<f32>,
}

impl Autocorrelation {
    /// The default minimum lag (in samples) to test for autocorrelation.
    ///
    /// A smaller lag corresponds to higher frequencies (shorter wavelengths).
    ///
    /// At a typical sample rate of **44.1 kHz**, `DEFAULT_MIN_LAG = 20` corresponds to:
    ///
    /// ```text
    /// frequency = sample_rate / lag = 44100 / 20 ≈ 2205 Hz
    /// ```
    ///
    /// This upper frequency bound (~2.2 kHz) covers the highest fundamental frequencies
    /// found in most musical instruments and voices while filtering out very high-pitched
    /// overtones and noise that are not typically perceived as "pitch".
    pub const DEFAULT_MIN_LAG: usize = 20;

    /// The default maximum lag (in samples) to test for autocorrelation.
    ///
    /// A larger lag corresponds to lower frequencies (longer wavelengths).
    ///
    /// At a typical sample rate of **44.1 kHz**, `DEFAULT_MAX_LAG = 2000` corresponds to:
    ///
    /// ```text
    /// frequency = sample_rate / lag = 44100 / 2000 ≈ 22 Hz
    /// ```
    ///
    /// This lower frequency bound (~22 Hz) is close to the **threshold of human hearing**.
    /// Going below this range is unnecessary for most musical or vocal applications and
    /// adds computational cost without improving useful detection results.
    pub const DEFAULT_MAX_LAG: usize = 2000;

    pub fn new(min_lag: usize, max_lag: usize) -> Self {
        Self {
            min_lag,
            max_lag,
            lag_scores: vec![0.0; max_lag + 1],
        }
    }
}

impl Autocorrelation {
    /// Finds the first local maximum in `lag_scores` that exceeds `threshold`.
    ///
    /// A local maximum is a point where the score is higher than both neighbors.
    /// This helps avoid octave errors by selecting the first significant peak
    /// (the true fundamental) rather than a later harmonic peak.
    fn find_first_peak(&self, threshold: f32) -> Option<usize> {
        for lag in (self.min_lag + 1)..self.max_lag {
            let prev = self.lag_scores[lag - 1];
            let curr = self.lag_scores[lag];
            let next = self.lag_scores[lag + 1];

            // Local maximum: higher than both neighbors and above threshold
            if curr > prev && curr > next && curr >= threshold {
                return Some(lag);
            }
        }
        None
    }
}

impl Default for Autocorrelation {
    fn default() -> Self {
        Self::new(Self::DEFAULT_MIN_LAG, Self::DEFAULT_MAX_LAG)
    }
}

impl PitchDetector for Autocorrelation {
    ///
    /// This algorithm iterates over possible lag values (sample offsets) within the range
    /// `[min_lag, max_lag]`, computing a correlation score for each lag.
    ///
    /// To avoid octave errors, it selects the **first significant peak** (above 80% of the
    /// maximum score) rather than the absolute maximum. The selected lag is then refined
    /// using parabolic interpolation for sub-sample accuracy.
    ///
    /// The detected pitch frequency is then calculated as:
    ///
    /// ```text
    /// pitch (Hz) = sample_rate / lag
    /// ```
    fn detect(&mut self, samples: &[f32], sample_rate: f32) -> Option<f32> {
        let samples_length = samples.len();

        if is_silent(samples) {
            return None;
        }

        let mut best_lag = 0;
        let mut best_score = 0.0;

        self.lag_scores.fill(0.0);

        for lag in self.min_lag..=self.max_lag {
            let mut sum = 0.0;

            let Some(overlap_len) = samples_length.checked_sub(lag).filter(|n| *n > 0) else {
                break;
            };

            for i in 0..overlap_len {
                sum += samples[i] * samples[i + lag];
            }

            // Normalize by number of summed terms
            sum /= overlap_len as f32;

            self.lag_scores[lag] = sum;

            if sum > best_score {
                best_score = sum;
                best_lag = lag;
            }
        }

        if best_lag == 0 {
            return None;
        }

        // Find first peak that's at least 80% of the max score.
        // This avoids octave errors where lag=2N scores similarly to lag=N.
        let threshold = best_score * PEAK_THRESHOLD_RATIO;
        let first_peak = self.find_first_peak(threshold);

        let peak_lag = first_peak.unwrap_or(best_lag);

        let delta = parabolic_interpolation(&self.lag_scores, peak_lag).unwrap_or_default();
        Some(sample_rate / (peak_lag as f32 + delta))
    }
}

/// Refines the estimated lag using parabolic interpolation.
///
/// This improves sub-sample pitch accuracy by fitting a parabola through
/// three autocorrelation values around the peak: (lag-1, lag, lag+1).
///
/// The formula used is:
///
/// ```text
/// δ = (y₋₁ - y₊₁) / [2 * (y₋₁ - 2y₀ + y₊₁)]
/// ```
///
/// **Note:**  
/// Some references define the denominator as `2(2y₀ - y₋₁ - y₊₁)`, which is
/// algebraically equivalent but flips the sign of δ.  
/// This implementation follows the standard derivative form (−b / 2a)
/// used in most DSP and signal analysis contexts, where the sign convention
/// simply determines whether the true peak lies slightly before or after
/// the integer lag.
///
/// # Returns
/// The fractional offset δ in samples, where:
/// - δ > 0 → the true peak is slightly to the right of the integer lag
/// - δ < 0 → the true peak is slightly to the left
///
/// If the curve is too flat or purely defined to determine a peak reliably,
/// returns `None`.
///
/// # Reference
/// This technique is described in:
/// - Julius O. Smith III, *Digital Signal Processing Techniques*, Stanford University,  
///   [Pitch Detection Using Autocorrelation](https://www.dsprelated.com/freebooks/sasp/Quadratic_Interpolation_Spectral_Peaks.html)
fn parabolic_interpolation(scores: &[f32], best_lag: usize) -> Option<f32> {
    let y_m = best_lag.checked_sub(1).and_then(|i| scores.get(i))?;
    let y_0 = scores.get(best_lag)?;
    let y_p = best_lag.checked_add(1).and_then(|i| scores.get(i))?;

    let numerator = y_m - y_p;
    let denominator = 2.0 * (y_m - 2.0 * y_0 + y_p);

    if denominator.abs() < f32::EPSILON {
        return None;
    }

    Some(numerator / denominator)
}

/// Returns `true` if the signal is silent or constant (no meaningful variation).
///
/// Uses variance to detect flat signals that would produce false pitch results.
fn is_silent(samples: &[f32]) -> bool {
    if samples.is_empty() {
        return true;
    }
    let len = samples.len() as f32;
    let mean = samples.iter().sum::<f32>() / len;
    let variance = samples.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / len;
    variance < MIN_VARIANCE_THRESHOLD
}