mlxrs 0.1.0

Safe Rust bindings for Apple's MLX array framework, with LM, VLM, audio, and embeddings support
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
//! `lfilter` IIR recurrence (BS.1770 K-weighting hot loop).
//!
//! Tracking: [#154](https://github.com/Findit-AI/mlxrs/issues/154).
//! The IIR recurrence has a loop-carried dependency (every `y[n]` depends
//! on `y[n-1]`, the next state cell on `y[n]`, etc. — no within-stream
//! parallelism), so a wide NEON kernel over the sample axis is
//! mathematically impossible without changing the algorithm.
//!
//! # Approaches
//!
//! Three approaches (in order, easiest → research-grade):
//!
//! 1. **`state_len == 0` FIR fast-path** — degenerate 1-tap FIR
//!    (`y[n] = b[0] * x[n]`) is trivially parallel; ship a NEON
//!    `f64x2`-wide multiply-and-store. Cosmetic for the actual
//!    K-weighting workload (which uses `state_len == 2` biquads and
//!    never hits this arm), but a legit NEON kernel.
//! 2. **Block-IIR parallel-prefix scan** — express the biquad as
//!    state-space `s[n] = A * s[n-1] + B * x[n]`, process B-sample
//!    blocks via parallel-prefix of `A^k` matrix products.
//!    **REJECTED** before implementation: matrix-power accumulation
//!    cannot produce bit-identical output to the per-sample
//!    recurrence (different evaluation order, different rounding),
//!    and the project mandate is "LUFS bit-exactness MUST be
//!    preserved" for the `integrated_loudness` consumer. A NEON arm
//!    that drifts from scalar by 1 ulp at the 401st sample would
//!    silently bias `integrated_loudness` against the EBU R128
//!    reference.
//! 3. **Hand-unrolled biquad specialization** — `state_len == 2`
//!    fast-path with the 3-tap b / 3-tap a inner loop fully
//!    flattened into 4 scalar muls + 4 scalar adds per sample (no
//!    `b_norm.get(i).copied().unwrap_or(0.0)` bounds-check, no
//!    `for i in 1..state_len` loop), so LLVM can schedule the
//!    fp pipeline tighter. Same per-sample serial dependency, same
//!    bit-identical math (the only-finite reordering preserves
//!    every multiply / add at the same IEEE rounding).
//!
//! # Bench results (M2 Pro, release)
//!
//! Generated by `mlxrs/benches/simd_lfilter.rs` with `--warm-up-time 1
//! --measurement-time 2 --sample-size 30` (criterion); point-estimate
//! throughput per row (criterion's median `[low estimate high]` middle
//! number). The bench file is the authoritative data source.
//!
//! ## Biquad IN-PLACE lane-sweep (single-stage HS coefficients @ 48 kHz)
//!
//! ```text
//! n        generic        biquad_scalar  biquad_dispatch  dispatch vs generic
//! 1024     336 Melem/s    544 Melem/s    543 Melem/s      +61.5%
//! 4096     332 Melem/s    513 Melem/s    489 Melem/s      +47.0%
//! 16384    313 Melem/s    468 Melem/s    476 Melem/s      +52.3%
//! 48000    302 Melem/s    484 Melem/s    480 Melem/s      +58.8%
//! 65536    309 Melem/s    479 Melem/s    480 Melem/s      +55.1%
//! 192000   329 Melem/s    497 Melem/s    502 Melem/s      +52.6%
//! 480000   336 Melem/s    492 Melem/s    484 Melem/s      +44.2%
//! ```
//!
//! IN-PLACE dispatch beats generic at every benched length
//! (+44% to +62%), including the realistic long-channel
//! sizes (192k = 4 s @ 48 kHz, 480k = 10 s @ 48 kHz —
//! `k_weight_channel` operates on FULL audio channels, not 48 k-sample
//! fixtures). `biquad_scalar` and `biquad_dispatch` are close at most
//! lengths (both routes flatten the inner `for i in 1..state_len` loop
//! to the same 3-tap unrolled body; the NEON-annotated arm is the same
//! body under `target_feature(enable = "neon")`). At `n=4096` the
//! criterion confidence intervals are non-overlapping —
//! `biquad_scalar` at 513.04 Melem/s `[499.03, 526.59]` versus
//! `biquad_dispatch` at 488.84 Melem/s `[485.23, 492.99]` — so the
//! NEON-annotated arm carries a small consistent ~5% deficit at that
//! length, well inside the larger run-to-run baseline movement
//! documented below.
//!
//! ## Biquad OUT-OF-PLACE lane-sweep (single-stage HS @ 48 kHz)
//!
//! Dispatch path = `Vec::with_capacity(n) + extend_from_slice(x) +
//! in_place_kernel(&mut y, b, a)`. Adds a full `n_samples` f64 memcpy
//! beyond the in-place kernel's work.
//!
//! ```text
//! n        generic        dispatch       dispatch vs generic
//! 1024     459 Melem/s    463 Melem/s    +0.9%
//! 4096     460 Melem/s    463 Melem/s    +0.7%
//! 16384    454 Melem/s    441 Melem/s    -2.9%     dispatch LOSES
//! 48000    460 Melem/s    454 Melem/s    -1.2%     dispatch LOSES
//! 65536    458 Melem/s    454 Melem/s    -0.9%     dispatch LOSES
//! 192000   373 Melem/s    461 Melem/s    +23.5%
//! 480000   339 Melem/s    453 Melem/s    +33.5%
//! ```
//!
//! Mixed picture. Mid-size losses (16k-65k) are 1-3% — within run-to-run
//! variance — but a fixed length-gated threshold cannot be confidently
//! drawn from noise. Long-channel wins (23-34%) are decisive but driven
//! by `generic_out_of_place`'s cache-locality drop at the L2 boundary;
//! `dispatch_out_of_place` happens to hold its throughput because its
//! `extend_from_slice` already warmed the cache for the in-place kernel.
//!
//! ## K-weighting chain (HS → HP, the actual `integrated_loudness` consumer)
//!
//! ```text
//! n        generic        biquad_scalar  biquad_dispatch  dispatch vs generic
//! 48000    160 Melem/s    243 Melem/s    245 Melem/s      +52.7%
//! 192000   174 Melem/s    241 Melem/s    251 Melem/s      +44.4%
//! 480000   185 Melem/s    237 Melem/s    242 Melem/s      +30.5%
//! ```
//!
//! Dispatch beats generic on the K-weighting chain across all benched
//! lengths (+30% to +53%), including 4 s and 10 s
//! @ 48 kHz long channels. This is the consumer that drives the ship
//! decision for the in-place arm — but see the baseline stability
//! caveat below before treating the magnitude as settled.
//!
//! ## Baseline stability
//!
//! Generic-loop baselines vary substantially run-to-run while dispatch
//! itself drifts by only single-digit percent across most rows (and
//! ~+11% on one in-place row). Two independent runs of the same
//! byte-identical harness, row by row:
//!
//! ```text
//! Biquad IN-PLACE lane sweep (single-stage HS @ 48 kHz)
//!
//! n        generic prior   generic curr    dispatch prior  dispatch curr   prior ratio
//! 1024     465 Melem/s     336 Melem/s     487 Melem/s     543 Melem/s     +4.7%
//! 4096     461 Melem/s     332 Melem/s     478 Melem/s     489 Melem/s     +3.6%
//! 16384    462 Melem/s     313 Melem/s     476 Melem/s     476 Melem/s     +3.0%
//! 48000    462 Melem/s     302 Melem/s     480 Melem/s     480 Melem/s     +3.9%
//! 65536    463 Melem/s     309 Melem/s     484 Melem/s     480 Melem/s     +4.5%
//! 192000   463 Melem/s     329 Melem/s     482 Melem/s     502 Melem/s     +4.1%
//! 480000   464 Melem/s     336 Melem/s     482 Melem/s     484 Melem/s     +3.8%
//!
//! K-weighting chain (HS → HP)
//!
//! n        generic prior   generic curr    dispatch prior  dispatch curr   prior ratio
//! 48000    235 Melem/s     160 Melem/s     254 Melem/s     245 Melem/s     +8.1%
//! 192000   236 Melem/s     174 Melem/s     257 Melem/s     251 Melem/s     +9.0%
//! 480000   234 Melem/s     185 Melem/s    ~256 Melem/s     242 Melem/s     +9.4%
//! ```
//!
//! Generic baselines vary by roughly 21-35% across both arms between
//! runs (in-place rows: 27.6% to 34.6%; K-weight rows: 20.9%
//! to 31.9%). Dispatch values themselves also vary across runs (e.g.
//! in-place at `n=1024` ranges 487–543 Melem/s, about +11.5%, and the
//! K-weighting rows move by roughly -3% to -5% on each row), but
//! dispatch consistently beats the generic baseline at every benched
//! length. The mechanism for the generic-baseline movement is not
//! identified (likely a thermal, scheduler, or codegen-side effect on a
//! route that is not actively tuned); the observable consequence is that
//! the larger ratios (+44% to +62% in-place, +30% to +53% K-weight) are
//! substantially driven by that baseline movement, not by an
//! independently-established dispatch improvement.
//!
//! ### Conservative lower-bound win (same-run ratios)
//!
//! The conservative basis below uses a single run's own row-paired
//! dispatch-vs-generic ratios — SAME-RUN observed ratios from one run,
//! NOT a cross-run mixed envelope. Row by row:
//!
//! - **IN-PLACE biquad dispatch beats generic by +3.0% to +4.7%**
//!   across the lane sweep (minimum +3.0% at n=16384; maximum +4.7% at
//!   n=1024). The larger +44% to +62% ratios should not be treated as
//!   settled until reproduced across independent runs.
//! - **K-weight chain dispatch beats generic by +8.1% to +9.4%**
//!   across the sweep (minimum +8.1% at n=48000; maximum +9.4% at
//!   n=480000). The larger +30% to +53% is similarly contingent on the
//!   depressed generic baseline reproducing.
//!
//! ### Ship basis
//!
//! Keeping the in-place arm wired is grounded on the CONSERVATIVE
//! lower-bound win read from a single run's own row-paired ratios
//! (+3.0% to +4.7% on the in-place lane sweep, +8.1% to +9.4% on the
//! K-weighting chain), not on the larger numbers and not on a cross-run
//! mixed envelope. A few-percent win that survives across two
//! independent runs at every benched length is sufficient to keep
//! wiring on the K-weighting hot path; the larger ratios would be a
//! nice-to-have IF reproduced, but the ship basis does not depend on
//! them. The per-arm ship calls in the **Ship calls** section below
//! follow that conservative framing.
//!
//! ## FIR fast-path (`state_len == 0`)
//!
//! ```text
//! n       scalar          dispatch (NEON)
//! 1024    12.1 Gelem/s    7.5 Gelem/s     NEON SLOWER (autovec wins)
//! 4096    10.0 Gelem/s    7.1 Gelem/s     NEON SLOWER
//! 16384   14.9 Gelem/s    7.9 Gelem/s     NEON SLOWER
//! 48000   10.4 Gelem/s    7.8 Gelem/s     NEON SLOWER
//! 65536   10.3 Gelem/s    7.7 Gelem/s     NEON SLOWER
//! ```
//!
//! LLVM autovectorizes `for (dst, &src) in out.iter_mut().zip(x.iter())
//! { *dst = b0 * src; }` at f64 width better than a hand-rolled
//! `vmulq_n_f64` 2-lane NEON tile. Auto-vectorization typically uses
//! both 128-bit NEON pipes via interleaved instructions; the 2-lane
//! hand-rolled tile under-utilises that interleaving.
//!
//! # Realistic workload context
//!
//! `crate::audio::dsp::k_weight_channel` (private; the only in-tree
//! consumer of the biquad fast-path, used by
//! [`crate::audio::dsp::integrated_loudness`]) operates on FULL audio
//! channels — up to `MAX_DECODED_SAMPLES = 64 Mi samples` per channel.
//! The lane sweep deliberately spans both short (1024 samples) and
//! realistic long (480000 samples = 10 s @ 48 kHz) sizes; ship
//! decisions cite the long-channel numbers, not the 48 k-sample
//! fixture alone.
//!
//! # Ship calls
//!
//! Keep wiring only where the benchmark proves dispatch > scalar at the
//! actually-wired paths; prefer the generic loop when in doubt:
//!
//! - **SHIP + WIRE biquad specialization at the in-place path only**
//!   (`crate::audio::dsp::lfilter_f64_in_place`) — grounded on the
//!   conservative lower-bound win read from a single run's own
//!   row-paired ratios (+3.0% to +4.7% on the in-place lane sweep,
//!   +8.1% to +9.4% on the K-weighting chain; SAME-RUN ratios, not a
//!   cross-run mixed envelope). See the **Baseline stability** section
//!   above for the row-accurate tables. The larger ratios
//!   (+44% to +62% in-place, +30% to +53% K-weight) are inflated by
//!   depressed generic baselines, not an independently-established
//!   dispatch improvement, and are not the basis of the ship call.
//!   Bit-exact against the generic kernel (asserted by the
//!   `biquad_bit_exact_vs_generic_*` tests below over 48000-sample
//!   HS / HP / chained fixtures); LUFS measurements through
//!   [`crate::audio::dsp::integrated_loudness`] remain byte-identical
//!   to pre-SIMD output.
//! - **DO NOT WIRE biquad specialization at the out-of-place path**
//!   (`crate::audio::dsp::lfilter_f64`) — the public out-of-place
//!   wrapper is not the K-weighting hot path (`integrated_loudness`
//!   calls the in-place kernel directly through `k_weight_channel`).
//!   Out-of-place dispatch loses 1-3% at mid sizes (16k-65k, within
//!   variance) and the mixed picture across the sweep gives no
//!   confident length-gating threshold. Preferring the generic loop
//!   when in doubt, the out-of-place wrapper runs its single-pass
//!   generic loop (no extra `extend_from_slice` memcpy).
//! - **SHIP FIR kernel but DO NOT WIRE** — the kernel lives here as
//!   a regression guard + building block, but the dispatcher does NOT
//!   replace the existing scalar loops in [`crate::audio::dsp`]. LLVM
//!   autovectorizes the scalar f64 multiply loop better than the
//!   hand-rolled NEON tile on M2 hardware. Bit-exactness is preserved
//!   either way; the kernel exists for callers that want the
//!   dispatcher behaviour explicitly.
//! - **SHIP biquad kernel even on arms NOT WIRED** — the
//!   `lfilter_biquad` / `lfilter_biquad_scalar` / `lfilter_biquad_neon`
//!   trio stays available + bit-exact-tested as a regression guard
//!   AND for future callers (e.g. anyone wanting the dispatcher
//!   behaviour explicitly for a long in-place biquad pass).
//! - **DO NOT SHIP block-IIR parallel-prefix** — rejected before
//!   implementation on the bit-exactness mandate (see Approach 2
//!   above).
//!
//! # Correctness class — `Exact` (bit-identical)
//!
//! The LUFS pipeline asserts bit-exact match against the EBU R128
//! reference via the existing `integrated_loudness` tests; a NEON
//! arm with even a 1-ulp drift on a single sample propagates into
//! the gated mean-square reduction and biases the final LUFS read.
//! The differential test asserts `to_bits()`-level equality on
//! every output f64 across the scalar / dispatcher arms, AND uses
//! the actual K-weighting fixture (1 s @ 48 kHz through the two
//! biquads) so any drift surfaces immediately rather than waiting
//! on the noisier `integrated_loudness` end-to-end test.

