1extern crate arrayfire;
2extern crate fnv;
3
4use arrayfire::{Array, Dim4, constant, add, col, row, cols, rows, join, lookup, set_col, set_row, maxof, replace, gt, ge, eq, max_all};
5use std::hash::BuildHasherDefault;
6use std::collections::HashMap;
7use fnv::FnvHasher;
8
9pub type BElm = u8;
10type BSq = Vec<BElm>;
11type AlgnPs = usize;
12#[derive(Debug)]
13struct AlgnPsPr {
14 strt_ps: AlgnPs,
15 end_ps: AlgnPs,
16}
17type AlgnScr = f32;
18#[derive(Debug)]
19pub struct PrAlgn {
20 algn: (BSq, BSq),
21 algn_ps_pr_pr: (AlgnPsPr, AlgnPsPr),
22 algn_scr: AlgnScr,
23}
24type Hshr = BuildHasherDefault<FnvHasher>;
25pub type SbstMt = HashMap<(BElm, BElm), AlgnScr, Hshr>;
26pub struct AlgnScrSchm {
27 gp_opn_pnlty: AlgnScr,
28 gp_extnsn_pnlty: AlgnScr,
29}
30impl AlgnScrSchm {
31 pub fn new(gp_opn_pnlty: AlgnScr, gp_extnsn_pnlty: AlgnScr) -> AlgnScrSchm {
32 AlgnScrSchm {
33 gp_opn_pnlty: gp_opn_pnlty,
34 gp_extnsn_pnlty: gp_extnsn_pnlty,
35 }
36 }
37}
38type AfDm = u64;
39pub type Alphbt<'a> = &'a[BElm];
40type DpSrc = u32;
41
42const GP: BElm = '-' as BElm;
43const DGNL: DpSrc = 0;
44const VRTCL: DpSrc = DGNL + 1;
45const HRZNTL: DpSrc = VRTCL + 1;
46
47pub fn gpu_sw(b_sq_pr: &(&[BElm], &[BElm]), sbst_mt: &SbstMt, algn_scr_schm: &AlgnScrSchm, is_dna: bool) -> PrAlgn {
68 let b_sq_ln_pr = (b_sq_pr.0.len(), b_sq_pr.1.len());
69 let alphbt = gt_alphbt(is_dna);
70 let gpu_b_sq_pr = (
71 Array::new(&b_sq_pr.0.iter().map(|&b_elm| gt_hsh_b_elm(b_elm, alphbt)).collect::<BSq>(), Dim4::new(&[b_sq_ln_pr.0 as AfDm, 1, 1, 1])),
72 Array::new(&b_sq_pr.1.iter().map(|&b_elm| gt_hsh_b_elm(b_elm, alphbt)).collect::<BSq>(), Dim4::new(&[b_sq_ln_pr.1 as AfDm, 1, 1, 1])),
73 );
74 let alphbt_ln = alphbt.len();
75 let mut hsh_sbst_mt = vec![vec![0.; alphbt_ln]; alphbt_ln];
76 for (b_elm_pr, &sbst_scr) in sbst_mt {
77 hsh_sbst_mt[gt_hsh_b_elm(b_elm_pr.0, alphbt) as usize][gt_hsh_b_elm(b_elm_pr.1, alphbt) as usize] = sbst_scr;
78 }
79 let hsh_sbst_mt = hsh_sbst_mt.iter().flat_map(|sbst_scr| sbst_scr.clone()).collect::<Vec<AlgnScr>>();
80 let gpu_sbst_mt = Array::new(&hsh_sbst_mt, Dim4::new(&[alphbt_ln as AfDm, alphbt_ln as AfDm, 1, 1]));
81 let scr_mt = lookup(&lookup(&gpu_sbst_mt, &gpu_b_sq_pr.0, 0), &gpu_b_sq_pr.1, 1);
82 let scr_mt_dms = scr_mt.dims();
83 let mut dp_mt = constant(0. as AlgnScr, Dim4::new(&[scr_mt_dms[0] + 1, scr_mt_dms[1] + 1, 1, 1]));
84 let (gp_opn_pnlty, gp_extnsn_pnlty) = (algn_scr_schm.gp_opn_pnlty, algn_scr_schm.gp_extnsn_pnlty);
85 let dp_mt_dms = dp_mt.dims();
86 let mut src_mt = constant(DGNL, Dim4::new(&[1, dp_mt_dms[1], 1, 1]));
87 let rw_tl_dms = Dim4::new(&[1, scr_mt_dms[1], 1, 1]);
88 let cl_dms = Dim4::new(&[1, 1, 1, 1]);
89 for i in 1 .. dp_mt_dms[0] {
90 let prvs_rw = row(&dp_mt, i - 1);
91 let mut prvs_rw_tl = cols(&prvs_rw, 1, dp_mt_dms[1] - 1).copy();
92 replace(&mut prvs_rw_tl, &eq(&cols(&row(&src_mt, i - 1), 1, dp_mt_dms[1] - 1), &VRTCL, false), &add(&cols(&prvs_rw, 1, dp_mt_dms[1] - 1), &gp_opn_pnlty, false));
93 let prvs_rw_tl = add(&prvs_rw_tl, &gp_extnsn_pnlty, false);
94 let prvs_rw_hd = add(&cols(&prvs_rw, 0, dp_mt_dms[1] - 2), &cols(&row(&scr_mt, i - 1), 0, dp_mt_dms[1] - 2), false);
95 let mut nw_rw_tl = constant(DGNL, rw_tl_dms);
96 replace(&mut nw_rw_tl, &ge(&prvs_rw_hd, &prvs_rw_tl, false), &constant(VRTCL, rw_tl_dms));
97 src_mt = join(0, &src_mt, &join(1, &constant(DGNL, cl_dms), &nw_rw_tl));
98 dp_mt = set_row(&dp_mt, &join(1, &col(&row(&dp_mt, i), 0), &maxof(&maxof(&prvs_rw_tl, &prvs_rw_hd, false), &constant(0. as AlgnScr, rw_tl_dms), false)), i);
99 }
100 let clmn_tl_dms = Dim4::new(&[scr_mt_dms[0], 1, 1, 1]);
101 for i in 1 .. dp_mt_dms[1] {
102 let prvs_clmn = col(&dp_mt, i - 1);
103 let mut prvs_clmn_tl = rows(&prvs_clmn, 1, dp_mt_dms[0] - 1).copy();
104 replace(&mut prvs_clmn_tl, &eq(&rows(&col(&src_mt, i - 1), 1, dp_mt_dms[0] - 1), &HRZNTL, false), &add(&rows(&prvs_clmn, 1, dp_mt_dms[0] - 1), &gp_opn_pnlty, false));
105 let prvs_clmn_tl = add(&prvs_clmn_tl, &gp_extnsn_pnlty, false);
106 let prvs_clmn_hd = add(&rows(&prvs_clmn, 0, dp_mt_dms[0] - 2), &rows(&col(&scr_mt, i - 1), 0, dp_mt_dms[0] - 2), false);
107 let crnt_clmn_tl = rows(&col(&dp_mt, i), 1, dp_mt_dms[0] - 1);
108 let mx = maxof(&prvs_clmn_hd, &crnt_clmn_tl, false);
109 let mut src_clmn_tl = rows(&col(&src_mt, i), 1, dp_mt_dms[0] - 1).copy();
110 replace(&mut src_clmn_tl, >(&crnt_clmn_tl, &prvs_clmn_hd, false), &constant(DGNL, clmn_tl_dms));
111 replace(&mut src_clmn_tl, &ge(&mx, &prvs_clmn_tl, false), &constant(HRZNTL, clmn_tl_dms));
112 src_mt = set_col(&src_mt, &join(0, &row(&col(&src_mt, i), 0), &src_clmn_tl), i);
113 dp_mt = set_col(&dp_mt, &join(0, &row(&col(&dp_mt, i), 0), &maxof(&prvs_clmn_tl, &mx, false)), i);
114 }
115 let dp_mt_elm_nm = dp_mt.elements() as usize;
116 let mut cpu_dp_mt = vec![0. as AlgnScr; dp_mt_elm_nm];
117 let mut cpu_src_mt = vec![DGNL; dp_mt_elm_nm];
118 dp_mt.host(&mut cpu_dp_mt);
119 src_mt.host(&mut cpu_src_mt);
120 let mx_scr = max_all(&dp_mt).0 as AlgnScr;
121 let mut dp_mt = vec![vec![0. as AlgnScr; dp_mt_dms[1] as usize]; dp_mt_dms[0] as usize];
122 let mut src_mt = vec![vec![DGNL; dp_mt_dms[1] as usize]; dp_mt_dms[0] as usize];
123 let mut fnd_ps_pr = (0, 0);
124 let mut is_ps_pr_fnd = false;
125 for (i, (&dp_mt_elm, &src)) in cpu_dp_mt.iter().zip(cpu_src_mt.iter()).enumerate() {
126 let ps_pr = (i % (dp_mt_dms[0] as usize), i / (dp_mt_dms[0] as usize));
127 dp_mt[ps_pr.0][ps_pr.1] = dp_mt_elm;
128 src_mt[ps_pr.0][ps_pr.1] = src;
129 if !is_ps_pr_fnd && dp_mt[ps_pr.0][ps_pr.1] == mx_scr {
130 fnd_ps_pr = ps_pr;
131 is_ps_pr_fnd = true;
132 }
133 }
134 let mut pr_algn = (Vec::new(), Vec::new());
135 let (mut i, mut j) = fnd_ps_pr;
136 let (mut prvs_i, mut prvs_j) = (i, j);
137 while i > 0 || j > 0 {
138 if dp_mt[i][j] == 0. {
139 break;
140 }
141 prvs_i = i;
142 prvs_j = j;
143 if j == 0 {
144 pr_algn.0.insert(0, b_sq_pr.0[i - 1]);
145 pr_algn.1.insert(0, GP);
146 i -= 1;
147 continue;
148 } else if i == 0 {
149 pr_algn.0.insert(0, GP);
150 pr_algn.1.insert(0, b_sq_pr.1[j - 1]);
151 j -= 1;
152 continue;
153 }
154 let src = src_mt[i][j];
155 if src == DGNL {
156 pr_algn.0.insert(0, b_sq_pr.0[i - 1]);
157 pr_algn.1.insert(0, b_sq_pr.1[j - 1]);
158 i -= 1;
159 j -= 1;
160 } else if src == VRTCL {
161 pr_algn.0.insert(0, b_sq_pr.0[i - 1]);
162 pr_algn.1.insert(0, GP);
163 i -= 1;
164 } else {
165 pr_algn.0.insert(0, GP);
166 pr_algn.1.insert(0, b_sq_pr.1[j - 1]);
167 j -= 1;
168 }
169 }
170 let pr_algn = PrAlgn {
171 algn: pr_algn,
172 algn_ps_pr_pr: (AlgnPsPr {strt_ps: if prvs_i == 0 {0} else {prvs_i - 1}, end_ps: fnd_ps_pr.0 - 1}, AlgnPsPr {strt_ps: if prvs_j == 0 {0} else {prvs_j - 1}, end_ps: fnd_ps_pr.1 - 1}),
173 algn_scr: mx_scr,
174 };
175 pr_algn
176}
177
178pub fn gt_alphbt<'a>(is_dna: bool) -> Alphbt<'a> {
179 if is_dna {
180 b"ACGTURYSWKMBDHVNacgturyswkmbdhvn"
181 } else {
182 b"ARNDCEQGHILKMFPSTWYVarndceqghilkmfpstwyv"
183 }
184}
185
186fn gt_hsh_b_elm(b_elm: BElm, alphbt: Alphbt) -> BElm {
187 alphbt.iter().position(|&alphbt_elm| alphbt_elm == b_elm).expect("Failed to get hashed bio elem.") as BElm
188}
189
190#[cfg(test)]
191mod tsts {
192 use super::{gpu_sw, gt_alphbt, SbstMt, AlgnScrSchm};
193 use super::arrayfire::{set_device, info};
194 use super::std::str::from_utf8;
195
196 #[test]
197 fn tst_gpu_sw() {
198 set_device(0);
199 info();
200 let is_dna = true;
201 let alphbt = gt_alphbt(is_dna);
202 let b_sq_pr = (&b"GGTTGACTA"[..], &b"TGTTACGG"[..]);
203 println!("Seq. pair to align:");
204 println!("{}", from_utf8(&b_sq_pr.0).expect("Failed to get Bio seq."));
205 println!("{}", from_utf8(&b_sq_pr.1).expect("Failed to get Bio seq."));
206 let mut sbst_mt = SbstMt::default();
207 for &alphbt_elm_1 in alphbt.iter() {
208 for &alphbt_elm_2 in alphbt.iter() {
209 sbst_mt.insert((alphbt_elm_1, alphbt_elm_2), if alphbt_elm_1 == alphbt_elm_2 {3.} else {-3.});
210 }
211 }
212 let algn_scr_schm = AlgnScrSchm::new(-0., -2.);
213 let pr_algn = gpu_sw(&b_sq_pr, &sbst_mt, &algn_scr_schm, is_dna);
214 println!("{:?}", &pr_algn);
215 println!("Alignment:");
216 println!("{}", from_utf8(&pr_algn.algn.0).expect("Failed to get alignment."));
217 println!("{}", from_utf8(&pr_algn.algn.1).expect("Failed to get alignment."));
218 }
219}