use std::collections::HashSet;
use chematic_core::{AtomIdx, Molecule, MoleculeBuilder};
use crate::ecfp::fnv1a as fnv1a_hash;
#[derive(Clone, Debug)]
pub struct MhfpConfig {
pub num_hashes: usize,
pub seed: u64,
pub radius: u32,
}
impl Default for MhfpConfig {
fn default() -> Self {
MhfpConfig {
num_hashes: 128,
seed: 0,
radius: 2,
}
}
}
#[derive(Clone, Debug)]
pub struct MhfpFingerprint {
pub hashes: Vec<u64>,
pub num_hashes: usize,
}
impl MhfpFingerprint {
pub fn tanimoto(&self, other: &MhfpFingerprint) -> f64 {
if self.num_hashes == 0 {
return 1.0;
}
let agreements = self.hashes.iter().zip(other.hashes.iter())
.filter(|(h1, h2)| h1 == h2)
.count();
agreements as f64 / self.num_hashes as f64
}
}
fn extract_subgraph(mol: &Molecule, atom_set: &[AtomIdx]) -> Molecule {
let mut builder = MoleculeBuilder::new();
let mut old_to_new: std::collections::HashMap<AtomIdx, AtomIdx> =
std::collections::HashMap::with_capacity(atom_set.len());
for &idx in atom_set {
let new_idx = builder.add_atom(mol.atom(idx).clone());
old_to_new.insert(idx, new_idx);
}
for (_, bond) in mol.bonds() {
if let (Some(&n1), Some(&n2)) =
(old_to_new.get(&bond.atom1), old_to_new.get(&bond.atom2))
{
let _ = builder.add_bond(n1, n2, bond.order);
}
}
builder.build()
}
fn circular_smiles_shingle(mol: &Molecule, center: AtomIdx, radius: u32) -> u64 {
let atoms = atoms_within_radius(mol, center, radius);
let subgraph = extract_subgraph(mol, &atoms);
let smiles = chematic_smiles::canonical_smiles(&subgraph);
fnv1a_hash(smiles.as_bytes())
}
fn atoms_within_radius(mol: &Molecule, center: AtomIdx, radius: u32) -> Vec<AtomIdx> {
let mut visited = HashSet::new();
let mut discovered = vec![center];
visited.insert(center);
let mut frontier_start = 0;
for _ in 0..radius {
let frontier_end = discovered.len();
for i in frontier_start..frontier_end {
for (nb, _) in mol.neighbors(discovered[i]) {
if visited.insert(nb) {
discovered.push(nb);
}
}
}
frontier_start = frontier_end;
if frontier_start == discovered.len() {
break;
}
}
discovered
}
fn extract_fragment_hashes(mol: &Molecule, radius: u32) -> Vec<u64> {
(0..mol.atom_count())
.flat_map(|i| {
let center = AtomIdx(i as u32);
(0..=radius).map(move |r| circular_smiles_shingle(mol, center, r))
})
.collect()
}
pub fn mhfp(mol: &Molecule) -> MhfpFingerprint {
mhfp_with_config(mol, &MhfpConfig::default())
}
pub fn mhfp_with_config(mol: &Molecule, config: &MhfpConfig) -> MhfpFingerprint {
let shingles = extract_fragment_hashes(mol, config.radius);
if shingles.is_empty() {
return MhfpFingerprint {
hashes: vec![u64::MAX; config.num_hashes],
num_hashes: config.num_hashes,
};
}
let mut minhashes = vec![u64::MAX; config.num_hashes];
for (h, slot) in minhashes.iter_mut().enumerate() {
let seed = config.seed.wrapping_add(h as u64);
let seed_bytes = seed.to_le_bytes();
for &shingle in &shingles {
let mut buf = [0u8; 16];
buf[..8].copy_from_slice(&seed_bytes);
buf[8..].copy_from_slice(&shingle.to_le_bytes());
let v = fnv1a_hash(&buf);
if v < *slot {
*slot = v;
}
}
}
MhfpFingerprint { hashes: minhashes, num_hashes: config.num_hashes }
}
pub fn mhfp_128(mol: &Molecule) -> MhfpFingerprint {
mhfp(mol)
}
pub fn tanimoto_mhfp(mol1: &Molecule, mol2: &Molecule) -> f64 {
mhfp(mol1).tanimoto(&mhfp(mol2))
}
#[cfg(test)]
mod tests {
use super::*;
use chematic_smiles::parse;
#[test]
fn test_mhfp_consistency() {
let mol = parse("CC").unwrap();
let fp1 = mhfp(&mol);
let fp2 = mhfp(&mol);
assert_eq!(fp1.hashes, fp2.hashes);
assert!((fp1.tanimoto(&fp2) - 1.0).abs() < 1e-6);
}
#[test]
fn test_mhfp_different_molecules() {
let mol1 = parse("CC").unwrap();
let mol2 = parse("CCC").unwrap();
let fp1 = mhfp(&mol1);
let fp2 = mhfp(&mol2);
let similarity = fp1.tanimoto(&fp2);
assert!((0.0..=1.0).contains(&similarity));
}
#[test]
fn test_mhfp_symmetry() {
let mol1 = parse("CC").unwrap();
let mol2 = parse("CCC").unwrap();
let sim12 = tanimoto_mhfp(&mol1, &mol2);
let sim21 = tanimoto_mhfp(&mol2, &mol1);
assert!((sim12 - sim21).abs() < 1e-10);
}
#[test]
fn test_mhfp_config() {
let mol = parse("CC").unwrap();
let config = MhfpConfig { num_hashes: 64, seed: 42, radius: 2 };
let fp = mhfp_with_config(&mol, &config);
assert_eq!(fp.num_hashes, 64);
}
#[test]
fn test_atoms_within_radius() {
let mol = parse("CCCC").unwrap();
let center = AtomIdx(1);
let r0 = atoms_within_radius(&mol, center, 0);
assert_eq!(r0.len(), 1);
let r1 = atoms_within_radius(&mol, center, 1);
assert!(r1.len() >= 2);
}
#[test]
fn test_extract_fragment_hashes() {
let mol = parse("CC").unwrap();
let hashes = extract_fragment_hashes(&mol, 2);
assert!(!hashes.is_empty());
}
#[test]
fn test_mhfp_canonical_same_mol_different_parse() {
let mol1 = parse("CCO").unwrap();
let mol2 = parse("OCC").unwrap();
let fp1 = mhfp(&mol1);
let fp2 = mhfp(&mol2);
assert!((fp1.tanimoto(&fp2) - 1.0).abs() < 1e-6,
"Same molecule with reversed SMILES should have Tanimoto=1.0, got {}", fp1.tanimoto(&fp2));
}
#[test]
fn test_mhfp_similar_mols_higher_than_dissimilar() {
let ethane = parse("CC").unwrap();
let propane = parse("CCC").unwrap();
let benzene = parse("c1ccccc1").unwrap();
let sim_ec = tanimoto_mhfp(ðane, &propane); let sim_eb = tanimoto_mhfp(ðane, &benzene);
assert!(sim_ec > sim_eb,
"ethane~propane ({:.3}) should be > ethane~benzene ({:.3})", sim_ec, sim_eb);
}
#[test]
fn test_mhfp_pyridine_vs_pyrrole() {
let pyridine = parse("c1ccncc1").unwrap();
let pyrrole = parse("c1cc[nH]c1").unwrap();
let sim = tanimoto_mhfp(&pyridine, &pyrrole);
assert!(sim < 1.0,
"pyridine and pyrrole should NOT be identical (got sim={sim:.3})");
}
#[test]
fn test_mhfp_amine_h_count() {
let methylamine = parse("CN").unwrap(); let dimethylamine = parse("CNC").unwrap(); let trimethylamine = parse("CN(C)C").unwrap(); let fp_m = mhfp(&methylamine);
let fp_dm = mhfp(&dimethylamine);
let fp_tm = mhfp(&trimethylamine);
assert!(fp_m.tanimoto(&fp_dm) < 1.0);
assert!(fp_dm.tanimoto(&fp_tm) < 1.0);
}
#[test]
fn test_mhfp_radius_effect() {
let mol = parse("c1ccccc1CC").unwrap(); let fp_r1 = mhfp_with_config(&mol, &MhfpConfig { radius: 1, ..Default::default() });
let fp_r3 = mhfp_with_config(&mol, &MhfpConfig { radius: 3, ..Default::default() });
assert_ne!(fp_r1.hashes, fp_r3.hashes);
}
}