#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::{vld1q_f64, vmulq_n_f64, vst1q_f64};

/// FIR fast-path scalar reference — `y[n] = b0 * x[n]`. Bit-exact
/// match for the [`super::super::super::audio::dsp`] kernel's
/// `state_len == 0` arm.
///
/// **Always compiled** — independent of `target_arch`. Anchors the
/// math contract, is the differential-test oracle, and is the
/// fallback path.
///
/// # Preconditions
///
/// - `out.len() == x.len()` — asserted **unconditionally**
///   (release-too).
#[inline]
#[doc(hidden)]
pub fn lfilter_fir_b0_scalar(out: &mut [f64], x: &[f64], b0: f64) {
  assert_eq!(
    out.len(),
    x.len(),
    "lfilter_fir_b0_scalar: out.len() ({}) must equal x.len() ({})",
    out.len(),
    x.len(),
  );
  for (dst, &src) in out.iter_mut().zip(x.iter()) {
    *dst = b0 * src;
  }
}

/// FIR fast-path NEON kernel — 2-lane `float64x2_t` multiply-and-store.
///
/// # Algorithm
///
/// Per 2-lane tile:
/// 1. Load `vld1q_f64` of `x[i..i+2]`.
/// 2. `vmulq_n_f64` against the broadcast `b0` scalar.
/// 3. Store `vst1q_f64` into `out[i..i+2]`.
///
/// Tail (`x.len() % 2 == 1`) is delegated to
/// [`lfilter_fir_b0_scalar`].
///
/// # Bit-exactness
///
/// `b0 * x[i]` on NEON `vmulq_n_f64` is a single IEEE 754 f64
/// multiply, byte-identical to the scalar `b0 * src` — same rounding,
/// same NaN propagation, same denormal handling. The 2-lane
/// reorganization does not affect the per-element math.
///
/// # Safety
///
/// 1. NEON must be available on the executing CPU. Caller obligation;
///    the dispatcher [`lfilter_fir_b0`] discharges it.
/// 2. `out.len() == x.len()` — asserted **unconditionally** here.
///
/// `vld1q_f64`/`vst1q_f64` accept unaligned addresses at full
/// throughput on aarch64.
#[cfg(target_arch = "aarch64")]
#[inline]
#[target_feature(enable = "neon")]
pub(crate) unsafe fn lfilter_fir_b0_neon(out: &mut [f64], x: &[f64], b0: f64) {
  assert_eq!(
    out.len(),
    x.len(),
    "lfilter_fir_b0_neon: out.len() ({}) must equal x.len() ({})",
    out.len(),
    x.len(),
  );

  let n = x.len();
  let body_len = n - (n % 2);

  // SAFETY: body loads `vld1q_f64` (2 lanes) from `x.as_ptr().add(i)` for
  // `i + 2 <= body_len <= n` — within bounds. Stores `vst1q_f64` to
  // `out.as_mut_ptr().add(i)` for the same `i` — within bounds. NEON
  // availability is the caller's obligation (precondition #1).
  unsafe {
    let src_base = x.as_ptr();
    let dst_base = out.as_mut_ptr();

    let mut i = 0usize;
    while i + 2 <= body_len {
      let v = vld1q_f64(src_base.add(i));
      let scaled = vmulq_n_f64(v, b0);
      vst1q_f64(dst_base.add(i), scaled);
      i += 2;
    }
  }

  if body_len < n {
    lfilter_fir_b0_scalar(&mut out[body_len..], &x[body_len..], b0);
  }
}

