mlxrs 0.1.0

Safe Rust bindings for Apple's MLX array framework, with LM, VLM, audio, and embeddings support
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
//! `rotate_buf` pixel permutation.
//!
//! Tracking: [#150](https://github.com/Findit-AI/mlxrs/issues/150).
//! The NEON kernel ships unconditionally: auto-vectorization is
//! compiler-version-dependent and could silently de-vectorize on a
//! toolchain upgrade, so the hand-rolled arm pins the SIMD contract.
//!
//! # The defect class
//!
//! The original `crate::vlm::image::rotate_buf` is a per-pixel
//! `copy_from_slice` loop over the source:
//!
//! ```rust,ignore
//! for y in 0..h_usize {
//!   for x in 0..w_usize {
//!     let (nx, ny) = match rotation { ... };
//!     let src_off = (y * w_usize + x) * channels;
//!     let dst_off = (ny * out_w_usize + nx) * channels;
//!     dst[dst_off..dst_off + channels].copy_from_slice(&src[src_off..src_off + channels]);
//!   }
//! }
//! ```
//!
//! Per pixel: a `copy_from_slice` of `channels` bytes (1, 2, 3, or 4
//! for u8 / u16 / f32 element types). LLVM auto-vectorizes the inner
//! per-channel copy as a single `LDR Wn`/`STR Wn` for the
//! `channels=4` case, but the outer iteration is **scatter-dominated**
//! — `dst_off` is a row-stride permutation of `src_off`, so successive
//! source pixels write to different output rows. NEON has no scatter,
//! so the SIMD win is bounded by the per-pixel widen (channels=4 case)
//! + the outer-loop unrolling.
//!
//! # The fix — u8 channels=4 specialised NEON kernel
//!
//! The hot path in mlxrs (and the only path where NEON gives a
//! meaningful speedup over LLVM's auto-vec) is **u8 + channels=4**
//! (Rgba8 source from `image::DynamicImage`). For this case the kernel:
//!
//! 1. Reads 4 source pixels (16 bytes) per tile via `vld1q_u8`.
//! 2. Computes the 4 destination offsets per the rotation kind.
//! 3. Writes each pixel as a single 32-bit store
//!    (`core::ptr::write_unaligned::<u32>`) at the destination offset.
//!
//! The destination writes are inherently scattered (each pixel goes to
//! a different row), so the NEON load is the only contiguous step. The
//! kernel matches the auto-vectorized scalar's per-pixel-copy shape but
//! pins the load width at 16 bytes — a guaranteed contract independent
//! of LLVM heuristics.
//!
//! For every other type / channels combination (u8 + 1/2/3, u16 + any,
//! f32 + any) the dispatcher falls back to the scalar arm. Specialising
//! for `(u8, 4)` covers the dominant Rgba8 image-decode + EXIF-rotate
//! path; the other arms are infrequent enough that the scalar arm's
//! auto-vectorized shape is the right contract.
//!
//! # Correctness class — `Exact`
//!
//! Pure data movement — every output byte equals exactly one input
//! byte. Scalar and NEON arms produce bit-identical output. Differential
//! tests use [`crate::simd::diff::assert_eq_over_lane_sweep`].
//!
//! # Output API
//!
//! The dispatcher writes into a caller-allocated `&mut [u8]` already
//! sized to `src.len()` (per the caller's pre-existing
//! `try_reserve_exact` + `resize` discipline in `rotate_buf`). Unlike
//! the RGB/BGR widen kernels, no `MaybeUninit` is needed — every
//! destination byte is written by exactly one source-pixel store.
//!
//! # Rotation arms (mirrors `RotateKind`)
//!
//! - `Rotate90`        : `(x, y) -> (h - 1 - y, x)` ; out dims `(h, w)`
//! - `Rotate270`       : `(x, y) -> (y, w - 1 - x)` ; out dims `(h, w)`
//! - `Rotate90FlipH`   : `(x, y) -> (y, x)` ; out dims `(h, w)`
//! - `Rotate270FlipH`  : `(x, y) -> (h - 1 - y, w - 1 - x)` ; out dims `(h, w)`
//!
//! All four output dimensions are `(h, w)` (transpose); the NEON arm
//! is parameterised by the rotation kind via a `RotateKind` enum
//! mirror to keep dispatch symmetric with the call site.

use derive_more::{Display, IsVariant};

