use crate::bitvec::BitVec2048;
pub fn tanimoto_slice(query: &BitVec2048, db: &[BitVec2048]) -> Vec<f32> {
if db.is_empty() {
return vec![];
}
let qa = query.popcount();
db.iter()
.map(|fp| query.tanimoto_with_counts(fp, qa, fp.popcount()))
.collect()
}
pub fn tanimoto_matrix(queries: &[BitVec2048], db: &[BitVec2048]) -> Vec<f32> {
let m = queries.len();
let n = db.len();
if m == 0 || n == 0 {
return vec![];
}
let q_counts: Vec<u32> = queries.iter().map(|q| q.popcount()).collect();
let d_counts: Vec<u32> = db.iter().map(|d| d.popcount()).collect();
let mut out = Vec::with_capacity(m * n);
for (i, q) in queries.iter().enumerate() {
for (j, d) in db.iter().enumerate() {
out.push(q.tanimoto_with_counts(d, q_counts[i], d_counts[j]));
}
}
out
}
pub fn top_k_similar(query: &BitVec2048, db: &[BitVec2048], k: usize) -> Vec<(usize, f32)> {
if k == 0 || db.is_empty() {
return vec![];
}
let k = k.min(db.len());
let scores = tanimoto_slice(query, db);
let mut indexed: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
indexed.select_nth_unstable_by(k - 1, |a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
indexed.truncate(k);
indexed.sort_unstable_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
indexed
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ecfp::ecfp4;
use chematic_smiles::parse;
fn fp(smi: &str) -> BitVec2048 {
ecfp4(&parse(smi).unwrap())
}
#[test]
fn test_tanimoto_slice_parity_with_single() {
let q = fp("c1ccccc1");
let db = vec![fp("CC"), fp("c1ccccc1N"), fp("CCO"), fp("c1ccccc1")];
let bulk = tanimoto_slice(&q, &db);
for (i, bv) in db.iter().enumerate() {
let single = q.tanimoto(bv) as f32;
assert!(
(bulk[i] - single).abs() < 1e-5,
"slice[{i}] = {:.6} ≠ single {:.6}", bulk[i], single
);
}
}
#[test]
fn test_tanimoto_slice_dense_no_zero_filter() {
let q = fp("c1ccccc1");
let db = vec![fp("CC")];
let result = tanimoto_slice(&q, &db);
assert_eq!(result.len(), 1);
assert!((0.0..=1.0).contains(&result[0]));
}
#[test]
fn test_tanimoto_slice_all_zero_convention() {
let zero = BitVec2048::new();
let result = tanimoto_slice(&zero, &[zero.clone()]);
assert_eq!(result.len(), 1);
assert!((result[0] - 1.0).abs() < 1e-6, "both zero → 1.0");
}
#[test]
fn test_tanimoto_slice_empty_db() {
let q = fp("CC");
let result = tanimoto_slice(&q, &[]);
assert!(result.is_empty());
}
#[test]
fn test_tanimoto_matrix_row_major() {
let queries = vec![fp("CC"), fp("c1ccccc1")];
let db = vec![fp("CCO"), fp("c1ccccc1N"), fp("CCCC")];
let matrix = tanimoto_matrix(&queries, &db);
assert_eq!(matrix.len(), queries.len() * db.len());
for (i, q) in queries.iter().enumerate() {
let row = tanimoto_slice(q, &db);
for (j, &expected) in row.iter().enumerate() {
assert!(
(matrix[i * db.len() + j] - expected).abs() < 1e-6,
"matrix[{i}][{j}] mismatch"
);
}
}
}
#[test]
fn test_tanimoto_matrix_empty() {
let q = vec![fp("CC")];
assert!(tanimoto_matrix(&q, &[]).is_empty());
assert!(tanimoto_matrix(&[], &q).is_empty());
}
#[test]
fn test_top_k_similar_sorted_descending() {
let query = fp("c1ccccc1");
let db = vec![fp("CC"), fp("c1ccccc1"), fp("c1ccccc1N"), fp("CCCC")];
let k = 3;
let hits = top_k_similar(&query, &db, k);
assert_eq!(hits.len(), k);
for w in hits.windows(2) {
assert!(
w[0].1 >= w[1].1,
"not sorted: {:.4} < {:.4}", w[0].1, w[1].1
);
}
}
#[test]
fn test_top_k_similar_matches_sorted_slice() {
let query = fp("c1ccccc1");
let db = vec![fp("CC"), fp("c1ccccc1"), fp("CCO"), fp("c1ccccc1N"), fp("CCCC")];
let k = 3;
let hits = top_k_similar(&query, &db, k);
let mut all: Vec<(usize, f32)> = tanimoto_slice(&query, &db)
.into_iter().enumerate().collect();
all.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
all.truncate(k);
for (i, (hit_idx, hit_score)) in hits.iter().enumerate() {
assert_eq!(*hit_idx, all[i].0, "index mismatch at rank {i}");
assert!((hit_score - all[i].1).abs() < 1e-6, "score mismatch at rank {i}");
}
}
#[test]
fn test_top_k_similar_k_zero() {
let query = fp("CC");
let db = vec![fp("CCC")];
assert!(top_k_similar(&query, &db, 0).is_empty());
}
#[test]
fn test_top_k_similar_k_exceeds_db() {
let query = fp("CC");
let db = vec![fp("CCC"), fp("CCCC")];
let hits = top_k_similar(&query, &db, 100);
assert_eq!(hits.len(), db.len(), "returns all when k > db.len()");
}
#[test]
fn test_top_k_similar_empty_db() {
let query = fp("CC");
assert!(top_k_similar(&query, &[], 5).is_empty());
}
#[test]
fn test_self_similarity_is_one() {
let query = fp("c1ccccc1N");
let db = vec![fp("c1ccccc1N")];
let hits = top_k_similar(&query, &db, 1);
assert_eq!(hits.len(), 1);
assert!((hits[0].1 - 1.0).abs() < 1e-6, "self-similarity must be 1.0");
}
}