Skip to main content

cyanea_seq/
motif.rs

1//! DNA motif discovery — PWM construction, scanning, and EM-based de novo discovery.
2//!
3//! Complements the generic [`Pssm`](crate::pssm::Pssm) with DNA-specific features:
4//! reverse complement scanning, consensus sequences, and MEME-style EM discovery.
5
6use cyanea_core::{CyaneaError, Result};
7
8/// A Position Weight Matrix for DNA motifs (A=0, C=1, G=2, T=3).
9#[derive(Debug, Clone)]
10pub struct Pwm {
11    /// Frequency matrix: `matrix[pos] = [p_A, p_C, p_G, p_T]`.
12    pub matrix: Vec<[f64; 4]>,
13    /// Motif length.
14    pub length: usize,
15}
16
17/// A motif match found by scanning.
18#[derive(Debug, Clone)]
19pub struct MotifMatch {
20    /// Position in the sequence where the match starts.
21    pub position: usize,
22    /// Log-odds score of the match.
23    pub score: f64,
24    /// Which strand the match is on.
25    pub strand: Strand,
26}
27
28/// Strand orientation.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum Strand {
31    Forward,
32    Reverse,
33}
34
35/// A motif discovered by EM.
36#[derive(Debug, Clone)]
37pub struct DiscoveredMotif {
38    /// The discovered PWM.
39    pub pwm: Pwm,
40    /// Sites where the motif was found: `(sequence_index, position)`.
41    pub sites: Vec<(usize, usize)>,
42    /// Log-likelihood ratio score of the motif.
43    pub score: f64,
44}
45
46fn base_index(b: u8) -> Option<usize> {
47    match b.to_ascii_uppercase() {
48        b'A' => Some(0),
49        b'C' => Some(1),
50        b'G' => Some(2),
51        b'T' => Some(3),
52        _ => None,
53    }
54}
55
56impl Pwm {
57    /// Build a PWM from a set of aligned sequences of equal length.
58    ///
59    /// Adds a pseudocount of 0.25 per base to avoid zero probabilities.
60    ///
61    /// # Errors
62    ///
63    /// Returns an error if `sequences` is empty or sequences differ in length.
64    pub fn from_aligned(sequences: &[&[u8]]) -> Result<Self> {
65        if sequences.is_empty() {
66            return Err(CyaneaError::InvalidInput(
67                "at least one sequence is required".into(),
68            ));
69        }
70        let len = sequences[0].len();
71        if len == 0 {
72            return Err(CyaneaError::InvalidInput(
73                "sequences must be non-empty".into(),
74            ));
75        }
76        for s in sequences {
77            if s.len() != len {
78                return Err(CyaneaError::InvalidInput(
79                    "all sequences must have the same length".into(),
80                ));
81            }
82        }
83
84        let n = sequences.len() as f64;
85        let pseudocount = 0.25;
86        let total = n + 4.0 * pseudocount;
87
88        let mut matrix = vec![[0.0f64; 4]; len];
89        for pos in 0..len {
90            let mut counts = [pseudocount; 4];
91            for seq in sequences {
92                if let Some(idx) = base_index(seq[pos]) {
93                    counts[idx] += 1.0;
94                }
95            }
96            for j in 0..4 {
97                matrix[pos][j] = counts[j] / total;
98            }
99        }
100
101        Ok(Self {
102            matrix,
103            length: len,
104        })
105    }
106
107    /// Build a PWM from raw base counts (no pseudocount added).
108    pub fn from_counts(counts: &[[usize; 4]]) -> Self {
109        let mut matrix = Vec::with_capacity(counts.len());
110        for row in counts {
111            let total: usize = row.iter().sum();
112            let t = if total > 0 { total as f64 } else { 1.0 };
113            matrix.push([
114                row[0] as f64 / t,
115                row[1] as f64 / t,
116                row[2] as f64 / t,
117                row[3] as f64 / t,
118            ]);
119        }
120        let length = matrix.len();
121        Self { matrix, length }
122    }
123
124    /// Score a sequence window against this PWM using log-odds.
125    ///
126    /// `background` is `[p_A, p_C, p_G, p_T]`. The window must be
127    /// exactly `self.length` bases long.
128    pub fn score_sequence(&self, seq: &[u8], background: &[f64; 4]) -> f64 {
129        let mut score = 0.0;
130        for (pos, &base) in seq.iter().enumerate().take(self.length) {
131            if let Some(idx) = base_index(base) {
132                let p = self.matrix[pos][idx];
133                let bg = background[idx];
134                if p > 0.0 && bg > 0.0 {
135                    score += (p / bg).log2();
136                }
137            }
138        }
139        score
140    }
141
142    /// Scan a sequence for motif matches above a score threshold.
143    ///
144    /// Checks both forward and reverse complement strands.
145    pub fn scan(
146        &self,
147        seq: &[u8],
148        background: &[f64; 4],
149        threshold: f64,
150    ) -> Vec<MotifMatch> {
151        let mut matches = Vec::new();
152        if seq.len() < self.length {
153            return matches;
154        }
155
156        let rc_pwm = self.reverse_complement();
157
158        for i in 0..=seq.len() - self.length {
159            let window = &seq[i..i + self.length];
160
161            // Forward strand.
162            let fwd_score = self.score_sequence(window, background);
163            if fwd_score >= threshold {
164                matches.push(MotifMatch {
165                    position: i,
166                    score: fwd_score,
167                    strand: Strand::Forward,
168                });
169            }
170
171            // Reverse strand.
172            let rev_score = rc_pwm.score_sequence(window, background);
173            if rev_score >= threshold {
174                matches.push(MotifMatch {
175                    position: i,
176                    score: rev_score,
177                    strand: Strand::Reverse,
178                });
179            }
180        }
181        matches
182    }
183
184    /// Information content at each position (in bits).
185    ///
186    /// IC = 2 - H, where H = -Σ p log2(p).
187    pub fn information_content(&self) -> Vec<f64> {
188        self.matrix
189            .iter()
190            .map(|row| {
191                let entropy: f64 = row
192                    .iter()
193                    .filter(|&&p| p > 0.0)
194                    .map(|&p| -p * p.log2())
195                    .sum();
196                2.0 - entropy
197            })
198            .collect()
199    }
200
201    /// Total information content of the motif (sum across positions).
202    pub fn total_information(&self) -> f64 {
203        self.information_content().iter().sum()
204    }
205
206    /// Consensus sequence (most frequent base at each position).
207    pub fn consensus(&self) -> Vec<u8> {
208        let bases = [b'A', b'C', b'G', b'T'];
209        self.matrix
210            .iter()
211            .map(|row| {
212                let max_idx = row
213                    .iter()
214                    .enumerate()
215                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
216                    .unwrap()
217                    .0;
218                bases[max_idx]
219            })
220            .collect()
221    }
222
223    /// Reverse complement of the PWM.
224    ///
225    /// Reverses position order and swaps A↔T, C↔G columns.
226    pub fn reverse_complement(&self) -> Self {
227        let matrix: Vec<[f64; 4]> = self
228            .matrix
229            .iter()
230            .rev()
231            .map(|row| {
232                // Swap: A(0)↔T(3), C(1)↔G(2)
233                [row[3], row[2], row[1], row[0]]
234            })
235            .collect();
236        Self {
237            length: self.length,
238            matrix,
239        }
240    }
241}
242
243/// Discover motifs in a set of sequences using EM (MEME-style).
244///
245/// Searches for `n_motifs` motifs of `motif_length` bases. Each motif is
246/// found by expectation-maximization, then its sites are masked before
247/// searching for the next.
248///
249/// # Errors
250///
251/// Returns an error if sequences are empty, any sequence is shorter than
252/// `motif_length`, or `motif_length` is zero.
253pub fn discover_motifs(
254    sequences: &[&[u8]],
255    motif_length: usize,
256    n_motifs: usize,
257    max_iter: usize,
258) -> Result<Vec<DiscoveredMotif>> {
259    if sequences.is_empty() {
260        return Err(CyaneaError::InvalidInput(
261            "at least one sequence is required".into(),
262        ));
263    }
264    if motif_length == 0 {
265        return Err(CyaneaError::InvalidInput(
266            "motif_length must be at least 1".into(),
267        ));
268    }
269    for (i, seq) in sequences.iter().enumerate() {
270        if seq.len() < motif_length {
271            return Err(CyaneaError::InvalidInput(format!(
272                "sequence {} (length {}) is shorter than motif_length {}",
273                i,
274                seq.len(),
275                motif_length
276            )));
277        }
278    }
279
280    let background = [0.25f64; 4];
281
282    // Work on owned copies so we can mask discovered sites.
283    let mut working_seqs: Vec<Vec<u8>> = sequences.iter().map(|s| s.to_vec()).collect();
284    let mut motifs = Vec::new();
285
286    for _ in 0..n_motifs {
287        let refs: Vec<&[u8]> = working_seqs.iter().map(|s| s.as_slice()).collect();
288        if let Some(motif) = em_one_motif(&refs, motif_length, max_iter, &background) {
289            // Mask discovered sites with N.
290            for &(seq_idx, pos) in &motif.sites {
291                for j in pos..pos + motif_length {
292                    if j < working_seqs[seq_idx].len() {
293                        working_seqs[seq_idx][j] = b'N';
294                    }
295                }
296            }
297            motifs.push(motif);
298        } else {
299            break;
300        }
301    }
302
303    Ok(motifs)
304}
305
306/// Run one round of EM motif discovery.
307fn em_one_motif(
308    sequences: &[&[u8]],
309    motif_length: usize,
310    max_iter: usize,
311    background: &[f64; 4],
312) -> Option<DiscoveredMotif> {
313    // Find the best seed: subsequence that yields highest initial score.
314    let mut best_pwm: Option<Pwm> = None;
315    let mut best_ll = f64::NEG_INFINITY;
316
317    // Try seeds from the first few sequences.
318    let n_seed_seqs = sequences.len().min(3);
319    for si in 0..n_seed_seqs {
320        if sequences[si].len() < motif_length {
321            continue;
322        }
323        // Sample a few starting positions.
324        let step = (sequences[si].len() - motif_length + 1).max(1);
325        let n_seeds = step.min(10);
326        let stride = step / n_seeds;
327        for seed_start_idx in 0..n_seeds {
328            let pos = seed_start_idx * stride;
329            let seed = &sequences[si][pos..pos + motif_length];
330
331            // Skip seeds with N.
332            if seed.iter().any(|&b| base_index(b).is_none()) {
333                continue;
334            }
335
336            // Initialize PWM from this single seed.
337            let mut pwm = Pwm::from_aligned(&[seed]).ok()?;
338
339            // Run EM iterations.
340            for _ in 0..max_iter {
341                // E-step: compute Z (posterior probability of motif start at each position).
342                let mut weighted_counts = vec![[0.25f64; 4]; motif_length]; // pseudocounts
343                let mut total_weight = 4.0 * 0.25 * motif_length as f64;
344
345                for seq in sequences {
346                    if seq.len() < motif_length {
347                        continue;
348                    }
349                    let n_pos = seq.len() - motif_length + 1;
350
351                    // Compute scores for all positions.
352                    let mut scores: Vec<f64> = Vec::with_capacity(n_pos);
353                    for j in 0..n_pos {
354                        let window = &seq[j..j + motif_length];
355                        if window.iter().any(|&b| base_index(b).is_none()) {
356                            scores.push(f64::NEG_INFINITY);
357                        } else {
358                            scores.push(pwm.score_sequence(window, background));
359                        }
360                    }
361
362                    // Convert to probabilities via softmax.
363                    let max_score = scores
364                        .iter()
365                        .copied()
366                        .filter(|s| s.is_finite())
367                        .fold(f64::NEG_INFINITY, f64::max);
368                    if !max_score.is_finite() {
369                        continue;
370                    }
371
372                    let exp_scores: Vec<f64> = scores
373                        .iter()
374                        .map(|&s| if s.is_finite() { (s - max_score).exp() } else { 0.0 })
375                        .collect();
376                    let sum_exp: f64 = exp_scores.iter().sum();
377                    if sum_exp <= 0.0 {
378                        continue;
379                    }
380
381                    // M-step: accumulate weighted counts.
382                    for j in 0..n_pos {
383                        let z = exp_scores[j] / sum_exp;
384                        if z < 1e-10 {
385                            continue;
386                        }
387                        for p in 0..motif_length {
388                            if let Some(idx) = base_index(seq[j + p]) {
389                                weighted_counts[p][idx] += z;
390                                total_weight += z;
391                            }
392                        }
393                    }
394                }
395
396                // Update PWM from weighted counts.
397                let _ = total_weight; // used implicitly in per-position normalization
398                let mut new_matrix = vec![[0.0f64; 4]; motif_length];
399                for p in 0..motif_length {
400                    let row_total: f64 = weighted_counts[p].iter().sum();
401                    if row_total > 0.0 {
402                        for j in 0..4 {
403                            new_matrix[p][j] = weighted_counts[p][j] / row_total;
404                        }
405                    } else {
406                        new_matrix[p] = [0.25; 4];
407                    }
408                }
409                pwm.matrix = new_matrix;
410            }
411
412            // Compute log-likelihood of this PWM.
413            let ll = compute_ll(sequences, &pwm, background, motif_length);
414            if ll > best_ll {
415                best_ll = ll;
416                best_pwm = Some(pwm);
417            }
418        }
419    }
420
421    let pwm = best_pwm?;
422
423    // Find best site in each sequence.
424    let mut sites = Vec::new();
425    let mut total_score = 0.0;
426    for (si, seq) in sequences.iter().enumerate() {
427        if seq.len() < motif_length {
428            continue;
429        }
430        let mut best_pos = 0;
431        let mut best_score = f64::NEG_INFINITY;
432        for j in 0..=seq.len() - motif_length {
433            let window = &seq[j..j + motif_length];
434            if window.iter().any(|&b| base_index(b).is_none()) {
435                continue;
436            }
437            let s = pwm.score_sequence(window, background);
438            if s > best_score {
439                best_score = s;
440                best_pos = j;
441            }
442        }
443        if best_score.is_finite() && best_score > 0.0 {
444            sites.push((si, best_pos));
445            total_score += best_score;
446        }
447    }
448
449    if sites.is_empty() {
450        return None;
451    }
452
453    Some(DiscoveredMotif {
454        pwm,
455        sites,
456        score: total_score,
457    })
458}
459
460fn compute_ll(sequences: &[&[u8]], pwm: &Pwm, background: &[f64; 4], motif_length: usize) -> f64 {
461    let mut ll = 0.0;
462    for seq in sequences {
463        if seq.len() < motif_length {
464            continue;
465        }
466        let mut best = f64::NEG_INFINITY;
467        for j in 0..=seq.len() - motif_length {
468            let window = &seq[j..j + motif_length];
469            if window.iter().any(|&b| base_index(b).is_none()) {
470                continue;
471            }
472            let s = pwm.score_sequence(window, background);
473            if s > best {
474                best = s;
475            }
476        }
477        if best.is_finite() {
478            ll += best;
479        }
480    }
481    ll
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn pwm_from_aligned_sequences() {
490        let seqs: Vec<&[u8]> = vec![b"ACGT", b"ACGT", b"ACGT"];
491        let pwm = Pwm::from_aligned(&seqs).unwrap();
492        assert_eq!(pwm.length, 4);
493        // Position 0 should be heavily A.
494        assert!(pwm.matrix[0][0] > pwm.matrix[0][1]);
495        assert!(pwm.matrix[0][0] > pwm.matrix[0][2]);
496        assert!(pwm.matrix[0][0] > pwm.matrix[0][3]);
497    }
498
499    #[test]
500    fn pwm_score_perfect_match() {
501        let seqs: Vec<&[u8]> = vec![b"ACGT", b"ACGT", b"ACGT"];
502        let pwm = Pwm::from_aligned(&seqs).unwrap();
503        let bg = [0.25; 4];
504        let score = pwm.score_sequence(b"ACGT", &bg);
505        // Perfect match should have positive score.
506        assert!(score > 0.0);
507    }
508
509    #[test]
510    fn pwm_scan_finds_motif() {
511        let seqs: Vec<&[u8]> = vec![b"GATTACA", b"GATTACA"];
512        let pwm = Pwm::from_aligned(&seqs).unwrap();
513        let bg = [0.25; 4];
514        let target = b"AAAGATTACAAAA";
515        let matches = pwm.scan(target, &bg, 0.0);
516        // Should find the motif on the forward strand.
517        let fwd_matches: Vec<_> = matches
518            .iter()
519            .filter(|m| m.strand == Strand::Forward)
520            .collect();
521        assert!(!fwd_matches.is_empty());
522        // Best forward match should be at position 3.
523        let best = fwd_matches.iter().max_by(|a, b| a.score.partial_cmp(&b.score).unwrap()).unwrap();
524        assert_eq!(best.position, 3);
525    }
526
527    #[test]
528    fn pwm_information_content() {
529        // Uniform distribution: IC = 0 at each position.
530        let pwm = Pwm {
531            matrix: vec![[0.25, 0.25, 0.25, 0.25]; 3],
532            length: 3,
533        };
534        let ic = pwm.information_content();
535        for &v in &ic {
536            assert!(v.abs() < 1e-10);
537        }
538
539        // Perfect conservation: IC = 2 at each position.
540        let pwm2 = Pwm {
541            matrix: vec![[1.0, 0.0, 0.0, 0.0]; 3],
542            length: 3,
543        };
544        let ic2 = pwm2.information_content();
545        for &v in &ic2 {
546            assert!((v - 2.0).abs() < 1e-10);
547        }
548    }
549
550    #[test]
551    fn pwm_reverse_complement() {
552        // A motif "ACG" → reverse complement should be "CGT".
553        let pwm = Pwm {
554            matrix: vec![
555                [1.0, 0.0, 0.0, 0.0], // A
556                [0.0, 1.0, 0.0, 0.0], // C
557                [0.0, 0.0, 1.0, 0.0], // G
558            ],
559            length: 3,
560        };
561        let rc = pwm.reverse_complement();
562        // Position 0 of rc should represent T (complement of A at position 2 reversed).
563        // Original pos 2: G → complement: C → rc pos 0 should be [0, 1, 0, 0]
564        assert!((rc.matrix[0][1] - 1.0).abs() < 1e-10); // C
565        assert!((rc.matrix[1][2] - 1.0).abs() < 1e-10); // G
566        assert!((rc.matrix[2][3] - 1.0).abs() < 1e-10); // T
567    }
568
569    #[test]
570    fn em_discovers_planted_motif() {
571        // Plant a strong motif "ACGTAC" in random-ish backgrounds.
572        let seqs: Vec<&[u8]> = vec![
573            b"TTTTACGTACTTTT",
574            b"GGGGACGTACGGGG",
575            b"AAAACGTACAAAA",
576            b"CCCCACGTACCCCC",
577        ];
578        let motifs = discover_motifs(&seqs, 6, 1, 20).unwrap();
579        assert!(!motifs.is_empty());
580        let m = &motifs[0];
581        // The discovered motif should have consensus close to ACGTAC.
582        let consensus = m.pwm.consensus();
583        assert_eq!(&consensus, b"ACGTAC");
584    }
585}