/// FIR fast-path dispatcher — routes to NEON on aarch64 (when the CPU
/// reports NEON), else to [`lfilter_fir_b0_scalar`].
///
/// # Preconditions
///
/// - `out.len() == x.len()` — asserted **unconditionally**.
///
/// # Correctness class
///
/// `Exact` — single per-element f64 multiply, bit-identical across
/// scalar / NEON. See module-level "Correctness class" section.
#[inline]
#[doc(hidden)]
pub fn lfilter_fir_b0(out: &mut [f64], x: &[f64], b0: f64) {
  assert_eq!(
    out.len(),
    x.len(),
    "simd::audio::lfilter_fir_b0: out.len() ({}) must equal x.len() ({})",
    out.len(),
    x.len(),
  );
  #[cfg(target_arch = "aarch64")]
  {
    if crate::simd::is_neon_available() {
      // SAFETY: `is_neon_available()` confirmed NEON is on this CPU
      // (precondition #1 of `lfilter_fir_b0_neon`). The slice-length
      // precondition (#2) was just asserted unconditionally above.
      unsafe { lfilter_fir_b0_neon(out, x, b0) };
      return;
    }
  }
  lfilter_fir_b0_scalar(out, x, b0);
}

/// Biquad specialization scalar reference — `state_len == 2`,
/// `b.len() == a.len() == 3`. Hand-unrolls the inner
/// `for i in 1..state_len` loop and the `b_norm.get(i)
/// .copied().unwrap_or(0.0)` bounds-checks into 4 fp muls + 4 fp adds
/// per sample, the irreducible work for a direct-form II transposed
/// biquad. LLVM should be able to schedule the fp pipeline tighter
/// than the generic loop (one indirect index vs. compile-time
/// constants).
///
/// # Algorithm
///
/// Per sample:
/// ```text
/// output       = b0 * sample + state0
/// state0_next  = state1 + b1 * sample - a1 * output
/// state1_next  = b2 * sample - a2 * output
/// ```
///
/// `b` and `a` are passed already-normalized (caller divided by `a[0]`
/// — matches the generic kernel's `b_norm` / `a_norm` step). The
/// caller is also responsible for the `b.is_empty()` /
/// `state_len == 0` early-returns and the `a[0] != 0` /
/// `n_samples <= MAX_LFILTER_SAMPLES` precondition checks; this
/// function is the inner loop only.
///
/// # Preconditions
///
/// - `b.len() == 3` and `a.len() == 3` — asserted **unconditionally**.
/// - The caller has already divided `b` and `a` by `a[0]`.
///
/// # Bit-exactness
///
/// The unrolled body evaluates the SAME 4 muls + 4 adds in the SAME
/// order as the generic kernel for `state_len == 2`:
/// - `output = b0 * sample + state[0]` (generic line 2522)
/// - inner loop `i = 1`: `state[0] = state[1] + b1 * sample - a1 * output`
///   (generic lines 2530-2532, with `b_norm.get(1).copied()` = `b1`,
///   `a_norm.get(1).copied()` = `a1`)
/// - final cell: `state[1] = b2 * sample - a2 * output` (generic
///   lines 2537-2539, with `b_norm.get(2).copied()` = `b2`,
///   `a_norm.get(2).copied()` = `a2`)
///
/// Each `+` / `-` / `*` is a single IEEE 754 op — no `mul_add` fusion,
/// no associativity rewrite. Bit-identical to the generic kernel for
/// any 3-tap biquad input.
#[inline]
#[doc(hidden)]
pub fn lfilter_biquad_scalar(x: &mut [f64], b: &[f64], a: &[f64]) {
  assert_eq!(
    b.len(),
    3,
    "lfilter_biquad_scalar: b.len() must be 3 (got {})",
    b.len(),
  );
  assert_eq!(
    a.len(),
    3,
    "lfilter_biquad_scalar: a.len() must be 3 (got {})",
    a.len(),
  );

  let b0 = b[0];
  let b1 = b[1];
  let b2 = b[2];
  let a1 = a[1];
  let a2 = a[2];

  let mut state0 = 0.0_f64;
  let mut state1 = 0.0_f64;

  for slot in x.iter_mut() {
    let sample = *slot;
    let output = b0 * sample + state0;
    // Same evaluation order as the generic kernel:
    //   state[0] = state[1] + b1*sample - a1*output   (left-to-right)
    //   state[1] = b2*sample - a2*output
    state0 = state1 + b1 * sample - a1 * output;
    state1 = b2 * sample - a2 * output;
    *slot = output;
  }
}