/// Pixel-permutation rotation variants. Mirrors
/// `crate::vlm::image::RotateKind` (kept local because that enum is
/// crate-private; this is the dispatcher's parameter type).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Display, IsVariant)]
#[display("{}", self.as_str())]
pub enum RotateKind {
  /// Clockwise 90° (transpose + flip on the horizontal axis).
  Rotate90,
  /// Clockwise 270° (transpose + flip on the vertical axis).
  Rotate270,
  /// `Rotate90` then horizontal flip — collapses to `(y, x)`.
  Rotate90FlipH,
  /// `Rotate270` then horizontal flip — collapses to `(h-1-y, w-1-x)`.
  Rotate270FlipH,
}

impl RotateKind {
  /// Lowercase string tag for this variant.
  pub const fn as_str(&self) -> &'static str {
    match self {
      Self::Rotate90 => "rotate90",
      Self::Rotate270 => "rotate270",
      Self::Rotate90FlipH => "rotate90fliph",
      Self::Rotate270FlipH => "rotate270fliph",
    }
  }
}

/// Scalar reference: per-pixel rotation of `src` into `dst`. Bit-exact
/// match for the original `rotate_buf` inner two-loop.
///
/// `channels` is the per-pixel subpixel count (1 / 2 / 3 / 4); `dst`
/// and `src` are byte slices, so for `T = u8` the unit is bytes; the
/// scalar arm is intentionally **u8-only** since the dominant call
/// site is u8 (Rgba8 / Rgb8 image-decode), and the dispatcher gates
/// non-u8 inputs to the scalar arm directly in the caller (the public
/// generic `rotate_buf::<T>` in `crate::vlm::image` keeps its T-generic
/// inner loop).
///
/// # Preconditions
///
/// - `dst.len() == src.len() == src_w * src_h * channels`.
/// - `src_w * src_h * channels` does not overflow `usize` (panics
///   explicitly via `checked_mul` rather than wrapping silently).
///
/// All asserted **unconditionally** (release-too).
#[inline]
#[doc(hidden)]
pub fn rotate_buf_u8_scalar(
  dst: &mut [u8],
  src: &[u8],
  src_w: usize,
  src_h: usize,
  channels: usize,
  rotation: RotateKind,
) {
  let elements = src_w
    .checked_mul(src_h)
    .and_then(|wh| wh.checked_mul(channels))
    .unwrap_or_else(|| {
      panic!("rotate_buf_u8_scalar: dimensions {src_w}x{src_h}x{channels} overflow usize")
    });
  assert_eq!(
    src.len(),
    elements,
    "rotate_buf_u8_scalar: src.len() ({}) must equal src_w * src_h * channels ({} * {} * {} = {})",
    src.len(),
    src_w,
    src_h,
    channels,
    elements,
  );
  assert_eq!(
    dst.len(),
    elements,
    "rotate_buf_u8_scalar: dst.len() ({}) must equal src.len() ({})",
    dst.len(),
    elements,
  );

  // Output width per the rotation: every rotate variant transposes,
  // so out_w == src_h, out_h == src_w.
  let out_w = src_h;
  for y in 0..src_h {
    for x in 0..src_w {
      let (nx, ny) = match rotation {
        RotateKind::Rotate90 => (src_h - 1 - y, x),
        RotateKind::Rotate270 => (y, src_w - 1 - x),
        RotateKind::Rotate90FlipH => (y, x),
        RotateKind::Rotate270FlipH => (src_h - 1 - y, src_w - 1 - x),
      };
      let src_off = (y * src_w + x) * channels;
      let dst_off = (ny * out_w + nx) * channels;
      dst[dst_off..dst_off + channels].copy_from_slice(&src[src_off..src_off + channels]);
    }
  }
}

