use kira_cdh_compat_lsh::lsh::LshIndex;
#[derive(Clone, Debug)]
pub struct ClusterOptions {
pub similarity_threshold: f32,
pub max_candidates: usize,
pub sort_by_length: bool,
pub parallel_sim_checks: bool,
pub min_coverage_short: Option<f32>,
pub min_coverage_long: Option<f32>,
pub kmer_k: Option<u32>,
}
#[derive(Clone, Debug)]
pub struct ClusterResult {
pub clusters: Vec<Vec<usize>>,
}
pub struct GreedyClusterer<'a> {
pub index: &'a LshIndex,
pub signatures: &'a [Vec<u64>],
pub seq_lengths: Option<&'a [u32]>,
pub options: ClusterOptions,
}
impl<'a> GreedyClusterer<'a> {
pub fn run(&self) -> ClusterResult {
self.validate_inputs();
let n = self.signatures.len();
if n == 0 {
return ClusterResult {
clusters: Vec::new(),
};
}
let order = self.order_indices();
let mut clusters: Vec<Vec<usize>> = Vec::new();
let mut assigned = vec![false; n];
for &i in &order {
if assigned[i] {
continue;
}
let mut cluster = Vec::with_capacity(64);
cluster.push(i);
assigned[i] = true;
let cand_ids = self
.index
.query_candidates(&self.signatures[i], self.options.max_candidates);
let cands: Vec<usize> = cand_ids
.into_iter()
.map(|(_, seq_id)| seq_id as usize)
.filter(|&j| j != i && !assigned[j])
.collect();
let sig_i = &self.signatures[i];
let thr = self.options.similarity_threshold;
if self.options.parallel_sim_checks {
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
let passed: Vec<usize> = cands
.par_iter()
.copied()
.filter(|&j| {
let sim = jaccard_minhash(sig_i, &self.signatures[j]);
if sim < thr {
return false;
}
if self.coverage_filters_enabled() {
self.passes_coverage(i, j, sim)
} else {
true
}
})
.collect();
for j in passed {
if !assigned[j] {
assigned[j] = true;
cluster.push(j);
}
}
}
#[cfg(not(feature = "parallel"))]
{
let mut to_add = Vec::new();
for j in cands {
if assigned[j] {
continue;
}
let sim = jaccard_minhash(sig_i, &self.signatures[j]);
if sim >= thr
&& (!self.coverage_filters_enabled() || self.passes_coverage(i, j, sim))
{
to_add.push(j);
}
}
for j in to_add {
if !assigned[j] {
assigned[j] = true;
cluster.push(j);
}
}
}
} else {
for j in cands {
if assigned[j] {
continue;
}
let sim = jaccard_minhash(sig_i, &self.signatures[j]);
if sim < thr {
continue;
}
if self.coverage_filters_enabled() && !self.passes_coverage(i, j, sim) {
continue;
}
assigned[j] = true;
cluster.push(j);
}
}
clusters.push(cluster);
}
ClusterResult { clusters }
}
#[inline]
fn coverage_filters_enabled(&self) -> bool {
self.options.min_coverage_short.is_some() || self.options.min_coverage_long.is_some()
}
fn passes_coverage(&self, i: usize, j: usize, sim: f32) -> bool {
let k = self
.options
.kmer_k
.expect("Coverage thresholds require `kmer_k`");
let lens = self
.seq_lengths
.expect("Coverage thresholds require `seq_lengths`");
let (li, lj) = (lens[i] as i64, lens[j] as i64);
let (short_idx, long_idx, ls, ll) = if li <= lj {
(i, j, li, lj)
} else {
(j, i, lj, li)
};
let k_i = k as i64;
let a = (ls - k_i + 1).max(0) as f64; let b = (ll - k_i + 1).max(0) as f64;
if a == 0.0 || b == 0.0 {
return false;
}
let s = sim as f64;
let inter = s * (a + b) / (1.0 + s);
let cov_short = inter / a;
let cov_long = inter / b;
if let Some(min_s) = self.options.min_coverage_short {
if cov_short < min_s as f64 {
return false;
}
}
if let Some(min_l) = self.options.min_coverage_long {
if cov_long < min_l as f64 {
return false;
}
}
let _ = (short_idx, long_idx);
true
}
fn validate_inputs(&self) {
if self.signatures.is_empty() {
return;
}
let m = self.signatures[0].len();
debug_assert!(
self.signatures.iter().all(|s| s.len() == m),
"All signatures must be of equal length"
);
if self.options.sort_by_length {
let lens = self
.seq_lengths
.expect("sort_by_length=true requires `seq_lengths`");
debug_assert_eq!(
lens.len(),
self.signatures.len(),
"`seq_lengths` must match signatures length"
);
}
if self.coverage_filters_enabled() {
debug_assert!(
self.options.kmer_k.is_some(),
"Coverage thresholds require `kmer_k`"
);
let lens = self
.seq_lengths
.expect("coverage thresholds require `seq_lengths`");
debug_assert_eq!(
lens.len(),
self.signatures.len(),
"`seq_lengths` must match signatures length"
);
}
}
fn order_indices(&self) -> Vec<usize> {
let n = self.signatures.len();
let mut idx: Vec<usize> = (0..n).collect();
if self.options.sort_by_length {
let lens = self.seq_lengths.expect("seq_lengths is required");
idx.sort_unstable_by_key(|&i| std::cmp::Reverse(lens[i]));
}
idx
}
}
pub fn jaccard_minhash(a: &[u64], b: &[u64]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "signatures must be equal length");
let mut same = 0usize;
for i in 0..a.len() {
if a[i] == b[i] {
same += 1;
}
}
same as f32 / (a.len() as f32)
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{Rng, SeedableRng, rngs::StdRng};
fn make_sig(base: u64, len: usize) -> Vec<u64> {
(0..len).map(|i| base.wrapping_add(i as u64)).collect()
}
#[test]
fn test_jaccard_minhash() {
let a = vec![1, 2, 3, 4];
let b = vec![1, 2, 9, 9];
assert!((jaccard_minhash(&a, &b) - 0.5).abs() < 1e-6);
}
#[test]
fn test_order_with_lengths() {
let idx = LshIndex::with_params(kira_cdh_compat_lsh::lsh::LshParams::new(8, 4).unwrap());
let sigs = vec![make_sig(0, 4), make_sig(10, 4), make_sig(20, 4)];
let lens = vec![100, 300, 200];
let gc = GreedyClusterer {
index: &idx,
signatures: &sigs,
seq_lengths: Some(&lens),
options: ClusterOptions {
similarity_threshold: 0.5,
max_candidates: 10,
sort_by_length: true,
parallel_sim_checks: false,
min_coverage_short: None,
min_coverage_long: None,
kmer_k: None,
},
};
let order = gc.order_indices();
assert_eq!(order, vec![1, 2, 0]); }
#[test]
fn test_coverage_gate_basic() {
let a = vec![1, 2, 3, 4];
let b = vec![1, 2, 3, 9];
let sim = jaccard_minhash(&a, &b);
assert!((sim - 0.75).abs() < 1e-6);
let idx = LshIndex::with_params(kira_cdh_compat_lsh::lsh::LshParams::new(8, 4).unwrap());
let sigs = vec![a, b];
let lens = vec![100, 120];
let gc = GreedyClusterer {
index: &idx,
signatures: &sigs,
seq_lengths: Some(&lens),
options: ClusterOptions {
similarity_threshold: 0.0, max_candidates: 0,
sort_by_length: false,
parallel_sim_checks: false,
min_coverage_short: Some(0.5),
min_coverage_long: Some(0.4),
kmer_k: Some(10),
},
};
gc.validate_inputs();
let ok = gc.passes_coverage(0, 1, 0.75);
assert!(ok);
}
}