/// Biquad specialization "NEON" arm.
///
/// **NOTE**: The IIR recurrence is purely serial — every
/// `output[n]` depends on `state0` and `state1` from `output[n-1]`,
/// which depends on `output[n-1]` from sample `n-1`. There is no
/// within-stream parallelism a NEON wide-load could exploit. This
/// arm is therefore the SAME body as [`lfilter_biquad_scalar`] —
/// the only difference is the `#[target_feature(enable = "neon")]`
/// attribute, which tells LLVM it may use NEON registers / fma
/// instructions for the scalar muls. In practice, on M2 hardware
/// the scalar fp pipeline and the NEON scalar-lane pipeline are
/// the same execution units, so this is an LLVM-codegen experiment,
/// not a real wide-SIMD kernel.
///
/// # Safety
///
/// 1. NEON must be available on the executing CPU. Caller obligation;
///    the dispatcher [`lfilter_biquad`] discharges it.
/// 2. `b.len() == 3` and `a.len() == 3` — asserted **unconditionally**.
#[cfg(target_arch = "aarch64")]
#[inline]
#[target_feature(enable = "neon")]
pub(crate) unsafe fn lfilter_biquad_neon(x: &mut [f64], b: &[f64], a: &[f64]) {
  assert_eq!(
    b.len(),
    3,
    "lfilter_biquad_neon: b.len() must be 3 (got {})",
    b.len(),
  );
  assert_eq!(
    a.len(),
    3,
    "lfilter_biquad_neon: a.len() must be 3 (got {})",
    a.len(),
  );

  // The function body uses ZERO unsafe intrinsics — every operation
  // is a scalar f64 multiply or add, all well-defined for any IEEE 754
  // input. The `#[target_feature(enable = "neon")]` attribute permits
  // LLVM to emit NEON FMA / scalar-lane operations as it sees fit, but
  // the math contract is unchanged. The `unsafe fn` keyword is required
  // by the dispatcher convention (target-feature kernels are `unsafe
  // fn`); a pre-2024 inner `unsafe { }` block here would be redundant
  // (clippy: unnecessary_unsafe).
  let b0 = b[0];
  let b1 = b[1];
  let b2 = b[2];
  let a1 = a[1];
  let a2 = a[2];

  let mut state0 = 0.0_f64;
  let mut state1 = 0.0_f64;

  for slot in x.iter_mut() {
    let sample = *slot;
    let output = b0 * sample + state0;
    state0 = state1 + b1 * sample - a1 * output;
    state1 = b2 * sample - a2 * output;
    *slot = output;
  }
}