/// NEON 4-pixel-tile u8 rotation for `channels = 4` (RGBA). Reads 16
/// bytes per tile via `vld1q_u8`, then issues four 32-bit destination
/// stores via `core::ptr::write_unaligned::<u32>`.
///
/// # Safety
///
/// 1. NEON must be available on the executing CPU. Caller obligation;
///    discharged by [`rotate_buf_u8`].
/// 2. `channels == 4` — required by the per-pixel u32 store assumption.
/// 3. `src.len() == dst.len() == src_w * src_h * 4` — asserted
///    **unconditionally** here.
#[cfg(target_arch = "aarch64")]
#[inline]
#[target_feature(enable = "neon")]
unsafe fn rotate_buf_u8_channels4_neon(
  dst: &mut [u8],
  src: &[u8],
  src_w: usize,
  src_h: usize,
  rotation: RotateKind,
) {
  let channels = 4usize;
  let elements = src_w
    .checked_mul(src_h)
    .and_then(|wh| wh.checked_mul(channels))
    .unwrap_or_else(|| {
      panic!("rotate_buf_u8_channels4_neon: dimensions {src_w}x{src_h}x4 overflow usize")
    });
  assert_eq!(
    src.len(),
    elements,
    "rotate_buf_u8_channels4_neon: src.len() ({}) must equal src_w * src_h * 4 ({} * {} * 4 = {})",
    src.len(),
    src_w,
    src_h,
    elements,
  );
  assert_eq!(
    dst.len(),
    elements,
    "rotate_buf_u8_channels4_neon: dst.len() ({}) must equal src.len() ({})",
    dst.len(),
    elements,
  );

  let out_w = src_h;

  // SAFETY: per-row tile loop reads 16 bytes via `vld1q_u8` from
  // `src.as_ptr().add(src_off)` for `src_off + 16 <= row_end_off
  // <= src.len()`. Destination writes are per-pixel u32 stores via
  // `core::ptr::write_unaligned::<u32>` at `dst.as_mut_ptr().add(dst_off)`
  // for `dst_off + 4 <= dst.len()` — checked by the per-pixel index
  // math (`(ny * out_w + nx) * 4` with `0 <= ny < src_h`, `0 <= nx <
  // src_w`, so `dst_off + 4 <= src_h * src_w * 4 = dst.len()`).
  // Writes target `&mut [u8]` backing memory, which has no validity
  // invariants beyond size + alignment; `write_unaligned` accepts any
  // address. NEON availability is the caller's obligation
  // (precondition #1).
  unsafe {
    let src_base = src.as_ptr();
    let dst_base = dst.as_mut_ptr();

    for y in 0..src_h {
      let row_x = src_w - (src_w % 4);
      let mut x = 0usize;
      while x + 4 <= row_x {
        let src_off = (y * src_w + x) * channels;
        // Load 16 source bytes (4 pixels × 4 channels).
        let tile = core::arch::aarch64::vld1q_u8(src_base.add(src_off));

        // Store as four u32 to a 16-byte stack scratch so we can
        // re-load per-pixel u32 lanes for the scattered destination
        // stores. This avoids vgetq_lane_u32 four times.
        let mut scratch = [0u8; 16];
        core::arch::aarch64::vst1q_u8(scratch.as_mut_ptr(), tile);

        for lane in 0..4 {
          let xx = x + lane;
          let (nx, ny) = match rotation {
            RotateKind::Rotate90 => (src_h - 1 - y, xx),
            RotateKind::Rotate270 => (y, src_w - 1 - xx),
            RotateKind::Rotate90FlipH => (y, xx),
            RotateKind::Rotate270FlipH => (src_h - 1 - y, src_w - 1 - xx),
          };
          let dst_off = (ny * out_w + nx) * channels;
          // Read the u32 pixel from the scratch buffer and write it
          // unaligned to dst_off. `read_unaligned`/`write_unaligned`
          // are required because `scratch` is u8-aligned and `dst`
          // offsets are channels-multiples (= 4n) which IS naturally
          // u32-aligned for u8-backed buffers, but we keep
          // `write_unaligned` for portability.
          let pixel: u32 = core::ptr::read_unaligned(scratch.as_ptr().add(lane * 4).cast::<u32>());
          core::ptr::write_unaligned(dst_base.add(dst_off).cast::<u32>(), pixel);
        }
        x += 4;
      }
      // Tail (`src_w % 4` < 4 pixels) — scalar per-pixel copy.
      while x < src_w {
        let (nx, ny) = match rotation {
          RotateKind::Rotate90 => (src_h - 1 - y, x),
          RotateKind::Rotate270 => (y, src_w - 1 - x),
          RotateKind::Rotate90FlipH => (y, x),
          RotateKind::Rotate270FlipH => (src_h - 1 - y, src_w - 1 - x),
        };
        let src_off = (y * src_w + x) * channels;
        let dst_off = (ny * out_w + nx) * channels;
        let pixel: u32 = core::ptr::read_unaligned(src_base.add(src_off).cast::<u32>());
        core::ptr::write_unaligned(dst_base.add(dst_off).cast::<u32>(), pixel);
        x += 1;
      }
    }
  }
}

