Skip to main content

cyanea_seq/
masking.rs

1//! Low-complexity and repeat masking for biological sequences.
2//!
3//! Implements DUST (DNA), SEG (protein), tandem repeat detection, and
4//! soft/hard masking application.
5
6use cyanea_core::{CyaneaError, Result};
7
8// ---------------------------------------------------------------------------
9// Types
10// ---------------------------------------------------------------------------
11
12/// How to mask identified regions.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum MaskMode {
15    /// Lowercase the masked bases.
16    Soft,
17    /// Replace with N (DNA) or X (protein).
18    Hard,
19}
20
21/// Source algorithm that identified a masked region.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum MaskSource {
24    Dust,
25    Seg,
26    TandemRepeat,
27}
28
29/// A region identified for masking.
30#[derive(Debug, Clone)]
31pub struct MaskedRegion {
32    /// Start position (0-indexed, inclusive).
33    pub start: usize,
34    /// End position (0-indexed, exclusive).
35    pub end: usize,
36    /// Score (algorithm-dependent).
37    pub score: f64,
38    /// Which algorithm found this region.
39    pub source: MaskSource,
40}
41
42/// Result of applying a mask to a sequence.
43#[derive(Debug, Clone)]
44pub struct MaskResult {
45    /// The masked sequence (same length as input).
46    pub sequence: Vec<u8>,
47    /// Regions that were masked.
48    pub regions: Vec<MaskedRegion>,
49    /// Fraction of the sequence that was masked, in [0.0, 1.0].
50    pub masked_fraction: f64,
51}
52
53// ---------------------------------------------------------------------------
54// DUST parameters
55// ---------------------------------------------------------------------------
56
57/// Parameters for the DUST low-complexity filter (DNA).
58#[derive(Debug, Clone)]
59pub struct DustParams {
60    /// Sliding window size (default: 64).
61    pub window: usize,
62    /// Score threshold; regions scoring above are masked (default: 20.0).
63    pub threshold: f64,
64    /// Join regions within this many bases (default: 1).
65    pub linker: usize,
66}
67
68impl Default for DustParams {
69    fn default() -> Self {
70        Self {
71            window: 64,
72            threshold: 20.0,
73            linker: 1,
74        }
75    }
76}
77
78// ---------------------------------------------------------------------------
79// SEG parameters
80// ---------------------------------------------------------------------------
81
82/// Parameters for the SEG low-complexity filter (protein).
83#[derive(Debug, Clone)]
84pub struct SegParams {
85    /// Sliding window size (default: 12).
86    pub window: usize,
87    /// Trigger masking when entropy ≤ lowcut (default: 2.2).
88    pub lowcut: f64,
89    /// Stop extending when entropy reaches highcut (default: 2.5).
90    pub highcut: f64,
91}
92
93impl Default for SegParams {
94    fn default() -> Self {
95        Self {
96            window: 12,
97            lowcut: 2.2,
98            highcut: 2.5,
99        }
100    }
101}
102
103// ---------------------------------------------------------------------------
104// Tandem repeat parameters
105// ---------------------------------------------------------------------------
106
107/// Parameters for tandem repeat detection.
108#[derive(Debug, Clone)]
109pub struct TandemRepeatParams {
110    /// Minimum repeat unit period (default: 1).
111    pub min_period: usize,
112    /// Maximum repeat unit period (default: 6).
113    pub max_period: usize,
114    /// Minimum number of complete repeat copies (default: 3).
115    pub min_copies: usize,
116}
117
118impl Default for TandemRepeatParams {
119    fn default() -> Self {
120        Self {
121            min_period: 1,
122            max_period: 6,
123            min_copies: 3,
124        }
125    }
126}
127
128// ---------------------------------------------------------------------------
129// DUST algorithm
130// ---------------------------------------------------------------------------
131
132/// Compute DUST triplet score for a window.
133///
134/// Count all 64 DNA triplets in the window, then score = Σ(c*(c-1)/2) / (W-2).
135fn dust_score(window: &[u8]) -> f64 {
136    if window.len() < 3 {
137        return 0.0;
138    }
139
140    let mut counts = [0u32; 64];
141    for tri in window.windows(3) {
142        let idx = triplet_index(tri);
143        if let Some(i) = idx {
144            counts[i] += 1;
145        }
146    }
147
148    let mut score = 0.0f64;
149    for &c in &counts {
150        if c > 1 {
151            score += (c as f64) * (c as f64 - 1.0) / 2.0;
152        }
153    }
154
155    let denom = (window.len() as f64) - 2.0;
156    if denom > 0.0 {
157        score / denom
158    } else {
159        0.0
160    }
161}
162
163/// Map a DNA triplet to an index in [0, 64).
164fn triplet_index(tri: &[u8]) -> Option<usize> {
165    let map = |b: u8| -> Option<usize> {
166        match b.to_ascii_uppercase() {
167            b'A' => Some(0),
168            b'C' => Some(1),
169            b'G' => Some(2),
170            b'T' | b'U' => Some(3),
171            _ => None,
172        }
173    };
174    Some(map(tri[0])? * 16 + map(tri[1])? * 4 + map(tri[2])?)
175}
176
177/// Identify low-complexity regions in a DNA sequence using the DUST algorithm.
178///
179/// # Errors
180///
181/// Returns an error if the sequence is empty.
182pub fn dust(seq: &[u8], params: &DustParams) -> Result<Vec<MaskedRegion>> {
183    if seq.is_empty() {
184        return Err(CyaneaError::InvalidInput("sequence is empty".into()));
185    }
186
187    let w = params.window.min(seq.len());
188    if w < 3 {
189        return Ok(Vec::new());
190    }
191
192    let mut raw_regions: Vec<(usize, usize, f64)> = Vec::new();
193
194    for start in 0..=seq.len().saturating_sub(w) {
195        let window = &seq[start..start + w];
196        let score = dust_score(window);
197        if score > params.threshold {
198            raw_regions.push((start, start + w, score));
199        }
200    }
201
202    // Merge overlapping / linker-adjacent regions
203    let merged = merge_regions(&raw_regions, params.linker);
204
205    Ok(merged
206        .into_iter()
207        .map(|(start, end, score)| MaskedRegion {
208            start,
209            end,
210            score,
211            source: MaskSource::Dust,
212        })
213        .collect())
214}
215
216// ---------------------------------------------------------------------------
217// SEG algorithm
218// ---------------------------------------------------------------------------
219
220/// Shannon entropy of amino acid frequencies in a window.
221fn aa_entropy(window: &[u8]) -> f64 {
222    let mut counts = [0u32; 26]; // A-Z
223    let mut total = 0u32;
224
225    for &b in window {
226        let upper = b.to_ascii_uppercase();
227        if upper >= b'A' && upper <= b'Z' {
228            counts[(upper - b'A') as usize] += 1;
229            total += 1;
230        }
231    }
232
233    if total == 0 {
234        return 0.0;
235    }
236
237    let mut entropy = 0.0f64;
238    let t = total as f64;
239    for &c in &counts {
240        if c > 0 {
241            let p = c as f64 / t;
242            entropy -= p * p.log2();
243        }
244    }
245    entropy
246}
247
248/// Identify low-complexity regions in a protein sequence using the SEG algorithm.
249///
250/// # Errors
251///
252/// Returns an error if the sequence is empty.
253pub fn seg(seq: &[u8], params: &SegParams) -> Result<Vec<MaskedRegion>> {
254    if seq.is_empty() {
255        return Err(CyaneaError::InvalidInput("sequence is empty".into()));
256    }
257
258    let w = params.window.min(seq.len());
259    if w < 2 {
260        return Ok(Vec::new());
261    }
262
263    let mut raw_regions: Vec<(usize, usize, f64)> = Vec::new();
264
265    for start in 0..=seq.len().saturating_sub(w) {
266        let window = &seq[start..start + w];
267        let ent = aa_entropy(window);
268        if ent <= params.lowcut {
269            // Extend in both directions until entropy >= highcut
270            let mut ext_start = start;
271            let mut ext_end = start + w;
272
273            // Extend left
274            while ext_start > 0 {
275                let candidate = &seq[ext_start - 1..ext_end];
276                if aa_entropy(candidate) <= params.highcut {
277                    ext_start -= 1;
278                } else {
279                    break;
280                }
281            }
282
283            // Extend right
284            while ext_end < seq.len() {
285                let candidate = &seq[ext_start..ext_end + 1];
286                if aa_entropy(candidate) <= params.highcut {
287                    ext_end += 1;
288                } else {
289                    break;
290                }
291            }
292
293            raw_regions.push((ext_start, ext_end, ent));
294        }
295    }
296
297    let merged = merge_regions(&raw_regions, 0);
298
299    Ok(merged
300        .into_iter()
301        .map(|(start, end, score)| MaskedRegion {
302            start,
303            end,
304            score,
305            source: MaskSource::Seg,
306        })
307        .collect())
308}
309
310// ---------------------------------------------------------------------------
311// Tandem repeat detection
312// ---------------------------------------------------------------------------
313
314/// Find tandem repeat regions in a sequence.
315///
316/// # Errors
317///
318/// Returns an error if the sequence is empty.
319pub fn find_tandem_repeats(
320    seq: &[u8],
321    params: &TandemRepeatParams,
322) -> Result<Vec<MaskedRegion>> {
323    if seq.is_empty() {
324        return Err(CyaneaError::InvalidInput("sequence is empty".into()));
325    }
326
327    let min_p = params.min_period.max(1);
328    let max_p = params.max_period.min(seq.len());
329
330    let mut raw_regions: Vec<(usize, usize, f64)> = Vec::new();
331
332    for p in min_p..=max_p {
333        let min_len = p * params.min_copies;
334        if min_len > seq.len() {
335            continue;
336        }
337
338        let mut i = p;
339        while i < seq.len() {
340            // Check if seq[i] matches seq[i-p]
341            if seq[i].to_ascii_uppercase() == seq[i - p].to_ascii_uppercase() {
342                // Found a match — extend the run
343                let run_start = i - p;
344                let mut run_end = i + 1;
345                while run_end < seq.len()
346                    && seq[run_end].to_ascii_uppercase()
347                        == seq[run_end - p].to_ascii_uppercase()
348                {
349                    run_end += 1;
350                }
351                let run_len = run_end - run_start;
352                let copies = run_len / p;
353                if copies >= params.min_copies {
354                    // Trim to complete copies
355                    let trimmed_end = run_start + copies * p;
356                    raw_regions.push((run_start, trimmed_end, copies as f64));
357                }
358                i = run_end;
359            } else {
360                i += 1;
361            }
362        }
363    }
364
365    let merged = merge_regions(&raw_regions, 0);
366
367    Ok(merged
368        .into_iter()
369        .map(|(start, end, score)| MaskedRegion {
370            start,
371            end,
372            score,
373            source: MaskSource::TandemRepeat,
374        })
375        .collect())
376}
377
378// ---------------------------------------------------------------------------
379// Masking application
380// ---------------------------------------------------------------------------
381
382/// Apply masking to a sequence given a set of regions.
383///
384/// - `Soft`: lowercase the masked bases.
385/// - `Hard`: replace with `N` (DNA) or `X` (protein).
386pub fn apply_mask(
387    seq: &[u8],
388    regions: &[MaskedRegion],
389    mode: MaskMode,
390    is_protein: bool,
391) -> MaskResult {
392    let mut out = seq.to_vec();
393    let mut masked_positions = vec![false; seq.len()];
394
395    for region in regions {
396        let start = region.start.min(seq.len());
397        let end = region.end.min(seq.len());
398        for i in start..end {
399            masked_positions[i] = true;
400            match mode {
401                MaskMode::Soft => {
402                    out[i] = out[i].to_ascii_lowercase();
403                }
404                MaskMode::Hard => {
405                    out[i] = if is_protein { b'X' } else { b'N' };
406                }
407            }
408        }
409    }
410
411    let masked_count = masked_positions.iter().filter(|&&m| m).count();
412    let masked_fraction = if seq.is_empty() {
413        0.0
414    } else {
415        masked_count as f64 / seq.len() as f64
416    };
417
418    MaskResult {
419        sequence: out,
420        regions: regions.to_vec(),
421        masked_fraction,
422    }
423}
424
425/// Run DUST and apply masking in one step.
426///
427/// # Errors
428///
429/// Returns an error if the sequence is empty.
430pub fn mask_dust(seq: &[u8], params: &DustParams, mode: MaskMode) -> Result<MaskResult> {
431    let regions = dust(seq, params)?;
432    Ok(apply_mask(seq, &regions, mode, false))
433}
434
435/// Run SEG and apply masking in one step.
436///
437/// # Errors
438///
439/// Returns an error if the sequence is empty.
440pub fn mask_seg(seq: &[u8], params: &SegParams, mode: MaskMode) -> Result<MaskResult> {
441    let regions = seg(seq, params)?;
442    Ok(apply_mask(seq, &regions, mode, true))
443}
444
445// ---------------------------------------------------------------------------
446// Helpers
447// ---------------------------------------------------------------------------
448
449/// Merge overlapping or adjacent regions.
450fn merge_regions(regions: &[(usize, usize, f64)], gap: usize) -> Vec<(usize, usize, f64)> {
451    if regions.is_empty() {
452        return Vec::new();
453    }
454
455    let mut sorted: Vec<(usize, usize, f64)> = regions.to_vec();
456    sorted.sort_by_key(|r| r.0);
457
458    let mut merged = vec![sorted[0]];
459    for &(start, end, score) in &sorted[1..] {
460        let last = merged.last_mut().unwrap();
461        if start <= last.1 + gap {
462            last.1 = last.1.max(end);
463            if score > last.2 {
464                last.2 = score;
465            }
466        } else {
467            merged.push((start, end, score));
468        }
469    }
470
471    merged
472}
473
474// ---------------------------------------------------------------------------
475// Tests
476// ---------------------------------------------------------------------------
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    // --- DUST ---
483
484    #[test]
485    fn dust_homopolymer() {
486        // Long poly-A run should score very high
487        let seq = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
488        let regions = dust(seq, &DustParams::default()).unwrap();
489        assert!(!regions.is_empty(), "homopolymer should be masked");
490    }
491
492    #[test]
493    fn dust_random_dna() {
494        // Diverse sequence should not be masked with default threshold
495        let seq = b"ACGTACGTACGTTGCATGCATGCAACGTACGTACGTTGCATGCATGCAACGTACGTACGTTGCA";
496        let regions = dust(seq, &DustParams::default()).unwrap();
497        assert!(regions.is_empty(), "diverse DNA should not be masked");
498    }
499
500    #[test]
501    fn dust_dinucleotide_repeat() {
502        // AT repeat produces score ~15 (only 2 distinct triplets), below default 20.
503        // Use a lower threshold to detect it.
504        let seq = b"ATATATATATATATATATATATATATATATATATATATATATATATATATATATATATATATATAT";
505        let params = DustParams {
506            threshold: 10.0,
507            ..Default::default()
508        };
509        let regions = dust(seq, &params).unwrap();
510        assert!(!regions.is_empty(), "dinucleotide repeat should be masked at threshold 10");
511    }
512
513    #[test]
514    fn dust_empty() {
515        let result = dust(b"", &DustParams::default());
516        assert!(result.is_err());
517    }
518
519    #[test]
520    fn dust_short() {
521        let regions = dust(b"AC", &DustParams::default()).unwrap();
522        assert!(regions.is_empty());
523    }
524
525    // --- SEG ---
526
527    #[test]
528    fn seg_poly_ala() {
529        let seq = b"AAAAAAAAAAAA";
530        let regions = seg(seq, &SegParams::default()).unwrap();
531        assert!(!regions.is_empty(), "poly-Ala should be low complexity");
532    }
533
534    #[test]
535    fn seg_diverse_protein() {
536        let seq = b"MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLSTPDAVMGNPKVK";
537        let params = SegParams::default();
538        let regions = seg(seq, &params).unwrap();
539        // Hemoglobin sequence is mostly high complexity
540        // Allow some regions but most should be unmasked
541        let total_masked: usize = regions.iter().map(|r| r.end - r.start).sum();
542        assert!(
543            total_masked < seq.len() / 2,
544            "diverse protein should be mostly unmasked, masked {} of {}",
545            total_masked,
546            seq.len()
547        );
548    }
549
550    #[test]
551    fn seg_empty() {
552        let result = seg(b"", &SegParams::default());
553        assert!(result.is_err());
554    }
555
556    #[test]
557    fn seg_extension() {
558        // Very low complexity should extend beyond initial window
559        let seq = b"QQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQ";
560        let regions = seg(seq, &SegParams::default()).unwrap();
561        if !regions.is_empty() {
562            assert!(
563                regions[0].end - regions[0].start > 12,
564                "should extend beyond initial window"
565            );
566        }
567    }
568
569    // --- Tandem repeats ---
570
571    #[test]
572    fn tandem_dinucleotide() {
573        let seq = b"ACACACACACACACACAC";
574        let regions = find_tandem_repeats(seq, &TandemRepeatParams::default()).unwrap();
575        assert!(!regions.is_empty(), "AC repeat should be found");
576    }
577
578    #[test]
579    fn tandem_trinucleotide() {
580        let seq = b"CAGCAGCAGCAGCAGCAG";
581        let regions = find_tandem_repeats(seq, &TandemRepeatParams::default()).unwrap();
582        assert!(!regions.is_empty(), "CAG repeat should be found");
583    }
584
585    #[test]
586    fn tandem_min_copies() {
587        let seq = b"ACACAC"; // 3 copies of AC
588        let params = TandemRepeatParams {
589            min_copies: 4,
590            ..Default::default()
591        };
592        let regions = find_tandem_repeats(seq, &params).unwrap();
593        // 3 copies < 4 minimum at period 2
594        // But period 1 might match — check specifically
595        let p2_regions: Vec<_> = regions
596            .iter()
597            .filter(|r| (r.end - r.start) >= 8) // need at least 4 copies of period 2
598            .collect();
599        assert!(p2_regions.is_empty(), "3 copies should not meet min_copies=4 for period 2");
600    }
601
602    #[test]
603    fn tandem_empty() {
604        let result = find_tandem_repeats(b"", &TandemRepeatParams::default());
605        assert!(result.is_err());
606    }
607
608    // --- Masking ---
609
610    #[test]
611    fn soft_mask_output() {
612        let seq = b"ACGTACGT";
613        let regions = vec![MaskedRegion {
614            start: 2,
615            end: 5,
616            score: 1.0,
617            source: MaskSource::Dust,
618        }];
619        let result = apply_mask(seq, &regions, MaskMode::Soft, false);
620        assert_eq!(result.sequence, b"ACgtaCGT");
621        assert_eq!(result.sequence.len(), seq.len());
622    }
623
624    #[test]
625    fn hard_mask_dna() {
626        let seq = b"ACGTACGT";
627        let regions = vec![MaskedRegion {
628            start: 0,
629            end: 4,
630            score: 1.0,
631            source: MaskSource::Dust,
632        }];
633        let result = apply_mask(seq, &regions, MaskMode::Hard, false);
634        assert_eq!(result.sequence, b"NNNNACGT");
635    }
636
637    #[test]
638    fn hard_mask_protein() {
639        let seq = b"MVHLTPEE";
640        let regions = vec![MaskedRegion {
641            start: 1,
642            end: 3,
643            score: 1.0,
644            source: MaskSource::Seg,
645        }];
646        let result = apply_mask(seq, &regions, MaskMode::Hard, true);
647        assert_eq!(result.sequence, b"MXXLTPEE");
648    }
649
650    #[test]
651    fn masked_fraction() {
652        let seq = b"AAAAAAAA"; // 8 bases
653        let regions = vec![MaskedRegion {
654            start: 0,
655            end: 4,
656            score: 1.0,
657            source: MaskSource::Dust,
658        }];
659        let result = apply_mask(seq, &regions, MaskMode::Soft, false);
660        assert!((result.masked_fraction - 0.5).abs() < 1e-10);
661    }
662
663    #[test]
664    fn mask_preserves_length() {
665        let seq = b"ACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGT";
666        let result = mask_dust(seq, &DustParams::default(), MaskMode::Soft).unwrap();
667        assert_eq!(result.sequence.len(), seq.len());
668    }
669}
670
671#[cfg(test)]
672mod proptests {
673    use super::*;
674    use proptest::prelude::*;
675
676    fn dna_seq(max_len: usize) -> impl Strategy<Value = Vec<u8>> {
677        proptest::collection::vec(
678            prop_oneof![Just(b'A'), Just(b'C'), Just(b'G'), Just(b'T')],
679            1..=max_len,
680        )
681    }
682
683    proptest! {
684        #[test]
685        fn mask_preserves_length(seq in dna_seq(200)) {
686            let result = mask_dust(&seq, &DustParams::default(), MaskMode::Soft).unwrap();
687            prop_assert_eq!(result.sequence.len(), seq.len());
688        }
689    }
690}