motif_finder/
lib.rs

1mod alignment;
2mod bwt;
3mod command;
4mod gibbs_sampler;
5mod median_string;
6mod randomized_motif_search;
7mod utils;
8
9use alignment::local_alignment;
10use gibbs_sampler::iterate_gibbs_sampler;
11use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
12use median_string::median_string;
13use randomized_motif_search::iterate_randomized_motif_search;
14use rayon::prelude::*;
15use std::str;
16use std::{
17    collections::{HashMap, HashSet},
18    fs::File,
19};
20use tracing::{error, info, trace};
21
22use bio::io::fasta;
23#[doc(hidden)]
24pub use command::MotifFinder;
25
26#[derive(Debug)]
27pub enum Error {
28    GenericError,
29    IOError,
30    FileNotFoundError(String),
31    InvalidInputError,
32    InvalidNucleotideError,
33    InvalidKmerLength,
34    InvalidNumberOfRuns,
35    InvalidNumberOfIterations,
36    InvalidMotifLength,
37    NoMotifsFound,
38    InvalidSequence,
39    InvalidPointerError,
40    InvalidNumberMotifs,
41}
42
43#[tracing::instrument(skip_all)]
44fn scoring_function(motif_matrix: &[String]) -> usize {
45    // given a motif matrix, generate its score by finding the highest count of nucleotide in a given position
46    // and subtract that count from the total length of the column
47    let mut score = 0;
48    let k = motif_matrix.get(0).unwrap().chars().count();
49    let motifs_length = motif_matrix.len();
50    trace!(motifs_length);
51    // println!("len {}",motifs_length);
52    for i in 0..k {
53        let mut count: HashMap<char, usize> = HashMap::new();
54        for motif in motif_matrix {
55            if let Some(nuc) = motif.chars().nth(i) {
56                *count.entry(nuc).or_insert(0) += 1;
57            } else {
58                continue;
59            }
60        }
61
62        let max = count.iter().max_by_key(|f| f.1).unwrap().1;
63        trace!(max);
64        score += motifs_length - max;
65    }
66    score
67}
68
69#[tracing::instrument(skip_all)]
70fn generate_profile_given_motif_matrix(
71    motif_matrix: &[String],
72    pseudo: bool,
73) -> Result<Vec<Vec<f64>>, Error> {
74    // generate probabilities per column using the count matrix divided by sum of each column
75    let k = motif_matrix[0].len();
76    trace!(k);
77    let count_matrix = generate_count_matrix(motif_matrix, k, pseudo);
78    let mut profile_matrix: Vec<Vec<f64>> = vec![vec![0.0; k]; 4];
79    let sum = motif_matrix.len() as f64;
80    // iterating over each position
81    for i in 0..k {
82        // iterating over each nucleotide base
83        for j in 0..4 {
84            // print_vector_space_delimited(row.clone());
85            // get the row associated with the nucleotide at index i
86            if let Some(row) = profile_matrix.get_mut(j) {
87                row[i] = (*count_matrix.get(j).unwrap().get(i).unwrap()) as f64 / sum;
88            } else {
89                error!("Invalid index for profile matrix");
90                return Err(Error::InvalidInputError);
91            }
92
93            // divide by sum to get the percentage probability
94        }
95    }
96    Ok(profile_matrix)
97}
98
99#[tracing::instrument(skip_all)]
100fn generate_count_matrix(motif_matrix: &[String], k: usize, pseudo: bool) -> Vec<Vec<usize>> {
101    // enumerate motif matrix per nucleotide per position
102    let mut val = 0;
103    if pseudo {
104        val = 1;
105    }
106    let mut count_matrix: Vec<Vec<usize>> = vec![vec![val; k]; 4]; // ACGT = 4
107    for i in 0..k {
108        for motif in motif_matrix {
109            if let Some(index) = match motif.chars().nth(i) {
110                Some('A') => Some(0),
111                Some('C') => Some(1),
112                Some('G') => Some(2),
113                Some('T') => Some(3),
114                _ => None,
115            } {
116                if let Some(count_col) = count_matrix.get_mut(index) {
117                    count_col[i] += 1;
118                }
119            }
120        }
121    }
122    count_matrix
123}
124
125#[tracing::instrument(skip_all)]
126fn generate_probability(kmer: &str, profile: &[Vec<f64>]) -> f64 {
127    // given a kmer and a profile, generate its probability
128    let mut probability = 1.0;
129    for (i, nuc) in kmer.chars().enumerate() {
130        let nuc_index = match nuc {
131            'A' => Some(0),
132            'C' => Some(1),
133            'G' => Some(2),
134            'T' => Some(3),
135            _ => None,
136        };
137        if let Some(nuc_index) = nuc_index {
138            // this should always be true but just in case
139            let current_prob = profile.get(nuc_index).unwrap().get(i).unwrap();
140            probability *= current_prob;
141        }
142    }
143    probability
144}
145
146#[tracing::instrument(skip_all)]
147fn consensus_string(motifs: &[String], k: usize) -> Result<String, Error> {
148    let mut consensus = String::new();
149    let count_matrix = generate_count_matrix(motifs, k, true);
150    for i in 0..k {
151        let mut max = 0;
152        let mut max_index = 0;
153        for j in 0..4 {
154            let count = count_matrix
155                .get(j)
156                .and_then(|row| row.get(i))
157                .ok_or(Error::InvalidNucleotideError)?;
158            if count > &max {
159                max = *count;
160                max_index = j;
161            }
162        }
163        let nuc = match max_index {
164            0 => 'A',
165            1 => 'C',
166            2 => 'G',
167            3 => 'T',
168            _ => return Err(Error::InvalidNucleotideError),
169        };
170        consensus.push(nuc);
171    }
172    Ok(consensus)
173}
174
175#[tracing::instrument(skip_all)]
176pub fn align_motifs_multi_threaded(
177    sequences: &[String],
178    motifs: &[String],
179) -> Result<Vec<(isize, String)>, Error> {
180    let motifs_len = motifs.len();
181    let sequences_len = sequences.len();
182    let pb = ProgressBar::new(
183        motifs_len
184            .try_into()
185            .map_err(|_| Error::InvalidNumberMotifs)?,
186    );
187    let sty =
188        ProgressStyle::with_template(&format!("{{prefix:.bold}}▕{{bar:.{}}}▏{{msg}} ", "9.on_0"))
189            .unwrap()
190            .progress_chars("█▉▊▋▌▍▎▏  ");
191    pb.set_style(
192        ProgressStyle::with_template(
193            "[{elapsed_precise}] {spinner:.9.on_0} {bar:50.9.on_0} {pos:>2}/{len:2} {msg} ({eta})",
194        )
195        .unwrap(),
196    );
197    pb.reset_eta();
198
199    let m = MultiProgress::new();
200
201    let total_pb = m.add(pb);
202    total_pb.println(format!(
203        "Aligning {} unique motifs to {} sequences",
204        motifs_len,
205        sequences.len()
206    ));
207
208    let mut top_five: Vec<(isize, String)> = motifs
209        .par_iter()
210        .progress_with(total_pb.clone())
211        .map(|motif| {
212            let inner = m.add(ProgressBar::new(sequences_len.try_into().unwrap()));
213            inner.set_style(sty.clone());
214            inner.set_prefix(motif.to_string());
215            let mut total_score = 0;
216            let mut highest_score = 0;
217            let mut best_motif = String::from("");
218            for sequence in sequences.iter() {
219                let (score, _v_align, w_align) = local_alignment(sequence, motif, 1, -10, -100)?;
220                if score > highest_score {
221                    highest_score = score;
222                    best_motif = w_align;
223                }
224                total_score += score;
225                inner.inc(1);
226            }
227            inner.finish_and_clear();
228            Ok((total_score, best_motif))
229        })
230        .collect::<Result<Vec<(isize, String)>, Error>>()?;
231
232    total_pb.finish_with_message("Done!");
233    top_five.par_sort_by(|a, b| b.0.cmp(&a.0));
234    top_five.dedup();
235    top_five.truncate(5);
236    Ok(top_five.to_vec())
237}
238
239#[tracing::instrument]
240pub fn load_data(path_to_file: &str, num_entries: usize) -> Result<Vec<String>, Error> {
241    info!("Loading data from '{}'...", path_to_file);
242    let mut sequences = vec![];
243    let file = match File::open(path_to_file) {
244        Ok(file) => file,
245        Err(_) => return Err(Error::FileNotFoundError(path_to_file.to_string())),
246    };
247    let mut records = fasta::Reader::new(file).records();
248    let mut count = 0;
249    while let Some(Ok(record)) = records.next() {
250        count += 1;
251        if count > num_entries {
252            break;
253        }
254        let s = match str::from_utf8(record.seq()) {
255            Ok(v) => v,
256            Err(_e) => return Err(Error::InvalidSequence),
257        }
258        .to_string()
259        .to_uppercase();
260
261        sequences.push(s);
262    }
263    info!("Done loading data: {} entries", sequences.len());
264    Ok(sequences)
265}
266
267#[tracing::instrument(skip(sequences))]
268pub fn run_gibbs_sampler(
269    sequences: &Vec<String>,
270    k: usize,
271    num_runs: usize,
272    num_iterations: usize,
273) -> Result<Vec<String>, Error> {
274    if num_runs == 0 {
275        return Err(Error::InvalidNumberOfRuns);
276    }
277    if num_iterations == 0 {
278        return Err(Error::InvalidNumberOfIterations);
279    }
280
281    iterate_gibbs_sampler(sequences, k, sequences.len(), num_iterations, num_runs)
282}
283
284#[tracing::instrument(skip(sequences))]
285pub fn run_median_string(sequences: &[String], k: usize) -> Result<Vec<String>, Error> {
286    let median_string = median_string(k, sequences)?;
287    info!("Median string: {}", median_string);
288    let vec = vec![median_string];
289    Ok(vec)
290}
291
292#[tracing::instrument(skip(sequences))]
293pub fn run_randomized_motif_search(
294    sequences: &[String],
295    k: usize,
296    num_runs: usize,
297) -> Result<Vec<String>, Error> {
298    if num_runs == 0 {
299        return Err(Error::InvalidNumberOfRuns);
300    }
301    iterate_randomized_motif_search(sequences, k, num_runs)
302}
303
304#[tracing::instrument(skip(motifs))]
305pub fn generate_consensus_string(motifs: &[String], k: usize) -> Result<String, Error> {
306    if motifs.is_empty() {
307        return Err(Error::NoMotifsFound);
308    } else if motifs.len() == 1 {
309        return Ok(motifs[0].clone());
310    }
311    consensus_string(motifs, k)
312}
313
314#[tracing::instrument(skip(motifs))]
315pub fn unique_motifs(motifs: &[String]) -> HashSet<String> {
316    motifs.into_par_iter().cloned().collect::<HashSet<String>>()
317}
318
319#[cfg(test)]
320mod test {
321    use crate::align_motifs_multi_threaded;
322
323    #[test]
324    pub fn test_load_data() {
325        let sequences = super::load_data("promoters.fasta", 5).unwrap();
326        assert_eq!(sequences.len(), 4);
327        let sequences = super::load_data("promoters.fasta", 4).unwrap();
328        assert_eq!(sequences.len(), 4);
329        let sequences = super::load_data("promoters.fasta", 3).unwrap();
330        assert_eq!(sequences.len(), 3);
331        let sequences = super::load_data("promoters.fasta", 2).unwrap();
332        assert_eq!(sequences.len(), 2);
333        let sequences = super::load_data("promoters.fasta", 1).unwrap();
334        assert_eq!(sequences.len(), 1);
335        let sequences = super::load_data("promoters.fasta", 0).unwrap();
336        assert_eq!(sequences.len(), 0);
337    }
338
339    #[test]
340    pub fn test_entries_less_than_five() {
341        let sequences = super::load_data("promoters.fasta", 4).unwrap();
342        let motifs = super::run_randomized_motif_search(&sequences, 8, 20).unwrap();
343        let top_five = align_motifs_multi_threaded(&sequences, &motifs).unwrap();
344        assert!(top_five.len() <= 4);
345        let sequences = super::load_data("promoters.fasta", 2).unwrap();
346        assert_eq!(sequences.len(), 2);
347        let motifs = super::run_randomized_motif_search(&sequences, 8, 20).unwrap();
348        assert_eq!(motifs.len(), 2);
349        let top_five = align_motifs_multi_threaded(&sequences, &motifs).unwrap();
350        assert!(top_five.len() <= 2);
351    }
352}