use crate::pattern_tiling::backend::SimdBackend;
use crate::pattern_tiling::minima::{TracePostProcess, local_minima_indices};
use crate::pattern_tiling::search::{HitRange, Myers};
use crate::pattern_tiling::tqueries::TQueries;
use crate::profiles::Profile;
use crate::search::Match;
use crate::search::Strand;
use crate::search::get_overhang_steps;
use crate::trace::{CostLookup, get_trace};
use pa_types::Cost;
pub struct PatternHistory<S: Copy> {
pub steps: Vec<SimdHistoryStep<S>>,
}
impl<S: Copy> Default for PatternHistory<S> {
fn default() -> Self {
Self { steps: Vec::new() }
}
}
pub struct SimdHistoryStep<S: Copy> {
pub vp: S,
pub vn: S,
}
pub struct TraceBuffer {
pub pattern_indices: Vec<usize>,
pub approx_slices: Vec<(isize, isize)>,
pub range_bounds: Vec<(isize, isize)>,
pub per_range_alignments: Vec<Vec<Match>>,
pub filtered_alignments: Vec<Match>,
pub temp_pos_cost: Vec<(isize, isize)>,
pub filled_till: usize,
pub pos_cost_buffer: Vec<(isize, isize)>,
pub minima_indices_buffer: Vec<usize>,
}
impl TraceBuffer {
pub fn new(lanes: usize) -> Self {
Self {
pattern_indices: vec![0; lanes],
approx_slices: vec![(0isize, 0isize); lanes],
range_bounds: vec![(0isize, 0isize); lanes],
per_range_alignments: vec![Vec::new(); lanes],
filtered_alignments: Vec::with_capacity(10),
temp_pos_cost: Vec::new(),
filled_till: 0,
pos_cost_buffer: Vec::new(),
minima_indices_buffer: Vec::new(),
}
}
#[inline(always)]
pub fn clear_alns(&mut self) {
for aln in self.per_range_alignments.iter_mut() {
aln.clear();
}
self.filtered_alignments.clear();
self.temp_pos_cost.clear();
self.filled_till = 0;
}
#[inline(always)]
pub fn populate(&mut self, ranges: &[HitRange], left_buffer: usize) {
if self.pattern_indices.len() < ranges.len() {
self.pattern_indices.resize(ranges.len(), 0);
self.approx_slices.resize(ranges.len(), (0isize, 0isize));
self.range_bounds.resize(ranges.len(), (0isize, 0isize));
self.per_range_alignments.resize(ranges.len(), Vec::new());
}
for (i, r) in ranges.iter().enumerate() {
self.pattern_indices[i] = r.pattern_idx;
self.approx_slices[i] = (r.start.saturating_sub(left_buffer as isize).max(0), r.end);
self.range_bounds[i] = (r.start, r.end);
self.per_range_alignments[i].clear();
}
self.filled_till = ranges.len();
}
}
pub(crate) struct V2CostLookup<'a, B: SimdBackend, P: Profile> {
searcher: &'a Myers<B, P>,
lane_idx: usize,
}
impl<B: SimdBackend, P: Profile> CostLookup for V2CostLookup<'_, B, P> {
#[inline(always)]
fn get(&self, i: usize, j: usize) -> Cost {
if j == 0 {
return 0;
}
let step_idx = i as isize - 1;
let pattern_pos = j as isize - 1;
let mask = if pattern_pos >= 63 {
!0u64
} else {
(1u64 << (pattern_pos + 1)) - 1
};
if step_idx < 0 {
(self.searcher.alpha_pattern & mask).count_ones() as Cost
} else {
let step_data = &self.searcher.history[self.lane_idx].steps[step_idx as usize];
let vp_bits = extract_simd_lane::<B>(step_data.vp, self.lane_idx);
let vn_bits = extract_simd_lane::<B>(step_data.vn, self.lane_idx);
let pos = (vp_bits & mask).count_ones() as Cost;
let neg = (vn_bits & mask).count_ones() as Cost;
pos - neg
}
}
}
#[inline(always)]
fn handle_suffix_overhangs<B: SimdBackend, P: Profile>(
searcher: &mut Myers<B, P>,
last_bit_shift: u32,
last_bit_mask: B::Simd,
batch_size: usize,
overhang_steps: usize,
) {
let blocks_ptr = searcher.blocks.as_mut_ptr();
let all_ones = searcher.all_ones;
for _i in 0..overhang_steps {
unsafe {
let block = &mut *blocks_ptr;
let (vp_out, vn_out, _cost_out) = Myers::<B, P>::myers_step(
block.vp,
block.vn,
block.cost,
all_ones, all_ones,
last_bit_shift,
last_bit_mask,
);
for lane in 0..batch_size {
searcher.history[lane].steps.push(SimdHistoryStep {
vp: vp_out,
vn: vn_out,
});
}
block.vp = vp_out;
block.vn = vn_out;
}
}
}
#[inline(always)]
fn traceback_positions<B: SimdBackend, P: Profile>(
positions_and_costs: &[(isize, isize)],
searcher: &Myers<B, P>,
lane: usize,
pattern_idx: usize,
t_queries: &TQueries<B, P>,
approx_start: isize,
text: &[u8],
out: &mut Vec<Match>,
) {
for &(pos, _cost) in positions_and_costs {
let aln = traceback_single(
searcher,
lane,
pattern_idx,
t_queries,
(approx_start, pos),
text,
);
out.push(aln);
}
}
#[inline(always)]
fn trace_passing_alignments<B: SimdBackend, P: Profile>(
batch_size: usize,
buffer: &mut TraceBuffer,
searcher: &Myers<B, P>,
t_queries: &TQueries<B, P>,
text: &[u8],
k: u32,
post: TracePostProcess,
) {
let text_len = text.len();
let pattern_length = t_queries.pattern_length;
let k_isize = k as isize;
let max_valid_text_pos = text_len.saturating_sub(1) as isize;
for lane in 0..batch_size {
let (r_start, r_end) = buffer.range_bounds[lane];
let approx_start = buffer.approx_slices[lane].0;
let pattern_idx = buffer.pattern_indices[lane];
let cost_lookup = V2CostLookup {
searcher,
lane_idx: lane,
};
buffer.pos_cost_buffer.clear();
for pos in r_start..=r_end {
let step_idx = pos - approx_start;
let i = (step_idx + 1).max(0) as usize;
let j = pattern_length;
let mut cost = cost_lookup.get(i, j) as isize;
if pos > max_valid_text_pos && searcher.alpha != 1.0 {
let overshoot = (pos - max_valid_text_pos) as usize;
cost += (overshoot as f32 * searcher.alpha).floor() as isize;
}
if cost <= k_isize {
buffer.pos_cost_buffer.push((pos, cost));
}
}
match post {
TracePostProcess::All => {
traceback_positions::<B, P>(
&buffer.pos_cost_buffer,
searcher,
lane,
pattern_idx,
t_queries,
approx_start,
text,
&mut buffer.filtered_alignments,
);
}
TracePostProcess::LocalMinima => {
local_minima_indices(&buffer.pos_cost_buffer, &mut buffer.minima_indices_buffer);
buffer.temp_pos_cost.clear();
buffer.temp_pos_cost.extend(
buffer
.minima_indices_buffer
.iter()
.map(|&idx| buffer.pos_cost_buffer[idx]),
);
traceback_positions::<B, P>(
&buffer.temp_pos_cost,
searcher,
lane,
pattern_idx,
t_queries,
approx_start,
text,
&mut buffer.filtered_alignments,
);
}
}
}
}
#[inline(always)]
pub fn trace_batch_ranges<B: SimdBackend, P: Profile>(
searcher: &mut Myers<B, P>,
t_queries: &TQueries<B, P>,
text: &[u8],
ranges: &[HitRange],
k: u32,
post: TracePostProcess,
alpha: Option<f32>,
max_overhang: Option<usize>,
buffer: &mut TraceBuffer,
) {
assert!(ranges.len() <= B::LANES, "Batch size must be <= LANES");
if ranges.is_empty() {
return;
}
let left_buffer = t_queries.pattern_length + k as usize;
buffer.clear_alns();
buffer.populate(ranges, left_buffer);
let batch_size = buffer.filled_till;
searcher.ensure_capacity(1, buffer.filled_till);
let length_mask = (!0u64) >> (64usize.saturating_sub(t_queries.pattern_length));
searcher.search_prep(
1,
t_queries.n_queries,
t_queries.pattern_length,
searcher.alpha_pattern & length_mask,
);
for i in 0..buffer.filled_till {
searcher.history[i].steps.clear();
searcher.history[i].steps.reserve(left_buffer);
}
let last_bit_shift = (t_queries.pattern_length - 1) as u32;
let last_bit_mask = B::splat_one() << last_bit_shift;
let all_ones = B::splat_all_ones();
let zero_scalar = B::scalar_from_i64(0);
let one_mask = <B as SimdBackend>::mask_word_to_scalar(!0u64);
let blocks_ptr = searcher.blocks.as_mut_ptr();
let text_ptr = text.as_ptr();
let text_len = text.len();
let mut max_len = 0;
for slice in buffer.approx_slices.iter().take(batch_size) {
let len = (slice.1 - slice.0 + 1) as usize;
if len > max_len {
max_len = len;
}
}
let overhang_steps = get_overhang_steps(
t_queries.pattern_length,
k as usize,
alpha.unwrap_or(1.0),
max_overhang,
);
let mut eq_arr = B::LaneArray::default();
let mut keep_mask_arr = B::LaneArray::default();
for i in 0..max_len {
unsafe {
let block = &mut *blocks_ptr;
let eq_slice = eq_arr.as_mut();
let keep_slice = keep_mask_arr.as_mut();
for lane in 0..batch_size {
let q_idx = buffer.pattern_indices[lane];
let start = buffer.approx_slices[lane].0;
let abs_pos = (i as isize) + start;
if abs_pos >= 0 && (abs_pos as usize) < text_len {
let cur_char = *text_ptr.add(abs_pos as usize);
let enc = P::encode_char(cur_char) as usize;
eq_slice[lane] = B::mask_word_to_scalar(t_queries.peq_masks[enc][q_idx]);
keep_slice[lane] = one_mask;
} else {
eq_slice[lane] = zero_scalar;
keep_slice[lane] = zero_scalar;
}
}
let eq = B::from_array(eq_arr);
let keep_mask = B::from_array(keep_mask_arr);
let freeze_mask = all_ones ^ keep_mask;
let (vp_new, vn_new, cost_new) = Myers::<B, P>::myers_step(
block.vp,
block.vn,
block.cost,
eq,
all_ones,
last_bit_shift,
last_bit_mask,
);
let vp_masked = (vp_new & keep_mask) | (block.vp & freeze_mask);
let vn_masked = (vn_new & keep_mask) | (block.vn & freeze_mask);
let cost_masked = (cost_new & keep_mask) | (block.cost & freeze_mask);
let freeze_arr = B::to_array(freeze_mask);
let freeze_slice = freeze_arr.as_ref();
for lane in 0..batch_size {
let is_frozen = B::scalar_to_u64(freeze_slice[lane]) != 0;
if !is_frozen {
searcher.history[lane].steps.push(SimdHistoryStep {
vp: vp_masked,
vn: vn_masked,
});
}
}
block.vp = vp_masked;
block.vn = vn_masked;
block.cost = cost_masked;
}
}
if alpha.is_some() {
handle_suffix_overhangs(
searcher,
last_bit_shift,
last_bit_mask,
batch_size,
overhang_steps,
);
}
trace_passing_alignments(batch_size, buffer, searcher, t_queries, text, k, post);
}
#[inline(always)]
fn extract_simd_lane<B: SimdBackend>(simd_val: B::Simd, lane: usize) -> u64 {
let arr = B::to_array(simd_val);
B::scalar_to_u64(arr.as_ref()[lane])
}
#[inline(always)]
fn traceback_single<B: SimdBackend, P: Profile>(
searcher: &Myers<B, P>,
lane_idx: usize,
original_pattern_idx: usize,
t_queries: &TQueries<B, P>,
slice: (isize, isize),
text: &[u8],
) -> Match {
let pattern = &t_queries.queries[original_pattern_idx];
let approx_start = slice.0 as usize;
let end_pos_exclusive = (slice.1 + 1).max(0) as usize;
let text_slice_end = end_pos_exclusive.min(text.len());
let text_slice = &text[approx_start..text_slice_end];
let alpha = if searcher.alpha != 1.0 {
Some(searcher.alpha)
} else {
None
};
let cost_lookup = V2CostLookup { searcher, lane_idx };
let mut m = get_trace::<P>(
pattern,
approx_start,
end_pos_exclusive,
text_slice,
&cost_lookup,
alpha,
searcher.max_overhang,
);
m.pattern_idx = original_pattern_idx % t_queries.n_original_queries;
m.strand = if original_pattern_idx >= t_queries.n_original_queries {
Strand::Rc
} else {
Strand::Fwd
};
m
}