use std::collections::HashSet;
use crate::core::PdbStructure;
use crate::records::Atom;
pub type ResidueKey = (String, i32);
pub type ContactPair = (ResidueKey, ResidueKey);
#[derive(Debug, Clone)]
pub struct ContactResult {
pub fnat: f64,
pub fnonnat: f64,
pub f1: f64,
pub num_native_contacts: usize,
pub num_model_contacts: usize,
}
pub fn find_interface_contacts(
structure: &PdbStructure,
chain_a: &str,
chain_b: &str,
threshold: f64,
) -> HashSet<ContactPair> {
let threshold_sq = threshold * threshold;
let atoms_a: Vec<&Atom> = structure
.atoms
.iter()
.filter(|a| a.chain_id == chain_a && !a.is_hetatm && !is_hydrogen(a))
.collect();
let atoms_b: Vec<&Atom> = structure
.atoms
.iter()
.filter(|a| a.chain_id == chain_b && !a.is_hetatm && !is_hydrogen(a))
.collect();
let mut contacts = HashSet::new();
for atom_a in &atoms_a {
for atom_b in &atoms_b {
let dx = atom_a.x - atom_b.x;
let dy = atom_a.y - atom_b.y;
let dz = atom_a.z - atom_b.z;
let dist_sq = dx * dx + dy * dy + dz * dz;
if dist_sq < threshold_sq {
let key_a = (atom_a.chain_id.clone(), atom_a.residue_seq);
let key_b = (atom_b.chain_id.clone(), atom_b.residue_seq);
let pair = if key_a <= key_b {
(key_a, key_b)
} else {
(key_b, key_a)
};
contacts.insert(pair);
}
}
}
contacts
}
pub fn count_clashes(structure: &PdbStructure, chain_a: &str, chain_b: &str) -> usize {
let clash_threshold_sq = 2.0 * 2.0;
let atoms_a: Vec<&Atom> = structure
.atoms
.iter()
.filter(|a| a.chain_id == chain_a && !a.is_hetatm && !is_hydrogen(a))
.collect();
let atoms_b: Vec<&Atom> = structure
.atoms
.iter()
.filter(|a| a.chain_id == chain_b && !a.is_hetatm && !is_hydrogen(a))
.collect();
let mut count = 0;
for atom_a in &atoms_a {
for atom_b in &atoms_b {
let dx = atom_a.x - atom_b.x;
let dy = atom_a.y - atom_b.y;
let dz = atom_a.z - atom_b.z;
let dist_sq = dx * dx + dy * dy + dz * dz;
if dist_sq < clash_threshold_sq {
count += 1;
}
}
}
count
}
pub fn compare_contacts(
native_contacts: &HashSet<ContactPair>,
model_contacts: &HashSet<ContactPair>,
residue_mapping: &std::collections::HashMap<ResidueKey, ResidueKey>,
) -> ContactResult {
if native_contacts.is_empty() {
return ContactResult {
fnat: 0.0,
fnonnat: 0.0,
f1: 0.0,
num_native_contacts: 0,
num_model_contacts: model_contacts.len(),
};
}
let mapped_model_contacts: HashSet<ContactPair> = model_contacts
.iter()
.filter_map(|(ka, kb)| {
let mapped_a = residue_mapping.get(ka)?;
let mapped_b = residue_mapping.get(kb)?;
let pair = if mapped_a <= mapped_b {
(mapped_a.clone(), mapped_b.clone())
} else {
(mapped_b.clone(), mapped_a.clone())
};
Some(pair)
})
.collect();
let true_positives = native_contacts.intersection(&mapped_model_contacts).count();
let false_negatives = native_contacts.len() - true_positives;
let false_positives = mapped_model_contacts
.iter()
.filter(|c| !native_contacts.contains(c))
.count();
let fnat = true_positives as f64 / native_contacts.len() as f64;
let fnonnat = if mapped_model_contacts.is_empty() {
0.0
} else {
false_positives as f64 / mapped_model_contacts.len() as f64
};
let f1 = if true_positives == 0 {
0.0
} else {
2.0 * true_positives as f64
/ (2.0 * true_positives as f64 + false_positives as f64 + false_negatives as f64)
};
ContactResult {
fnat,
fnonnat,
f1,
num_native_contacts: native_contacts.len(),
num_model_contacts: model_contacts.len(),
}
}
pub fn find_interface_residues(
structure: &PdbStructure,
chain_a: &str,
chain_b: &str,
threshold: f64,
) -> HashSet<ResidueKey> {
let threshold_sq = threshold * threshold;
let atoms_a: Vec<&Atom> = structure
.atoms
.iter()
.filter(|a| a.chain_id == chain_a && !a.is_hetatm && !is_hydrogen(a))
.collect();
let atoms_b: Vec<&Atom> = structure
.atoms
.iter()
.filter(|a| a.chain_id == chain_b && !a.is_hetatm && !is_hydrogen(a))
.collect();
let mut interface_residues = HashSet::new();
for atom_a in &atoms_a {
for atom_b in &atoms_b {
let dx = atom_a.x - atom_b.x;
let dy = atom_a.y - atom_b.y;
let dz = atom_a.z - atom_b.z;
let dist_sq = dx * dx + dy * dy + dz * dz;
if dist_sq < threshold_sq {
interface_residues.insert((atom_a.chain_id.clone(), atom_a.residue_seq));
interface_residues.insert((atom_b.chain_id.clone(), atom_b.residue_seq));
}
}
}
interface_residues
}
fn is_hydrogen(atom: &Atom) -> bool {
atom.element.trim() == "H" || (atom.element.is_empty() && atom.name.trim().starts_with('H'))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::records::Atom;
fn make_atom(
serial: i32,
name: &str,
chain_id: &str,
residue_seq: i32,
x: f64,
y: f64,
z: f64,
) -> Atom {
Atom {
serial,
name: name.to_string(),
alt_loc: None,
residue_name: "ALA".to_string(),
chain_id: chain_id.to_string(),
residue_seq,
ins_code: None,
is_hetatm: false,
x,
y,
z,
occupancy: 1.0,
temp_factor: 20.0,
element: "C".to_string(),
}
}
#[test]
fn test_find_contacts_close() {
let mut structure = PdbStructure::new();
structure.atoms = vec![
make_atom(1, "CA", "A", 1, 0.0, 0.0, 0.0),
make_atom(2, "CA", "B", 1, 3.0, 0.0, 0.0), ];
let contacts = find_interface_contacts(&structure, "A", "B", 5.0);
assert_eq!(contacts.len(), 1);
}
#[test]
fn test_find_contacts_far() {
let mut structure = PdbStructure::new();
structure.atoms = vec![
make_atom(1, "CA", "A", 1, 0.0, 0.0, 0.0),
make_atom(2, "CA", "B", 1, 10.0, 0.0, 0.0), ];
let contacts = find_interface_contacts(&structure, "A", "B", 5.0);
assert_eq!(contacts.len(), 0);
}
#[test]
fn test_compare_identical_contacts() {
let mut contacts = HashSet::new();
contacts.insert((("A".to_string(), 1), ("B".to_string(), 1)));
contacts.insert((("A".to_string(), 2), ("B".to_string(), 2)));
let mut mapping = std::collections::HashMap::new();
mapping.insert(("A".to_string(), 1), ("A".to_string(), 1));
mapping.insert(("A".to_string(), 2), ("A".to_string(), 2));
mapping.insert(("B".to_string(), 1), ("B".to_string(), 1));
mapping.insert(("B".to_string(), 2), ("B".to_string(), 2));
let result = compare_contacts(&contacts, &contacts, &mapping);
assert!((result.fnat - 1.0).abs() < 1e-10);
assert!((result.fnonnat - 0.0).abs() < 1e-10);
assert!((result.f1 - 1.0).abs() < 1e-10);
}
#[test]
fn test_interface_residues() {
let mut structure = PdbStructure::new();
structure.atoms = vec![
make_atom(1, "CA", "A", 1, 0.0, 0.0, 0.0),
make_atom(2, "CA", "A", 2, 3.8, 0.0, 0.0),
make_atom(3, "CA", "B", 1, 4.0, 0.0, 0.0),
make_atom(4, "CA", "B", 2, 20.0, 0.0, 0.0),
];
let iface = find_interface_residues(&structure, "A", "B", 10.0);
assert!(iface.contains(&("A".to_string(), 1)));
assert!(iface.contains(&("A".to_string(), 2)));
assert!(iface.contains(&("B".to_string(), 1)));
assert!(!iface.contains(&("B".to_string(), 2)));
}
}