use std::default::Default;
use std::iter;
use std::marker::PhantomData;
use std::ops::Range;
use crate::alignment::AlignmentOperation;
use crate::pattern_matching::myers::{word_size, BitVec, DistType, State};
pub(super) trait StatesHandler<'a, T, D>
where
T: BitVec + 'a,
D: DistType,
{
type TracebackHandler: TracebackHandler<'a, T, D>;
type TracebackColumn: ?Sized;
fn init(&mut self, n: usize, m: D) -> usize;
fn n_blocks(&self) -> usize;
fn set_max_state(&self, pos: usize, states: &mut [State<T, D>]);
fn add_state(&self, source: &Self::TracebackColumn, pos: usize, states: &mut [State<T, D>]);
fn dist_at(&self, pos: usize, states: &[State<T, D>]) -> Option<D>;
fn init_traceback(
&self,
m: D,
pos: usize,
states: &'a [State<T, D>],
) -> Option<Self::TracebackHandler>;
}
pub(super) trait TracebackHandler<'a, T, D>
where
T: BitVec + 'a,
D: DistType,
{
fn dist(&self) -> D;
fn left_dist(&self) -> D;
fn try_move_up(&mut self) -> bool;
fn move_up(&mut self);
fn try_prepare_left(&mut self) -> bool;
fn prepare_diagonal(&mut self);
fn finish_move_left(&mut self);
fn done(&self) -> bool;
#[allow(dead_code)]
fn print_state(&self);
}
#[derive(Clone, Debug)]
pub(super) struct Traceback<'a, T, D, H>
where
T: BitVec + 'a,
D: DistType,
H: StatesHandler<'a, T, D>,
{
m: D,
positions: iter::Cycle<Range<usize>>,
handler: H,
pos: usize,
_t: PhantomData<&'a T>,
}
impl<'a, T, D, H> Traceback<'a, T, D, H>
where
T: BitVec,
D: DistType,
H: StatesHandler<'a, T, D>,
{
#[inline]
pub fn new(
states: &mut Vec<State<T, D>>,
initial_state: &H::TracebackColumn,
num_cols: usize,
m: D,
mut handler: H,
) -> Self {
let num_cols = num_cols + 2;
let n_states = handler.init(num_cols, m);
let mut tb = Traceback {
m,
positions: (0..num_cols).cycle(),
handler,
pos: 0,
_t: PhantomData,
};
states.resize_with(n_states, Default::default);
tb.pos = tb.positions.next().unwrap();
tb.handler.set_max_state(tb.pos, states);
tb.add_state(initial_state, states);
tb
}
#[inline]
pub fn add_state(&mut self, column: &H::TracebackColumn, states: &mut [State<T, D>]) {
self.pos = self.positions.next().unwrap();
self.handler.add_state(column, self.pos, states);
}
#[inline]
pub fn dist_at(&self, pos: usize, states: &'a [State<T, D>]) -> Option<D> {
let pos = pos + 2; if pos <= self.pos {
return self.handler.dist_at(pos, states).map(|d| d as D);
}
None
}
#[inline]
pub fn traceback(
&self,
ops: Option<&mut Vec<AlignmentOperation>>,
states: &'a [State<T, D>],
) -> Option<(D, D)> {
self._traceback_at(self.pos, ops, states)
}
#[inline]
pub fn traceback_at(
&self,
pos: usize,
ops: Option<&mut Vec<AlignmentOperation>>,
states: &'a [State<T, D>],
) -> Option<(D, D)> {
let pos = pos + 2; if pos <= self.pos {
return self._traceback_at(pos, ops, states);
}
None
}
#[inline]
fn _traceback_at(
&self,
pos: usize,
mut ops: Option<&mut Vec<AlignmentOperation>>,
state_slice: &'a [State<T, D>],
) -> Option<(D, D)> {
use self::AlignmentOperation::*;
let mut h = self.handler.init_traceback(self.m, pos, state_slice)?;
let mut h_offset = D::zero();
let dist = h.dist();
while !h.done() {
let op;
#[allow(clippy::never_loop)]
loop {
if h.left_dist().wrapping_add(&D::one()) == h.dist() {
h.prepare_diagonal();
op = Subst;
} else if h.try_move_up() {
op = Ins;
break;
} else if h.try_prepare_left() {
op = Del;
} else {
debug_assert!(h.left_dist() == h.dist());
h.prepare_diagonal();
op = Match;
}
h_offset += D::one();
h.finish_move_left();
break;
}
if let Some(o) = ops.as_mut() {
o.push(op);
}
}
Some((h_offset, dist))
}
#[allow(dead_code)]
fn print_tb_matrix(&self, pos: usize, state_slice: &'a [State<T, D>]) {
let n_blocks = self.handler.n_blocks();
let pos = n_blocks * (pos + 1);
let states_iter = state_slice[..pos]
.chunks(n_blocks)
.rev()
.chain(state_slice.chunks(n_blocks).rev().cycle());
let m = self.m.to_usize().unwrap();
let mut out = vec![];
for col in states_iter {
let mut col_out = Vec::with_capacity(m);
let mut empty = true;
for (i, block) in col.iter().enumerate().rev() {
if !(block.is_new() || block.is_max()) {
empty = false;
}
let w = word_size::<T>();
let end = (i + 1) * w;
let _m = if end <= m { w } else { m % w };
let mut _block = *block;
let mut pos_mask = T::one() << (_m - 1);
col_out.push(_block.dist);
for _ in 0.._m {
_block.adjust_one_up(pos_mask);
pos_mask >>= 1;
col_out.push(_block.dist);
}
}
out.push(col_out);
if empty {
break;
}
}
for j in (0..m).rev() {
eprint!("{:>4}: ", m - j + 1);
for col in out.iter().rev() {
if let Some(d) = col.get(j) {
if *d >= (D::max_value() >> 1) {
eprint!(" ");
} else {
eprint!("{:>4?}", d);
}
} else {
eprint!(" -");
}
}
eprintln!();
}
}
}