use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
use crate::ecfp::ecfp4;
#[derive(Clone, Debug)]
pub struct MhfpConfig {
pub num_hashes: usize,
pub seed: u64,
}
impl Default for MhfpConfig {
fn default() -> Self {
MhfpConfig {
num_hashes: 128,
seed: 0,
}
}
}
#[derive(Clone, Debug)]
pub struct MhfpFingerprint {
pub hashes: Vec<u64>,
pub num_hashes: usize,
}
impl MhfpFingerprint {
pub fn tanimoto(&self, other: &MhfpFingerprint) -> f64 {
if self.num_hashes == 0 {
return 1.0;
}
let mut agreements = 0;
for (h1, h2) in self.hashes.iter().zip(other.hashes.iter()) {
if h1 == h2 {
agreements += 1;
}
}
agreements as f64 / self.num_hashes as f64
}
}
pub fn mhfp(mol: &chematic_core::Molecule) -> MhfpFingerprint {
mhfp_with_config(mol, &MhfpConfig::default())
}
pub fn mhfp_with_config(
mol: &chematic_core::Molecule,
config: &MhfpConfig,
) -> MhfpFingerprint {
let ecfp = ecfp4(mol);
let mut bit_set = Vec::new();
for i in 0..2048 {
if ecfp.get(i) {
bit_set.push(i as u64);
}
}
if bit_set.is_empty() {
return MhfpFingerprint {
hashes: vec![u64::MAX; config.num_hashes],
num_hashes: config.num_hashes,
};
}
let mut hashes = vec![u64::MAX; config.num_hashes];
for h in 0..config.num_hashes {
let seed = config.seed.wrapping_add(h as u64);
for &bit_pos in &bit_set {
let mut hasher = DefaultHasher::new();
hasher.write_u64(seed);
hasher.write_u64(bit_pos);
hasher.write_usize(mol.atoms().count());
hasher.write_usize(mol.bonds().count());
let hash_val = hasher.finish();
if hash_val < hashes[h] {
hashes[h] = hash_val;
}
}
}
MhfpFingerprint {
hashes,
num_hashes: config.num_hashes,
}
}
pub fn mhfp_128(mol: &chematic_core::Molecule) -> MhfpFingerprint {
mhfp(mol)
}
pub fn tanimoto_mhfp(mol1: &chematic_core::Molecule, mol2: &chematic_core::Molecule) -> f64 {
let fp1 = mhfp(mol1);
let fp2 = mhfp(mol2);
fp1.tanimoto(&fp2)
}
#[cfg(test)]
mod tests {
use super::*;
use chematic_smiles::parse;
#[test]
fn test_mhfp_simple() {
let mol = parse("CC").unwrap();
let fp = mhfp(&mol);
assert_eq!(fp.num_hashes, 128);
assert_eq!(fp.hashes.len(), 128);
}
#[test]
fn test_mhfp_identical() {
let mol = parse("CC").unwrap();
let fp1 = mhfp(&mol);
let fp2 = mhfp(&mol);
assert_eq!(fp1.hashes, fp2.hashes);
assert!((fp1.tanimoto(&fp2) - 1.0).abs() < 1e-6);
}
#[test]
fn test_mhfp_different_molecules() {
let mol1 = parse("CC").unwrap();
let mol2 = parse("CCC").unwrap();
let fp1 = mhfp(&mol1);
let fp2 = mhfp(&mol2);
let similarity = fp1.tanimoto(&fp2);
assert!((0.0..=1.0).contains(&similarity));
}
#[test]
fn test_mhfp_symmetry() {
let mol1 = parse("CC").unwrap();
let mol2 = parse("CCC").unwrap();
let sim12 = tanimoto_mhfp(&mol1, &mol2);
let sim21 = tanimoto_mhfp(&mol2, &mol1);
assert!((sim12 - sim21).abs() < 1e-10);
}
#[test]
fn test_mhfp_config() {
let mol = parse("CC").unwrap();
let config = MhfpConfig {
num_hashes: 64,
seed: 42,
};
let fp = mhfp_with_config(&mol, &config);
assert_eq!(fp.num_hashes, 64);
assert_eq!(fp.hashes.len(), 64);
}
#[test]
fn test_mhfp_similar_molecules() {
let mol1 = parse("CC").unwrap();
let mol2 = parse("CC").unwrap();
let fp1 = mhfp(&mol1);
let fp2 = mhfp(&mol2);
let similarity = fp1.tanimoto(&fp2);
assert!(similarity > 0.9);
}
#[test]
fn test_mhfp_empty_molecule() {
let mol = parse("C").unwrap();
let fp = mhfp(&mol);
assert_eq!(fp.num_hashes, 128);
}
#[test]
fn test_mhfp_bounds() {
let mol1 = parse("CC").unwrap();
let mol2 = parse("CCCCCCCC").unwrap();
let fp1 = mhfp(&mol1);
let fp2 = mhfp(&mol2);
let similarity = fp1.tanimoto(&fp2);
assert!((0.0..=1.0).contains(&similarity));
}
}