Skip to main content

math_audio_dsp/
esprit.rs

1// ============================================================================
2// ESPRIT Frequency/Phase Estimator
3// ============================================================================
4//
5// Super-resolution frequency estimation using the ESPRIT (Estimation of Signal
6// Parameters via Rotational Invariance Techniques) algorithm.
7//
8// Resolves frequencies beyond FFT bin resolution by exploiting shift-invariance
9// in the signal subspace of a Hankel data matrix.
10//
11// Complexity: O(M³) for SVD where M = signal.len() / 3 (typical).
12
13use nalgebra::DMatrix;
14
15/// A single estimated sinusoidal component.
16#[derive(Debug, Clone)]
17pub struct SinusoidEstimate {
18    /// Frequency in Hz
19    pub frequency: f64,
20    /// Amplitude (linear scale)
21    pub amplitude: f64,
22    /// Phase in radians (-π to π)
23    pub phase: f64,
24}
25
26/// Model order estimation criterion.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum ModelOrderCriterion {
29    /// Minimum Description Length — conservative, penalizes complexity more
30    Mdl,
31    /// Akaike Information Criterion — less conservative
32    Aic,
33}
34
35/// Estimate the number of sinusoidal components using information-theoretic criteria.
36///
37/// # Arguments
38/// * `singular_values` - Singular values from SVD of the Hankel matrix (descending order)
39/// * `num_snapshots` - Number of rows (snapshots) in the Hankel matrix
40/// * `criterion` - MDL or AIC
41///
42/// # Returns
43/// Estimated model order (number of sinusoids)
44pub fn estimate_model_order(
45    singular_values: &[f64],
46    num_snapshots: usize,
47    criterion: ModelOrderCriterion,
48) -> usize {
49    let m = singular_values.len();
50    let n = num_snapshots as f64;
51
52    if m == 0 {
53        return 0;
54    }
55
56    let mut best_p = 0;
57    let mut best_cost = f64::INFINITY;
58
59    // Eigenvalues = singular_values^2
60    let eigenvalues: Vec<f64> = singular_values.iter().map(|s| s * s).collect();
61
62    for p in 0..m {
63        let noise_dim = m - p;
64        if noise_dim == 0 {
65            break;
66        }
67
68        // Geometric and arithmetic mean of noise eigenvalues
69        let noise_eigs = &eigenvalues[p..];
70        let arith_mean = noise_eigs.iter().sum::<f64>() / noise_dim as f64;
71
72        if arith_mean <= 0.0 {
73            break;
74        }
75
76        let log_geo_mean =
77            noise_eigs.iter().map(|&e| (e.max(1e-30)).ln()).sum::<f64>() / noise_dim as f64;
78        let geo_mean = log_geo_mean.exp();
79
80        let ratio = geo_mean / arith_mean;
81        if ratio <= 0.0 || !ratio.is_finite() {
82            break;
83        }
84
85        let log_likelihood = -n * noise_dim as f64 * ratio.ln();
86
87        let num_free_params = p as f64 * (2.0 * m as f64 - p as f64);
88
89        let cost = match criterion {
90            ModelOrderCriterion::Mdl => -log_likelihood + 0.5 * num_free_params * n.ln(),
91            ModelOrderCriterion::Aic => -2.0 * log_likelihood + 2.0 * num_free_params,
92        };
93
94        if cost < best_cost {
95            best_cost = cost;
96            best_p = p;
97        }
98    }
99
100    best_p
101}
102
103/// Core ESPRIT algorithm for super-resolution frequency estimation.
104///
105/// # Arguments
106/// * `signal` - Input signal samples
107/// * `sample_rate` - Sample rate in Hz
108/// * `model_order` - Number of sinusoids to estimate (None = auto-detect via MDL)
109/// * `window_size` - Hankel matrix column count (None = signal.len() / 3)
110///
111/// # Returns
112/// Vector of estimated sinusoidal components, sorted by amplitude (descending)
113pub fn esprit(
114    signal: &[f32],
115    sample_rate: f32,
116    model_order: Option<usize>,
117    window_size: Option<usize>,
118) -> Vec<SinusoidEstimate> {
119    let n = signal.len();
120    if n < 4 {
121        return Vec::new();
122    }
123
124    // Default M = N/3
125    let m = window_size.unwrap_or(n / 3).max(2).min(n - 1);
126    let num_rows = n - m + 1;
127
128    if num_rows < 2 || m < 2 {
129        return Vec::new();
130    }
131
132    // Build Hankel matrix X (num_rows x m)
133    let hankel = DMatrix::from_fn(num_rows, m, |i, j| signal[i + j] as f64);
134
135    // SVD
136    let svd = hankel.svd(true, true);
137    let singular_values = svd.singular_values.as_slice();
138
139    // Guard against near-zero signals (degenerate subspace)
140    if singular_values.is_empty() || singular_values[0] < f64::EPSILON * n as f64 {
141        return Vec::new();
142    }
143
144    // Determine model order
145    // For real-valued signals, each sinusoid contributes 2 eigenvalues (conjugate pair),
146    // so we need to double the requested model order.
147    let p = match model_order {
148        Some(p) => (p * 2).min(m - 1).min(num_rows - 1),
149        None => {
150            let auto_p = estimate_model_order(singular_values, num_rows, ModelOrderCriterion::Mdl);
151            // Ensure at least 2 for real signals
152            auto_p.max(2).min(m - 1).min(num_rows - 1)
153        }
154    };
155
156    if p == 0 {
157        return Vec::new();
158    }
159
160    // Extract signal subspace: first P columns of V
161    let v_full = match &svd.v_t {
162        Some(v_t) => v_t.transpose(),
163        None => return Vec::new(),
164    };
165
166    if v_full.ncols() < p || v_full.nrows() < m {
167        return Vec::new();
168    }
169
170    let v_s = v_full.columns(0, p);
171
172    // Shift-invariance: V_1 = V_s[0..m-1, :], V_2 = V_s[1..m, :]
173    let v1 = v_s.rows(0, m - 1).clone_owned();
174    let v2 = v_s.rows(1, m - 1).clone_owned();
175
176    // Compute Phi = pinv(V_1) * V_2
177    let v1_svd = v1.svd(true, true);
178    let phi = match v1_svd.solve(&v2, 1e-10) {
179        Ok(phi) => phi,
180        Err(_) => return Vec::new(),
181    };
182
183    // Extract complex eigenvalues from real matrix
184    // nalgebra::DMatrix<f64>::complex_eigenvalues() handles 2x2 blocks in real Schur form
185    let eigenvalues = phi.complex_eigenvalues();
186
187    let mut estimates = Vec::with_capacity(p);
188    for lambda in eigenvalues.iter() {
189        // Frequency from angle of eigenvalue
190        let angle = lambda.im.atan2(lambda.re);
191        let freq = sample_rate as f64 * angle / (2.0 * std::f64::consts::PI);
192
193        // Only keep positive frequencies below Nyquist
194        if freq > 0.0 && freq < sample_rate as f64 / 2.0 {
195            let amplitude = estimate_amplitude(signal, freq, sample_rate as f64);
196            let phase = estimate_phase(signal, freq, sample_rate as f64);
197
198            estimates.push(SinusoidEstimate {
199                frequency: freq,
200                amplitude,
201                phase,
202            });
203        }
204    }
205
206    // Sort by amplitude (descending)
207    estimates.sort_by(|a, b| {
208        b.amplitude
209            .partial_cmp(&a.amplitude)
210            .unwrap_or(std::cmp::Ordering::Equal)
211    });
212
213    estimates
214}
215
216/// Estimate the amplitude of a sinusoid at a given frequency using least-squares.
217fn estimate_amplitude(signal: &[f32], freq: f64, sample_rate: f64) -> f64 {
218    let n = signal.len();
219    let omega = 2.0 * std::f64::consts::PI * freq / sample_rate;
220
221    let mut sum_cos = 0.0;
222    let mut sum_sin = 0.0;
223
224    for (i, &s) in signal.iter().enumerate() {
225        let phase = omega * i as f64;
226        sum_cos += s as f64 * phase.cos();
227        sum_sin += s as f64 * phase.sin();
228    }
229
230    2.0 * (sum_cos * sum_cos + sum_sin * sum_sin).sqrt() / n as f64
231}
232
233/// Estimate the phase of a sinusoid at a given frequency.
234fn estimate_phase(signal: &[f32], freq: f64, sample_rate: f64) -> f64 {
235    let omega = 2.0 * std::f64::consts::PI * freq / sample_rate;
236
237    let mut sum_cos = 0.0;
238    let mut sum_sin = 0.0;
239
240    for (i, &s) in signal.iter().enumerate() {
241        let phase = omega * i as f64;
242        sum_cos += s as f64 * phase.cos();
243        sum_sin += s as f64 * phase.sin();
244    }
245
246    sum_sin.atan2(sum_cos)
247}
248
249/// Convenience wrapper: estimate frequencies from a signal.
250///
251/// # Arguments
252/// * `signal` - Input signal samples
253/// * `sample_rate` - Sample rate in Hz
254/// * `max_sinusoids` - Maximum number of sinusoids to return
255///
256/// # Returns
257/// Sorted vector of estimated frequencies in Hz
258pub fn estimate_frequencies(signal: &[f32], sample_rate: f32, max_sinusoids: usize) -> Vec<f64> {
259    let estimates = esprit(signal, sample_rate, Some(max_sinusoids), None);
260    let mut freqs: Vec<f64> = estimates
261        .iter()
262        .take(max_sinusoids)
263        .map(|e| e.frequency)
264        .collect();
265    freqs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
266    freqs
267}
268
269// ============================================================================
270// Tests
271// ============================================================================
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    fn gen_sinusoid(
278        freq: f64,
279        amplitude: f64,
280        phase: f64,
281        sample_rate: f64,
282        num_samples: usize,
283    ) -> Vec<f32> {
284        (0..num_samples)
285            .map(|i| {
286                let t = i as f64 / sample_rate;
287                (amplitude * (2.0 * std::f64::consts::PI * freq * t + phase).sin()) as f32
288            })
289            .collect()
290    }
291
292    #[test]
293    fn test_pure_tone_1000_5_hz() {
294        let sample_rate = 48000.0_f32;
295        let freq = 1000.5;
296        let signal = gen_sinusoid(freq, 1.0, 0.0, sample_rate as f64, 1024);
297
298        let estimates = esprit(&signal, sample_rate, Some(1), None);
299        assert!(
300            !estimates.is_empty(),
301            "ESPRIT should find at least one component"
302        );
303
304        let est_freq = estimates[0].frequency;
305        let error = (est_freq - freq).abs();
306        assert!(
307            error < 1.0,
308            "ESPRIT frequency error {error:.3} Hz exceeds 1 Hz threshold (estimated {est_freq:.3} vs actual {freq})"
309        );
310    }
311
312    #[test]
313    fn test_two_close_tones() {
314        let sample_rate = 48000.0_f32;
315        let signal1 = gen_sinusoid(1000.0, 1.0, 0.0, sample_rate as f64, 2048);
316        let signal2 = gen_sinusoid(1050.0, 0.8, 0.5, sample_rate as f64, 2048);
317        let signal: Vec<f32> = signal1.iter().zip(&signal2).map(|(&a, &b)| a + b).collect();
318
319        let freqs = estimate_frequencies(&signal, sample_rate, 4);
320        assert!(
321            freqs.len() >= 2,
322            "Should find at least 2 frequencies, found {}",
323            freqs.len()
324        );
325
326        // Check that both frequencies are close to expected values
327        let has_1000 = freqs.iter().any(|&f| (f - 1000.0).abs() < 5.0);
328        let has_1050 = freqs.iter().any(|&f| (f - 1050.0).abs() < 5.0);
329        assert!(has_1000, "Should find ~1000 Hz in {freqs:?}");
330        assert!(has_1050, "Should find ~1050 Hz in {freqs:?}");
331    }
332
333    #[test]
334    fn test_white_noise_low_model_order() {
335        let _sample_rate = 48000.0_f32;
336        // Generate pseudo-random noise using a simple LCG
337        let mut rng_state: u64 = 12345;
338        let signal: Vec<f32> = (0..512)
339            .map(|_| {
340                rng_state = rng_state
341                    .wrapping_mul(6364136223846793005)
342                    .wrapping_add(1442695040888963407);
343                ((rng_state >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0
344            })
345            .collect();
346
347        let p = estimate_model_order(
348            &{
349                let hankel = DMatrix::from_fn(
350                    signal.len() - signal.len() / 3 + 1,
351                    signal.len() / 3,
352                    |i, j| signal[i + j] as f64,
353                );
354                let svd = hankel.svd(false, false);
355                svd.singular_values.as_slice().to_vec()
356            },
357            signal.len() - signal.len() / 3 + 1,
358            ModelOrderCriterion::Mdl,
359        );
360
361        assert!(p <= 5, "White noise model order should be small, got {p}");
362    }
363
364    #[test]
365    fn test_three_sinusoids_known_answer() {
366        let sample_rate = 48000.0_f32;
367        let freqs_expected = [440.0, 880.0, 1320.0];
368        let amps = [1.0, 0.5, 0.3];
369        let phases = [0.0, 0.7, -1.2];
370
371        let num_samples = 4096;
372        let mut signal = vec![0.0f32; num_samples];
373        for ((&freq, &amp), &phase) in freqs_expected.iter().zip(&amps).zip(&phases) {
374            let s = gen_sinusoid(freq, amp, phase, sample_rate as f64, num_samples);
375            for (i, &v) in s.iter().enumerate() {
376                signal[i] += v;
377            }
378        }
379
380        let estimates = esprit(&signal, sample_rate, Some(3), None);
381        assert!(
382            estimates.len() >= 3,
383            "Should find 3 sinusoids, found {}",
384            estimates.len()
385        );
386
387        let est_freqs = estimate_frequencies(&signal, sample_rate, 3);
388        for &expected in &freqs_expected {
389            assert!(
390                est_freqs.iter().any(|&f| (f - expected).abs() < 3.0),
391                "Expected frequency {expected} Hz not found in {est_freqs:?}"
392            );
393        }
394    }
395
396    #[test]
397    fn test_near_zero_signal_no_phantom() {
398        // Regression: near-zero signals should not produce phantom frequencies
399        let signal = vec![1e-38f32; 512];
400        let result = esprit(&signal, 48000.0, Some(2), None);
401        assert!(
402            result.is_empty(),
403            "Near-zero signal should produce no estimates, got {} components",
404            result.len()
405        );
406    }
407
408    #[test]
409    fn test_all_zero_signal() {
410        let signal = vec![0.0f32; 512];
411        let result = esprit(&signal, 48000.0, Some(2), None);
412        assert!(
413            result.is_empty(),
414            "All-zero signal should produce no estimates"
415        );
416    }
417
418    #[test]
419    fn test_empty_signal() {
420        let result = esprit(&[], 48000.0, Some(1), None);
421        assert!(result.is_empty());
422    }
423
424    #[test]
425    fn test_very_short_signal() {
426        let result = esprit(&[1.0, 2.0], 48000.0, Some(1), None);
427        assert!(result.is_empty());
428    }
429
430    #[test]
431    fn test_estimate_frequencies_convenience() {
432        let sample_rate = 48000.0_f32;
433        let signal = gen_sinusoid(500.0, 1.0, 0.0, sample_rate as f64, 2048);
434        let freqs = estimate_frequencies(&signal, sample_rate, 2);
435        assert!(!freqs.is_empty());
436        assert!(
437            (freqs[0] - 500.0).abs() < 2.0,
438            "Expected ~500 Hz, got {:.2}",
439            freqs[0]
440        );
441    }
442}