Skip to main content

caps_sa/
lcp.rs

1//! Suffix comparison primitives over a generic text.
2//!
3//! Mirrors CaPS-SA's `Suffix_Array::LCP` family (see `include/Suffix_Array.hpp`
4//! and `include/Genomic_Text.hpp`). Performance-critical callers (the
5//! merge-sort and cascade-merge inner loops) construct a [`LcpDispatch`]
6//! **once** at the top of the SA build and pass it through. The dispatch
7//! holds a function pointer chosen by [`is_x86_feature_detected!`] /
8//! [`is_aarch64_feature_detected!`] at construction time, so the hot path
9//! is a single indirect call through a register — no per-call atomic
10//! loads, no per-call feature-detection branches.
11//!
12//! The free-standing [`lcp`] / [`suffix_cmp`] / [`lcp_u8`] helpers remain
13//! for one-off callers (and for the tests in this file). They construct a
14//! [`LcpDispatch`] on every call and are correspondingly slower; algorithm
15//! kernels should prefer the methods on [`LcpDispatch`].
16//!
17//! ## Symbol types
18//!
19//! The whole crate is generic over any [`Symbol`] type. `Symbol` is an
20//! `unsafe` marker for types whose in-memory bytes encode equality
21//! (`a == b` iff their byte representations are equal) — i.e. no padding,
22//! no invalid bit patterns. Blanket impls are provided for every stdlib
23//! integer type and for fixed-size arrays of `Symbol`s, so `u8`, `u16`,
24//! `u32`, `u64`, `[u8; 3]` (24-bit), … all work out of the box. The LCP
25//! function casts `&[S]` to a byte view, runs a single byte-level SIMD
26//! compare (AVX-512BW hybrid → AVX2 → NEON → scalar), then divides the
27//! byte-LCP by `size_of::<S>()` to get the symbol-LCP. Endianness is
28//! irrelevant because the byte-compare resolves equality only; symbol
29//! ordering is recovered by the caller's `text[lcp].cmp(&text[lcp + 1])`
30//! using `S`'s native `Ord`.
31
32use std::cmp::Ordering;
33
34use crate::limits::{LimitProvider, PlainText};
35
36/// A symbol type for suffix-array construction. All stdlib unsigned
37/// and signed integer types satisfy this, as does `[T; N]` for any
38/// `T: Symbol` (arrays have no padding and inherit `Ord`
39/// lexicographically from `T`). To use a custom type as a symbol,
40/// implement this trait yourself; if the type contains padding, mark
41/// it `#[repr(C, packed)]` first.
42///
43/// The trait bundles every other bound the algorithm needs from a
44/// symbol type ([`Ord`] + [`Copy`] + [`Send`] + [`Sync`] + `'static`),
45/// so the public API surface only ever needs `S: Symbol`.
46///
47/// # Safety
48///
49/// Implementations must guarantee that bit-equality of the in-memory
50/// representation implies value-equality — i.e. **no padding bytes**,
51/// **no invalid bit patterns** that two distinct values could share.
52/// The byte-view SIMD LCP path in [`LcpDispatch::lcp`] casts `&[S]`
53/// to a `&[u8]` view and compares bytes; if two distinct `S` values
54/// could have identical bytes, or one `S` value could have two
55/// different byte representations (e.g. via uninitialised padding),
56/// the LCP function would return wrong answers and corrupt the
57/// resulting suffix array.
58pub unsafe trait Symbol: Ord + Copy + Send + Sync + 'static {}
59
60macro_rules! impl_symbol_for_primitives {
61    ($($t:ty),* $(,)?) => {
62        $(
63            // SAFETY: stdlib primitive integers have no padding and every
64            // bit pattern is a valid value, so byte-equality is exactly
65            // value-equality.
66            unsafe impl Symbol for $t {}
67        )*
68    };
69}
70impl_symbol_for_primitives!(
71    u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize,
72);
73
74// SAFETY: arrays of `Symbol`s have no padding (Rust arrays are tightly
75// packed) and inherit equality element-wise, so byte-equality of the
76// whole array is exactly value-equality. This covers patterns like
77// `[u8; 3]` for 24-bit alphabets and `[u32; 2]` for 64-bit-on-32-bit.
78unsafe impl<T: Symbol, const N: usize> Symbol for [T; N] {}
79
80/// A function-pointer dispatch for byte-level LCP. The architecture-
81/// specific pointer is selected once at construction by feature
82/// detection; later calls reduce to a register-resident indirect call.
83///
84/// `LcpDispatch` is `Copy`, `Send`, and `Sync` (a function pointer is
85/// all three), so it threads freely through `rayon` boundaries.
86///
87/// The same byte-level function backs every symbol width: the
88/// [`Self::lcp`] method casts `&[S]` to a byte view, calls the function
89/// with byte-scale offsets, and divides the result by `size_of::<S>()`
90/// to recover the symbol-level LCP.
91#[derive(Copy, Clone)]
92pub struct LcpDispatch {
93    lcp_bytes_fn: LcpBytesFn,
94}
95
96/// Internal function-pointer type for the byte-level LCP path.
97/// `unsafe fn` because the AVX2 / AVX-512 / NEON variants are
98/// `#[target_feature]` gated; the dispatch's owner has already
99/// verified CPU support.
100type LcpBytesFn = unsafe fn(&[u8], usize, usize, usize) -> usize;
101
102impl LcpDispatch {
103    /// Detect the best LCP implementation for this CPU. Cheap (a couple
104    /// of `is_*_feature_detected!` checks) but does still touch the
105    /// feature-detection cache, so call it **once** per top-level build.
106    pub fn detect() -> Self {
107        Self {
108            lcp_bytes_fn: pick_lcp_bytes_impl(),
109        }
110    }
111
112    /// Forced scalar dispatch — useful for tests and for clients that
113    /// want a deterministic baseline.
114    pub fn scalar() -> Self {
115        Self {
116            lcp_bytes_fn: lcp_bytes_scalar,
117        }
118    }
119
120    /// Longest common prefix of `text[p..]` and `text[q..]` in symbols,
121    /// bounded by `max_ctx`. For any `S: Symbol` of non-zero size, this
122    /// dispatches to the byte-level SIMD path with byte-scaled offsets
123    /// and returns `byte_lcp / size_of::<S>()`.
124    #[inline]
125    pub fn lcp<S: Symbol>(&self, text: &[S], p: usize, q: usize, max_ctx: usize) -> usize {
126        let k = std::mem::size_of::<S>();
127        if k == 0 {
128            // ZSTs: `Symbol` permits ZSTs (e.g. a unit `struct Foo;`),
129            // but the byte-view dispatch can't divide by zero. Such a
130            // text has zero bytes regardless of length, so every suffix
131            // is identical and the LCP is just the length-bounded `lim`.
132            let lim_p = text.len().saturating_sub(p).min(max_ctx);
133            let lim_q = text.len().saturating_sub(q).min(max_ctx);
134            return lim_p.min(lim_q);
135        }
136        // SAFETY: `Symbol`'s `unsafe` contract is exactly "bit-equality
137        // is value-equality" — `&[S]` has the same byte representation
138        // as a `&[u8]` view over the same bytes, with no padding to
139        // worry about. `size_of_val(text)` gives the slice's exact
140        // byte length, which Rust's slice invariant already guarantees
141        // fits in `isize`.
142        let bytes =
143            unsafe { std::slice::from_raw_parts(text.as_ptr() as *const u8, size_of_val(text)) };
144        let byte_lcp = unsafe {
145            (self.lcp_bytes_fn)(
146                bytes,
147                p.saturating_mul(k),
148                q.saturating_mul(k),
149                max_ctx.saturating_mul(k),
150            )
151        };
152        byte_lcp / k
153    }
154
155    /// Total order on two suffixes of `text`. Uses [`Self::lcp`] for the
156    /// shared prefix, then resolves the first differing symbol or — if
157    /// both suffixes are exhausted within `max_ctx` — orders by remaining
158    /// length (shorter is smaller, the convention SAIS and CaPS-SA use).
159    ///
160    /// Zero-cost wrapper around [`Self::suffix_cmp_with`] for the
161    /// non-segmented case.
162    #[inline]
163    pub fn suffix_cmp<S: Symbol>(
164        &self,
165        text: &[S],
166        p: usize,
167        q: usize,
168        max_ctx: usize,
169    ) -> Ordering {
170        self.suffix_cmp_with(text, &PlainText::new(text.len()), p, q, max_ctx)
171    }
172
173    /// Like [`Self::suffix_cmp`] but takes a [`LimitProvider`] so the
174    /// suffix lengths used for the LCP-cap and the boundary-tie-break
175    /// come from a segmented view of the text. With [`PlainText`]
176    /// this matches [`Self::suffix_cmp`] exactly.
177    ///
178    /// The boundary tie-break (when both suffixes hit their limit
179    /// before any byte differs) is delegated to
180    /// [`LimitProvider::boundary_order`]; the default
181    /// "shorter-is-smaller" gives the generalised-SA convention,
182    /// custom impls can flip it (e.g. STAR's spacer-as-largest).
183    #[inline]
184    pub fn suffix_cmp_with<S: Symbol, L: LimitProvider>(
185        &self,
186        text: &[S],
187        lp: &L,
188        p: usize,
189        q: usize,
190        max_ctx: usize,
191    ) -> Ordering {
192        let lim_p = lp.lim_at(p);
193        let lim_q = lp.lim_at(q);
194        let lim = lim_p.min(lim_q).min(max_ctx);
195        let common = self.lcp(text, p, q, lim);
196        if common < lim {
197            text[p + common].cmp(&text[q + common])
198        } else {
199            lp.boundary_order(p, lim_p, q, lim_q)
200        }
201    }
202}
203
204/// One-off LCP. Constructs a fresh [`LcpDispatch`] on every call — fine
205/// for tests / introspection, slower than reusing one [`LcpDispatch`]
206/// across an algorithm's inner loop.
207#[inline]
208pub fn lcp<S: Symbol>(text: &[S], p: usize, q: usize, max_ctx: usize) -> usize {
209    LcpDispatch::detect().lcp(text, p, q, max_ctx)
210}
211
212/// One-off suffix comparison; see [`lcp`] for the cost note.
213#[inline]
214pub fn suffix_cmp<S: Symbol>(text: &[S], p: usize, q: usize, max_ctx: usize) -> Ordering {
215    LcpDispatch::detect().suffix_cmp(text, p, q, max_ctx)
216}
217
218/// One-off `u8`-typed LCP that auto-selects AVX-512 / AVX2 / NEON /
219/// scalar. Convenience entry point for byte texts; equivalent in cost
220/// to `LcpDispatch::detect().lcp::<u8>(...)` but skips the generic
221/// indirection.
222#[inline]
223pub fn lcp_u8(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
224    let f = pick_lcp_bytes_impl();
225    unsafe { f(text, p, q, max_ctx) }
226}
227
228/// Generic scalar LCP. Public so callers that already know they can
229/// skip the SIMD dispatch (e.g. non-`Symbol` symbol types like
230/// arbitrary `Eq` newtypes) can call this directly. Still
231/// symbol-granularity; just doesn't go through the byte view.
232#[inline]
233pub fn lcp_scalar<S: Eq>(text: &[S], p: usize, q: usize, max_ctx: usize) -> usize {
234    let n = text.len();
235    let lim_p = n.saturating_sub(p).min(max_ctx);
236    let lim_q = n.saturating_sub(q).min(max_ctx);
237    let lim = lim_p.min(lim_q);
238    let mut i = 0;
239    while i < lim {
240        if text[p + i] != text[q + i] {
241            return i;
242        }
243        i += 1;
244    }
245    i
246}
247
248/// Inspect this CPU's features and return the best [`LcpBytesFn`].
249fn pick_lcp_bytes_impl() -> LcpBytesFn {
250    #[cfg(target_arch = "x86_64")]
251    {
252        // AVX-512BW gives us a 64-byte byte-compare returning a 64-bit
253        // mask register directly — no movemask intrinsic, no extract.
254        // Both `f` (foundation) and `bw` (byte/word ops, for the
255        // `_mm512_cmpeq_epi8_mask` we use) are required.
256        if std::is_x86_feature_detected!("avx512f") && std::is_x86_feature_detected!("avx512bw") {
257            return lcp_bytes_avx512;
258        }
259        if std::is_x86_feature_detected!("avx2") {
260            return lcp_bytes_avx2;
261        }
262    }
263    #[cfg(target_arch = "aarch64")]
264    {
265        if std::arch::is_aarch64_feature_detected!("neon") {
266            return lcp_bytes_neon;
267        }
268    }
269    lcp_bytes_scalar
270}
271
272/// `unsafe fn`-typed scalar — wrapper around [`lcp_scalar`] for the
273/// `u8` instantiation so all dispatch targets share the [`LcpBytesFn`]
274/// signature.
275unsafe fn lcp_bytes_scalar(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
276    lcp_scalar(text, p, q, max_ctx)
277}
278
279/// AVX-512BW path: 64-byte vector compares; `_mm512_cmpeq_epi8_mask`
280/// returns the per-byte equality mask straight in a 64-bit `__mmask64`
281/// register (no movemask round-trip), and `(!mask).trailing_zeros()`
282/// gives the first differing byte.
283///
284/// The function leads with a single 32-byte AVX2 step. This keeps the
285/// short-LCP regime (random DNA, where every call typically resolves
286/// in the first ≤16 bytes) at AVX2's per-call cost — a 64-byte load +
287/// ZMM register usage on a call that exits inside the first 32 bytes
288/// is wasted work. Once we've established the LCP exceeds 32 bytes we
289/// switch to the 64-byte stride for the rest of the comparison, which
290/// is the regime the upstream genome bench (and this function's
291/// reason for existing) actually hits.
292#[cfg(target_arch = "x86_64")]
293#[target_feature(enable = "avx512f,avx512bw")]
294unsafe fn lcp_bytes_avx512(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
295    use std::arch::x86_64::{
296        __m256i, __m512i, _mm256_cmpeq_epi8, _mm256_loadu_si256, _mm256_movemask_epi8,
297        _mm512_cmpeq_epi8_mask, _mm512_loadu_si512,
298    };
299    let n = text.len();
300    let lim_p = n.saturating_sub(p).min(max_ctx);
301    let lim_q = n.saturating_sub(q).min(max_ctx);
302    let lim = lim_p.min(lim_q);
303    let ptr = text.as_ptr();
304
305    let mut i = 0usize;
306    // 32-byte head: AVX2 compare. If it resolves the LCP we never touch
307    // a ZMM register.
308    if i + 32 <= lim {
309        // SAFETY: AVX2 is implied by AVX-512F; bounds checked above.
310        let va = unsafe { _mm256_loadu_si256(ptr.add(p + i) as *const __m256i) };
311        let vb = unsafe { _mm256_loadu_si256(ptr.add(q + i) as *const __m256i) };
312        let eq = _mm256_cmpeq_epi8(va, vb);
313        let mask = _mm256_movemask_epi8(eq) as u32;
314        if mask != u32::MAX {
315            return i + (!mask).trailing_zeros() as usize;
316        }
317        i += 32;
318    }
319    // 64-byte body: LCP exceeds 32 bytes, run at AVX-512 stride.
320    //
321    // We tried software prefetching (`_mm_prefetch::<_MM_HINT_T0>(ptr+256)`)
322    // inside this loop on the theory that overlapping the next stride's
323    // memory fetch with current-iteration execution would shave the
324    // dominant phase 1 wall. It did not: hum200m, rand100m, and the
325    // full GRCh38 32t bench all showed indistinguishable wall vs the
326    // bare loop. The Zen 5 hardware prefetcher recognises the strided
327    // pattern and is already issuing the same loads, so the explicit
328    // hint just adds inner-loop instructions for no benefit. Left bare
329    // for clarity.
330    while i + 64 <= lim {
331        // SAFETY: bounds ensured by the loop condition; unaligned loads.
332        let va = unsafe { _mm512_loadu_si512(ptr.add(p + i) as *const __m512i) };
333        let vb = unsafe { _mm512_loadu_si512(ptr.add(q + i) as *const __m512i) };
334        let mask = _mm512_cmpeq_epi8_mask(va, vb);
335        if mask != u64::MAX {
336            return i + (!mask).trailing_zeros() as usize;
337        }
338        i += 64;
339    }
340    // 32-byte tail: covers the case where the residue past the 64-byte
341    // loop is between 32 and 63 bytes.
342    if i + 32 <= lim {
343        // SAFETY: bounds checked above.
344        let va = unsafe { _mm256_loadu_si256(ptr.add(p + i) as *const __m256i) };
345        let vb = unsafe { _mm256_loadu_si256(ptr.add(q + i) as *const __m256i) };
346        let eq = _mm256_cmpeq_epi8(va, vb);
347        let mask = _mm256_movemask_epi8(eq) as u32;
348        if mask != u32::MAX {
349            return i + (!mask).trailing_zeros() as usize;
350        }
351        i += 32;
352    }
353    while i < lim {
354        if text[p + i] != text[q + i] {
355            return i;
356        }
357        i += 1;
358    }
359    i
360}
361
362/// AVX2 path: 32-byte vector compares, locate the first differing byte
363/// via `_mm256_movemask_epi8` + `trailing_zeros`.
364#[cfg(target_arch = "x86_64")]
365#[target_feature(enable = "avx2")]
366unsafe fn lcp_bytes_avx2(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
367    use std::arch::x86_64::{__m256i, _mm256_cmpeq_epi8, _mm256_loadu_si256, _mm256_movemask_epi8};
368    let n = text.len();
369    let lim_p = n.saturating_sub(p).min(max_ctx);
370    let lim_q = n.saturating_sub(q).min(max_ctx);
371    let lim = lim_p.min(lim_q);
372    let ptr = text.as_ptr();
373
374    let mut i = 0usize;
375    while i + 32 <= lim {
376        // SAFETY: bounds ensured by the loop condition; unaligned loads.
377        let va = unsafe { _mm256_loadu_si256(ptr.add(p + i) as *const __m256i) };
378        let vb = unsafe { _mm256_loadu_si256(ptr.add(q + i) as *const __m256i) };
379        let eq = _mm256_cmpeq_epi8(va, vb);
380        let mask = _mm256_movemask_epi8(eq) as u32;
381        if mask != u32::MAX {
382            return i + (!mask).trailing_zeros() as usize;
383        }
384        i += 32;
385    }
386    while i < lim {
387        if text[p + i] != text[q + i] {
388            return i;
389        }
390        i += 1;
391    }
392    i
393}
394
395/// NEON path: 16-byte compares, locate the first differing byte via the
396/// "shrn by 4" movemask emulation — pack each `vceqq_u8` byte
397/// (`0xFF` or `0x00`) into 4 mask bits of a single 64-bit lane, then
398/// `trailing_zeros / 4` gives the byte index.
399#[cfg(target_arch = "aarch64")]
400#[target_feature(enable = "neon")]
401unsafe fn lcp_bytes_neon(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
402    use std::arch::aarch64::{
403        vceqq_u8, vget_lane_u64, vld1q_u8, vreinterpret_u64_u8, vreinterpretq_u16_u8, vshrn_n_u16,
404    };
405    let n = text.len();
406    let lim_p = n.saturating_sub(p).min(max_ctx);
407    let lim_q = n.saturating_sub(q).min(max_ctx);
408    let lim = lim_p.min(lim_q);
409    let ptr = text.as_ptr();
410
411    let mut i = 0usize;
412    while i + 16 <= lim {
413        // SAFETY: bounds ensured by the loop condition; unaligned loads.
414        let va = unsafe { vld1q_u8(ptr.add(p + i)) };
415        let vb = unsafe { vld1q_u8(ptr.add(q + i)) };
416        let eq = vceqq_u8(va, vb);
417        let narrow = vshrn_n_u16::<4>(vreinterpretq_u16_u8(eq));
418        let mask = vget_lane_u64::<0>(vreinterpret_u64_u8(narrow));
419        if mask != u64::MAX {
420            return i + ((!mask).trailing_zeros() as usize / 4);
421        }
422        i += 16;
423    }
424    while i < lim {
425        if text[p + i] != text[q + i] {
426            return i;
427        }
428        i += 1;
429    }
430    i
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn lcp_matches_to_first_difference() {
439        let text = b"banana";
440        // suffix at 0: "banana", at 1: "anana". LCP = 0 (b vs a).
441        assert_eq!(lcp(text, 0, 1, usize::MAX), 0);
442        // suffix at 1: "anana", at 3: "ana". LCP = 3 ("ana"), then diff
443        // (n vs end).
444        assert_eq!(lcp(text, 1, 3, usize::MAX), 3);
445    }
446
447    #[test]
448    fn lcp_respects_max_ctx() {
449        let text = b"aaaaaa";
450        assert_eq!(lcp(text, 0, 1, 3), 3);
451    }
452
453    #[test]
454    fn lcp_stops_at_text_end() {
455        let text = b"abc";
456        // suffix at 0: "abc", at 2: "c". LCP=0.
457        assert_eq!(lcp(text, 0, 2, usize::MAX), 0);
458        // suffix at 1: "bc", at 1: "bc". LCP=2.
459        assert_eq!(lcp(text, 1, 1, usize::MAX), 2);
460    }
461
462    #[test]
463    fn cmp_lex_order() {
464        let text = b"banana";
465        // "anana" < "banana"
466        assert_eq!(suffix_cmp(text, 1, 0, usize::MAX), Ordering::Less);
467        // "ana" < "anana" (prefix)
468        assert_eq!(suffix_cmp(text, 3, 1, usize::MAX), Ordering::Less);
469        // self-equal
470        assert_eq!(suffix_cmp(text, 1, 1, usize::MAX), Ordering::Equal);
471    }
472
473    /// SIMD vs scalar agreement across pathological positions: long runs
474    /// of identical bytes (exercises full-vector equal branches),
475    /// 32-byte and 16-byte boundary differences (covers AVX2 and NEON
476    /// chunk sizes), and unaligned tail bytes.
477    #[test]
478    fn simd_matches_scalar_on_u8() {
479        use rand::{RngExt, SeedableRng};
480        let mut rng = rand::rngs::StdRng::seed_from_u64(0xA5A5);
481
482        for diff_at in [0usize, 1, 31, 32, 33, 63, 64, 65, 100] {
483            // Build a 400-byte text: first half is "AAA…CAAAA…" with C at
484            // `diff_at`, second half is all A's. The LCP between suffix
485            // 200 (all-A) and suffix 0 (has C at diff_at) is exactly
486            // diff_at.
487            let mut combined = vec![b'A'; 400];
488            combined[diff_at] = b'C';
489            let got = lcp(&combined, 0, 200, usize::MAX);
490            assert_eq!(got, diff_at, "wrong LCP at diff_at={diff_at}");
491        }
492
493        for &n in &[1usize, 32, 33, 200, 1000] {
494            let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..4u8)).collect();
495            for _ in 0..20 {
496                let p = rng.random_range(0..n);
497                let q = rng.random_range(0..n);
498                let want = lcp_scalar(&text, p, q, usize::MAX);
499                let got = lcp(&text, p, q, usize::MAX);
500                assert_eq!(got, want, "p={p} q={q} text={text:?}");
501            }
502        }
503    }
504
505    /// Verify the cached-dispatch path matches both scalar and the
506    /// one-off helper on a tricky case spanning AVX2/NEON vector
507    /// boundaries.
508    #[test]
509    fn dispatch_struct_matches_oneoff_and_scalar() {
510        let scalar = LcpDispatch::scalar();
511        let detected = LcpDispatch::detect();
512        let mut text: Vec<u8> = vec![b'A'; 200];
513        text[64] = b'T'; // diff right at the second AVX2 boundary
514        assert_eq!(scalar.lcp(&text, 0, 100, usize::MAX), 64);
515        assert_eq!(detected.lcp(&text, 0, 100, usize::MAX), 64);
516    }
517
518    /// Exercise the AVX-512 64-byte stride and the 32-byte tail: place
519    /// the differing byte at offsets that straddle each boundary
520    /// (0/63/64/65/95/96/97/127/128) and confirm the dispatched path
521    /// agrees with scalar. On a non-AVX-512 host this devolves to the
522    /// other SIMD paths but still verifies correctness.
523    #[test]
524    fn avx512_boundary_agreement() {
525        let detected = LcpDispatch::detect();
526        for diff_at in [0usize, 1, 31, 32, 33, 63, 64, 65, 95, 96, 97, 127, 128, 200] {
527            let mut text = vec![b'A'; 512];
528            text[diff_at] = b'G';
529            let got = detected.lcp(&text, 0, 256, usize::MAX);
530            assert_eq!(got, diff_at, "diff_at={diff_at}");
531        }
532    }
533
534    /// SIMD-vs-scalar agreement for `u16` text. The byte-view dispatch
535    /// must return symbol-LCPs that match the scalar walk over `&[u16]`.
536    /// Covers the case where the first differing byte lands inside a
537    /// symbol whose previous bytes were equal (e.g. low byte of u16
538    /// equal, high byte differs).
539    #[test]
540    fn simd_matches_scalar_on_u16() {
541        use rand::{RngExt, SeedableRng};
542        let mut rng = rand::rngs::StdRng::seed_from_u64(0x1357);
543
544        // Place a difference at every interesting byte boundary within
545        // and across symbols; the symbol-LCP should equal byte_diff/2.
546        let mut text = vec![0u16; 256];
547        for byte_diff_at in [0usize, 1, 2, 3, 31, 32, 33, 63, 64, 65, 127, 128, 200] {
548            text.iter_mut().for_each(|x| *x = 0xAAAA);
549            // Flip one bit inside `text[byte_diff_at / 2]`, in either
550            // the low or high byte depending on parity.
551            let sym = byte_diff_at / 2;
552            let mask = if byte_diff_at % 2 == 0 {
553                0x00FF
554            } else {
555                0xFF00
556            };
557            text[sym] ^= mask & 0xAAAA; // toggle the bit pattern
558            let got = lcp(&text, 0, 128, usize::MAX);
559            // The first differing symbol is `byte_diff_at / 2`.
560            assert_eq!(
561                got, sym,
562                "byte_diff_at={byte_diff_at}, expected symbol {sym}"
563            );
564        }
565
566        // Random u16 texts: dispatched path must equal scalar.
567        for &n in &[1usize, 16, 17, 100, 500] {
568            let text: Vec<u16> = (0..n).map(|_| rng.random_range(0..16u16)).collect();
569            for _ in 0..20 {
570                let p = rng.random_range(0..n);
571                let q = rng.random_range(0..n);
572                let want = lcp_scalar(&text, p, q, usize::MAX);
573                let got = lcp(&text, p, q, usize::MAX);
574                assert_eq!(got, want, "u16 p={p} q={q} n={n}");
575            }
576        }
577    }
578
579    /// Same agreement check at `u32` granularity (4-byte symbols).
580    #[test]
581    fn simd_matches_scalar_on_u32() {
582        use rand::{RngExt, SeedableRng};
583        let mut rng = rand::rngs::StdRng::seed_from_u64(0x2468);
584
585        for &n in &[1usize, 8, 9, 100, 500] {
586            let text: Vec<u32> = (0..n).map(|_| rng.random_range(0..32u32)).collect();
587            for _ in 0..20 {
588                let p = rng.random_range(0..n);
589                let q = rng.random_range(0..n);
590                let want = lcp_scalar(&text, p, q, usize::MAX);
591                let got = lcp(&text, p, q, usize::MAX);
592                assert_eq!(got, want, "u32 p={p} q={q} n={n}");
593            }
594        }
595    }
596
597    /// 24-bit alphabet via `[u8; 3]`: every symbol-LCP must equal the
598    /// scalar walk's. Exercises a symbol width that doesn't divide any
599    /// SIMD chunk evenly.
600    #[test]
601    fn simd_matches_scalar_on_u8_3() {
602        use rand::{RngExt, SeedableRng};
603        let mut rng = rand::rngs::StdRng::seed_from_u64(0xACAC);
604
605        for &n in &[1usize, 8, 22, 100, 333] {
606            let text: Vec<[u8; 3]> = (0..n)
607                .map(|_| {
608                    [
609                        rng.random_range(0..4u8),
610                        rng.random_range(0..4u8),
611                        rng.random_range(0..4u8),
612                    ]
613                })
614                .collect();
615            for _ in 0..20 {
616                let p = rng.random_range(0..n);
617                let q = rng.random_range(0..n);
618                let want = lcp_scalar(&text, p, q, usize::MAX);
619                let got = lcp(&text, p, q, usize::MAX);
620                assert_eq!(got, want, "[u8;3] p={p} q={q} n={n}");
621            }
622        }
623    }
624
625    /// `u64` agreement — 8-byte symbols.
626    #[test]
627    fn simd_matches_scalar_on_u64() {
628        use rand::{RngExt, SeedableRng};
629        let mut rng = rand::rngs::StdRng::seed_from_u64(0xFEED);
630
631        for &n in &[1usize, 4, 5, 50, 250] {
632            let text: Vec<u64> = (0..n).map(|_| rng.random_range(0..64u64)).collect();
633            for _ in 0..20 {
634                let p = rng.random_range(0..n);
635                let q = rng.random_range(0..n);
636                let want = lcp_scalar(&text, p, q, usize::MAX);
637                let got = lcp(&text, p, q, usize::MAX);
638                assert_eq!(got, want, "u64 p={p} q={q} n={n}");
639            }
640        }
641    }
642}