/// Public dispatcher: rotate a u8 byte buffer in place. Routes to the
/// `channels=4` NEON kernel on `aarch64` (when NEON is reported) when
/// `channels == 4`; everything else falls back to the scalar arm
/// (which itself is per-pixel `copy_from_slice` that LLVM
/// auto-vectorizes for `channels=1/2/4`).
///
/// Used by `crate::vlm::image::rotate_buf` for the u8 element-type
/// arms (Luma8 / LumaA8 / Rgb8 / Rgba8).
///
/// # Preconditions
///
/// - `src.len() == dst.len() == src_w * src_h * channels` — asserted
///   unconditionally.
/// - `src_w * src_h * channels` does not overflow `usize` — checked
///   via `checked_mul` BEFORE the size-equality assertions, so a
///   wrapped product can never sneak past the size checks and let the
///   unsafe NEON kernel compute offsets from unwrapped loop dims (the
///   wired [`crate::vlm::image::rotate_buf`] caller already
///   pre-checks, but this public entry is reachable directly).
///
/// # Panics
///
/// Panics explicitly (not silently wraps) on `src_w * src_h *
/// channels` `usize` overflow — the only correct response when a
/// caller has supplied dimensions that cannot fit a contiguous
/// buffer, since silently wrapping would let an under-sized buffer
/// satisfy the size-equality assertion and reach the unsafe kernel.
///
/// # Correctness class
///
/// `Exact` — bit-identical output between scalar and NEON.
#[inline]
#[doc(hidden)]
pub fn rotate_buf_u8(
  dst: &mut [u8],
  src: &[u8],
  src_w: usize,
  src_h: usize,
  channels: usize,
  rotation: RotateKind,
) {
  // Checked dimension math BEFORE the size-equality assertions:
  // wrapping `src_w * src_h * channels` in release mode could
  // otherwise produce a small `elements` that an under-sized
  // `src` / `dst` would satisfy, letting the unsafe NEON kernel
  // compute per-pixel offsets from unwrapped loop dims and issue
  // out-of-bounds `vld1q_u8` / `write_unaligned` (UB).
  let elements = src_w
    .checked_mul(src_h)
    .and_then(|wh| wh.checked_mul(channels))
    .unwrap_or_else(|| {
      panic!("simd::vlm::rotate_buf_u8: dimensions {src_w}x{src_h}x{channels} overflow usize")
    });
  assert_eq!(
    src.len(),
    elements,
    "simd::vlm::rotate_buf_u8: src.len() ({}) must equal src_w * src_h * channels ({} * {} * {} = {})",
    src.len(),
    src_w,
    src_h,
    channels,
    elements,
  );
  assert_eq!(
    dst.len(),
    elements,
    "simd::vlm::rotate_buf_u8: dst.len() ({}) must equal src.len() ({})",
    dst.len(),
    elements,
  );

  #[cfg(target_arch = "aarch64")]
  {
    if channels == 4 && crate::simd::is_neon_available() {
      // SAFETY: NEON gated; channels == 4 confirmed; size preconditions
      // asserted above; `elements` derived via `checked_mul` so per-
      // pixel offsets cannot overflow into stale ranges.
      unsafe { rotate_buf_u8_channels4_neon(dst, src, src_w, src_h, rotation) };
      return;
    }
  }
  rotate_buf_u8_scalar(dst, src, src_w, src_h, channels, rotation);
}

#[cfg(test)]
mod tests {
  //! Scalar vs dispatcher Exact differential tests + edge coverage for the rotate.

  use super::{RotateKind, rotate_buf_u8, rotate_buf_u8_scalar};

  /// Build a deterministic source buffer of `w * h * channels` bytes.
  fn src(w: usize, h: usize, channels: usize) -> Vec<u8> {
    (0..(w * h * channels)).map(|i| (i % 251) as u8).collect()
  }

  fn rotate_via(
    dispatch: bool,
    w: usize,
    h: usize,
    channels: usize,
    rotation: RotateKind,
  ) -> Vec<u8> {
    let s = src(w, h, channels);
    let mut d = vec![0u8; s.len()];
    if dispatch {
      rotate_buf_u8(&mut d, &s, w, h, channels, rotation);
    } else {
      rotate_buf_u8_scalar(&mut d, &s, w, h, channels, rotation);
    }
    d
  }

  #[test]
  fn rotate_buf_u8_channels4_scalar_matches_dispatcher_exact() {
    // Sweep over interesting widths (boundaries around multiples of 4).
    for &w in &[1usize, 4, 5, 7, 8, 16, 17, 33] {
      for &h in &[1usize, 2, 4, 8, 17] {
        for &rotation in &[
          RotateKind::Rotate90,
          RotateKind::Rotate270,
          RotateKind::Rotate90FlipH,
          RotateKind::Rotate270FlipH,
        ] {
          let s = rotate_via(false, w, h, 4, rotation);
          let d = rotate_via(true, w, h, 4, rotation);
          assert_eq!(
            s, d,
            "Exact mismatch (w={w}, h={h}, channels=4, rotation={rotation:?})"
          );
        }
      }
    }
  }

