caps_sa/lcp.rs
1//! Suffix comparison primitives over a generic text.
2//!
3//! Mirrors CaPS-SA's `Suffix_Array::LCP` family (see `include/Suffix_Array.hpp`
4//! and `include/Genomic_Text.hpp`). Performance-critical callers (the
5//! merge-sort and cascade-merge inner loops) construct a [`LcpDispatch`]
6//! **once** at the top of the SA build and pass it through. The dispatch
7//! holds a function pointer chosen by [`is_x86_feature_detected!`] /
8//! [`is_aarch64_feature_detected!`] at construction time, so the hot path
9//! is a single indirect call through a register — no per-call atomic
10//! loads, no per-call feature-detection branches.
11//!
12//! The free-standing [`lcp`] / [`suffix_cmp`] / [`lcp_u8`] helpers remain
13//! for one-off callers (and for the tests in this file). They construct a
14//! [`LcpDispatch`] on every call and are correspondingly slower; algorithm
15//! kernels should prefer the methods on [`LcpDispatch`].
16//!
17//! ## Symbol types
18//!
19//! The whole crate is generic over any [`Symbol`] type. `Symbol` is an
20//! `unsafe` marker for types whose in-memory bytes encode equality
21//! (`a == b` iff their byte representations are equal) — i.e. no padding,
22//! no invalid bit patterns. Blanket impls are provided for every stdlib
23//! integer type and for fixed-size arrays of `Symbol`s, so `u8`, `u16`,
24//! `u32`, `u64`, `[u8; 3]` (24-bit), … all work out of the box. The LCP
25//! function casts `&[S]` to a byte view, runs a single byte-level SIMD
26//! compare (AVX-512BW hybrid → AVX2 → NEON → scalar), then divides the
27//! byte-LCP by `size_of::<S>()` to get the symbol-LCP. Endianness is
28//! irrelevant because the byte-compare resolves equality only; symbol
29//! ordering is recovered by the caller's `text[lcp].cmp(&text[lcp + 1])`
30//! using `S`'s native `Ord`.
31
32use std::cmp::Ordering;
33
34use crate::limits::{LimitProvider, PlainText};
35
36/// A symbol type for suffix-array construction. All stdlib unsigned
37/// and signed integer types satisfy this, as does `[T; N]` for any
38/// `T: Symbol` (arrays have no padding and inherit `Ord`
39/// lexicographically from `T`). To use a custom type as a symbol,
40/// implement this trait yourself; if the type contains padding, mark
41/// it `#[repr(C, packed)]` first.
42///
43/// The trait bundles every other bound the algorithm needs from a
44/// symbol type ([`Ord`] + [`Copy`] + [`Send`] + [`Sync`] + `'static`),
45/// so the public API surface only ever needs `S: Symbol`.
46///
47/// # Safety
48///
49/// Implementations must guarantee that bit-equality of the in-memory
50/// representation implies value-equality — i.e. **no padding bytes**,
51/// **no invalid bit patterns** that two distinct values could share.
52/// The byte-view SIMD LCP path in [`LcpDispatch::lcp`] casts `&[S]`
53/// to a `&[u8]` view and compares bytes; if two distinct `S` values
54/// could have identical bytes, or one `S` value could have two
55/// different byte representations (e.g. via uninitialised padding),
56/// the LCP function would return wrong answers and corrupt the
57/// resulting suffix array.
58pub unsafe trait Symbol: Ord + Copy + Send + Sync + 'static {}
59
60macro_rules! impl_symbol_for_primitives {
61 ($($t:ty),* $(,)?) => {
62 $(
63 // SAFETY: stdlib primitive integers have no padding and every
64 // bit pattern is a valid value, so byte-equality is exactly
65 // value-equality.
66 unsafe impl Symbol for $t {}
67 )*
68 };
69}
70impl_symbol_for_primitives!(
71 u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize,
72);
73
74// SAFETY: arrays of `Symbol`s have no padding (Rust arrays are tightly
75// packed) and inherit equality element-wise, so byte-equality of the
76// whole array is exactly value-equality. This covers patterns like
77// `[u8; 3]` for 24-bit alphabets and `[u32; 2]` for 64-bit-on-32-bit.
78unsafe impl<T: Symbol, const N: usize> Symbol for [T; N] {}
79
80/// A function-pointer dispatch for byte-level LCP. The architecture-
81/// specific pointer is selected once at construction by feature
82/// detection; later calls reduce to a register-resident indirect call.
83///
84/// `LcpDispatch` is `Copy`, `Send`, and `Sync` (a function pointer is
85/// all three), so it threads freely through `rayon` boundaries.
86///
87/// The same byte-level function backs every symbol width: the
88/// [`Self::lcp`] method casts `&[S]` to a byte view, calls the function
89/// with byte-scale offsets, and divides the result by `size_of::<S>()`
90/// to recover the symbol-level LCP.
91#[derive(Copy, Clone)]
92pub struct LcpDispatch {
93 lcp_bytes_fn: LcpBytesFn,
94}
95
96/// Internal function-pointer type for the byte-level LCP path.
97/// `unsafe fn` because the AVX2 / AVX-512 / NEON variants are
98/// `#[target_feature]` gated; the dispatch's owner has already
99/// verified CPU support.
100type LcpBytesFn = unsafe fn(&[u8], usize, usize, usize) -> usize;
101
102impl LcpDispatch {
103 /// Detect the best LCP implementation for this CPU. Cheap (a couple
104 /// of `is_*_feature_detected!` checks) but does still touch the
105 /// feature-detection cache, so call it **once** per top-level build.
106 pub fn detect() -> Self {
107 Self {
108 lcp_bytes_fn: pick_lcp_bytes_impl(),
109 }
110 }
111
112 /// Forced scalar dispatch — useful for tests and for clients that
113 /// want a deterministic baseline.
114 pub fn scalar() -> Self {
115 Self {
116 lcp_bytes_fn: lcp_bytes_scalar,
117 }
118 }
119
120 /// Longest common prefix of `text[p..]` and `text[q..]` in symbols,
121 /// bounded by `max_ctx`. For any `S: Symbol` of non-zero size, this
122 /// dispatches to the byte-level SIMD path with byte-scaled offsets
123 /// and returns `byte_lcp / size_of::<S>()`.
124 #[inline]
125 pub fn lcp<S: Symbol>(&self, text: &[S], p: usize, q: usize, max_ctx: usize) -> usize {
126 let k = std::mem::size_of::<S>();
127 if k == 0 {
128 // ZSTs: `Symbol` permits ZSTs (e.g. a unit `struct Foo;`),
129 // but the byte-view dispatch can't divide by zero. Such a
130 // text has zero bytes regardless of length, so every suffix
131 // is identical and the LCP is just the length-bounded `lim`.
132 let lim_p = text.len().saturating_sub(p).min(max_ctx);
133 let lim_q = text.len().saturating_sub(q).min(max_ctx);
134 return lim_p.min(lim_q);
135 }
136 // SAFETY: `Symbol`'s `unsafe` contract is exactly "bit-equality
137 // is value-equality" — `&[S]` has the same byte representation
138 // as a `&[u8]` view over the same bytes, with no padding to
139 // worry about. `size_of_val(text)` gives the slice's exact
140 // byte length, which Rust's slice invariant already guarantees
141 // fits in `isize`.
142 let bytes =
143 unsafe { std::slice::from_raw_parts(text.as_ptr() as *const u8, size_of_val(text)) };
144 let byte_lcp = unsafe {
145 (self.lcp_bytes_fn)(
146 bytes,
147 p.saturating_mul(k),
148 q.saturating_mul(k),
149 max_ctx.saturating_mul(k),
150 )
151 };
152 byte_lcp / k
153 }
154
155 /// Total order on two suffixes of `text`. Uses [`Self::lcp`] for the
156 /// shared prefix, then resolves the first differing symbol or — if
157 /// both suffixes are exhausted within `max_ctx` — orders by remaining
158 /// length (shorter is smaller, the convention SAIS and CaPS-SA use).
159 ///
160 /// Zero-cost wrapper around [`Self::suffix_cmp_with`] for the
161 /// non-segmented case.
162 #[inline]
163 pub fn suffix_cmp<S: Symbol>(
164 &self,
165 text: &[S],
166 p: usize,
167 q: usize,
168 max_ctx: usize,
169 ) -> Ordering {
170 self.suffix_cmp_with(text, &PlainText::new(text.len()), p, q, max_ctx)
171 }
172
173 /// Like [`Self::suffix_cmp`] but takes a [`LimitProvider`] so the
174 /// suffix lengths used for the LCP-cap and the boundary-tie-break
175 /// come from a segmented view of the text. With [`PlainText`]
176 /// this matches [`Self::suffix_cmp`] exactly.
177 ///
178 /// The boundary tie-break (when both suffixes hit their limit
179 /// before any byte differs) is delegated to
180 /// [`LimitProvider::boundary_order`]; the default
181 /// "shorter-is-smaller" gives the generalised-SA convention,
182 /// custom impls can flip it (e.g. STAR's spacer-as-largest).
183 #[inline]
184 pub fn suffix_cmp_with<S: Symbol, L: LimitProvider>(
185 &self,
186 text: &[S],
187 lp: &L,
188 p: usize,
189 q: usize,
190 max_ctx: usize,
191 ) -> Ordering {
192 let lim_p = lp.lim_at(p);
193 let lim_q = lp.lim_at(q);
194 let lim = lim_p.min(lim_q).min(max_ctx);
195 let common = self.lcp(text, p, q, lim);
196 if common < lim {
197 text[p + common].cmp(&text[q + common])
198 } else {
199 lp.boundary_order(p, lim_p, q, lim_q)
200 }
201 }
202}
203
204/// One-off LCP. Constructs a fresh [`LcpDispatch`] on every call — fine
205/// for tests / introspection, slower than reusing one [`LcpDispatch`]
206/// across an algorithm's inner loop.
207#[inline]
208pub fn lcp<S: Symbol>(text: &[S], p: usize, q: usize, max_ctx: usize) -> usize {
209 LcpDispatch::detect().lcp(text, p, q, max_ctx)
210}
211
212/// One-off suffix comparison; see [`lcp`] for the cost note.
213#[inline]
214pub fn suffix_cmp<S: Symbol>(text: &[S], p: usize, q: usize, max_ctx: usize) -> Ordering {
215 LcpDispatch::detect().suffix_cmp(text, p, q, max_ctx)
216}
217
218/// One-off `u8`-typed LCP that auto-selects AVX-512 / AVX2 / NEON /
219/// scalar. Convenience entry point for byte texts; equivalent in cost
220/// to `LcpDispatch::detect().lcp::<u8>(...)` but skips the generic
221/// indirection.
222#[inline]
223pub fn lcp_u8(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
224 let f = pick_lcp_bytes_impl();
225 unsafe { f(text, p, q, max_ctx) }
226}
227
228/// Generic scalar LCP. Public so callers that already know they can
229/// skip the SIMD dispatch (e.g. non-`Symbol` symbol types like
230/// arbitrary `Eq` newtypes) can call this directly. Still
231/// symbol-granularity; just doesn't go through the byte view.
232#[inline]
233pub fn lcp_scalar<S: Eq>(text: &[S], p: usize, q: usize, max_ctx: usize) -> usize {
234 let n = text.len();
235 let lim_p = n.saturating_sub(p).min(max_ctx);
236 let lim_q = n.saturating_sub(q).min(max_ctx);
237 let lim = lim_p.min(lim_q);
238 let mut i = 0;
239 while i < lim {
240 if text[p + i] != text[q + i] {
241 return i;
242 }
243 i += 1;
244 }
245 i
246}
247
248/// Inspect this CPU's features and return the best [`LcpBytesFn`].
249fn pick_lcp_bytes_impl() -> LcpBytesFn {
250 #[cfg(target_arch = "x86_64")]
251 {
252 // AVX-512BW gives us a 64-byte byte-compare returning a 64-bit
253 // mask register directly — no movemask intrinsic, no extract.
254 // Both `f` (foundation) and `bw` (byte/word ops, for the
255 // `_mm512_cmpeq_epi8_mask` we use) are required.
256 if std::is_x86_feature_detected!("avx512f") && std::is_x86_feature_detected!("avx512bw") {
257 return lcp_bytes_avx512;
258 }
259 if std::is_x86_feature_detected!("avx2") {
260 return lcp_bytes_avx2;
261 }
262 }
263 #[cfg(target_arch = "aarch64")]
264 {
265 if std::arch::is_aarch64_feature_detected!("neon") {
266 return lcp_bytes_neon;
267 }
268 }
269 lcp_bytes_scalar
270}
271
272/// `unsafe fn`-typed scalar — wrapper around [`lcp_scalar`] for the
273/// `u8` instantiation so all dispatch targets share the [`LcpBytesFn`]
274/// signature.
275unsafe fn lcp_bytes_scalar(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
276 lcp_scalar(text, p, q, max_ctx)
277}
278
279/// AVX-512BW path: 64-byte vector compares; `_mm512_cmpeq_epi8_mask`
280/// returns the per-byte equality mask straight in a 64-bit `__mmask64`
281/// register (no movemask round-trip), and `(!mask).trailing_zeros()`
282/// gives the first differing byte.
283///
284/// The function leads with a single 32-byte AVX2 step. This keeps the
285/// short-LCP regime (random DNA, where every call typically resolves
286/// in the first ≤16 bytes) at AVX2's per-call cost — a 64-byte load +
287/// ZMM register usage on a call that exits inside the first 32 bytes
288/// is wasted work. Once we've established the LCP exceeds 32 bytes we
289/// switch to the 64-byte stride for the rest of the comparison, which
290/// is the regime the upstream genome bench (and this function's
291/// reason for existing) actually hits.
292#[cfg(target_arch = "x86_64")]
293#[target_feature(enable = "avx512f,avx512bw")]
294unsafe fn lcp_bytes_avx512(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
295 use std::arch::x86_64::{
296 __m256i, __m512i, _mm256_cmpeq_epi8, _mm256_loadu_si256, _mm256_movemask_epi8,
297 _mm512_cmpeq_epi8_mask, _mm512_loadu_si512,
298 };
299 let n = text.len();
300 let lim_p = n.saturating_sub(p).min(max_ctx);
301 let lim_q = n.saturating_sub(q).min(max_ctx);
302 let lim = lim_p.min(lim_q);
303 let ptr = text.as_ptr();
304
305 let mut i = 0usize;
306 // 32-byte head: AVX2 compare. If it resolves the LCP we never touch
307 // a ZMM register.
308 if i + 32 <= lim {
309 // SAFETY: AVX2 is implied by AVX-512F; bounds checked above.
310 let va = unsafe { _mm256_loadu_si256(ptr.add(p + i) as *const __m256i) };
311 let vb = unsafe { _mm256_loadu_si256(ptr.add(q + i) as *const __m256i) };
312 let eq = _mm256_cmpeq_epi8(va, vb);
313 let mask = _mm256_movemask_epi8(eq) as u32;
314 if mask != u32::MAX {
315 return i + (!mask).trailing_zeros() as usize;
316 }
317 i += 32;
318 }
319 // 64-byte body: LCP exceeds 32 bytes, run at AVX-512 stride.
320 //
321 // We tried software prefetching (`_mm_prefetch::<_MM_HINT_T0>(ptr+256)`)
322 // inside this loop on the theory that overlapping the next stride's
323 // memory fetch with current-iteration execution would shave the
324 // dominant phase 1 wall. It did not: hum200m, rand100m, and the
325 // full GRCh38 32t bench all showed indistinguishable wall vs the
326 // bare loop. The Zen 5 hardware prefetcher recognises the strided
327 // pattern and is already issuing the same loads, so the explicit
328 // hint just adds inner-loop instructions for no benefit. Left bare
329 // for clarity.
330 while i + 64 <= lim {
331 // SAFETY: bounds ensured by the loop condition; unaligned loads.
332 let va = unsafe { _mm512_loadu_si512(ptr.add(p + i) as *const __m512i) };
333 let vb = unsafe { _mm512_loadu_si512(ptr.add(q + i) as *const __m512i) };
334 let mask = _mm512_cmpeq_epi8_mask(va, vb);
335 if mask != u64::MAX {
336 return i + (!mask).trailing_zeros() as usize;
337 }
338 i += 64;
339 }
340 // 32-byte tail: covers the case where the residue past the 64-byte
341 // loop is between 32 and 63 bytes.
342 if i + 32 <= lim {
343 // SAFETY: bounds checked above.
344 let va = unsafe { _mm256_loadu_si256(ptr.add(p + i) as *const __m256i) };
345 let vb = unsafe { _mm256_loadu_si256(ptr.add(q + i) as *const __m256i) };
346 let eq = _mm256_cmpeq_epi8(va, vb);
347 let mask = _mm256_movemask_epi8(eq) as u32;
348 if mask != u32::MAX {
349 return i + (!mask).trailing_zeros() as usize;
350 }
351 i += 32;
352 }
353 while i < lim {
354 if text[p + i] != text[q + i] {
355 return i;
356 }
357 i += 1;
358 }
359 i
360}
361
362/// AVX2 path: 32-byte vector compares, locate the first differing byte
363/// via `_mm256_movemask_epi8` + `trailing_zeros`.
364#[cfg(target_arch = "x86_64")]
365#[target_feature(enable = "avx2")]
366unsafe fn lcp_bytes_avx2(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
367 use std::arch::x86_64::{__m256i, _mm256_cmpeq_epi8, _mm256_loadu_si256, _mm256_movemask_epi8};
368 let n = text.len();
369 let lim_p = n.saturating_sub(p).min(max_ctx);
370 let lim_q = n.saturating_sub(q).min(max_ctx);
371 let lim = lim_p.min(lim_q);
372 let ptr = text.as_ptr();
373
374 let mut i = 0usize;
375 while i + 32 <= lim {
376 // SAFETY: bounds ensured by the loop condition; unaligned loads.
377 let va = unsafe { _mm256_loadu_si256(ptr.add(p + i) as *const __m256i) };
378 let vb = unsafe { _mm256_loadu_si256(ptr.add(q + i) as *const __m256i) };
379 let eq = _mm256_cmpeq_epi8(va, vb);
380 let mask = _mm256_movemask_epi8(eq) as u32;
381 if mask != u32::MAX {
382 return i + (!mask).trailing_zeros() as usize;
383 }
384 i += 32;
385 }
386 while i < lim {
387 if text[p + i] != text[q + i] {
388 return i;
389 }
390 i += 1;
391 }
392 i
393}
394
395/// NEON path: 16-byte compares, locate the first differing byte via the
396/// "shrn by 4" movemask emulation — pack each `vceqq_u8` byte
397/// (`0xFF` or `0x00`) into 4 mask bits of a single 64-bit lane, then
398/// `trailing_zeros / 4` gives the byte index.
399#[cfg(target_arch = "aarch64")]
400#[target_feature(enable = "neon")]
401unsafe fn lcp_bytes_neon(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
402 use std::arch::aarch64::{
403 vceqq_u8, vget_lane_u64, vld1q_u8, vreinterpret_u64_u8, vreinterpretq_u16_u8, vshrn_n_u16,
404 };
405 let n = text.len();
406 let lim_p = n.saturating_sub(p).min(max_ctx);
407 let lim_q = n.saturating_sub(q).min(max_ctx);
408 let lim = lim_p.min(lim_q);
409 let ptr = text.as_ptr();
410
411 let mut i = 0usize;
412 while i + 16 <= lim {
413 // SAFETY: bounds ensured by the loop condition; unaligned loads.
414 let va = unsafe { vld1q_u8(ptr.add(p + i)) };
415 let vb = unsafe { vld1q_u8(ptr.add(q + i)) };
416 let eq = vceqq_u8(va, vb);
417 let narrow = vshrn_n_u16::<4>(vreinterpretq_u16_u8(eq));
418 let mask = vget_lane_u64::<0>(vreinterpret_u64_u8(narrow));
419 if mask != u64::MAX {
420 return i + ((!mask).trailing_zeros() as usize / 4);
421 }
422 i += 16;
423 }
424 while i < lim {
425 if text[p + i] != text[q + i] {
426 return i;
427 }
428 i += 1;
429 }
430 i
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn lcp_matches_to_first_difference() {
439 let text = b"banana";
440 // suffix at 0: "banana", at 1: "anana". LCP = 0 (b vs a).
441 assert_eq!(lcp(text, 0, 1, usize::MAX), 0);
442 // suffix at 1: "anana", at 3: "ana". LCP = 3 ("ana"), then diff
443 // (n vs end).
444 assert_eq!(lcp(text, 1, 3, usize::MAX), 3);
445 }
446
447 #[test]
448 fn lcp_respects_max_ctx() {
449 let text = b"aaaaaa";
450 assert_eq!(lcp(text, 0, 1, 3), 3);
451 }
452
453 #[test]
454 fn lcp_stops_at_text_end() {
455 let text = b"abc";
456 // suffix at 0: "abc", at 2: "c". LCP=0.
457 assert_eq!(lcp(text, 0, 2, usize::MAX), 0);
458 // suffix at 1: "bc", at 1: "bc". LCP=2.
459 assert_eq!(lcp(text, 1, 1, usize::MAX), 2);
460 }
461
462 #[test]
463 fn cmp_lex_order() {
464 let text = b"banana";
465 // "anana" < "banana"
466 assert_eq!(suffix_cmp(text, 1, 0, usize::MAX), Ordering::Less);
467 // "ana" < "anana" (prefix)
468 assert_eq!(suffix_cmp(text, 3, 1, usize::MAX), Ordering::Less);
469 // self-equal
470 assert_eq!(suffix_cmp(text, 1, 1, usize::MAX), Ordering::Equal);
471 }
472
473 /// SIMD vs scalar agreement across pathological positions: long runs
474 /// of identical bytes (exercises full-vector equal branches),
475 /// 32-byte and 16-byte boundary differences (covers AVX2 and NEON
476 /// chunk sizes), and unaligned tail bytes.
477 #[test]
478 fn simd_matches_scalar_on_u8() {
479 use rand::{RngExt, SeedableRng};
480 let mut rng = rand::rngs::StdRng::seed_from_u64(0xA5A5);
481
482 for diff_at in [0usize, 1, 31, 32, 33, 63, 64, 65, 100] {
483 // Build a 400-byte text: first half is "AAA…CAAAA…" with C at
484 // `diff_at`, second half is all A's. The LCP between suffix
485 // 200 (all-A) and suffix 0 (has C at diff_at) is exactly
486 // diff_at.
487 let mut combined = vec![b'A'; 400];
488 combined[diff_at] = b'C';
489 let got = lcp(&combined, 0, 200, usize::MAX);
490 assert_eq!(got, diff_at, "wrong LCP at diff_at={diff_at}");
491 }
492
493 for &n in &[1usize, 32, 33, 200, 1000] {
494 let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..4u8)).collect();
495 for _ in 0..20 {
496 let p = rng.random_range(0..n);
497 let q = rng.random_range(0..n);
498 let want = lcp_scalar(&text, p, q, usize::MAX);
499 let got = lcp(&text, p, q, usize::MAX);
500 assert_eq!(got, want, "p={p} q={q} text={text:?}");
501 }
502 }
503 }
504
505 /// Verify the cached-dispatch path matches both scalar and the
506 /// one-off helper on a tricky case spanning AVX2/NEON vector
507 /// boundaries.
508 #[test]
509 fn dispatch_struct_matches_oneoff_and_scalar() {
510 let scalar = LcpDispatch::scalar();
511 let detected = LcpDispatch::detect();
512 let mut text: Vec<u8> = vec![b'A'; 200];
513 text[64] = b'T'; // diff right at the second AVX2 boundary
514 assert_eq!(scalar.lcp(&text, 0, 100, usize::MAX), 64);
515 assert_eq!(detected.lcp(&text, 0, 100, usize::MAX), 64);
516 }
517
518 /// Exercise the AVX-512 64-byte stride and the 32-byte tail: place
519 /// the differing byte at offsets that straddle each boundary
520 /// (0/63/64/65/95/96/97/127/128) and confirm the dispatched path
521 /// agrees with scalar. On a non-AVX-512 host this devolves to the
522 /// other SIMD paths but still verifies correctness.
523 #[test]
524 fn avx512_boundary_agreement() {
525 let detected = LcpDispatch::detect();
526 for diff_at in [0usize, 1, 31, 32, 33, 63, 64, 65, 95, 96, 97, 127, 128, 200] {
527 let mut text = vec![b'A'; 512];
528 text[diff_at] = b'G';
529 let got = detected.lcp(&text, 0, 256, usize::MAX);
530 assert_eq!(got, diff_at, "diff_at={diff_at}");
531 }
532 }
533
534 /// SIMD-vs-scalar agreement for `u16` text. The byte-view dispatch
535 /// must return symbol-LCPs that match the scalar walk over `&[u16]`.
536 /// Covers the case where the first differing byte lands inside a
537 /// symbol whose previous bytes were equal (e.g. low byte of u16
538 /// equal, high byte differs).
539 #[test]
540 fn simd_matches_scalar_on_u16() {
541 use rand::{RngExt, SeedableRng};
542 let mut rng = rand::rngs::StdRng::seed_from_u64(0x1357);
543
544 // Place a difference at every interesting byte boundary within
545 // and across symbols; the symbol-LCP should equal byte_diff/2.
546 let mut text = vec![0u16; 256];
547 for byte_diff_at in [0usize, 1, 2, 3, 31, 32, 33, 63, 64, 65, 127, 128, 200] {
548 text.iter_mut().for_each(|x| *x = 0xAAAA);
549 // Flip one bit inside `text[byte_diff_at / 2]`, in either
550 // the low or high byte depending on parity.
551 let sym = byte_diff_at / 2;
552 let mask = if byte_diff_at % 2 == 0 {
553 0x00FF
554 } else {
555 0xFF00
556 };
557 text[sym] ^= mask & 0xAAAA; // toggle the bit pattern
558 let got = lcp(&text, 0, 128, usize::MAX);
559 // The first differing symbol is `byte_diff_at / 2`.
560 assert_eq!(
561 got, sym,
562 "byte_diff_at={byte_diff_at}, expected symbol {sym}"
563 );
564 }
565
566 // Random u16 texts: dispatched path must equal scalar.
567 for &n in &[1usize, 16, 17, 100, 500] {
568 let text: Vec<u16> = (0..n).map(|_| rng.random_range(0..16u16)).collect();
569 for _ in 0..20 {
570 let p = rng.random_range(0..n);
571 let q = rng.random_range(0..n);
572 let want = lcp_scalar(&text, p, q, usize::MAX);
573 let got = lcp(&text, p, q, usize::MAX);
574 assert_eq!(got, want, "u16 p={p} q={q} n={n}");
575 }
576 }
577 }
578
579 /// Same agreement check at `u32` granularity (4-byte symbols).
580 #[test]
581 fn simd_matches_scalar_on_u32() {
582 use rand::{RngExt, SeedableRng};
583 let mut rng = rand::rngs::StdRng::seed_from_u64(0x2468);
584
585 for &n in &[1usize, 8, 9, 100, 500] {
586 let text: Vec<u32> = (0..n).map(|_| rng.random_range(0..32u32)).collect();
587 for _ in 0..20 {
588 let p = rng.random_range(0..n);
589 let q = rng.random_range(0..n);
590 let want = lcp_scalar(&text, p, q, usize::MAX);
591 let got = lcp(&text, p, q, usize::MAX);
592 assert_eq!(got, want, "u32 p={p} q={q} n={n}");
593 }
594 }
595 }
596
597 /// 24-bit alphabet via `[u8; 3]`: every symbol-LCP must equal the
598 /// scalar walk's. Exercises a symbol width that doesn't divide any
599 /// SIMD chunk evenly.
600 #[test]
601 fn simd_matches_scalar_on_u8_3() {
602 use rand::{RngExt, SeedableRng};
603 let mut rng = rand::rngs::StdRng::seed_from_u64(0xACAC);
604
605 for &n in &[1usize, 8, 22, 100, 333] {
606 let text: Vec<[u8; 3]> = (0..n)
607 .map(|_| {
608 [
609 rng.random_range(0..4u8),
610 rng.random_range(0..4u8),
611 rng.random_range(0..4u8),
612 ]
613 })
614 .collect();
615 for _ in 0..20 {
616 let p = rng.random_range(0..n);
617 let q = rng.random_range(0..n);
618 let want = lcp_scalar(&text, p, q, usize::MAX);
619 let got = lcp(&text, p, q, usize::MAX);
620 assert_eq!(got, want, "[u8;3] p={p} q={q} n={n}");
621 }
622 }
623 }
624
625 /// `u64` agreement — 8-byte symbols.
626 #[test]
627 fn simd_matches_scalar_on_u64() {
628 use rand::{RngExt, SeedableRng};
629 let mut rng = rand::rngs::StdRng::seed_from_u64(0xFEED);
630
631 for &n in &[1usize, 4, 5, 50, 250] {
632 let text: Vec<u64> = (0..n).map(|_| rng.random_range(0..64u64)).collect();
633 for _ in 0..20 {
634 let p = rng.random_range(0..n);
635 let q = rng.random_range(0..n);
636 let want = lcp_scalar(&text, p, q, usize::MAX);
637 let got = lcp(&text, p, q, usize::MAX);
638 assert_eq!(got, want, "u64 p={p} q={q} n={n}");
639 }
640 }
641 }
642}