jam_rs/
compare.rs

1use crate::signature::Signature;
2use crate::sketch::Sketch;
3use anyhow::anyhow;
4use anyhow::Result;
5use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
6use serde::{Deserialize, Serialize};
7use std::{
8    fmt::{self, Display, Formatter},
9    ops::DerefMut,
10    sync::Mutex,
11};
12
13#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
14pub struct CompareResult {
15    pub from_name: String,
16    pub to_name: String,
17    pub num_common: usize,
18    pub num_kmers: usize,
19    pub option_num_skipped: Option<usize>,
20    pub reverse: bool,
21    pub estimated_containment: f64,
22}
23
24impl Display for CompareResult {
25    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
26        if self.reverse {
27            write!(
28                f,
29                "{}\t{}\t{}\t{}\t{}\t{}\t{}",
30                self.to_name,
31                self.from_name,
32                self.num_common,
33                self.num_kmers,
34                self.num_common as f64 / self.num_kmers as f64 * 100.0, // Percent
35                self.estimated_containment,
36                self.option_num_skipped.unwrap_or(0)
37            )?;
38            Ok(())
39        } else {
40            write!(
41                f,
42                "{}\t{}\t{}\t{}\t{}\t{}\t{}",
43                self.from_name,
44                self.to_name,
45                self.num_common,
46                self.num_kmers,
47                self.num_common as f64 / self.num_kmers as f64 * 100.0,
48                self.estimated_containment,
49                self.option_num_skipped.unwrap_or(0)
50            )
51        }
52    }
53}
54
55pub struct MultiComp {
56    from: Vec<Sketch>,
57    to: Vec<Sketch>,
58    results: Vec<CompareResult>,
59    threads: usize,
60    kmer_size: u8,
61    cutoff: f64,
62    use_stats: bool,
63    gc_bounds: Option<(u8, u8)>,
64}
65
66impl MultiComp {
67    pub fn new(
68        mut from: Vec<Signature>,
69        mut to: Vec<Signature>,
70        threads: usize,
71        cutoff: f64,
72        use_stats: bool,
73        gc_bounds: Option<(u8, u8)>,
74    ) -> Result<Self> {
75        let kmer_size = from
76            .first()
77            .ok_or_else(|| anyhow!("Empty from list"))?
78            .kmer_size;
79
80        Ok(MultiComp {
81            from: from.iter_mut().map(|e| e.collapse()).collect(),
82            to: to.iter_mut().map(|e| e.collapse()).collect(),
83            results: Vec::new(),
84            threads,
85            kmer_size,
86            cutoff,
87            use_stats,
88            gc_bounds,
89        })
90    }
91
92    pub fn compare(&mut self) -> Result<()> {
93        let pool = rayon::ThreadPoolBuilder::new()
94            .num_threads(self.threads)
95            .build()?;
96
97        let results = Mutex::new(Vec::new());
98
99        pool.install(|| {
100            self.from.par_iter().try_for_each(|origin| {
101                self.to.par_iter().try_for_each(|target| {
102                    if target.kmer_size != self.kmer_size || origin.kmer_size != self.kmer_size {
103                        return Err(anyhow!(
104                            "Kmer sizes do not match, expected: {}, got: {}",
105                            self.kmer_size,
106                            origin.kmer_size
107                        ));
108                    }
109                    let mut comparator =
110                        Comparator::new(origin, target, self.use_stats, self.gc_bounds);
111                    comparator.compare()?;
112                    results
113                        .lock()
114                        .unwrap()
115                        .deref_mut()
116                        .push(comparator.finalize());
117                    Ok::<(), anyhow::Error>(())
118                })
119            })
120        })?;
121
122        self.results = results.into_inner().unwrap();
123        Ok(())
124    }
125
126    pub fn finalize(self) -> Vec<CompareResult> {
127        self.results
128            .into_iter()
129            .filter(|e| e.num_common as f64 / e.num_kmers as f64 * 100.0 > self.cutoff)
130            .collect()
131    }
132}
133
134pub struct Comparator<'a> {
135    larger: &'a Sketch,
136    smaller: &'a Sketch,
137    num_kmers: usize,
138    num_common: usize,
139    num_skipped: usize,
140    reverse: bool,
141    use_stats: bool,
142    gc_bounds: Option<(u8, u8)>,
143}
144
145impl<'a> Comparator<'a> {
146    pub fn new(
147        sketch_a: &'a Sketch,
148        sketch_b: &'a Sketch,
149        use_stats: bool,
150        gc_bounds: Option<(u8, u8)>,
151    ) -> Self {
152        let (larger, smaller, reverse) = if sketch_a.hashes.len() >= sketch_b.hashes.len() {
153            // DATABASE, INPUT -> Reverse = false
154            (sketch_a, sketch_b, false)
155        } else {
156            // INPUT, DATABASE -> Reverse = true
157            (sketch_b, sketch_a, true)
158        };
159        Comparator {
160            larger,
161            smaller,
162            num_kmers: 0,
163            num_common: 0,
164            num_skipped: 0,
165            reverse,
166            use_stats,
167            gc_bounds,
168        }
169    }
170
171    // Stats handling:
172    // GC & Size for the original contig are stored in the Stats struct
173    // This comparison is always in relation to the query sketch
174    // If reverse is true, the query sketch is the larger sketch
175    #[inline]
176    pub fn compare(&mut self) -> Result<()> {
177        if self.use_stats {
178            for (hash, stats) in &self.smaller.hashes {
179                let smaller_stats = stats.as_ref().ok_or_else(|| anyhow!("Missing stats"))?;
180                self.num_kmers += 1;
181                if let Some(stats) = self.larger.hashes.get(hash) {
182                    let larger_stats = stats.as_ref().ok_or_else(|| anyhow!("Missing stats"))?;
183                    if self.reverse {
184                        if !larger_stats.compare(smaller_stats, self.gc_bounds) {
185                            self.num_skipped += 1;
186                        } else {
187                            self.num_common += 1;
188                        }
189                    } else if !smaller_stats.compare(larger_stats, self.gc_bounds) {
190                        self.num_skipped += 1;
191                    } else {
192                        self.num_common += 1;
193                    }
194                };
195            }
196        } else {
197            for hash in self.smaller.hashes.keys() {
198                self.num_kmers += 1;
199                if self.larger.hashes.contains_key(hash) {
200                    self.num_common += 1;
201                };
202            }
203        }
204        Ok(())
205    }
206
207    pub fn finalize(self) -> CompareResult {
208        // Eg 0.1
209        let larger_fraction = self.larger.num_kmers as f64 / self.larger.max_kmers as f64;
210        // Eg 1.0
211        let smaller_fraction = self.smaller.num_kmers as f64 / self.smaller.max_kmers as f64;
212        // How much smaller is the smaller sketch
213        let fraction = if larger_fraction < smaller_fraction {
214            smaller_fraction / larger_fraction
215        } else {
216            larger_fraction / smaller_fraction
217        };
218        let estimated_containment =
219            self.num_common as f64 / self.num_kmers as f64 * fraction * 100.0;
220
221        CompareResult {
222            from_name: self.larger.name.clone(),
223            to_name: self.smaller.name.clone(),
224            num_kmers: self.num_kmers,
225            num_common: self.num_common,
226            option_num_skipped: if self.use_stats {
227                Some(self.num_skipped)
228            } else {
229                None
230            },
231            reverse: self.reverse,
232            estimated_containment,
233        }
234    }
235
236    #[allow(dead_code)]
237    pub fn reset(&mut self) {
238        self.num_kmers = 0;
239        self.num_common = 0;
240        self.num_skipped = 0;
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use std::collections::HashMap;
247
248    use crate::{compare::CompareResult, signature::Signature, sketch::Stats};
249
250    use super::MultiComp;
251
252    #[test]
253    fn test_comp_without_stats() {
254        let mut hashmap = HashMap::default();
255        hashmap.extend([(1, None), (2, None), (3, None)]);
256        let sketch_a = crate::sketch::Sketch {
257            name: "a".to_string(),
258            hashes: hashmap,
259            num_kmers: 3,
260            max_kmers: 10,
261            kmer_size: 21,
262        };
263        let mut hashmap2 = HashMap::default();
264        hashmap2.extend([(1, None), (2, None), (4, None)]);
265        let sketch_b = crate::sketch::Sketch {
266            name: "b".to_string(),
267            hashes: hashmap2,
268            num_kmers: 3,
269            max_kmers: 10,
270            kmer_size: 21,
271        };
272
273        let mut comp = super::Comparator::new(&sketch_a, &sketch_b, false, None);
274        comp.compare().unwrap();
275        let result = comp.finalize();
276        assert_eq!(result.num_kmers, 3);
277        assert_eq!(result.num_common, 2);
278        assert_eq!(result.estimated_containment, 66.66666666666666);
279        assert_eq!(result.option_num_skipped, None);
280
281        let constructed_result = CompareResult {
282            from_name: "a".to_string(),
283            to_name: "b".to_string(),
284            num_kmers: 3,
285            num_common: 2,
286            option_num_skipped: None,
287            reverse: false,
288            estimated_containment: 66.66666666666666,
289        };
290        assert_eq!(result, constructed_result);
291    }
292
293    #[test]
294    fn test_multi_comp() {
295        let mut hashmap = HashMap::default();
296        hashmap.extend([
297            (1, Some(Stats::new(3, 20))),
298            (2, Some(Stats::new(3, 20))),
299            (3, Some(Stats::new(3, 20))),
300            (4, Some(Stats::new(3, 20))),
301        ]);
302        let sketch_a = crate::sketch::Sketch {
303            name: "a".to_string(),
304            hashes: hashmap,
305            num_kmers: 4,
306            max_kmers: 10,
307            kmer_size: 21,
308        };
309        let mut hashmap = HashMap::default();
310        hashmap.extend([
311            (1, Some(Stats::new(5, 20))),
312            (2, Some(Stats::new(3, 20))),
313            (3, Some(Stats::new(2, 30))),
314            (4, Some(Stats::new(2, 60))),
315        ]);
316        let sketch_b = crate::sketch::Sketch {
317            name: "b".to_string(),
318            hashes: hashmap,
319            num_kmers: 4,
320            max_kmers: 10,
321            kmer_size: 21,
322        };
323        let mut comp = MultiComp::new(
324            vec![Signature {
325                file_name: "test".to_string(),
326                sketches: vec![sketch_a],
327                algorithm: crate::cli::HashAlgorithms::Ahash,
328                kmer_size: 21,
329                max_hash: u64::MAX,
330            }],
331            vec![Signature {
332                file_name: "test2".to_string(),
333                sketches: vec![sketch_b],
334                algorithm: crate::cli::HashAlgorithms::Ahash,
335                kmer_size: 21,
336                max_hash: u64::MAX,
337            }],
338            1,
339            0.0,
340            true,
341            Some((10, 10)),
342        )
343        .unwrap();
344
345        comp.compare().unwrap();
346        let res = comp.finalize();
347
348        assert_eq!(res.len(), 1);
349        let expected = CompareResult {
350            from_name: "test".to_string(),
351            to_name: "test2".to_string(),
352            num_kmers: 4,
353            num_common: 2,
354            option_num_skipped: Some(2),
355            reverse: false,
356            estimated_containment: 50.0,
357        };
358        assert_eq!(res[0], expected);
359
360        assert_eq!(
361            res[0].to_string(),
362            "test\ttest2\t2\t4\t50\t50\t2".to_string()
363        );
364    }
365}