autoeq_cea2034/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::collections::HashMap;
4use std::error::Error;
5
6use ndarray::concatenate;
7use ndarray::s;
8use ndarray::{Array1, Array2, Axis};
9use serde::{Deserialize, Serialize};
10
11/// A struct to hold frequency and SPL data.
12/// Re-exported from the main autoeq crate for compatibility
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Curve {
15    /// Frequency points in Hz
16    pub freq: Array1<f64>,
17    /// Sound Pressure Level in dB
18    pub spl: Array1<f64>,
19    /// Phase in degrees (optional)
20    #[serde(default, skip_serializing_if = "Option::is_none")]
21    pub phase: Option<Array1<f64>>,
22}
23
24/// A single directivity measurement at a specific angle
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct DirectivityCurve {
27    /// Angle in degrees (e.g., -60, -50, ..., 0, ..., 50, 60)
28    pub angle: f64,
29    /// Frequency points in Hz
30    pub freq: Array1<f64>,
31    /// Sound Pressure Level in dB
32    pub spl: Array1<f64>,
33}
34
35/// Complete directivity data for horizontal and vertical planes
36///
37/// Contains SPL measurements at multiple angles for both horizontal
38/// and vertical planes, as typically provided by spinorama.org.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct DirectivityData {
41    /// Horizontal plane measurements (typically -60° to +60°)
42    pub horizontal: Vec<DirectivityCurve>,
43    /// Vertical plane measurements (typically -60° to +60°)
44    pub vertical: Vec<DirectivityCurve>,
45}
46
47/// Convert SPL values to pressure values
48///
49/// # Arguments
50/// * `spl` - Array of SPL values
51///
52/// # Returns
53/// * Array of pressure values
54///
55/// # Formula
56/// pressure = 10^((spl-105)/20)
57fn spl2pressure(spl: &Array1<f64>) -> Array1<f64> {
58    // 10^((spl-105)/20)
59    spl.mapv(|v| 10f64.powf((v - 105.0) / 20.0))
60}
61
62/// Convert pressure values to SPL values
63///
64/// # Arguments
65/// * `p` - Array of pressure values
66///
67/// # Returns
68/// * Array of SPL values
69///
70/// # Formula
71/// spl = 20*log10(p) + 105
72fn pressure2spl(p: &Array1<f64>) -> Array1<f64> {
73    // 20*log10(p) + 105
74    p.mapv(|v| 20.0 * v.log10() + 105.0)
75}
76
77/// Convert SPL values to squared pressure values
78///
79/// # Arguments
80/// * `spl` - 2D array of SPL values
81///
82/// # Returns
83/// * 2D array of squared pressure values
84///
85/// # Details
86/// Computes pressure values from SPL and then squares them using vectorized operations
87fn spl2pressure2(spl: &Array2<f64>) -> Array2<f64> {
88    // Vectorized: 10^((spl-105)/20) then square
89    spl.mapv(|v| {
90        let p = 10f64.powf((v - 105.0) / 20.0);
91        p * p
92    })
93}
94
95/// Compute the CEA2034 spinorama from SPL data (internal implementation)
96///
97/// # Arguments
98/// * `spl` - 2D array of SPL measurements
99/// * `idx` - Indices for grouping measurements
100/// * `weights` - Weights for computing weighted averages
101///
102/// # Returns
103/// * 2D array representing the CEA2034 spinorama
104///
105/// # Details
106/// Computes various CEA2034 curves including On Axis, Listening Window,
107/// Early Reflections, Sound Power, and Predicted In-Room response.
108fn cea2034_array(spl: &Array2<f64>, idx: &[Vec<usize>], weights: &Array1<f64>) -> Array2<f64> {
109    let len_spl = spl.shape()[1];
110    let p2 = spl2pressure2(spl);
111    let idx_sp = idx.len() - 1;
112    let idx_lw = 0usize;
113    let idx_er = 1usize;
114    let idx_pir = idx_sp + 1;
115
116    let mut cea = Array2::<f64>::zeros((idx.len() + 1, len_spl));
117
118    for (i, idx_val) in idx.iter().enumerate().take(idx_sp) {
119        let curve = apply_rms(&p2, idx_val);
120        cea.row_mut(i).assign(&curve);
121    }
122
123    // ER: indices 2..=6 per original logic - vectorized
124    let er_rows = cea.slice(s![2..=6, ..]);
125    let er_pressures = er_rows.mapv(|v| {
126        let p = 10f64.powf((v - 105.0) / 20.0);
127        p * p
128    });
129    let er_p2_sum = er_pressures.sum_axis(Axis(0));
130    let er_p = er_p2_sum.mapv(|v| (v / 5.0).sqrt());
131    let er_spl = pressure2spl(&er_p);
132    cea.row_mut(idx_er).assign(&er_spl);
133
134    // SP weighted
135    let sp_curve = apply_weighted_rms(&p2, &idx[idx_sp], weights);
136    cea.row_mut(idx_sp).assign(&sp_curve);
137
138    // PIR - vectorized computation
139    let lw_p = spl2pressure(&cea.row(idx_lw).to_owned());
140    let er_p = spl2pressure(&cea.row(idx_er).to_owned());
141    let sp_p = spl2pressure(&cea.row(idx_sp).to_owned());
142
143    let lw2 = lw_p.mapv(|v| v * v);
144    let er2 = er_p.mapv(|v| v * v);
145    let sp2 = sp_p.mapv(|v| v * v);
146
147    let pir = (lw2.mapv(|v| 0.12 * v) + er2.mapv(|v| 0.44 * v) + sp2.mapv(|v| 0.44 * v))
148        .mapv(|v| v.sqrt());
149    let pir_spl = pressure2spl(&pir);
150    cea.row_mut(idx_pir).assign(&pir_spl);
151
152    cea
153}
154
155/// Apply RMS averaging to pressure squared values
156///
157/// # Arguments
158/// * `p2` - 2D array of squared pressure values
159/// * `idx` - Indices of rows to include in RMS calculation
160///
161/// # Returns
162/// * Array of SPL values after RMS averaging
163///
164/// # Formula
165/// rms = sqrt(sum(p2\[idx\]) / len) then converted to SPL
166fn apply_rms(p2: &Array2<f64>, idx: &[usize]) -> Array1<f64> {
167    // Vectorized sum using select and sum_axis
168    let selected_rows = p2.select(Axis(0), idx);
169    let sum_rows = selected_rows.sum_axis(Axis(0));
170    let len_idx = idx.len() as f64;
171    let r = sum_rows.mapv(|v| (v / len_idx).sqrt());
172    pressure2spl(&r)
173}
174
175/// Apply weighted RMS averaging to pressure squared values
176///
177/// # Arguments
178/// * `p2` - 2D array of squared pressure values
179/// * `idx` - Indices of rows to include in weighted RMS calculation
180/// * `weights` - Weights for each row
181///
182/// # Returns
183/// * Array of SPL values after weighted RMS averaging
184///
185/// # Formula
186/// weighted_rms = sqrt(sum(p2\[idx\] * weights\[idx\]) / sum(weights)) then converted to SPL
187fn apply_weighted_rms(p2: &Array2<f64>, idx: &[usize], weights: &Array1<f64>) -> Array1<f64> {
188    // Vectorized weighted sum
189    let selected_rows = p2.select(Axis(0), idx);
190    let selected_weights = weights.select(Axis(0), idx);
191    let sum_w = selected_weights.sum();
192
193    // Broadcast weights to match row dimensions and compute weighted sum
194    let weighted_rows = &selected_rows * &selected_weights.insert_axis(Axis(1));
195    let acc = weighted_rows.sum_axis(Axis(0));
196    let r = acc.mapv(|v| (v / sum_w).sqrt());
197    pressure2spl(&r)
198}
199
200/// Compute Mean Absolute Deviation (MAD) for a slice of SPL values
201///
202/// # Arguments
203/// * `spl` - Array of SPL values
204/// * `imin` - Start index (inclusive)
205/// * `imax` - End index (exclusive)
206///
207/// # Returns
208/// * Mean absolute deviation value
209///
210/// # Formula
211/// mad = mean(|x - mean(x)|)
212fn mad(spl: &Array1<f64>, imin: usize, imax: usize) -> f64 {
213    let slice = spl.slice(s![imin..imax]).to_owned();
214    let m = slice.mean().unwrap_or(0.0);
215    let diffs = slice.mapv(|v| (v - m).abs());
216    diffs.mean().unwrap_or(f64::NAN)
217}
218
219/// Compute the coefficient of determination (R-squared) between two arrays
220///
221/// # Arguments
222/// * `x` - First array of values
223/// * `y` - Second array of values
224///
225/// # Returns
226/// * R-squared value (Pearson correlation coefficient squared)
227fn r_squared(x: &Array1<f64>, y: &Array1<f64>) -> f64 {
228    // Vectorized Pearson correlation squared
229    let n = x.len() as f64;
230    if n == 0.0 {
231        return f64::NAN;
232    }
233    let mx = x.mean().unwrap_or(0.0);
234    let my = y.mean().unwrap_or(0.0);
235
236    // Vectorized computation of deviations
237    let dx = x.mapv(|v| v - mx);
238    let dy = y.mapv(|v| v - my);
239
240    let num = (&dx * &dy).sum();
241    let sxx = (&dx * &dx).sum();
242    let syy = (&dy * &dy).sum();
243
244    if sxx == 0.0 || syy == 0.0 {
245        return f64::NAN;
246    }
247    let r = num / (sxx.sqrt() * syy.sqrt());
248    r * r
249}
250
251// ---------------- Pure Rust API below ----------------
252
253/// Compute the CEA2034 spinorama from SPL data
254///
255/// # Arguments
256/// * `spl` - 2D array of SPL measurements
257/// * `idx` - Indices for grouping measurements
258/// * `weights` - Weights for computing weighted averages
259///
260/// # Returns
261/// * 2D array representing the CEA2034 spinorama
262pub fn cea2034(spl: &Array2<f64>, idx: &[Vec<usize>], weights: &Array1<f64>) -> Array2<f64> {
263    cea2034_array(spl, idx, weights)
264}
265
266/// Generate octave band frequencies
267///
268/// # Arguments
269/// * `count` - Number of bands per octave
270///
271/// # Returns
272/// * Vector of tuples representing (low, center, high) frequencies for each band
273///
274/// # Panics
275/// * If count is less than 2
276pub fn octave(count: usize) -> Vec<(f64, f64, f64)> {
277    assert!(count >= 2, "count (N) must be >= 2");
278    let reference = 1290.0_f64;
279    let p = 2.0_f64.powf(1.0 / count as f64);
280    let p_band = 2.0_f64.powf(1.0 / (2.0 * count as f64));
281    let o_iter: i32 = (count as i32 * 10 + 1) / 2;
282    let mut centers: Vec<f64> = Vec::with_capacity((o_iter as usize) * 2 + 1);
283    for i in (1..=o_iter).rev() {
284        centers.push(reference / p.powi(i));
285    }
286    centers.push(reference);
287    for i in 1..=o_iter {
288        let center = reference * p.powi(i);
289        if (center / p_band) <= 20000.0 {
290            centers.push(reference * p.powi(i));
291        }
292    }
293    centers
294        .into_iter()
295        .map(|c| (c / p_band, c, c * p_band))
296        .collect()
297}
298
299/// Compute octave band intervals for a given frequency array
300///
301/// # Arguments
302/// * `count` - Number of bands per octave
303/// * `freq` - Array of frequencies
304///
305/// # Returns
306/// * Vector of tuples representing (start_index, end_index) for each band
307pub fn octave_intervals(count: usize, freq: &Array1<f64>) -> Vec<(usize, usize)> {
308    let bands = octave(count);
309
310    // Python logic: band_min_freq = max(100, min_freq)
311    let min_freq = freq[0];
312    let band_min_freq = 100.0_f64.max(min_freq);
313
314    let mut out = Vec::new();
315    for (low, center, high) in bands.into_iter() {
316        if center < band_min_freq || center > 12000.0 {
317            continue; // skip bands outside desired range
318        }
319        // Match Python: dfu.loc[(dfu.Freq >= band_min) & (dfu.Freq <= band_max)]
320        // Python uses inclusive bounds on both ends
321        let imin = freq.iter().position(|&f| f >= low).unwrap_or(freq.len());
322        let imax = freq.iter().position(|&f| f > high).unwrap_or(freq.len());
323        out.push((imin, imax));
324    }
325    out
326}
327
328/// Compute the Narrow Band Deviation (NBD) metric
329///
330/// # Arguments
331/// * `intervals` - Vector of (start_index, end_index) tuples for frequency bands
332/// * `spl` - SPL measurements
333///
334/// # Returns
335/// * NBD value as f64
336pub fn nbd(intervals: &[(usize, usize)], spl: &Array1<f64>) -> f64 {
337    let mut sum = 0.0;
338    let mut cnt = 0.0;
339    for &(imin, imax) in intervals.iter() {
340        let v = mad(spl, imin, imax);
341        if v.is_finite() {
342            sum += v;
343            cnt += 1.0;
344        }
345    }
346    if cnt == 0.0 { f64::NAN } else { sum / cnt }
347}
348
349/// Compute the Low Frequency Extension (LFX) metric
350///
351/// # Arguments
352/// * `freq` - Frequency array
353/// * `lw` - Listening window SPL measurements
354/// * `sp` - Sound power SPL measurements
355///
356/// # Returns
357/// * LFX value as f64 (log10 of the frequency)
358pub fn lfx(freq: &Array1<f64>, lw: &Array1<f64>, sp: &Array1<f64>) -> f64 {
359    // Match Python behavior:
360    // LW reference is mean(LW) over [300 Hz, 10 kHz], inclusive on both ends.
361    // Implemented by indices: [first f >= 300] .. [first f > 10000]
362    let lw_min = freq.iter().position(|&f| f >= 300.0).unwrap_or(freq.len());
363    let lw_max = freq.iter().position(|&f| f > 10000.0).unwrap_or(freq.len());
364    if lw_min >= lw_max {
365        return (300.0_f64).log10();
366    }
367    let lw_ref = lw.slice(s![lw_min..lw_max]).mean().unwrap_or(0.0) - 6.0;
368    // Collect indices where freq <= 300 Hz AND SP <= (LW_ref)
369    let mut indices: Vec<usize> = Vec::new();
370    for (i, (&f, &spv)) in freq.iter().zip(sp.iter()).enumerate() {
371        if f <= 300.0 && spv <= lw_ref {
372            indices.push(i);
373        }
374    }
375    if indices.is_empty() {
376        // No frequency bin meets the -6 dB criterion → fall back to lowest frequency
377        return freq[0].log10();
378    }
379
380    // Identify the first contiguous group of indices (as in Python implementation)
381    let mut last_idx = indices[0];
382    for &idx in indices.iter().skip(1) {
383        if idx == last_idx + 1 {
384            last_idx = idx;
385        } else {
386            break; // stop at the end of the first consecutive block
387        }
388    }
389
390    // Use the next frequency bin (pos + 1) to align with Python behavior
391    let next_idx = last_idx + 1;
392    if next_idx < freq.len() {
393        freq[next_idx].log10()
394    } else {
395        // Some measurements might end at/below 300 Hz, use default per Python
396        (300.0_f64).log10()
397    }
398}
399
400/// Compute the Smoothness Metric (SM)
401///
402/// # Arguments
403/// * `freq` - Frequency array
404/// * `spl` - SPL measurements
405///
406/// # Returns
407/// * SM value as f64 (R-squared value)
408pub fn sm(freq: &Array1<f64>, spl: &Array1<f64>) -> f64 {
409    let f_min = freq.iter().position(|&f| f > 100.0).unwrap_or(freq.len());
410    let f_max = freq
411        .iter()
412        .position(|&f| f >= 16000.0)
413        .unwrap_or(freq.len());
414    if f_min >= f_max {
415        return f64::NAN;
416    }
417    let x: Array1<f64> = freq.slice(s![f_min..f_max]).mapv(|v| v.log10());
418    let y: Array1<f64> = spl.slice(s![f_min..f_max]).to_owned();
419    r_squared(&x, &y)
420}
421
422/// Metrics computed for the CEA2034 preference score
423#[derive(Debug, Clone)]
424pub struct ScoreMetrics {
425    /// Narrow Band Deviation for on-axis response
426    pub nbd_on: f64,
427    /// Narrow Band Deviation for predicted in-room response
428    pub nbd_pir: f64,
429    /// Low Frequency Extension metric
430    pub lfx: f64,
431    /// Smoothness Metric for predicted in-room response
432    pub sm_pir: f64,
433    /// Overall preference score
434    pub pref_score: f64,
435}
436
437/// Compute all CEA2034 metrics and preference score
438///
439/// # Arguments
440/// * `freq` - Frequency array
441/// * `intervals` - Octave band intervals
442/// * `on` - On-axis SPL measurements
443/// * `lw` - Listening window SPL measurements
444/// * `sp` - Sound power SPL measurements
445/// * `pir` - Predicted in-room SPL measurements
446///
447/// # Returns
448/// * ScoreMetrics struct containing all computed metrics
449pub fn score(
450    freq: &Array1<f64>,
451    intervals: &[(usize, usize)],
452    on: &Array1<f64>,
453    lw: &Array1<f64>,
454    sp: &Array1<f64>,
455    pir: &Array1<f64>,
456) -> ScoreMetrics {
457    let nbd_on = nbd(intervals, on);
458    let nbd_pir = nbd(intervals, pir);
459    let sm_pir = sm(freq, pir);
460    let lfx_val = lfx(freq, lw, sp);
461    let pref = 12.69 - 2.49 * nbd_on - 2.99 * nbd_pir - 4.31 * lfx_val + 2.32 * sm_pir;
462    ScoreMetrics {
463        nbd_on,
464        nbd_pir,
465        lfx: lfx_val,
466        sm_pir,
467        pref_score: pref,
468    }
469}
470
471/// Compute CEA2034 metrics and preference score for a PEQ filter
472///
473/// # Arguments
474/// * `freq` - Frequency array
475/// * `idx` - Indices for grouping measurements
476/// * `intervals` - Octave band intervals
477/// * `weights` - Weights for computing weighted averages
478/// * `spl_h` - Horizontal SPL measurements
479/// * `spl_v` - Vertical SPL measurements
480/// * `peq` - PEQ filter response
481///
482/// # Returns
483/// * Tuple containing (spinorama data, ScoreMetrics)
484///
485/// # Panics
486/// * If peq length doesn't match SPL columns
487pub fn score_peq(
488    freq: &Array1<f64>,
489    idx: &[Vec<usize>],
490    intervals: &[(usize, usize)],
491    weights: &Array1<f64>,
492    spl_h: &Array2<f64>,
493    spl_v: &Array2<f64>,
494    peq: &Array1<f64>,
495) -> (Array2<f64>, ScoreMetrics) {
496    assert_eq!(
497        peq.len(),
498        spl_h.shape()[1],
499        "peq length must match SPL columns"
500    );
501    assert_eq!(
502        peq.len(),
503        spl_v.shape()[1],
504        "peq length must match SPL columns"
505    );
506
507    // add PEQ to each row using broadcasting
508    let peq_broadcast = peq.view().insert_axis(Axis(0));
509    let spl_h_peq = spl_h + &peq_broadcast;
510    let spl_v_peq = spl_v + &peq_broadcast;
511
512    let spl_full =
513        concatenate(Axis(0), &[spl_h_peq.view(), spl_v_peq.view()]).expect("concatenate failed");
514    let spin_nd = cea2034_array(&spl_full, idx, weights);
515
516    // Prepare rows for scoring
517    let on = spl_h_peq.row(17).to_owned();
518    let lw = spin_nd.row(0).to_owned();
519    let sp_row = spin_nd.row(spin_nd.shape()[0] - 2).to_owned();
520    let pir = spin_nd.row(spin_nd.shape()[0] - 1).to_owned();
521
522    let metrics = score(freq, intervals, &on, &lw, &sp_row, &pir);
523    (spin_nd, metrics)
524}
525
526/// Compute approximate CEA2034 metrics and preference score for a PEQ filter
527///
528/// This is a simplified version of score_peq that works directly with pre-computed
529/// LW, SP, and PIR curves rather than computing them from raw measurements.
530///
531/// # Arguments
532/// * `freq` - Frequency array
533/// * `intervals` - Octave band intervals
534/// * `lw` - Listening window SPL measurements
535/// * `sp` - Sound power SPL measurements
536/// * `pir` - Predicted in-room SPL measurements
537/// * `on` - On-axis SPL measurements
538/// * `peq` - PEQ filter response
539///
540/// # Returns
541/// * ScoreMetrics struct containing all computed metrics
542pub fn score_peq_approx(
543    freq: &Array1<f64>,
544    intervals: &[(usize, usize)],
545    lw: &Array1<f64>,
546    sp: &Array1<f64>,
547    pir: &Array1<f64>,
548    on: &Array1<f64>,
549    peq: &Array1<f64>,
550) -> ScoreMetrics {
551    let on2 = on + peq;
552    let lw2 = lw + peq;
553    let sp2 = sp + peq;
554    let pir2 = pir + peq;
555    score(freq, intervals, &on2, &lw2, &sp2, &pir2)
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561
562    #[test]
563    fn octave_count_2_includes_reference_center() {
564        let bands = octave(2);
565        // find the center equal to 1290
566        assert!(bands.iter().any(|&(_l, c, _h)| (c - 1290.0).abs() < 1e-9));
567    }
568
569    #[test]
570    fn nbd_simple_mean_of_mads() {
571        let spl = Array1::from(vec![0.0, 1.0, 2.0, 1.0, 0.0]);
572        // two intervals: [0..3) and [2..5)
573        let intervals = vec![(0, 3), (2, 5)];
574        let v = nbd(&intervals, &spl);
575        assert!(v.is_finite());
576    }
577
578    #[test]
579    fn score_peq_approx_matches_score_when_peq_zero() {
580        // Simple synthetic data
581        let freq = Array1::from(vec![100.0, 1000.0, 10000.0]);
582        let intervals = vec![(0, 3)];
583        let on = Array1::from(vec![80.0, 85.0, 82.0]);
584        let lw = Array1::from(vec![81.0, 84.0, 83.0]);
585        let sp = Array1::from(vec![79.0, 83.0, 81.0]);
586        let pir = Array1::from(vec![80.5, 84.0, 82.0]);
587        let zero = Array1::zeros(freq.len());
588
589        let m1 = score(&freq, &intervals, &on, &lw, &sp, &pir);
590        let m2 = score_peq_approx(&freq, &intervals, &lw, &sp, &pir, &on, &zero);
591
592        assert!((m1.nbd_on - m2.nbd_on).abs() < 1e-12);
593        assert!((m1.nbd_pir - m2.nbd_pir).abs() < 1e-12);
594        assert!((m1.lfx - m2.lfx).abs() < 1e-12);
595        assert!((m1.sm_pir - m2.sm_pir).abs() < 1e-12);
596        assert!((m1.pref_score - m2.pref_score).abs() < 1e-12);
597    }
598
599    #[test]
600    fn lfx_next_bin_after_first_block() {
601        // Frequencies spanning below and above 300 and up to 12k
602        let freq = Array1::from(vec![
603            50.0, 100.0, 200.0, 300.0, 500.0, 1000.0, 5000.0, 10000.0, 12000.0,
604        ]);
605        // LW constant 80 dB; LW_ref = 80 - 6 = 74
606        let lw = Array1::from(vec![80.0; 9]);
607        // SP <= LW_ref for first two bins only (50, 100). First block ends at index 1.
608        // Next bin is index 2 -> 200 Hz
609        let sp = Array1::from(vec![70.0, 73.0, 75.0, 76.0, 80.0, 80.0, 80.0, 80.0, 80.0]);
610        let val = lfx(&freq, &lw, &sp);
611        assert!((val - 200.0_f64.log10()).abs() < 1e-12);
612    }
613
614    #[test]
615    fn lfx_no_indices_falls_back_to_first_freq() {
616        let freq = Array1::from(vec![
617            50.0, 100.0, 200.0, 300.0, 500.0, 1000.0, 5000.0, 10000.0, 12000.0,
618        ]);
619        let lw = Array1::from(vec![80.0; 9]);
620        // All SP > LW_ref (74) for <= 300
621        let sp = Array1::from(vec![75.0, 80.0, 80.0, 80.0, 80.0, 80.0, 80.0, 80.0, 80.0]);
622        let val = lfx(&freq, &lw, &sp);
623        assert!((val - 50.0_f64.log10()).abs() < 1e-12);
624    }
625
626    #[test]
627    fn lfx_next_index_oob_defaults_to_300() {
628        let freq = Array1::from(vec![100.0, 200.0, 300.0]);
629        let lw = Array1::from(vec![80.0, 80.0, 80.0]);
630        // All SP <= LW_ref (74) for <= 300 => indices [0,1,2]; next index OOB
631        let sp = Array1::from(vec![70.0, 70.0, 70.0]);
632        let val = lfx(&freq, &lw, &sp);
633        assert!((val - 300.0_f64.log10()).abs() < 1e-12);
634    }
635}
636
637/// Compute Predicted In-Room (PIR) response from LW, ER, and SP measurements
638///
639/// # Arguments
640/// * `lw` - Listening window SPL measurements
641/// * `er` - Early reflections SPL measurements
642/// * `sp` - Sound power SPL measurements
643///
644/// # Returns
645/// * PIR SPL measurements
646pub fn compute_pir_from_lw_er_sp(
647    lw: &Array1<f64>,
648    er: &Array1<f64>,
649    sp: &Array1<f64>,
650) -> Array1<f64> {
651    let lw_p = spl2pressure(lw);
652    let er_p = spl2pressure(er);
653    let sp_p = spl2pressure(sp);
654    let lw2 = lw_p.mapv(|v| v * v);
655    let er2 = er_p.mapv(|v| v * v);
656    let sp2 = sp_p.mapv(|v| v * v);
657    let pir_p2 = lw2.mapv(|v| 0.12 * v) + &er2.mapv(|v| 0.44 * v) + &sp2.mapv(|v| 0.44 * v);
658    let pir_p = pir_p2.mapv(|v| v.sqrt());
659    pressure2spl(&pir_p)
660}
661
662/// Compute CEA2034 metrics for speaker performance evaluation
663///
664/// # Arguments
665/// * `freq` - Frequency grid for computation
666/// * `cea_plot_data` - Cached plot data (may be updated if fetched)
667/// * `peq` - Optional PEQ response to apply to metrics
668///
669/// # Returns
670/// * Result containing ScoreMetrics or an error
671///
672/// # Details
673/// Computes CEA2034 metrics including preference score, Narrow Band Deviation (NBD),
674/// Low Frequency Extension (LFX), and Smoothness Metric (SM) for various curves.
675pub async fn compute_cea2034_metrics(
676    freq: &Array1<f64>,
677    cea2034_data: &HashMap<String, Curve>,
678    peq: Option<&Array1<f64>>,
679) -> Result<ScoreMetrics, Box<dyn Error>> {
680    let on = &cea2034_data.get("On Axis").unwrap().spl;
681    let lw = &cea2034_data.get("Listening Window").unwrap().spl;
682    let sp = &cea2034_data.get("Sound Power").unwrap().spl;
683    let pir = &cea2034_data.get("Estimated In-Room Response").unwrap().spl;
684
685    // 1/2 octave intervals for band metrics
686    let intervals = octave_intervals(2, freq);
687
688    // Use provided PEQ or assume zero PEQ
689    let peq_arr = peq.cloned().unwrap_or_else(|| Array1::zeros(freq.len()));
690
691    Ok(score_peq_approx(
692        freq, &intervals, lw, sp, pir, on, &peq_arr,
693    ))
694}
695
696#[cfg(test)]
697mod pir_helpers_tests {
698    use super::{compute_pir_from_lw_er_sp, pressure2spl, spl2pressure};
699    use crate::Curve;
700    use ndarray::Array1;
701    use std::collections::HashMap;
702
703    // Helpers to encode f64 arrays into the Plotly-typed array base64 format used in read.rs
704    fn _le_f64_bytes(vals: &[f64]) -> Vec<u8> {
705        let mut out = Vec::with_capacity(vals.len() * 8);
706        for v in vals {
707            out.extend_from_slice(&v.to_bits().to_le_bytes());
708        }
709        out
710    }
711
712    fn _base64_encode(bytes: &[u8]) -> String {
713        let alphabet = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
714        let mut out = String::new();
715        let mut i = 0usize;
716        while i < bytes.len() {
717            let b0 = bytes[i] as u32;
718            let b1 = if i + 1 < bytes.len() {
719                bytes[i + 1] as u32
720            } else {
721                0
722            };
723            let b2 = if i + 2 < bytes.len() {
724                bytes[i + 2] as u32
725            } else {
726                0
727            };
728
729            let idx0 = (b0 >> 2) & 0x3F;
730            let idx1 = ((b0 & 0x03) << 4) | ((b1 >> 4) & 0x0F);
731            let idx2 = ((b1 & 0x0F) << 2) | ((b2 >> 6) & 0x03);
732            let idx3 = b2 & 0x3F;
733
734            out.push(alphabet[idx0 as usize] as char);
735            out.push(alphabet[idx1 as usize] as char);
736            if i + 1 < bytes.len() {
737                out.push(alphabet[idx2 as usize] as char);
738            } else {
739                out.push('=');
740            }
741            if i + 2 < bytes.len() {
742                out.push(alphabet[idx3 as usize] as char);
743            } else {
744                out.push('=');
745            }
746
747            i += 3;
748        }
749        out
750    }
751
752    #[test]
753    fn spl_pressure_roundtrip_is_identity() {
754        let spl = Array1::from(vec![60.0, 80.0, 100.0]);
755        let p = spl2pressure(&spl);
756        let spl2 = pressure2spl(&p);
757        for (a, b) in spl.iter().zip(spl2.iter()) {
758            assert!((a - b).abs() < 1e-12);
759        }
760    }
761
762    #[test]
763    fn pir_equals_input_when_all_equal() {
764        let lw = Array1::from(vec![80.0, 80.0, 80.0]);
765        let er = Array1::from(vec![80.0, 80.0, 80.0]);
766        let sp = Array1::from(vec![80.0, 80.0, 80.0]);
767        let pir = compute_pir_from_lw_er_sp(&lw, &er, &sp);
768        for v in pir.iter() {
769            assert!((*v - 80.0).abs() < 1e-12);
770        }
771    }
772
773    #[test]
774    fn pir_reflects_er_sp_weighting() {
775        // ER and SP have higher weights than LW (0.44 each vs 0.12)
776        let lw = Array1::from(vec![70.0, 70.0, 70.0]);
777        let er = Array1::from(vec![80.0, 80.0, 80.0]);
778        let sp = Array1::from(vec![80.0, 80.0, 80.0]);
779        let pir = compute_pir_from_lw_er_sp(&lw, &er, &sp);
780        for v in pir.iter() {
781            assert!(*v > 75.0 && *v < 81.0);
782        }
783    }
784
785    #[tokio::test]
786    async fn metrics_with_precomputed_curves() {
787        use super::{compute_cea2034_metrics, octave_intervals, score};
788
789        // Simple two-point dataset
790        let freq = Array1::from(vec![100.0, 1000.0]);
791        let on_vals = Array1::from(vec![80.0_f64, 85.0_f64]);
792        let lw_vals = Array1::from(vec![81.0_f64, 84.0_f64]);
793        let er_vals = Array1::from(vec![79.0_f64, 83.0_f64]);
794        let sp_vals = Array1::from(vec![78.0_f64, 82.0_f64]);
795
796        // Precompute PIR from LW/ER/SP
797        let pir_vals = compute_pir_from_lw_er_sp(&lw_vals, &er_vals, &sp_vals);
798
799        // Build CEA2034 data map expected by the helper
800        let mut cea2034_data: HashMap<String, Curve> = HashMap::new();
801        cea2034_data.insert(
802            "On Axis".to_string(),
803            Curve {
804                freq: freq.clone(),
805                spl: on_vals.clone(),
806                phase: None,
807            },
808        );
809        cea2034_data.insert(
810            "Listening Window".to_string(),
811            Curve {
812                freq: freq.clone(),
813                spl: lw_vals.clone(),
814                phase: None,
815            },
816        );
817        cea2034_data.insert(
818            "Sound Power".to_string(),
819            Curve {
820                freq: freq.clone(),
821                spl: sp_vals.clone(),
822                phase: None,
823            },
824        );
825        cea2034_data.insert(
826            "Estimated In-Room Response".to_string(),
827            Curve {
828                freq: freq.clone(),
829                spl: pir_vals.clone(),
830                phase: None,
831            },
832        );
833
834        // Compute using the async helper
835        let got = compute_cea2034_metrics(&freq, &cea2034_data, None)
836            .await
837            .expect("metrics");
838
839        // Build expected
840        let intervals = octave_intervals(2, &freq);
841        let expected = score(&freq, &intervals, &on_vals, &lw_vals, &sp_vals, &pir_vals);
842
843        assert!((got.nbd_on - expected.nbd_on).abs() < 1e-12);
844        assert!((got.nbd_pir - expected.nbd_pir).abs() < 1e-12);
845        assert!((got.lfx - expected.lfx).abs() < 1e-12);
846        if got.sm_pir.is_nan() && expected.sm_pir.is_nan() {
847            // ok
848        } else {
849            assert!((got.sm_pir - expected.sm_pir).abs() < 1e-12);
850        }
851        if got.pref_score.is_nan() && expected.pref_score.is_nan() {
852            // ok
853        } else {
854            assert!((got.pref_score - expected.pref_score).abs() < 1e-12);
855        }
856    }
857
858    #[tokio::test]
859    async fn metrics_with_precomputed_curves_and_peq_matches_approx() {
860        use super::{compute_cea2034_metrics, octave_intervals, score_peq_approx};
861
862        // Simple two-point dataset
863        let freq = Array1::from(vec![100.0, 1000.0]);
864        let on_vals = Array1::from(vec![80.0_f64, 85.0_f64]);
865        let lw_vals = Array1::from(vec![81.0_f64, 84.0_f64]);
866        let er_vals = Array1::from(vec![79.0_f64, 83.0_f64]);
867        let sp_vals = Array1::from(vec![78.0_f64, 82.0_f64]);
868
869        // Precompute PIR from LW/ER/SP
870        let pir_vals = compute_pir_from_lw_er_sp(&lw_vals, &er_vals, &sp_vals);
871
872        // Build CEA2034 data map expected by the helper
873        let mut cea2034_data: HashMap<String, Curve> = HashMap::new();
874        cea2034_data.insert(
875            "On Axis".to_string(),
876            Curve {
877                freq: freq.clone(),
878                spl: on_vals.clone(),
879                phase: None,
880            },
881        );
882        cea2034_data.insert(
883            "Listening Window".to_string(),
884            Curve {
885                freq: freq.clone(),
886                spl: lw_vals.clone(),
887                phase: None,
888            },
889        );
890        cea2034_data.insert(
891            "Sound Power".to_string(),
892            Curve {
893                freq: freq.clone(),
894                spl: sp_vals.clone(),
895                phase: None,
896            },
897        );
898        cea2034_data.insert(
899            "Estimated In-Room Response".to_string(),
900            Curve {
901                freq: freq.clone(),
902                spl: pir_vals.clone(),
903                phase: None,
904            },
905        );
906
907        // A simple PEQ response
908        let peq = Array1::from(vec![1.0_f64, -1.0_f64]);
909
910        // Compute using the async helper with PEQ
911        let got = compute_cea2034_metrics(&freq, &cea2034_data, Some(&peq))
912            .await
913            .expect("metrics with peq");
914
915        // Build expected using the approximation helper
916        let intervals = octave_intervals(2, &freq);
917        let expected = score_peq_approx(
918            &freq, &intervals, &lw_vals, &sp_vals, &pir_vals, &on_vals, &peq,
919        );
920
921        assert!((got.nbd_on - expected.nbd_on).abs() < 1e-12);
922        assert!((got.nbd_pir - expected.nbd_pir).abs() < 1e-12);
923        assert!((got.lfx - expected.lfx).abs() < 1e-12);
924        if got.sm_pir.is_nan() && expected.sm_pir.is_nan() {
925            // ok
926        } else {
927            assert!((got.sm_pir - expected.sm_pir).abs() < 1e-12);
928        }
929        if got.pref_score.is_nan() && expected.pref_score.is_nan() {
930            // ok
931        } else {
932            assert!((got.pref_score - expected.pref_score).abs() < 1e-12);
933        }
934    }
935}