use alloc::vec::Vec;
use core::mem::MaybeUninit;
use memchr::memchr2;
use crate::{
MATCH_MAX_LEN,
SCORE_MAX,
SCORE_MIN,
Score,
bonus::compute_bonus,
config::{GapConfig, SCORE_MATCH_CONSECUTIVE},
};
#[inline]
pub fn has_match(needle: &[u8], haystack: &[u8]) -> bool {
let mut h = haystack;
for &nc in needle {
let lo = nc | 0x20;
let hi = nc & !0x20;
match memchr2(lo, hi, h) {
Some(pos) => h = &h[pos + 1..],
None => return false,
}
}
true
}
struct MatchStruct {
needle_len: usize,
haystack_len: usize,
lower_needle: [MaybeUninit<u8>; MATCH_MAX_LEN],
lower_haystack: [MaybeUninit<u8>; MATCH_MAX_LEN],
match_bonus: [MaybeUninit<f64>; MATCH_MAX_LEN],
}
#[inline(always)]
fn to_lower(b: u8) -> u8 {
if b.is_ascii_uppercase() { b | 0x20 } else { b }
}
#[inline(always)]
fn eq_ignore_ascii_case(a: &[u8], b: &[u8]) -> bool {
a.len() == b.len()
&& a
.iter()
.zip(b.iter())
.all(|(&x, &y)| to_lower(x) == to_lower(y))
}
#[inline(always)]
fn setup<'a>(
needle: &[u8],
haystack: &[u8],
storage: &'a mut MaybeUninit<MatchStruct>,
) -> Option<&'a MatchStruct> {
let n = needle.len();
let m = haystack.len();
if m > MATCH_MAX_LEN || n > m {
return None;
}
let ms = unsafe { &mut *storage.as_mut_ptr() };
ms.needle_len = n;
ms.haystack_len = m;
for (i, &b) in needle.iter().enumerate() {
ms.lower_needle[i] = MaybeUninit::new(to_lower(b));
}
for (i, &b) in haystack.iter().enumerate() {
ms.lower_haystack[i] = MaybeUninit::new(to_lower(b));
}
let mut prev = b'/';
for (i, &ch) in haystack.iter().enumerate() {
ms.match_bonus[i] = MaybeUninit::new(compute_bonus(prev, ch));
prev = ch;
}
Some(ms)
}
#[inline(always)]
unsafe fn match_row(
ms: &MatchStruct,
row: usize,
curr_d: *mut f64,
curr_m: *mut f64,
last_d: *const f64,
last_m: *const f64,
gap: GapConfig,
) {
let n = ms.needle_len;
let m = ms.haystack_len;
let i = row;
let gap_score = if i == n - 1 { gap.trailing } else { gap.inner };
unsafe {
let ni = ms.lower_needle[i].assume_init();
let mut prev_score = SCORE_MIN;
let mut prev_d = SCORE_MIN;
let mut prev_m = SCORE_MIN;
for j in 0..m {
let hj = ms.lower_haystack[j].assume_init();
if ni == hj {
let score = if i == 0 {
(j as f64) * gap.leading + ms.match_bonus[j].assume_init()
} else if j > 0 {
f64::max(
prev_m + ms.match_bonus[j].assume_init(),
prev_d + SCORE_MATCH_CONSECUTIVE,
)
} else {
SCORE_MIN
};
prev_d = *last_d.add(j);
prev_m = *last_m.add(j);
*curr_d.add(j) = score;
prev_score = f64::max(score, prev_score + gap_score);
*curr_m.add(j) = prev_score;
} else {
prev_d = *last_d.add(j);
prev_m = *last_m.add(j);
*curr_d.add(j) = SCORE_MIN;
prev_score += gap_score;
*curr_m.add(j) = prev_score;
}
}
}
}
pub(crate) fn score_raw(
needle: &[u8],
haystack: &[u8],
gap: GapConfig,
) -> Score {
if needle.is_empty() {
return SCORE_MIN;
}
let mut storage = MaybeUninit::<MatchStruct>::uninit();
let ms = match setup(needle, haystack, &mut storage) {
None => return SCORE_MIN,
Some(ms) => ms,
};
let n = ms.needle_len;
let m = ms.haystack_len;
if n == m && eq_ignore_ascii_case(needle, haystack) {
return SCORE_MAX;
}
let mut d = [MaybeUninit::<f64>::uninit(); MATCH_MAX_LEN];
let mut m_a = [MaybeUninit::<f64>::uninit(); MATCH_MAX_LEN];
unsafe {
let dp = d.as_mut_ptr() as *mut f64;
let mp = m_a.as_mut_ptr() as *mut f64;
for i in 0..n {
match_row(ms, i, dp, mp, dp, mp, gap);
}
*mp.add(m - 1)
}
}
pub(crate) fn score_positions_raw(
needle: &[u8],
haystack: &[u8],
positions: &mut [usize],
gap: GapConfig,
) -> Score {
if needle.is_empty() {
return SCORE_MIN;
}
let mut storage = MaybeUninit::<MatchStruct>::uninit();
let ms = match setup(needle, haystack, &mut storage) {
None => return SCORE_MIN,
Some(ms) => ms,
};
let n = ms.needle_len;
let m = ms.haystack_len;
if n == m && eq_ignore_ascii_case(needle, haystack) {
for (i, p) in positions[..n].iter_mut().enumerate() {
*p = i;
}
return SCORE_MAX;
}
let mut d_mat: Vec<MaybeUninit<f64>> =
alloc::vec![MaybeUninit::uninit(); n * MATCH_MAX_LEN];
let mut m_mat: Vec<MaybeUninit<f64>> =
alloc::vec![MaybeUninit::uninit(); n * MATCH_MAX_LEN];
unsafe {
let d_base = d_mat.as_mut_ptr() as *mut f64;
let m_base = m_mat.as_mut_ptr() as *mut f64;
match_row(ms, 0, d_base, m_base, d_base, m_base, gap);
for i in 1..n {
let last_d = d_base.add((i - 1) * MATCH_MAX_LEN) as *const f64;
let last_m = m_base.add((i - 1) * MATCH_MAX_LEN) as *const f64;
let curr_d = d_base.add(i * MATCH_MAX_LEN);
let curr_m = m_base.add(i * MATCH_MAX_LEN);
match_row(ms, i, curr_d, curr_m, last_d, last_m, gap);
}
let mut match_required = false;
let mut j = m - 1;
let mut i = n;
loop {
if i == 0 {
break;
}
i -= 1;
loop {
let di = *d_base.add(i * MATCH_MAX_LEN + j);
let mi = *m_base.add(i * MATCH_MAX_LEN + j);
if di != SCORE_MIN && (match_required || di == mi) {
match_required = if i > 0 && j > 0 {
let prev_d = *d_base.add((i - 1) * MATCH_MAX_LEN + (j - 1));
mi == prev_d + SCORE_MATCH_CONSECUTIVE
} else {
false
};
positions[i] = j;
j = j.saturating_sub(1);
break;
}
if j == 0 {
break;
}
j -= 1;
}
}
*m_base.add((n - 1) * MATCH_MAX_LEN + m - 1)
}
}
pub fn score(needle: &[u8], haystack: &[u8]) -> Score {
score_raw(needle, haystack, GapConfig::CONTIGUOUS)
}
pub fn score_positions(
needle: &[u8],
haystack: &[u8],
positions: &mut [usize],
) -> Score {
score_positions_raw(needle, haystack, positions, GapConfig::CONTIGUOUS)
}