use std::cmp::Ordering;
use ndarray::Array2;
use crate::serialization::{Sketch, SketchDistance};
use crate::sketch_schemes::KmerCount;
pub fn distance(
query_sketch: &Sketch,
ref_sketch: &Sketch,
old_mode: bool,
) -> Result<SketchDistance, &'static str> {
let distances = if old_mode {
old_distance(&query_sketch.hashes, &ref_sketch.hashes)
} else {
let mut min_scale = 0.;
if let Some(scale1) = query_sketch.sketch_params.hash_info().3 {
if let Some(scale2) = ref_sketch.sketch_params.hash_info().3 {
min_scale = f64::min(scale1, scale2);
}
}
raw_distance(&query_sketch.hashes, &ref_sketch.hashes, min_scale)
};
let containment = distances.0;
let jaccard = distances.1;
let common_hashes = distances.2;
let total_hashes = distances.3;
let k = query_sketch.sketch_params.k() as f64;
let mash_distance: f64 = -1.0 * ((2.0 * jaccard) / (1.0 + jaccard)).ln() / k;
Ok(SketchDistance {
containment,
jaccard,
mash_distance: f64::min(1f64, f64::max(0f64, mash_distance)),
common_hashes,
total_hashes,
query: query_sketch.name.to_string(),
reference: ref_sketch.name.to_string(),
})
}
pub fn raw_distance(
query_hashes: &[KmerCount],
ref_hashes: &[KmerCount],
scale: f64,
) -> (f64, f64, u64, u64) {
fn kmers_are_sorted(kmer_counts: &[KmerCount]) -> bool {
for slice in kmer_counts.windows(2) {
if slice[0].hash > slice[1].hash {
return false;
}
}
true
}
debug_assert!(kmers_are_sorted(query_hashes));
debug_assert!(kmers_are_sorted(ref_hashes));
let mut i: usize = 0;
let mut j: usize = 0;
let mut common: u64 = 0;
while let (Some(query), Some(refer)) = (query_hashes.get(i), ref_hashes.get(j)) {
match query.hash.cmp(&refer.hash) {
Ordering::Less => i += 1,
Ordering::Greater => j += 1,
Ordering::Equal => {
common += 1;
i += 1;
j += 1;
}
}
}
if scale > 0. {
let max_hash = u64::max_value() / scale.recip() as u64;
while query_hashes
.get(i)
.map(|kmer_count| kmer_count.hash < max_hash)
.unwrap_or(false)
{
i += 1;
}
while ref_hashes
.get(j)
.map(|kmer_count| kmer_count.hash < max_hash)
.unwrap_or(false)
{
j += 1;
}
}
let containment = if j == 0 { 0. } else { common as f64 / j as f64 };
let total = i as u64 - common + j as u64;
let jaccard: f64 = if total == 0 {
1.
} else {
common as f64 / total as f64
};
(containment, jaccard, common, total)
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
fn kc(arr: &[u64]) -> Vec<KmerCount> {
arr.iter()
.map(|x| KmerCount {
hash: *x,
kmer: vec![],
count: 1,
extra_count: 1,
label: None,
})
.collect()
}
proptest! {
#[test]
fn test_raw_distance_commutes(mut query_hashes: Vec<u64>, mut ref_hashes: Vec<u64>) {
query_hashes.sort();
ref_hashes.sort();
let lhs = kc(&query_hashes);
let rhs = kc(&ref_hashes);
prop_assert_eq!(raw_distance(&lhs, &rhs, 0.), raw_distance(&rhs, &lhs, 0.));
}
}
#[test]
fn test_raw_distance() {
let (cont, jac, com, total) = raw_distance(&kc(&[0, 1, 2]), &kc(&[1, 2]), 0.);
assert_eq!(cont, 2. / 2.);
assert_eq!(jac, 2. / 3.);
assert_eq!(com, 2);
assert_eq!(total, 3);
let (cont, jac, com, total) = raw_distance(&kc(&[0, 2]), &kc(&[1, 2]), 0.);
assert_eq!(cont, 1. / 2.);
assert_eq!(jac, 1. / 3.);
assert_eq!(com, 1);
assert_eq!(total, 3);
let (cont, jac, com, total) = raw_distance(&kc(&[0, 1]), &kc(&[2, 3]), 0.);
assert_eq!(cont, 0. / 2.);
assert_eq!(jac, 0. / 2.);
assert_eq!(com, 0);
assert_eq!(total, 2);
assert_eq!((0., 1., 0, 0), raw_distance(&kc(&[]), &kc(&[]), 0.));
assert_eq!((0., 1., 0, 0), raw_distance(&kc(&[]), &kc(&[5]), 0.));
}
#[test]
fn test_raw_distance_scaled() {
let (cont, jac, com, total) = raw_distance(&kc(&[10, 15, 20]), &kc(&[15, 20]), 1e-18);
assert_eq!(cont, 2. / 2.);
assert_eq!(jac, 2. / 3.);
assert_eq!(com, 2);
assert_eq!(total, 3);
let (cont, jac, com, total) = raw_distance(&kc(&[5, 10, 15]), &kc(&[5, 10]), 1e-18);
assert_eq!(cont, 2. / 2.);
assert_eq!(jac, 2. / 3.);
assert_eq!(com, 2);
assert_eq!(total, 3);
let (cont, jac, com, total) = raw_distance(&kc(&[5, 10, 15, 20]), &kc(&[5, 10]), 1e-18);
assert_eq!(cont, 2. / 2.);
assert_eq!(jac, 2. / 3.);
assert_eq!(com, 2);
assert_eq!(total, 3);
let (cont, jac, com, total) = raw_distance(&kc(&[5, 10]), &kc(&[5, 10, 15, 20]), 1e-18);
assert_eq!(cont, 2. / 3.);
assert_eq!(jac, 2. / 3.);
assert_eq!(com, 2);
assert_eq!(total, 3);
}
fn mash_paper_distance(sketch2: &[KmerCount], sketch1: &[KmerCount]) -> (f64, f64, u64, u64) {
let mut i: usize = 0;
let mut j: usize = 0;
let mut common: u64 = 0;
let mut total: u64 = 0;
let sketch_size = sketch1.len();
while (total < sketch_size as u64) && (i < sketch1.len()) && (j < sketch2.len()) {
if sketch1[i].hash < sketch2[j].hash {
i += 1;
} else if sketch2[j].hash < sketch1[i].hash {
j += 1;
} else {
i += 1;
j += 1;
common += 1;
}
total += 1;
}
if total < sketch_size as u64 {
if i < sketch1.len() {
total += (sketch1.len() - 1) as u64;
}
if j < sketch2.len() {
total += (sketch2.len() - 1) as u64;
}
if total > sketch_size as u64 {
total = sketch_size as u64;
}
}
let containment: f64 = common as f64 / i as f64;
let jaccard: f64 = common as f64 / total as f64;
(containment, jaccard, common, total)
}
#[test]
fn test_mash_compatibility() {
let (cont, _jac, _com, _total) = mash_paper_distance(&kc(&[0, 1, 2]), &kc(&[1, 2]));
assert_eq!(cont, 2. / 2.);
let (_cont, _jac, _com, _total) = mash_paper_distance(&kc(&[0, 2]), &kc(&[1, 2]));
let (_cont, jac, com, total) = mash_paper_distance(&kc(&[0, 1]), &kc(&[2, 3]));
assert_eq!(jac, 0. / 2.);
assert_eq!(com, 0);
assert_eq!(total, 2);
}
#[test]
fn test_distance_scaled() -> Result<(), Box<dyn std::error::Error>> {
use crate::sketch_schemes::scaled::ScaledSketcher;
use crate::sketch_schemes::SketchScheme;
let mut queue1 = ScaledSketcher::new(3, 0.001, 2, 42);
queue1.push(b"ca", 0);
queue1.push(b"cc", 1);
queue1.push(b"ac", 0);
queue1.push(b"ac", 1);
let array1 = queue1.to_sketch();
let mut queue2 = ScaledSketcher::new(3, 0.001, 2, 42);
queue2.push(b"ca", 0);
queue2.push(b"cc", 1);
queue2.push(b"ac", 0);
queue2.push(b"ac", 1);
let array2 = queue2.to_sketch();
let dist = distance(&array1, &array2, false)?;
assert_eq!(dist.jaccard, 1.0);
assert_eq!(dist.containment, 1.0);
assert_eq!(dist.common_hashes, 3);
Ok(())
}
}
pub fn old_distance(query_sketch: &[KmerCount], ref_sketch: &[KmerCount]) -> (f64, f64, u64, u64) {
let mut i: usize = 0;
let mut common: u64 = 0;
let mut total: u64 = 0;
for ref_hash in ref_sketch {
while (query_sketch[i].hash < ref_hash.hash) && (i < query_sketch.len() - 1) {
i += 1;
}
if query_sketch[i].hash == ref_hash.hash {
common += 1;
}
total += 1;
}
let containment: f64 = common as f64 / total as f64;
let jaccard: f64 = common as f64 / (common + 2 * (total - common)) as f64;
(containment, jaccard, common, total)
}
pub fn minmer_matrix<U>(ref_sketch: &[KmerCount], sketches: &[U]) -> Array2<i32>
where
U: AsRef<[KmerCount]>,
{
let mut result = Array2::<i32>::zeros((sketches.len(), ref_sketch.len()));
for (i, sketch) in sketches.iter().map(|s| s.as_ref()).enumerate() {
let mut ref_pos = 0;
for hash in sketch.iter() {
while (hash.hash > ref_sketch[ref_pos].hash) && (ref_pos < ref_sketch.len() - 1) {
ref_pos += 1;
}
if hash.hash == ref_sketch[ref_pos].hash {
result[[i, ref_pos]] = hash.count as i32;
}
}
}
result
}