use std::collections::HashMap;
use nalgebra::{Matrix3, SVD, Vector3};
use crate::core::PdbStructure;
use crate::error::PdbError;
use super::rmsd::rmsd_from_coords;
use super::transform::{
AtomSelection, CoordWithResidue, Point3D, apply_transform, center_coords,
extract_coords_by_selection, extract_coords_with_residue_info,
};
type SuperposeResult = Result<(Vec<Point3D>, AlignmentResult), PdbError>;
type ResidueCoordMap = HashMap<(String, i32), Vec<CoordWithResidue>>;
#[derive(Debug, Clone)]
pub struct AlignmentResult {
pub rmsd: f64,
pub rotation: [[f64; 3]; 3],
pub translation: [f64; 3],
pub num_atoms: usize,
}
#[derive(Debug, Clone)]
pub struct PerResidueRmsd {
pub residue_id: (String, i32),
pub residue_name: String,
pub rmsd: f64,
pub num_atoms: usize,
}
fn kabsch_rotation(p: &[Vector3<f64>], q: &[Vector3<f64>]) -> Matrix3<f64> {
let mut h = Matrix3::zeros();
for (pi, qi) in p.iter().zip(q.iter()) {
h += pi * qi.transpose();
}
let svd = SVD::new(h, true, true);
let u = svd.u.expect("SVD should compute U matrix");
let v_t = svd.v_t.expect("SVD should compute V^T matrix");
let mut rotation = v_t.transpose() * u.transpose();
if rotation.determinant() < 0.0 {
let mut v = v_t.transpose();
v.column_mut(2).scale_mut(-1.0);
rotation = v * u.transpose();
}
rotation
}
pub(crate) fn superpose_coords(
mobile_coords: &[Point3D],
target_coords: &[Point3D],
) -> SuperposeResult {
if mobile_coords.is_empty() || target_coords.is_empty() {
return Err(PdbError::NoAtomsSelected(
"Cannot superpose empty coordinate sets".to_string(),
));
}
if mobile_coords.len() != target_coords.len() {
return Err(PdbError::AtomCountMismatch {
expected: target_coords.len(),
found: mobile_coords.len(),
});
}
if mobile_coords.len() < 3 {
return Err(PdbError::InsufficientAtoms(
"Need at least 3 atoms for superposition".to_string(),
));
}
let (mobile_centered, mobile_centroid) = center_coords(mobile_coords)?;
let (target_centered, target_centroid) = center_coords(target_coords)?;
let mobile_vec: Vec<Vector3<f64>> = mobile_centered
.iter()
.map(|(x, y, z)| Vector3::new(*x, *y, *z))
.collect();
let target_vec: Vec<Vector3<f64>> = target_centered
.iter()
.map(|(x, y, z)| Vector3::new(*x, *y, *z))
.collect();
let rotation = kabsch_rotation(&mobile_vec, &target_vec);
let mut aligned = Vec::with_capacity(mobile_coords.len());
for m in &mobile_vec {
let rotated = rotation * m;
aligned.push((
rotated.x + target_centroid.0,
rotated.y + target_centroid.1,
rotated.z + target_centroid.2,
));
}
let rmsd = rmsd_from_coords(&aligned, target_coords)?;
let rotation_array = [
[rotation[(0, 0)], rotation[(0, 1)], rotation[(0, 2)]],
[rotation[(1, 0)], rotation[(1, 1)], rotation[(1, 2)]],
[rotation[(2, 0)], rotation[(2, 1)], rotation[(2, 2)]],
];
let neg_mobile = Vector3::new(-mobile_centroid.0, -mobile_centroid.1, -mobile_centroid.2);
let rotated_neg_mobile = rotation * neg_mobile;
let translation = [
rotated_neg_mobile.x + target_centroid.0,
rotated_neg_mobile.y + target_centroid.1,
rotated_neg_mobile.z + target_centroid.2,
];
let result = AlignmentResult {
rmsd,
rotation: rotation_array,
translation,
num_atoms: mobile_coords.len(),
};
Ok((aligned, result))
}
pub fn align_structures(
mobile: &PdbStructure,
target: &PdbStructure,
selection: AtomSelection,
) -> Result<(PdbStructure, AlignmentResult), PdbError> {
let mobile_coords = extract_coords_by_selection(mobile, &selection, None);
let target_coords = extract_coords_by_selection(target, &selection, None);
if mobile_coords.is_empty() {
return Err(PdbError::NoAtomsSelected(format!(
"No atoms matching {:?} selection in mobile structure",
selection
)));
}
if target_coords.is_empty() {
return Err(PdbError::NoAtomsSelected(format!(
"No atoms matching {:?} selection in target structure",
selection
)));
}
let (_, result) = superpose_coords(&mobile_coords, &target_coords)?;
let aligned = apply_transform(mobile, &result.rotation, &result.translation);
Ok((aligned, result))
}
pub fn calculate_alignment(
mobile: &PdbStructure,
target: &PdbStructure,
selection: AtomSelection,
) -> Result<AlignmentResult, PdbError> {
let mobile_coords = extract_coords_by_selection(mobile, &selection, None);
let target_coords = extract_coords_by_selection(target, &selection, None);
if mobile_coords.is_empty() {
return Err(PdbError::NoAtomsSelected(format!(
"No atoms matching {:?} selection in mobile structure",
selection
)));
}
if target_coords.is_empty() {
return Err(PdbError::NoAtomsSelected(format!(
"No atoms matching {:?} selection in target structure",
selection
)));
}
let (_, result) = superpose_coords(&mobile_coords, &target_coords)?;
Ok(result)
}
pub fn per_residue_rmsd(
mobile: &PdbStructure,
target: &PdbStructure,
selection: AtomSelection,
) -> Result<Vec<PerResidueRmsd>, PdbError> {
let (aligned, _) = align_structures(mobile, target, selection.clone())?;
let aligned_with_res = extract_coords_with_residue_info(&aligned, &selection, None);
let target_with_res = extract_coords_with_residue_info(target, &selection, None);
let mut aligned_by_residue: ResidueCoordMap = HashMap::new();
for item in aligned_with_res {
let key = (item.0.0.clone(), item.0.1);
aligned_by_residue.entry(key).or_default().push(item);
}
let mut target_by_residue: ResidueCoordMap = HashMap::new();
for item in target_with_res {
let key = (item.0.0.clone(), item.0.1);
target_by_residue.entry(key).or_default().push(item);
}
let mut results = Vec::new();
for (residue_key, aligned_atoms) in &aligned_by_residue {
if let Some(target_atoms) = target_by_residue.get(residue_key) {
if aligned_atoms.len() == target_atoms.len() && !aligned_atoms.is_empty() {
let aligned_coords: Vec<_> = aligned_atoms.iter().map(|a| a.1).collect();
let target_coords: Vec<_> = target_atoms.iter().map(|a| a.1).collect();
if let Ok(rmsd) = rmsd_from_coords(&aligned_coords, &target_coords) {
results.push(PerResidueRmsd {
residue_id: residue_key.clone(),
residue_name: aligned_atoms[0].0.2.clone(),
rmsd,
num_atoms: aligned_atoms.len(),
});
}
}
}
}
results.sort_by(|a, b| {
a.residue_id
.0
.cmp(&b.residue_id.0)
.then(a.residue_id.1.cmp(&b.residue_id.1))
});
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::records::Atom;
fn create_atom(x: f64, y: f64, z: f64, name: &str, residue_seq: i32, chain_id: &str) -> Atom {
Atom {
serial: residue_seq,
name: name.to_string(),
alt_loc: None,
residue_name: "ALA".to_string(),
chain_id: chain_id.to_string(),
residue_seq,
x,
y,
z,
occupancy: 1.0,
temp_factor: 0.0,
element: "C".to_string(),
ins_code: None,
is_hetatm: false,
}
}
fn create_linear_structure() -> PdbStructure {
let mut structure = PdbStructure::new();
structure.atoms = vec![
create_atom(0.0, 0.0, 0.0, "CA", 1, "A"),
create_atom(3.8, 0.0, 0.0, "CA", 2, "A"),
create_atom(7.6, 0.0, 0.0, "CA", 3, "A"),
create_atom(11.4, 0.0, 0.0, "CA", 4, "A"),
];
structure
}
fn translate_structure(structure: &PdbStructure, tx: f64, ty: f64, tz: f64) -> PdbStructure {
let mut translated = structure.clone();
for atom in &mut translated.atoms {
atom.x += tx;
atom.y += ty;
atom.z += tz;
}
translated
}
#[test]
fn test_kabsch_identity() {
let points = vec![
Vector3::new(0.0, 0.0, 0.0),
Vector3::new(1.0, 0.0, 0.0),
Vector3::new(0.0, 1.0, 0.0),
];
let rotation = kabsch_rotation(&points, &points);
assert!((rotation[(0, 0)] - 1.0).abs() < 1e-10);
assert!((rotation[(1, 1)] - 1.0).abs() < 1e-10);
assert!((rotation[(2, 2)] - 1.0).abs() < 1e-10);
}
#[test]
fn test_kabsch_rotation_is_proper() {
let p = vec![
Vector3::new(1.0, 0.0, 0.0),
Vector3::new(0.0, 1.0, 0.0),
Vector3::new(0.0, 0.0, 1.0),
];
let q = vec![
Vector3::new(0.0, 1.0, 0.0),
Vector3::new(-1.0, 0.0, 0.0),
Vector3::new(0.0, 0.0, 1.0),
];
let rotation = kabsch_rotation(&p, &q);
assert!((rotation.determinant() - 1.0).abs() < 1e-10);
}
#[test]
fn test_superpose_identical() {
let coords = vec![
(0.0, 0.0, 0.0),
(1.0, 0.0, 0.0),
(0.0, 1.0, 0.0),
(0.0, 0.0, 1.0),
];
let (aligned, result) = superpose_coords(&coords, &coords).unwrap();
assert!(result.rmsd < 1e-10);
assert_eq!(result.num_atoms, 4);
for (a, o) in aligned.iter().zip(coords.iter()) {
assert!((a.0 - o.0).abs() < 1e-10);
assert!((a.1 - o.1).abs() < 1e-10);
assert!((a.2 - o.2).abs() < 1e-10);
}
}
#[test]
fn test_superpose_translated() {
let coords1 = vec![
(0.0, 0.0, 0.0),
(1.0, 0.0, 0.0),
(0.0, 1.0, 0.0),
(0.0, 0.0, 1.0),
];
let coords2: Vec<_> = coords1
.iter()
.map(|(x, y, z)| (x + 10.0, y + 10.0, z + 10.0))
.collect();
let (aligned, result) = superpose_coords(&coords2, &coords1).unwrap();
assert!(result.rmsd < 1e-10);
for (a, t) in aligned.iter().zip(coords1.iter()) {
assert!((a.0 - t.0).abs() < 1e-10);
assert!((a.1 - t.1).abs() < 1e-10);
assert!((a.2 - t.2).abs() < 1e-10);
}
}
#[test]
fn test_superpose_insufficient_atoms() {
let coords = vec![(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)];
let result = superpose_coords(&coords, &coords);
assert!(matches!(result, Err(PdbError::InsufficientAtoms(_))));
}
#[test]
fn test_superpose_mismatched_length() {
let coords1 = vec![(0.0, 0.0, 0.0), (1.0, 0.0, 0.0), (0.0, 1.0, 0.0)];
let coords2 = vec![(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)];
let result = superpose_coords(&coords1, &coords2);
assert!(matches!(result, Err(PdbError::AtomCountMismatch { .. })));
}
#[test]
fn test_align_structures_identical() {
let structure = create_linear_structure();
let (aligned, result) =
align_structures(&structure, &structure, AtomSelection::CaOnly).unwrap();
assert!(result.rmsd < 1e-10);
assert_eq!(result.num_atoms, 4);
assert_eq!(aligned.atoms.len(), structure.atoms.len());
}
#[test]
fn test_align_structures_translated() {
let target = create_linear_structure();
let mobile = translate_structure(&target, 50.0, 50.0, 50.0);
let (aligned, result) = align_structures(&mobile, &target, AtomSelection::CaOnly).unwrap();
assert!(result.rmsd < 1e-6);
for (aligned_atom, target_atom) in aligned.atoms.iter().zip(target.atoms.iter()) {
assert!((aligned_atom.x - target_atom.x).abs() < 1e-6);
assert!((aligned_atom.y - target_atom.y).abs() < 1e-6);
assert!((aligned_atom.z - target_atom.z).abs() < 1e-6);
}
}
#[test]
fn test_calculate_alignment() {
let target = create_linear_structure();
let mobile = translate_structure(&target, 10.0, 20.0, 30.0);
let result = calculate_alignment(&mobile, &target, AtomSelection::CaOnly).unwrap();
assert!(result.rmsd < 1e-6);
assert_eq!(result.num_atoms, 4);
}
#[test]
fn test_per_residue_rmsd() {
let target = create_linear_structure();
let mobile = translate_structure(&target, 5.0, 5.0, 5.0);
let per_res = per_residue_rmsd(&mobile, &target, AtomSelection::CaOnly).unwrap();
assert_eq!(per_res.len(), 4);
for r in &per_res {
assert!(r.rmsd < 1e-6);
assert_eq!(r.num_atoms, 1); }
}
}