use std::cmp::{max, Ordering};
use crate::utils::TextSlice;
use crate::alignment::pairwise::{MatchFunc, Scoring};
use petgraph::graph::NodeIndex;
use petgraph::visit::Topo;
use petgraph::{Directed, Graph, Incoming};
pub const MIN_SCORE: i32 = -858_993_459;
pub type POAGraph = Graph<u8, i32, Directed, usize>;
#[derive(Debug, Clone)]
pub enum AlignmentOperation {
Match(Option<(usize, usize)>),
Del(Option<(usize, usize)>),
Ins(Option<usize>),
}
pub struct Alignment {
pub score: i32,
operations: Vec<AlignmentOperation>,
}
#[derive(Debug, Clone)]
pub struct TracebackCell {
score: i32,
op: AlignmentOperation,
}
impl Ord for TracebackCell {
fn cmp(&self, other: &TracebackCell) -> Ordering {
self.score.cmp(&other.score)
}
}
impl PartialOrd for TracebackCell {
fn partial_cmp(&self, other: &TracebackCell) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for TracebackCell {
fn eq(&self, other: &TracebackCell) -> bool {
self.score == other.score
}
}
impl Eq for TracebackCell {}
pub struct Traceback {
rows: usize,
cols: usize,
last: NodeIndex<usize>,
matrix: Vec<Vec<TracebackCell>>,
}
impl Traceback {
fn with_capacity(m: usize, n: usize) -> Self {
let matrix = vec![
vec![
TracebackCell {
score: 0,
op: AlignmentOperation::Match(None)
};
n + 1
];
m + 1
];
Traceback {
rows: m,
cols: n,
last: NodeIndex::new(0),
matrix,
}
}
fn initialize_scores(&mut self, gap_open: i32) {
for (i, row) in self
.matrix
.iter_mut()
.enumerate()
.take(self.rows + 1)
.skip(1)
{
row[0] = TracebackCell {
score: (i as i32) * gap_open,
op: AlignmentOperation::Del(None),
};
}
for j in 1..=self.cols {
self.matrix[0][j] = TracebackCell {
score: (j as i32) * gap_open,
op: AlignmentOperation::Ins(None),
};
}
}
fn new() -> Self {
Traceback {
rows: 0,
cols: 0,
last: NodeIndex::new(0),
matrix: Vec::new(),
}
}
fn set(&mut self, i: usize, j: usize, cell: TracebackCell) {
self.matrix[i][j] = cell;
}
fn get(&self, i: usize, j: usize) -> &TracebackCell {
&self.matrix[i][j]
}
pub fn print(&self, g: &Graph<u8, i32, Directed, usize>, query: TextSlice) {
let (m, n) = (g.node_count(), query.len());
print!(".\t");
for base in query.iter().take(n) {
print!("{:?}\t", *base);
}
for i in 0..m {
print!("\n{:?}\t", g.raw_nodes()[i].weight);
for j in 0..n {
print!("{}.\t", self.get(i + 1, j + 1).score);
}
}
println!();
}
pub fn alignment(&self) -> Alignment {
let mut ops: Vec<AlignmentOperation> = vec![];
let mut i = self.last.index() + 1;
let mut j = self.cols;
while i > 0 && j > 0 {
ops.push(self.matrix[i][j].op.clone());
match self.matrix[i][j].op {
AlignmentOperation::Match(Some((p, _))) => {
i = p + 1;
j -= 1;
}
AlignmentOperation::Del(Some((p, _))) => {
i = p + 1;
}
AlignmentOperation::Ins(Some(p)) => {
i = p + 1;
j -= 1;
}
AlignmentOperation::Match(None) => {
break;
}
AlignmentOperation::Del(None) => {
j -= 1;
}
AlignmentOperation::Ins(None) => {
i -= 1;
}
}
}
ops.reverse();
Alignment {
score: self.matrix[self.last.index() + 1][self.cols].score,
operations: ops,
}
}
}
pub struct Aligner<F: MatchFunc> {
traceback: Traceback,
query: Vec<u8>,
poa: Poa<F>,
}
impl<F: MatchFunc> Aligner<F> {
pub fn new(scoring: Scoring<F>, reference: TextSlice) -> Self {
Aligner {
traceback: Traceback::new(),
query: reference.to_vec(),
poa: Poa::from_string(scoring, reference),
}
}
pub fn add_to_graph(&mut self) -> &mut Self {
let alignment = self.traceback.alignment();
self.poa.add_alignment(&alignment, &self.query);
self
}
pub fn alignment(&self) -> Alignment {
self.traceback.alignment()
}
pub fn global(&mut self, query: TextSlice) -> &mut Self {
self.query = query.to_vec();
self.traceback = self.poa.global(query);
self
}
pub fn graph(&self) -> &POAGraph {
&self.poa.graph
}
}
pub struct Poa<F: MatchFunc> {
scoring: Scoring<F>,
pub graph: POAGraph,
}
impl<F: MatchFunc> Poa<F> {
pub fn new(scoring: Scoring<F>, graph: POAGraph) -> Self {
Poa { scoring, graph }
}
pub fn from_string(scoring: Scoring<F>, seq: TextSlice) -> Self {
let mut graph: Graph<u8, i32, Directed, usize> =
Graph::with_capacity(seq.len(), seq.len() - 1);
let mut prev: NodeIndex<usize> = graph.add_node(seq[0]);
let mut node: NodeIndex<usize>;
for base in seq.iter().skip(1) {
node = graph.add_node(*base);
graph.add_edge(prev, node, 1);
prev = node;
}
Poa { scoring, graph }
}
pub fn global(&self, query: TextSlice) -> Traceback {
assert!(self.graph.node_count() != 0);
let (m, n) = (self.graph.node_count(), query.len());
let mut traceback = Traceback::with_capacity(m, n);
traceback.initialize_scores(self.scoring.gap_open);
traceback.set(
0,
0,
TracebackCell {
score: 0,
op: AlignmentOperation::Match(None),
},
);
let mut topo = Topo::new(&self.graph);
while let Some(node) = topo.next(&self.graph) {
let r = self.graph.raw_nodes()[node.index()].weight;
let i = node.index() + 1;
traceback.last = node;
let prevs: Vec<NodeIndex<usize>> =
self.graph.neighbors_directed(node, Incoming).collect();
for (j_p, q) in query.iter().enumerate() {
let j = j_p + 1;
let max_cell = if prevs.is_empty() {
TracebackCell {
score: traceback.get(0, j - 1).score + self.scoring.match_fn.score(r, *q),
op: AlignmentOperation::Match(None),
}
} else {
let mut max_cell = TracebackCell {
score: MIN_SCORE,
op: AlignmentOperation::Match(None),
};
for prev_node in &prevs {
let i_p: usize = prev_node.index() + 1;
max_cell = max(
max_cell,
max(
TracebackCell {
score: traceback.get(i_p, j - 1).score
+ self.scoring.match_fn.score(r, *q),
op: AlignmentOperation::Match(Some((i_p - 1, i - 1))),
},
TracebackCell {
score: traceback.get(i_p, j).score + self.scoring.gap_open,
op: AlignmentOperation::Del(Some((i_p - 1, i))),
},
),
);
}
max_cell
};
let score = max(
max_cell,
TracebackCell {
score: traceback.get(i, j - 1).score + self.scoring.gap_open,
op: AlignmentOperation::Ins(Some(i - 1)),
},
);
traceback.set(i, j, score);
}
}
traceback
}
pub fn edges(&self, aln: Alignment) -> Vec<usize> {
let mut path: Vec<usize> = vec![];
let mut prev: NodeIndex<usize> = NodeIndex::new(0);
let mut _i: usize = 0;
for op in aln.operations {
match op {
AlignmentOperation::Match(None) => {
_i += 1;
}
AlignmentOperation::Match(Some((_, p))) => {
let node = NodeIndex::new(p);
let edge = self.graph.find_edge(prev, node).unwrap();
path.push(edge.index());
prev = NodeIndex::new(p);
_i += 1;
}
AlignmentOperation::Ins(None) => {}
AlignmentOperation::Ins(Some(_)) => {}
AlignmentOperation::Del(_) => {}
}
}
path
}
pub fn add_alignment(&mut self, aln: &Alignment, seq: TextSlice) {
let mut prev: NodeIndex<usize> = NodeIndex::new(0);
let mut i: usize = 0;
for op in aln.operations.iter() {
match op {
AlignmentOperation::Match(None) => {
i += 1;
}
AlignmentOperation::Match(Some((_, p))) => {
let node = NodeIndex::new(*p);
if (seq[i] != self.graph.raw_nodes()[*p].weight) && (seq[i] != b'X') {
let node = self.graph.add_node(seq[i]);
self.graph.add_edge(prev, node, 1);
prev = node;
} else {
match self.graph.find_edge(prev, node) {
Some(edge) => {
*self.graph.edge_weight_mut(edge).unwrap() += 1;
}
None => {
self.graph.add_edge(prev, node, 1);
}
}
prev = NodeIndex::new(*p);
}
i += 1;
}
AlignmentOperation::Ins(None) => {
i += 1;
}
AlignmentOperation::Ins(Some(_)) => {
let node = self.graph.add_node(seq[i]);
self.graph.add_edge(prev, node, 1);
prev = node;
i += 1;
}
AlignmentOperation::Del(_) => {}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::alignment::pairwise::Scoring;
use petgraph::graph::NodeIndex;
#[test]
fn test_init_graph() {
let scoring = Scoring::new(-1, 0, |a: u8, b: u8| if a == b { 1i32 } else { -1i32 });
let poa = Poa::from_string(scoring, b"123456789");
assert!(poa.graph.is_directed());
assert_eq!(poa.graph.node_count(), 9);
assert_eq!(poa.graph.edge_count(), 8);
}
#[test]
fn test_alignment() {
let scoring = Scoring::new(-1, 0, |a: u8, b: u8| if a == b { 1i32 } else { -1i32 });
let poa = Poa::from_string(scoring, b"GATTACA");
let alignment = poa.global(b"GCATGCU").alignment();
assert_eq!(alignment.score, 0);
let alignment = poa.global(b"GCATGCUx").alignment();
assert_eq!(alignment.score, -1);
let alignment = poa.global(b"xCATGCU").alignment();
assert_eq!(alignment.score, -2);
}
#[test]
fn test_branched_alignment() {
let scoring = Scoring::new(-1, 0, |a: u8, b: u8| if a == b { 1i32 } else { -1i32 });
let seq1 = b"TTTTT";
let seq2 = b"TTATT";
let mut poa = Poa::from_string(scoring, seq1);
let head: NodeIndex<usize> = NodeIndex::new(1);
let tail: NodeIndex<usize> = NodeIndex::new(2);
let node1 = poa.graph.add_node(b'A');
let node2 = poa.graph.add_node(b'A');
poa.graph.add_edge(head, node1, 1);
poa.graph.add_edge(node1, node2, 1);
poa.graph.add_edge(node2, tail, 1);
let alignment = poa.global(seq2).alignment();
assert_eq!(alignment.score, 3);
}
#[test]
fn test_alt_branched_alignment() {
let scoring = Scoring::new(-1, 0, |a: u8, b: u8| if a == b { 1i32 } else { -1i32 });
let seq1 = b"TTCCTTAA";
let seq2 = b"TTTTGGAA";
let mut poa = Poa::from_string(scoring, seq1);
let head: NodeIndex<usize> = NodeIndex::new(1);
let tail: NodeIndex<usize> = NodeIndex::new(2);
let node1 = poa.graph.add_node(b'A');
let node2 = poa.graph.add_node(b'A');
poa.graph.add_edge(head, node1, 1);
poa.graph.add_edge(node1, node2, 1);
poa.graph.add_edge(node2, tail, 1);
let alignment = poa.global(seq2).alignment();
poa.add_alignment(&alignment, seq2);
assert_eq!(poa.graph.edge_count(), 14);
assert!(poa
.graph
.contains_edge(NodeIndex::new(5), NodeIndex::new(10)));
assert!(poa
.graph
.contains_edge(NodeIndex::new(11), NodeIndex::new(6)));
}
#[test]
fn test_insertion_on_branch() {
let scoring = Scoring::new(-1, 0, |a: u8, b: u8| if a == b { 1i32 } else { -1i32 });
let seq1 = b"TTCCGGTTTAA";
let seq2 = b"TTGGTATGGGAA";
let seq3 = b"TTGGTTTGCGAA";
let mut poa = Poa::from_string(scoring, seq1);
let head: NodeIndex<usize> = NodeIndex::new(1);
let tail: NodeIndex<usize> = NodeIndex::new(2);
let node1 = poa.graph.add_node(b'C');
let node2 = poa.graph.add_node(b'C');
let node3 = poa.graph.add_node(b'C');
poa.graph.add_edge(head, node1, 1);
poa.graph.add_edge(node1, node2, 1);
poa.graph.add_edge(node2, node3, 1);
poa.graph.add_edge(node3, tail, 1);
let alignment = poa.global(seq2).alignment();
assert_eq!(alignment.score, 2);
poa.add_alignment(&alignment, seq2);
let alignment2 = poa.global(seq3).alignment();
assert_eq!(alignment2.score, 10);
}
#[test]
fn test_poa_method_chaining() {
let scoring = Scoring::new(-1, 0, |a: u8, b: u8| if a == b { 1i32 } else { -1i32 });
let mut aligner = Aligner::new(scoring, b"TTCCGGTTTAA");
aligner
.global(b"TTGGTATGGGAA")
.add_to_graph()
.global(b"TTGGTTTGCGAA")
.add_to_graph();
assert_eq!(aligner.alignment().score, 10);
}
}