1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
extern crate arrayfire;
extern crate fnv;

use arrayfire::{Array, Dim4, constant, add, col, row, cols, rows, join, lookup, set_col, set_row, maxof, replace, gt, ge, eq, max_all};
use std::hash::BuildHasherDefault;
use std::collections::HashMap;
use fnv::FnvHasher;

pub type BElm = u8;
type BSq = Vec<BElm>;
type AlgnPs = usize;
#[derive(Debug)]
struct AlgnPsPr {
  strt_ps: AlgnPs,
  end_ps: AlgnPs,
}
type AlgnScr = f32;
#[derive(Debug)]
pub struct PrAlgn {
  algn: (BSq, BSq),
  algn_ps_pr_pr: (AlgnPsPr, AlgnPsPr),
  algn_scr: AlgnScr,
}
type Hshr = BuildHasherDefault<FnvHasher>;
pub type SbstMt = HashMap<(BElm, BElm), AlgnScr, Hshr>;
pub struct AlgnScrSchm {
  gp_opn_pnlty: AlgnScr,
  gp_extnsn_pnlty: AlgnScr,
}
impl AlgnScrSchm {
  pub fn new(gp_opn_pnlty: AlgnScr, gp_extnsn_pnlty: AlgnScr) -> AlgnScrSchm {
    AlgnScrSchm {
      gp_opn_pnlty: gp_opn_pnlty,
      gp_extnsn_pnlty: gp_extnsn_pnlty,
    }
  }
}
type AfDm = u64;
pub type Alphbt<'a> = &'a[BElm];
type DpSrc = u32;

const GP: BElm = '-' as BElm;
const DGNL: DpSrc = 0;
const VRTCL: DpSrc = DGNL + 1;
const HRZNTL: DpSrc = VRTCL + 1;

