use std::collections::{HashMap, HashSet};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use crate::structure::Molecule;
#[derive(Debug, Clone)]
pub struct ClusteringConfig {
pub contact_distance: f64,
pub fcc_cutoff: f64,
pub strictness: f64,
pub min_cluster_size: usize,
}
impl Default for ClusteringConfig {
fn default() -> Self {
Self {
contact_distance: 5.0,
fcc_cutoff: 0.60,
strictness: 0.75,
min_cluster_size: 4,
}
}
}
#[derive(Debug, Clone)]
pub struct ClusterResult {
pub center_idx: usize,
pub members: Vec<usize>,
pub size: usize,
}
struct Element {
neighbors: HashSet<usize>,
}
type ResidueCoordMap = HashMap<(char, i16), Vec<(f64, f64, f64)>>;
pub fn calculate_contacts(molecule: &Molecule, contact_distance: f64) -> HashSet<String> {
let mut residues: ResidueCoordMap = HashMap::new();
for atom in &molecule.0 {
if atom.element.trim() == "H" {
continue;
}
residues
.entry((atom.chainid, atom.resseq))
.or_default()
.push((atom.x, atom.y, atom.z));
}
let mut contacts = HashSet::new();
let residue_keys: Vec<_> = residues.keys().collect();
for i in 0..residue_keys.len() {
for j in (i + 1)..residue_keys.len() {
let (chain_a, res_a) = residue_keys[i];
let (chain_b, res_b) = residue_keys[j];
if chain_a == chain_b {
continue;
}
let atoms_a = &residues[residue_keys[i]];
let atoms_b = &residues[residue_keys[j]];
'outer: for (xa, ya, za) in atoms_a {
for (xb, yb, zb) in atoms_b {
let dist = ((xa - xb).powi(2) + (ya - yb).powi(2) + (za - zb).powi(2)).sqrt();
if dist <= contact_distance {
let contact = if chain_a < chain_b {
format!("{} {} {} {}", chain_a, res_a, chain_b, res_b)
} else {
format!("{} {} {} {}", chain_b, res_b, chain_a, res_a)
};
contacts.insert(contact);
break 'outer;
}
}
}
}
}
contacts
}
fn calculate_fcc(x: &HashSet<String>, y: &HashSet<String>) -> (f64, f64) {
if x.is_empty() || y.is_empty() {
return (0.0, 0.0);
}
let common = x.intersection(y).count() as f64;
let fcc = common / x.len() as f64;
let fcc_v = common / y.len() as f64;
(fcc, fcc_v)
}
fn calculate_pairwise_fcc(contact_sets: &[HashSet<String>]) -> Vec<(usize, usize, f64, f64)> {
let n = contact_sets.len();
#[cfg(feature = "parallel")]
{
(0..n)
.into_par_iter()
.flat_map(|i| {
(0..n).into_par_iter().map(move |j| {
let (fcc, fcc_v) = calculate_fcc(&contact_sets[i], &contact_sets[j]);
(i, j, fcc, fcc_v)
})
})
.collect()
}
#[cfg(not(feature = "parallel"))]
{
let mut result = Vec::new();
for i in 0..n {
for j in 0..n {
let (fcc, fcc_v) = calculate_fcc(&contact_sets[i], &contact_sets[j]);
result.push((i, j, fcc, fcc_v));
}
}
result
}
}
fn create_elements(
pairwise_fcc: Vec<(usize, usize, f64, f64)>,
n_structures: usize,
config: &ClusteringConfig,
) -> HashMap<usize, Element> {
let mut elements: HashMap<usize, Element> = HashMap::new();
for i in 0..n_structures {
elements.insert(
i,
Element {
neighbors: HashSet::new(),
},
);
}
for (i, j, fcc, fcc_v) in pairwise_fcc {
if fcc >= config.fcc_cutoff && fcc_v >= config.fcc_cutoff * config.strictness {
elements.get_mut(&i).unwrap().neighbors.insert(j);
}
if fcc_v >= config.fcc_cutoff && fcc >= config.fcc_cutoff * config.strictness {
elements.get_mut(&j).unwrap().neighbors.insert(i);
}
}
elements
}
fn cluster_elements(mut elements: HashMap<usize, Element>) -> Vec<ClusterResult> {
let mut used: HashSet<usize> = HashSet::new();
let mut clusters = Vec::new();
loop {
let clusterable: Vec<usize> = elements
.keys()
.filter(|k| !used.contains(k))
.copied()
.collect();
if clusterable.is_empty() {
break;
}
let center = clusterable
.iter()
.max_by(|&&a, &&b| {
let count_a = elements[&a]
.neighbors
.iter()
.filter(|n| !used.contains(n))
.count();
let count_b = elements[&b]
.neighbors
.iter()
.filter(|n| !used.contains(n))
.count();
count_a.cmp(&count_b).then_with(|| b.cmp(&a))
})
.copied()
.unwrap();
let neighbors: Vec<usize> = elements[¢er]
.neighbors
.iter()
.filter(|n| !used.contains(n))
.copied()
.collect();
let mut members: Vec<usize> = vec![center];
for neighbor in &neighbors {
if *neighbor != center {
members.push(*neighbor);
used.insert(*neighbor);
}
}
used.insert(center);
members.sort();
clusters.push(ClusterResult {
center_idx: center,
members: members.clone(),
size: members.len(),
});
elements.remove(¢er);
}
clusters.sort_by(|a, b| b.size.cmp(&a.size));
clusters
}
pub fn cluster_structures(
structures: &[Molecule],
config: &ClusteringConfig,
) -> Vec<ClusterResult> {
if structures.is_empty() {
return Vec::new();
}
#[cfg(feature = "parallel")]
let contact_sets: Vec<HashSet<String>> = structures
.par_iter()
.map(|mol| calculate_contacts(mol, config.contact_distance))
.collect();
#[cfg(not(feature = "parallel"))]
let contact_sets: Vec<HashSet<String>> = structures
.iter()
.map(|mol| calculate_contacts(mol, config.contact_distance))
.collect();
let pairwise_fcc = calculate_pairwise_fcc(&contact_sets);
let elements = create_elements(pairwise_fcc, structures.len(), config);
let clusters = cluster_elements(elements);
clusters
.into_iter()
.filter(|c| c.size >= config.min_cluster_size)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_clustering_config_default() {
let config = ClusteringConfig::default();
assert!((config.contact_distance - 5.0).abs() < 0.001);
assert!((config.fcc_cutoff - 0.60).abs() < 0.001);
assert!((config.strictness - 0.75).abs() < 0.001);
assert_eq!(config.min_cluster_size, 4);
}
#[test]
fn test_calculate_fcc_identical() {
let mut set_a = HashSet::new();
set_a.insert("A 1 B 2".to_string());
set_a.insert("A 3 B 4".to_string());
let (fcc, fcc_v) = calculate_fcc(&set_a, &set_a);
assert!((fcc - 1.0).abs() < 0.001);
assert!((fcc_v - 1.0).abs() < 0.001);
}
#[test]
fn test_calculate_fcc_disjoint() {
let mut set_a = HashSet::new();
set_a.insert("A 1 B 2".to_string());
let mut set_b = HashSet::new();
set_b.insert("A 5 B 6".to_string());
let (fcc, fcc_v) = calculate_fcc(&set_a, &set_b);
assert!((fcc - 0.0).abs() < 0.001);
assert!((fcc_v - 0.0).abs() < 0.001);
}
#[test]
fn test_calculate_fcc_partial_overlap() {
let mut set_a = HashSet::new();
set_a.insert("A 1 B 2".to_string());
set_a.insert("A 3 B 4".to_string());
let mut set_b = HashSet::new();
set_b.insert("A 1 B 2".to_string());
set_b.insert("A 5 B 6".to_string());
let (fcc, fcc_v) = calculate_fcc(&set_a, &set_b);
assert!((fcc - 0.5).abs() < 0.001);
assert!((fcc_v - 0.5).abs() < 0.001);
}
#[test]
fn test_calculate_fcc_empty() {
let set_a: HashSet<String> = HashSet::new();
let set_b: HashSet<String> = HashSet::new();
let (fcc, fcc_v) = calculate_fcc(&set_a, &set_b);
assert!((fcc - 0.0).abs() < 0.001);
assert!((fcc_v - 0.0).abs() < 0.001);
}
#[test]
fn test_cluster_structures_empty() {
let structures: Vec<Molecule> = Vec::new();
let config = ClusteringConfig::default();
let clusters = cluster_structures(&structures, &config);
assert!(clusters.is_empty());
}
#[test]
fn test_calculate_contacts_with_real_data() {
use crate::commands::run::combine_molecules;
use crate::structure::read_pdb;
let receptor_model = read_pdb(&"data/2oob_A.pdb".to_string());
let ligand_model = read_pdb(&"data/2oob_B.pdb".to_string());
let receptor = &receptor_model.0[0];
let ligand = &ligand_model.0[0];
let complex = combine_molecules(receptor, ligand);
let contacts = calculate_contacts(&complex, 5.0);
assert!(!contacts.is_empty(), "Should detect inter-chain contacts");
for contact in &contacts {
let parts: Vec<&str> = contact.split_whitespace().collect();
assert_eq!(
parts.len(),
4,
"Contact format should be 'chain res chain res'"
);
assert_ne!(parts[0], parts[2], "Contacts should be inter-chain");
}
}
}