/// Biquad specialization dispatcher — routes to the NEON-feature-gated
/// arm on aarch64 (when the CPU reports NEON), else to
/// [`lfilter_biquad_scalar`].
///
/// # Preconditions
///
/// - `b.len() == 3` and `a.len() == 3` — asserted **unconditionally**.
/// - The caller has already divided `b` and `a` by `a[0]` (matches
///   the generic kernel's `b_norm` / `a_norm` step).
///
/// # Correctness class
///
/// `Exact` — bit-identical across scalar / dispatcher arms. The body
/// is the same 4 muls + 4 adds in the same evaluation order; only the
/// `target_feature` annotation differs.
#[inline]
#[doc(hidden)]
pub fn lfilter_biquad(x: &mut [f64], b: &[f64], a: &[f64]) {
  assert_eq!(
    b.len(),
    3,
    "simd::audio::lfilter_biquad: b.len() must be 3 (got {})",
    b.len(),
  );
  assert_eq!(
    a.len(),
    3,
    "simd::audio::lfilter_biquad: a.len() must be 3 (got {})",
    a.len(),
  );

  #[cfg(target_arch = "aarch64")]
  {
    if crate::simd::is_neon_available() {
      // SAFETY: `is_neon_available()` confirmed NEON is on this CPU
      // (precondition #1 of `lfilter_biquad_neon`). The slice-length
      // preconditions (#2) were just asserted unconditionally above.
      unsafe { lfilter_biquad_neon(x, b, a) };
      return;
    }
  }
  lfilter_biquad_scalar(x, b, a);
}