  #[test]
  fn rotate_buf_u8_channels3_scalar_matches_dispatcher_exact() {
    // channels=3 routes to the scalar arm; verify the dispatcher
    // produces the same output as the scalar reference.
    for &w in &[1usize, 4, 17] {
      for &h in &[1usize, 4, 8] {
        for &rotation in &[
          RotateKind::Rotate90,
          RotateKind::Rotate270,
          RotateKind::Rotate90FlipH,
          RotateKind::Rotate270FlipH,
        ] {
          let s = rotate_via(false, w, h, 3, rotation);
          let d = rotate_via(true, w, h, 3, rotation);
          assert_eq!(
            s, d,
            "Exact mismatch (w={w}, h={h}, channels=3, rotation={rotation:?})"
          );
        }
      }
    }
  }

  #[test]
  fn rotate_buf_u8_rotate90_pin() {
    // 2x2 RGBA, Rotate90: per scalar arm `(nx, ny) = (h-1-y, x)`,
    // dst_off = (ny * out_w + nx) * 4 where out_w = src_h = 2.
    let w = 2;
    let h = 2;
    let channels = 4;
    let s: Vec<u8> = (0..16).map(|i| i as u8).collect();
    let mut d = vec![0u8; 16];
    rotate_buf_u8(&mut d, &s, w, h, channels, RotateKind::Rotate90);
    // src (0, 0) = [0..4]; nx=h-1-0=1, ny=0; dst_off = (0*2 + 1)*4 = 4
    assert_eq!(&d[4..8], &s[0..4], "Rotate90: src(0,0) → dst[4..8]");
    // src (1, 0) = [4..8]; nx=h-1-0=1, ny=1; dst_off = (1*2 + 1)*4 = 12
    assert_eq!(&d[12..16], &s[4..8], "Rotate90: src(1,0) → dst[12..16]");
    // src (0, 1) = [8..12]; nx=h-1-1=0, ny=0; dst_off = (0*2 + 0)*4 = 0
    assert_eq!(&d[0..4], &s[8..12], "Rotate90: src(0,1) → dst[0..4]");
    // src (1, 1) = [12..16]; nx=h-1-1=0, ny=1; dst_off = (1*2 + 0)*4 = 8
    assert_eq!(&d[8..12], &s[12..16], "Rotate90: src(1,1) → dst[8..12]");
  }

  #[test]
  fn rotate_buf_u8_double_rotate_round_trip() {
    // Two Rotate90s + two Rotate270s should recover the input
    // (composition is identity for any 360° net rotation). We test
    // Rotate90 → Rotate270 round-trip (which IS identity since
    // Rotate90 inverse is Rotate270).
    let w = 4;
    let h = 3;
    let channels = 4;
    let s: Vec<u8> = (0..(w * h * channels)).map(|i| (i % 251) as u8).collect();

    // Rotate90: src (w=4, h=3) → out (w=h=3, h=w=4)
    let mut once = vec![0u8; s.len()];
    rotate_buf_u8(&mut once, &s, w, h, channels, RotateKind::Rotate90);
    // Rotate270 on the once-rotated buffer (w=3, h=4) → recovers (w=4, h=3)
    let mut twice = vec![0u8; s.len()];
    rotate_buf_u8(&mut twice, &once, h, w, channels, RotateKind::Rotate270);
    assert_eq!(twice, s, "Rotate90 ∘ Rotate270 should be identity");
  }

