finch/
filtering.rs

1use std::cmp;
2use std::collections::HashMap;
3
4use crate::errors::FinchResult;
5use crate::serialization::Sketch;
6use crate::sketch_schemes::KmerCount;
7use crate::statistics::hist;
8
9/// Used to pass around filter options for sketching
10#[derive(Clone, Debug, PartialEq)]
11pub struct FilterParams {
12    pub filter_on: Option<bool>,
13    pub abun_filter: (Option<u32>, Option<u32>),
14    pub err_filter: f64,
15    pub strand_filter: f64,
16}
17
18impl FilterParams {
19    /// Filter the sketch according to these FilterParams
20    pub fn filter_sketch(&self, sketch: &mut Sketch) {
21        // make a copy of myself so any updates from the data don't change
22        // my own parameters
23        let mut filters_copy = self.clone();
24        filters_copy.filter_counts(&sketch.hashes);
25        // we need to update any parameters that are stricter than the
26        // ones in the existing `sketch.filter_params` to reflect additional
27        // filtering (if they were already more strict leave them be)
28        sketch.filter_params.filter_on = self.filter_on;
29        sketch.filter_params.abun_filter = match self.abun_filter {
30            (Some(l), Some(h)) => (
31                Some(u32::max(l, sketch.filter_params.abun_filter.0.unwrap_or(0))),
32                Some(u32::min(
33                    h,
34                    sketch
35                        .filter_params
36                        .abun_filter
37                        .1
38                        .unwrap_or(u32::max_value()),
39                )),
40            ),
41            (Some(l), None) => (
42                Some(u32::max(l, sketch.filter_params.abun_filter.0.unwrap_or(0))),
43                None,
44            ),
45            (None, Some(h)) => (
46                None,
47                Some(u32::min(
48                    h,
49                    sketch
50                        .filter_params
51                        .abun_filter
52                        .1
53                        .unwrap_or(u32::max_value()),
54                )),
55            ),
56            (None, None) => (None, None),
57        };
58        sketch.filter_params.err_filter =
59            f64::max(sketch.filter_params.err_filter, self.err_filter);
60        sketch.filter_params.strand_filter =
61            f64::max(sketch.filter_params.strand_filter, self.strand_filter);
62    }
63
64    /// Returns the filtered kmer counts.
65    ///
66    /// If the err filter determined a different low_abundance_filter update
67    /// self to that one.
68    pub fn filter_counts(&mut self, hashes: &[KmerCount]) -> Vec<KmerCount> {
69        let filter_on = self.filter_on == Some(true);
70        let mut filtered_hashes = hashes.to_vec();
71
72        if filter_on && self.strand_filter > 0f64 {
73            filtered_hashes = filter_strands(&filtered_hashes, self.strand_filter);
74        }
75
76        if filter_on && self.err_filter > 0f64 {
77            let cutoff = guess_filter_threshold(&filtered_hashes, self.err_filter);
78            if let Some(v) = self.abun_filter.0 {
79                // there's an existing filter so we only use this one if it's stricter
80                if cutoff > v {
81                    self.abun_filter.0 = Some(cutoff);
82                }
83            } else {
84                // no filter set so just use the one we determined
85                self.abun_filter.0 = Some(cutoff);
86            }
87        }
88
89        if filter_on && (self.abun_filter.0.is_some() || self.abun_filter.1.is_some()) {
90            filtered_hashes =
91                filter_abundance(&filtered_hashes, self.abun_filter.0, self.abun_filter.1);
92        }
93
94        filtered_hashes
95    }
96
97    pub fn to_serialized(&self) -> HashMap<String, String> {
98        let mut filter_stats: HashMap<String, String> = HashMap::new();
99        if self.filter_on != Some(true) {
100            return filter_stats;
101        }
102
103        if self.strand_filter > 0f64 {
104            filter_stats.insert(String::from("strandFilter"), self.strand_filter.to_string());
105        }
106        if self.err_filter > 0f64 {
107            filter_stats.insert(String::from("errFilter"), self.err_filter.to_string());
108        }
109        if let Some(v) = self.abun_filter.0 {
110            filter_stats.insert(String::from("minCopies"), v.to_string());
111        }
112        if let Some(v) = self.abun_filter.1 {
113            filter_stats.insert(String::from("maxCopies"), v.to_string());
114        }
115        filter_stats
116    }
117
118    pub fn from_serialized(filters: &HashMap<String, String>) -> FinchResult<Self> {
119        let low_abun = if let Some(min_copies) = filters.get("minCopies") {
120            Some(min_copies.parse()?)
121        } else {
122            None
123        };
124        let high_abun = if let Some(max_copies) = filters.get("maxCopies") {
125            Some(max_copies.parse()?)
126        } else {
127            None
128        };
129        Ok(FilterParams {
130            filter_on: Some(!filters.is_empty()),
131            abun_filter: (low_abun, high_abun),
132            err_filter: filters
133                .get("errFilter")
134                .unwrap_or(&"0".to_string())
135                .parse()?,
136            strand_filter: filters
137                .get("strandFilter")
138                .unwrap_or(&"0".to_string())
139                .parse()?,
140        })
141    }
142}
143
144impl Default for FilterParams {
145    fn default() -> Self {
146        FilterParams {
147            filter_on: Some(false),
148            abun_filter: (None, None),
149            err_filter: 0.,
150            strand_filter: 0.,
151        }
152    }
153}
154
155/// Determines a dynamic filtering threshold for low abundance kmers. The
156/// cutoff returned is the lowest number of counts that should be included
157/// in any final results.
158///
159/// Useful for removing, e.g. low-abundance kmers arising from sequencing
160/// errors
161///
162pub fn guess_filter_threshold(sketch: &[KmerCount], filter_level: f64) -> u32 {
163    let hist_data = hist(sketch);
164    let total_counts = hist_data
165        .iter()
166        .enumerate()
167        .map(|t| (t.0 as u64 + 1) * t.1)
168        .sum::<u64>() as f64;
169    let cutoff_amt = filter_level * total_counts;
170
171    // calculate the coverage that N% of the weighted data is above
172    // note wgt_cutoff is an index now *not* a number of counts
173    let mut wgt_cutoff: usize = 0;
174    let mut cum_count: u64 = 0;
175    for count in &hist_data {
176        cum_count += wgt_cutoff as u64 * *count as u64;
177        if cum_count as f64 > cutoff_amt {
178            break;
179        }
180        wgt_cutoff += 1;
181    }
182
183    // special case if the cutoff is the first value
184    if wgt_cutoff == 0 {
185        return 1;
186    }
187
188    // now find the minima within the window to the left
189    let win_size = cmp::max(1, wgt_cutoff / 20);
190    let mut sum: u64 = hist_data[..win_size].iter().sum();
191    let mut lowest_val = sum;
192    let mut lowest_idx = win_size - 1;
193    for (i, j) in (0..wgt_cutoff - win_size).zip(win_size..wgt_cutoff) {
194        if sum <= lowest_val {
195            lowest_val = sum;
196            lowest_idx = j;
197        }
198        sum -= hist_data[i];
199        sum += hist_data[j];
200    }
201
202    lowest_idx as u32 + 1
203}
204
205#[test]
206fn test_guess_filter_threshold() {
207    let sketch = vec![];
208    let cutoff = guess_filter_threshold(&sketch, 0.2);
209    assert_eq!(cutoff, 1);
210
211    let sketch = vec![KmerCount {
212        hash: 1,
213        kmer: vec![],
214        count: 1,
215        extra_count: 0,
216        label: None,
217    }];
218    let cutoff = guess_filter_threshold(&sketch, 0.2);
219    assert_eq!(cutoff, 1);
220
221    let sketch = vec![
222        KmerCount {
223            hash: 1,
224            kmer: vec![],
225            count: 1,
226            extra_count: 0,
227            label: None,
228        },
229        KmerCount {
230            hash: 2,
231            kmer: vec![],
232            count: 1,
233            extra_count: 0,
234            label: None,
235        },
236    ];
237    let cutoff = guess_filter_threshold(&sketch, 0.2);
238    assert_eq!(cutoff, 1);
239
240    let sketch = vec![
241        KmerCount {
242            hash: 1,
243            kmer: vec![],
244            count: 1,
245            extra_count: 0,
246            label: None,
247        },
248        KmerCount {
249            hash: 2,
250            kmer: vec![],
251            count: 9,
252            extra_count: 0,
253            label: None,
254        },
255    ];
256    let cutoff = guess_filter_threshold(&sketch, 0.2);
257    assert_eq!(cutoff, 8);
258
259    let sketch = vec![
260        KmerCount {
261            hash: 1,
262            kmer: vec![],
263            count: 1,
264            extra_count: 0,
265            label: None,
266        },
267        KmerCount {
268            hash: 2,
269            kmer: vec![],
270            count: 10,
271            extra_count: 0,
272            label: None,
273        },
274        KmerCount {
275            hash: 3,
276            kmer: vec![],
277            count: 10,
278            extra_count: 0,
279            label: None,
280        },
281        KmerCount {
282            hash: 4,
283            kmer: vec![],
284            count: 9,
285            extra_count: 0,
286            label: None,
287        },
288    ];
289    let cutoff = guess_filter_threshold(&sketch, 0.1);
290    assert_eq!(cutoff, 8);
291
292    let sketch = vec![
293        KmerCount {
294            hash: 1,
295            kmer: vec![],
296            count: 1,
297            extra_count: 0,
298            label: None,
299        },
300        KmerCount {
301            hash: 2,
302            kmer: vec![],
303            count: 1,
304            extra_count: 0,
305            label: None,
306        },
307        KmerCount {
308            hash: 3,
309            kmer: vec![],
310            count: 2,
311            extra_count: 0,
312            label: None,
313        },
314        KmerCount {
315            hash: 4,
316            kmer: vec![],
317            count: 4,
318            extra_count: 0,
319            label: None,
320        },
321    ];
322    let cutoff = guess_filter_threshold(&sketch, 0.1);
323    assert_eq!(cutoff, 1);
324
325    // check that we don't overflow
326    let sketch = vec![KmerCount {
327        hash: 2,
328        kmer: vec![],
329        count: 2,
330        extra_count: 0,
331        label: None,
332    }];
333    let cutoff = guess_filter_threshold(&sketch, 1.);
334    assert_eq!(cutoff, 2);
335}
336
337pub fn filter_abundance(
338    sketch: &[KmerCount],
339    low: Option<u32>,
340    high: Option<u32>,
341) -> Vec<KmerCount> {
342    let mut filtered = Vec::new();
343    let lo_threshold = low.unwrap_or(0u32);
344    let hi_threshold = high.unwrap_or(u32::max_value());
345    for kmer in sketch {
346        if lo_threshold <= kmer.count && kmer.count <= hi_threshold {
347            filtered.push(kmer.clone());
348        }
349    }
350    filtered
351}
352
353#[test]
354fn test_filter_abundance() {
355    let sketch = vec![
356        KmerCount {
357            hash: 1,
358            kmer: vec![],
359            count: 1,
360            extra_count: 0,
361            label: None,
362        },
363        KmerCount {
364            hash: 2,
365            kmer: vec![],
366            count: 1,
367            extra_count: 0,
368            label: None,
369        },
370    ];
371    let filtered = filter_abundance(&sketch, Some(1), None);
372    assert_eq!(filtered.len(), 2);
373    assert_eq!(filtered[0].hash, 1);
374    assert_eq!(filtered[1].hash, 2);
375
376    let sketch = vec![
377        KmerCount {
378            hash: 1,
379            kmer: vec![],
380            count: 1,
381            extra_count: 0,
382            label: None,
383        },
384        KmerCount {
385            hash: 2,
386            kmer: vec![],
387            count: 10,
388            extra_count: 0,
389            label: None,
390        },
391        KmerCount {
392            hash: 3,
393            kmer: vec![],
394            count: 10,
395            extra_count: 0,
396            label: None,
397        },
398        KmerCount {
399            hash: 4,
400            kmer: vec![],
401            count: 9,
402            extra_count: 0,
403            label: None,
404        },
405    ];
406    let filtered = filter_abundance(&sketch, Some(9), None);
407    assert_eq!(filtered.len(), 3);
408    assert_eq!(filtered[0].hash, 2);
409    assert_eq!(filtered[1].hash, 3);
410    assert_eq!(filtered[2].hash, 4);
411
412    let filtered = filter_abundance(&sketch, Some(2), Some(9));
413    assert_eq!(filtered.len(), 1);
414    assert_eq!(filtered[0].hash, 4);
415}
416
417/// Filter out kmers that have a large abundance difference between being seen in the
418/// "forward" and "reverse" orientations (picked arbitrarily which is which).
419///
420/// These tend to be sequencing adapters.
421pub fn filter_strands(sketch: &[KmerCount], ratio_cutoff: f64) -> Vec<KmerCount> {
422    let mut filtered = Vec::new();
423    for kmer in sketch {
424        // "special-case" anything with fewer than 16 kmers -> these are too stochastic to accurately
425        // determine if something is an adapter or not. The odds of randomly picking less than 10%
426        // (0 or 1 reversed kmers) in 16 should be ~ 17 / 2 ** 16 or 1/4000 so we're avoiding
427        // removing "good" kmers
428        if kmer.count < 16 {
429            filtered.push(kmer.clone());
430            continue;
431        }
432
433        // check the forward/reverse ratio and only add if it's within bounds
434        let lowest_strand_count = cmp::min(kmer.extra_count, kmer.count - kmer.extra_count);
435        if (lowest_strand_count as f64 / kmer.count as f64) >= ratio_cutoff {
436            filtered.push(kmer.clone());
437        }
438    }
439    filtered
440}
441
442#[test]
443fn test_filter_strands() {
444    let sketch = vec![
445        KmerCount {
446            hash: 1,
447            kmer: vec![],
448            count: 10,
449            extra_count: 1,
450            label: None,
451        },
452        KmerCount {
453            hash: 2,
454            kmer: vec![],
455            count: 10,
456            extra_count: 2,
457            label: None,
458        },
459        KmerCount {
460            hash: 3,
461            kmer: vec![],
462            count: 10,
463            extra_count: 8,
464            label: None,
465        },
466        KmerCount {
467            hash: 4,
468            kmer: vec![],
469            count: 10,
470            extra_count: 9,
471            label: None,
472        },
473    ];
474    let filtered = filter_strands(&sketch, 0.15);
475    assert_eq!(filtered.len(), 4);
476    assert_eq!(filtered[0].hash, 1);
477    assert_eq!(filtered[3].hash, 4);
478
479    let sketch = vec![
480        KmerCount {
481            hash: 1,
482            kmer: vec![],
483            count: 16,
484            extra_count: 1,
485            label: None,
486        },
487        KmerCount {
488            hash: 2,
489            kmer: vec![],
490            count: 16,
491            extra_count: 2,
492            label: None,
493        },
494        KmerCount {
495            hash: 3,
496            kmer: vec![],
497            count: 16,
498            extra_count: 8,
499            label: None,
500        },
501        KmerCount {
502            hash: 4,
503            kmer: vec![],
504            count: 16,
505            extra_count: 9,
506            label: None,
507        },
508    ];
509    let filtered = filter_strands(&sketch, 0.15);
510    assert_eq!(filtered.len(), 2);
511    assert_eq!(filtered[0].hash, 3);
512    assert_eq!(filtered[1].hash, 4);
513}