#[cfg(test)]
mod tests {
  //! Scalar vs dispatcher + scalar vs NEON differential tests + edge
  //! coverage for the lfilter.

  use super::{lfilter_biquad, lfilter_biquad_scalar, lfilter_fir_b0, lfilter_fir_b0_scalar};
  use crate::simd::diff::{assert_eq_over_lane_sweep, lane_sweep_lengths};

  /// Generate a deterministic f64 sample stream.
  fn gen_samples(n: usize) -> Vec<f64> {
    (0..n)
      .map(|i| {
        let mag = 0.1 + (i as f64) * 0.0007;
        if i.is_multiple_of(2) { mag } else { -mag }
      })
      .collect()
  }

  // ── FIR fast-path (state_len == 0) ────────────────────────────────

  /// Test adapter — scalar FIR fast-path against a fixed `b0`.
  fn fir_scalar(src: &[f64]) -> Vec<f64> {
    let mut out = vec![0.0_f64; src.len()];
    lfilter_fir_b0_scalar(&mut out, src, 0.5_f64);
    out
  }

  /// Test adapter — dispatcher FIR fast-path against the same `b0`.
  fn fir_dispatch(src: &[f64]) -> Vec<f64> {
    let mut out = vec![0.0_f64; src.len()];
    lfilter_fir_b0(&mut out, src, 0.5_f64);
    out
  }

  /// `Exact` differential — scalar vs dispatcher over the lane sweep
  /// at `lanes = 2` (matches the NEON 2-lane `float64x2_t` tile width).
  #[test]
  fn fir_scalar_matches_dispatcher_exact() {
    assert_eq_over_lane_sweep(2, fir_scalar, fir_dispatch, gen_samples);
  }

  /// Lane-sweep covers FIR-relevant boundary lengths.
  #[test]
  fn fir_lane_sweep_covers_tile_boundaries() {
    let sweep = lane_sweep_lengths(2);
    assert_eq!(sweep, [0, 1, 1, 2, 3, 3, 4, 6, 7]);
  }

  /// Direct NEON-arm adapter, aarch64-only.
  #[cfg(target_arch = "aarch64")]
  fn fir_neon(src: &[f64]) -> Vec<f64> {
    let mut out = vec![0.0_f64; src.len()];
    // SAFETY: caller guards on `is_neon_available()`; size is `n`
    // exactly; kernel writes every slot.
    unsafe { super::lfilter_fir_b0_neon(&mut out, src, 0.5_f64) };
    out
  }