  #[test]
  #[should_panic(
    expected = "simd::vlm::rotate_buf_u8: src.len() (3) must equal src_w * src_h * channels"
  )]
  fn rotate_buf_u8_panics_on_size_mismatch() {
    let s = vec![0u8; 3]; // WRONG: should be 2*2*4 = 16
    let mut d = vec![0u8; 16];
    rotate_buf_u8(&mut d, &s, 2, 2, 4, RotateKind::Rotate90);
  }

  /// Wrap-arith defence: even though the wired `rotate_buf_u8`
  /// caller pre-checks `src_w * src_h * channels` via `checked_mul`,
  /// the public dispatcher entry is reachable directly (e.g. via a
  /// `pub use` from a future caller, or via the in-crate `unsafe`
  /// neighbours that share the symbol). A wrapping multiply in
  /// release mode would otherwise let a small `elements` value pass
  /// the size-equality assertion and reach the unsafe NEON kernel —
  /// where the per-pixel offset math (computed from the unwrapped
  /// loop dims) would compute out-of-bounds offsets and trigger UB.
  /// `checked_mul` must therefore land BEFORE the asserts.
  #[test]
  #[should_panic(expected = "overflow usize")]
  fn rotate_buf_u8_panics_on_dimension_overflow() {
    // src_w * src_h would already saturate (usize::MAX/2 + 1) * 2 → wrap.
    // We give a small `src` + `dst` so allocation succeeds and the
    // dimension overflow is the only failure mode.
    let s = vec![0u8; 16];
    let mut d = vec![0u8; 16];
    rotate_buf_u8(&mut d, &s, usize::MAX / 2 + 1, 2, 4, RotateKind::Rotate90);
  }

  #[test]
  fn rotate_buf_u8_rotate90_flip_h_collapses() {
    // Rotate90FlipH: per scalar arm `(nx, ny) = (y, x)` — pure
    // transpose. dst_off = (ny * out_w + nx) * 4 with out_w = src_h = 2.
    let w = 2;
    let h = 2;
    let channels = 4;
    let s: Vec<u8> = (0..16).map(|i| i as u8).collect();
    let mut d = vec![0u8; 16];
    rotate_buf_u8(&mut d, &s, w, h, channels, RotateKind::Rotate90FlipH);
    // src (0, 0) = [0..4]; nx=0, ny=0; dst_off = 0. dst[0..4] = src[0..4]
    assert_eq!(&d[0..4], &s[0..4]);
    // src (1, 0) = [4..8]; nx=0, ny=1; dst_off = (1*2 + 0)*4 = 8.
    assert_eq!(&d[8..12], &s[4..8]);
    // src (0, 1) = [8..12]; nx=1, ny=0; dst_off = (0*2 + 1)*4 = 4.
    assert_eq!(&d[4..8], &s[8..12]);
    // src (1, 1) = [12..16]; nx=1, ny=1; dst_off = (1*2 + 1)*4 = 12.
    assert_eq!(&d[12..16], &s[12..16]);
  }

  /// [`RotateKind::as_str`] is only reached via the `Display` derive
  /// (`#[display("{}", self.as_str())]`); the differential tests format
  /// with `{:?}` (Debug), so the four `as_str` arms are otherwise never
  /// executed. Pin every variant's lowercase tag here and assert the
  /// `Display` impl routes through `as_str` (so `to_string()` matches).
  #[test]
  fn rotate_kind_as_str_and_display_all_variants() {
    let cases = [
      (RotateKind::Rotate90, "rotate90"),
      (RotateKind::Rotate270, "rotate270"),
      (RotateKind::Rotate90FlipH, "rotate90fliph"),
      (RotateKind::Rotate270FlipH, "rotate270fliph"),
    ];
    for (kind, tag) in cases {
      assert_eq!(kind.as_str(), tag, "as_str mismatch for {kind:?}");
      // `Display` is derived as `self.as_str()`, so `to_string()` must
      // equal the tag — this drives the same match arms via the
      // formatter path.
      assert_eq!(kind.to_string(), tag, "Display mismatch for {kind:?}");
    }
  }

  /// Scalar reference, dimension-overflow arm. The existing
  /// [`rotate_buf_u8_panics_on_dimension_overflow`] exercises the
  /// *dispatcher*'s `checked_mul`; this drives the **scalar** kernel's
  /// own pre-assert overflow guard (a distinct `unwrap_or_else` panic
  /// closure). `src_w * src_h` saturates `usize`, so the product wraps
  /// without `checked_mul` and the explicit `panic!` is the only
  /// correct response. Small `src` / `dst` so the overflow is the only
  /// failure mode (the asserts come *after* the checked math).
  #[test]
  #[should_panic(expected = "rotate_buf_u8_scalar: dimensions")]
  fn rotate_buf_u8_scalar_panics_on_dimension_overflow() {
    let s = vec![0u8; 16];
    let mut d = vec![0u8; 16];
    rotate_buf_u8_scalar(&mut d, &s, usize::MAX / 2 + 1, 2, 4, RotateKind::Rotate90);
  }

  /// Scalar reference, `src.len()` size-mismatch arm. The existing
  /// size-mismatch test only drives the *dispatcher*; calling the
  /// scalar kernel directly with a too-short `src` (correct,
  /// non-overflowing dims so the `checked_mul` passes) exercises the
  /// scalar arm's first `assert_eq!` and its full message-formatting
  /// args (`src.len()`, `src_w`, `src_h`, `channels`, `elements`).
  #[test]
  #[should_panic(
    expected = "rotate_buf_u8_scalar: src.len() (3) must equal src_w * src_h * channels (2 * 2 * 4 = 16)"
  )]
  fn rotate_buf_u8_scalar_panics_on_src_size_mismatch() {
    let s = vec![0u8; 3]; // WRONG: 2*2*4 = 16
    let mut d = vec![0u8; 16];
    rotate_buf_u8_scalar(&mut d, &s, 2, 2, 4, RotateKind::Rotate90);
  }

  /// Scalar reference, `dst.len()` size-mismatch arm. `src` is sized
  /// correctly (so the first `assert_eq!` passes), but `dst` is
  /// undersized — driving the scalar arm's **second** `assert_eq!` and
  /// its message-formatting args (`dst.len()`, `elements`). This is the
  /// only shape that reaches the second scalar assert (a `src`
  /// mismatch short-circuits at the first).
  #[test]
  #[should_panic(expected = "rotate_buf_u8_scalar: dst.len() (3) must equal src.len() (16)")]
  fn rotate_buf_u8_scalar_panics_on_dst_size_mismatch() {
    let s = vec![0u8; 16]; // correct: 2*2*4
    let mut d = vec![0u8; 3]; // WRONG
    rotate_buf_u8_scalar(&mut d, &s, 2, 2, 4, RotateKind::Rotate90);
  }

  /// Dispatcher, `dst.len()` size-mismatch arm. The existing
  /// [`rotate_buf_u8_panics_on_size_mismatch`] supplies a too-short
  /// `src` (hitting the dispatcher's *first* `assert_eq!`); here `src`
  /// is correct and `dst` is undersized, so the dispatcher's **second**
  /// `assert_eq!` (`dst.len()` vs `src.len()`) and its message args are
  /// covered. The mismatch is caught before any routing to the unsafe
  /// NEON kernel.
  #[test]
  #[should_panic(expected = "simd::vlm::rotate_buf_u8: dst.len() (3) must equal src.len() (16)")]
  fn rotate_buf_u8_dispatch_panics_on_dst_size_mismatch() {
    let s = vec![0u8; 16]; // correct: 2*2*4
    let mut d = vec![0u8; 3]; // WRONG
    rotate_buf_u8(&mut d, &s, 2, 2, 4, RotateKind::Rotate90);
  }

  /// NEON kernel, dimension-overflow arm. The `channels=4` NEON kernel
  /// has its own pre-assert `checked_mul` guard (distinct panic closure
  /// from the dispatcher's). Driven through the kernel **directly** so
  /// the guard is covered even if dispatcher routing changes; gated on
  /// `is_neon_available()` so it no-ops (forcing the expected panic
  /// message) on non-NEON CPUs / `mlxrs_force_scalar`. `src_w * src_h`
  /// saturates `usize`; small `src` / `dst` so the overflow is the only
  /// failure mode, and the panic fires at the `unwrap_or_else` closure
  /// **before** any pointer arithmetic (so no UB despite the
  /// intentionally undersized buffers).
  #[cfg(target_arch = "aarch64")]
  #[test]
  #[should_panic(expected = "rotate_buf_u8_channels4_neon: dimensions")]
  fn rotate_buf_u8_neon_panics_on_dimension_overflow() {
    if !crate::simd::is_neon_available() {
      // Force the expected panic without invoking the kernel — the
      // guard under test only applies when the NEON arm is reachable.
      panic!("rotate_buf_u8_channels4_neon: dimensions (skipped — NEON unavailable)");
    }
    let s = vec![0u8; 16];
    let mut d = vec![0u8; 16];
    // SAFETY: `is_neon_available()` checked immediately above
    // (precondition #1); `channels == 4` is implicit in the kernel
    // (precondition #2). The kernel's `checked_mul` overflow guard
    // fires before any `vld1q_u8` / `write_unaligned`, so the
    // intentionally tiny `src` / `dst` are never dereferenced.
    unsafe {
      super::rotate_buf_u8_channels4_neon(&mut d, &s, usize::MAX / 2 + 1, 2, RotateKind::Rotate90)
    };
  }

  /// NEON kernel, `src.len()` size-mismatch arm. Non-overflowing dims
  /// (so the `checked_mul` passes) but a too-short `src` drives the
  /// kernel's first `assert_eq!` and its message args (`src.len()`,
  /// `src_w`, `src_h`, `elements`). The assert sits **before** the
  /// `unsafe` tile-loop block, so no OOB load occurs. Gated +
  /// no-op-panic on non-NEON, matching the bgr_widen NEON precondition
  /// tests.
  #[cfg(target_arch = "aarch64")]
  #[test]
  #[should_panic(
    expected = "rotate_buf_u8_channels4_neon: src.len() (3) must equal src_w * src_h * 4 (2 * 2 * 4 = 16)"
  )]
  fn rotate_buf_u8_neon_panics_on_src_size_mismatch() {
    if !crate::simd::is_neon_available() {
      panic!(
        "rotate_buf_u8_channels4_neon: src.len() (3) must equal src_w * src_h * 4 (2 * 2 * 4 = 16) (skipped — NEON unavailable)"
      );
    }
    let s = vec![0u8; 3]; // WRONG: 2*2*4 = 16
    let mut d = vec![0u8; 16];
    // SAFETY: NEON checked above. The size `assert_eq!` precedes the
    // `unsafe` pointer loop, so the kernel panics before any
    // `vld1q_u8` / `write_unaligned` — the undersized `src` is never
    // read past its length.
    unsafe { super::rotate_buf_u8_channels4_neon(&mut d, &s, 2, 2, RotateKind::Rotate90) };
  }

  /// NEON kernel, `dst.len()` size-mismatch arm. `src` is correct (so
  /// the first `assert_eq!` passes) and `dst` is undersized — driving
  /// the kernel's **second** `assert_eq!` and its message args
  /// (`dst.len()`, `elements`). As with the src-mismatch test, the
  /// assert precedes the `unsafe` block, so no OOB write occurs.
  #[cfg(target_arch = "aarch64")]
  #[test]
  #[should_panic(
    expected = "rotate_buf_u8_channels4_neon: dst.len() (3) must equal src.len() (16)"
  )]
  fn rotate_buf_u8_neon_panics_on_dst_size_mismatch() {
    if !crate::simd::is_neon_available() {
      panic!(
        "rotate_buf_u8_channels4_neon: dst.len() (3) must equal src.len() (16) (skipped — NEON unavailable)"
      );
    }
    let s = vec![0u8; 16]; // correct: 2*2*4
    let mut d = vec![0u8; 3]; // WRONG
    // SAFETY: NEON checked above. The `dst` size `assert_eq!` precedes
    // the `unsafe` pointer loop, so the kernel panics before any
    // `write_unaligned` — the undersized `dst` is never written.
    unsafe { super::rotate_buf_u8_channels4_neon(&mut d, &s, 2, 2, RotateKind::Rotate90) };
  }

  /// NEON kernel vs scalar reference, **direct** bit-identical
  /// differential (not via the dispatcher) so the `(u8, channels=4)`
  /// NEON arm is covered independent of dispatcher routing. Sweeps
  /// widths straddling the 4-pixel tile boundary (so both the
  /// `while x + 4 <= row_x` body and the `while x < src_w` scalar tail
  /// of the kernel run) and several heights. No-op-skips on non-NEON
  /// CPUs / `mlxrs_force_scalar`.
  #[cfg(target_arch = "aarch64")]
  #[test]
  fn rotate_buf_u8_neon_matches_scalar_bit_identical() {
    if !crate::simd::is_neon_available() {
      return;
    }
    let channels = 4usize;
    for &w in &[1usize, 2, 3, 4, 5, 7, 8, 9, 16, 17] {
      for &h in &[1usize, 2, 3, 5, 8] {
        for &rotation in &[
          RotateKind::Rotate90,
          RotateKind::Rotate270,
          RotateKind::Rotate90FlipH,
          RotateKind::Rotate270FlipH,
        ] {
          let s = src(w, h, channels);
          let mut scalar = vec![0u8; s.len()];
          rotate_buf_u8_scalar(&mut scalar, &s, w, h, channels, rotation);
          let mut neon = vec![0u8; s.len()];
          // SAFETY: `is_neon_available()` checked at the top of the
          // test (precondition #1); `channels == 4` (precondition #2);
          // `scalar`/`neon`/`s` are all sized exactly `w*h*4`
          // (precondition #3, asserted inside the kernel too).
          unsafe { super::rotate_buf_u8_channels4_neon(&mut neon, &s, w, h, rotation) };
          assert_eq!(
            neon, scalar,
            "NEON vs scalar Exact mismatch (w={w}, h={h}, channels=4, rotation={rotation:?})"
          );
        }
      }
    }
  }
}