1mod alignment;
2mod bwt;
3mod command;
4mod gibbs_sampler;
5mod median_string;
6mod randomized_motif_search;
7mod utils;
8
9use alignment::local_alignment;
10use gibbs_sampler::iterate_gibbs_sampler;
11use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
12use median_string::median_string;
13use randomized_motif_search::iterate_randomized_motif_search;
14use rayon::prelude::*;
15use std::str;
16use std::{
17 collections::{HashMap, HashSet},
18 fs::File,
19};
20use tracing::{error, info, trace};
21
22use bio::io::fasta;
23#[doc(hidden)]
24pub use command::MotifFinder;
25
26#[derive(Debug)]
27pub enum Error {
28 GenericError,
29 IOError,
30 FileNotFoundError(String),
31 InvalidInputError,
32 InvalidNucleotideError,
33 InvalidKmerLength,
34 InvalidNumberOfRuns,
35 InvalidNumberOfIterations,
36 InvalidMotifLength,
37 NoMotifsFound,
38 InvalidSequence,
39 InvalidPointerError,
40 InvalidNumberMotifs,
41}
42
43#[tracing::instrument(skip_all)]
44fn scoring_function(motif_matrix: &[String]) -> usize {
45 let mut score = 0;
48 let k = motif_matrix.get(0).unwrap().chars().count();
49 let motifs_length = motif_matrix.len();
50 trace!(motifs_length);
51 for i in 0..k {
53 let mut count: HashMap<char, usize> = HashMap::new();
54 for motif in motif_matrix {
55 if let Some(nuc) = motif.chars().nth(i) {
56 *count.entry(nuc).or_insert(0) += 1;
57 } else {
58 continue;
59 }
60 }
61
62 let max = count.iter().max_by_key(|f| f.1).unwrap().1;
63 trace!(max);
64 score += motifs_length - max;
65 }
66 score
67}
68
69#[tracing::instrument(skip_all)]
70fn generate_profile_given_motif_matrix(
71 motif_matrix: &[String],
72 pseudo: bool,
73) -> Result<Vec<Vec<f64>>, Error> {
74 let k = motif_matrix[0].len();
76 trace!(k);
77 let count_matrix = generate_count_matrix(motif_matrix, k, pseudo);
78 let mut profile_matrix: Vec<Vec<f64>> = vec![vec![0.0; k]; 4];
79 let sum = motif_matrix.len() as f64;
80 for i in 0..k {
82 for j in 0..4 {
84 if let Some(row) = profile_matrix.get_mut(j) {
87 row[i] = (*count_matrix.get(j).unwrap().get(i).unwrap()) as f64 / sum;
88 } else {
89 error!("Invalid index for profile matrix");
90 return Err(Error::InvalidInputError);
91 }
92
93 }
95 }
96 Ok(profile_matrix)
97}
98
99#[tracing::instrument(skip_all)]
100fn generate_count_matrix(motif_matrix: &[String], k: usize, pseudo: bool) -> Vec<Vec<usize>> {
101 let mut val = 0;
103 if pseudo {
104 val = 1;
105 }
106 let mut count_matrix: Vec<Vec<usize>> = vec![vec![val; k]; 4]; for i in 0..k {
108 for motif in motif_matrix {
109 if let Some(index) = match motif.chars().nth(i) {
110 Some('A') => Some(0),
111 Some('C') => Some(1),
112 Some('G') => Some(2),
113 Some('T') => Some(3),
114 _ => None,
115 } {
116 if let Some(count_col) = count_matrix.get_mut(index) {
117 count_col[i] += 1;
118 }
119 }
120 }
121 }
122 count_matrix
123}
124
125#[tracing::instrument(skip_all)]
126fn generate_probability(kmer: &str, profile: &[Vec<f64>]) -> f64 {
127 let mut probability = 1.0;
129 for (i, nuc) in kmer.chars().enumerate() {
130 let nuc_index = match nuc {
131 'A' => Some(0),
132 'C' => Some(1),
133 'G' => Some(2),
134 'T' => Some(3),
135 _ => None,
136 };
137 if let Some(nuc_index) = nuc_index {
138 let current_prob = profile.get(nuc_index).unwrap().get(i).unwrap();
140 probability *= current_prob;
141 }
142 }
143 probability
144}
145
146#[tracing::instrument(skip_all)]
147fn consensus_string(motifs: &[String], k: usize) -> Result<String, Error> {
148 let mut consensus = String::new();
149 let count_matrix = generate_count_matrix(motifs, k, true);
150 for i in 0..k {
151 let mut max = 0;
152 let mut max_index = 0;
153 for j in 0..4 {
154 let count = count_matrix
155 .get(j)
156 .and_then(|row| row.get(i))
157 .ok_or(Error::InvalidNucleotideError)?;
158 if count > &max {
159 max = *count;
160 max_index = j;
161 }
162 }
163 let nuc = match max_index {
164 0 => 'A',
165 1 => 'C',
166 2 => 'G',
167 3 => 'T',
168 _ => return Err(Error::InvalidNucleotideError),
169 };
170 consensus.push(nuc);
171 }
172 Ok(consensus)
173}
174
175#[tracing::instrument(skip_all)]
176pub fn align_motifs_multi_threaded(
177 sequences: &[String],
178 motifs: &[String],
179) -> Result<Vec<(isize, String)>, Error> {
180 let motifs_len = motifs.len();
181 let sequences_len = sequences.len();
182 let pb = ProgressBar::new(
183 motifs_len
184 .try_into()
185 .map_err(|_| Error::InvalidNumberMotifs)?,
186 );
187 let sty =
188 ProgressStyle::with_template(&format!("{{prefix:.bold}}▕{{bar:.{}}}▏{{msg}} ", "9.on_0"))
189 .unwrap()
190 .progress_chars("█▉▊▋▌▍▎▏ ");
191 pb.set_style(
192 ProgressStyle::with_template(
193 "[{elapsed_precise}] {spinner:.9.on_0} {bar:50.9.on_0} {pos:>2}/{len:2} {msg} ({eta})",
194 )
195 .unwrap(),
196 );
197 pb.reset_eta();
198
199 let m = MultiProgress::new();
200
201 let total_pb = m.add(pb);
202 total_pb.println(format!(
203 "Aligning {} unique motifs to {} sequences",
204 motifs_len,
205 sequences.len()
206 ));
207
208 let mut top_five: Vec<(isize, String)> = motifs
209 .par_iter()
210 .progress_with(total_pb.clone())
211 .map(|motif| {
212 let inner = m.add(ProgressBar::new(sequences_len.try_into().unwrap()));
213 inner.set_style(sty.clone());
214 inner.set_prefix(motif.to_string());
215 let mut total_score = 0;
216 let mut highest_score = 0;
217 let mut best_motif = String::from("");
218 for sequence in sequences.iter() {
219 let (score, _v_align, w_align) = local_alignment(sequence, motif, 1, -10, -100)?;
220 if score > highest_score {
221 highest_score = score;
222 best_motif = w_align;
223 }
224 total_score += score;
225 inner.inc(1);
226 }
227 inner.finish_and_clear();
228 Ok((total_score, best_motif))
229 })
230 .collect::<Result<Vec<(isize, String)>, Error>>()?;
231
232 total_pb.finish_with_message("Done!");
233 top_five.par_sort_by(|a, b| b.0.cmp(&a.0));
234 top_five.dedup();
235 top_five.truncate(5);
236 Ok(top_five.to_vec())
237}
238
239#[tracing::instrument]
240pub fn load_data(path_to_file: &str, num_entries: usize) -> Result<Vec<String>, Error> {
241 info!("Loading data from '{}'...", path_to_file);
242 let mut sequences = vec![];
243 let file = match File::open(path_to_file) {
244 Ok(file) => file,
245 Err(_) => return Err(Error::FileNotFoundError(path_to_file.to_string())),
246 };
247 let mut records = fasta::Reader::new(file).records();
248 let mut count = 0;
249 while let Some(Ok(record)) = records.next() {
250 count += 1;
251 if count > num_entries {
252 break;
253 }
254 let s = match str::from_utf8(record.seq()) {
255 Ok(v) => v,
256 Err(_e) => return Err(Error::InvalidSequence),
257 }
258 .to_string()
259 .to_uppercase();
260
261 sequences.push(s);
262 }
263 info!("Done loading data: {} entries", sequences.len());
264 Ok(sequences)
265}
266
267#[tracing::instrument(skip(sequences))]
268pub fn run_gibbs_sampler(
269 sequences: &Vec<String>,
270 k: usize,
271 num_runs: usize,
272 num_iterations: usize,
273) -> Result<Vec<String>, Error> {
274 if num_runs == 0 {
275 return Err(Error::InvalidNumberOfRuns);
276 }
277 if num_iterations == 0 {
278 return Err(Error::InvalidNumberOfIterations);
279 }
280
281 iterate_gibbs_sampler(sequences, k, sequences.len(), num_iterations, num_runs)
282}
283
284#[tracing::instrument(skip(sequences))]
285pub fn run_median_string(sequences: &[String], k: usize) -> Result<Vec<String>, Error> {
286 let median_string = median_string(k, sequences)?;
287 info!("Median string: {}", median_string);
288 let vec = vec![median_string];
289 Ok(vec)
290}
291
292#[tracing::instrument(skip(sequences))]
293pub fn run_randomized_motif_search(
294 sequences: &[String],
295 k: usize,
296 num_runs: usize,
297) -> Result<Vec<String>, Error> {
298 if num_runs == 0 {
299 return Err(Error::InvalidNumberOfRuns);
300 }
301 iterate_randomized_motif_search(sequences, k, num_runs)
302}
303
304#[tracing::instrument(skip(motifs))]
305pub fn generate_consensus_string(motifs: &[String], k: usize) -> Result<String, Error> {
306 if motifs.is_empty() {
307 return Err(Error::NoMotifsFound);
308 } else if motifs.len() == 1 {
309 return Ok(motifs[0].clone());
310 }
311 consensus_string(motifs, k)
312}
313
314#[tracing::instrument(skip(motifs))]
315pub fn unique_motifs(motifs: &[String]) -> HashSet<String> {
316 motifs.into_par_iter().cloned().collect::<HashSet<String>>()
317}
318
319#[cfg(test)]
320mod test {
321 use crate::align_motifs_multi_threaded;
322
323 #[test]
324 pub fn test_load_data() {
325 let sequences = super::load_data("promoters.fasta", 5).unwrap();
326 assert_eq!(sequences.len(), 4);
327 let sequences = super::load_data("promoters.fasta", 4).unwrap();
328 assert_eq!(sequences.len(), 4);
329 let sequences = super::load_data("promoters.fasta", 3).unwrap();
330 assert_eq!(sequences.len(), 3);
331 let sequences = super::load_data("promoters.fasta", 2).unwrap();
332 assert_eq!(sequences.len(), 2);
333 let sequences = super::load_data("promoters.fasta", 1).unwrap();
334 assert_eq!(sequences.len(), 1);
335 let sequences = super::load_data("promoters.fasta", 0).unwrap();
336 assert_eq!(sequences.len(), 0);
337 }
338
339 #[test]
340 pub fn test_entries_less_than_five() {
341 let sequences = super::load_data("promoters.fasta", 4).unwrap();
342 let motifs = super::run_randomized_motif_search(&sequences, 8, 20).unwrap();
343 let top_five = align_motifs_multi_threaded(&sequences, &motifs).unwrap();
344 assert!(top_five.len() <= 4);
345 let sequences = super::load_data("promoters.fasta", 2).unwrap();
346 assert_eq!(sequences.len(), 2);
347 let motifs = super::run_randomized_motif_search(&sequences, 8, 20).unwrap();
348 assert_eq!(motifs.len(), 2);
349 let top_five = align_motifs_multi_threaded(&sequences, &motifs).unwrap();
350 assert!(top_five.len() <= 2);
351 }
352}