  /// NEON-vs-scalar bit-identical assertion via direct kernel call.
  #[cfg(target_arch = "aarch64")]
  #[test]
  fn fir_neon_matches_scalar_bit_identical() {
    if !crate::simd::is_neon_available() {
      return;
    }
    for &n in &[0usize, 1, 2, 3, 7, 8, 9, 64, 1024] {
      let src = gen_samples(n);
      let s = fir_scalar(&src);
      let d = fir_neon(&src);
      assert_eq!(s.len(), d.len(), "fir n={n}: length mismatch");
      for (i, (sv, dv)) in s.iter().zip(d.iter()).enumerate() {
        assert_eq!(
          sv.to_bits(),
          dv.to_bits(),
          "fir n={n} i={i}: scalar/NEON not bit-identical (s={sv}, d={dv})"
        );
      }
    }
  }

  // ── Biquad specialization (state_len == 2) ────────────────────────

  /// BS.1770 K-weighting high-shelf coefficients at 48 kHz (b, a).
  /// Pre-normalized by `a[0]`. Matches the runtime values produced by
  /// `bs1770_biquad_coefficients(4.0, 1/sqrt(2), 1500.0, 48000.0,
  /// HighShelf)` in [`crate::audio::dsp`].
  fn k_weight_hs_coeffs_48k() -> ([f64; 3], [f64; 3]) {
    use core::f64::consts::PI;
    let gain_db = 4.0_f64;
    let q = 1.0 / 2.0_f64.sqrt();
    let fc = 1500.0_f64;
    let rate = 48000.0_f64;
    let amplitude = 10.0_f64.powf(gain_db / 40.0);
    let omega = 2.0 * PI * (fc / rate);
    let alpha = omega.sin() / (2.0 * q);
    let cos_omega = omega.cos();
    let sqrt_a = amplitude.sqrt();
    let b0 = amplitude * ((amplitude + 1.0) + (amplitude - 1.0) * cos_omega + 2.0 * sqrt_a * alpha);
    let b1 = -2.0 * amplitude * ((amplitude - 1.0) + (amplitude + 1.0) * cos_omega);
    let b2 = amplitude * ((amplitude + 1.0) + (amplitude - 1.0) * cos_omega - 2.0 * sqrt_a * alpha);
    let a0 = (amplitude + 1.0) - (amplitude - 1.0) * cos_omega + 2.0 * sqrt_a * alpha;
    let a1 = 2.0 * ((amplitude - 1.0) - (amplitude + 1.0) * cos_omega);
    let a2 = (amplitude + 1.0) - (amplitude - 1.0) * cos_omega - 2.0 * sqrt_a * alpha;
    ([b0 / a0, b1 / a0, b2 / a0], [1.0, a1 / a0, a2 / a0])
  }

  /// BS.1770 K-weighting high-pass coefficients at 48 kHz.
  fn k_weight_hp_coeffs_48k() -> ([f64; 3], [f64; 3]) {
    use core::f64::consts::PI;
    let q = 0.5_f64;
    let fc = 38.0_f64;
    let rate = 48000.0_f64;
    let omega = 2.0 * PI * (fc / rate);
    let alpha = omega.sin() / (2.0 * q);
    let cos_omega = omega.cos();
    let b0 = (1.0 + cos_omega) / 2.0;
    let b1 = -(1.0 + cos_omega);
    let b2 = (1.0 + cos_omega) / 2.0;
    let a0 = 1.0 + alpha;
    let a1 = -2.0 * cos_omega;
    let a2 = 1.0 - alpha;
    ([b0 / a0, b1 / a0, b2 / a0], [1.0, a1 / a0, a2 / a0])
  }

  /// Reference oracle — the generic kernel from `audio::dsp`, transcribed
  /// here (this module sits inside `simd::audio`, which the dsp module
  /// uses, so we cannot call back into the dsp kernel without a circular
  /// import). The body matches `lfilter_f64_in_place` for the
  /// pre-normalized `state_len == 2` case byte-for-byte.
  fn lfilter_generic_in_place_reference(b: &[f64], a: &[f64], x: &mut [f64]) {
    let b0 = b[0];
    let mut state = vec![0.0_f64; a.len().max(b.len()) - 1];
    for slot in x.iter_mut() {
      let sample = *slot;
      let output = b0 * sample + state[0];
      let state_len = state.len();
      for i in 1..state_len {
        let feedforward = b.get(i).copied().unwrap_or(0.0) * sample;
        let feedback = a.get(i).copied().unwrap_or(0.0) * output;
        state[i - 1] = state[i] + feedforward - feedback;
      }
      let feedforward_last = b.get(state_len).copied().unwrap_or(0.0) * sample;
      let feedback_last = a.get(state_len).copied().unwrap_or(0.0) * output;
      state[state_len - 1] = feedforward_last - feedback_last;
      *slot = output;
    }
  }

