gpu_sw/
lb.rs

1extern crate arrayfire;
2extern crate fnv;
3
4use arrayfire::{Array, Dim4, constant, add, col, row, cols, rows, join, lookup, set_col, set_row, maxof, replace, gt, ge, eq, max_all};
5use std::hash::BuildHasherDefault;
6use std::collections::HashMap;
7use fnv::FnvHasher;
8
9pub type BElm = u8;
10type BSq = Vec<BElm>;
11type AlgnPs = usize;
12#[derive(Debug)]
13struct AlgnPsPr {
14  strt_ps: AlgnPs,
15  end_ps: AlgnPs,
16}
17type AlgnScr = f32;
18#[derive(Debug)]
19pub struct PrAlgn {
20  algn: (BSq, BSq),
21  algn_ps_pr_pr: (AlgnPsPr, AlgnPsPr),
22  algn_scr: AlgnScr,
23}
24type Hshr = BuildHasherDefault<FnvHasher>;
25pub type SbstMt = HashMap<(BElm, BElm), AlgnScr, Hshr>;
26pub struct AlgnScrSchm {
27  gp_opn_pnlty: AlgnScr,
28  gp_extnsn_pnlty: AlgnScr,
29}
30impl AlgnScrSchm {
31  pub fn new(gp_opn_pnlty: AlgnScr, gp_extnsn_pnlty: AlgnScr) -> AlgnScrSchm {
32    AlgnScrSchm {
33      gp_opn_pnlty: gp_opn_pnlty,
34      gp_extnsn_pnlty: gp_extnsn_pnlty,
35    }
36  }
37}
38type AfDm = u64;
39pub type Alphbt<'a> = &'a[BElm];
40type DpSrc = u32;
41
42const GP: BElm = '-' as BElm;
43const DGNL: DpSrc = 0;
44const VRTCL: DpSrc = DGNL + 1;
45const HRZNTL: DpSrc = VRTCL + 1;
46
47/// Run the Smith Waterman algorithm on GPU.
48/// # Examples
49///
50/// ```rust
51/// use self::gpu_sw::{gpu_sw, gt_alphbt, SbstMt, AlgnScrSchm};
52///
53/// let is_dna = true;
54/// let alphbt = gt_alphbt(is_dna);
55/// let b_sq_pr = (&b"GGTTGACTA"[..], &b"TGTTACGG"[..]);
56/// let mut sbst_mt = SbstMt::default();
57/// for &alphbt_elm_1 in alphbt.iter() {
58///   for &alphbt_elm_2 in alphbt.iter() {
59///     sbst_mt.insert((alphbt_elm_1, alphbt_elm_2), if alphbt_elm_1 == alphbt_elm_2 {1.} else {-1.});
60///   }
61/// }
62/// let algn_scr_schm = AlgnScrSchm::new(-7., -1.);
63/// let pr_algn = gpu_sw(&b_sq_pr, &sbst_mt, &algn_scr_schm, is_dna);
64/// println!("{:?}.", &pr_algn);
65/// ```
66///
67pub fn gpu_sw(b_sq_pr: &(&[BElm], &[BElm]), sbst_mt: &SbstMt, algn_scr_schm: &AlgnScrSchm, is_dna: bool) -> PrAlgn {
68  let b_sq_ln_pr = (b_sq_pr.0.len(), b_sq_pr.1.len());
69  let alphbt = gt_alphbt(is_dna);
70  let gpu_b_sq_pr = (
71    Array::new(&b_sq_pr.0.iter().map(|&b_elm| gt_hsh_b_elm(b_elm, alphbt)).collect::<BSq>(), Dim4::new(&[b_sq_ln_pr.0 as AfDm, 1, 1, 1])),
72    Array::new(&b_sq_pr.1.iter().map(|&b_elm| gt_hsh_b_elm(b_elm, alphbt)).collect::<BSq>(), Dim4::new(&[b_sq_ln_pr.1 as AfDm, 1, 1, 1])),
73  );
74  let alphbt_ln = alphbt.len();
75  let mut hsh_sbst_mt = vec![vec![0.; alphbt_ln]; alphbt_ln];
76  for (b_elm_pr, &sbst_scr) in sbst_mt {
77    hsh_sbst_mt[gt_hsh_b_elm(b_elm_pr.0, alphbt) as usize][gt_hsh_b_elm(b_elm_pr.1, alphbt) as usize] = sbst_scr;
78  }
79  let hsh_sbst_mt = hsh_sbst_mt.iter().flat_map(|sbst_scr| sbst_scr.clone()).collect::<Vec<AlgnScr>>();
80  let gpu_sbst_mt = Array::new(&hsh_sbst_mt, Dim4::new(&[alphbt_ln as AfDm, alphbt_ln as AfDm, 1, 1]));
81  let scr_mt = lookup(&lookup(&gpu_sbst_mt, &gpu_b_sq_pr.0, 0), &gpu_b_sq_pr.1, 1);
82  let scr_mt_dms = scr_mt.dims();
83  let mut dp_mt = constant(0. as AlgnScr, Dim4::new(&[scr_mt_dms[0] + 1, scr_mt_dms[1] + 1, 1, 1]));
84  let (gp_opn_pnlty, gp_extnsn_pnlty) = (algn_scr_schm.gp_opn_pnlty, algn_scr_schm.gp_extnsn_pnlty);
85  let dp_mt_dms = dp_mt.dims();
86  let mut src_mt = constant(DGNL, Dim4::new(&[1, dp_mt_dms[1], 1, 1]));
87  let rw_tl_dms = Dim4::new(&[1, scr_mt_dms[1], 1, 1]);
88  let cl_dms = Dim4::new(&[1, 1, 1, 1]);
89  for i in 1 .. dp_mt_dms[0] {
90    let prvs_rw = row(&dp_mt, i - 1);
91    let mut prvs_rw_tl = cols(&prvs_rw, 1, dp_mt_dms[1] - 1).copy();
92    replace(&mut prvs_rw_tl, &eq(&cols(&row(&src_mt, i - 1), 1, dp_mt_dms[1] - 1), &VRTCL, false), &add(&cols(&prvs_rw, 1, dp_mt_dms[1] - 1), &gp_opn_pnlty, false));
93    let prvs_rw_tl = add(&prvs_rw_tl, &gp_extnsn_pnlty, false);
94    let prvs_rw_hd = add(&cols(&prvs_rw, 0, dp_mt_dms[1] - 2), &cols(&row(&scr_mt, i - 1), 0, dp_mt_dms[1] - 2), false);
95    let mut nw_rw_tl = constant(DGNL, rw_tl_dms);
96    replace(&mut nw_rw_tl, &ge(&prvs_rw_hd, &prvs_rw_tl, false), &constant(VRTCL, rw_tl_dms));
97    src_mt = join(0, &src_mt, &join(1, &constant(DGNL, cl_dms), &nw_rw_tl));
98    dp_mt = set_row(&dp_mt, &join(1, &col(&row(&dp_mt, i), 0), &maxof(&maxof(&prvs_rw_tl, &prvs_rw_hd, false), &constant(0. as AlgnScr, rw_tl_dms), false)), i);
99  }
100  let clmn_tl_dms = Dim4::new(&[scr_mt_dms[0], 1, 1, 1]);
101  for i in 1 .. dp_mt_dms[1] {
102    let prvs_clmn = col(&dp_mt, i - 1);
103    let mut prvs_clmn_tl = rows(&prvs_clmn, 1, dp_mt_dms[0] - 1).copy();
104    replace(&mut prvs_clmn_tl, &eq(&rows(&col(&src_mt, i - 1), 1, dp_mt_dms[0] - 1), &HRZNTL, false), &add(&rows(&prvs_clmn, 1, dp_mt_dms[0] - 1), &gp_opn_pnlty, false));
105    let prvs_clmn_tl = add(&prvs_clmn_tl, &gp_extnsn_pnlty, false);
106    let prvs_clmn_hd = add(&rows(&prvs_clmn, 0, dp_mt_dms[0] - 2), &rows(&col(&scr_mt, i - 1), 0, dp_mt_dms[0] - 2), false);
107    let crnt_clmn_tl = rows(&col(&dp_mt, i), 1, dp_mt_dms[0] - 1);
108    let mx = maxof(&prvs_clmn_hd, &crnt_clmn_tl, false);
109    let mut src_clmn_tl = rows(&col(&src_mt, i), 1, dp_mt_dms[0] - 1).copy();
110    replace(&mut src_clmn_tl, &gt(&crnt_clmn_tl, &prvs_clmn_hd, false), &constant(DGNL, clmn_tl_dms));
111    replace(&mut src_clmn_tl, &ge(&mx, &prvs_clmn_tl, false), &constant(HRZNTL, clmn_tl_dms));
112    src_mt = set_col(&src_mt, &join(0, &row(&col(&src_mt, i), 0), &src_clmn_tl), i);
113    dp_mt = set_col(&dp_mt, &join(0, &row(&col(&dp_mt, i), 0), &maxof(&prvs_clmn_tl, &mx, false)), i);
114  }
115  let dp_mt_elm_nm = dp_mt.elements() as usize;
116  let mut cpu_dp_mt = vec![0. as AlgnScr; dp_mt_elm_nm];
117  let mut cpu_src_mt = vec![DGNL; dp_mt_elm_nm];
118  dp_mt.host(&mut cpu_dp_mt);
119  src_mt.host(&mut cpu_src_mt);
120  let mx_scr = max_all(&dp_mt).0 as AlgnScr;
121  let mut dp_mt = vec![vec![0. as AlgnScr; dp_mt_dms[1] as usize]; dp_mt_dms[0] as usize];
122  let mut src_mt = vec![vec![DGNL; dp_mt_dms[1] as usize]; dp_mt_dms[0] as usize];
123  let mut fnd_ps_pr = (0, 0);
124  let mut is_ps_pr_fnd = false;
125  for (i, (&dp_mt_elm, &src)) in cpu_dp_mt.iter().zip(cpu_src_mt.iter()).enumerate() {
126    let ps_pr = (i % (dp_mt_dms[0] as usize), i / (dp_mt_dms[0] as usize));
127    dp_mt[ps_pr.0][ps_pr.1] = dp_mt_elm;
128    src_mt[ps_pr.0][ps_pr.1] = src;
129    if !is_ps_pr_fnd && dp_mt[ps_pr.0][ps_pr.1] == mx_scr {
130      fnd_ps_pr = ps_pr;
131      is_ps_pr_fnd = true;
132    }
133  }
134  let mut pr_algn = (Vec::new(), Vec::new());
135  let (mut i, mut j) = fnd_ps_pr;
136  let (mut prvs_i, mut prvs_j) = (i, j);
137  while i > 0 || j > 0 {
138    if dp_mt[i][j] == 0. {
139      break;
140    }
141    prvs_i = i;
142    prvs_j = j;
143    if j == 0 {
144      pr_algn.0.insert(0, b_sq_pr.0[i - 1]);
145      pr_algn.1.insert(0, GP);
146      i -= 1;
147      continue;
148    } else if i == 0 {
149      pr_algn.0.insert(0, GP);
150      pr_algn.1.insert(0, b_sq_pr.1[j - 1]);
151      j -= 1;
152      continue;
153    }
154    let src = src_mt[i][j];
155    if src == DGNL {
156      pr_algn.0.insert(0, b_sq_pr.0[i - 1]);
157      pr_algn.1.insert(0, b_sq_pr.1[j - 1]);
158      i -= 1;
159      j -= 1;
160    } else if src == VRTCL {
161      pr_algn.0.insert(0, b_sq_pr.0[i - 1]);
162      pr_algn.1.insert(0, GP);
163      i -= 1;
164    } else {
165      pr_algn.0.insert(0, GP);
166      pr_algn.1.insert(0, b_sq_pr.1[j - 1]);
167      j -= 1;
168    }
169  }
170  let pr_algn = PrAlgn {
171    algn: pr_algn,
172    algn_ps_pr_pr: (AlgnPsPr {strt_ps: if prvs_i == 0 {0} else {prvs_i - 1}, end_ps: fnd_ps_pr.0 - 1}, AlgnPsPr {strt_ps: if prvs_j == 0 {0} else {prvs_j - 1}, end_ps: fnd_ps_pr.1 - 1}),
173    algn_scr: mx_scr,
174  };
175  pr_algn
176}
177
178pub fn gt_alphbt<'a>(is_dna: bool) -> Alphbt<'a> {
179  if is_dna {
180    b"ACGTURYSWKMBDHVNacgturyswkmbdhvn"
181  } else {
182    b"ARNDCEQGHILKMFPSTWYVarndceqghilkmfpstwyv"
183  }
184}
185
186fn gt_hsh_b_elm(b_elm: BElm, alphbt: Alphbt) -> BElm {
187  alphbt.iter().position(|&alphbt_elm| alphbt_elm == b_elm).expect("Failed to get hashed bio elem.") as BElm
188}
189
190#[cfg(test)]
191mod tsts {
192  use super::{gpu_sw, gt_alphbt, SbstMt, AlgnScrSchm};
193  use super::arrayfire::{set_device, info};
194  use super::std::str::from_utf8;
195
196  #[test]
197  fn tst_gpu_sw() {
198    set_device(0);
199    info();
200    let is_dna = true;
201    let alphbt = gt_alphbt(is_dna);
202    let b_sq_pr = (&b"GGTTGACTA"[..], &b"TGTTACGG"[..]);
203    println!("Seq. pair to align:");
204    println!("{}", from_utf8(&b_sq_pr.0).expect("Failed to get Bio seq."));
205    println!("{}", from_utf8(&b_sq_pr.1).expect("Failed to get Bio seq."));
206    let mut sbst_mt = SbstMt::default();
207    for &alphbt_elm_1 in alphbt.iter() {
208      for &alphbt_elm_2 in alphbt.iter() {
209        sbst_mt.insert((alphbt_elm_1, alphbt_elm_2), if alphbt_elm_1 == alphbt_elm_2 {3.} else {-3.});
210      }
211    }
212    let algn_scr_schm = AlgnScrSchm::new(-0., -2.);
213    let pr_algn = gpu_sw(&b_sq_pr, &sbst_mt, &algn_scr_schm, is_dna);
214    println!("{:?}", &pr_algn);
215    println!("Alignment:");
216    println!("{}", from_utf8(&pr_algn.algn.0).expect("Failed to get alignment."));
217    println!("{}", from_utf8(&pr_algn.algn.1).expect("Failed to get alignment."));
218  }
219}