Skip to main content

cyanea_seq/
trim.rs

1//! Quality trimming, adapter removal, and read filtering for FASTQ records.
2//!
3//! Two-level API:
4//!
5//! 1. **Low-level functions** operate on `&[u8]` quality/sequence slices and return
6//!    [`TrimRange`] values that can be composed via [`intersect_ranges`].
7//! 2. **High-level functions** operate on [`FastqRecord`] and return
8//!    `Option<FastqRecord>` (None = filtered out).
9//! 3. **[`TrimPipeline`]** builder chains operations in Trimmomatic-style order
10//!    and collects statistics via [`TrimReport`].
11//!
12//! # Example
13//!
14//! ```
15//! use cyanea_seq::trim::{TrimPipeline, adapters};
16//! use cyanea_seq::{FastqRecord, DnaSequence, QualityScores};
17//!
18//! let pipeline = TrimPipeline::new()
19//!     .adapter(adapters::TRUSEQ_PREFIX)
20//!     .leading(3)
21//!     .trailing(3)
22//!     .sliding_window(4, 15.0)
23//!     .min_length(4);
24//!
25//! let seq = DnaSequence::new(b"ACGTACGTACGTACGT").unwrap();
26//! let qual = QualityScores::from_raw(vec![30; 16]);
27//! let record = FastqRecord::new("read1".into(), None, seq, qual).unwrap();
28//!
29//! let result = pipeline.process(&record);
30//! assert!(result.is_some());
31//! ```
32
33use crate::fastq::FastqRecord;
34use crate::quality::QualityScores;
35use crate::seq::ValidatedSeq;
36use crate::alphabet::DnaAlphabet;
37use cyanea_core::{Annotated, Sequence};
38
39/// A half-open range `[start, end)` describing which portion of a read to keep.
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub struct TrimRange {
42    pub start: usize,
43    pub end: usize,
44}
45
46impl TrimRange {
47    /// Length of the range (0 if empty).
48    pub fn len(&self) -> usize {
49        if self.end > self.start {
50            self.end - self.start
51        } else {
52            0
53        }
54    }
55
56    /// Whether this range is empty.
57    pub fn is_empty(&self) -> bool {
58        self.end <= self.start
59    }
60}
61
62/// Intersect multiple trim ranges, returning the overlap.
63///
64/// Returns a range covering only positions present in *all* input ranges.
65/// If `ranges` is empty, returns an empty range.
66pub fn intersect_ranges(ranges: &[TrimRange]) -> TrimRange {
67    if ranges.is_empty() {
68        return TrimRange { start: 0, end: 0 };
69    }
70    let start = ranges.iter().map(|r| r.start).max().unwrap();
71    let end = ranges.iter().map(|r| r.end).min().unwrap();
72    TrimRange {
73        start,
74        end: end.max(start),
75    }
76}
77
78// ---------------------------------------------------------------------------
79// Low-level trim functions
80// ---------------------------------------------------------------------------
81
82/// Trimmomatic-style sliding window trim.
83///
84/// Scans from the 5' end with a window of `window_size` bases. When the mean
85/// quality in the window drops below `threshold`, the read is cut at the start
86/// of that window. Returns the range of bases to keep.
87pub fn trim_sliding_window(quality: &[u8], window_size: usize, threshold: f64) -> TrimRange {
88    let len = quality.len();
89    if len == 0 || window_size == 0 {
90        return TrimRange { start: 0, end: 0 };
91    }
92
93    let ws = window_size.min(len);
94    let threshold_sum = threshold * ws as f64;
95
96    // Initial window sum
97    let mut sum: u64 = quality[..ws].iter().map(|&q| q as u64).sum();
98    if (sum as f64) < threshold_sum {
99        return TrimRange { start: 0, end: 0 };
100    }
101
102    for i in 1..=(len - ws) {
103        // Slide: remove left, add right
104        sum -= quality[i - 1] as u64;
105        sum += quality[i + ws - 1] as u64;
106        if (sum as f64) < threshold_sum {
107            return TrimRange { start: 0, end: i + ws - 1 };
108        }
109    }
110
111    TrimRange { start: 0, end: len }
112}
113
114/// Trim low-quality bases from the 5' (leading) end.
115///
116/// Removes consecutive bases from the start that have quality below `threshold`.
117pub fn trim_leading(quality: &[u8], threshold: u8) -> TrimRange {
118    let start = quality.iter().position(|&q| q >= threshold).unwrap_or(quality.len());
119    TrimRange {
120        start,
121        end: quality.len(),
122    }
123}
124
125/// Trim low-quality bases from the 3' (trailing) end.
126///
127/// Removes consecutive bases from the end that have quality below `threshold`.
128pub fn trim_trailing(quality: &[u8], threshold: u8) -> TrimRange {
129    let end = quality
130        .iter()
131        .rposition(|&q| q >= threshold)
132        .map(|i| i + 1)
133        .unwrap_or(0);
134    TrimRange { start: 0, end }
135}
136
137/// BWA-style quality trimming from the 3' end.
138///
139/// Scans right-to-left, accumulating `(threshold - Q[i])`. Resets the running
140/// sum to 0 when it goes negative. Cuts at the position where the sum was
141/// maximized. This effectively finds the longest suffix with low aggregate
142/// quality and removes it.
143pub fn trim_quality_3prime(quality: &[u8], threshold: u8) -> TrimRange {
144    let len = quality.len();
145    if len == 0 {
146        return TrimRange { start: 0, end: 0 };
147    }
148
149    let mut max_sum: i64 = 0;
150    let mut sum: i64 = 0;
151    let mut cut_pos = len;
152
153    for i in (0..len).rev() {
154        sum += threshold as i64 - quality[i] as i64;
155        if sum > max_sum {
156            max_sum = sum;
157            cut_pos = i;
158        }
159    }
160
161    TrimRange {
162        start: 0,
163        end: cut_pos,
164    }
165}
166
167/// Find the position of a 3' adapter in a sequence.
168///
169/// Checks overlaps at the 3' end of the read, longest first, requiring at most
170/// `max_mismatches` mismatches and a minimum overlap of `max(8, adapter.len()/3)`.
171/// Returns the position where the adapter starts (i.e., where to cut).
172/// If no adapter is found, returns `seq.len()`.
173pub fn find_adapter_3prime(seq: &[u8], adapter: &[u8], max_mismatches: usize) -> usize {
174    let slen = seq.len();
175    let alen = adapter.len();
176    if slen == 0 || alen == 0 {
177        return slen;
178    }
179
180    let min_overlap = 8.max(alen / 3);
181
182    // Check overlaps longest-first: adapter starts at position `start` in the read
183    // Overlap length = min(slen - start, alen)
184    for start in 0..slen {
185        let overlap = (slen - start).min(alen);
186        if overlap < min_overlap {
187            break;
188        }
189
190        let mismatches = seq[start..start + overlap]
191            .iter()
192            .zip(&adapter[..overlap])
193            .filter(|(&a, &b)| a != b)
194            .count();
195
196        if mismatches <= max_mismatches {
197            return start;
198        }
199    }
200
201    slen
202}
203
204/// Shannon entropy of base composition (bits, max 2.0 for 4-letter DNA).
205///
206/// Computes `-sum(p * log2(p))` over the frequencies of A, C, G, T.
207/// Bases not in {A, C, G, T} are ignored.
208pub fn shannon_entropy(seq: &[u8]) -> f64 {
209    if seq.is_empty() {
210        return 0.0;
211    }
212
213    let mut counts = [0u64; 4]; // A, C, G, T
214    for &b in seq {
215        match b {
216            b'A' | b'a' => counts[0] += 1,
217            b'C' | b'c' => counts[1] += 1,
218            b'G' | b'g' => counts[2] += 1,
219            b'T' | b't' => counts[3] += 1,
220            _ => {}
221        }
222    }
223
224    let total: u64 = counts.iter().sum();
225    if total == 0 {
226        return 0.0;
227    }
228
229    let mut entropy = 0.0;
230    for &c in &counts {
231        if c > 0 {
232            let p = c as f64 / total as f64;
233            entropy -= p * p.log2();
234        }
235    }
236    entropy
237}
238
239// ---------------------------------------------------------------------------
240// High-level functions
241// ---------------------------------------------------------------------------
242
243/// Apply a trim range to a FASTQ record, producing a new trimmed record.
244///
245/// Returns `None` if the range is empty (nothing left after trimming).
246/// Uses `from_validated()` internally since subslicing validated data is safe.
247pub fn apply_trim(record: &FastqRecord, range: TrimRange) -> Option<FastqRecord> {
248    if range.is_empty() || range.start >= record.sequence().len() {
249        return None;
250    }
251    let end = range.end.min(record.sequence().len());
252    if end <= range.start {
253        return None;
254    }
255
256    let seq_bytes = record.sequence().as_bytes()[range.start..end].to_vec();
257    let qual_bytes = record.quality().as_slice()[range.start..end].to_vec();
258
259    let sequence = ValidatedSeq::<DnaAlphabet>::from_validated(seq_bytes);
260    let quality = QualityScores::from_raw(qual_bytes);
261
262    // FastqRecord::new checks length match, but we guarantee it here.
263    FastqRecord::new(
264        record.name().to_string(),
265        record.description().map(|d| d.to_string()),
266        sequence,
267        quality,
268    )
269    .ok()
270}
271
272/// Remove a 3' adapter from a record.
273///
274/// Returns a new record with the adapter (and everything after it) removed.
275/// If no adapter is found, returns a clone of the original record.
276pub fn trim_adapter(record: &FastqRecord, adapter: &[u8], max_mismatches: usize) -> FastqRecord {
277    let cut = find_adapter_3prime(record.sequence().as_bytes(), adapter, max_mismatches);
278    if cut >= record.sequence().len() {
279        return record.clone();
280    }
281    let range = TrimRange { start: 0, end: cut };
282    apply_trim(record, range).unwrap_or_else(|| record.clone())
283}
284
285/// Filter a record by length.
286///
287/// Returns `None` if the record's length is outside `[min_len, max_len]`.
288pub fn filter_by_length<'a>(record: &'a FastqRecord, min_len: usize, max_len: usize) -> Option<&'a FastqRecord> {
289    let len = record.sequence().len();
290    if len >= min_len && len <= max_len {
291        Some(record)
292    } else {
293        None
294    }
295}
296
297/// Filter a record by low complexity.
298///
299/// Returns `None` if the Shannon entropy is below `min_entropy`.
300pub fn filter_low_complexity<'a>(record: &'a FastqRecord, min_entropy: f64) -> Option<&'a FastqRecord> {
301    if shannon_entropy(record.sequence().as_bytes()) >= min_entropy {
302        Some(record)
303    } else {
304        None
305    }
306}
307
308/// Filter a record by mean quality score.
309///
310/// Returns `None` if the mean quality is below `min_quality`.
311pub fn filter_by_quality<'a>(record: &'a FastqRecord, min_quality: f64) -> Option<&'a FastqRecord> {
312    if record.quality().mean() >= min_quality {
313        Some(record)
314    } else {
315        None
316    }
317}
318
319// ---------------------------------------------------------------------------
320// Adapter constants
321// ---------------------------------------------------------------------------
322
323/// Common sequencing adapter sequences.
324pub mod adapters {
325    /// Illumina TruSeq Universal Adapter.
326    pub const TRUSEQ_UNIVERSAL: &[u8] = b"AGATCGGAAGAGCACACGTCTGAACTCCAGTCA";
327    /// Illumina TruSeq Indexed Adapter.
328    pub const TRUSEQ_INDEXED: &[u8] = b"AGATCGGAAGAGCGTCGTGTAGGGAAAGAGTGT";
329    /// Nextera Transposase Read 1.
330    pub const NEXTERA_READ1: &[u8] = b"TCGTCGGCAGCGTCAGATGTGTATAAGAGACAG";
331    /// Nextera Transposase Read 2.
332    pub const NEXTERA_READ2: &[u8] = b"GTCTCGTGGGCTCGGAGATGTGTATAAGAGACAG";
333    /// Illumina Small RNA 3' Adapter.
334    pub const SMALL_RNA_3P: &[u8] = b"TGGAATTCTCGGGTGCCAAGG";
335    /// Common 12-base prefix shared by TruSeq adapters.
336    pub const TRUSEQ_PREFIX: &[u8] = b"AGATCGGAAGAG";
337    /// All standard Illumina adapters for batch searching.
338    pub const ALL_ILLUMINA: &[&[u8]] = &[
339        TRUSEQ_UNIVERSAL,
340        TRUSEQ_INDEXED,
341        NEXTERA_READ1,
342        NEXTERA_READ2,
343        SMALL_RNA_3P,
344    ];
345}
346
347// ---------------------------------------------------------------------------
348// TrimPipeline
349// ---------------------------------------------------------------------------
350
351/// Which quality trimming algorithm to use in the pipeline.
352#[derive(Debug, Clone)]
353enum QualityTrimAlgo {
354    None,
355    SlidingWindow { window_size: usize, threshold: f64 },
356    Bwa { threshold: u8 },
357}
358
359/// A configurable read-processing pipeline.
360///
361/// Operations are applied in a fixed order matching Trimmomatic convention:
362/// 1. Adapter trimming
363/// 2. Leading quality trim
364/// 3. Trailing quality trim
365/// 4. Sliding window / BWA quality trim
366/// 5. Length filter
367/// 6. Quality filter
368/// 7. Complexity filter
369///
370/// # Example
371///
372/// ```
373/// use cyanea_seq::trim::TrimPipeline;
374///
375/// let pipeline = TrimPipeline::new()
376///     .leading(3)
377///     .trailing(3)
378///     .sliding_window(4, 15.0)
379///     .min_length(36)
380///     .min_mean_quality(20.0);
381/// ```
382#[derive(Debug, Clone)]
383pub struct TrimPipeline {
384    adapters: Vec<Vec<u8>>,
385    adapter_max_mismatches: usize,
386    leading_threshold: Option<u8>,
387    trailing_threshold: Option<u8>,
388    quality_trim: QualityTrimAlgo,
389    min_length: Option<usize>,
390    max_length: Option<usize>,
391    min_mean_quality: Option<f64>,
392    min_entropy: Option<f64>,
393}
394
395impl TrimPipeline {
396    /// Create a new empty pipeline (no-op by default).
397    pub fn new() -> Self {
398        Self {
399            adapters: Vec::new(),
400            adapter_max_mismatches: 1,
401            leading_threshold: None,
402            trailing_threshold: None,
403            quality_trim: QualityTrimAlgo::None,
404            min_length: None,
405            max_length: None,
406            min_mean_quality: None,
407            min_entropy: None,
408        }
409    }
410
411    /// Add a single adapter sequence to search for and remove.
412    pub fn adapter(mut self, adapter: &[u8]) -> Self {
413        self.adapters.push(adapter.to_vec());
414        self
415    }
416
417    /// Add all standard Illumina adapters.
418    pub fn illumina_adapters(mut self) -> Self {
419        for &a in adapters::ALL_ILLUMINA {
420            self.adapters.push(a.to_vec());
421        }
422        self
423    }
424
425    /// Set the maximum number of mismatches allowed for adapter matching.
426    pub fn adapter_mismatches(mut self, max: usize) -> Self {
427        self.adapter_max_mismatches = max;
428        self
429    }
430
431    /// Trim bases below `threshold` quality from the 5' end.
432    pub fn leading(mut self, threshold: u8) -> Self {
433        self.leading_threshold = Some(threshold);
434        self
435    }
436
437    /// Trim bases below `threshold` quality from the 3' end.
438    pub fn trailing(mut self, threshold: u8) -> Self {
439        self.trailing_threshold = Some(threshold);
440        self
441    }
442
443    /// Use Trimmomatic-style sliding window trimming.
444    ///
445    /// Mutually exclusive with [`bwa_quality`](Self::bwa_quality) — setting
446    /// one clears the other.
447    pub fn sliding_window(mut self, window_size: usize, threshold: f64) -> Self {
448        self.quality_trim = QualityTrimAlgo::SlidingWindow {
449            window_size,
450            threshold,
451        };
452        self
453    }
454
455    /// Use BWA-style 3' quality trimming.
456    ///
457    /// Mutually exclusive with [`sliding_window`](Self::sliding_window) — setting
458    /// one clears the other.
459    pub fn bwa_quality(mut self, threshold: u8) -> Self {
460        self.quality_trim = QualityTrimAlgo::Bwa { threshold };
461        self
462    }
463
464    /// Set the minimum read length (reads shorter than this are discarded).
465    pub fn min_length(mut self, len: usize) -> Self {
466        self.min_length = Some(len);
467        self
468    }
469
470    /// Set the maximum read length (reads longer than this are discarded).
471    pub fn max_length(mut self, len: usize) -> Self {
472        self.max_length = Some(len);
473        self
474    }
475
476    /// Set the minimum mean quality (reads below this are discarded).
477    pub fn min_mean_quality(mut self, quality: f64) -> Self {
478        self.min_mean_quality = Some(quality);
479        self
480    }
481
482    /// Set the minimum Shannon entropy (reads below this are discarded).
483    pub fn min_entropy(mut self, entropy: f64) -> Self {
484        self.min_entropy = Some(entropy);
485        self
486    }
487
488    /// Process a single record through the pipeline.
489    ///
490    /// Returns `Some(trimmed_record)` if the record passes all filters,
491    /// or `None` if it was filtered out.
492    pub fn process(&self, record: &FastqRecord) -> Option<FastqRecord> {
493        let mut current = record.clone();
494
495        // 1. Adapter trimming — try each adapter, take the earliest hit
496        if !self.adapters.is_empty() {
497            let seq = current.sequence().as_bytes();
498            let mut best_cut = seq.len();
499            for adapter in &self.adapters {
500                let cut = find_adapter_3prime(seq, adapter, self.adapter_max_mismatches);
501                if cut < best_cut {
502                    best_cut = cut;
503                }
504            }
505            if best_cut < seq.len() {
506                let range = TrimRange { start: 0, end: best_cut };
507                current = apply_trim(&current, range)?;
508            }
509        }
510
511        // 2-4. Quality trimming via TrimRange composition
512        let quality = current.quality().as_slice();
513        let mut ranges = Vec::new();
514
515        if let Some(threshold) = self.leading_threshold {
516            ranges.push(trim_leading(quality, threshold));
517        }
518
519        if let Some(threshold) = self.trailing_threshold {
520            ranges.push(trim_trailing(quality, threshold));
521        }
522
523        match &self.quality_trim {
524            QualityTrimAlgo::SlidingWindow {
525                window_size,
526                threshold,
527            } => {
528                ranges.push(trim_sliding_window(quality, *window_size, *threshold));
529            }
530            QualityTrimAlgo::Bwa { threshold } => {
531                ranges.push(trim_quality_3prime(quality, *threshold));
532            }
533            QualityTrimAlgo::None => {}
534        }
535
536        if !ranges.is_empty() {
537            // Start with the full range, then intersect with each trim result
538            let full = TrimRange {
539                start: 0,
540                end: quality.len(),
541            };
542            ranges.insert(0, full);
543            let combined = intersect_ranges(&ranges);
544            current = apply_trim(&current, combined)?;
545        }
546
547        // 5. Length filter
548        let len = current.sequence().len();
549        if let Some(min) = self.min_length {
550            if len < min {
551                return None;
552            }
553        }
554        if let Some(max) = self.max_length {
555            if len > max {
556                return None;
557            }
558        }
559
560        // 6. Quality filter
561        if let Some(min_q) = self.min_mean_quality {
562            if current.quality().mean() < min_q {
563                return None;
564            }
565        }
566
567        // 7. Complexity filter
568        if let Some(min_e) = self.min_entropy {
569            if shannon_entropy(current.sequence().as_bytes()) < min_e {
570                return None;
571            }
572        }
573
574        Some(current)
575    }
576
577    /// Process a batch of records, returning only those that pass.
578    pub fn process_batch(&self, records: &[FastqRecord]) -> Vec<FastqRecord> {
579        records.iter().filter_map(|r| self.process(r)).collect()
580    }
581
582    /// Process a batch and collect detailed statistics.
583    pub fn process_batch_with_stats(&self, records: &[FastqRecord]) -> TrimReport {
584        let total_input = records.len();
585        let mut total_bases_input: u64 = 0;
586        let mut total_bases_output: u64 = 0;
587        let mut filtered_by_length: usize = 0;
588        let mut filtered_by_quality: usize = 0;
589        let mut filtered_by_complexity: usize = 0;
590        let mut adapters_found: usize = 0;
591        let mut kept = Vec::new();
592
593        for record in records {
594            total_bases_input += record.sequence().len() as u64;
595
596            // Track adapter detection
597            if !self.adapters.is_empty() {
598                let seq = record.sequence().as_bytes();
599                for adapter in &self.adapters {
600                    if find_adapter_3prime(seq, adapter, self.adapter_max_mismatches) < seq.len() {
601                        adapters_found += 1;
602                        break;
603                    }
604                }
605            }
606
607            // Run the full pipeline and track which filter rejected the record
608            match self.process(record) {
609                Some(trimmed) => {
610                    total_bases_output += trimmed.sequence().len() as u64;
611                    kept.push(trimmed);
612                }
613                None => {
614                    // Determine which filter rejected it by running steps incrementally
615                    let rejection = self.find_rejection_reason(record);
616                    match rejection {
617                        Rejection::Length => filtered_by_length += 1,
618                        Rejection::Quality => filtered_by_quality += 1,
619                        Rejection::Complexity => filtered_by_complexity += 1,
620                        Rejection::Trimmed => filtered_by_length += 1,
621                    }
622                }
623            }
624        }
625
626        TrimReport {
627            kept,
628            total_input,
629            total_output: total_input - filtered_by_length - filtered_by_quality - filtered_by_complexity,
630            filtered_by_length,
631            filtered_by_quality,
632            filtered_by_complexity,
633            adapters_found,
634            total_bases_input,
635            total_bases_output,
636        }
637    }
638}
639
640impl Default for TrimPipeline {
641    fn default() -> Self {
642        Self::new()
643    }
644}
645
646enum Rejection {
647    Trimmed,
648    Length,
649    Quality,
650    Complexity,
651}
652
653impl TrimPipeline {
654    fn find_rejection_reason(&self, record: &FastqRecord) -> Rejection {
655        let mut current = record.clone();
656
657        // Adapter trimming
658        if !self.adapters.is_empty() {
659            let seq = current.sequence().as_bytes();
660            let mut best_cut = seq.len();
661            for adapter in &self.adapters {
662                let cut = find_adapter_3prime(seq, adapter, self.adapter_max_mismatches);
663                if cut < best_cut {
664                    best_cut = cut;
665                }
666            }
667            if best_cut < seq.len() {
668                let range = TrimRange { start: 0, end: best_cut };
669                match apply_trim(&current, range) {
670                    Some(t) => current = t,
671                    None => return Rejection::Trimmed,
672                }
673            }
674        }
675
676        // Quality trimming
677        let quality = current.quality().as_slice();
678        let mut ranges = Vec::new();
679
680        if let Some(threshold) = self.leading_threshold {
681            ranges.push(trim_leading(quality, threshold));
682        }
683        if let Some(threshold) = self.trailing_threshold {
684            ranges.push(trim_trailing(quality, threshold));
685        }
686        match &self.quality_trim {
687            QualityTrimAlgo::SlidingWindow { window_size, threshold } => {
688                ranges.push(trim_sliding_window(quality, *window_size, *threshold));
689            }
690            QualityTrimAlgo::Bwa { threshold } => {
691                ranges.push(trim_quality_3prime(quality, *threshold));
692            }
693            QualityTrimAlgo::None => {}
694        }
695        if !ranges.is_empty() {
696            let full = TrimRange { start: 0, end: quality.len() };
697            ranges.insert(0, full);
698            let combined = intersect_ranges(&ranges);
699            match apply_trim(&current, combined) {
700                Some(t) => current = t,
701                None => return Rejection::Trimmed,
702            }
703        }
704
705        // Length filter
706        let len = current.sequence().len();
707        if let Some(min) = self.min_length {
708            if len < min { return Rejection::Length; }
709        }
710        if let Some(max) = self.max_length {
711            if len > max { return Rejection::Length; }
712        }
713
714        // Quality filter
715        if let Some(min_q) = self.min_mean_quality {
716            if current.quality().mean() < min_q { return Rejection::Quality; }
717        }
718
719        // Complexity filter
720        if let Some(min_e) = self.min_entropy {
721            if shannon_entropy(current.sequence().as_bytes()) < min_e {
722                return Rejection::Complexity;
723            }
724        }
725
726        // Shouldn't reach here if process() returned None, but default to trimmed
727        Rejection::Trimmed
728    }
729}
730
731/// Summary statistics from [`TrimPipeline::process_batch_with_stats`].
732#[derive(Debug, Clone)]
733pub struct TrimReport {
734    /// Records that passed all filters.
735    pub kept: Vec<FastqRecord>,
736    /// Total number of input records.
737    pub total_input: usize,
738    /// Total number of output records.
739    pub total_output: usize,
740    /// Records filtered out for being too short or too long.
741    pub filtered_by_length: usize,
742    /// Records filtered out for low mean quality.
743    pub filtered_by_quality: usize,
744    /// Records filtered out for low complexity.
745    pub filtered_by_complexity: usize,
746    /// Number of records where an adapter was detected.
747    pub adapters_found: usize,
748    /// Total bases in input records.
749    pub total_bases_input: u64,
750    /// Total bases in output records.
751    pub total_bases_output: u64,
752}
753
754// ---------------------------------------------------------------------------
755// Paired-end trimming
756// ---------------------------------------------------------------------------
757
758#[cfg(feature = "std")]
759use crate::paired::PairedFastqRecord;
760
761/// How to handle orphan reads where one mate passes and the other doesn't.
762#[cfg(feature = "std")]
763#[derive(Debug, Clone, Copy, PartialEq, Eq)]
764pub enum OrphanPolicy {
765    /// Drop both reads if either fails.
766    DropBoth,
767    /// Keep R1 if it passes, even if R2 fails.
768    KeepFirst,
769    /// Keep R2 if it passes, even if R1 fails.
770    KeepSecond,
771}
772
773/// Result of processing a single read pair through a trim pipeline.
774#[cfg(feature = "std")]
775#[derive(Debug, Clone)]
776pub enum PairedTrimResult {
777    /// Both reads passed all filters.
778    BothPassed(FastqRecord, FastqRecord),
779    /// Only R1 passed.
780    OnlyFirst(FastqRecord),
781    /// Only R2 passed.
782    OnlySecond(FastqRecord),
783    /// Both reads were filtered out.
784    Dropped,
785}
786
787/// Summary statistics from paired-end trim processing.
788#[cfg(feature = "std")]
789#[derive(Debug, Clone)]
790pub struct PairedTrimReport {
791    /// Pairs where both reads passed all filters.
792    pub kept: Vec<PairedFastqRecord>,
793    /// Total number of input pairs.
794    pub total_input: usize,
795    /// Pairs where both reads passed.
796    pub both_passed: usize,
797    /// Pairs where only R1 passed.
798    pub r1_only_passed: usize,
799    /// Pairs where only R2 passed.
800    pub r2_only_passed: usize,
801    /// Pairs where both reads failed.
802    pub both_failed: usize,
803    /// Total bases across all input reads (R1 + R2).
804    pub total_bases_input: u64,
805    /// Total bases across kept output reads (R1 + R2).
806    pub total_bases_output: u64,
807}
808
809#[cfg(feature = "std")]
810impl PairedTrimReport {
811    /// Number of orphan reads (one mate passed, the other didn't).
812    pub fn orphans(&self) -> usize {
813        self.r1_only_passed + self.r2_only_passed
814    }
815
816    /// Fraction of input pairs where both reads survived.
817    pub fn survival_rate(&self) -> f64 {
818        if self.total_input == 0 {
819            return 0.0;
820        }
821        self.both_passed as f64 / self.total_input as f64
822    }
823}
824
825#[cfg(feature = "std")]
826impl TrimPipeline {
827    /// Process a single read pair through the pipeline.
828    ///
829    /// Applies [`process`](Self::process) to each read independently, then
830    /// applies the orphan policy to decide what to keep.
831    pub fn process_paired(
832        &self,
833        r1: &FastqRecord,
834        r2: &FastqRecord,
835        policy: OrphanPolicy,
836    ) -> PairedTrimResult {
837        let r1_result = self.process(r1);
838        let r2_result = self.process(r2);
839
840        match (r1_result, r2_result) {
841            (Some(r1), Some(r2)) => PairedTrimResult::BothPassed(r1, r2),
842            (Some(r1), None) => match policy {
843                OrphanPolicy::KeepFirst => PairedTrimResult::OnlyFirst(r1),
844                _ => PairedTrimResult::Dropped,
845            },
846            (None, Some(r2)) => match policy {
847                OrphanPolicy::KeepSecond => PairedTrimResult::OnlySecond(r2),
848                _ => PairedTrimResult::Dropped,
849            },
850            (None, None) => PairedTrimResult::Dropped,
851        }
852    }
853
854    /// Process a batch of pairs, keeping only those where both reads pass.
855    ///
856    /// Uses [`OrphanPolicy::DropBoth`] — pairs with a single surviving read
857    /// are discarded.
858    pub fn process_paired_batch(
859        &self,
860        pairs: &[PairedFastqRecord],
861    ) -> Vec<PairedFastqRecord> {
862        pairs
863            .iter()
864            .filter_map(|pair| {
865                let r1 = self.process(pair.r1())?;
866                let r2 = self.process(pair.r2())?;
867                Some(PairedFastqRecord::new_unchecked(r1, r2))
868            })
869            .collect()
870    }
871
872    /// Process a batch of pairs and collect detailed statistics.
873    pub fn process_paired_batch_with_stats(
874        &self,
875        pairs: &[PairedFastqRecord],
876    ) -> PairedTrimReport {
877        let total_input = pairs.len();
878        let mut total_bases_input: u64 = 0;
879        let mut total_bases_output: u64 = 0;
880        let mut both_passed: usize = 0;
881        let mut r1_only_passed: usize = 0;
882        let mut r2_only_passed: usize = 0;
883        let mut both_failed: usize = 0;
884        let mut kept = Vec::new();
885
886        for pair in pairs {
887            total_bases_input += pair.r1().sequence().len() as u64;
888            total_bases_input += pair.r2().sequence().len() as u64;
889
890            let r1_result = self.process(pair.r1());
891            let r2_result = self.process(pair.r2());
892
893            match (r1_result, r2_result) {
894                (Some(r1), Some(r2)) => {
895                    total_bases_output += r1.sequence().len() as u64;
896                    total_bases_output += r2.sequence().len() as u64;
897                    both_passed += 1;
898                    kept.push(PairedFastqRecord::new_unchecked(r1, r2));
899                }
900                (Some(_), None) => {
901                    r1_only_passed += 1;
902                }
903                (None, Some(_)) => {
904                    r2_only_passed += 1;
905                }
906                (None, None) => {
907                    both_failed += 1;
908                }
909            }
910        }
911
912        PairedTrimReport {
913            kept,
914            total_input,
915            both_passed,
916            r1_only_passed,
917            r2_only_passed,
918            both_failed,
919            total_bases_input,
920            total_bases_output,
921        }
922    }
923}
924
925// ---------------------------------------------------------------------------
926// Tests
927// ---------------------------------------------------------------------------
928
929#[cfg(test)]
930mod tests {
931    use super::*;
932    use crate::types::DnaSequence;
933
934    /// Helper to create a FastqRecord for testing.
935    fn make_record(seq: &[u8], quals: &[u8]) -> FastqRecord {
936        let sequence = DnaSequence::new(seq).unwrap();
937        let quality = QualityScores::from_raw(quals.to_vec());
938        FastqRecord::new("test".into(), None, sequence, quality).unwrap()
939    }
940
941    // --- Sliding window ---
942
943    #[test]
944    fn sliding_window_all_high() {
945        let q = &[30, 30, 30, 30, 30, 30, 30, 30];
946        let r = trim_sliding_window(q, 4, 15.0);
947        assert_eq!(r, TrimRange { start: 0, end: 8 });
948    }
949
950    #[test]
951    fn sliding_window_mid_drop() {
952        // Quality drops in the middle
953        let q = &[30, 30, 30, 30, 5, 5, 5, 5];
954        let r = trim_sliding_window(q, 4, 15.0);
955        // Window [30,5,5,5] at pos 3 has mean 11.25 < 15 → cut at pos 6
956        // Window [30,30,5,5] at pos 2 has mean 17.5 >= 15 → ok
957        // Window [30,5,5,5] at pos 3 has mean 11.25 < 15 → cut at end of prev window = 6
958        assert!(r.end <= 7);
959        assert!(r.end >= 4);
960    }
961
962    #[test]
963    fn sliding_window_immediate_drop() {
964        let q = &[2, 2, 2, 2];
965        let r = trim_sliding_window(q, 4, 15.0);
966        assert_eq!(r, TrimRange { start: 0, end: 0 });
967    }
968
969    #[test]
970    fn sliding_window_window_1() {
971        // Window size 1 = per-base threshold
972        let q = &[30, 30, 5, 30];
973        let r = trim_sliding_window(q, 1, 15.0);
974        assert_eq!(r, TrimRange { start: 0, end: 2 });
975    }
976
977    #[test]
978    fn sliding_window_last_window_drop() {
979        let q = &[30, 30, 30, 30, 30, 2, 2, 2];
980        let r = trim_sliding_window(q, 4, 15.0);
981        // Window at pos 4 = [30,2,2,2] mean=9 < 15 → cut at pos 7
982        assert!(r.end < 8);
983    }
984
985    // --- Leading / trailing ---
986
987    #[test]
988    fn leading_no_trim() {
989        let q = &[30, 30, 30, 30];
990        assert_eq!(trim_leading(q, 20), TrimRange { start: 0, end: 4 });
991    }
992
993    #[test]
994    fn leading_partial() {
995        let q = &[2, 5, 30, 30];
996        assert_eq!(trim_leading(q, 20), TrimRange { start: 2, end: 4 });
997    }
998
999    #[test]
1000    fn trailing_no_trim() {
1001        let q = &[30, 30, 30, 30];
1002        assert_eq!(trim_trailing(q, 20), TrimRange { start: 0, end: 4 });
1003    }
1004
1005    #[test]
1006    fn trailing_partial() {
1007        let q = &[30, 30, 5, 2];
1008        assert_eq!(trim_trailing(q, 20), TrimRange { start: 0, end: 2 });
1009    }
1010
1011    #[test]
1012    fn leading_trailing_combined() {
1013        let q = &[2, 5, 30, 30, 5, 2];
1014        let r1 = trim_leading(q, 20);
1015        let r2 = trim_trailing(q, 20);
1016        let combined = intersect_ranges(&[r1, r2]);
1017        assert_eq!(combined, TrimRange { start: 2, end: 4 });
1018    }
1019
1020    // --- BWA quality trim ---
1021
1022    #[test]
1023    fn bwa_clean_read() {
1024        let q = &[30, 30, 30, 30, 30];
1025        let r = trim_quality_3prime(q, 20);
1026        assert_eq!(r.end, 5);
1027    }
1028
1029    #[test]
1030    fn bwa_3prime_ramp_down() {
1031        let q = &[30, 30, 30, 10, 5, 2];
1032        let r = trim_quality_3prime(q, 20);
1033        assert!(r.end <= 3);
1034    }
1035
1036    #[test]
1037    fn bwa_all_low() {
1038        let q = &[2, 2, 2, 2];
1039        let r = trim_quality_3prime(q, 20);
1040        assert_eq!(r.end, 0);
1041    }
1042
1043    #[test]
1044    fn bwa_isolated_low_base() {
1045        // One low base among high — BWA should keep most of the read
1046        let q = &[30, 30, 5, 30, 30];
1047        let r = trim_quality_3prime(q, 20);
1048        assert_eq!(r.end, 5);
1049    }
1050
1051    // --- Adapter detection ---
1052
1053    #[test]
1054    fn adapter_exact_match() {
1055        // Read ends with the adapter prefix
1056        let seq = b"ACGTACGTACGTAACCAGATCGGAAGAG";
1057        let cut = find_adapter_3prime(seq, adapters::TRUSEQ_PREFIX, 0);
1058        assert_eq!(cut, 16);
1059    }
1060
1061    #[test]
1062    fn adapter_partial_3prime() {
1063        // Only 8 bases of adapter at the end of the read
1064        let adapter = b"AGATCGGAAGAGCACACGTCTGAACTCCAGTCA";
1065        let seq = b"ACGTACGTACGTACGTAACCAGATCGGAAGAG";
1066        // Adapter prefix at position 21
1067        let cut = find_adapter_3prime(seq, adapter, 0);
1068        assert!(cut < seq.len());
1069    }
1070
1071    #[test]
1072    fn adapter_one_mismatch() {
1073        // Adapter with one mismatch
1074        let seq = b"ACGTACGTACGTAACCAGATCGGAATAG";
1075        // Mismatch at position 25 (G→T in "AAGAG" → "AATAG")
1076        let cut = find_adapter_3prime(seq, adapters::TRUSEQ_PREFIX, 1);
1077        assert_eq!(cut, 16);
1078    }
1079
1080    #[test]
1081    fn adapter_too_many_mismatches() {
1082        let seq = b"ACGTACGTACGTAACCNNNNNNNNNNN";
1083        let cut = find_adapter_3prime(seq, adapters::TRUSEQ_PREFIX, 1);
1084        assert_eq!(cut, seq.len());
1085    }
1086
1087    #[test]
1088    fn adapter_no_adapter() {
1089        let seq = b"ACGTACGTACGTACGT";
1090        let cut = find_adapter_3prime(seq, adapters::TRUSEQ_PREFIX, 1);
1091        assert_eq!(cut, seq.len());
1092    }
1093
1094    // --- Shannon entropy ---
1095
1096    #[test]
1097    fn entropy_homopolymer() {
1098        let e = shannon_entropy(b"AAAAAAAAAA");
1099        assert!((e - 0.0).abs() < 1e-10);
1100    }
1101
1102    #[test]
1103    fn entropy_equiprobable() {
1104        let e = shannon_entropy(b"ACGTACGTACGTACGT");
1105        assert!((e - 2.0).abs() < 1e-10);
1106    }
1107
1108    #[test]
1109    fn entropy_dinucleotide() {
1110        let e = shannon_entropy(b"ACACACACACACACAC");
1111        assert!((e - 1.0).abs() < 1e-10);
1112    }
1113
1114    #[test]
1115    fn entropy_empty() {
1116        assert_eq!(shannon_entropy(b""), 0.0);
1117    }
1118
1119    // --- Filter functions ---
1120
1121    #[test]
1122    fn filter_length_pass() {
1123        let r = make_record(b"ACGTACGT", &[30; 8]);
1124        assert!(filter_by_length(&r, 4, 100).is_some());
1125    }
1126
1127    #[test]
1128    fn filter_length_too_short() {
1129        let r = make_record(b"ACG", &[30; 3]);
1130        assert!(filter_by_length(&r, 4, 100).is_none());
1131    }
1132
1133    #[test]
1134    fn filter_length_too_long() {
1135        let r = make_record(b"ACGTACGT", &[30; 8]);
1136        assert!(filter_by_length(&r, 1, 4).is_none());
1137    }
1138
1139    #[test]
1140    fn filter_quality_pass() {
1141        let r = make_record(b"ACGT", &[30, 30, 30, 30]);
1142        assert!(filter_by_quality(&r, 20.0).is_some());
1143    }
1144
1145    #[test]
1146    fn filter_quality_fail() {
1147        let r = make_record(b"ACGT", &[5, 5, 5, 5]);
1148        assert!(filter_by_quality(&r, 20.0).is_none());
1149    }
1150
1151    #[test]
1152    fn filter_complexity_pass() {
1153        let r = make_record(b"ACGTACGTACGTACGT", &[30; 16]);
1154        assert!(filter_low_complexity(&r, 1.5).is_some());
1155    }
1156
1157    #[test]
1158    fn filter_complexity_fail() {
1159        let r = make_record(b"AAAAAAAAAAAAAAAA", &[30; 16]);
1160        assert!(filter_low_complexity(&r, 1.0).is_none());
1161    }
1162
1163    // --- TrimRange intersection ---
1164
1165    #[test]
1166    fn intersect_overlapping() {
1167        let r = intersect_ranges(&[
1168            TrimRange { start: 0, end: 8 },
1169            TrimRange { start: 2, end: 10 },
1170        ]);
1171        assert_eq!(r, TrimRange { start: 2, end: 8 });
1172    }
1173
1174    #[test]
1175    fn intersect_non_overlapping() {
1176        let r = intersect_ranges(&[
1177            TrimRange { start: 0, end: 3 },
1178            TrimRange { start: 5, end: 10 },
1179        ]);
1180        assert!(r.is_empty());
1181    }
1182
1183    // --- apply_trim ---
1184
1185    #[test]
1186    fn apply_trim_basic() {
1187        let r = make_record(b"ACGTACGT", &[10, 20, 30, 40, 30, 20, 10, 5]);
1188        let trimmed = apply_trim(&r, TrimRange { start: 2, end: 6 }).unwrap();
1189        assert_eq!(trimmed.sequence().as_bytes(), b"GTAC");
1190        assert_eq!(trimmed.quality().as_slice(), &[30, 40, 30, 20]);
1191    }
1192
1193    #[test]
1194    fn apply_trim_empty_range() {
1195        let r = make_record(b"ACGT", &[30; 4]);
1196        assert!(apply_trim(&r, TrimRange { start: 5, end: 2 }).is_none());
1197    }
1198
1199    // --- Pipeline integration ---
1200
1201    #[test]
1202    fn pipeline_noop() {
1203        let pipeline = TrimPipeline::new();
1204        let r = make_record(b"ACGTACGT", &[30; 8]);
1205        let result = pipeline.process(&r).unwrap();
1206        assert_eq!(result.sequence().as_bytes(), b"ACGTACGT");
1207    }
1208
1209    #[test]
1210    fn pipeline_full() {
1211        let pipeline = TrimPipeline::new()
1212            .leading(20)
1213            .trailing(20)
1214            .sliding_window(4, 15.0)
1215            .min_length(2);
1216
1217        let r = make_record(b"ACGTACGT", &[5, 30, 30, 30, 30, 30, 30, 5]);
1218        let result = pipeline.process(&r).unwrap();
1219        // Leading trims first base (Q5 < 20), trailing trims last (Q5 < 20)
1220        assert_eq!(result.sequence().as_bytes(), b"CGTACG");
1221    }
1222
1223    #[test]
1224    fn pipeline_everything_filtered() {
1225        let pipeline = TrimPipeline::new().min_mean_quality(40.0);
1226        let r = make_record(b"ACGT", &[10, 10, 10, 10]);
1227        assert!(pipeline.process(&r).is_none());
1228    }
1229
1230    #[test]
1231    fn pipeline_batch_stats() {
1232        let pipeline = TrimPipeline::new().min_mean_quality(20.0);
1233        let records = vec![
1234            make_record(b"ACGT", &[30, 30, 30, 30]), // passes
1235            make_record(b"ACGT", &[5, 5, 5, 5]),     // filtered by quality
1236            make_record(b"ACGT", &[25, 25, 25, 25]), // passes
1237        ];
1238        let report = pipeline.process_batch_with_stats(&records);
1239        assert_eq!(report.total_input, 3);
1240        assert_eq!(report.kept.len(), 2);
1241        assert_eq!(report.filtered_by_quality, 1);
1242    }
1243
1244    // --- Edge cases ---
1245
1246    #[test]
1247    fn empty_sequence() {
1248        let seq = DnaSequence::new(b"").unwrap();
1249        let qual = QualityScores::from_raw(vec![]);
1250        let r = FastqRecord::new("empty".into(), None, seq, qual).unwrap();
1251
1252        assert_eq!(trim_sliding_window(&[], 4, 15.0), TrimRange { start: 0, end: 0 });
1253        assert!(TrimPipeline::new().process(&r).is_some());
1254    }
1255
1256    #[test]
1257    fn single_base() {
1258        let r = make_record(b"A", &[30]);
1259        let pipeline = TrimPipeline::new().min_length(1);
1260        assert!(pipeline.process(&r).is_some());
1261    }
1262
1263    #[test]
1264    fn uniform_quality() {
1265        let q = &[20, 20, 20, 20, 20];
1266        assert_eq!(trim_sliding_window(q, 4, 20.0), TrimRange { start: 0, end: 5 });
1267        assert_eq!(trim_leading(q, 20), TrimRange { start: 0, end: 5 });
1268        assert_eq!(trim_trailing(q, 20), TrimRange { start: 0, end: 5 });
1269    }
1270}
1271
1272#[cfg(test)]
1273mod proptests {
1274    use super::*;
1275    use crate::types::DnaSequence;
1276    use proptest::prelude::*;
1277
1278    fn dna_and_quality(max_len: usize) -> impl Strategy<Value = (Vec<u8>, Vec<u8>)> {
1279        (1..=max_len).prop_flat_map(|len| {
1280            let seq = proptest::collection::vec(
1281                prop_oneof![Just(b'A'), Just(b'C'), Just(b'G'), Just(b'T')],
1282                len,
1283            );
1284            let qual = proptest::collection::vec(0..=41u8, len);
1285            (seq, qual)
1286        })
1287    }
1288
1289    proptest! {
1290        #[test]
1291        fn trimmed_never_longer(
1292            (seq, qual) in dna_and_quality(200)
1293        ) {
1294            let record = {
1295                let s = DnaSequence::new(&seq).unwrap();
1296                let q = QualityScores::from_raw(qual.clone());
1297                FastqRecord::new("test".into(), None, s, q).unwrap()
1298            };
1299            let pipeline = TrimPipeline::new()
1300                .leading(10)
1301                .trailing(10)
1302                .sliding_window(4, 15.0);
1303            if let Some(trimmed) = pipeline.process(&record) {
1304                prop_assert!(trimmed.sequence().len() <= record.sequence().len());
1305            }
1306        }
1307
1308        #[test]
1309        fn intersect_valid_subrange(
1310            s1 in 0..50usize,
1311            e1 in 0..50usize,
1312            s2 in 0..50usize,
1313            e2 in 0..50usize,
1314        ) {
1315            let r1 = TrimRange { start: s1, end: e1 };
1316            let r2 = TrimRange { start: s2, end: e2 };
1317            let result = intersect_ranges(&[r1, r2]);
1318            prop_assert!(result.start <= result.end);
1319            if !r1.is_empty() && !r2.is_empty() {
1320                prop_assert!(result.start >= s1.max(s2));
1321                prop_assert!(result.end <= e1.min(e2).max(result.start));
1322            }
1323        }
1324    }
1325}
1326
1327#[cfg(test)]
1328#[cfg(feature = "std")]
1329mod paired_tests {
1330    use super::*;
1331    use crate::paired::PairedFastqRecord;
1332    use crate::types::DnaSequence;
1333
1334    fn make_record(seq: &[u8], quals: &[u8]) -> FastqRecord {
1335        let sequence = DnaSequence::new(seq).unwrap();
1336        let quality = QualityScores::from_raw(quals.to_vec());
1337        FastqRecord::new("test".into(), None, sequence, quality).unwrap()
1338    }
1339
1340    fn make_pair(
1341        seq1: &[u8],
1342        quals1: &[u8],
1343        seq2: &[u8],
1344        quals2: &[u8],
1345    ) -> PairedFastqRecord {
1346        PairedFastqRecord::new_unchecked(make_record(seq1, quals1), make_record(seq2, quals2))
1347    }
1348
1349    #[test]
1350    fn paired_both_pass() {
1351        let pipeline = TrimPipeline::new().min_mean_quality(10.0);
1352        let r1 = make_record(b"ACGT", &[30; 4]);
1353        let r2 = make_record(b"TGCA", &[30; 4]);
1354        match pipeline.process_paired(&r1, &r2, OrphanPolicy::DropBoth) {
1355            PairedTrimResult::BothPassed(_, _) => {}
1356            other => panic!("expected BothPassed, got {:?}", other),
1357        }
1358    }
1359
1360    #[test]
1361    fn paired_r1_fails_drop_both() {
1362        let pipeline = TrimPipeline::new().min_mean_quality(20.0);
1363        let r1 = make_record(b"ACGT", &[5; 4]);
1364        let r2 = make_record(b"TGCA", &[30; 4]);
1365        match pipeline.process_paired(&r1, &r2, OrphanPolicy::DropBoth) {
1366            PairedTrimResult::Dropped => {}
1367            other => panic!("expected Dropped, got {:?}", other),
1368        }
1369    }
1370
1371    #[test]
1372    fn paired_r1_fails_keep_second() {
1373        let pipeline = TrimPipeline::new().min_mean_quality(20.0);
1374        let r1 = make_record(b"ACGT", &[5; 4]);
1375        let r2 = make_record(b"TGCA", &[30; 4]);
1376        match pipeline.process_paired(&r1, &r2, OrphanPolicy::KeepSecond) {
1377            PairedTrimResult::OnlySecond(_) => {}
1378            other => panic!("expected OnlySecond, got {:?}", other),
1379        }
1380    }
1381
1382    #[test]
1383    fn paired_r2_fails_keep_first() {
1384        let pipeline = TrimPipeline::new().min_mean_quality(20.0);
1385        let r1 = make_record(b"ACGT", &[30; 4]);
1386        let r2 = make_record(b"TGCA", &[5; 4]);
1387        match pipeline.process_paired(&r1, &r2, OrphanPolicy::KeepFirst) {
1388            PairedTrimResult::OnlyFirst(_) => {}
1389            other => panic!("expected OnlyFirst, got {:?}", other),
1390        }
1391    }
1392
1393    #[test]
1394    fn paired_r2_fails_drop_both() {
1395        let pipeline = TrimPipeline::new().min_mean_quality(20.0);
1396        let r1 = make_record(b"ACGT", &[30; 4]);
1397        let r2 = make_record(b"TGCA", &[5; 4]);
1398        match pipeline.process_paired(&r1, &r2, OrphanPolicy::DropBoth) {
1399            PairedTrimResult::Dropped => {}
1400            other => panic!("expected Dropped, got {:?}", other),
1401        }
1402    }
1403
1404    #[test]
1405    fn paired_both_fail() {
1406        let pipeline = TrimPipeline::new().min_mean_quality(20.0);
1407        let r1 = make_record(b"ACGT", &[5; 4]);
1408        let r2 = make_record(b"TGCA", &[5; 4]);
1409        match pipeline.process_paired(&r1, &r2, OrphanPolicy::KeepFirst) {
1410            PairedTrimResult::Dropped => {}
1411            other => panic!("expected Dropped, got {:?}", other),
1412        }
1413    }
1414
1415    #[test]
1416    fn paired_batch_drop_both() {
1417        let pipeline = TrimPipeline::new().min_mean_quality(20.0);
1418        let pairs = vec![
1419            make_pair(b"ACGT", &[30; 4], b"TGCA", &[30; 4]),
1420            make_pair(b"ACGT", &[5; 4], b"TGCA", &[30; 4]),
1421        ];
1422        let kept = pipeline.process_paired_batch(&pairs);
1423        assert_eq!(kept.len(), 1);
1424    }
1425
1426    #[test]
1427    fn paired_batch_stats() {
1428        let pipeline = TrimPipeline::new().min_mean_quality(20.0);
1429        let pairs = vec![
1430            make_pair(b"ACGT", &[30; 4], b"TGCA", &[30; 4]), // both pass
1431            make_pair(b"ACGT", &[5; 4], b"TGCA", &[30; 4]),  // r1 fails
1432            make_pair(b"ACGT", &[30; 4], b"TGCA", &[5; 4]),  // r2 fails
1433            make_pair(b"ACGT", &[5; 4], b"TGCA", &[5; 4]),   // both fail
1434        ];
1435        let report = pipeline.process_paired_batch_with_stats(&pairs);
1436        assert_eq!(report.total_input, 4);
1437        assert_eq!(report.both_passed, 1);
1438        assert_eq!(report.r1_only_passed, 1);
1439        assert_eq!(report.r2_only_passed, 1);
1440        assert_eq!(report.both_failed, 1);
1441        assert_eq!(report.orphans(), 2);
1442        assert!((report.survival_rate() - 0.25).abs() < 1e-10);
1443        assert_eq!(report.kept.len(), 1);
1444    }
1445
1446    #[test]
1447    fn paired_batch_stats_bases() {
1448        let pipeline = TrimPipeline::new();
1449        let pairs = vec![make_pair(b"ACGTACGT", &[30; 8], b"TGCA", &[30; 4])];
1450        let report = pipeline.process_paired_batch_with_stats(&pairs);
1451        assert_eq!(report.total_bases_input, 12);
1452        assert_eq!(report.total_bases_output, 12);
1453    }
1454
1455    #[test]
1456    fn paired_batch_stats_empty() {
1457        let pipeline = TrimPipeline::new();
1458        let pairs: Vec<PairedFastqRecord> = vec![];
1459        let report = pipeline.process_paired_batch_with_stats(&pairs);
1460        assert_eq!(report.total_input, 0);
1461        assert_eq!(report.both_passed, 0);
1462        assert!((report.survival_rate() - 0.0).abs() < 1e-10);
1463    }
1464
1465    #[test]
1466    fn paired_full_pipeline() {
1467        let pipeline = TrimPipeline::new()
1468            .adapter(b"AGATCGGAAGAG")
1469            .leading(3)
1470            .trailing(3)
1471            .sliding_window(4, 15.0)
1472            .min_length(4);
1473
1474        let r1 = make_record(b"ACGTACGTACGTACGT", &[30; 16]);
1475        let r2 = make_record(b"TGCATGCATGCATGCA", &[30; 16]);
1476        match pipeline.process_paired(&r1, &r2, OrphanPolicy::DropBoth) {
1477            PairedTrimResult::BothPassed(_, _) => {}
1478            other => panic!("expected BothPassed, got {:?}", other),
1479        }
1480    }
1481}
1482
1483#[cfg(test)]
1484#[cfg(feature = "std")]
1485mod paired_proptests {
1486    use super::*;
1487    use crate::types::DnaSequence;
1488    use proptest::prelude::*;
1489
1490    fn dna_and_quality(max_len: usize) -> impl Strategy<Value = (Vec<u8>, Vec<u8>)> {
1491        (1..=max_len).prop_flat_map(|len| {
1492            let seq = proptest::collection::vec(
1493                prop_oneof![Just(b'A'), Just(b'C'), Just(b'G'), Just(b'T')],
1494                len,
1495            );
1496            let qual = proptest::collection::vec(0..=41u8, len);
1497            (seq, qual)
1498        })
1499    }
1500
1501    proptest! {
1502        #[test]
1503        fn paired_trimmed_never_longer(
1504            (seq1, qual1) in dna_and_quality(100),
1505            (seq2, qual2) in dna_and_quality(100),
1506        ) {
1507            let r1 = {
1508                let s = DnaSequence::new(&seq1).unwrap();
1509                let q = QualityScores::from_raw(qual1);
1510                FastqRecord::new("test".into(), None, s, q).unwrap()
1511            };
1512            let r2 = {
1513                let s = DnaSequence::new(&seq2).unwrap();
1514                let q = QualityScores::from_raw(qual2);
1515                FastqRecord::new("test".into(), None, s, q).unwrap()
1516            };
1517
1518            let pipeline = TrimPipeline::new()
1519                .leading(10)
1520                .trailing(10)
1521                .sliding_window(4, 15.0);
1522
1523            match pipeline.process_paired(&r1, &r2, OrphanPolicy::KeepFirst) {
1524                PairedTrimResult::BothPassed(tr1, tr2) => {
1525                    prop_assert!(tr1.sequence().len() <= r1.sequence().len());
1526                    prop_assert!(tr2.sequence().len() <= r2.sequence().len());
1527                }
1528                PairedTrimResult::OnlyFirst(tr1) => {
1529                    prop_assert!(tr1.sequence().len() <= r1.sequence().len());
1530                }
1531                _ => {}
1532            }
1533        }
1534    }
1535}