1use cyanea_core::{CyaneaError, Result};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum MaskMode {
15 Soft,
17 Hard,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum MaskSource {
24 Dust,
25 Seg,
26 TandemRepeat,
27}
28
29#[derive(Debug, Clone)]
31pub struct MaskedRegion {
32 pub start: usize,
34 pub end: usize,
36 pub score: f64,
38 pub source: MaskSource,
40}
41
42#[derive(Debug, Clone)]
44pub struct MaskResult {
45 pub sequence: Vec<u8>,
47 pub regions: Vec<MaskedRegion>,
49 pub masked_fraction: f64,
51}
52
53#[derive(Debug, Clone)]
59pub struct DustParams {
60 pub window: usize,
62 pub threshold: f64,
64 pub linker: usize,
66}
67
68impl Default for DustParams {
69 fn default() -> Self {
70 Self {
71 window: 64,
72 threshold: 20.0,
73 linker: 1,
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
84pub struct SegParams {
85 pub window: usize,
87 pub lowcut: f64,
89 pub highcut: f64,
91}
92
93impl Default for SegParams {
94 fn default() -> Self {
95 Self {
96 window: 12,
97 lowcut: 2.2,
98 highcut: 2.5,
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
109pub struct TandemRepeatParams {
110 pub min_period: usize,
112 pub max_period: usize,
114 pub min_copies: usize,
116}
117
118impl Default for TandemRepeatParams {
119 fn default() -> Self {
120 Self {
121 min_period: 1,
122 max_period: 6,
123 min_copies: 3,
124 }
125 }
126}
127
128fn dust_score(window: &[u8]) -> f64 {
136 if window.len() < 3 {
137 return 0.0;
138 }
139
140 let mut counts = [0u32; 64];
141 for tri in window.windows(3) {
142 let idx = triplet_index(tri);
143 if let Some(i) = idx {
144 counts[i] += 1;
145 }
146 }
147
148 let mut score = 0.0f64;
149 for &c in &counts {
150 if c > 1 {
151 score += (c as f64) * (c as f64 - 1.0) / 2.0;
152 }
153 }
154
155 let denom = (window.len() as f64) - 2.0;
156 if denom > 0.0 {
157 score / denom
158 } else {
159 0.0
160 }
161}
162
163fn triplet_index(tri: &[u8]) -> Option<usize> {
165 let map = |b: u8| -> Option<usize> {
166 match b.to_ascii_uppercase() {
167 b'A' => Some(0),
168 b'C' => Some(1),
169 b'G' => Some(2),
170 b'T' | b'U' => Some(3),
171 _ => None,
172 }
173 };
174 Some(map(tri[0])? * 16 + map(tri[1])? * 4 + map(tri[2])?)
175}
176
177pub fn dust(seq: &[u8], params: &DustParams) -> Result<Vec<MaskedRegion>> {
183 if seq.is_empty() {
184 return Err(CyaneaError::InvalidInput("sequence is empty".into()));
185 }
186
187 let w = params.window.min(seq.len());
188 if w < 3 {
189 return Ok(Vec::new());
190 }
191
192 let mut raw_regions: Vec<(usize, usize, f64)> = Vec::new();
193
194 for start in 0..=seq.len().saturating_sub(w) {
195 let window = &seq[start..start + w];
196 let score = dust_score(window);
197 if score > params.threshold {
198 raw_regions.push((start, start + w, score));
199 }
200 }
201
202 let merged = merge_regions(&raw_regions, params.linker);
204
205 Ok(merged
206 .into_iter()
207 .map(|(start, end, score)| MaskedRegion {
208 start,
209 end,
210 score,
211 source: MaskSource::Dust,
212 })
213 .collect())
214}
215
216fn aa_entropy(window: &[u8]) -> f64 {
222 let mut counts = [0u32; 26]; let mut total = 0u32;
224
225 for &b in window {
226 let upper = b.to_ascii_uppercase();
227 if upper >= b'A' && upper <= b'Z' {
228 counts[(upper - b'A') as usize] += 1;
229 total += 1;
230 }
231 }
232
233 if total == 0 {
234 return 0.0;
235 }
236
237 let mut entropy = 0.0f64;
238 let t = total as f64;
239 for &c in &counts {
240 if c > 0 {
241 let p = c as f64 / t;
242 entropy -= p * p.log2();
243 }
244 }
245 entropy
246}
247
248pub fn seg(seq: &[u8], params: &SegParams) -> Result<Vec<MaskedRegion>> {
254 if seq.is_empty() {
255 return Err(CyaneaError::InvalidInput("sequence is empty".into()));
256 }
257
258 let w = params.window.min(seq.len());
259 if w < 2 {
260 return Ok(Vec::new());
261 }
262
263 let mut raw_regions: Vec<(usize, usize, f64)> = Vec::new();
264
265 for start in 0..=seq.len().saturating_sub(w) {
266 let window = &seq[start..start + w];
267 let ent = aa_entropy(window);
268 if ent <= params.lowcut {
269 let mut ext_start = start;
271 let mut ext_end = start + w;
272
273 while ext_start > 0 {
275 let candidate = &seq[ext_start - 1..ext_end];
276 if aa_entropy(candidate) <= params.highcut {
277 ext_start -= 1;
278 } else {
279 break;
280 }
281 }
282
283 while ext_end < seq.len() {
285 let candidate = &seq[ext_start..ext_end + 1];
286 if aa_entropy(candidate) <= params.highcut {
287 ext_end += 1;
288 } else {
289 break;
290 }
291 }
292
293 raw_regions.push((ext_start, ext_end, ent));
294 }
295 }
296
297 let merged = merge_regions(&raw_regions, 0);
298
299 Ok(merged
300 .into_iter()
301 .map(|(start, end, score)| MaskedRegion {
302 start,
303 end,
304 score,
305 source: MaskSource::Seg,
306 })
307 .collect())
308}
309
310pub fn find_tandem_repeats(
320 seq: &[u8],
321 params: &TandemRepeatParams,
322) -> Result<Vec<MaskedRegion>> {
323 if seq.is_empty() {
324 return Err(CyaneaError::InvalidInput("sequence is empty".into()));
325 }
326
327 let min_p = params.min_period.max(1);
328 let max_p = params.max_period.min(seq.len());
329
330 let mut raw_regions: Vec<(usize, usize, f64)> = Vec::new();
331
332 for p in min_p..=max_p {
333 let min_len = p * params.min_copies;
334 if min_len > seq.len() {
335 continue;
336 }
337
338 let mut i = p;
339 while i < seq.len() {
340 if seq[i].to_ascii_uppercase() == seq[i - p].to_ascii_uppercase() {
342 let run_start = i - p;
344 let mut run_end = i + 1;
345 while run_end < seq.len()
346 && seq[run_end].to_ascii_uppercase()
347 == seq[run_end - p].to_ascii_uppercase()
348 {
349 run_end += 1;
350 }
351 let run_len = run_end - run_start;
352 let copies = run_len / p;
353 if copies >= params.min_copies {
354 let trimmed_end = run_start + copies * p;
356 raw_regions.push((run_start, trimmed_end, copies as f64));
357 }
358 i = run_end;
359 } else {
360 i += 1;
361 }
362 }
363 }
364
365 let merged = merge_regions(&raw_regions, 0);
366
367 Ok(merged
368 .into_iter()
369 .map(|(start, end, score)| MaskedRegion {
370 start,
371 end,
372 score,
373 source: MaskSource::TandemRepeat,
374 })
375 .collect())
376}
377
378pub fn apply_mask(
387 seq: &[u8],
388 regions: &[MaskedRegion],
389 mode: MaskMode,
390 is_protein: bool,
391) -> MaskResult {
392 let mut out = seq.to_vec();
393 let mut masked_positions = vec![false; seq.len()];
394
395 for region in regions {
396 let start = region.start.min(seq.len());
397 let end = region.end.min(seq.len());
398 for i in start..end {
399 masked_positions[i] = true;
400 match mode {
401 MaskMode::Soft => {
402 out[i] = out[i].to_ascii_lowercase();
403 }
404 MaskMode::Hard => {
405 out[i] = if is_protein { b'X' } else { b'N' };
406 }
407 }
408 }
409 }
410
411 let masked_count = masked_positions.iter().filter(|&&m| m).count();
412 let masked_fraction = if seq.is_empty() {
413 0.0
414 } else {
415 masked_count as f64 / seq.len() as f64
416 };
417
418 MaskResult {
419 sequence: out,
420 regions: regions.to_vec(),
421 masked_fraction,
422 }
423}
424
425pub fn mask_dust(seq: &[u8], params: &DustParams, mode: MaskMode) -> Result<MaskResult> {
431 let regions = dust(seq, params)?;
432 Ok(apply_mask(seq, ®ions, mode, false))
433}
434
435pub fn mask_seg(seq: &[u8], params: &SegParams, mode: MaskMode) -> Result<MaskResult> {
441 let regions = seg(seq, params)?;
442 Ok(apply_mask(seq, ®ions, mode, true))
443}
444
445fn merge_regions(regions: &[(usize, usize, f64)], gap: usize) -> Vec<(usize, usize, f64)> {
451 if regions.is_empty() {
452 return Vec::new();
453 }
454
455 let mut sorted: Vec<(usize, usize, f64)> = regions.to_vec();
456 sorted.sort_by_key(|r| r.0);
457
458 let mut merged = vec![sorted[0]];
459 for &(start, end, score) in &sorted[1..] {
460 let last = merged.last_mut().unwrap();
461 if start <= last.1 + gap {
462 last.1 = last.1.max(end);
463 if score > last.2 {
464 last.2 = score;
465 }
466 } else {
467 merged.push((start, end, score));
468 }
469 }
470
471 merged
472}
473
474#[cfg(test)]
479mod tests {
480 use super::*;
481
482 #[test]
485 fn dust_homopolymer() {
486 let seq = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
488 let regions = dust(seq, &DustParams::default()).unwrap();
489 assert!(!regions.is_empty(), "homopolymer should be masked");
490 }
491
492 #[test]
493 fn dust_random_dna() {
494 let seq = b"ACGTACGTACGTTGCATGCATGCAACGTACGTACGTTGCATGCATGCAACGTACGTACGTTGCA";
496 let regions = dust(seq, &DustParams::default()).unwrap();
497 assert!(regions.is_empty(), "diverse DNA should not be masked");
498 }
499
500 #[test]
501 fn dust_dinucleotide_repeat() {
502 let seq = b"ATATATATATATATATATATATATATATATATATATATATATATATATATATATATATATATATAT";
505 let params = DustParams {
506 threshold: 10.0,
507 ..Default::default()
508 };
509 let regions = dust(seq, ¶ms).unwrap();
510 assert!(!regions.is_empty(), "dinucleotide repeat should be masked at threshold 10");
511 }
512
513 #[test]
514 fn dust_empty() {
515 let result = dust(b"", &DustParams::default());
516 assert!(result.is_err());
517 }
518
519 #[test]
520 fn dust_short() {
521 let regions = dust(b"AC", &DustParams::default()).unwrap();
522 assert!(regions.is_empty());
523 }
524
525 #[test]
528 fn seg_poly_ala() {
529 let seq = b"AAAAAAAAAAAA";
530 let regions = seg(seq, &SegParams::default()).unwrap();
531 assert!(!regions.is_empty(), "poly-Ala should be low complexity");
532 }
533
534 #[test]
535 fn seg_diverse_protein() {
536 let seq = b"MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLSTPDAVMGNPKVK";
537 let params = SegParams::default();
538 let regions = seg(seq, ¶ms).unwrap();
539 let total_masked: usize = regions.iter().map(|r| r.end - r.start).sum();
542 assert!(
543 total_masked < seq.len() / 2,
544 "diverse protein should be mostly unmasked, masked {} of {}",
545 total_masked,
546 seq.len()
547 );
548 }
549
550 #[test]
551 fn seg_empty() {
552 let result = seg(b"", &SegParams::default());
553 assert!(result.is_err());
554 }
555
556 #[test]
557 fn seg_extension() {
558 let seq = b"QQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQ";
560 let regions = seg(seq, &SegParams::default()).unwrap();
561 if !regions.is_empty() {
562 assert!(
563 regions[0].end - regions[0].start > 12,
564 "should extend beyond initial window"
565 );
566 }
567 }
568
569 #[test]
572 fn tandem_dinucleotide() {
573 let seq = b"ACACACACACACACACAC";
574 let regions = find_tandem_repeats(seq, &TandemRepeatParams::default()).unwrap();
575 assert!(!regions.is_empty(), "AC repeat should be found");
576 }
577
578 #[test]
579 fn tandem_trinucleotide() {
580 let seq = b"CAGCAGCAGCAGCAGCAG";
581 let regions = find_tandem_repeats(seq, &TandemRepeatParams::default()).unwrap();
582 assert!(!regions.is_empty(), "CAG repeat should be found");
583 }
584
585 #[test]
586 fn tandem_min_copies() {
587 let seq = b"ACACAC"; let params = TandemRepeatParams {
589 min_copies: 4,
590 ..Default::default()
591 };
592 let regions = find_tandem_repeats(seq, ¶ms).unwrap();
593 let p2_regions: Vec<_> = regions
596 .iter()
597 .filter(|r| (r.end - r.start) >= 8) .collect();
599 assert!(p2_regions.is_empty(), "3 copies should not meet min_copies=4 for period 2");
600 }
601
602 #[test]
603 fn tandem_empty() {
604 let result = find_tandem_repeats(b"", &TandemRepeatParams::default());
605 assert!(result.is_err());
606 }
607
608 #[test]
611 fn soft_mask_output() {
612 let seq = b"ACGTACGT";
613 let regions = vec![MaskedRegion {
614 start: 2,
615 end: 5,
616 score: 1.0,
617 source: MaskSource::Dust,
618 }];
619 let result = apply_mask(seq, ®ions, MaskMode::Soft, false);
620 assert_eq!(result.sequence, b"ACgtaCGT");
621 assert_eq!(result.sequence.len(), seq.len());
622 }
623
624 #[test]
625 fn hard_mask_dna() {
626 let seq = b"ACGTACGT";
627 let regions = vec![MaskedRegion {
628 start: 0,
629 end: 4,
630 score: 1.0,
631 source: MaskSource::Dust,
632 }];
633 let result = apply_mask(seq, ®ions, MaskMode::Hard, false);
634 assert_eq!(result.sequence, b"NNNNACGT");
635 }
636
637 #[test]
638 fn hard_mask_protein() {
639 let seq = b"MVHLTPEE";
640 let regions = vec![MaskedRegion {
641 start: 1,
642 end: 3,
643 score: 1.0,
644 source: MaskSource::Seg,
645 }];
646 let result = apply_mask(seq, ®ions, MaskMode::Hard, true);
647 assert_eq!(result.sequence, b"MXXLTPEE");
648 }
649
650 #[test]
651 fn masked_fraction() {
652 let seq = b"AAAAAAAA"; let regions = vec![MaskedRegion {
654 start: 0,
655 end: 4,
656 score: 1.0,
657 source: MaskSource::Dust,
658 }];
659 let result = apply_mask(seq, ®ions, MaskMode::Soft, false);
660 assert!((result.masked_fraction - 0.5).abs() < 1e-10);
661 }
662
663 #[test]
664 fn mask_preserves_length() {
665 let seq = b"ACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGT";
666 let result = mask_dust(seq, &DustParams::default(), MaskMode::Soft).unwrap();
667 assert_eq!(result.sequence.len(), seq.len());
668 }
669}
670
671#[cfg(test)]
672mod proptests {
673 use super::*;
674 use proptest::prelude::*;
675
676 fn dna_seq(max_len: usize) -> impl Strategy<Value = Vec<u8>> {
677 proptest::collection::vec(
678 prop_oneof![Just(b'A'), Just(b'C'), Just(b'G'), Just(b'T')],
679 1..=max_len,
680 )
681 }
682
683 proptest! {
684 #[test]
685 fn mask_preserves_length(seq in dna_seq(200)) {
686 let result = mask_dust(&seq, &DustParams::default(), MaskMode::Soft).unwrap();
687 prop_assert_eq!(result.sequence.len(), seq.len());
688 }
689 }
690}