use std::collections::{HashMap, HashSet};
use crate::mhfp::MhfpFingerprint;
pub struct MhfpLshIndex {
bands: usize,
rows: usize,
buckets: Vec<HashMap<Vec<u64>, Vec<usize>>>,
fps: Vec<MhfpFingerprint>,
}
impl MhfpLshIndex {
pub fn new(num_hashes: usize) -> Self {
assert!(
num_hashes.is_multiple_of(16),
"num_hashes ({num_hashes}) must be divisible by 16 for default band decomposition"
);
Self::with_bands(16, num_hashes / 16)
}
pub fn with_bands(bands: usize, rows: usize) -> Self {
assert!(bands > 0 && rows > 0, "bands and rows must be > 0");
MhfpLshIndex {
bands,
rows,
buckets: vec![HashMap::new(); bands],
fps: Vec::new(),
}
}
pub fn add(&mut self, fp: MhfpFingerprint) -> usize {
let expected = self.bands * self.rows;
assert_eq!(
fp.hashes.len(), expected,
"fingerprint has {} hashes but index expects {} (bands={}, rows={})",
fp.hashes.len(), expected, self.bands, self.rows
);
let idx = self.fps.len();
for b in 0..self.bands {
let start = b * self.rows;
let band_key = fp.hashes[start..start + self.rows].to_vec();
self.buckets[b].entry(band_key).or_default().push(idx);
}
self.fps.push(fp);
idx
}
pub fn query(&self, fp: &MhfpFingerprint, threshold: f64) -> Vec<(usize, f64)> {
let mut candidates: HashSet<usize> = HashSet::new();
for b in 0..self.bands {
let start = b * self.rows;
let end = (start + self.rows).min(fp.hashes.len());
let band_key = &fp.hashes[start..end];
if let Some(bucket) = self.buckets[b].get(band_key) {
candidates.extend(bucket.iter().copied());
}
}
let mut results: Vec<(usize, f64)> = candidates
.into_iter()
.filter_map(|idx| {
let sim = fp.tanimoto(&self.fps[idx]);
(sim >= threshold).then_some((idx, sim))
})
.collect();
results.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results
}
pub fn len(&self) -> usize {
self.fps.len()
}
pub fn is_empty(&self) -> bool {
self.fps.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use chematic_smiles::parse;
use crate::mhfp::{mhfp, MhfpConfig};
fn fp(smiles: &str) -> MhfpFingerprint {
let mol = parse(smiles).unwrap();
mhfp(&mol)
}
fn fp_cfg(smiles: &str, cfg: &MhfpConfig) -> MhfpFingerprint {
let mol = parse(smiles).unwrap();
crate::mhfp::mhfp_with_config(&mol, cfg)
}
#[test]
fn test_lsh_empty_query() {
let index = MhfpLshIndex::new(128);
let results = index.query(&fp("CC"), 0.5);
assert!(results.is_empty(), "empty index should return no results");
}
#[test]
fn test_lsh_self_similarity() {
let mut index = MhfpLshIndex::new(128);
let benzene = fp("c1ccccc1");
index.add(benzene.clone());
let results = index.query(&benzene, 0.99);
assert!(!results.is_empty(), "self should be found at threshold 0.99");
assert_eq!(results[0].0, 0, "self should be index 0");
assert!(
(results[0].1 - 1.0).abs() < 1e-9,
"self-similarity should be 1.0, got {}",
results[0].1
);
}
#[test]
fn test_lsh_threshold_filtering() {
let mut index = MhfpLshIndex::new(128);
let benzene_fp = fp("c1ccccc1");
let ethane_fp = fp("CC");
index.add(benzene_fp.clone());
index.add(ethane_fp);
let sim_benz_eth = benzene_fp.tanimoto(&fp("CC"));
let high_threshold = sim_benz_eth + 0.1; let results = index.query(&benzene_fp, high_threshold);
for (_, sim) in &results {
assert!(*sim >= high_threshold, "all results must meet threshold, got {}", sim);
}
}
#[test]
fn test_lsh_similar_mols_found() {
let mut index = MhfpLshIndex::new(128);
let benzene = fp("c1ccccc1");
let toluene = fp("Cc1ccccc1");
let ethane = fp("CC");
index.add(ethane); index.add(toluene); index.add(benzene.clone());
let results = index.query(&benzene, 0.99);
let found_self = results.iter().any(|(i, _)| *i == 2);
assert!(found_self, "benzene should find itself in the index");
let sim = benzene.tanimoto(&fp("Cc1ccccc1"));
assert!(sim > 0.0, "benzene-toluene similarity should be > 0");
}
#[test]
fn test_lsh_sorted_descending() {
let mut index = MhfpLshIndex::new(128);
for smi in &["c1ccccc1", "Cc1ccccc1", "c1ccc2ccccc2c1", "CC", "CCC"] {
index.add(fp(smi));
}
let results = index.query(&fp("c1ccccc1"), 0.0);
for w in results.windows(2) {
assert!(
w[0].1 >= w[1].1,
"results should be sorted descending: {} >= {}",
w[0].1, w[1].1
);
}
}
#[test]
fn test_lsh_custom_bands() {
let cfg = MhfpConfig { num_hashes: 128, seed: 0, radius: 2 };
let mut index = MhfpLshIndex::with_bands(32, 4);
let benzene = fp_cfg("c1ccccc1", &cfg);
index.add(benzene.clone());
let results = index.query(&benzene, 0.99);
assert!(!results.is_empty(), "custom bands: self should be found");
assert!((results[0].1 - 1.0).abs() < 1e-9);
}
}