sourmash/index/
mod.rs

1//! # Indexing structures for fast similarity search
2//!
3//! An index organizes signatures to allow for fast similarity search.
4//! Some indices also support containment searches.
5
6pub mod linear;
7
8#[cfg(not(target_arch = "wasm32"))]
9#[cfg(feature = "branchwater")]
10pub mod revindex;
11
12pub mod search;
13
14use std::path::Path;
15
16use getset::{CopyGetters, Getters, Setters};
17use log::trace;
18use serde::{Deserialize, Serialize};
19use stats::{median, stddev};
20use typed_builder::TypedBuilder;
21
22use crate::ani_utils::{ani_ci_from_containment, ani_from_containment};
23use crate::encodings::Idx;
24use crate::index::search::{search_minhashes, search_minhashes_containment};
25use crate::prelude::*;
26use crate::selection::Selection;
27use crate::signature::SigsTrait;
28use crate::sketch::minhash::KmerMinHash;
29use crate::storage::SigStore;
30use crate::Error::CannotUpsampleScaled;
31use crate::Result;
32
33#[derive(TypedBuilder, CopyGetters, Getters, Setters, Serialize, Deserialize, Debug, PartialEq)]
34pub struct GatherResult {
35    #[getset(get_copy = "pub")]
36    intersect_bp: u64,
37
38    #[getset(get_copy = "pub")]
39    f_orig_query: f64,
40
41    #[getset(get_copy = "pub")]
42    f_match: f64,
43
44    #[getset(get_copy = "pub")]
45    f_unique_to_query: f64,
46
47    #[getset(get_copy = "pub")]
48    f_unique_weighted: f64,
49
50    #[getset(get_copy = "pub")]
51    average_abund: f64,
52
53    #[getset(get_copy = "pub")]
54    median_abund: f64,
55
56    #[getset(get_copy = "pub")]
57    std_abund: f64,
58
59    #[getset(get = "pub")]
60    filename: String,
61
62    #[getset(get = "pub")]
63    name: String,
64
65    #[getset(get = "pub")]
66    md5: String,
67
68    #[serde(skip)]
69    match_: SigStore,
70
71    #[getset(get_copy = "pub")]
72    f_match_orig: f64,
73
74    #[getset(get_copy = "pub")]
75    unique_intersect_bp: u64,
76
77    #[getset(get_copy = "pub")]
78    gather_result_rank: u32,
79
80    #[getset(get_copy = "pub")]
81    remaining_bp: u64,
82
83    #[getset(get_copy = "pub")]
84    n_unique_weighted_found: u64,
85
86    #[getset(get_copy = "pub")]
87    total_weighted_hashes: u64,
88
89    #[getset(get_copy = "pub")]
90    sum_weighted_found: u64,
91
92    #[getset(get_copy = "pub")]
93    query_containment_ani: f64,
94
95    #[getset(get_copy = "pub")]
96    #[serde(skip_serializing_if = "Option::is_none")]
97    query_containment_ani_ci_low: Option<f64>,
98
99    #[getset(get_copy = "pub")]
100    #[serde(skip_serializing_if = "Option::is_none")]
101    query_containment_ani_ci_high: Option<f64>,
102
103    #[getset(get_copy = "pub")]
104    match_containment_ani: f64,
105
106    #[getset(get_copy = "pub")]
107    #[serde(skip_serializing_if = "Option::is_none")]
108    match_containment_ani_ci_low: Option<f64>,
109
110    #[getset(get_copy = "pub")]
111    #[serde(skip_serializing_if = "Option::is_none")]
112    match_containment_ani_ci_high: Option<f64>,
113
114    #[getset(get_copy = "pub")]
115    average_containment_ani: f64,
116
117    #[getset(get_copy = "pub")]
118    max_containment_ani: f64,
119}
120
121impl GatherResult {
122    pub fn get_match(&self) -> Signature {
123        self.match_.clone().into()
124    }
125}
126
127type SigCounter = counter::Counter<Idx>;
128
129pub trait Index<'a> {
130    type Item: Comparable<Self::Item>;
131    //type SignatureIterator: Iterator<Item = Self::Item>;
132
133    fn find<F>(&self, search_fn: F, sig: &Self::Item, threshold: f64) -> Result<Vec<&Self::Item>>
134    where
135        F: Fn(&dyn Comparable<Self::Item>, &Self::Item, f64) -> bool,
136    {
137        Ok(self
138            .signature_refs()
139            .into_iter()
140            .flat_map(|node| {
141                if search_fn(&node, sig, threshold) {
142                    Some(node)
143                } else {
144                    None
145                }
146            })
147            .collect())
148    }
149
150    fn search(
151        &self,
152        sig: &Self::Item,
153        threshold: f64,
154        containment: bool,
155    ) -> Result<Vec<&Self::Item>> {
156        if containment {
157            self.find(search_minhashes_containment, sig, threshold)
158        } else {
159            self.find(search_minhashes, sig, threshold)
160        }
161    }
162
163    //fn gather(&self, sig: &Self::Item, threshold: f64) -> Result<Vec<&Self::Item>>;
164
165    fn insert(&mut self, node: Self::Item) -> Result<()>;
166
167    fn batch_insert(&mut self, nodes: Vec<Self::Item>) -> Result<()> {
168        for node in nodes {
169            self.insert(node)?;
170        }
171
172        Ok(())
173    }
174
175    fn save<P: AsRef<Path>>(&self, path: P) -> Result<()>;
176
177    fn load<P: AsRef<Path>>(path: P) -> Result<()>;
178
179    fn signatures(&self) -> Vec<Self::Item>;
180
181    fn signature_refs(&self) -> Vec<&Self::Item>;
182
183    fn len(&self) -> usize {
184        self.signature_refs().len()
185    }
186
187    fn is_empty(&self) -> bool {
188        self.len() == 0
189    }
190
191    /*
192    fn iter_signatures(&self) -> Self::SignatureIterator;
193    */
194}
195
196impl<N, L> Comparable<L> for &N
197where
198    N: Comparable<L>,
199{
200    fn similarity(&self, other: &L) -> f64 {
201        (*self).similarity(other)
202    }
203
204    fn containment(&self, other: &L) -> f64 {
205        (*self).containment(other)
206    }
207}
208
209#[allow(clippy::too_many_arguments)]
210pub fn calculate_gather_stats(
211    orig_query: &KmerMinHash,
212    remaining_query: KmerMinHash,
213    match_sig: SigStore,
214    match_size: usize,
215    gather_result_rank: u32,
216    sum_weighted_found: u64,
217    total_weighted_hashes: u64,
218    calc_abund_stats: bool,
219    calc_ani_ci: bool,
220    confidence: Option<f64>,
221) -> Result<(GatherResult, (Vec<u64>, u64))> {
222    // get match_mh
223    let match_mh = match_sig.minhash().expect("cannot retrieve sketch");
224
225    // it's ok to downsample match, but query is often big and repeated,
226    // so we do not allow downsampling of query in this function.
227    if match_mh.scaled() > remaining_query.scaled() {
228        return Err(CannotUpsampleScaled);
229    }
230
231    let match_mh = match_mh
232        .clone()
233        .downsample_scaled(remaining_query.scaled())
234        .expect("cannot downsample match");
235
236    // calculate intersection
237    let isect = match_mh
238        .intersection(&remaining_query)
239        .expect("could not do intersection");
240    let isect_size = isect.0.len();
241    trace!("isect_size: {isect_size}");
242    trace!("query.size: {}", remaining_query.size());
243
244    //bp remaining in subtracted query
245    let remaining_bp =
246        (remaining_query.size() - isect_size) as u64 * remaining_query.scaled() as u64;
247
248    // stats for this match vs original query
249    let (intersect_orig, _) = match_mh.intersection_size(orig_query).unwrap();
250    let intersect_bp = match_mh.scaled() as u64 * intersect_orig;
251    let f_orig_query = intersect_orig as f64 / orig_query.size() as f64;
252    let f_match_orig = intersect_orig as f64 / match_mh.size() as f64;
253
254    // stats for this match vs current (subtracted) query
255    let f_match = match_size as f64 / match_mh.size() as f64;
256    let unique_intersect_bp = match_mh.scaled() as u64 * isect_size as u64;
257    let f_unique_to_query = isect_size as f64 / orig_query.size() as f64;
258
259    // // get ANI values
260    let ksize = match_mh.ksize() as f64;
261    let query_containment_ani = ani_from_containment(f_orig_query, ksize);
262    let match_containment_ani = ani_from_containment(f_match_orig, ksize);
263    let mut query_containment_ani_ci_low = None;
264    let mut query_containment_ani_ci_high = None;
265    let mut match_containment_ani_ci_low = None;
266    let mut match_containment_ani_ci_high = None;
267
268    if calc_ani_ci {
269        let n_unique_kmers = match_mh.n_unique_kmers();
270        let (qani_low, qani_high) = ani_ci_from_containment(
271            f_unique_to_query,
272            ksize,
273            match_mh.scaled(),
274            n_unique_kmers,
275            confidence,
276        )?;
277        query_containment_ani_ci_low = Some(qani_low);
278        query_containment_ani_ci_high = Some(qani_high);
279
280        let (mani_low, mani_high) = ani_ci_from_containment(
281            f_match,
282            ksize,
283            match_mh.scaled(),
284            n_unique_kmers,
285            confidence,
286        )?;
287        match_containment_ani_ci_low = Some(mani_low);
288        match_containment_ani_ci_high = Some(mani_high);
289    }
290
291    let average_containment_ani = (query_containment_ani + match_containment_ani) / 2.0;
292    let max_containment_ani = f64::max(query_containment_ani, match_containment_ani);
293
294    // set up non-abundance weighted values
295    let mut f_unique_weighted = f_unique_to_query;
296    let mut average_abund = 1.0;
297    let mut median_abund = 1.0;
298    let mut std_abund = 0.0;
299    // should these default to the unweighted numbers?
300    let mut n_unique_weighted_found = 0;
301    let mut sum_total_weighted_found = 0;
302
303    // If abundance, calculate abund-related metrics (vs current query)
304    if calc_abund_stats {
305        // take abunds from subtracted query
306        let (abunds, unique_weighted_found) = match match_mh.inflated_abundances(&remaining_query) {
307            Ok((abunds, unique_weighted_found)) => (abunds, unique_weighted_found),
308            Err(e) => {
309                return Err(e);
310            }
311        };
312
313        n_unique_weighted_found = unique_weighted_found;
314        sum_total_weighted_found = sum_weighted_found + n_unique_weighted_found;
315        f_unique_weighted = n_unique_weighted_found as f64 / total_weighted_hashes as f64;
316
317        average_abund = n_unique_weighted_found as f64 / abunds.len() as f64;
318
319        // todo: try to avoid clone for these?
320        median_abund = median(abunds.iter().cloned()).unwrap();
321        std_abund = stddev(abunds.iter().cloned());
322    }
323
324    let result = GatherResult::builder()
325        .intersect_bp(intersect_bp)
326        .f_orig_query(f_orig_query)
327        .f_match(f_match)
328        .f_unique_to_query(f_unique_to_query)
329        .f_unique_weighted(f_unique_weighted)
330        .average_abund(average_abund)
331        .median_abund(median_abund)
332        .std_abund(std_abund)
333        .filename(match_sig.filename())
334        .name(match_sig.name())
335        .md5(match_sig.md5sum())
336        .match_(match_sig)
337        .f_match_orig(f_match_orig)
338        .unique_intersect_bp(unique_intersect_bp)
339        .gather_result_rank(gather_result_rank)
340        .remaining_bp(remaining_bp)
341        .n_unique_weighted_found(n_unique_weighted_found)
342        .query_containment_ani(query_containment_ani)
343        .query_containment_ani_ci_low(query_containment_ani_ci_low)
344        .query_containment_ani_ci_high(query_containment_ani_ci_high)
345        .match_containment_ani_ci_low(match_containment_ani_ci_low)
346        .match_containment_ani_ci_high(match_containment_ani_ci_high)
347        .match_containment_ani(match_containment_ani)
348        .average_containment_ani(average_containment_ani)
349        .max_containment_ani(max_containment_ani)
350        .sum_weighted_found(sum_total_weighted_found)
351        .total_weighted_hashes(total_weighted_hashes)
352        .build();
353    Ok((result, isect))
354}
355
356#[cfg(test)]
357mod test_calculate_gather_stats {
358    use super::*;
359    use crate::cmd::ComputeParameters;
360    use crate::encodings::HashFunctions;
361    use crate::signature::Signature;
362    use crate::sketch::minhash::KmerMinHash;
363    use crate::sketch::Sketch;
364    // use std::f64::EPSILON;
365    // TODO: use f64::EPSILON when we bump MSRV
366    const EPSILON: f64 = 0.01;
367
368    #[test]
369    fn test_calculate_gather_stats() {
370        let scaled = 10;
371        let params = ComputeParameters::builder()
372            .ksizes(vec![31])
373            .scaled(scaled)
374            .build();
375
376        let mut match_sig = Signature::from_params(&params);
377        // create two minhash
378        let mut match_mh = KmerMinHash::new(scaled, 31, HashFunctions::Murmur64Dna, 42, true, 0);
379        match_mh.add_hash_with_abundance(1, 5);
380        match_mh.add_hash_with_abundance(3, 3);
381        match_mh.add_hash_with_abundance(5, 2);
382        match_mh.add_hash_with_abundance(8, 2);
383        match_mh.add_hash_with_abundance(11, 2); // Non-matching hash
384
385        match_sig.reset_sketches();
386        match_sig.push(Sketch::MinHash(match_mh.clone()));
387        match_sig.set_filename("match-filename");
388        match_sig.set_name("match-name");
389
390        eprintln!("num_sketches: {:?}", match_sig.size());
391        eprintln!("match_md5: {:?}", match_sig.md5sum());
392
393        // Setup orig_query minhash with abundances and non-matching hash
394        let mut orig_query = KmerMinHash::new(scaled, 31, HashFunctions::Murmur64Dna, 42, true, 0);
395        orig_query.add_hash_with_abundance(1, 3);
396        orig_query.add_hash_with_abundance(3, 2);
397        orig_query.add_hash_with_abundance(5, 1);
398        orig_query.add_hash_with_abundance(6, 1); // Non-matching hash
399        orig_query.add_hash_with_abundance(8, 1);
400        orig_query.add_hash_with_abundance(10, 1); // Non-matching hash
401
402        let query = orig_query.clone();
403        let total_weighted_hashes = orig_query.sum_abunds();
404
405        let match_size = 4;
406        let gather_result_rank = 0;
407        let calc_abund_stats = true;
408        let calc_ani_ci = false;
409        let (result, _isect) = calculate_gather_stats(
410            &orig_query,
411            query,
412            match_sig.into(),
413            match_size,
414            gather_result_rank,
415            0,
416            total_weighted_hashes.try_into().unwrap(),
417            calc_abund_stats,
418            calc_ani_ci,
419            None,
420        )
421        .unwrap();
422
423        // first, print all results
424        assert_eq!(result.filename(), "match-filename");
425        assert_eq!(result.name(), "match-name");
426        assert_eq!(result.md5(), "f54b271a62fb7e2856e7b8a33e741b6e");
427        assert_eq!(result.gather_result_rank, 0);
428        assert_eq!(result.remaining_bp, 20);
429
430        // results from match vs current query
431        assert_eq!(result.f_match, 0.8);
432        assert_eq!(result.unique_intersect_bp, 40);
433        assert_eq!(result.f_unique_to_query, 4.0 / 6.0);
434        eprintln!("{}", result.f_unique_weighted);
435        assert_eq!(result.f_unique_weighted, 7. / 9.);
436        assert_eq!(result.average_abund, 1.75);
437        assert_eq!(result.median_abund, 1.5);
438        assert_eq!(result.std_abund, 0.82915619758885);
439
440        // results from match vs orig_query
441        assert_eq!(result.intersect_bp, 40);
442        assert_eq!(result.f_orig_query, 4.0 / 6.0);
443        assert_eq!(result.f_match_orig, 4.0 / 5.0);
444
445        assert!((result.average_containment_ani - 0.98991665567826).abs() < EPSILON);
446        assert!((result.match_containment_ani - 0.9928276657672302).abs() < EPSILON);
447        assert!((result.query_containment_ani - 0.9870056455892898).abs() < EPSILON);
448        assert!((result.max_containment_ani - 0.9928276657672302).abs() < EPSILON);
449
450        assert_eq!(result.total_weighted_hashes, 9);
451        assert_eq!(result.n_unique_weighted_found, 7);
452        assert_eq!(result.sum_weighted_found, 7);
453    }
454}