1use cyanea_core::{CyaneaError, Result};
8
9#[derive(Debug, Clone)]
13pub struct Pssm<const N: usize> {
14 scores: Vec<[f64; N]>,
16 background: [f64; N],
18}
19
20pub type PssmDna = Pssm<4>;
22
23pub type PssmProtein = Pssm<20>;
25
26impl<const N: usize> Pssm<N> {
27 pub fn from_counts(
37 counts: &[[f64; N]],
38 pseudocount: f64,
39 background: [f64; N],
40 ) -> Result<Self> {
41 if counts.is_empty() {
42 return Err(CyaneaError::InvalidInput(
43 "count matrix must have at least one position".into(),
44 ));
45 }
46 for (i, &bg) in background.iter().enumerate() {
47 if bg <= 0.0 {
48 return Err(CyaneaError::InvalidInput(format!(
49 "background frequency at index {} must be positive, got {}",
50 i, bg
51 )));
52 }
53 }
54
55 let mut scores = Vec::with_capacity(counts.len());
56 for row in counts {
57 let total: f64 = row.iter().sum::<f64>() + pseudocount * N as f64;
58 let mut log_odds = [0.0f64; N];
59 for j in 0..N {
60 let freq = (row[j] + pseudocount) / total;
61 log_odds[j] = (freq / background[j]).ln();
62 }
63 scores.push(log_odds);
64 }
65
66 Ok(Self { scores, background })
67 }
68
69 pub fn len(&self) -> usize {
71 self.scores.len()
72 }
73
74 pub fn is_empty(&self) -> bool {
76 self.scores.is_empty()
77 }
78
79 pub fn score(&self, seq: &[u8], mapping: &dyn Fn(u8) -> Option<usize>) -> Result<f64> {
89 if seq.len() != self.len() {
90 return Err(CyaneaError::InvalidInput(format!(
91 "sequence length {} does not match PSSM length {}",
92 seq.len(),
93 self.len()
94 )));
95 }
96 let mut total = 0.0;
97 for (i, &base) in seq.iter().enumerate() {
98 let idx = mapping(base).ok_or_else(|| {
99 CyaneaError::InvalidInput(format!(
100 "unmapped character '{}' at position {}",
101 base as char, i
102 ))
103 })?;
104 total += self.scores[i][idx];
105 }
106 Ok(total)
107 }
108
109 pub fn scan(
114 &self,
115 seq: &[u8],
116 threshold: f64,
117 mapping: &dyn Fn(u8) -> Option<usize>,
118 ) -> Vec<(usize, f64)> {
119 let motif_len = self.len();
120 if seq.len() < motif_len {
121 return Vec::new();
122 }
123 let mut hits = Vec::new();
124 for start in 0..=seq.len() - motif_len {
125 if let Ok(s) = self.score(&seq[start..start + motif_len], mapping) {
126 if s >= threshold {
127 hits.push((start, s));
128 }
129 }
130 }
131 hits
132 }
133
134 pub fn information_content(&self) -> Vec<f64> {
139 self.scores
140 .iter()
141 .map(|row| {
142 let mut ic = 0.0;
143 for j in 0..N {
144 let freq = self.background[j] * row[j].exp();
146 if freq > 0.0 {
147 ic += freq * (freq / self.background[j]).log2();
148 }
149 }
150 ic
151 })
152 .collect()
153 }
154
155 pub fn max_score(&self) -> f64 {
157 self.scores
158 .iter()
159 .map(|row| row.iter().cloned().fold(f64::NEG_INFINITY, f64::max))
160 .sum()
161 }
162
163 pub fn min_score(&self) -> f64 {
165 self.scores
166 .iter()
167 .map(|row| row.iter().cloned().fold(f64::INFINITY, f64::min))
168 .sum()
169 }
170}
171
172pub fn dna_mapping(b: u8) -> Option<usize> {
174 match b {
175 b'A' | b'a' => Some(0),
176 b'C' | b'c' => Some(1),
177 b'G' | b'g' => Some(2),
178 b'T' | b't' => Some(3),
179 _ => None,
180 }
181}
182
183pub fn protein_mapping(b: u8) -> Option<usize> {
186 match b {
187 b'A' | b'a' => Some(0),
188 b'C' | b'c' => Some(1),
189 b'D' | b'd' => Some(2),
190 b'E' | b'e' => Some(3),
191 b'F' | b'f' => Some(4),
192 b'G' | b'g' => Some(5),
193 b'H' | b'h' => Some(6),
194 b'I' | b'i' => Some(7),
195 b'K' | b'k' => Some(8),
196 b'L' | b'l' => Some(9),
197 b'M' | b'm' => Some(10),
198 b'N' | b'n' => Some(11),
199 b'P' | b'p' => Some(12),
200 b'Q' | b'q' => Some(13),
201 b'R' | b'r' => Some(14),
202 b'S' | b's' => Some(15),
203 b'T' | b't' => Some(16),
204 b'V' | b'v' => Some(17),
205 b'W' | b'w' => Some(18),
206 b'Y' | b'y' => Some(19),
207 _ => None,
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 fn uniform_bg() -> [f64; 4] {
216 [0.25; 4]
217 }
218
219 #[test]
220 fn uniform_counts_scores_near_zero() {
221 let counts = vec![[10.0, 10.0, 10.0, 10.0]; 3];
222 let pssm = PssmDna::from_counts(&counts, 0.0, uniform_bg()).unwrap();
223 assert_eq!(pssm.len(), 3);
224 let s = pssm.score(b"ACG", &dna_mapping).unwrap();
225 assert!(s.abs() < 1e-10, "expected ~0, got {}", s);
226 }
227
228 #[test]
229 fn biased_counts_high_score_for_consensus() {
230 let counts = vec![
232 [100.0, 1.0, 1.0, 1.0],
233 [1.0, 100.0, 1.0, 1.0],
234 [1.0, 1.0, 100.0, 1.0],
235 ];
236 let pssm = PssmDna::from_counts(&counts, 0.0, uniform_bg()).unwrap();
237 let consensus = pssm.score(b"ACG", &dna_mapping).unwrap();
238 let mismatch = pssm.score(b"TTA", &dna_mapping).unwrap();
239 assert!(consensus > mismatch, "consensus {} should beat mismatch {}", consensus, mismatch);
240 assert!(consensus > 0.0);
241 }
242
243 #[test]
244 fn score_known_motif() {
245 let counts = vec![
246 [50.0, 0.0, 0.0, 0.0],
247 [0.0, 50.0, 0.0, 0.0],
248 ];
249 let pssm = PssmDna::from_counts(&counts, 1.0, uniform_bg()).unwrap();
250 let s = pssm.score(b"AC", &dna_mapping).unwrap();
251 assert!(s > 2.0, "expected score > 2.0, got {}", s);
253 }
254
255 #[test]
256 fn scan_finds_positions() {
257 let counts = vec![
258 [100.0, 0.0, 0.0, 0.0],
259 [0.0, 100.0, 0.0, 0.0],
260 [0.0, 0.0, 100.0, 0.0],
261 ];
262 let pssm = PssmDna::from_counts(&counts, 1.0, uniform_bg()).unwrap();
263 let seq = b"TTACGTTACGTT";
264 let hits = pssm.scan(seq, 0.0, &dna_mapping);
265 let positions: Vec<usize> = hits.iter().map(|&(p, _)| p).collect();
267 assert!(positions.contains(&2), "expected hit at 2, got {:?}", positions);
268 assert!(positions.contains(&7), "expected hit at 7, got {:?}", positions);
269 }
270
271 #[test]
272 fn information_content_uniform_is_zero() {
273 let counts = vec![[25.0, 25.0, 25.0, 25.0]];
274 let pssm = PssmDna::from_counts(&counts, 0.0, uniform_bg()).unwrap();
275 let ic = pssm.information_content();
276 assert!(ic[0].abs() < 1e-10, "uniform IC should be ~0, got {}", ic[0]);
277 }
278
279 #[test]
280 fn information_content_conserved_is_two_bits() {
281 let counts = vec![[1000.0, 0.0, 0.0, 0.0]];
283 let pssm = PssmDna::from_counts(&counts, 0.01, uniform_bg()).unwrap();
284 let ic = pssm.information_content();
285 assert!((ic[0] - 2.0).abs() < 0.05, "conserved IC should be ~2 bits, got {}", ic[0]);
286 }
287
288 #[test]
289 fn error_empty_counts() {
290 let counts: Vec<[f64; 4]> = vec![];
291 let result = PssmDna::from_counts(&counts, 1.0, uniform_bg());
292 assert!(result.is_err());
293 }
294
295 #[test]
296 fn error_zero_background() {
297 let counts = vec![[10.0; 4]];
298 let result = PssmDna::from_counts(&counts, 1.0, [0.25, 0.0, 0.25, 0.25]);
299 assert!(result.is_err());
300 }
301
302 #[test]
303 fn error_wrong_seq_length() {
304 let counts = vec![[10.0; 4]; 3];
305 let pssm = PssmDna::from_counts(&counts, 1.0, uniform_bg()).unwrap();
306 let result = pssm.score(b"AC", &dna_mapping);
307 assert!(result.is_err());
308 }
309
310 #[test]
311 fn error_unmapped_character() {
312 let counts = vec![[10.0; 4]];
313 let pssm = PssmDna::from_counts(&counts, 1.0, uniform_bg()).unwrap();
314 let result = pssm.score(b"X", &dna_mapping);
315 assert!(result.is_err());
316 }
317
318 #[test]
319 fn min_max_score_bounds() {
320 let counts = vec![
321 [100.0, 1.0, 1.0, 1.0],
322 [1.0, 1.0, 1.0, 100.0],
323 ];
324 let pssm = PssmDna::from_counts(&counts, 0.0, uniform_bg()).unwrap();
325 let best = pssm.score(b"AT", &dna_mapping).unwrap();
326 let worst = pssm.score(b"TA", &dna_mapping).unwrap();
327 assert!((best - pssm.max_score()).abs() < 1e-10);
328 assert!((worst - pssm.min_score()).abs() < 1e-10);
329 assert!(pssm.max_score() > pssm.min_score());
330 }
331
332 #[test]
333 fn dna_mapping_cases() {
334 assert_eq!(dna_mapping(b'A'), Some(0));
335 assert_eq!(dna_mapping(b'a'), Some(0));
336 assert_eq!(dna_mapping(b'C'), Some(1));
337 assert_eq!(dna_mapping(b'G'), Some(2));
338 assert_eq!(dna_mapping(b'T'), Some(3));
339 assert_eq!(dna_mapping(b't'), Some(3));
340 assert_eq!(dna_mapping(b'N'), None);
341 assert_eq!(dna_mapping(b'X'), None);
342 }
343
344 #[test]
345 fn protein_mapping_cases() {
346 assert_eq!(protein_mapping(b'A'), Some(0));
347 assert_eq!(protein_mapping(b'Y'), Some(19));
348 assert_eq!(protein_mapping(b'w'), Some(18));
349 assert_eq!(protein_mapping(b'K'), Some(8));
350 assert_eq!(protein_mapping(b'X'), None);
351 assert_eq!(protein_mapping(b'B'), None);
352 }
353
354 #[test]
355 fn scan_short_seq_returns_empty() {
356 let counts = vec![[10.0; 4]; 5];
357 let pssm = PssmDna::from_counts(&counts, 1.0, uniform_bg()).unwrap();
358 let hits = pssm.scan(b"ACG", 0.0, &dna_mapping);
359 assert!(hits.is_empty());
360 }
361
362 #[test]
363 fn case_insensitive_scoring() {
364 let counts = vec![[100.0, 0.0, 0.0, 0.0]];
365 let pssm = PssmDna::from_counts(&counts, 1.0, uniform_bg()).unwrap();
366 let upper = pssm.score(b"A", &dna_mapping).unwrap();
367 let lower = pssm.score(b"a", &dna_mapping).unwrap();
368 assert!((upper - lower).abs() < 1e-10);
369 }
370}