use crate::bitpacking::compute_block;
use crate::delta_encoding::H;
use crate::delta_encoding::V;
use crate::profiles::Profile;
use crate::search::init_deltas_for_overshoot_all_lanes;
use crate::search::init_deltas_for_overshoot_scalar;
use pa_types::Cigar;
use pa_types::Cost;
use pa_types::I;
use crate::LANES;
use crate::S;
use crate::bitpacking::compute_block_simd;
use crate::search::{Match, Strand};
use std::array::from_fn;
pub trait CostLookup {
fn get(&self, i: usize, j: usize) -> Cost;
}
#[derive(Debug, Clone, Default)]
pub struct CostMatrix {
q: usize,
deltas: Vec<V>,
pub alpha: Option<f32>,
pub max_overhang: Option<usize>,
}
impl CostLookup for CostMatrix {
#[inline(always)]
fn get(&self, i: usize, j: usize) -> Cost {
let mut s = if let Some(alpha) = self.alpha {
if let Some(mo) = self.max_overhang {
(j.min(mo) as f32 * alpha).floor() as Cost + j.saturating_sub(mo) as Cost
} else {
(j as f32 * alpha).floor() as Cost
}
} else {
j as Cost
};
for idx in (j..j + i / 64 * (self.q + 1)).step_by(self.q + 1) {
s += self.deltas[idx].value();
}
if !i.is_multiple_of(64) {
s += self.deltas[j + i / 64 * (self.q + 1)].value_of_prefix(i as I % 64);
}
s
}
}
pub fn fill<P: Profile>(
query: &[u8],
text: &[u8],
len: usize,
cm: &mut CostMatrix,
alpha: Option<f32>,
max_overhang: Option<usize>,
) {
assert!(text.len() <= len);
if alpha.is_some() && !P::supports_overhang() {
panic!(
"Overhang is not supported for {:?}",
std::any::type_name::<P>()
);
}
cm.alpha = alpha;
cm.max_overhang = max_overhang;
cm.q = query.len();
cm.deltas.clear();
cm.deltas.reserve((cm.q + 1) * len.div_ceil(64));
let (profiler, query_profile) = P::encode_pattern(query);
let mut h = vec![H(1, 0); query.len()];
init_deltas_for_overshoot_scalar(&mut h, alpha, max_overhang);
let mut text_profile = P::alloc_out();
let num_chunks = len.div_ceil(64);
for i in 0..num_chunks {
let mut slice: [u8; 64] = [b'N'; 64];
let block = text.get(64 * i..).unwrap_or_default();
let block = block.get(..64).unwrap_or(block);
slice[..block.len()].copy_from_slice(block);
profiler.encode_ref(&slice, &mut text_profile);
let mut v = V::zero();
cm.deltas.push(v);
for j in 0..query.len() {
compute_block::<P>(&mut h[j], &mut v, &query_profile[j], &text_profile);
cm.deltas.push(v);
}
}
}
pub fn simd_fill<P: Profile>(
pattern: &[u8],
texts: &[&[u8]],
max_len: usize,
m: &mut [CostMatrix; LANES],
alpha: Option<f32>,
max_overhang: Option<usize>,
) {
assert!(texts.len() <= LANES);
if alpha.is_some() && !P::supports_overhang() {
panic!(
"Overhang is not supported for {:?}",
std::any::type_name::<P>()
);
}
let lanes = texts.len();
for text in texts {
assert!(text.len() <= max_len);
}
let (profiler, pattern_profile) = P::encode_pattern(pattern);
let num_chunks = max_len.div_ceil(64);
for m in &mut *m {
m.alpha = alpha;
m.max_overhang = max_overhang;
m.q = pattern.len();
m.deltas.clear();
m.deltas.reserve((m.q + 1) * num_chunks);
}
let mut hp: Vec<S> = Vec::with_capacity(pattern.len());
let mut hm: Vec<S> = Vec::with_capacity(pattern.len());
hp.resize(pattern.len(), S::splat(1));
hm.resize(pattern.len(), S::splat(0));
init_deltas_for_overshoot_all_lanes(&mut hp, alpha, max_overhang);
let mut text_profile: [_; LANES] = from_fn(|_| P::alloc_out());
for i in 0..num_chunks {
for lane in 0..lanes {
let mut slice = [b'N'; 64];
let block = texts[lane].get(64 * i..).unwrap_or_default();
let block = block.get(..64).unwrap_or(block);
slice[..block.len()].copy_from_slice(block);
profiler.encode_ref(&slice, &mut text_profile[lane]);
}
let mut vp = S::splat(0);
let mut vm = S::splat(0);
for lane in 0..lanes {
let v = V::from(vp.as_array()[lane], vm.as_array()[lane]);
m[lane].deltas.push(v);
}
for j in 0..pattern.len() {
let eq = from_fn(|lane| P::eq(&pattern_profile[j], &text_profile[lane])).into();
compute_block_simd(&mut hp[j], &mut hm[j], &mut vp, &mut vm, eq);
for lane in 0..lanes {
let v = V::from(vp.as_array()[lane], vm.as_array()[lane]);
m[lane].deltas.push(v);
}
}
}
for lane in 0..lanes {
assert_eq!(m[lane].deltas.len(), num_chunks * (m[lane].q + 1));
}
}
pub fn simd_fill_multipattern<P: Profile>(
patterns: &[&[u8]],
texts: &[&[u8]],
max_len: usize,
m: &mut [CostMatrix; LANES],
alpha: Option<f32>,
max_overhang: Option<usize>,
) {
assert!(texts.len() <= LANES);
if alpha.is_some() && !P::supports_overhang() {
panic!(
"Overhang is not supported for {:?}",
std::any::type_name::<P>()
);
}
let lanes = texts.len();
let (profiler, pattern_profiles) = P::encode_patterns(patterns);
let pattern = &patterns[0];
let num_chunks = max_len.div_ceil(64);
log::debug!("max len {max_len} num_chunks {num_chunks}");
for m in &mut *m {
m.alpha = alpha;
m.max_overhang = max_overhang;
m.q = pattern.len();
m.deltas.clear();
m.deltas.reserve((m.q + 1) * num_chunks);
}
let mut hp: Vec<S> = Vec::with_capacity(pattern.len());
let mut hm: Vec<S> = Vec::with_capacity(pattern.len());
hp.resize(pattern.len(), S::splat(1));
hm.resize(pattern.len(), S::splat(0));
init_deltas_for_overshoot_all_lanes(&mut hp, alpha, max_overhang);
let mut text_profile: [_; LANES] = from_fn(|_| P::alloc_out());
for i in 0..num_chunks {
for lane in 0..lanes {
let mut slice = [b'N'; 64];
let block = texts[lane].get(64 * i..).unwrap_or_default();
let block = block.get(..64).unwrap_or(block);
slice[..block.len()].copy_from_slice(block);
profiler.encode_ref(&slice, &mut text_profile[lane]);
}
let mut vp = S::splat(0);
let mut vm = S::splat(0);
for lane in 0..lanes {
let v = V::from(vp.as_array()[lane], vm.as_array()[lane]);
m[lane].deltas.push(v);
}
for j in 0..pattern.len() {
let eq = from_fn(|lane| P::eq(&pattern_profiles[j][lane], &text_profile[lane])).into();
compute_block_simd(&mut hp[j], &mut hm[j], &mut vp, &mut vm, eq);
for lane in 0..lanes {
let v = V::from(vp.as_array()[lane], vm.as_array()[lane]);
m[lane].deltas.push(v);
}
}
}
for lane in 0..lanes {
assert_eq!(m[lane].deltas.len(), num_chunks * (m[lane].q + 1));
}
}
pub fn get_trace<P: Profile>(
pattern: &[u8],
text_offset: usize,
end_pos: usize,
text: &[u8],
m: &impl CostLookup,
alpha: Option<f32>,
max_overhang: Option<usize>,
) -> Match {
let mut trace = Vec::new();
let mut j = pattern.len();
let mut i = end_pos - text_offset;
let cost = |j: usize, i: usize| -> Cost { m.get(i, j) };
log::debug!("Trace ({j}, {i}) end pos {end_pos} offset {text_offset}");
let mut g = cost(j, i);
let mut total_cost = g;
log::debug!("Initial cost at ({j}, {i}) is {g}");
let mut cigar = Cigar::default();
let mut pattern_start = 0;
let mut pattern_end = pattern.len();
if i > text.len() {
let overshoot = i - text.len();
pattern_end -= overshoot;
let overshoot_cost = (overshoot as f32 * alpha.unwrap()).floor() as Cost;
total_cost += overshoot_cost;
i -= overshoot;
j -= overshoot;
log::debug!("Trace from ({j}, {i}) for total cost {total_cost}");
log::debug!("Right overshoot {overshoot} for cost {overshoot_cost}");
} else {
log::debug!("Trace from ({j}, {i}) for total cost {total_cost}");
}
loop {
trace.push((j, text_offset + i));
if j == 0 {
break;
}
if i == 0
&& let Some(alpha) = alpha
{
let overshoot = j;
pattern_start = overshoot;
let overshoot_cost = if let Some(mo) = max_overhang {
(j.min(mo) as f32 * alpha).floor() as Cost + j.saturating_sub(mo) as Cost
} else {
(j as f32 * alpha).floor() as Cost
};
g -= overshoot_cost;
break;
}
if i > 0 && cost(j - 1, i - 1) == g && P::is_match(pattern[j - 1], text[i - 1]) {
cigar.push(pa_types::CigarOp::Match);
j -= 1;
i -= 1;
continue;
}
g -= 1;
if i > 0 && cost(j - 1, i - 1) == g {
cigar.push(pa_types::CigarOp::Sub);
j -= 1;
i -= 1;
continue;
}
if i > 0 && cost(j, i - 1) == g {
cigar.push(pa_types::CigarOp::Del);
i -= 1;
continue;
}
if cost(j - 1, i) == g {
cigar.push(pa_types::CigarOp::Ins);
j -= 1;
continue;
}
if !P::valid_seq(&[pattern[j - 1]]) {
panic!(
"Trace failed, because the query contains non-{:?} character {} at position {}. (Use `profiles::Iupac` instead of `profiles::Dna`.)",
std::any::type_name::<P>(),
pattern[j - 1] as char,
j - 1
);
}
if !P::valid_seq(&[text[i - 1]]) {
panic!(
"Trace failed, because the text contains non-{:?} character {} at position {}. (Use `profiles::Iupac` instead of `profiles::Dna`.)",
std::any::type_name::<P>(),
text[i - 1] as char,
i - 1
);
}
panic!(
"Trace failed! No ancestor found of {j} {i} at distance {}",
g + 1
);
}
assert_eq!(g, 0, "Remaining cost after the trace must be 0.");
cigar.reverse();
Match {
pattern_idx: 0,
text_idx: 0,
cost: total_cost,
text_start: text_offset + i,
text_end: text_offset + text.len(),
pattern_start,
pattern_end,
strand: Strand::Fwd,
cigar,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::profiles::Dna;
#[test]
fn test_traceback() {
let query = b"ATTTTCCCGGGGATTTT".as_slice();
let text2: &[u8] = b"ATTTTGGGGATTTT".as_slice();
let mut cost_matrix = Default::default();
fill::<Dna>(query, text2, text2.len(), &mut cost_matrix, None, None);
let trace = get_trace::<Dna>(query, 0, text2.len(), text2, &cost_matrix, None, None);
println!("Trace: {:?}", trace);
}
#[test]
fn test_traceback_simd() {
let query = b"ATTTTCCCGGGGATTTT".as_slice();
let text1 = b"ATTTTCCCGGGGATTTT".as_slice();
let text2 = b"ATTTTGGGGATTTT".as_slice();
let text3 = b"TGGGGATTTT".as_slice();
let text4 = b"TTTTTTTTTTATTTTGGGGATTTT".as_slice();
let mut cost_matrix = Default::default();
simd_fill::<Dna>(
query,
&[text1, text2, text3, text4],
text4.len(),
&mut cost_matrix,
None,
None,
);
let _trace = get_trace::<Dna>(query, 0, text1.len(), text1, &cost_matrix[0], None, None);
let _trace = get_trace::<Dna>(query, 0, text2.len(), text2, &cost_matrix[1], None, None);
let _trace = get_trace::<Dna>(query, 0, text3.len(), text3, &cost_matrix[2], None, None);
let trace = get_trace::<Dna>(query, 0, text4.len(), text4, &cost_matrix[3], None, None);
println!("Trace: {:?}", trace);
}
}