use std::cell::RefCell;
use thread_local::ThreadLocal;
use crate::fuzzy_matcher::{IndexType, MatchIndices};
use super::banding::{compute_banding, typo_vband_row};
use super::constants::*;
use super::{Atom, CELL_ZERO, Cell, Dir, SWMatrix, Score};
#[inline(always)]
#[allow(clippy::too_many_arguments)]
fn compute_cell<const ALLOW_TYPOS: bool>(
is_match: bool,
is_first: bool,
bonus_j: Score,
diag_score: Score,
diag_was_diag: bool,
up_score: Score,
left_score: Score,
left_was_diag: bool,
) -> (Score, Dir) {
let bonus = (bonus_j + CONSECUTIVE_BONUS * (diag_was_diag as Score)) * (1 + is_first as Score);
let match_val = (diag_score + MATCH_BONUS + bonus) * (is_match as Score);
let mismatch_val = if ALLOW_TYPOS {
(diag_score - MISMATCH_PENALTY) * (!is_match as Score)
} else {
0
};
let diag_val = match_val + mismatch_val;
let up_val = if ALLOW_TYPOS { up_score - TYPO_PENALTY } else { 0 };
let left_val = left_score - (GAP_EXTEND + (GAP_OPEN - GAP_EXTEND) * (left_was_diag as Score));
let best = diag_val.max(up_val).max(left_val);
let diag_wins = if ALLOW_TYPOS {
diag_val >= up_val && diag_val >= left_val
} else {
is_match && diag_val >= left_val
};
let up_wins = ALLOW_TYPOS && !diag_wins && up_val >= left_val;
let dir_bits: u8 = Dir::Left as u8 - (up_wins as u8) - (diag_wins as u8) * 2;
let positive = best > 0;
let dir_val = dir_bits & (positive as u8).wrapping_neg();
let dir: Dir = unsafe { std::mem::transmute(dir_val) };
(best, dir)
}
pub(super) fn full_dp<const ALLOW_TYPOS: bool, const COMPUTE_INDICES: bool, C: Atom>(
cho: &[C],
pat: &[C],
bonuses: &[Score],
respect_case: bool,
full_buf: &ThreadLocal<RefCell<SWMatrix>>,
indices_buf: &ThreadLocal<RefCell<MatchIndices>>,
use_last_match: bool,
) -> Option<(Score, MatchIndices)> {
let n = pat.len();
let m = cho.len();
let banding = compute_banding::<ALLOW_TYPOS, C>(pat, cho, respect_case)?;
let j_start = banding.j_first;
let col_off = j_start - 1; let mcols = m - col_off + 1;
let mut buf = full_buf
.get_or(|| RefCell::new(SWMatrix::zero(n + 1, mcols)))
.borrow_mut();
buf.resize(n + 1, mcols);
let base_ptr = buf.data.as_mut_ptr();
let cols = buf.cols;
unsafe {
std::ptr::write_bytes(base_ptr, 0, mcols);
for i in 1..=n {
*base_ptr.add(i * cols) = CELL_ZERO;
}
}
let (row_lo_arr, row_hi_arr) = if !ALLOW_TYPOS {
let (lo, hi) = banding.row_bounds.as_ref().unwrap();
(*lo, *hi)
} else {
([0usize; MAX_PAT_LEN], [0usize; MAX_PAT_LEN])
};
let cho_ptr = cho.as_ptr();
let bonuses_ptr = bonuses.as_ptr();
for i in 1..=n {
let pi = pat[i - 1];
let is_first = i == 1;
let (j_lo, j_hi) = typo_vband_row(i, m, banding.bandwidth, banding.j_first);
if j_lo > j_hi || j_lo > m {
if i < n {
let (nj_lo, nj_hi) = if ALLOW_TYPOS {
typo_vband_row(i + 1, m, banding.bandwidth, banding.j_first)
} else {
(row_lo_arr[i], row_hi_arr[i])
};
let nj_lo = nj_lo.max(j_start);
if nj_lo <= nj_hi && nj_lo <= m {
let njm_lo = nj_lo - col_off;
let njm_hi = (nj_hi - col_off).min(mcols - 1);
let zero_lo = njm_lo.saturating_sub(1);
let zero_hi = njm_hi.min(mcols - 1);
unsafe {
let row_ptr = base_ptr.add(i * cols);
for k in zero_lo..=zero_hi {
*row_ptr.add(k) = CELL_ZERO;
}
}
}
}
continue;
}
let jm_lo = j_lo - col_off;
let jm_hi = j_hi - col_off;
let jm_max = mcols - 1;
unsafe {
let row_ptr = base_ptr.add(i * cols);
if jm_lo > 1 {
*row_ptr.add(jm_lo - 1) = CELL_ZERO;
}
if jm_hi < jm_max {
*row_ptr.add(jm_hi + 1) = CELL_ZERO;
}
}
let (prev_row, cur_row) = unsafe {
let pr = std::slice::from_raw_parts(base_ptr.add((i - 1) * cols), cols);
let cr = std::slice::from_raw_parts_mut(base_ptr.add(i * cols), cols);
(pr, cr)
};
let prev_ptr = prev_row.as_ptr();
let cur_ptr = cur_row.as_mut_ptr();
for j in j_lo..=j_hi {
let jm = j - col_off; let cj = unsafe { *cho_ptr.add(j - 1) };
let is_match = pi.eq(cj, respect_case);
let diag_cell = unsafe { *prev_ptr.add(jm - 1) };
let up_score = if ALLOW_TYPOS {
let up_cell = unsafe { *prev_ptr.add(jm) };
up_cell.score()
} else {
0
};
let left_cell = unsafe { *cur_ptr.add(jm - 1) };
let (best, dir) = compute_cell::<ALLOW_TYPOS>(
is_match,
is_first,
unsafe { *bonuses_ptr.add(j - 1) },
diag_cell.score(),
diag_cell.is_diag(),
up_score,
left_cell.score(),
left_cell.is_diag(),
);
unsafe {
*cur_ptr.add(jm) = Cell::new(best, dir);
}
}
}
let mut best_score: Score = 0;
let mut best_j = 0usize; {
let (last_j_lo_raw, last_j_hi) = if ALLOW_TYPOS {
typo_vband_row(n, m, banding.bandwidth, banding.j_first)
} else {
(row_lo_arr[n - 1], row_hi_arr[n - 1])
};
let last_j_lo = last_j_lo_raw.max(j_start);
let last_row_ptr = unsafe { base_ptr.add(n * cols) };
if use_last_match {
for j in last_j_lo..=last_j_hi {
let jm = j - col_off;
let s = unsafe { (*last_row_ptr.add(jm)).score() };
let better = s >= best_score && s > 0;
best_score = if better { s } else { best_score };
best_j = if better { j } else { best_j };
}
} else {
for j in last_j_lo..=last_j_hi {
let jm = j - col_off;
let s = unsafe { (*last_row_ptr.add(jm)).score() };
let better = s > best_score;
best_score = if better { s } else { best_score };
best_j = if better { j } else { best_j };
}
}
}
if best_score <= 0 {
return None;
}
if COMPUTE_INDICES {
let indices_ref_cell = indices_buf.get_or(|| RefCell::new(Vec::new()));
let mut indices_ref = indices_ref_cell.borrow_mut();
indices_ref.clear();
let mut i = n;
let mut j = best_j;
let mut true_matches = 0usize;
while i > 0 && j >= j_start {
let jm = j - col_off;
let c = unsafe { *base_ptr.add(i * cols).add(jm) };
match c.dir() {
Dir::Diag => {
if pat[i - 1].eq(cho[j - 1], respect_case) {
indices_ref.push((j - 1) as IndexType);
true_matches += 1;
}
i -= 1;
j -= 1;
}
Dir::Up => {
i -= 1;
}
Dir::Left => {
j -= 1;
}
Dir::None => break,
}
}
if true_matches < banding.min_true_matches {
return None;
}
indices_ref.reverse();
let out = indices_ref.to_vec();
Some((best_score, out))
} else {
Some((best_score, Vec::default()))
}
}
pub(super) fn range_dp<const ALLOW_TYPOS: bool, C: Atom>(
cho: &[C],
pat: &[C],
bonuses: &[Score],
respect_case: bool,
full_buf: &ThreadLocal<RefCell<SWMatrix>>,
use_last_match: bool,
) -> Option<(Score, usize, usize)> {
let n = pat.len();
let m = cho.len();
let banding = compute_banding::<ALLOW_TYPOS, C>(pat, cho, respect_case)?;
let j_start = banding.j_first;
let col_off = j_start - 1;
let mcols = m - col_off + 1;
let mut buf = full_buf
.get_or(|| RefCell::new(SWMatrix::zero(n + 1, mcols)))
.borrow_mut();
buf.resize(n + 1, mcols);
let base_ptr = buf.data.as_mut_ptr();
let cols = buf.cols;
unsafe {
std::ptr::write_bytes(base_ptr, 0, mcols);
for i in 1..=n {
*base_ptr.add(i * cols) = CELL_ZERO;
}
}
let (row_lo_arr, row_hi_arr) = if !ALLOW_TYPOS {
let (lo, hi) = banding.row_bounds.as_ref().unwrap();
(*lo, *hi)
} else {
([0usize; MAX_PAT_LEN], [0usize; MAX_PAT_LEN])
};
let cho_ptr = cho.as_ptr();
let bonuses_ptr = bonuses.as_ptr();
let mut dead_rows = 0u32;
for i in 1..=n {
let pi = pat[i - 1];
let is_first = i == 1;
let (j_lo, j_hi) = if ALLOW_TYPOS {
typo_vband_row(i, m, banding.bandwidth, banding.j_first)
} else {
(row_lo_arr[i - 1], row_hi_arr[i - 1])
};
let j_lo = j_lo.max(j_start);
if j_lo > j_hi || j_lo > m {
if i < n {
let (nj_lo, nj_hi) = if ALLOW_TYPOS {
typo_vband_row(i + 1, m, banding.bandwidth, banding.j_first)
} else {
(row_lo_arr[i], row_hi_arr[i])
};
let nj_lo = nj_lo.max(j_start);
if nj_lo <= nj_hi && nj_lo <= m {
let njm_lo = nj_lo - col_off;
let njm_hi = (nj_hi - col_off).min(mcols - 1);
let zero_lo = njm_lo.saturating_sub(1);
let zero_hi = njm_hi.min(mcols - 1);
unsafe {
let row_ptr = base_ptr.add(i * cols);
for k in zero_lo..=zero_hi {
*row_ptr.add(k) = CELL_ZERO;
}
}
}
}
dead_rows += 1;
if dead_rows >= 2 {
return None;
}
continue;
}
let jm_lo = j_lo - col_off;
let jm_hi = j_hi - col_off;
let jm_max = mcols - 1;
unsafe {
let row_ptr = base_ptr.add(i * cols);
if jm_lo > 1 {
*row_ptr.add(jm_lo - 1) = CELL_ZERO;
}
if jm_hi < jm_max {
*row_ptr.add(jm_hi + 1) = CELL_ZERO;
}
}
let (prev_row, cur_row) = unsafe {
let pr = std::slice::from_raw_parts(base_ptr.add((i - 1) * cols), cols);
let cr = std::slice::from_raw_parts_mut(base_ptr.add(i * cols), cols);
(pr, cr)
};
let prev_ptr = prev_row.as_ptr();
let cur_ptr = cur_row.as_mut_ptr();
let mut row_positive = false;
for j in j_lo..=j_hi {
let jm = j - col_off;
let cj = unsafe { *cho_ptr.add(j - 1) };
let is_match = pi.eq(cj, respect_case);
let diag_cell = unsafe { *prev_ptr.add(jm - 1) };
let up_score = if ALLOW_TYPOS {
let up_cell = unsafe { *prev_ptr.add(jm) };
up_cell.score()
} else {
0
};
let left_cell = unsafe { *cur_ptr.add(jm - 1) };
let (best, dir) = compute_cell::<ALLOW_TYPOS>(
is_match,
is_first,
unsafe { *bonuses_ptr.add(j - 1) },
diag_cell.score(),
diag_cell.is_diag(),
up_score,
left_cell.score(),
left_cell.is_diag(),
);
row_positive |= best > 0;
unsafe {
*cur_ptr.add(jm) = Cell::new(best, dir);
}
}
if row_positive {
dead_rows = 0;
} else {
dead_rows += 1;
if dead_rows >= 2 {
return None;
}
}
}
let mut best_score: Score = 0;
let mut best_j = 0usize;
{
let (last_j_lo, last_j_hi) = if ALLOW_TYPOS {
typo_vband_row(n, m, banding.bandwidth, banding.j_first)
} else {
(row_lo_arr[n - 1], row_hi_arr[n - 1])
};
let last_j_lo = last_j_lo.max(j_start);
if last_j_lo <= last_j_hi && last_j_lo <= m {
let last_row_ptr = unsafe { base_ptr.add(n * cols) };
for j in last_j_lo..=last_j_hi {
let jm = j - col_off;
let s = unsafe { (*last_row_ptr.add(jm)).score() };
let better = if use_last_match {
s >= best_score && s > 0
} else {
s > best_score
};
best_score = if better { s } else { best_score };
best_j = if better { j } else { best_j };
}
}
}
if best_score <= 0 {
return None;
}
let end_0 = best_j - 1; let mut i = n;
let mut j = best_j;
let mut true_matches = 0usize;
while i > 0 && j >= j_start {
let jm = j - col_off;
let c = unsafe { *base_ptr.add(i * cols).add(jm) };
match c.dir() {
Dir::Diag => {
if pat[i - 1].eq(cho[j - 1], respect_case) {
true_matches += 1;
}
i -= 1;
j -= 1;
}
Dir::Up => {
i -= 1;
}
Dir::Left => {
j -= 1;
}
Dir::None => break,
}
}
if true_matches < banding.min_true_matches {
return None;
}
let begin_0 = j;
Some((best_score, begin_0, end_0))
}