extern crate arrayfire;
extern crate fnv;
use arrayfire::{Array, Dim4, constant, add, col, row, cols, rows, join, lookup, set_col, set_row, maxof, replace, gt, ge, eq, max_all};
use std::hash::BuildHasherDefault;
use std::collections::HashMap;
use fnv::FnvHasher;
pub type BElm = u8;
type BSq = Vec<BElm>;
type AlgnPs = usize;
#[derive(Debug)]
struct AlgnPsPr {
strt_ps: AlgnPs,
end_ps: AlgnPs,
}
type AlgnScr = f32;
#[derive(Debug)]
pub struct PrAlgn {
algn: (BSq, BSq),
algn_ps_pr_pr: (AlgnPsPr, AlgnPsPr),
algn_scr: AlgnScr,
}
type Hshr = BuildHasherDefault<FnvHasher>;
pub type SbstMt = HashMap<(BElm, BElm), AlgnScr, Hshr>;
pub struct AlgnScrSchm {
gp_opn_pnlty: AlgnScr,
gp_extnsn_pnlty: AlgnScr,
}
impl AlgnScrSchm {
pub fn new(gp_opn_pnlty: AlgnScr, gp_extnsn_pnlty: AlgnScr) -> AlgnScrSchm {
AlgnScrSchm {
gp_opn_pnlty: gp_opn_pnlty,
gp_extnsn_pnlty: gp_extnsn_pnlty,
}
}
}
type AfDm = u64;
pub type Alphbt<'a> = &'a[BElm];
type DpSrc = u32;
const GP: BElm = '-' as BElm;
const DGNL: DpSrc = 0;
const VRTCL: DpSrc = DGNL + 1;
const HRZNTL: DpSrc = VRTCL + 1;
pub fn gpu_sw(b_sq_pr: &(&[BElm], &[BElm]), sbst_mt: &SbstMt, algn_scr_schm: &AlgnScrSchm, is_dna: bool) -> PrAlgn {
let b_sq_ln_pr = (b_sq_pr.0.len(), b_sq_pr.1.len());
let alphbt = gt_alphbt(is_dna);
let gpu_b_sq_pr = (
Array::new(&b_sq_pr.0.iter().map(|&b_elm| gt_hsh_b_elm(b_elm, alphbt)).collect::<BSq>(), Dim4::new(&[b_sq_ln_pr.0 as AfDm, 1, 1, 1])),
Array::new(&b_sq_pr.1.iter().map(|&b_elm| gt_hsh_b_elm(b_elm, alphbt)).collect::<BSq>(), Dim4::new(&[b_sq_ln_pr.1 as AfDm, 1, 1, 1])),
);
let alphbt_ln = alphbt.len();
let mut hsh_sbst_mt = vec![vec![0.; alphbt_ln]; alphbt_ln];
for (b_elm_pr, &sbst_scr) in sbst_mt {
hsh_sbst_mt[gt_hsh_b_elm(b_elm_pr.0, alphbt) as usize][gt_hsh_b_elm(b_elm_pr.1, alphbt) as usize] = sbst_scr;
}
let hsh_sbst_mt = hsh_sbst_mt.iter().flat_map(|sbst_scr| sbst_scr.clone()).collect::<Vec<AlgnScr>>();
let gpu_sbst_mt = Array::new(&hsh_sbst_mt, Dim4::new(&[alphbt_ln as AfDm, alphbt_ln as AfDm, 1, 1]));
let scr_mt = lookup(&lookup(&gpu_sbst_mt, &gpu_b_sq_pr.0, 0), &gpu_b_sq_pr.1, 1);
let scr_mt_dms = scr_mt.dims();
let mut dp_mt = constant(0. as AlgnScr, Dim4::new(&[scr_mt_dms[0] + 1, scr_mt_dms[1] + 1, 1, 1]));
let (gp_opn_pnlty, gp_extnsn_pnlty) = (algn_scr_schm.gp_opn_pnlty, algn_scr_schm.gp_extnsn_pnlty);
let dp_mt_dms = dp_mt.dims();
let mut src_mt = constant(DGNL, Dim4::new(&[1, dp_mt_dms[1], 1, 1]));
let rw_tl_dms = Dim4::new(&[1, scr_mt_dms[1], 1, 1]);
let cl_dms = Dim4::new(&[1, 1, 1, 1]);
for i in 1 .. dp_mt_dms[0] {
let prvs_rw = row(&dp_mt, i - 1);
let mut prvs_rw_tl = cols(&prvs_rw, 1, dp_mt_dms[1] - 1).copy();
replace(&mut prvs_rw_tl, &eq(&cols(&row(&src_mt, i - 1), 1, dp_mt_dms[1] - 1), &VRTCL, false), &add(&cols(&prvs_rw, 1, dp_mt_dms[1] - 1), &gp_opn_pnlty, false));
let prvs_rw_tl = add(&prvs_rw_tl, &gp_extnsn_pnlty, false);
let prvs_rw_hd = add(&cols(&prvs_rw, 0, dp_mt_dms[1] - 2), &cols(&row(&scr_mt, i - 1), 0, dp_mt_dms[1] - 2), false);
let mut nw_rw_tl = constant(DGNL, rw_tl_dms);
replace(&mut nw_rw_tl, &ge(&prvs_rw_hd, &prvs_rw_tl, false), &constant(VRTCL, rw_tl_dms));
src_mt = join(0, &src_mt, &join(1, &constant(DGNL, cl_dms), &nw_rw_tl));
dp_mt = set_row(&dp_mt, &join(1, &col(&row(&dp_mt, i), 0), &maxof(&maxof(&prvs_rw_tl, &prvs_rw_hd, false), &constant(0. as AlgnScr, rw_tl_dms), false)), i);
}
let clmn_tl_dms = Dim4::new(&[scr_mt_dms[0], 1, 1, 1]);
for i in 1 .. dp_mt_dms[1] {
let prvs_clmn = col(&dp_mt, i - 1);
let mut prvs_clmn_tl = rows(&prvs_clmn, 1, dp_mt_dms[0] - 1).copy();
replace(&mut prvs_clmn_tl, &eq(&rows(&col(&src_mt, i - 1), 1, dp_mt_dms[0] - 1), &HRZNTL, false), &add(&rows(&prvs_clmn, 1, dp_mt_dms[0] - 1), &gp_opn_pnlty, false));
let prvs_clmn_tl = add(&prvs_clmn_tl, &gp_extnsn_pnlty, false);
let prvs_clmn_hd = add(&rows(&prvs_clmn, 0, dp_mt_dms[0] - 2), &rows(&col(&scr_mt, i - 1), 0, dp_mt_dms[0] - 2), false);
let crnt_clmn_tl = rows(&col(&dp_mt, i), 1, dp_mt_dms[0] - 1);
let mx = maxof(&prvs_clmn_hd, &crnt_clmn_tl, false);
let mut src_clmn_tl = rows(&col(&src_mt, i), 1, dp_mt_dms[0] - 1).copy();
replace(&mut src_clmn_tl, >(&crnt_clmn_tl, &prvs_clmn_hd, false), &constant(DGNL, clmn_tl_dms));
replace(&mut src_clmn_tl, &ge(&mx, &prvs_clmn_tl, false), &constant(HRZNTL, clmn_tl_dms));
src_mt = set_col(&src_mt, &join(0, &row(&col(&src_mt, i), 0), &src_clmn_tl), i);
dp_mt = set_col(&dp_mt, &join(0, &row(&col(&dp_mt, i), 0), &maxof(&prvs_clmn_tl, &mx, false)), i);
}
let dp_mt_elm_nm = dp_mt.elements() as usize;
let mut cpu_dp_mt = vec![0. as AlgnScr; dp_mt_elm_nm];
let mut cpu_src_mt = vec![DGNL; dp_mt_elm_nm];
dp_mt.host(&mut cpu_dp_mt);
src_mt.host(&mut cpu_src_mt);
let mx_scr = max_all(&dp_mt).0 as AlgnScr;
let mut dp_mt = vec![vec![0. as AlgnScr; dp_mt_dms[1] as usize]; dp_mt_dms[0] as usize];
let mut src_mt = vec![vec![DGNL; dp_mt_dms[1] as usize]; dp_mt_dms[0] as usize];
let mut fnd_ps_pr = (0, 0);
let mut is_ps_pr_fnd = false;
for (i, (&dp_mt_elm, &src)) in cpu_dp_mt.iter().zip(cpu_src_mt.iter()).enumerate() {
let ps_pr = (i % (dp_mt_dms[0] as usize), i / (dp_mt_dms[0] as usize));
dp_mt[ps_pr.0][ps_pr.1] = dp_mt_elm;
src_mt[ps_pr.0][ps_pr.1] = src;
if !is_ps_pr_fnd && dp_mt[ps_pr.0][ps_pr.1] == mx_scr {
fnd_ps_pr = ps_pr;
is_ps_pr_fnd = true;
}
}
let mut pr_algn = (Vec::new(), Vec::new());
let (mut i, mut j) = fnd_ps_pr;
let (mut prvs_i, mut prvs_j) = (i, j);
while i > 0 || j > 0 {
if dp_mt[i][j] == 0. {
break;
}
prvs_i = i;
prvs_j = j;
if j == 0 {
pr_algn.0.insert(0, b_sq_pr.0[i - 1]);
pr_algn.1.insert(0, GP);
i -= 1;
continue;
} else if i == 0 {
pr_algn.0.insert(0, GP);
pr_algn.1.insert(0, b_sq_pr.1[j - 1]);
j -= 1;
continue;
}
let src = src_mt[i][j];
if src == DGNL {
pr_algn.0.insert(0, b_sq_pr.0[i - 1]);
pr_algn.1.insert(0, b_sq_pr.1[j - 1]);
i -= 1;
j -= 1;
} else if src == VRTCL {
pr_algn.0.insert(0, b_sq_pr.0[i - 1]);
pr_algn.1.insert(0, GP);
i -= 1;
} else {
pr_algn.0.insert(0, GP);
pr_algn.1.insert(0, b_sq_pr.1[j - 1]);
j -= 1;
}
}
let pr_algn = PrAlgn {
algn: pr_algn,
algn_ps_pr_pr: (AlgnPsPr {strt_ps: if prvs_i == 0 {0} else {prvs_i - 1}, end_ps: fnd_ps_pr.0 - 1}, AlgnPsPr {strt_ps: if prvs_j == 0 {0} else {prvs_j - 1}, end_ps: fnd_ps_pr.1 - 1}),
algn_scr: mx_scr,
};
pr_algn
}
pub fn gt_alphbt<'a>(is_dna: bool) -> Alphbt<'a> {
if is_dna {
b"ACGTURYSWKMBDHVNacgturyswkmbdhvn"
} else {
b"ARNDCEQGHILKMFPSTWYVarndceqghilkmfpstwyv"
}
}
fn gt_hsh_b_elm(b_elm: BElm, alphbt: Alphbt) -> BElm {
alphbt.iter().position(|&alphbt_elm| alphbt_elm == b_elm).expect("Failed to get hashed bio elem.") as BElm
}
#[cfg(test)]
mod tsts {
use super::{gpu_sw, gt_alphbt, SbstMt, AlgnScrSchm};
use super::arrayfire::{set_device, info};
use super::std::str::from_utf8;
#[test]
fn tst_gpu_sw() {
set_device(0);
info();
let is_dna = true;
let alphbt = gt_alphbt(is_dna);
let b_sq_pr = (&b"GGTTGACTA"[..], &b"TGTTACGG"[..]);
println!("Seq. pair to align:");
println!("{}", from_utf8(&b_sq_pr.0).expect("Failed to get Bio seq."));
println!("{}", from_utf8(&b_sq_pr.1).expect("Failed to get Bio seq."));
let mut sbst_mt = SbstMt::default();
for &alphbt_elm_1 in alphbt.iter() {
for &alphbt_elm_2 in alphbt.iter() {
sbst_mt.insert((alphbt_elm_1, alphbt_elm_2), if alphbt_elm_1 == alphbt_elm_2 {3.} else {-3.});
}
}
let algn_scr_schm = AlgnScrSchm::new(-0., -2.);
let pr_algn = gpu_sw(&b_sq_pr, &sbst_mt, &algn_scr_schm, is_dna);
println!("{:?}", &pr_algn);
println!("Alignment:");
println!("{}", from_utf8(&pr_algn.algn.0).expect("Failed to get alignment."));
println!("{}", from_utf8(&pr_algn.algn.1).expect("Failed to get alignment."));
}
}