1use cyanea_core::{CyaneaError, Result};
7
8#[derive(Debug, Clone)]
10pub struct Pwm {
11 pub matrix: Vec<[f64; 4]>,
13 pub length: usize,
15}
16
17#[derive(Debug, Clone)]
19pub struct MotifMatch {
20 pub position: usize,
22 pub score: f64,
24 pub strand: Strand,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum Strand {
31 Forward,
32 Reverse,
33}
34
35#[derive(Debug, Clone)]
37pub struct DiscoveredMotif {
38 pub pwm: Pwm,
40 pub sites: Vec<(usize, usize)>,
42 pub score: f64,
44}
45
46fn base_index(b: u8) -> Option<usize> {
47 match b.to_ascii_uppercase() {
48 b'A' => Some(0),
49 b'C' => Some(1),
50 b'G' => Some(2),
51 b'T' => Some(3),
52 _ => None,
53 }
54}
55
56impl Pwm {
57 pub fn from_aligned(sequences: &[&[u8]]) -> Result<Self> {
65 if sequences.is_empty() {
66 return Err(CyaneaError::InvalidInput(
67 "at least one sequence is required".into(),
68 ));
69 }
70 let len = sequences[0].len();
71 if len == 0 {
72 return Err(CyaneaError::InvalidInput(
73 "sequences must be non-empty".into(),
74 ));
75 }
76 for s in sequences {
77 if s.len() != len {
78 return Err(CyaneaError::InvalidInput(
79 "all sequences must have the same length".into(),
80 ));
81 }
82 }
83
84 let n = sequences.len() as f64;
85 let pseudocount = 0.25;
86 let total = n + 4.0 * pseudocount;
87
88 let mut matrix = vec![[0.0f64; 4]; len];
89 for pos in 0..len {
90 let mut counts = [pseudocount; 4];
91 for seq in sequences {
92 if let Some(idx) = base_index(seq[pos]) {
93 counts[idx] += 1.0;
94 }
95 }
96 for j in 0..4 {
97 matrix[pos][j] = counts[j] / total;
98 }
99 }
100
101 Ok(Self {
102 matrix,
103 length: len,
104 })
105 }
106
107 pub fn from_counts(counts: &[[usize; 4]]) -> Self {
109 let mut matrix = Vec::with_capacity(counts.len());
110 for row in counts {
111 let total: usize = row.iter().sum();
112 let t = if total > 0 { total as f64 } else { 1.0 };
113 matrix.push([
114 row[0] as f64 / t,
115 row[1] as f64 / t,
116 row[2] as f64 / t,
117 row[3] as f64 / t,
118 ]);
119 }
120 let length = matrix.len();
121 Self { matrix, length }
122 }
123
124 pub fn score_sequence(&self, seq: &[u8], background: &[f64; 4]) -> f64 {
129 let mut score = 0.0;
130 for (pos, &base) in seq.iter().enumerate().take(self.length) {
131 if let Some(idx) = base_index(base) {
132 let p = self.matrix[pos][idx];
133 let bg = background[idx];
134 if p > 0.0 && bg > 0.0 {
135 score += (p / bg).log2();
136 }
137 }
138 }
139 score
140 }
141
142 pub fn scan(
146 &self,
147 seq: &[u8],
148 background: &[f64; 4],
149 threshold: f64,
150 ) -> Vec<MotifMatch> {
151 let mut matches = Vec::new();
152 if seq.len() < self.length {
153 return matches;
154 }
155
156 let rc_pwm = self.reverse_complement();
157
158 for i in 0..=seq.len() - self.length {
159 let window = &seq[i..i + self.length];
160
161 let fwd_score = self.score_sequence(window, background);
163 if fwd_score >= threshold {
164 matches.push(MotifMatch {
165 position: i,
166 score: fwd_score,
167 strand: Strand::Forward,
168 });
169 }
170
171 let rev_score = rc_pwm.score_sequence(window, background);
173 if rev_score >= threshold {
174 matches.push(MotifMatch {
175 position: i,
176 score: rev_score,
177 strand: Strand::Reverse,
178 });
179 }
180 }
181 matches
182 }
183
184 pub fn information_content(&self) -> Vec<f64> {
188 self.matrix
189 .iter()
190 .map(|row| {
191 let entropy: f64 = row
192 .iter()
193 .filter(|&&p| p > 0.0)
194 .map(|&p| -p * p.log2())
195 .sum();
196 2.0 - entropy
197 })
198 .collect()
199 }
200
201 pub fn total_information(&self) -> f64 {
203 self.information_content().iter().sum()
204 }
205
206 pub fn consensus(&self) -> Vec<u8> {
208 let bases = [b'A', b'C', b'G', b'T'];
209 self.matrix
210 .iter()
211 .map(|row| {
212 let max_idx = row
213 .iter()
214 .enumerate()
215 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
216 .unwrap()
217 .0;
218 bases[max_idx]
219 })
220 .collect()
221 }
222
223 pub fn reverse_complement(&self) -> Self {
227 let matrix: Vec<[f64; 4]> = self
228 .matrix
229 .iter()
230 .rev()
231 .map(|row| {
232 [row[3], row[2], row[1], row[0]]
234 })
235 .collect();
236 Self {
237 length: self.length,
238 matrix,
239 }
240 }
241}
242
243pub fn discover_motifs(
254 sequences: &[&[u8]],
255 motif_length: usize,
256 n_motifs: usize,
257 max_iter: usize,
258) -> Result<Vec<DiscoveredMotif>> {
259 if sequences.is_empty() {
260 return Err(CyaneaError::InvalidInput(
261 "at least one sequence is required".into(),
262 ));
263 }
264 if motif_length == 0 {
265 return Err(CyaneaError::InvalidInput(
266 "motif_length must be at least 1".into(),
267 ));
268 }
269 for (i, seq) in sequences.iter().enumerate() {
270 if seq.len() < motif_length {
271 return Err(CyaneaError::InvalidInput(format!(
272 "sequence {} (length {}) is shorter than motif_length {}",
273 i,
274 seq.len(),
275 motif_length
276 )));
277 }
278 }
279
280 let background = [0.25f64; 4];
281
282 let mut working_seqs: Vec<Vec<u8>> = sequences.iter().map(|s| s.to_vec()).collect();
284 let mut motifs = Vec::new();
285
286 for _ in 0..n_motifs {
287 let refs: Vec<&[u8]> = working_seqs.iter().map(|s| s.as_slice()).collect();
288 if let Some(motif) = em_one_motif(&refs, motif_length, max_iter, &background) {
289 for &(seq_idx, pos) in &motif.sites {
291 for j in pos..pos + motif_length {
292 if j < working_seqs[seq_idx].len() {
293 working_seqs[seq_idx][j] = b'N';
294 }
295 }
296 }
297 motifs.push(motif);
298 } else {
299 break;
300 }
301 }
302
303 Ok(motifs)
304}
305
306fn em_one_motif(
308 sequences: &[&[u8]],
309 motif_length: usize,
310 max_iter: usize,
311 background: &[f64; 4],
312) -> Option<DiscoveredMotif> {
313 let mut best_pwm: Option<Pwm> = None;
315 let mut best_ll = f64::NEG_INFINITY;
316
317 let n_seed_seqs = sequences.len().min(3);
319 for si in 0..n_seed_seqs {
320 if sequences[si].len() < motif_length {
321 continue;
322 }
323 let step = (sequences[si].len() - motif_length + 1).max(1);
325 let n_seeds = step.min(10);
326 let stride = step / n_seeds;
327 for seed_start_idx in 0..n_seeds {
328 let pos = seed_start_idx * stride;
329 let seed = &sequences[si][pos..pos + motif_length];
330
331 if seed.iter().any(|&b| base_index(b).is_none()) {
333 continue;
334 }
335
336 let mut pwm = Pwm::from_aligned(&[seed]).ok()?;
338
339 for _ in 0..max_iter {
341 let mut weighted_counts = vec![[0.25f64; 4]; motif_length]; let mut total_weight = 4.0 * 0.25 * motif_length as f64;
344
345 for seq in sequences {
346 if seq.len() < motif_length {
347 continue;
348 }
349 let n_pos = seq.len() - motif_length + 1;
350
351 let mut scores: Vec<f64> = Vec::with_capacity(n_pos);
353 for j in 0..n_pos {
354 let window = &seq[j..j + motif_length];
355 if window.iter().any(|&b| base_index(b).is_none()) {
356 scores.push(f64::NEG_INFINITY);
357 } else {
358 scores.push(pwm.score_sequence(window, background));
359 }
360 }
361
362 let max_score = scores
364 .iter()
365 .copied()
366 .filter(|s| s.is_finite())
367 .fold(f64::NEG_INFINITY, f64::max);
368 if !max_score.is_finite() {
369 continue;
370 }
371
372 let exp_scores: Vec<f64> = scores
373 .iter()
374 .map(|&s| if s.is_finite() { (s - max_score).exp() } else { 0.0 })
375 .collect();
376 let sum_exp: f64 = exp_scores.iter().sum();
377 if sum_exp <= 0.0 {
378 continue;
379 }
380
381 for j in 0..n_pos {
383 let z = exp_scores[j] / sum_exp;
384 if z < 1e-10 {
385 continue;
386 }
387 for p in 0..motif_length {
388 if let Some(idx) = base_index(seq[j + p]) {
389 weighted_counts[p][idx] += z;
390 total_weight += z;
391 }
392 }
393 }
394 }
395
396 let _ = total_weight; let mut new_matrix = vec![[0.0f64; 4]; motif_length];
399 for p in 0..motif_length {
400 let row_total: f64 = weighted_counts[p].iter().sum();
401 if row_total > 0.0 {
402 for j in 0..4 {
403 new_matrix[p][j] = weighted_counts[p][j] / row_total;
404 }
405 } else {
406 new_matrix[p] = [0.25; 4];
407 }
408 }
409 pwm.matrix = new_matrix;
410 }
411
412 let ll = compute_ll(sequences, &pwm, background, motif_length);
414 if ll > best_ll {
415 best_ll = ll;
416 best_pwm = Some(pwm);
417 }
418 }
419 }
420
421 let pwm = best_pwm?;
422
423 let mut sites = Vec::new();
425 let mut total_score = 0.0;
426 for (si, seq) in sequences.iter().enumerate() {
427 if seq.len() < motif_length {
428 continue;
429 }
430 let mut best_pos = 0;
431 let mut best_score = f64::NEG_INFINITY;
432 for j in 0..=seq.len() - motif_length {
433 let window = &seq[j..j + motif_length];
434 if window.iter().any(|&b| base_index(b).is_none()) {
435 continue;
436 }
437 let s = pwm.score_sequence(window, background);
438 if s > best_score {
439 best_score = s;
440 best_pos = j;
441 }
442 }
443 if best_score.is_finite() && best_score > 0.0 {
444 sites.push((si, best_pos));
445 total_score += best_score;
446 }
447 }
448
449 if sites.is_empty() {
450 return None;
451 }
452
453 Some(DiscoveredMotif {
454 pwm,
455 sites,
456 score: total_score,
457 })
458}
459
460fn compute_ll(sequences: &[&[u8]], pwm: &Pwm, background: &[f64; 4], motif_length: usize) -> f64 {
461 let mut ll = 0.0;
462 for seq in sequences {
463 if seq.len() < motif_length {
464 continue;
465 }
466 let mut best = f64::NEG_INFINITY;
467 for j in 0..=seq.len() - motif_length {
468 let window = &seq[j..j + motif_length];
469 if window.iter().any(|&b| base_index(b).is_none()) {
470 continue;
471 }
472 let s = pwm.score_sequence(window, background);
473 if s > best {
474 best = s;
475 }
476 }
477 if best.is_finite() {
478 ll += best;
479 }
480 }
481 ll
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn pwm_from_aligned_sequences() {
490 let seqs: Vec<&[u8]> = vec![b"ACGT", b"ACGT", b"ACGT"];
491 let pwm = Pwm::from_aligned(&seqs).unwrap();
492 assert_eq!(pwm.length, 4);
493 assert!(pwm.matrix[0][0] > pwm.matrix[0][1]);
495 assert!(pwm.matrix[0][0] > pwm.matrix[0][2]);
496 assert!(pwm.matrix[0][0] > pwm.matrix[0][3]);
497 }
498
499 #[test]
500 fn pwm_score_perfect_match() {
501 let seqs: Vec<&[u8]> = vec![b"ACGT", b"ACGT", b"ACGT"];
502 let pwm = Pwm::from_aligned(&seqs).unwrap();
503 let bg = [0.25; 4];
504 let score = pwm.score_sequence(b"ACGT", &bg);
505 assert!(score > 0.0);
507 }
508
509 #[test]
510 fn pwm_scan_finds_motif() {
511 let seqs: Vec<&[u8]> = vec![b"GATTACA", b"GATTACA"];
512 let pwm = Pwm::from_aligned(&seqs).unwrap();
513 let bg = [0.25; 4];
514 let target = b"AAAGATTACAAAA";
515 let matches = pwm.scan(target, &bg, 0.0);
516 let fwd_matches: Vec<_> = matches
518 .iter()
519 .filter(|m| m.strand == Strand::Forward)
520 .collect();
521 assert!(!fwd_matches.is_empty());
522 let best = fwd_matches.iter().max_by(|a, b| a.score.partial_cmp(&b.score).unwrap()).unwrap();
524 assert_eq!(best.position, 3);
525 }
526
527 #[test]
528 fn pwm_information_content() {
529 let pwm = Pwm {
531 matrix: vec![[0.25, 0.25, 0.25, 0.25]; 3],
532 length: 3,
533 };
534 let ic = pwm.information_content();
535 for &v in &ic {
536 assert!(v.abs() < 1e-10);
537 }
538
539 let pwm2 = Pwm {
541 matrix: vec![[1.0, 0.0, 0.0, 0.0]; 3],
542 length: 3,
543 };
544 let ic2 = pwm2.information_content();
545 for &v in &ic2 {
546 assert!((v - 2.0).abs() < 1e-10);
547 }
548 }
549
550 #[test]
551 fn pwm_reverse_complement() {
552 let pwm = Pwm {
554 matrix: vec![
555 [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], ],
559 length: 3,
560 };
561 let rc = pwm.reverse_complement();
562 assert!((rc.matrix[0][1] - 1.0).abs() < 1e-10); assert!((rc.matrix[1][2] - 1.0).abs() < 1e-10); assert!((rc.matrix[2][3] - 1.0).abs() < 1e-10); }
568
569 #[test]
570 fn em_discovers_planted_motif() {
571 let seqs: Vec<&[u8]> = vec![
573 b"TTTTACGTACTTTT",
574 b"GGGGACGTACGGGG",
575 b"AAAACGTACAAAA",
576 b"CCCCACGTACCCCC",
577 ];
578 let motifs = discover_motifs(&seqs, 6, 1, 20).unwrap();
579 assert!(!motifs.is_empty());
580 let m = &motifs[0];
581 let consensus = m.pwm.consensus();
583 assert_eq!(&consensus, b"ACGTAC");
584 }
585}