/// Run the Smith Waterman algorithm on GPU.
/// # Examples
///
/// ```rust
/// use self::gpu_sw::{gpu_sw, gt_alphbt, SbstMt, AlgnScrSchm};
///
/// let is_dna = true;
/// let alphbt = gt_alphbt(is_dna);
/// let b_sq_pr = (&b"GGTTGACTA"[..], &b"TGTTACGG"[..]);
/// let mut sbst_mt = SbstMt::default();
/// for &alphbt_elm_1 in alphbt.iter() {
///   for &alphbt_elm_2 in alphbt.iter() {
///     sbst_mt.insert((alphbt_elm_1, alphbt_elm_2), if alphbt_elm_1 == alphbt_elm_2 {1.} else {-1.});
///   }
/// }
/// let algn_scr_schm = AlgnScrSchm::new(-7., -1.);
/// let pr_algn = gpu_sw(&b_sq_pr, &sbst_mt, &algn_scr_schm, is_dna);
/// println!("{:?}.", &pr_algn);
/// ```
///
pub fn gpu_sw(b_sq_pr: &(&[BElm], &[BElm]), sbst_mt: &SbstMt, algn_scr_schm: &AlgnScrSchm, is_dna: bool) -> PrAlgn {
  let b_sq_ln_pr = (b_sq_pr.0.len(), b_sq_pr.1.len());
  let alphbt = gt_alphbt(is_dna);
  let gpu_b_sq_pr = (
    Array::new(&b_sq_pr.0.iter().map(|&b_elm| gt_hsh_b_elm(b_elm, alphbt)).collect::<BSq>(), Dim4::new(&[b_sq_ln_pr.0 as AfDm, 1, 1, 1])),
    Array::new(&b_sq_pr.1.iter().map(|&b_elm| gt_hsh_b_elm(b_elm, alphbt)).collect::<BSq>(), Dim4::new(&[b_sq_ln_pr.1 as AfDm, 1, 1, 1])),
  );
  let alphbt_ln = alphbt.len();
  let mut hsh_sbst_mt = vec![vec![0.; alphbt_ln]; alphbt_ln];
  for (b_elm_pr, &sbst_scr) in sbst_mt {
    hsh_sbst_mt[gt_hsh_b_elm(b_elm_pr.0, alphbt) as usize][gt_hsh_b_elm(b_elm_pr.1, alphbt) as usize] = sbst_scr;
  }
  let hsh_sbst_mt = hsh_sbst_mt.iter().flat_map(|sbst_scr| sbst_scr.clone()).collect::<Vec<AlgnScr>>();
  let gpu_sbst_mt = Array::new(&hsh_sbst_mt, Dim4::new(&[alphbt_ln as AfDm, alphbt_ln as AfDm, 1, 1]));
  let scr_mt = lookup(&lookup(&gpu_sbst_mt, &gpu_b_sq_pr.0, 0), &gpu_b_sq_pr.1, 1);
  let scr_mt_dms = scr_mt.dims();
  let mut dp_mt = constant(0. as AlgnScr, Dim4::new(&[scr_mt_dms[0] + 1, scr_mt_dms[1] + 1, 1, 1]));
  let (gp_opn_pnlty, gp_extnsn_pnlty) = (algn_scr_schm.gp_opn_pnlty, algn_scr_schm.gp_extnsn_pnlty);
  let dp_mt_dms = dp_mt.dims();
  let mut src_mt = constant(DGNL, Dim4::new(&[1, dp_mt_dms[1], 1, 1]));
  let rw_tl_dms = Dim4::new(&[1, scr_mt_dms[1], 1, 1]);
  let cl_dms = Dim4::new(&[1, 1, 1, 1]);
  for i in 1 .. dp_mt_dms[0] {
    let prvs_rw = row(&dp_mt, i - 1);
    let mut prvs_rw_tl = cols(&prvs_rw, 1, dp_mt_dms[1] - 1).copy();
    replace(&mut prvs_rw_tl, &eq(&cols(&row(&src_mt, i - 1), 1, dp_mt_dms[1] - 1), &VRTCL, false), &add(&cols(&prvs_rw, 1, dp_mt_dms[1] - 1), &gp_opn_pnlty, false));
    let prvs_rw_tl = add(&prvs_rw_tl, &gp_extnsn_pnlty, false);
    let prvs_rw_hd = add(&cols(&prvs_rw, 0, dp_mt_dms[1] - 2), &cols(&row(&scr_mt, i - 1), 0, dp_mt_dms[1] - 2), false);
    let mut nw_rw_tl = constant(DGNL, rw_tl_dms);
    replace(&mut nw_rw_tl, &ge(&prvs_rw_hd, &prvs_rw_tl, false), &constant(VRTCL, rw_tl_dms));
    src_mt = join(0, &src_mt, &join(1, &constant(DGNL, cl_dms), &nw_rw_tl));
    dp_mt = set_row(&dp_mt, &join(1, &col(&row(&dp_mt, i), 0), &maxof(&maxof(&prvs_rw_tl, &prvs_rw_hd, false), &constant(0. as AlgnScr, rw_tl_dms), false)), i);
  }
  let clmn_tl_dms = Dim4::new(&[scr_mt_dms[0], 1, 1, 1]);
  for i in 1 .. dp_mt_dms[1] {
    let prvs_clmn = col(&dp_mt, i - 1);
    let mut prvs_clmn_tl = rows(&prvs_clmn, 1, dp_mt_dms[0] - 1).copy();
    replace(&mut prvs_clmn_tl, &eq(&rows(&col(&src_mt, i - 1), 1, dp_mt_dms[0] - 1), &HRZNTL, false), &add(&rows(&prvs_clmn, 1, dp_mt_dms[0] - 1), &gp_opn_pnlty, false));
    let prvs_clmn_tl = add(&prvs_clmn_tl, &gp_extnsn_pnlty, false);
    let prvs_clmn_hd = add(&rows(&prvs_clmn, 0, dp_mt_dms[0] - 2), &rows(&col(&scr_mt, i - 1), 0, dp_mt_dms[0] - 2), false);
    let crnt_clmn_tl = rows(&col(&dp_mt, i), 1, dp_mt_dms[0] - 1);
    let mx = maxof(&prvs_clmn_hd, &crnt_clmn_tl, false);
    let mut src_clmn_tl = rows(&col(&src_mt, i), 1, dp_mt_dms[0] - 1).copy();
    replace(&mut src_clmn_tl, &gt(&crnt_clmn_tl, &prvs_clmn_hd, false), &constant(DGNL, clmn_tl_dms));
    replace(&mut src_clmn_tl, &ge(&mx, &prvs_clmn_tl, false), &constant(HRZNTL, clmn_tl_dms));
    src_mt = set_col(&src_mt, &join(0, &row(&col(&src_mt, i), 0), &src_clmn_tl), i);
    dp_mt = set_col(&dp_mt, &join(0, &row(&col(&dp_mt, i), 0), &maxof(&prvs_clmn_tl, &mx, false)), i);
  }
  let dp_mt_elm_nm = dp_mt.elements() as usize;
  let mut cpu_dp_mt = vec![0. as AlgnScr; dp_mt_elm_nm];
  let mut cpu_src_mt = vec![DGNL; dp_mt_elm_nm];
  dp_mt.host(&mut cpu_dp_mt);
  src_mt.host(&mut cpu_src_mt);
  let mx_scr = max_all(&dp_mt).0 as AlgnScr;
  let mut dp_mt = vec![vec![0. as AlgnScr; dp_mt_dms[1] as usize]; dp_mt_dms[0] as usize];
  let mut src_mt = vec![vec![DGNL; dp_mt_dms[1] as usize]; dp_mt_dms[0] as usize];
  let mut fnd_ps_pr = (0, 0);
  let mut is_ps_pr_fnd = false;
  for (i, (&dp_mt_elm, &src)) in cpu_dp_mt.iter().zip(cpu_src_mt.iter()).enumerate() {
    let ps_pr = (i % (dp_mt_dms[0] as usize), i / (dp_mt_dms[0] as usize));
    dp_mt[ps_pr.0][ps_pr.1] = dp_mt_elm;
    src_mt[ps_pr.0][ps_pr.1] = src;
    if !is_ps_pr_fnd && dp_mt[ps_pr.0][ps_pr.1] == mx_scr {
      fnd_ps_pr = ps_pr;
      is_ps_pr_fnd = true;
    }
  }
  let mut pr_algn = (Vec::new(), Vec::new());
  let (mut i, mut j) = fnd_ps_pr;
  let (mut prvs_i, mut prvs_j) = (i, j);
  while i > 0 || j > 0 {
    if dp_mt[i][j] == 0. {
      break;
    }
    prvs_i = i;
    prvs_j = j;
    if j == 0 {
      pr_algn.0.insert(0, b_sq_pr.0[i - 1]);
      pr_algn.1.insert(0, GP);
      i -= 1;
      continue;
    } else if i == 0 {
      pr_algn.0.insert(0, GP);
      pr_algn.1.insert(0, b_sq_pr.1[j - 1]);
      j -= 1;
      continue;
    }
    let src = src_mt[i][j];
    if src == DGNL {
      pr_algn.0.insert(0, b_sq_pr.0[i - 1]);
      pr_algn.1.insert(0, b_sq_pr.1[j - 1]);
      i -= 1;
      j -= 1;
    } else if src == VRTCL {
      pr_algn.0.insert(0, b_sq_pr.0[i - 1]);
      pr_algn.1.insert(0, GP);
      i -= 1;
    } else {
      pr_algn.0.insert(0, GP);
      pr_algn.1.insert(0, b_sq_pr.1[j - 1]);
      j -= 1;
    }
  }
  let pr_algn = PrAlgn {
    algn: pr_algn,
    algn_ps_pr_pr: (AlgnPsPr {strt_ps: if prvs_i == 0 {0} else {prvs_i - 1}, end_ps: fnd_ps_pr.0 - 1}, AlgnPsPr {strt_ps: if prvs_j == 0 {0} else {prvs_j - 1}, end_ps: fnd_ps_pr.1 - 1}),
    algn_scr: mx_scr,
  };
  pr_algn
}

