use std::cmp::Ordering;
use crate::limits::{LimitProvider, PlainText};
pub unsafe trait Symbol: Ord + Copy + Send + Sync + 'static {}
macro_rules! impl_symbol_for_primitives {
($($t:ty),* $(,)?) => {
$(
unsafe impl Symbol for $t {}
)*
};
}
impl_symbol_for_primitives!(
u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize,
);
unsafe impl<T: Symbol, const N: usize> Symbol for [T; N] {}
#[derive(Copy, Clone)]
pub struct LcpDispatch {
lcp_bytes_fn: LcpBytesFn,
}
type LcpBytesFn = unsafe fn(&[u8], usize, usize, usize) -> usize;
impl LcpDispatch {
pub fn detect() -> Self {
Self {
lcp_bytes_fn: pick_lcp_bytes_impl(),
}
}
pub fn scalar() -> Self {
Self {
lcp_bytes_fn: lcp_bytes_scalar,
}
}
#[inline]
pub fn lcp<S: Symbol>(&self, text: &[S], p: usize, q: usize, max_ctx: usize) -> usize {
let k = std::mem::size_of::<S>();
if k == 0 {
let lim_p = text.len().saturating_sub(p).min(max_ctx);
let lim_q = text.len().saturating_sub(q).min(max_ctx);
return lim_p.min(lim_q);
}
let bytes =
unsafe { std::slice::from_raw_parts(text.as_ptr() as *const u8, size_of_val(text)) };
let byte_lcp = unsafe {
(self.lcp_bytes_fn)(
bytes,
p.saturating_mul(k),
q.saturating_mul(k),
max_ctx.saturating_mul(k),
)
};
byte_lcp / k
}
#[inline]
pub fn suffix_cmp<S: Symbol>(
&self,
text: &[S],
p: usize,
q: usize,
max_ctx: usize,
) -> Ordering {
self.suffix_cmp_with(text, &PlainText::new(text.len()), p, q, max_ctx)
}
#[inline]
pub fn suffix_cmp_with<S: Symbol, L: LimitProvider>(
&self,
text: &[S],
lp: &L,
p: usize,
q: usize,
max_ctx: usize,
) -> Ordering {
let lim_p = lp.lim_at(p);
let lim_q = lp.lim_at(q);
let lim = lim_p.min(lim_q).min(max_ctx);
let common = self.lcp(text, p, q, lim);
if common < lim {
text[p + common].cmp(&text[q + common])
} else {
lp.boundary_order(p, lim_p, q, lim_q)
}
}
}
#[inline]
pub fn lcp<S: Symbol>(text: &[S], p: usize, q: usize, max_ctx: usize) -> usize {
LcpDispatch::detect().lcp(text, p, q, max_ctx)
}
#[inline]
pub fn suffix_cmp<S: Symbol>(text: &[S], p: usize, q: usize, max_ctx: usize) -> Ordering {
LcpDispatch::detect().suffix_cmp(text, p, q, max_ctx)
}
#[inline]
pub fn lcp_u8(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
let f = pick_lcp_bytes_impl();
unsafe { f(text, p, q, max_ctx) }
}
#[inline]
pub fn lcp_scalar<S: Eq>(text: &[S], p: usize, q: usize, max_ctx: usize) -> usize {
let n = text.len();
let lim_p = n.saturating_sub(p).min(max_ctx);
let lim_q = n.saturating_sub(q).min(max_ctx);
let lim = lim_p.min(lim_q);
let mut i = 0;
while i < lim {
if text[p + i] != text[q + i] {
return i;
}
i += 1;
}
i
}
fn pick_lcp_bytes_impl() -> LcpBytesFn {
#[cfg(target_arch = "x86_64")]
{
if std::is_x86_feature_detected!("avx512f") && std::is_x86_feature_detected!("avx512bw") {
return lcp_bytes_avx512;
}
if std::is_x86_feature_detected!("avx2") {
return lcp_bytes_avx2;
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return lcp_bytes_neon;
}
}
lcp_bytes_scalar
}
unsafe fn lcp_bytes_scalar(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
lcp_scalar(text, p, q, max_ctx)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f,avx512bw")]
unsafe fn lcp_bytes_avx512(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
use std::arch::x86_64::{
__m256i, __m512i, _mm256_cmpeq_epi8, _mm256_loadu_si256, _mm256_movemask_epi8,
_mm512_cmpeq_epi8_mask, _mm512_loadu_si512,
};
let n = text.len();
let lim_p = n.saturating_sub(p).min(max_ctx);
let lim_q = n.saturating_sub(q).min(max_ctx);
let lim = lim_p.min(lim_q);
let ptr = text.as_ptr();
let mut i = 0usize;
if i + 32 <= lim {
let va = unsafe { _mm256_loadu_si256(ptr.add(p + i) as *const __m256i) };
let vb = unsafe { _mm256_loadu_si256(ptr.add(q + i) as *const __m256i) };
let eq = _mm256_cmpeq_epi8(va, vb);
let mask = _mm256_movemask_epi8(eq) as u32;
if mask != u32::MAX {
return i + (!mask).trailing_zeros() as usize;
}
i += 32;
}
while i + 64 <= lim {
let va = unsafe { _mm512_loadu_si512(ptr.add(p + i) as *const __m512i) };
let vb = unsafe { _mm512_loadu_si512(ptr.add(q + i) as *const __m512i) };
let mask = _mm512_cmpeq_epi8_mask(va, vb);
if mask != u64::MAX {
return i + (!mask).trailing_zeros() as usize;
}
i += 64;
}
if i + 32 <= lim {
let va = unsafe { _mm256_loadu_si256(ptr.add(p + i) as *const __m256i) };
let vb = unsafe { _mm256_loadu_si256(ptr.add(q + i) as *const __m256i) };
let eq = _mm256_cmpeq_epi8(va, vb);
let mask = _mm256_movemask_epi8(eq) as u32;
if mask != u32::MAX {
return i + (!mask).trailing_zeros() as usize;
}
i += 32;
}
while i < lim {
if text[p + i] != text[q + i] {
return i;
}
i += 1;
}
i
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn lcp_bytes_avx2(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
use std::arch::x86_64::{__m256i, _mm256_cmpeq_epi8, _mm256_loadu_si256, _mm256_movemask_epi8};
let n = text.len();
let lim_p = n.saturating_sub(p).min(max_ctx);
let lim_q = n.saturating_sub(q).min(max_ctx);
let lim = lim_p.min(lim_q);
let ptr = text.as_ptr();
let mut i = 0usize;
while i + 32 <= lim {
let va = unsafe { _mm256_loadu_si256(ptr.add(p + i) as *const __m256i) };
let vb = unsafe { _mm256_loadu_si256(ptr.add(q + i) as *const __m256i) };
let eq = _mm256_cmpeq_epi8(va, vb);
let mask = _mm256_movemask_epi8(eq) as u32;
if mask != u32::MAX {
return i + (!mask).trailing_zeros() as usize;
}
i += 32;
}
while i < lim {
if text[p + i] != text[q + i] {
return i;
}
i += 1;
}
i
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn lcp_bytes_neon(text: &[u8], p: usize, q: usize, max_ctx: usize) -> usize {
use std::arch::aarch64::{
vceqq_u8, vget_lane_u64, vld1q_u8, vreinterpret_u64_u8, vreinterpretq_u16_u8, vshrn_n_u16,
};
let n = text.len();
let lim_p = n.saturating_sub(p).min(max_ctx);
let lim_q = n.saturating_sub(q).min(max_ctx);
let lim = lim_p.min(lim_q);
let ptr = text.as_ptr();
let mut i = 0usize;
while i + 16 <= lim {
let va = unsafe { vld1q_u8(ptr.add(p + i)) };
let vb = unsafe { vld1q_u8(ptr.add(q + i)) };
let eq = vceqq_u8(va, vb);
let narrow = vshrn_n_u16::<4>(vreinterpretq_u16_u8(eq));
let mask = vget_lane_u64::<0>(vreinterpret_u64_u8(narrow));
if mask != u64::MAX {
return i + ((!mask).trailing_zeros() as usize / 4);
}
i += 16;
}
while i < lim {
if text[p + i] != text[q + i] {
return i;
}
i += 1;
}
i
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lcp_matches_to_first_difference() {
let text = b"banana";
assert_eq!(lcp(text, 0, 1, usize::MAX), 0);
assert_eq!(lcp(text, 1, 3, usize::MAX), 3);
}
#[test]
fn lcp_respects_max_ctx() {
let text = b"aaaaaa";
assert_eq!(lcp(text, 0, 1, 3), 3);
}
#[test]
fn lcp_stops_at_text_end() {
let text = b"abc";
assert_eq!(lcp(text, 0, 2, usize::MAX), 0);
assert_eq!(lcp(text, 1, 1, usize::MAX), 2);
}
#[test]
fn cmp_lex_order() {
let text = b"banana";
assert_eq!(suffix_cmp(text, 1, 0, usize::MAX), Ordering::Less);
assert_eq!(suffix_cmp(text, 3, 1, usize::MAX), Ordering::Less);
assert_eq!(suffix_cmp(text, 1, 1, usize::MAX), Ordering::Equal);
}
#[test]
fn simd_matches_scalar_on_u8() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0xA5A5);
for diff_at in [0usize, 1, 31, 32, 33, 63, 64, 65, 100] {
let mut combined = vec![b'A'; 400];
combined[diff_at] = b'C';
let got = lcp(&combined, 0, 200, usize::MAX);
assert_eq!(got, diff_at, "wrong LCP at diff_at={diff_at}");
}
for &n in &[1usize, 32, 33, 200, 1000] {
let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..4u8)).collect();
for _ in 0..20 {
let p = rng.random_range(0..n);
let q = rng.random_range(0..n);
let want = lcp_scalar(&text, p, q, usize::MAX);
let got = lcp(&text, p, q, usize::MAX);
assert_eq!(got, want, "p={p} q={q} text={text:?}");
}
}
}
#[test]
fn dispatch_struct_matches_oneoff_and_scalar() {
let scalar = LcpDispatch::scalar();
let detected = LcpDispatch::detect();
let mut text: Vec<u8> = vec![b'A'; 200];
text[64] = b'T'; assert_eq!(scalar.lcp(&text, 0, 100, usize::MAX), 64);
assert_eq!(detected.lcp(&text, 0, 100, usize::MAX), 64);
}
#[test]
fn avx512_boundary_agreement() {
let detected = LcpDispatch::detect();
for diff_at in [0usize, 1, 31, 32, 33, 63, 64, 65, 95, 96, 97, 127, 128, 200] {
let mut text = vec![b'A'; 512];
text[diff_at] = b'G';
let got = detected.lcp(&text, 0, 256, usize::MAX);
assert_eq!(got, diff_at, "diff_at={diff_at}");
}
}
#[test]
fn simd_matches_scalar_on_u16() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0x1357);
let mut text = vec![0u16; 256];
for byte_diff_at in [0usize, 1, 2, 3, 31, 32, 33, 63, 64, 65, 127, 128, 200] {
text.iter_mut().for_each(|x| *x = 0xAAAA);
let sym = byte_diff_at / 2;
let mask = if byte_diff_at % 2 == 0 {
0x00FF
} else {
0xFF00
};
text[sym] ^= mask & 0xAAAA; let got = lcp(&text, 0, 128, usize::MAX);
assert_eq!(
got, sym,
"byte_diff_at={byte_diff_at}, expected symbol {sym}"
);
}
for &n in &[1usize, 16, 17, 100, 500] {
let text: Vec<u16> = (0..n).map(|_| rng.random_range(0..16u16)).collect();
for _ in 0..20 {
let p = rng.random_range(0..n);
let q = rng.random_range(0..n);
let want = lcp_scalar(&text, p, q, usize::MAX);
let got = lcp(&text, p, q, usize::MAX);
assert_eq!(got, want, "u16 p={p} q={q} n={n}");
}
}
}
#[test]
fn simd_matches_scalar_on_u32() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0x2468);
for &n in &[1usize, 8, 9, 100, 500] {
let text: Vec<u32> = (0..n).map(|_| rng.random_range(0..32u32)).collect();
for _ in 0..20 {
let p = rng.random_range(0..n);
let q = rng.random_range(0..n);
let want = lcp_scalar(&text, p, q, usize::MAX);
let got = lcp(&text, p, q, usize::MAX);
assert_eq!(got, want, "u32 p={p} q={q} n={n}");
}
}
}
#[test]
fn simd_matches_scalar_on_u8_3() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0xACAC);
for &n in &[1usize, 8, 22, 100, 333] {
let text: Vec<[u8; 3]> = (0..n)
.map(|_| {
[
rng.random_range(0..4u8),
rng.random_range(0..4u8),
rng.random_range(0..4u8),
]
})
.collect();
for _ in 0..20 {
let p = rng.random_range(0..n);
let q = rng.random_range(0..n);
let want = lcp_scalar(&text, p, q, usize::MAX);
let got = lcp(&text, p, q, usize::MAX);
assert_eq!(got, want, "[u8;3] p={p} q={q} n={n}");
}
}
}
#[test]
fn simd_matches_scalar_on_u64() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0xFEED);
for &n in &[1usize, 4, 5, 50, 250] {
let text: Vec<u64> = (0..n).map(|_| rng.random_range(0..64u64)).collect();
for _ in 0..20 {
let p = rng.random_range(0..n);
let q = rng.random_range(0..n);
let want = lcp_scalar(&text, p, q, usize::MAX);
let got = lcp(&text, p, q, usize::MAX);
assert_eq!(got, want, "u64 p={p} q={q} n={n}");
}
}
}
}