1use std::cmp;
2use std::collections::HashMap;
3
4use crate::errors::FinchResult;
5use crate::serialization::Sketch;
6use crate::sketch_schemes::KmerCount;
7use crate::statistics::hist;
8
9#[derive(Clone, Debug, PartialEq)]
11pub struct FilterParams {
12 pub filter_on: Option<bool>,
13 pub abun_filter: (Option<u32>, Option<u32>),
14 pub err_filter: f64,
15 pub strand_filter: f64,
16}
17
18impl FilterParams {
19 pub fn filter_sketch(&self, sketch: &mut Sketch) {
21 let mut filters_copy = self.clone();
24 filters_copy.filter_counts(&sketch.hashes);
25 sketch.filter_params.filter_on = self.filter_on;
29 sketch.filter_params.abun_filter = match self.abun_filter {
30 (Some(l), Some(h)) => (
31 Some(u32::max(l, sketch.filter_params.abun_filter.0.unwrap_or(0))),
32 Some(u32::min(
33 h,
34 sketch
35 .filter_params
36 .abun_filter
37 .1
38 .unwrap_or(u32::max_value()),
39 )),
40 ),
41 (Some(l), None) => (
42 Some(u32::max(l, sketch.filter_params.abun_filter.0.unwrap_or(0))),
43 None,
44 ),
45 (None, Some(h)) => (
46 None,
47 Some(u32::min(
48 h,
49 sketch
50 .filter_params
51 .abun_filter
52 .1
53 .unwrap_or(u32::max_value()),
54 )),
55 ),
56 (None, None) => (None, None),
57 };
58 sketch.filter_params.err_filter =
59 f64::max(sketch.filter_params.err_filter, self.err_filter);
60 sketch.filter_params.strand_filter =
61 f64::max(sketch.filter_params.strand_filter, self.strand_filter);
62 }
63
64 pub fn filter_counts(&mut self, hashes: &[KmerCount]) -> Vec<KmerCount> {
69 let filter_on = self.filter_on == Some(true);
70 let mut filtered_hashes = hashes.to_vec();
71
72 if filter_on && self.strand_filter > 0f64 {
73 filtered_hashes = filter_strands(&filtered_hashes, self.strand_filter);
74 }
75
76 if filter_on && self.err_filter > 0f64 {
77 let cutoff = guess_filter_threshold(&filtered_hashes, self.err_filter);
78 if let Some(v) = self.abun_filter.0 {
79 if cutoff > v {
81 self.abun_filter.0 = Some(cutoff);
82 }
83 } else {
84 self.abun_filter.0 = Some(cutoff);
86 }
87 }
88
89 if filter_on && (self.abun_filter.0.is_some() || self.abun_filter.1.is_some()) {
90 filtered_hashes =
91 filter_abundance(&filtered_hashes, self.abun_filter.0, self.abun_filter.1);
92 }
93
94 filtered_hashes
95 }
96
97 pub fn to_serialized(&self) -> HashMap<String, String> {
98 let mut filter_stats: HashMap<String, String> = HashMap::new();
99 if self.filter_on != Some(true) {
100 return filter_stats;
101 }
102
103 if self.strand_filter > 0f64 {
104 filter_stats.insert(String::from("strandFilter"), self.strand_filter.to_string());
105 }
106 if self.err_filter > 0f64 {
107 filter_stats.insert(String::from("errFilter"), self.err_filter.to_string());
108 }
109 if let Some(v) = self.abun_filter.0 {
110 filter_stats.insert(String::from("minCopies"), v.to_string());
111 }
112 if let Some(v) = self.abun_filter.1 {
113 filter_stats.insert(String::from("maxCopies"), v.to_string());
114 }
115 filter_stats
116 }
117
118 pub fn from_serialized(filters: &HashMap<String, String>) -> FinchResult<Self> {
119 let low_abun = if let Some(min_copies) = filters.get("minCopies") {
120 Some(min_copies.parse()?)
121 } else {
122 None
123 };
124 let high_abun = if let Some(max_copies) = filters.get("maxCopies") {
125 Some(max_copies.parse()?)
126 } else {
127 None
128 };
129 Ok(FilterParams {
130 filter_on: Some(!filters.is_empty()),
131 abun_filter: (low_abun, high_abun),
132 err_filter: filters
133 .get("errFilter")
134 .unwrap_or(&"0".to_string())
135 .parse()?,
136 strand_filter: filters
137 .get("strandFilter")
138 .unwrap_or(&"0".to_string())
139 .parse()?,
140 })
141 }
142}
143
144impl Default for FilterParams {
145 fn default() -> Self {
146 FilterParams {
147 filter_on: Some(false),
148 abun_filter: (None, None),
149 err_filter: 0.,
150 strand_filter: 0.,
151 }
152 }
153}
154
155pub fn guess_filter_threshold(sketch: &[KmerCount], filter_level: f64) -> u32 {
163 let hist_data = hist(sketch);
164 let total_counts = hist_data
165 .iter()
166 .enumerate()
167 .map(|t| (t.0 as u64 + 1) * t.1)
168 .sum::<u64>() as f64;
169 let cutoff_amt = filter_level * total_counts;
170
171 let mut wgt_cutoff: usize = 0;
174 let mut cum_count: u64 = 0;
175 for count in &hist_data {
176 cum_count += wgt_cutoff as u64 * *count as u64;
177 if cum_count as f64 > cutoff_amt {
178 break;
179 }
180 wgt_cutoff += 1;
181 }
182
183 if wgt_cutoff == 0 {
185 return 1;
186 }
187
188 let win_size = cmp::max(1, wgt_cutoff / 20);
190 let mut sum: u64 = hist_data[..win_size].iter().sum();
191 let mut lowest_val = sum;
192 let mut lowest_idx = win_size - 1;
193 for (i, j) in (0..wgt_cutoff - win_size).zip(win_size..wgt_cutoff) {
194 if sum <= lowest_val {
195 lowest_val = sum;
196 lowest_idx = j;
197 }
198 sum -= hist_data[i];
199 sum += hist_data[j];
200 }
201
202 lowest_idx as u32 + 1
203}
204
205#[test]
206fn test_guess_filter_threshold() {
207 let sketch = vec![];
208 let cutoff = guess_filter_threshold(&sketch, 0.2);
209 assert_eq!(cutoff, 1);
210
211 let sketch = vec![KmerCount {
212 hash: 1,
213 kmer: vec![],
214 count: 1,
215 extra_count: 0,
216 label: None,
217 }];
218 let cutoff = guess_filter_threshold(&sketch, 0.2);
219 assert_eq!(cutoff, 1);
220
221 let sketch = vec![
222 KmerCount {
223 hash: 1,
224 kmer: vec![],
225 count: 1,
226 extra_count: 0,
227 label: None,
228 },
229 KmerCount {
230 hash: 2,
231 kmer: vec![],
232 count: 1,
233 extra_count: 0,
234 label: None,
235 },
236 ];
237 let cutoff = guess_filter_threshold(&sketch, 0.2);
238 assert_eq!(cutoff, 1);
239
240 let sketch = vec![
241 KmerCount {
242 hash: 1,
243 kmer: vec![],
244 count: 1,
245 extra_count: 0,
246 label: None,
247 },
248 KmerCount {
249 hash: 2,
250 kmer: vec![],
251 count: 9,
252 extra_count: 0,
253 label: None,
254 },
255 ];
256 let cutoff = guess_filter_threshold(&sketch, 0.2);
257 assert_eq!(cutoff, 8);
258
259 let sketch = vec![
260 KmerCount {
261 hash: 1,
262 kmer: vec![],
263 count: 1,
264 extra_count: 0,
265 label: None,
266 },
267 KmerCount {
268 hash: 2,
269 kmer: vec![],
270 count: 10,
271 extra_count: 0,
272 label: None,
273 },
274 KmerCount {
275 hash: 3,
276 kmer: vec![],
277 count: 10,
278 extra_count: 0,
279 label: None,
280 },
281 KmerCount {
282 hash: 4,
283 kmer: vec![],
284 count: 9,
285 extra_count: 0,
286 label: None,
287 },
288 ];
289 let cutoff = guess_filter_threshold(&sketch, 0.1);
290 assert_eq!(cutoff, 8);
291
292 let sketch = vec![
293 KmerCount {
294 hash: 1,
295 kmer: vec![],
296 count: 1,
297 extra_count: 0,
298 label: None,
299 },
300 KmerCount {
301 hash: 2,
302 kmer: vec![],
303 count: 1,
304 extra_count: 0,
305 label: None,
306 },
307 KmerCount {
308 hash: 3,
309 kmer: vec![],
310 count: 2,
311 extra_count: 0,
312 label: None,
313 },
314 KmerCount {
315 hash: 4,
316 kmer: vec![],
317 count: 4,
318 extra_count: 0,
319 label: None,
320 },
321 ];
322 let cutoff = guess_filter_threshold(&sketch, 0.1);
323 assert_eq!(cutoff, 1);
324
325 let sketch = vec![KmerCount {
327 hash: 2,
328 kmer: vec![],
329 count: 2,
330 extra_count: 0,
331 label: None,
332 }];
333 let cutoff = guess_filter_threshold(&sketch, 1.);
334 assert_eq!(cutoff, 2);
335}
336
337pub fn filter_abundance(
338 sketch: &[KmerCount],
339 low: Option<u32>,
340 high: Option<u32>,
341) -> Vec<KmerCount> {
342 let mut filtered = Vec::new();
343 let lo_threshold = low.unwrap_or(0u32);
344 let hi_threshold = high.unwrap_or(u32::max_value());
345 for kmer in sketch {
346 if lo_threshold <= kmer.count && kmer.count <= hi_threshold {
347 filtered.push(kmer.clone());
348 }
349 }
350 filtered
351}
352
353#[test]
354fn test_filter_abundance() {
355 let sketch = vec![
356 KmerCount {
357 hash: 1,
358 kmer: vec![],
359 count: 1,
360 extra_count: 0,
361 label: None,
362 },
363 KmerCount {
364 hash: 2,
365 kmer: vec![],
366 count: 1,
367 extra_count: 0,
368 label: None,
369 },
370 ];
371 let filtered = filter_abundance(&sketch, Some(1), None);
372 assert_eq!(filtered.len(), 2);
373 assert_eq!(filtered[0].hash, 1);
374 assert_eq!(filtered[1].hash, 2);
375
376 let sketch = vec![
377 KmerCount {
378 hash: 1,
379 kmer: vec![],
380 count: 1,
381 extra_count: 0,
382 label: None,
383 },
384 KmerCount {
385 hash: 2,
386 kmer: vec![],
387 count: 10,
388 extra_count: 0,
389 label: None,
390 },
391 KmerCount {
392 hash: 3,
393 kmer: vec![],
394 count: 10,
395 extra_count: 0,
396 label: None,
397 },
398 KmerCount {
399 hash: 4,
400 kmer: vec![],
401 count: 9,
402 extra_count: 0,
403 label: None,
404 },
405 ];
406 let filtered = filter_abundance(&sketch, Some(9), None);
407 assert_eq!(filtered.len(), 3);
408 assert_eq!(filtered[0].hash, 2);
409 assert_eq!(filtered[1].hash, 3);
410 assert_eq!(filtered[2].hash, 4);
411
412 let filtered = filter_abundance(&sketch, Some(2), Some(9));
413 assert_eq!(filtered.len(), 1);
414 assert_eq!(filtered[0].hash, 4);
415}
416
417pub fn filter_strands(sketch: &[KmerCount], ratio_cutoff: f64) -> Vec<KmerCount> {
422 let mut filtered = Vec::new();
423 for kmer in sketch {
424 if kmer.count < 16 {
429 filtered.push(kmer.clone());
430 continue;
431 }
432
433 let lowest_strand_count = cmp::min(kmer.extra_count, kmer.count - kmer.extra_count);
435 if (lowest_strand_count as f64 / kmer.count as f64) >= ratio_cutoff {
436 filtered.push(kmer.clone());
437 }
438 }
439 filtered
440}
441
442#[test]
443fn test_filter_strands() {
444 let sketch = vec![
445 KmerCount {
446 hash: 1,
447 kmer: vec![],
448 count: 10,
449 extra_count: 1,
450 label: None,
451 },
452 KmerCount {
453 hash: 2,
454 kmer: vec![],
455 count: 10,
456 extra_count: 2,
457 label: None,
458 },
459 KmerCount {
460 hash: 3,
461 kmer: vec![],
462 count: 10,
463 extra_count: 8,
464 label: None,
465 },
466 KmerCount {
467 hash: 4,
468 kmer: vec![],
469 count: 10,
470 extra_count: 9,
471 label: None,
472 },
473 ];
474 let filtered = filter_strands(&sketch, 0.15);
475 assert_eq!(filtered.len(), 4);
476 assert_eq!(filtered[0].hash, 1);
477 assert_eq!(filtered[3].hash, 4);
478
479 let sketch = vec![
480 KmerCount {
481 hash: 1,
482 kmer: vec![],
483 count: 16,
484 extra_count: 1,
485 label: None,
486 },
487 KmerCount {
488 hash: 2,
489 kmer: vec![],
490 count: 16,
491 extra_count: 2,
492 label: None,
493 },
494 KmerCount {
495 hash: 3,
496 kmer: vec![],
497 count: 16,
498 extra_count: 8,
499 label: None,
500 },
501 KmerCount {
502 hash: 4,
503 kmer: vec![],
504 count: 16,
505 extra_count: 9,
506 label: None,
507 },
508 ];
509 let filtered = filter_strands(&sketch, 0.15);
510 assert_eq!(filtered.len(), 2);
511 assert_eq!(filtered[0].hash, 3);
512 assert_eq!(filtered[1].hash, 4);
513}