Skip to main content

cyanea_seq/
pssm.rs

1//! Position-Specific Scoring Matrices (PSSMs) for motif representation and scanning.
2//!
3//! Build a PSSM from a count matrix, then score or scan sequences against it.
4//! Supports any fixed-alphabet size via const generics, with [`PssmDna`] and
5//! [`PssmProtein`] type aliases for common use cases.
6
7use cyanea_core::{CyaneaError, Result};
8
9/// A Position-Specific Scoring Matrix (PSSM) with a fixed alphabet size.
10///
11/// Scores are stored as log-odds ratios relative to background frequencies.
12#[derive(Debug, Clone)]
13pub struct Pssm<const N: usize> {
14    /// Log-odds scores: `scores[pos][symbol]` for each position.
15    scores: Vec<[f64; N]>,
16    /// Background frequencies for each symbol.
17    background: [f64; N],
18}
19
20/// PSSM for DNA sequences (A=0, C=1, G=2, T=3).
21pub type PssmDna = Pssm<4>;
22
23/// PSSM for protein sequences (20 standard amino acids).
24pub type PssmProtein = Pssm<20>;
25
26impl<const N: usize> Pssm<N> {
27    /// Build a PSSM from a count matrix.
28    ///
29    /// Each row of `counts` represents one position; each column is a symbol.
30    /// A `pseudocount` is added to every cell before converting to frequencies.
31    /// Scores are computed as `ln(freq / background)`.
32    ///
33    /// # Errors
34    ///
35    /// Returns an error if `counts` is empty or any `background` entry is zero.
36    pub fn from_counts(
37        counts: &[[f64; N]],
38        pseudocount: f64,
39        background: [f64; N],
40    ) -> Result<Self> {
41        if counts.is_empty() {
42            return Err(CyaneaError::InvalidInput(
43                "count matrix must have at least one position".into(),
44            ));
45        }
46        for (i, &bg) in background.iter().enumerate() {
47            if bg <= 0.0 {
48                return Err(CyaneaError::InvalidInput(format!(
49                    "background frequency at index {} must be positive, got {}",
50                    i, bg
51                )));
52            }
53        }
54
55        let mut scores = Vec::with_capacity(counts.len());
56        for row in counts {
57            let total: f64 = row.iter().sum::<f64>() + pseudocount * N as f64;
58            let mut log_odds = [0.0f64; N];
59            for j in 0..N {
60                let freq = (row[j] + pseudocount) / total;
61                log_odds[j] = (freq / background[j]).ln();
62            }
63            scores.push(log_odds);
64        }
65
66        Ok(Self { scores, background })
67    }
68
69    /// Number of positions in the motif.
70    pub fn len(&self) -> usize {
71        self.scores.len()
72    }
73
74    /// Returns `true` if the PSSM has zero positions.
75    pub fn is_empty(&self) -> bool {
76        self.scores.is_empty()
77    }
78
79    /// Score a sequence window against this PSSM.
80    ///
81    /// `seq` must be exactly [`self.len()`] bytes. The `mapping` function
82    /// converts each byte to an alphabet index in `0..N`.
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if `seq.len() != self.len()` or if `mapping` returns
87    /// `None` for any byte.
88    pub fn score(&self, seq: &[u8], mapping: &dyn Fn(u8) -> Option<usize>) -> Result<f64> {
89        if seq.len() != self.len() {
90            return Err(CyaneaError::InvalidInput(format!(
91                "sequence length {} does not match PSSM length {}",
92                seq.len(),
93                self.len()
94            )));
95        }
96        let mut total = 0.0;
97        for (i, &base) in seq.iter().enumerate() {
98            let idx = mapping(base).ok_or_else(|| {
99                CyaneaError::InvalidInput(format!(
100                    "unmapped character '{}' at position {}",
101                    base as char, i
102                ))
103            })?;
104            total += self.scores[i][idx];
105        }
106        Ok(total)
107    }
108
109    /// Slide the PSSM across `seq` and return all hits at or above `threshold`.
110    ///
111    /// Returns `(position, score)` pairs. Positions with unmapped characters
112    /// are silently skipped.
113    pub fn scan(
114        &self,
115        seq: &[u8],
116        threshold: f64,
117        mapping: &dyn Fn(u8) -> Option<usize>,
118    ) -> Vec<(usize, f64)> {
119        let motif_len = self.len();
120        if seq.len() < motif_len {
121            return Vec::new();
122        }
123        let mut hits = Vec::new();
124        for start in 0..=seq.len() - motif_len {
125            if let Ok(s) = self.score(&seq[start..start + motif_len], mapping) {
126                if s >= threshold {
127                    hits.push((start, s));
128                }
129            }
130        }
131        hits
132    }
133
134    /// Information content (bits) at each position.
135    ///
136    /// `IC_j = sum_c freq_c * log2(freq_c / bg_c)` where frequencies are
137    /// recovered from the stored log-odds scores.
138    pub fn information_content(&self) -> Vec<f64> {
139        self.scores
140            .iter()
141            .map(|row| {
142                let mut ic = 0.0;
143                for j in 0..N {
144                    // score = ln(freq / bg), so freq = bg * exp(score)
145                    let freq = self.background[j] * row[j].exp();
146                    if freq > 0.0 {
147                        ic += freq * (freq / self.background[j]).log2();
148                    }
149                }
150                ic
151            })
152            .collect()
153    }
154
155    /// Maximum possible score (sum of the best symbol at each position).
156    pub fn max_score(&self) -> f64 {
157        self.scores
158            .iter()
159            .map(|row| row.iter().cloned().fold(f64::NEG_INFINITY, f64::max))
160            .sum()
161    }
162
163    /// Minimum possible score (sum of the worst symbol at each position).
164    pub fn min_score(&self) -> f64 {
165        self.scores
166            .iter()
167            .map(|row| row.iter().cloned().fold(f64::INFINITY, f64::min))
168            .sum()
169    }
170}
171
172/// Map a DNA base to index (A=0, C=1, G=2, T=3).
173pub fn dna_mapping(b: u8) -> Option<usize> {
174    match b {
175        b'A' | b'a' => Some(0),
176        b'C' | b'c' => Some(1),
177        b'G' | b'g' => Some(2),
178        b'T' | b't' => Some(3),
179        _ => None,
180    }
181}
182
183/// Map an amino acid to index (0--19, alphabetical order:
184/// A, C, D, E, F, G, H, I, K, L, M, N, P, Q, R, S, T, V, W, Y).
185pub fn protein_mapping(b: u8) -> Option<usize> {
186    match b {
187        b'A' | b'a' => Some(0),
188        b'C' | b'c' => Some(1),
189        b'D' | b'd' => Some(2),
190        b'E' | b'e' => Some(3),
191        b'F' | b'f' => Some(4),
192        b'G' | b'g' => Some(5),
193        b'H' | b'h' => Some(6),
194        b'I' | b'i' => Some(7),
195        b'K' | b'k' => Some(8),
196        b'L' | b'l' => Some(9),
197        b'M' | b'm' => Some(10),
198        b'N' | b'n' => Some(11),
199        b'P' | b'p' => Some(12),
200        b'Q' | b'q' => Some(13),
201        b'R' | b'r' => Some(14),
202        b'S' | b's' => Some(15),
203        b'T' | b't' => Some(16),
204        b'V' | b'v' => Some(17),
205        b'W' | b'w' => Some(18),
206        b'Y' | b'y' => Some(19),
207        _ => None,
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    fn uniform_bg() -> [f64; 4] {
216        [0.25; 4]
217    }
218
219    #[test]
220    fn uniform_counts_scores_near_zero() {
221        let counts = vec![[10.0, 10.0, 10.0, 10.0]; 3];
222        let pssm = PssmDna::from_counts(&counts, 0.0, uniform_bg()).unwrap();
223        assert_eq!(pssm.len(), 3);
224        let s = pssm.score(b"ACG", &dna_mapping).unwrap();
225        assert!(s.abs() < 1e-10, "expected ~0, got {}", s);
226    }
227
228    #[test]
229    fn biased_counts_high_score_for_consensus() {
230        // Position 0 strongly favors A, position 1 favors C, position 2 favors G
231        let counts = vec![
232            [100.0, 1.0, 1.0, 1.0],
233            [1.0, 100.0, 1.0, 1.0],
234            [1.0, 1.0, 100.0, 1.0],
235        ];
236        let pssm = PssmDna::from_counts(&counts, 0.0, uniform_bg()).unwrap();
237        let consensus = pssm.score(b"ACG", &dna_mapping).unwrap();
238        let mismatch = pssm.score(b"TTA", &dna_mapping).unwrap();
239        assert!(consensus > mismatch, "consensus {} should beat mismatch {}", consensus, mismatch);
240        assert!(consensus > 0.0);
241    }
242
243    #[test]
244    fn score_known_motif() {
245        let counts = vec![
246            [50.0, 0.0, 0.0, 0.0],
247            [0.0, 50.0, 0.0, 0.0],
248        ];
249        let pssm = PssmDna::from_counts(&counts, 1.0, uniform_bg()).unwrap();
250        let s = pssm.score(b"AC", &dna_mapping).unwrap();
251        // freq = 51/54 ≈ 0.944, ln(0.944/0.25) ≈ 1.329 per position
252        assert!(s > 2.0, "expected score > 2.0, got {}", s);
253    }
254
255    #[test]
256    fn scan_finds_positions() {
257        let counts = vec![
258            [100.0, 0.0, 0.0, 0.0],
259            [0.0, 100.0, 0.0, 0.0],
260            [0.0, 0.0, 100.0, 0.0],
261        ];
262        let pssm = PssmDna::from_counts(&counts, 1.0, uniform_bg()).unwrap();
263        let seq = b"TTACGTTACGTT";
264        let hits = pssm.scan(seq, 0.0, &dna_mapping);
265        // "ACG" appears at positions 2 and 7
266        let positions: Vec<usize> = hits.iter().map(|&(p, _)| p).collect();
267        assert!(positions.contains(&2), "expected hit at 2, got {:?}", positions);
268        assert!(positions.contains(&7), "expected hit at 7, got {:?}", positions);
269    }
270
271    #[test]
272    fn information_content_uniform_is_zero() {
273        let counts = vec![[25.0, 25.0, 25.0, 25.0]];
274        let pssm = PssmDna::from_counts(&counts, 0.0, uniform_bg()).unwrap();
275        let ic = pssm.information_content();
276        assert!(ic[0].abs() < 1e-10, "uniform IC should be ~0, got {}", ic[0]);
277    }
278
279    #[test]
280    fn information_content_conserved_is_two_bits() {
281        // Perfectly conserved A (with tiny pseudocount to avoid log(0))
282        let counts = vec![[1000.0, 0.0, 0.0, 0.0]];
283        let pssm = PssmDna::from_counts(&counts, 0.01, uniform_bg()).unwrap();
284        let ic = pssm.information_content();
285        assert!((ic[0] - 2.0).abs() < 0.05, "conserved IC should be ~2 bits, got {}", ic[0]);
286    }
287
288    #[test]
289    fn error_empty_counts() {
290        let counts: Vec<[f64; 4]> = vec![];
291        let result = PssmDna::from_counts(&counts, 1.0, uniform_bg());
292        assert!(result.is_err());
293    }
294
295    #[test]
296    fn error_zero_background() {
297        let counts = vec![[10.0; 4]];
298        let result = PssmDna::from_counts(&counts, 1.0, [0.25, 0.0, 0.25, 0.25]);
299        assert!(result.is_err());
300    }
301
302    #[test]
303    fn error_wrong_seq_length() {
304        let counts = vec![[10.0; 4]; 3];
305        let pssm = PssmDna::from_counts(&counts, 1.0, uniform_bg()).unwrap();
306        let result = pssm.score(b"AC", &dna_mapping);
307        assert!(result.is_err());
308    }
309
310    #[test]
311    fn error_unmapped_character() {
312        let counts = vec![[10.0; 4]];
313        let pssm = PssmDna::from_counts(&counts, 1.0, uniform_bg()).unwrap();
314        let result = pssm.score(b"X", &dna_mapping);
315        assert!(result.is_err());
316    }
317
318    #[test]
319    fn min_max_score_bounds() {
320        let counts = vec![
321            [100.0, 1.0, 1.0, 1.0],
322            [1.0, 1.0, 1.0, 100.0],
323        ];
324        let pssm = PssmDna::from_counts(&counts, 0.0, uniform_bg()).unwrap();
325        let best = pssm.score(b"AT", &dna_mapping).unwrap();
326        let worst = pssm.score(b"TA", &dna_mapping).unwrap();
327        assert!((best - pssm.max_score()).abs() < 1e-10);
328        assert!((worst - pssm.min_score()).abs() < 1e-10);
329        assert!(pssm.max_score() > pssm.min_score());
330    }
331
332    #[test]
333    fn dna_mapping_cases() {
334        assert_eq!(dna_mapping(b'A'), Some(0));
335        assert_eq!(dna_mapping(b'a'), Some(0));
336        assert_eq!(dna_mapping(b'C'), Some(1));
337        assert_eq!(dna_mapping(b'G'), Some(2));
338        assert_eq!(dna_mapping(b'T'), Some(3));
339        assert_eq!(dna_mapping(b't'), Some(3));
340        assert_eq!(dna_mapping(b'N'), None);
341        assert_eq!(dna_mapping(b'X'), None);
342    }
343
344    #[test]
345    fn protein_mapping_cases() {
346        assert_eq!(protein_mapping(b'A'), Some(0));
347        assert_eq!(protein_mapping(b'Y'), Some(19));
348        assert_eq!(protein_mapping(b'w'), Some(18));
349        assert_eq!(protein_mapping(b'K'), Some(8));
350        assert_eq!(protein_mapping(b'X'), None);
351        assert_eq!(protein_mapping(b'B'), None);
352    }
353
354    #[test]
355    fn scan_short_seq_returns_empty() {
356        let counts = vec![[10.0; 4]; 5];
357        let pssm = PssmDna::from_counts(&counts, 1.0, uniform_bg()).unwrap();
358        let hits = pssm.scan(b"ACG", 0.0, &dna_mapping);
359        assert!(hits.is_empty());
360    }
361
362    #[test]
363    fn case_insensitive_scoring() {
364        let counts = vec![[100.0, 0.0, 0.0, 0.0]];
365        let pssm = PssmDna::from_counts(&counts, 1.0, uniform_bg()).unwrap();
366        let upper = pssm.score(b"A", &dna_mapping).unwrap();
367        let lower = pssm.score(b"a", &dna_mapping).unwrap();
368        assert!((upper - lower).abs() < 1e-10);
369    }
370}