Skip to main content

audiofp/classical/
wang.rs

1//! Wang-style landmark fingerprinter.
2//!
3//! The algorithm (Avery Wang, "An Industrial-Strength Audio Search
4//! Algorithm", 2003 — the "Shazam paper"):
5//!
6//! 1. Resample the input to 8 kHz mono *(caller's responsibility)*.
7//! 2. Take a Hann-windowed STFT with `n_fft = 1024`, `hop = 128` →
8//!    62.5 frames/s, 513 frequency bins.
9//! 3. Convert the magnitude spectrogram to dB log-magnitude.
10//! 4. Pick spectral peaks in a 31×31 neighbourhood, capped at 30/s.
11//! 5. For each anchor peak, take the strongest `fan_out` peaks within
12//!    `Δt ∈ [1, target_zone_t]` and `|Δf| ≤ target_zone_f`; pack each
13//!    `(anchor, target)` pair into a 32-bit hash.
14//!
15//! Hash layout (high to low bit):
16//! ```text
17//! [31..23]  f_a_q  (9 bits, anchor frequency, quantised to 512 buckets)
18//! [22..14]  f_b_q  (9 bits, target frequency, same quantisation)
19//! [13.. 0]  Δt     (14 bits, frames between anchor and target, clamped 1..=16383)
20//! ```
21
22use alloc::vec::Vec;
23
24use libm::log10f;
25
26use crate::dsp::peaks::{Peak, PeakPicker, PeakPickerConfig};
27use crate::dsp::stft::{ShortTimeFFT, StftConfig};
28use crate::dsp::windows::WindowKind;
29use crate::{AfpError, AudioBuffer, Fingerprinter, Result, StreamingFingerprinter, TimestampMs};
30
31/// One anchor-target landmark pair packed into a 32-bit hash.
32#[repr(C)]
33#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, bytemuck::Pod, bytemuck::Zeroable)]
34pub struct WangHash {
35    /// 32-bit hash: `f_a_q (9) | f_b_q (9) | Δt (14)`, MSB first.
36    pub hash: u32,
37    /// STFT frame index of the anchor peak.
38    pub t_anchor: u32,
39}
40
41/// All hashes produced by [`Wang`] over an audio buffer.
42#[derive(Clone, Debug)]
43pub struct WangFingerprint {
44    /// Hashes sorted by `(t_anchor, hash)`.
45    pub hashes: Vec<WangHash>,
46    /// Frame rate of the underlying STFT — always 62.5 for `wang-v1`
47    /// (`8000 / 128`).
48    pub frames_per_sec: f32,
49}
50
51/// Tunable parameters for [`Wang`].
52#[derive(Clone, Debug)]
53pub struct WangConfig {
54    /// `F`: target peaks paired with each anchor. Default 10; embedded
55    /// builds typically lower this to 5.
56    pub fan_out: u16,
57    /// Maximum `Δt` (frames) between anchor and target. Default 63.
58    pub target_zone_t: u16,
59    /// Maximum `|Δf|` (FFT bins) between anchor and target. Default 64.
60    pub target_zone_f: u16,
61    /// Per-second cap on peak count. Default 30.
62    pub peaks_per_sec: u16,
63    /// Magnitude floor (dB) below which peaks are ignored. Default −50.
64    pub min_anchor_mag_db: f32,
65}
66
67impl Default for WangConfig {
68    fn default() -> Self {
69        Self {
70            fan_out: 10,
71            target_zone_t: 63,
72            target_zone_f: 64,
73            peaks_per_sec: 30,
74            min_anchor_mag_db: -50.0,
75        }
76    }
77}
78
79const WANG_N_FFT: usize = 1024;
80const WANG_HOP: usize = 128;
81const WANG_SR: u32 = 8_000;
82const WANG_FRAMES_PER_SEC: f32 = WANG_SR as f32 / WANG_HOP as f32;
83/// Quantisation buckets for the 9-bit frequency field.
84const WANG_FREQ_BUCKETS: u32 = 512;
85const WANG_PEAK_NEIGHBOURHOOD: usize = 15;
86const WANG_LOG_FLOOR: f32 = 1e-6;
87/// Squared form of the magnitude floor — fed to `log10(power)` instead of
88/// `log10(magnitude)`, which lets us skip the per-bin `sqrt` in STFT.
89/// Equivalent to `WANG_LOG_FLOOR.powi(2)`.
90const WANG_LOG_FLOOR_POWER: f32 = WANG_LOG_FLOOR * WANG_LOG_FLOOR;
91
92/// Wang offline fingerprinter.
93///
94/// # Example
95///
96/// ```
97/// use audiofp::{AudioBuffer, Fingerprinter, SampleRate};
98/// use audiofp::classical::Wang;
99///
100/// let mut fp = Wang::default();
101/// // 3 seconds of silence — produces an empty fingerprint, not an error.
102/// let samples = vec![0.0_f32; 8_000 * 3];
103/// let buf = AudioBuffer { samples: &samples, rate: SampleRate::HZ_8000 };
104/// let fpr = fp.extract(buf).unwrap();
105/// assert_eq!(fpr.frames_per_sec, 62.5);
106/// assert!(fpr.hashes.is_empty());
107/// ```
108pub struct Wang {
109    cfg: WangConfig,
110    stft: ShortTimeFFT,
111    /// Cached peak picker — pools its scratch buffers across calls so
112    /// repeated `extract` invocations don't re-allocate.
113    picker: PeakPicker,
114    /// Pooled log-magnitude buffer reused between calls.
115    log_spec: Vec<f32>,
116}
117
118impl Default for Wang {
119    fn default() -> Self {
120        Self::new(WangConfig::default())
121    }
122}
123
124impl Wang {
125    /// Build a Wang extractor with the given config.
126    #[must_use]
127    pub fn new(cfg: WangConfig) -> Self {
128        let stft = ShortTimeFFT::new(StftConfig {
129            n_fft: WANG_N_FFT,
130            hop: WANG_HOP,
131            window: WindowKind::Hann,
132            // No reflect-padding: hashes are most stable when the first
133            // frame starts at sample 0 of the input buffer.
134            center: false,
135        });
136        let picker = PeakPicker::new(PeakPickerConfig {
137            neighborhood_t: WANG_PEAK_NEIGHBOURHOOD,
138            neighborhood_f: WANG_PEAK_NEIGHBOURHOOD,
139            min_magnitude: cfg.min_anchor_mag_db,
140            target_per_sec: cfg.peaks_per_sec as usize,
141        });
142        Self {
143            cfg,
144            stft,
145            picker,
146            log_spec: Vec::new(),
147        }
148    }
149}
150
151impl Fingerprinter for Wang {
152    type Output = WangFingerprint;
153    type Config = WangConfig;
154
155    fn name(&self) -> &'static str {
156        "wang-v1"
157    }
158
159    fn config(&self) -> &Self::Config {
160        &self.cfg
161    }
162
163    fn required_sample_rate(&self) -> u32 {
164        WANG_SR
165    }
166
167    fn min_samples(&self) -> usize {
168        WANG_SR as usize * 2
169    }
170
171    fn extract(&mut self, audio: AudioBuffer<'_>) -> Result<Self::Output> {
172        if audio.rate.hz() != WANG_SR {
173            return Err(AfpError::UnsupportedSampleRate(audio.rate.hz()));
174        }
175        if audio.samples.len() < self.min_samples() {
176            return Err(AfpError::AudioTooShort {
177                needed: self.min_samples(),
178                got: audio.samples.len(),
179            });
180        }
181
182        // Compute power (|X|²) directly from the FFT — skips a per-bin
183        // sqrt that the dB conversion would immediately undo.
184        // 20 · log10(sqrt(p)) ≡ 10 · log10(p).
185        let (power_flat, n_frames, n_bins) = self.stft.power_flat(audio.samples);
186        if n_frames == 0 {
187            return Ok(WangFingerprint {
188                hashes: Vec::new(),
189                frames_per_sec: WANG_FRAMES_PER_SEC,
190            });
191        }
192
193        // Convert power → dB log-magnitude in-place into the pooled buffer.
194        self.log_spec.clear();
195        self.log_spec.resize(power_flat.len(), 0.0);
196        for (i, &p) in power_flat.iter().enumerate() {
197            self.log_spec[i] = 10.0 * log10f(p.max(WANG_LOG_FLOOR_POWER));
198        }
199
200        let peaks = self
201            .picker
202            .pick(&self.log_spec, n_frames, n_bins, WANG_FRAMES_PER_SEC);
203
204        let mut hashes = build_hashes(&peaks, &self.cfg);
205        // Stable, deterministic ordering for round-trip and golden tests.
206        hashes.sort_unstable_by_key(|h| (h.t_anchor, h.hash));
207
208        Ok(WangFingerprint {
209            hashes,
210            frames_per_sec: WANG_FRAMES_PER_SEC,
211        })
212    }
213}
214
215/// Wrapper that orders `Peak`s such that the **smallest** magnitude (with
216/// the largest position as tiebreak) compares **greatest**. Used as the
217/// element of a max-heap to maintain the top-K largest candidates with
218/// `O(N log K)` work instead of an `O(N log N)` full sort.
219#[derive(Copy, Clone)]
220struct MinByMag<'a>(&'a Peak);
221
222impl PartialEq for MinByMag<'_> {
223    fn eq(&self, o: &Self) -> bool {
224        self.0.mag == o.0.mag && self.0.t_frame == o.0.t_frame && self.0.f_bin == o.0.f_bin
225    }
226}
227impl Eq for MinByMag<'_> {}
228impl PartialOrd for MinByMag<'_> {
229    fn partial_cmp(&self, o: &Self) -> Option<core::cmp::Ordering> {
230        Some(self.cmp(o))
231    }
232}
233impl Ord for MinByMag<'_> {
234    fn cmp(&self, o: &Self) -> core::cmp::Ordering {
235        // Reverse mag ordering (smallest first). Reverse position ordering
236        // (largest position first) so the final sort's deterministic
237        // (mag desc, pos asc) ordering still wins for kept elements.
238        o.0.mag
239            .partial_cmp(&self.0.mag)
240            .unwrap_or(core::cmp::Ordering::Equal)
241            .then_with(|| (o.0.t_frame, o.0.f_bin).cmp(&(self.0.t_frame, self.0.f_bin)))
242    }
243}
244
245/// Walk `peaks` (sorted by `(t_frame, f_bin)`) and emit landmark hashes.
246fn build_hashes(peaks: &[Peak], cfg: &WangConfig) -> Vec<WangHash> {
247    let mut hashes = Vec::with_capacity(peaks.len() * cfg.fan_out as usize);
248    let target_zone_t = cfg.target_zone_t as i32;
249    let target_zone_f = cfg.target_zone_f as i32;
250    let fan_out = cfg.fan_out as usize;
251
252    let mut heap: alloc::collections::BinaryHeap<MinByMag> =
253        alloc::collections::BinaryHeap::with_capacity(fan_out + 1);
254    let mut targets: Vec<&Peak> = Vec::with_capacity(fan_out);
255
256    for (i, anchor) in peaks.iter().enumerate() {
257        heap.clear();
258        for target in &peaks[i + 1..] {
259            let dt = target.t_frame as i32 - anchor.t_frame as i32;
260            if dt < 1 {
261                continue;
262            }
263            if dt > target_zone_t {
264                // Peaks are sorted by t_frame, so once we exceed the zone
265                // for this anchor, no later peak can fit either.
266                break;
267            }
268            let df = target.f_bin as i32 - anchor.f_bin as i32;
269            if df.abs() > target_zone_f {
270                continue;
271            }
272            heap.push(MinByMag(target));
273            if heap.len() > fan_out {
274                // Drop the current smallest — the heap top, by our reversed Ord.
275                heap.pop();
276            }
277        }
278
279        // Drain the heap and re-sort the kept K for deterministic emission.
280        targets.clear();
281        targets.extend(heap.drain().map(|w| w.0));
282        targets.sort_unstable_by(|a, b| {
283            b.mag
284                .partial_cmp(&a.mag)
285                .unwrap_or(core::cmp::Ordering::Equal)
286                .then_with(|| (a.t_frame, a.f_bin).cmp(&(b.t_frame, b.f_bin)))
287        });
288
289        for target in &targets {
290            let f_a_q = quantise_freq(anchor.f_bin);
291            let f_b_q = quantise_freq(target.f_bin);
292            let dt = ((target.t_frame - anchor.t_frame) & 0x3FFF).max(1);
293            let hash = ((f_a_q & 0x1FF) << 23) | ((f_b_q & 0x1FF) << 14) | (dt & 0x3FFF);
294            hashes.push(WangHash {
295                hash,
296                t_anchor: anchor.t_frame,
297            });
298        }
299    }
300    hashes
301}
302
303/// FFT gives 513 bins; pack into 9 bits (512 buckets) per spec.
304#[inline]
305fn quantise_freq(bin: u16) -> u32 {
306    (bin as u32 * WANG_FREQ_BUCKETS) / 513
307}
308
309/// Owned wrapper around `Peak` whose `Ord` reverses magnitude (and
310/// position tiebreak), so a `BinaryHeap<MinByMagOwned>` of size `K`
311/// behaves as a min-heap that retains the top-K largest peaks.
312#[derive(Copy, Clone)]
313struct MinByMagOwned(Peak);
314
315impl PartialEq for MinByMagOwned {
316    fn eq(&self, o: &Self) -> bool {
317        self.0.mag == o.0.mag && self.0.t_frame == o.0.t_frame && self.0.f_bin == o.0.f_bin
318    }
319}
320impl Eq for MinByMagOwned {}
321impl PartialOrd for MinByMagOwned {
322    fn partial_cmp(&self, o: &Self) -> Option<core::cmp::Ordering> {
323        Some(self.cmp(o))
324    }
325}
326impl Ord for MinByMagOwned {
327    fn cmp(&self, o: &Self) -> core::cmp::Ordering {
328        o.0.mag
329            .partial_cmp(&self.0.mag)
330            .unwrap_or(core::cmp::Ordering::Equal)
331            .then_with(|| (o.0.t_frame, o.0.f_bin).cmp(&(self.0.t_frame, self.0.f_bin)))
332    }
333}
334
335/// Anchor pending finalisation, with its top-K target heap.
336struct PendingAnchor {
337    peak: Peak,
338    targets: alloc::collections::BinaryHeap<MinByMagOwned>,
339}
340
341/// Streaming Wang fingerprinter — fully incremental.
342///
343/// Maintains a rolling spectrogram window (`2·neighborhood_t + 1` rows),
344/// detects peaks frame-by-frame as they ripen, accumulates per-second
345/// candidate buckets, finalises buckets via the per-second adaptive
346/// threshold, and grows per-anchor target heaps until each anchor's
347/// target zone is fully observed. Per-push CPU cost is proportional to
348/// the number of new frames (not the total stream length).
349///
350/// The output hash multiset matches what [`Wang::extract`] would produce
351/// for the same total input — verified by the `streaming_offline_*`
352/// tests, including the 1-sample-per-push pathological case.
353///
354/// # Example
355///
356/// ```
357/// use audiofp::{SampleRate, StreamingFingerprinter};
358/// use audiofp::classical::StreamingWang;
359///
360/// let mut s = StreamingWang::default();
361/// // Feed 4 seconds of silence in two chunks; nothing should be emitted.
362/// let zeros = vec![0.0_f32; 8_000 * 2];
363/// assert!(s.push(&zeros).is_empty());
364/// assert!(s.push(&zeros).is_empty());
365/// assert!(s.flush().is_empty());
366/// ```
367pub struct StreamingWang {
368    cfg: WangConfig,
369
370    // Front-end.
371    stft: ShortTimeFFT,
372    sample_carry: alloc::vec::Vec<f32>,
373
374    // Rolling log-power spectrogram window (contiguous, row-major).
375    // Capacity = `2*nbht + 1` rows.
376    spec: alloc::vec::Vec<f32>,
377    spec_n_rows: usize,
378    spec_n_bins: usize,
379    spec_first_frame: u32,
380
381    // Frame counter and detection cursor.
382    n_frames_total: u32,
383    last_pd_frame: i32,
384
385    // Pooled peak-detection scratch.
386    pd_max: alloc::vec::Vec<f32>,
387    pd_temp: alloc::vec::Vec<f32>,
388    pd_col_in: alloc::vec::Vec<f32>,
389    pd_col_out: alloc::vec::Vec<f32>,
390
391    // Reusable scratch row for STFT output.
392    frame_scratch: alloc::vec::Vec<f32>,
393
394    // Per-second adaptive thresholding.
395    bucket_pending: alloc::collections::BTreeMap<u32, alloc::vec::Vec<Peak>>,
396    last_finalized_bucket: i32,
397
398    // Anchors awaiting finalisation, in t-order.
399    pending_anchors: alloc::collections::VecDeque<PendingAnchor>,
400}
401
402impl Default for StreamingWang {
403    fn default() -> Self {
404        Self::new(WangConfig::default())
405    }
406}
407
408impl StreamingWang {
409    /// Build a streaming Wang extractor with the given config.
410    #[must_use]
411    pub fn new(cfg: WangConfig) -> Self {
412        let stft = ShortTimeFFT::new(StftConfig {
413            n_fft: WANG_N_FFT,
414            hop: WANG_HOP,
415            window: WindowKind::Hann,
416            center: false,
417        });
418        let n_bins = stft.n_bins();
419        let window_capacity = 2 * WANG_PEAK_NEIGHBOURHOOD + 1;
420        Self {
421            cfg,
422            stft,
423            sample_carry: alloc::vec::Vec::new(),
424            spec: alloc::vec![0.0_f32; window_capacity * n_bins],
425            spec_n_rows: 0,
426            spec_n_bins: n_bins,
427            spec_first_frame: 0,
428            n_frames_total: 0,
429            last_pd_frame: -1,
430            pd_max: alloc::vec::Vec::new(),
431            pd_temp: alloc::vec::Vec::new(),
432            pd_col_in: alloc::vec::Vec::new(),
433            pd_col_out: alloc::vec::Vec::new(),
434            frame_scratch: alloc::vec![0.0_f32; n_bins],
435            bucket_pending: alloc::collections::BTreeMap::new(),
436            last_finalized_bucket: -1,
437            pending_anchors: alloc::collections::VecDeque::new(),
438        }
439    }
440
441    /// Borrow the configuration this stream was built with.
442    #[must_use]
443    pub fn config(&self) -> &WangConfig {
444        &self.cfg
445    }
446
447    /// Frames an anchor must have *after* it before all of its targets
448    /// are observed. Used only for [`latency_ms`] — emission timing in
449    /// the incremental implementation is driven by anchor finalisation.
450    ///
451    /// [`latency_ms`]: StreamingWang::latency_ms
452    fn lookahead_frames(&self) -> u32 {
453        self.cfg.target_zone_t as u32
454            + WANG_PEAK_NEIGHBOURHOOD as u32
455            + WANG_FRAMES_PER_SEC.ceil() as u32
456    }
457
458    /// Append the current contents of `self.frame_scratch` to the
459    /// rolling spec buffer, dropping the oldest row if at capacity.
460    /// Avoids the per-frame `Vec::clone` the borrow checker would
461    /// otherwise force on a `(&mut self, &[f32])` signature.
462    fn append_frame_scratch_row(&mut self) {
463        debug_assert_eq!(self.frame_scratch.len(), self.spec_n_bins);
464        let cap = 2 * WANG_PEAK_NEIGHBOURHOOD + 1;
465        if self.spec_n_rows == cap {
466            self.spec.copy_within(self.spec_n_bins.., 0);
467            self.spec_first_frame += 1;
468            self.spec_n_rows -= 1;
469        }
470        let dst_start = self.spec_n_rows * self.spec_n_bins;
471        let n_bins = self.spec_n_bins;
472        // Disjoint borrow: `self.spec` (mut) and `self.frame_scratch`
473        // (shared) are different fields of `self`, so this is sound.
474        self.spec[dst_start..dst_start + n_bins].copy_from_slice(&self.frame_scratch);
475        self.spec_n_rows += 1;
476    }
477
478    /// Run rolling-max on the current spec buffer and extract peaks at
479    /// rows `[from_row_inclusive, to_row_inclusive]` (in spec-buffer-relative
480    /// indices). Push survivors into [`bucket_pending`].
481    fn detect_rows(&mut self, from_row: usize, to_row: usize) {
482        if self.spec_n_rows == 0 || from_row > to_row {
483            return;
484        }
485        let n_rows = self.spec_n_rows;
486        let n_bins = self.spec_n_bins;
487        let used = n_rows * n_bins;
488
489        self.pd_max.clear();
490        self.pd_max.resize(used, 0.0);
491        self.pd_temp.clear();
492        self.pd_temp.resize(used, 0.0);
493        self.pd_col_in.clear();
494        self.pd_col_in.resize(n_rows, 0.0);
495        self.pd_col_out.clear();
496        self.pd_col_out.resize(n_rows, 0.0);
497
498        crate::dsp::peaks::rolling_max_2d_pooled(
499            &self.spec[..used],
500            n_rows,
501            n_bins,
502            WANG_PEAK_NEIGHBOURHOOD,
503            WANG_PEAK_NEIGHBOURHOOD,
504            &mut self.pd_max,
505            &mut self.pd_temp,
506            &mut self.pd_col_in,
507            &mut self.pd_col_out,
508        );
509
510        for row in from_row..=to_row {
511            if row >= n_rows {
512                break;
513            }
514            let abs_f = self.spec_first_frame + row as u32;
515            let bucket = (abs_f as f32 / WANG_FRAMES_PER_SEC) as u32;
516            for bin in 0..n_bins {
517                let idx = row * n_bins + bin;
518                let v = self.spec[idx];
519                if v > self.cfg.min_anchor_mag_db && v >= self.pd_max[idx] {
520                    let peak = Peak {
521                        t_frame: abs_f,
522                        f_bin: bin as u16,
523                        _pad: 0,
524                        mag: v,
525                    };
526                    self.bucket_pending.entry(bucket).or_default().push(peak);
527                }
528            }
529        }
530    }
531
532    /// Finalise one bucket: apply per-second adaptive threshold (top
533    /// `peaks_per_sec` by magnitude), then for each surviving peak in
534    /// `(t, f)` order, grow target heaps of older anchors and register
535    /// the peak as a new anchor.
536    fn finalize_bucket(&mut self, bucket: u32) {
537        let mut peaks = match self.bucket_pending.remove(&bucket) {
538            Some(p) => p,
539            None => return,
540        };
541        // Match the offline picker's `adaptive_per_second`: sort by mag
542        // desc only, no positional tiebreak.
543        peaks.sort_unstable_by(|a, b| {
544            b.mag
545                .partial_cmp(&a.mag)
546                .unwrap_or(core::cmp::Ordering::Equal)
547        });
548        peaks.truncate(self.cfg.peaks_per_sec as usize);
549        // Re-sort by `(t, f)` so downstream iteration matches the offline
550        // hash builder's order.
551        peaks.sort_unstable_by_key(|p| (p.t_frame, p.f_bin));
552
553        let target_zone_t = self.cfg.target_zone_t as i32;
554        let target_zone_f = self.cfg.target_zone_f as i32;
555        let fan_out = self.cfg.fan_out as usize;
556
557        for peak in peaks {
558            // Add this peak as a TARGET to every still-pending anchor whose
559            // zone covers it.
560            for anchor in self.pending_anchors.iter_mut() {
561                let dt = peak.t_frame as i32 - anchor.peak.t_frame as i32;
562                if dt < 1 || dt > target_zone_t {
563                    continue;
564                }
565                let df = peak.f_bin as i32 - anchor.peak.f_bin as i32;
566                if df.abs() > target_zone_f {
567                    continue;
568                }
569                anchor.targets.push(MinByMagOwned(peak));
570                if anchor.targets.len() > fan_out {
571                    anchor.targets.pop();
572                }
573            }
574            // Register this peak as a new ANCHOR.
575            self.pending_anchors.push_back(PendingAnchor {
576                peak,
577                targets: alloc::collections::BinaryHeap::with_capacity(fan_out + 1),
578            });
579        }
580        self.last_finalized_bucket = bucket as i32;
581    }
582
583    /// Finalise every bucket whose ALL frames have been peak-detected.
584    /// Conservative: bucket B is finalisable iff `bucket(last_pd_frame) > B`.
585    fn finalize_buckets(&mut self) {
586        if self.last_pd_frame < 0 {
587            return;
588        }
589        let current_bucket = (self.last_pd_frame as f32 / WANG_FRAMES_PER_SEC) as i32;
590        let to_finalize: alloc::vec::Vec<u32> = self
591            .bucket_pending
592            .keys()
593            .filter(|&&b| (b as i32) > self.last_finalized_bucket && (b as i32) < current_bucket)
594            .cloned()
595            .collect();
596        for bucket in to_finalize {
597            self.finalize_bucket(bucket);
598        }
599    }
600
601    /// Pop anchors whose target zone is fully observed (i.e. the bucket
602    /// containing the last possible target frame has been finalised),
603    /// build hashes from their accumulated target heap, and return them.
604    fn emit_finalized_anchors(&mut self) -> alloc::vec::Vec<(TimestampMs, WangHash)> {
605        let mut emitted = alloc::vec::Vec::new();
606        while let Some(front) = self.pending_anchors.front() {
607            let last_target_frame = front.peak.t_frame + self.cfg.target_zone_t as u32;
608            let last_target_bucket = (last_target_frame as f32 / WANG_FRAMES_PER_SEC) as i32;
609            if self.last_finalized_bucket < last_target_bucket {
610                break;
611            }
612            let anchor = self.pending_anchors.pop_front().unwrap();
613            self.build_hashes_for_anchor(anchor, &mut emitted);
614        }
615        emitted
616    }
617
618    /// Drain an anchor's target heap, sort by `(mag desc, position asc)`
619    /// for deterministic emission, then emit the corresponding hashes.
620    fn build_hashes_for_anchor(
621        &self,
622        anchor: PendingAnchor,
623        out: &mut alloc::vec::Vec<(TimestampMs, WangHash)>,
624    ) {
625        let mut targets: alloc::vec::Vec<Peak> = anchor.targets.into_iter().map(|w| w.0).collect();
626        targets.sort_unstable_by(|a, b| {
627            b.mag
628                .partial_cmp(&a.mag)
629                .unwrap_or(core::cmp::Ordering::Equal)
630                .then_with(|| (a.t_frame, a.f_bin).cmp(&(b.t_frame, b.f_bin)))
631        });
632        for target in &targets {
633            let f_a_q = quantise_freq(anchor.peak.f_bin);
634            let f_b_q = quantise_freq(target.f_bin);
635            let dt = ((target.t_frame - anchor.peak.t_frame) & 0x3FFF).max(1);
636            let hash = ((f_a_q & 0x1FF) << 23) | ((f_b_q & 0x1FF) << 14) | (dt & 0x3FFF);
637            let t_ms = (anchor.peak.t_frame as u64 * WANG_HOP as u64 * 1000) / WANG_SR as u64;
638            out.push((
639                TimestampMs(t_ms),
640                WangHash {
641                    hash,
642                    t_anchor: anchor.peak.t_frame,
643                },
644            ));
645        }
646    }
647}
648
649impl StreamingFingerprinter for StreamingWang {
650    type Frame = WangHash;
651
652    fn push(&mut self, samples: &[f32]) -> alloc::vec::Vec<(TimestampMs, Self::Frame)> {
653        self.sample_carry.extend_from_slice(samples);
654
655        let nbht = WANG_PEAK_NEIGHBOURHOOD as u32;
656
657        // 1. Compute new STFT frames one at a time, detecting peaks at
658        // each frame as soon as it becomes ripe (i.e. its full forward
659        // neighbourhood is in the buffer).
660        //
661        // Walk frames with an offset cursor so we drain `sample_carry`
662        // exactly once at the end of the call instead of shifting the
663        // tail by `WANG_HOP` after every frame; the loop becomes
664        // O(frames) instead of O(frames × buffer).
665        let mut off = 0usize;
666        while self.sample_carry.len() - off >= WANG_N_FFT {
667            self.stft.process_frame_power(
668                &self.sample_carry[off..off + WANG_N_FFT],
669                &mut self.frame_scratch,
670            );
671            for v in self.frame_scratch.iter_mut() {
672                *v = 10.0 * libm::log10f(v.max(WANG_LOG_FLOOR_POWER));
673            }
674            // Append `self.frame_scratch` directly via disjoint field
675            // borrow, avoiding a per-frame `Vec::clone` of the row.
676            self.append_frame_scratch_row();
677
678            let frame_idx = self.n_frames_total;
679            self.n_frames_total += 1;
680            off += WANG_HOP;
681
682            // After adding frame `frame_idx`, frame `frame_idx - nbht`
683            // becomes ripe (its forward neighbourhood is now in the
684            // buffer; backward neighbourhood is offline-equivalent
685            // because the buffer's left edge matches the offline
686            // saturating clip when applicable).
687            if frame_idx >= nbht {
688                let abs_ripe = frame_idx - nbht;
689                let row_idx = (abs_ripe - self.spec_first_frame) as usize;
690                self.detect_rows(row_idx, row_idx);
691                self.last_pd_frame = abs_ripe as i32;
692            }
693        }
694
695        if off > 0 {
696            self.sample_carry.drain(0..off);
697        }
698
699        // 2. Finalise any buckets whose frames are all detected.
700        self.finalize_buckets();
701
702        // 3. Emit hashes for anchors whose target zone is fully observed.
703        self.emit_finalized_anchors()
704    }
705
706    fn flush(&mut self) -> alloc::vec::Vec<(TimestampMs, Self::Frame)> {
707        // Detect peaks at remaining frames (those whose forward context
708        // would otherwise extend past end-of-stream — same boundary the
709        // offline picker handles via `saturating_sub`).
710        if self.spec_n_rows > 0 && self.n_frames_total > 0 {
711            let detect_to_abs = self.n_frames_total as i32 - 1;
712            if detect_to_abs > self.last_pd_frame {
713                let from_abs = (self.last_pd_frame + 1).max(self.spec_first_frame as i32) as u32;
714                let to_abs = detect_to_abs as u32;
715                let from_row = (from_abs - self.spec_first_frame) as usize;
716                let to_row = (to_abs - self.spec_first_frame) as usize;
717                self.detect_rows(from_row, to_row);
718                self.last_pd_frame = detect_to_abs;
719            }
720        }
721
722        // Finalise every remaining bucket — no more peaks can arrive.
723        let buckets: alloc::vec::Vec<u32> = self.bucket_pending.keys().cloned().collect();
724        for bucket in buckets {
725            self.finalize_bucket(bucket);
726        }
727
728        // Emit every remaining anchor — no more targets can arrive.
729        let mut emitted = alloc::vec::Vec::new();
730        while let Some(anchor) = self.pending_anchors.pop_front() {
731            self.build_hashes_for_anchor(anchor, &mut emitted);
732        }
733        emitted
734    }
735
736    fn latency_ms(&self) -> u32 {
737        (self.lookahead_frames() * WANG_HOP as u32 * 1000) / WANG_SR
738    }
739}
740
741#[cfg(test)]
742mod tests {
743    use super::*;
744    use crate::SampleRate;
745    use alloc::vec;
746    use core::f32::consts::PI;
747
748    fn synthetic_audio(seed: u32, len: usize) -> Vec<f32> {
749        // Two-tone with low-amplitude noise: stable across runs (no rng),
750        // but rich enough to produce many peaks.
751        let mut out = Vec::with_capacity(len);
752        let mut x: u32 = seed.max(1);
753        for n in 0..len {
754            // xorshift32 — deterministic noise.
755            x ^= x << 13;
756            x ^= x >> 17;
757            x ^= x << 5;
758            let noise = ((x as i32 as f32) / (i32::MAX as f32)) * 0.05;
759            let t = n as f32 / 8_000.0;
760            let s = 0.5 * libm::sinf(2.0 * PI * 880.0 * t)
761                + 0.3 * libm::sinf(2.0 * PI * 1320.0 * t)
762                + noise;
763            out.push(s);
764        }
765        out
766    }
767
768    #[test]
769    fn rejects_wrong_sample_rate() {
770        let mut fp = Wang::default();
771        let samples = vec![0.0_f32; 16_000];
772        let buf = AudioBuffer {
773            samples: &samples,
774            rate: SampleRate::HZ_16000,
775        };
776        match fp.extract(buf) {
777            Err(AfpError::UnsupportedSampleRate(16_000)) => {}
778            other => panic!("expected UnsupportedSampleRate(16000), got {other:?}"),
779        }
780    }
781
782    #[test]
783    fn rejects_short_audio() {
784        let mut fp = Wang::default();
785        let samples = vec![0.0_f32; 8_000]; // 1 second, need 2
786        let buf = AudioBuffer {
787            samples: &samples,
788            rate: SampleRate::HZ_8000,
789        };
790        match fp.extract(buf) {
791            Err(AfpError::AudioTooShort {
792                needed: 16_000,
793                got: 8_000,
794            }) => {}
795            other => panic!("expected AudioTooShort, got {other:?}"),
796        }
797    }
798
799    #[test]
800    fn silence_gives_empty_fingerprint() {
801        let mut fp = Wang::default();
802        let samples = vec![0.0_f32; 8_000 * 3];
803        let buf = AudioBuffer {
804            samples: &samples,
805            rate: SampleRate::HZ_8000,
806        };
807        let fpr = fp.extract(buf).unwrap();
808        assert_eq!(fpr.frames_per_sec, 62.5);
809        assert!(fpr.hashes.is_empty());
810    }
811
812    #[test]
813    fn synthetic_signal_produces_hashes() {
814        let mut fp = Wang::default();
815        let samples = synthetic_audio(0xC0FFEE, 8_000 * 5);
816        let buf = AudioBuffer {
817            samples: &samples,
818            rate: SampleRate::HZ_8000,
819        };
820        let fpr = fp.extract(buf).unwrap();
821        assert!(!fpr.hashes.is_empty(), "expected hashes from a 5s tone");
822        // Ordering invariant.
823        for w in fpr.hashes.windows(2) {
824            assert!((w[0].t_anchor, w[0].hash) <= (w[1].t_anchor, w[1].hash));
825        }
826    }
827
828    #[test]
829    fn extraction_is_deterministic() {
830        let samples = synthetic_audio(0xDEAD, 8_000 * 4);
831
832        let mut fp1 = Wang::default();
833        let buf1 = AudioBuffer {
834            samples: &samples,
835            rate: SampleRate::HZ_8000,
836        };
837        let f1 = fp1.extract(buf1).unwrap();
838
839        let mut fp2 = Wang::default();
840        let buf2 = AudioBuffer {
841            samples: &samples,
842            rate: SampleRate::HZ_8000,
843        };
844        let f2 = fp2.extract(buf2).unwrap();
845
846        assert_eq!(f1.hashes.len(), f2.hashes.len());
847        for (a, b) in f1.hashes.iter().zip(f2.hashes.iter()) {
848            assert_eq!(a, b);
849        }
850    }
851
852    #[test]
853    fn different_signals_diverge() {
854        let samples_a = synthetic_audio(0x1111, 8_000 * 3);
855        let samples_b = synthetic_audio(0x2222, 8_000 * 3);
856
857        let mut fp = Wang::default();
858        let fa = fp
859            .extract(AudioBuffer {
860                samples: &samples_a,
861                rate: SampleRate::HZ_8000,
862            })
863            .unwrap();
864        let fb = fp
865            .extract(AudioBuffer {
866                samples: &samples_b,
867                rate: SampleRate::HZ_8000,
868            })
869            .unwrap();
870        // Different noise streams must yield non-identical hash sequences.
871        assert_ne!(fa.hashes, fb.hashes);
872    }
873
874    #[test]
875    fn hash_packing_round_trips() {
876        // Smoke: feed a known peak set and verify hash-field decode.
877        // Build fake peaks: one anchor, one target inside zone.
878        let peaks = alloc::vec![
879            Peak {
880                t_frame: 100,
881                f_bin: 50,
882                _pad: 0,
883                mag: -10.0
884            },
885            Peak {
886                t_frame: 110,
887                f_bin: 70,
888                _pad: 0,
889                mag: -12.0
890            },
891        ];
892        let cfg = WangConfig::default();
893        let hashes = build_hashes(&peaks, &cfg);
894        assert_eq!(hashes.len(), 1);
895        let h = hashes[0].hash;
896        // Decode
897        let f_a_q = (h >> 23) & 0x1FF;
898        let f_b_q = (h >> 14) & 0x1FF;
899        let dt = h & 0x3FFF;
900        assert_eq!(f_a_q, quantise_freq(50));
901        assert_eq!(f_b_q, quantise_freq(70));
902        assert_eq!(dt, 10);
903        assert_eq!(hashes[0].t_anchor, 100);
904    }
905
906    #[test]
907    fn streaming_latency_matches_lookahead() {
908        let s = StreamingWang::default();
909        // (63 target_zone + 15 picker + 63 adaptive bucket) * 128 / 8000 ≈ 2256 ms.
910        assert_eq!(s.latency_ms(), 2_256);
911    }
912
913    #[test]
914    fn streaming_empty_push_is_empty() {
915        let mut s = StreamingWang::default();
916        assert!(s.push(&[]).is_empty());
917        assert!(s.flush().is_empty());
918    }
919
920    #[test]
921    fn streaming_silence_emits_nothing() {
922        let mut s = StreamingWang::default();
923        let zeros = vec![0.0_f32; 8_000 * 4];
924        assert!(s.push(&zeros).is_empty());
925        assert!(s.flush().is_empty());
926    }
927
928    /// xorshift32 → split into deterministic pseudo-random chunk sizes.
929    fn chunk_sizes(seed: u32, total: usize, max_chunk: usize) -> Vec<usize> {
930        let mut x = seed.max(1);
931        let mut out = Vec::new();
932        let mut remaining = total;
933        while remaining > 0 {
934            x ^= x << 13;
935            x ^= x >> 17;
936            x ^= x << 5;
937            let n = ((x as usize) % max_chunk).max(1).min(remaining);
938            out.push(n);
939            remaining -= n;
940        }
941        out
942    }
943
944    /// Sanity check that the incremental impl emits the *same* hashes
945    /// across a sequence of fixed-size chunks regardless of the chunk
946    /// size — no spurious quadratic state, no per-push artefacts.
947    #[test]
948    fn streaming_chunk_size_invariant() {
949        let samples = synthetic_audio(0xFACE, 8_000 * 4);
950
951        let collect = |chunk_size: usize| -> Vec<WangHash> {
952            let mut s = StreamingWang::default();
953            let mut out = Vec::new();
954            for chunk in samples.chunks(chunk_size) {
955                out.extend(s.push(chunk).into_iter().map(|(_, h)| h));
956            }
957            out.extend(s.flush().into_iter().map(|(_, h)| h));
958            out.sort_unstable_by_key(|h| (h.t_anchor, h.hash));
959            out
960        };
961
962        let baseline = collect(8_000); // 1-second chunks
963        for chunk_size in [128, 1024, 4321, 16_000] {
964            assert_eq!(
965                collect(chunk_size),
966                baseline,
967                "chunk_size = {chunk_size} produced different hashes than 8000",
968            );
969        }
970    }
971
972    #[test]
973    fn streaming_offline_equivalence() {
974        let samples = synthetic_audio(0xBEEF, 8_000 * 6);
975
976        // Offline reference.
977        let mut offline = Wang::default();
978        let off = offline
979            .extract(AudioBuffer {
980                samples: &samples,
981                rate: SampleRate::HZ_8000,
982            })
983            .unwrap();
984
985        // Streaming with random chunks.
986        let mut streaming = StreamingWang::default();
987        let mut online = Vec::new();
988        let mut cursor = 0;
989        for n in chunk_sizes(0xCAFE, samples.len(), 4_000) {
990            let end = cursor + n;
991            online.extend(
992                streaming
993                    .push(&samples[cursor..end])
994                    .into_iter()
995                    .map(|(_, h)| h),
996            );
997            cursor = end;
998        }
999        online.extend(streaming.flush().into_iter().map(|(_, h)| h));
1000
1001        // Same multiset of hashes.
1002        let mut a: Vec<WangHash> = off.hashes;
1003        let mut b: Vec<WangHash> = online;
1004        a.sort_unstable_by_key(|h| (h.t_anchor, h.hash));
1005        b.sort_unstable_by_key(|h| (h.t_anchor, h.hash));
1006        assert_eq!(a.len(), b.len(), "hash count mismatch");
1007        assert_eq!(a, b, "hash sequences differ");
1008    }
1009
1010    #[test]
1011    fn smaller_fan_out_yields_fewer_hashes() {
1012        let samples = synthetic_audio(0xFEED, 8_000 * 4);
1013        let buf_a = AudioBuffer {
1014            samples: &samples,
1015            rate: SampleRate::HZ_8000,
1016        };
1017        let buf_b = AudioBuffer {
1018            samples: &samples,
1019            rate: SampleRate::HZ_8000,
1020        };
1021
1022        let mut wide = Wang::new(WangConfig {
1023            fan_out: 10,
1024            ..WangConfig::default()
1025        });
1026        let mut narrow = Wang::new(WangConfig {
1027            fan_out: 3,
1028            ..WangConfig::default()
1029        });
1030        let f_wide = wide.extract(buf_a).unwrap();
1031        let f_narrow = narrow.extract(buf_b).unwrap();
1032        assert!(
1033            f_narrow.hashes.len() < f_wide.hashes.len(),
1034            "narrow={} wide={}",
1035            f_narrow.hashes.len(),
1036            f_wide.hashes.len(),
1037        );
1038    }
1039
1040    #[test]
1041    fn quantise_freq_covers_full_range() {
1042        // Bin 0 maps to bucket 0; bin 512 (≈ Nyquist - 1 step) ≈ bucket 511.
1043        assert_eq!(quantise_freq(0), 0);
1044        assert!(quantise_freq(512) < WANG_FREQ_BUCKETS);
1045        // Quantisation is monotonic non-decreasing.
1046        let mut prev = 0;
1047        for b in 0..513_u16 {
1048            let q = quantise_freq(b);
1049            assert!(q >= prev);
1050            assert!(q < WANG_FREQ_BUCKETS);
1051            prev = q;
1052        }
1053    }
1054
1055    #[test]
1056    fn streaming_with_one_sample_chunks_still_matches_offline() {
1057        let samples = synthetic_audio(0xABCD, 8_000 * 3);
1058        let mut offline = Wang::default();
1059        let off = offline
1060            .extract(AudioBuffer {
1061                samples: &samples,
1062                rate: SampleRate::HZ_8000,
1063            })
1064            .unwrap();
1065
1066        let mut s = StreamingWang::default();
1067        let mut online = Vec::new();
1068        // Push one sample at a time — pathological case for any incremental
1069        // streaming impl.
1070        for &sample in &samples {
1071            online.extend(s.push(&[sample]).into_iter().map(|(_, h)| h));
1072        }
1073        online.extend(s.flush().into_iter().map(|(_, h)| h));
1074
1075        let mut a = off.hashes;
1076        let mut b = online;
1077        a.sort_unstable_by_key(|h| (h.t_anchor, h.hash));
1078        b.sort_unstable_by_key(|h| (h.t_anchor, h.hash));
1079        assert_eq!(a, b);
1080    }
1081
1082    #[test]
1083    fn streaming_state_stays_bounded_under_long_input() {
1084        // Push 30 s of audio in 256-sample chunks (~940 pushes) and
1085        // track peak-observed sizes for each streaming buffer. Tight
1086        // ceilings document the actual steady-state and catch future
1087        // regressions that would inflate any of them.
1088        let secs = 30usize;
1089        let samples = synthetic_audio(7, WANG_SR as usize * secs);
1090        let chunk = 256usize;
1091
1092        let mut s = StreamingWang::default();
1093        let max_spec_rows = 2 * WANG_PEAK_NEIGHBOURHOOD + 1;
1094
1095        let mut peak_carry = 0usize;
1096        let mut peak_spec_rows = 0usize;
1097        let mut peak_bucket_pending = 0usize;
1098        let mut peak_anchors = 0usize;
1099
1100        let mut start = 0usize;
1101        while start < samples.len() {
1102            let end = (start + chunk).min(samples.len());
1103            let _ = s.push(&samples[start..end]);
1104            peak_carry = peak_carry.max(s.sample_carry.len());
1105            peak_spec_rows = peak_spec_rows.max(s.spec_n_rows);
1106            peak_bucket_pending = peak_bucket_pending.max(s.bucket_pending.len());
1107            peak_anchors = peak_anchors.max(s.pending_anchors.len());
1108
1109            // Hard structural invariants — must hold every push.
1110            assert!(s.sample_carry.len() < WANG_N_FFT);
1111            assert!(s.spec_n_rows <= max_spec_rows);
1112            start = end;
1113        }
1114
1115        // Tight ceilings on the peaks observed across the whole run at
1116        // default config (peaks_per_sec=30, target_zone_t=63 frames ≈
1117        // 1 s of bucket coverage, fan_out=5).
1118        assert_eq!(
1119            peak_spec_rows, max_spec_rows,
1120            "spec window should fill once the stream is long enough",
1121        );
1122        assert!(peak_carry < WANG_N_FFT, "peak_carry {peak_carry}");
1123        assert!(
1124            peak_bucket_pending <= 3,
1125            "bucket_pending peaked at {peak_bucket_pending} (steady state should be ≤ 2)",
1126        );
1127        // 1 s of finalised buckets × peaks_per_sec=30 = ~30 anchors;
1128        // allow modest headroom for the boundary between adjacent buckets.
1129        assert!(
1130            peak_anchors <= 40,
1131            "pending_anchors peaked at {peak_anchors} (expected ≤ 40)",
1132        );
1133
1134        // Flush drains everything.
1135        let _ = s.flush();
1136        assert_eq!(s.bucket_pending.len(), 0);
1137        assert_eq!(s.pending_anchors.len(), 0);
1138    }
1139
1140    #[test]
1141    fn target_zone_filters_far_peaks() {
1142        let peaks = alloc::vec![
1143            Peak {
1144                t_frame: 0,
1145                f_bin: 100,
1146                _pad: 0,
1147                mag: 0.0
1148            },
1149            // Same time → skipped (Δt < 1).
1150            Peak {
1151                t_frame: 0,
1152                f_bin: 200,
1153                _pad: 0,
1154                mag: 0.0
1155            },
1156            // Δt = 70 > target_zone_t (63) → skipped.
1157            Peak {
1158                t_frame: 70,
1159                f_bin: 100,
1160                _pad: 0,
1161                mag: 0.0
1162            },
1163            // Inside zone.
1164            Peak {
1165                t_frame: 5,
1166                f_bin: 110,
1167                _pad: 0,
1168                mag: 0.0
1169            },
1170            // |Δf| = 200 > 64 → skipped.
1171            Peak {
1172                t_frame: 5,
1173                f_bin: 300,
1174                _pad: 0,
1175                mag: 0.0
1176            },
1177        ];
1178        // Note: peaks vec must be sorted by (t_frame, f_bin) for the
1179        // "break on dt > zone" optimisation to fire correctly.
1180        let mut sorted = peaks;
1181        sorted.sort_unstable_by_key(|p| (p.t_frame, p.f_bin));
1182
1183        let cfg = WangConfig::default();
1184        let hashes = build_hashes(&sorted, &cfg);
1185        // Anchor at (0,100) should pair with (5,110) only; anchor at (0,200)
1186        // can pair with (5,110) (|Δf|=90 — wait that's > 64), or (5,300)
1187        // (|Δf|=100 > 64). Neither fits → no hash from anchor (0,200).
1188        // From (5,110) onwards, no later peaks fit any anchor.
1189        assert_eq!(hashes.len(), 1);
1190        assert_eq!(hashes[0].t_anchor, 0);
1191    }
1192}