  /// CRITICAL — bit-exactness against the generic `lfilter_f64_in_place`
  /// kernel for the actual K-weighting workload. A 1-ulp drift at any
  /// sample would bias the LUFS measurement against the EBU R128
  /// reference. Tests the high-shelf coefficients at 48 kHz over 1 s
  /// of synthetic audio.
  #[test]
  fn biquad_bit_exact_vs_generic_hs_48k() {
    let (b, a) = k_weight_hs_coeffs_48k();
    let n = 48000_usize;
    let src = gen_samples(n);

    let mut a_dispatch = src.clone();
    lfilter_biquad(&mut a_dispatch, &b, &a);

    let mut a_generic = src;
    lfilter_generic_in_place_reference(&b, &a, &mut a_generic);

    assert_eq!(a_dispatch.len(), a_generic.len());
    for (i, (d, g)) in a_dispatch.iter().zip(a_generic.iter()).enumerate() {
      assert_eq!(
        d.to_bits(),
        g.to_bits(),
        "biquad HS i={i}: dispatcher vs generic not bit-identical (d={d}, g={g})"
      );
    }
  }

  /// CRITICAL — bit-exactness against the generic kernel for the
  /// high-pass coefficients at 48 kHz. Same shape as the high-shelf
  /// test.
  #[test]
  fn biquad_bit_exact_vs_generic_hp_48k() {
    let (b, a) = k_weight_hp_coeffs_48k();
    let n = 48000_usize;
    let src = gen_samples(n);

    let mut a_dispatch = src.clone();
    lfilter_biquad(&mut a_dispatch, &b, &a);

    let mut a_generic = src;
    lfilter_generic_in_place_reference(&b, &a, &mut a_generic);

    assert_eq!(a_dispatch.len(), a_generic.len());
    for (i, (d, g)) in a_dispatch.iter().zip(a_generic.iter()).enumerate() {
      assert_eq!(
        d.to_bits(),
        g.to_bits(),
        "biquad HP i={i}: dispatcher vs generic not bit-identical (d={d}, g={g})"
      );
    }
  }

  /// CRITICAL — bit-exactness when CHAINED (HS → HP), which is the
  /// actual `k_weight_channel` shape. Catches any state-cell aliasing
  /// or input-aliasing bugs that show up only when two biquads operate
  /// on the same buffer in sequence.
  #[test]
  fn biquad_bit_exact_vs_generic_chained_k_weight_48k() {
    let (hs_b, hs_a) = k_weight_hs_coeffs_48k();
    let (hp_b, hp_a) = k_weight_hp_coeffs_48k();
    let n = 48000_usize;
    let src = gen_samples(n);

    let mut a_dispatch = src.clone();
    lfilter_biquad(&mut a_dispatch, &hs_b, &hs_a);
    lfilter_biquad(&mut a_dispatch, &hp_b, &hp_a);

    let mut a_generic = src;
    lfilter_generic_in_place_reference(&hs_b, &hs_a, &mut a_generic);
    lfilter_generic_in_place_reference(&hp_b, &hp_a, &mut a_generic);

    assert_eq!(a_dispatch.len(), a_generic.len());
    for (i, (d, g)) in a_dispatch.iter().zip(a_generic.iter()).enumerate() {
      assert_eq!(
        d.to_bits(),
        g.to_bits(),
        "chained K-weight i={i}: dispatcher vs generic not bit-identical (d={d}, g={g})"
      );
    }
  }

  /// Scalar arm matches dispatcher arm (single-pass differential).
  #[test]
  fn biquad_scalar_matches_dispatcher() {
    let (b, a) = k_weight_hs_coeffs_48k();
    let n = 12_345_usize;
    let src = gen_samples(n);

    let mut s = src.clone();
    lfilter_biquad_scalar(&mut s, &b, &a);

    let mut d = src;
    lfilter_biquad(&mut d, &b, &a);

    for (i, (sv, dv)) in s.iter().zip(d.iter()).enumerate() {
      assert_eq!(
        sv.to_bits(),
        dv.to_bits(),
        "biquad scalar vs dispatcher i={i}: not bit-identical"
      );
    }
  }

  /// Empty input is a no-op.
  #[test]
  fn biquad_empty_is_noop() {
    let (b, a) = k_weight_hs_coeffs_48k();
    let mut x: Vec<f64> = vec![];
    lfilter_biquad(&mut x, &b, &a);
    assert!(x.is_empty());
  }

  /// Release-mode precondition guards.
  #[test]
  #[should_panic(expected = "lfilter_biquad_scalar: b.len() must be 3")]
  fn biquad_scalar_panics_on_wrong_b_len() {
    let mut x = [0.0_f64; 4];
    let b = [1.0_f64; 2];
    let a = [1.0_f64; 3];
    lfilter_biquad_scalar(&mut x, &b, &a);
  }

  #[test]
  #[should_panic(expected = "simd::audio::lfilter_biquad: b.len() must be 3")]
  fn biquad_dispatch_panics_on_wrong_b_len() {
    let mut x = [0.0_f64; 4];
    let b = [1.0_f64; 2];
    let a = [1.0_f64; 3];
    lfilter_biquad(&mut x, &b, &a);
  }

  #[test]
  #[should_panic(expected = "lfilter_fir_b0_scalar: out.len()")]
  fn fir_scalar_panics_on_size_mismatch() {
    let mut out = [0.0_f64; 4];
    let src = [0.0_f64; 6];
    lfilter_fir_b0_scalar(&mut out, &src, 0.5);
  }
}