mod intrinsics;
use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
use log::debug;
use packed_seq::{u32x8, ChunkIt, PackedNSeq, PaddedIt, Seq};
use seq_hash::KmerHasher;
type FwdNtHasher = seq_hash::NtHasher<false, 1>;
type RcNtHasher = seq_hash::NtHasher<true, 1>;
#[derive(bincode::Encode, bincode::Decode, Debug)]
pub enum Sketch {
BottomSketch(BottomSketch),
BucketSketch(BucketSketch),
}
fn compute_mash_distance(j: f32, k: usize) -> f32 {
assert!(j >= 0.0, "Jaccard similarity {j} should not be negative");
let mash_dist = -(2. * j / (1. + j)).ln() / k as f32;
assert!(
mash_dist >= 0.0,
"Bad mash distance {mash_dist} for jaccard similarity {j}"
);
mash_dist.max(0.0)
}
impl Sketch {
pub fn to_params(&self) -> SketchParams {
match self {
Sketch::BottomSketch(sketch) => SketchParams {
alg: SketchAlg::Bottom,
rc: sketch.rc,
k: sketch.k,
s: sketch.bottom.len(),
b: 0,
filter_empty: false,
filter_out_n: false, },
Sketch::BucketSketch(sketch) => SketchParams {
alg: SketchAlg::Bucket,
rc: sketch.rc,
k: sketch.k,
s: sketch.buckets.len(),
b: sketch.b,
filter_empty: false,
filter_out_n: false, },
}
}
pub fn jaccard_similarity(&self, other: &Self) -> f32 {
match (self, other) {
(Sketch::BottomSketch(a), Sketch::BottomSketch(b)) => a.jaccard_similarity(b),
(Sketch::BucketSketch(a), Sketch::BucketSketch(b)) => a.jaccard_similarity(b),
_ => panic!("Sketches are of different types!"),
}
}
pub fn mash_distance(&self, other: &Self) -> f32 {
let j = self.jaccard_similarity(other);
let k = match self {
Sketch::BottomSketch(sketch) => sketch.k,
Sketch::BucketSketch(sketch) => sketch.k,
};
compute_mash_distance(j, k)
}
}
#[derive(bincode::Encode, bincode::Decode, Debug)]
pub enum BitSketch {
B32(Vec<u32>),
B16(Vec<u16>),
B8(Vec<u8>),
B1(Vec<u64>),
}
impl BitSketch {
fn new(b: usize, vals: Vec<u32>) -> Self {
match b {
32 => BitSketch::B32(vals),
16 => BitSketch::B16(vals.into_iter().map(|x| x as u16).collect()),
8 => BitSketch::B8(vals.into_iter().map(|x| x as u8).collect()),
1 => BitSketch::B1({
assert_eq!(vals.len() % 64, 0);
vals.chunks_exact(64)
.map(|xs| {
xs.iter()
.enumerate()
.fold(0u64, |bits, (i, x)| bits | (((x & 1) as u64) << i))
})
.collect()
}),
_ => panic!("Unsupported bit width. Must be 1 or 8 or 16 or 32."),
}
}
fn len(&self) -> usize {
match self {
BitSketch::B32(v) => v.len(),
BitSketch::B16(v) => v.len(),
BitSketch::B8(v) => v.len(),
BitSketch::B1(v) => 64 * v.len(),
}
}
}
#[derive(bincode::Encode, bincode::Decode, Debug)]
pub struct BottomSketch {
rc: bool,
k: usize,
bottom: Vec<u32>,
}
impl BottomSketch {
pub fn jaccard_similarity(&self, other: &Self) -> f32 {
assert_eq!(self.rc, other.rc);
assert_eq!(self.k, other.k);
let a = &self.bottom;
let b = &other.bottom;
assert_eq!(a.len(), b.len());
let mut intersection_size = 0;
let mut union_size = 0;
let mut i = 0;
let mut j = 0;
while union_size < a.len() {
intersection_size += (a[i] == b[j]) as usize;
let di = (a[i] <= b[j]) as usize;
let dj = (a[i] >= b[j]) as usize;
i += di;
j += dj;
union_size += 1;
}
return intersection_size as f32 / a.len() as f32;
}
pub fn mash_distance(&self, other: &Self) -> f32 {
let j = self.jaccard_similarity(other);
compute_mash_distance(j, self.k)
}
}
#[derive(bincode::Encode, bincode::Decode, Debug)]
pub struct BucketSketch {
rc: bool,
k: usize,
b: usize,
buckets: BitSketch,
empty: Vec<u64>,
}
impl BucketSketch {
pub fn jaccard_similarity(&self, other: &Self) -> f32 {
assert_eq!(self.rc, other.rc);
assert_eq!(self.k, other.k);
assert_eq!(self.b, other.b);
let both_empty = self.both_empty(other);
match (&self.buckets, &other.buckets) {
(BitSketch::B32(a), BitSketch::B32(b)) => Self::inner_similarity(a, b, both_empty),
(BitSketch::B16(a), BitSketch::B16(b)) => Self::inner_similarity(a, b, both_empty),
(BitSketch::B8(a), BitSketch::B8(b)) => Self::inner_similarity(a, b, both_empty),
(BitSketch::B1(a), BitSketch::B1(b)) => Self::b1_similarity(a, b, both_empty),
_ => panic!("Bit width mismatch"),
}
}
pub fn mash_distance(&self, other: &Self) -> f32 {
let j = self.jaccard_similarity(other);
compute_mash_distance(j, self.k)
}
fn inner_similarity<T: Eq>(a: &Vec<T>, b: &Vec<T>, both_empty: usize) -> f32 {
assert_eq!(a.len(), b.len());
let f = 1.0
- std::iter::zip(a, b)
.map(|(a, b)| (a != b) as u32)
.sum::<u32>() as f32
/ (a.len() - both_empty) as f32;
let bb = (1usize << (size_of::<T>() * 8)) as f32;
(bb * f - 1.0).max(0.0) / (bb - 1.0)
}
fn b1_similarity(a: &Vec<u64>, b: &Vec<u64>, both_empty: usize) -> f32 {
assert_eq!(a.len(), b.len());
let f = 1.0
- std::iter::zip(a, b)
.map(|(a, b)| (*a ^ *b).count_ones())
.sum::<u32>() as f32
/ (64 * a.len() - both_empty) as f32;
(2. * f - 1.).max(0.0)
}
fn both_empty(&self, other: &Self) -> usize {
std::iter::zip(&self.empty, &other.empty)
.map(|(a, b)| (a & b).count_ones())
.sum::<u32>() as usize
}
}
#[derive(clap::ValueEnum, Clone, Copy, Debug, Eq, PartialEq)]
pub enum SketchAlg {
Bottom,
Bucket,
}
#[derive(clap::Args, Copy, Clone, Debug, Eq, PartialEq)]
pub struct SketchParams {
#[arg(long, default_value_t = SketchAlg::Bucket)]
#[arg(value_enum)]
pub alg: SketchAlg,
#[arg(
long="fwd",
num_args(0),
action = clap::builder::ArgAction::Set,
default_value = "false",
default_missing_value = "true",
)]
pub rc: bool,
#[arg(short, default_value_t = 31)]
pub k: usize,
#[arg(short, default_value_t = 10000)]
pub s: usize,
#[arg(short, default_value_t = 8)]
pub b: usize,
#[arg(skip = true)]
pub filter_empty: bool,
#[arg(long)]
pub filter_out_n: bool,
}
pub struct Sketcher {
params: SketchParams,
rc_hasher: RcNtHasher,
fwd_hasher: FwdNtHasher,
factor: AtomicU64,
}
impl SketchParams {
pub fn build(&self) -> Sketcher {
let mut params = *self;
let factor;
match params.alg {
SketchAlg::Bottom => {
params.b = 0;
factor = 13;
}
SketchAlg::Bucket => {
factor = params.s.ilog2() as u64 * 5;
}
}
if params.alg == SketchAlg::Bottom {}
Sketcher {
params,
rc_hasher: RcNtHasher::new(params.k),
fwd_hasher: FwdNtHasher::new(params.k),
factor: AtomicU64::new(factor),
}
}
pub fn default(k: usize) -> Self {
SketchParams {
alg: SketchAlg::Bucket,
rc: true,
k,
s: 32768,
b: 1,
filter_empty: true,
filter_out_n: false,
}
}
pub fn default_fast_sketching(k: usize) -> Self {
SketchParams {
alg: SketchAlg::Bucket,
rc: true,
k,
s: 8192,
b: 8,
filter_empty: false,
filter_out_n: false,
}
}
}
impl Sketcher {
pub fn params(&self) -> &SketchParams {
&self.params
}
pub fn sketch(&self, seq: impl Sketchable) -> Sketch {
self.sketch_seqs(&[seq])
}
pub fn sketch_seqs<'s>(&self, seqs: &[impl Sketchable]) -> Sketch {
match self.params.alg {
SketchAlg::Bottom => Sketch::BottomSketch(self.bottom_sketch(seqs)),
SketchAlg::Bucket => Sketch::BucketSketch(self.bucket_sketch(seqs)),
}
}
fn num_kmers<'s>(&self, seqs: &[impl Sketchable]) -> usize {
seqs.iter()
.map(|seq| seq.len() - self.params.k + 1)
.sum::<usize>()
}
fn bottom_sketch<'s>(&self, seqs: &[impl Sketchable]) -> BottomSketch {
let n = self.num_kmers(seqs);
let mut out = vec![];
loop {
let target = u32::MAX as usize * self.params.s / n;
let factor = self.factor.load(Relaxed);
let bound = (target as u128 * factor as u128 / 10 as u128).min(u32::MAX as u128) as u32;
self.collect_up_to_bound(seqs, bound, &mut out);
if bound == u32::MAX || out.len() >= self.params.s {
out.sort_unstable();
let old_len = out.len();
out.dedup();
let new_len = out.len();
debug!("Deduplicated from {old_len} to {new_len}");
if bound == u32::MAX || out.len() >= self.params.s {
out.resize(self.params.s, u32::MAX);
return BottomSketch {
rc: self.params.rc,
k: self.params.k,
bottom: out,
};
}
}
let new_factor = factor + factor.div_ceil(4);
let prev = self.factor.fetch_max(new_factor, Relaxed);
debug!(
"Found only {:>10} of {:>10} ({:>6.3}%)) Increasing factor from {factor} to {new_factor} (was already {prev})",
out.len(),
self.params.s,
out.len() as f32 / self.params.s as f32,
);
}
}
fn bucket_sketch<'s>(&self, seqs: &[impl Sketchable]) -> BucketSketch {
let n = self.num_kmers(seqs);
let mut out = vec![];
let mut buckets = vec![u32::MAX; self.params.s];
loop {
let target = u32::MAX as usize * self.params.s / n;
let factor = self.factor.load(Relaxed);
let bound = (target as u128 * factor as u128 / 10 as u128).min(u32::MAX as u128) as u32;
debug!(
"n {n:>10} s {} target {target:>10} factor {factor:>3} bound {bound:>10} ({:>6.3}%)",
self.params.s,
bound as f32 / u32::MAX as f32 * 100.0,
);
self.collect_up_to_bound(seqs, bound, &mut out);
let mut empty = 0;
if bound == u32::MAX || out.len() >= self.params.s {
let m = FM32::new(self.params.s as u32);
for &hash in &out {
let bucket = m.fastmod(hash);
buckets[bucket] = buckets[bucket].min(hash);
}
for &x in &buckets {
if x == u32::MAX {
empty += 1;
}
}
if bound == u32::MAX || empty == 0 {
if empty > 0 {
debug!("Found {empty} empty buckets.");
}
let empty = if empty > 0 && self.params.filter_empty {
debug!("Found {empty} empty buckets. Storing bitmask.");
buckets
.chunks(64)
.map(|xs| {
xs.iter().enumerate().fold(0u64, |bits, (i, x)| {
bits | (((*x == u32::MAX) as u64) << i)
})
})
.collect()
} else {
vec![]
};
return BucketSketch {
rc: self.params.rc,
k: self.params.k,
b: self.params.b,
empty,
buckets: BitSketch::new(
self.params.b,
buckets.into_iter().map(|x| m.fastdiv(x) as u32).collect(),
),
};
}
}
let new_factor = factor + factor.div_ceil(4);
let prev = self.factor.fetch_max(new_factor, Relaxed);
debug!(
"Found only {:>10} of {:>10} ({:>6.3}%, {empty:>5} empty) Increasing factor from {factor} to {new_factor} (was already {prev})",
out.len(),
self.params.s,
out.len() as f32 / self.params.s as f32 * 100.,
);
}
}
fn collect_up_to_bound<'s>(&self, seqs: &[impl Sketchable], bound: u32, out: &mut Vec<u32>) {
out.clear();
if self.params.rc {
for &seq in seqs {
let hashes = seq.hash_kmers(&self.rc_hasher);
collect_impl(bound, hashes, out);
}
} else {
for &seq in seqs {
let hashes = seq.hash_kmers(&self.fwd_hasher);
collect_impl(bound, hashes, out);
}
}
debug!(
"Collect up to {bound:>10}: {:>9} ({:>6.3}%)",
out.len(),
out.len() as f32 / self.num_kmers(seqs) as f32 * 100.0
);
}
}
fn collect_impl(bound: u32, hashes: PaddedIt<impl ChunkIt<u32x8>>, out: &mut Vec<u32>) {
let simd_bound = u32x8::splat(bound);
let mut write_idx = out.len();
let lane_len = hashes.it.len();
let mut idx = u32x8::from(std::array::from_fn(|i| (i * lane_len) as u32));
let max_idx = (8 * lane_len - hashes.padding) as u32;
let max_idx = u32x8::splat(max_idx);
hashes.it.for_each(|hashes| {
let mask = hashes.cmp_lt(simd_bound);
let in_bounds = idx.cmp_lt(max_idx);
if write_idx + 8 > out.capacity() {
out.reserve(out.capacity() + 8);
}
unsafe { intrinsics::append_from_mask(hashes, mask & in_bounds, out, &mut write_idx) };
idx += u32x8::ONE;
});
unsafe { out.set_len(write_idx) };
}
pub trait Sketchable: Copy {
fn len(&self) -> usize;
fn hash_kmers<H: KmerHasher>(self, hasher: &H) -> PaddedIt<impl ChunkIt<u32x8>>;
}
impl Sketchable for &[u8] {
fn len(&self) -> usize {
Seq::len(self)
}
fn hash_kmers<H: KmerHasher>(self, hasher: &H) -> PaddedIt<impl ChunkIt<u32x8>> {
hasher.hash_kmers_simd(self, 1)
}
}
impl Sketchable for packed_seq::AsciiSeq<'_> {
fn len(&self) -> usize {
Seq::len(self)
}
fn hash_kmers<H: KmerHasher>(self, hasher: &H) -> PaddedIt<impl ChunkIt<u32x8>> {
hasher.hash_kmers_simd(self, 1)
}
}
impl Sketchable for packed_seq::PackedSeq<'_> {
fn len(&self) -> usize {
Seq::len(self)
}
fn hash_kmers<H: KmerHasher>(self, hasher: &H) -> PaddedIt<impl ChunkIt<u32x8>> {
hasher.hash_kmers_simd(self, 1)
}
}
impl<'s> Sketchable for PackedNSeq<'s> {
fn len(&self) -> usize {
Seq::len(&self.seq)
}
fn hash_kmers<'h, H: KmerHasher>(
self,
hasher: &'h H,
) -> PaddedIt<impl ChunkIt<u32x8> + use<'s, 'h, H>> {
hasher.hash_valid_kmers_simd(self, 1)
}
}
#[derive(Copy, Clone, Debug)]
struct FM32 {
d: u64,
m: u64,
}
impl FM32 {
#[inline(always)]
fn new(d: u32) -> Self {
Self {
d: d as u64,
m: u64::MAX / d as u64 + 1,
}
}
#[inline(always)]
fn fastmod(self, h: u32) -> usize {
let lowbits = self.m.wrapping_mul(h as u64);
((lowbits as u128 * self.d as u128) >> 64) as usize
}
#[inline(always)]
fn fastdiv(self, h: u32) -> usize {
((self.m as u128 * h as u128) >> 64) as u32 as usize
}
}
#[cfg(test)]
mod test {
use std::hint::black_box;
use super::*;
use packed_seq::SeqVec;
#[test]
fn test() {
let b = 16;
let k = 31;
for n in 31..100 {
for f in [0.0, 0.01, 0.03] {
let s = n - k + 1;
let seq = packed_seq::PackedNSeqVec::random(n, f);
let sketcher = crate::SketchParams {
alg: SketchAlg::Bottom,
rc: false,
k,
s,
b,
filter_empty: false,
filter_out_n: true,
}
.build();
let bottom = sketcher.bottom_sketch(&[seq.as_slice()]).bottom;
assert_eq!(bottom.len(), s);
assert!(bottom.is_sorted());
let s = s.min(10);
let seq = packed_seq::PackedNSeqVec::random(n, f);
let sketcher = crate::SketchParams {
alg: SketchAlg::Bottom,
rc: true,
k,
s,
b,
filter_empty: false,
filter_out_n: true,
}
.build();
let bottom = sketcher.bottom_sketch(&[seq.as_slice()]).bottom;
assert_eq!(bottom.len(), s);
assert!(bottom.is_sorted());
}
}
}
#[test]
fn rc() {
let b = 32;
for k in (0..10).map(|_| rand::random_range(1..=64)) {
for n in (0..10).map(|_| rand::random_range(k..1000)) {
for s in (0..10).map(|_| rand::random_range(0..n - k + 1)) {
for f in [0.0, 0.001, 0.01] {
let seq = packed_seq::PackedNSeqVec::random(n, f);
let sketcher = crate::SketchParams {
alg: SketchAlg::Bottom,
rc: true,
k,
s,
b,
filter_empty: false,
filter_out_n: true,
}
.build();
let bottom = sketcher.bottom_sketch(&[seq.as_slice()]).bottom;
assert_eq!(bottom.len(), s);
assert!(bottom.is_sorted());
let seq_rc = seq.as_slice().to_revcomp();
let bottom_rc = sketcher.bottom_sketch(&[seq_rc.as_slice()]).bottom;
assert_eq!(bottom, bottom_rc);
}
}
}
}
}
#[test]
fn equal_dist() {
let s = 1000;
let k = 10;
let n = 300;
let b = 8;
let seq = packed_seq::AsciiSeqVec::random(n);
for (alg, filter_empty) in [
(SketchAlg::Bottom, false),
(SketchAlg::Bucket, false),
(SketchAlg::Bucket, true),
] {
let sketcher = crate::SketchParams {
alg,
rc: false,
k,
s,
b,
filter_empty,
filter_out_n: false,
}
.build();
let sketch = sketcher.sketch(seq.as_slice());
assert_eq!(sketch.mash_distance(&sketch), 0.0);
}
}
#[test]
fn fuzz_short() {
let s = 1024;
let k = 10;
for b in [1, 8, 16, 32] {
for n in [10, 20, 40, 80, 150, 300, 500, 1000, 2000] {
let seq1 = packed_seq::AsciiSeqVec::random(n);
let seq2 = packed_seq::AsciiSeqVec::random(n);
for (alg, filter_empty) in [
(SketchAlg::Bottom, false),
(SketchAlg::Bucket, false),
(SketchAlg::Bucket, true),
] {
let sketcher = crate::SketchParams {
alg,
rc: false,
k,
s,
b,
filter_empty,
filter_out_n: false,
}
.build();
let s1 = sketcher.sketch(seq1.as_slice());
let s2 = sketcher.sketch(seq2.as_slice());
s1.mash_distance(&s2);
}
}
}
}
#[test]
fn test_collect() {
let mut out = vec![];
let n = black_box(8000);
let it = (0..n).map(|x| u32x8::splat((x as u32).wrapping_mul(546786567)));
let padded_it = PaddedIt { it, padding: 0 };
let bound = black_box(u32::MAX / 10);
collect_impl(bound, padded_it, &mut out);
eprintln!("{out:?}");
}
}