1use crate::signature::Signature;
2use crate::sketch::Sketch;
3use anyhow::anyhow;
4use anyhow::Result;
5use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
6use serde::{Deserialize, Serialize};
7use std::{
8 fmt::{self, Display, Formatter},
9 ops::DerefMut,
10 sync::Mutex,
11};
12
13#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
14pub struct CompareResult {
15 pub from_name: String,
16 pub to_name: String,
17 pub num_common: usize,
18 pub num_kmers: usize,
19 pub option_num_skipped: Option<usize>,
20 pub reverse: bool,
21 pub estimated_containment: f64,
22}
23
24impl Display for CompareResult {
25 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
26 if self.reverse {
27 write!(
28 f,
29 "{}\t{}\t{}\t{}\t{}\t{}\t{}",
30 self.to_name,
31 self.from_name,
32 self.num_common,
33 self.num_kmers,
34 self.num_common as f64 / self.num_kmers as f64 * 100.0, self.estimated_containment,
36 self.option_num_skipped.unwrap_or(0)
37 )?;
38 Ok(())
39 } else {
40 write!(
41 f,
42 "{}\t{}\t{}\t{}\t{}\t{}\t{}",
43 self.from_name,
44 self.to_name,
45 self.num_common,
46 self.num_kmers,
47 self.num_common as f64 / self.num_kmers as f64 * 100.0,
48 self.estimated_containment,
49 self.option_num_skipped.unwrap_or(0)
50 )
51 }
52 }
53}
54
55pub struct MultiComp {
56 from: Vec<Sketch>,
57 to: Vec<Sketch>,
58 results: Vec<CompareResult>,
59 threads: usize,
60 kmer_size: u8,
61 cutoff: f64,
62 use_stats: bool,
63 gc_bounds: Option<(u8, u8)>,
64}
65
66impl MultiComp {
67 pub fn new(
68 mut from: Vec<Signature>,
69 mut to: Vec<Signature>,
70 threads: usize,
71 cutoff: f64,
72 use_stats: bool,
73 gc_bounds: Option<(u8, u8)>,
74 ) -> Result<Self> {
75 let kmer_size = from
76 .first()
77 .ok_or_else(|| anyhow!("Empty from list"))?
78 .kmer_size;
79
80 Ok(MultiComp {
81 from: from.iter_mut().map(|e| e.collapse()).collect(),
82 to: to.iter_mut().map(|e| e.collapse()).collect(),
83 results: Vec::new(),
84 threads,
85 kmer_size,
86 cutoff,
87 use_stats,
88 gc_bounds,
89 })
90 }
91
92 pub fn compare(&mut self) -> Result<()> {
93 let pool = rayon::ThreadPoolBuilder::new()
94 .num_threads(self.threads)
95 .build()?;
96
97 let results = Mutex::new(Vec::new());
98
99 pool.install(|| {
100 self.from.par_iter().try_for_each(|origin| {
101 self.to.par_iter().try_for_each(|target| {
102 if target.kmer_size != self.kmer_size || origin.kmer_size != self.kmer_size {
103 return Err(anyhow!(
104 "Kmer sizes do not match, expected: {}, got: {}",
105 self.kmer_size,
106 origin.kmer_size
107 ));
108 }
109 let mut comparator =
110 Comparator::new(origin, target, self.use_stats, self.gc_bounds);
111 comparator.compare()?;
112 results
113 .lock()
114 .unwrap()
115 .deref_mut()
116 .push(comparator.finalize());
117 Ok::<(), anyhow::Error>(())
118 })
119 })
120 })?;
121
122 self.results = results.into_inner().unwrap();
123 Ok(())
124 }
125
126 pub fn finalize(self) -> Vec<CompareResult> {
127 self.results
128 .into_iter()
129 .filter(|e| e.num_common as f64 / e.num_kmers as f64 * 100.0 > self.cutoff)
130 .collect()
131 }
132}
133
134pub struct Comparator<'a> {
135 larger: &'a Sketch,
136 smaller: &'a Sketch,
137 num_kmers: usize,
138 num_common: usize,
139 num_skipped: usize,
140 reverse: bool,
141 use_stats: bool,
142 gc_bounds: Option<(u8, u8)>,
143}
144
145impl<'a> Comparator<'a> {
146 pub fn new(
147 sketch_a: &'a Sketch,
148 sketch_b: &'a Sketch,
149 use_stats: bool,
150 gc_bounds: Option<(u8, u8)>,
151 ) -> Self {
152 let (larger, smaller, reverse) = if sketch_a.hashes.len() >= sketch_b.hashes.len() {
153 (sketch_a, sketch_b, false)
155 } else {
156 (sketch_b, sketch_a, true)
158 };
159 Comparator {
160 larger,
161 smaller,
162 num_kmers: 0,
163 num_common: 0,
164 num_skipped: 0,
165 reverse,
166 use_stats,
167 gc_bounds,
168 }
169 }
170
171 #[inline]
176 pub fn compare(&mut self) -> Result<()> {
177 if self.use_stats {
178 for (hash, stats) in &self.smaller.hashes {
179 let smaller_stats = stats.as_ref().ok_or_else(|| anyhow!("Missing stats"))?;
180 self.num_kmers += 1;
181 if let Some(stats) = self.larger.hashes.get(hash) {
182 let larger_stats = stats.as_ref().ok_or_else(|| anyhow!("Missing stats"))?;
183 if self.reverse {
184 if !larger_stats.compare(smaller_stats, self.gc_bounds) {
185 self.num_skipped += 1;
186 } else {
187 self.num_common += 1;
188 }
189 } else if !smaller_stats.compare(larger_stats, self.gc_bounds) {
190 self.num_skipped += 1;
191 } else {
192 self.num_common += 1;
193 }
194 };
195 }
196 } else {
197 for hash in self.smaller.hashes.keys() {
198 self.num_kmers += 1;
199 if self.larger.hashes.contains_key(hash) {
200 self.num_common += 1;
201 };
202 }
203 }
204 Ok(())
205 }
206
207 pub fn finalize(self) -> CompareResult {
208 let larger_fraction = self.larger.num_kmers as f64 / self.larger.max_kmers as f64;
210 let smaller_fraction = self.smaller.num_kmers as f64 / self.smaller.max_kmers as f64;
212 let fraction = if larger_fraction < smaller_fraction {
214 smaller_fraction / larger_fraction
215 } else {
216 larger_fraction / smaller_fraction
217 };
218 let estimated_containment =
219 self.num_common as f64 / self.num_kmers as f64 * fraction * 100.0;
220
221 CompareResult {
222 from_name: self.larger.name.clone(),
223 to_name: self.smaller.name.clone(),
224 num_kmers: self.num_kmers,
225 num_common: self.num_common,
226 option_num_skipped: if self.use_stats {
227 Some(self.num_skipped)
228 } else {
229 None
230 },
231 reverse: self.reverse,
232 estimated_containment,
233 }
234 }
235
236 #[allow(dead_code)]
237 pub fn reset(&mut self) {
238 self.num_kmers = 0;
239 self.num_common = 0;
240 self.num_skipped = 0;
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use std::collections::HashMap;
247
248 use crate::{compare::CompareResult, signature::Signature, sketch::Stats};
249
250 use super::MultiComp;
251
252 #[test]
253 fn test_comp_without_stats() {
254 let mut hashmap = HashMap::default();
255 hashmap.extend([(1, None), (2, None), (3, None)]);
256 let sketch_a = crate::sketch::Sketch {
257 name: "a".to_string(),
258 hashes: hashmap,
259 num_kmers: 3,
260 max_kmers: 10,
261 kmer_size: 21,
262 };
263 let mut hashmap2 = HashMap::default();
264 hashmap2.extend([(1, None), (2, None), (4, None)]);
265 let sketch_b = crate::sketch::Sketch {
266 name: "b".to_string(),
267 hashes: hashmap2,
268 num_kmers: 3,
269 max_kmers: 10,
270 kmer_size: 21,
271 };
272
273 let mut comp = super::Comparator::new(&sketch_a, &sketch_b, false, None);
274 comp.compare().unwrap();
275 let result = comp.finalize();
276 assert_eq!(result.num_kmers, 3);
277 assert_eq!(result.num_common, 2);
278 assert_eq!(result.estimated_containment, 66.66666666666666);
279 assert_eq!(result.option_num_skipped, None);
280
281 let constructed_result = CompareResult {
282 from_name: "a".to_string(),
283 to_name: "b".to_string(),
284 num_kmers: 3,
285 num_common: 2,
286 option_num_skipped: None,
287 reverse: false,
288 estimated_containment: 66.66666666666666,
289 };
290 assert_eq!(result, constructed_result);
291 }
292
293 #[test]
294 fn test_multi_comp() {
295 let mut hashmap = HashMap::default();
296 hashmap.extend([
297 (1, Some(Stats::new(3, 20))),
298 (2, Some(Stats::new(3, 20))),
299 (3, Some(Stats::new(3, 20))),
300 (4, Some(Stats::new(3, 20))),
301 ]);
302 let sketch_a = crate::sketch::Sketch {
303 name: "a".to_string(),
304 hashes: hashmap,
305 num_kmers: 4,
306 max_kmers: 10,
307 kmer_size: 21,
308 };
309 let mut hashmap = HashMap::default();
310 hashmap.extend([
311 (1, Some(Stats::new(5, 20))),
312 (2, Some(Stats::new(3, 20))),
313 (3, Some(Stats::new(2, 30))),
314 (4, Some(Stats::new(2, 60))),
315 ]);
316 let sketch_b = crate::sketch::Sketch {
317 name: "b".to_string(),
318 hashes: hashmap,
319 num_kmers: 4,
320 max_kmers: 10,
321 kmer_size: 21,
322 };
323 let mut comp = MultiComp::new(
324 vec![Signature {
325 file_name: "test".to_string(),
326 sketches: vec![sketch_a],
327 algorithm: crate::cli::HashAlgorithms::Ahash,
328 kmer_size: 21,
329 max_hash: u64::MAX,
330 }],
331 vec![Signature {
332 file_name: "test2".to_string(),
333 sketches: vec![sketch_b],
334 algorithm: crate::cli::HashAlgorithms::Ahash,
335 kmer_size: 21,
336 max_hash: u64::MAX,
337 }],
338 1,
339 0.0,
340 true,
341 Some((10, 10)),
342 )
343 .unwrap();
344
345 comp.compare().unwrap();
346 let res = comp.finalize();
347
348 assert_eq!(res.len(), 1);
349 let expected = CompareResult {
350 from_name: "test".to_string(),
351 to_name: "test2".to_string(),
352 num_kmers: 4,
353 num_common: 2,
354 option_num_skipped: Some(2),
355 reverse: false,
356 estimated_containment: 50.0,
357 };
358 assert_eq!(res[0], expected);
359
360 assert_eq!(
361 res[0].to_string(),
362 "test\ttest2\t2\t4\t50\t50\t2".to_string()
363 );
364 }
365}