pub fn gt_alphbt<'a>(is_dna: bool) -> Alphbt<'a> {
  if is_dna {
    b"ACGTURYSWKMBDHVNacgturyswkmbdhvn"
  } else {
    b"ARNDCEQGHILKMFPSTWYVarndceqghilkmfpstwyv"
  }
}

fn gt_hsh_b_elm(b_elm: BElm, alphbt: Alphbt) -> BElm {
  alphbt.iter().position(|&alphbt_elm| alphbt_elm == b_elm).expect("Failed to get hashed bio elem.") as BElm
}

#[cfg(test)]
mod tsts {
  use super::{gpu_sw, gt_alphbt, SbstMt, AlgnScrSchm};
  use super::arrayfire::{set_device, info};
  use super::std::str::from_utf8;

  #[test]
  fn tst_gpu_sw() {
    set_device(0);
    info();
    let is_dna = true;
    let alphbt = gt_alphbt(is_dna);
    let b_sq_pr = (&b"GGTTGACTA"[..], &b"TGTTACGG"[..]);
    println!("Seq. pair to align:");
    println!("{}", from_utf8(&b_sq_pr.0).expect("Failed to get Bio seq."));
    println!("{}", from_utf8(&b_sq_pr.1).expect("Failed to get Bio seq."));
    let mut sbst_mt = SbstMt::default();
    for &alphbt_elm_1 in alphbt.iter() {
      for &alphbt_elm_2 in alphbt.iter() {
        sbst_mt.insert((alphbt_elm_1, alphbt_elm_2), if alphbt_elm_1 == alphbt_elm_2 {3.} else {-3.});
      }
    }
    let algn_scr_schm = AlgnScrSchm::new(-0., -2.);
    let pr_algn = gpu_sw(&b_sq_pr, &sbst_mt, &algn_scr_schm, is_dna);
    println!("{:?}", &pr_algn);
    println!("Alignment:");
    println!("{}", from_utf8(&pr_algn.algn.0).expect("Failed to get alignment."));
    println!("{}", from_utf8(&pr_algn.algn.1).expect("Failed to get alignment."));
  }
}