Skip to main content

cyanea_struct/
superposition.rs

1//! Structural superposition via the Kabsch algorithm.
2//!
3//! Finds the optimal rigid-body rotation (and translation) that minimizes RMSD
4//! between two sets of corresponding points.
5
6use cyanea_core::{CyaneaError, Result, Scored};
7
8use crate::geometry::center_of_mass_points;
9use crate::linalg::{svd_3x3, Matrix3x3};
10use crate::types::{Atom, Point3D};
11
12use alloc::format;
13use alloc::vec::Vec;
14
15/// Result of a Kabsch superposition.
16#[derive(Debug, Clone)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18pub struct SuperpositionResult {
19    /// RMSD after optimal superposition.
20    pub rmsd: f64,
21    /// 3x3 rotation matrix (row-major).
22    pub rotation: [[f64; 3]; 3],
23    /// Translation vector applied after rotation.
24    pub translation: Point3D,
25    /// Transformed coordinates of the mobile set after superposition.
26    pub transformed_coords: Vec<Point3D>,
27}
28
29impl Scored for SuperpositionResult {
30    fn score(&self) -> f64 {
31        -self.rmsd
32    }
33}
34
35/// Kabsch superposition on atom coordinates.
36///
37/// `atoms1` is the reference (fixed) set, `atoms2` is the mobile set that gets
38/// rotated and translated to minimize RMSD.
39pub fn kabsch(atoms1: &[&Atom], atoms2: &[&Atom]) -> Result<SuperpositionResult> {
40    let p1: Vec<Point3D> = atoms1.iter().map(|a| a.coords).collect();
41    let p2: Vec<Point3D> = atoms2.iter().map(|a| a.coords).collect();
42    kabsch_points(&p1, &p2)
43}
44
45/// Kabsch superposition on point coordinates.
46pub fn kabsch_points(
47    points1: &[Point3D],
48    points2: &[Point3D],
49) -> Result<SuperpositionResult> {
50    if points1.len() != points2.len() {
51        return Err(CyaneaError::InvalidInput(format!(
52            "point set sizes differ: {} vs {}",
53            points1.len(),
54            points2.len()
55        )));
56    }
57    if points1.len() < 3 {
58        return Err(CyaneaError::InvalidInput(
59            "need at least 3 points for Kabsch superposition".into(),
60        ));
61    }
62
63    let n = points1.len();
64
65    // Step 1: center both sets
66    let com1 = center_of_mass_points(points1);
67    let com2 = center_of_mass_points(points2);
68
69    let centered1: Vec<Point3D> = points1.iter().map(|p| p.sub(&com1)).collect();
70    let centered2: Vec<Point3D> = points2.iter().map(|p| p.sub(&com2)).collect();
71
72    // Step 2: compute cross-covariance matrix H = P2^T * P1
73    let mut h = Matrix3x3::zeros();
74    for i in 0..n {
75        let p = &centered2[i];
76        let q = &centered1[i];
77        h.data[0][0] += p.x * q.x;
78        h.data[0][1] += p.x * q.y;
79        h.data[0][2] += p.x * q.z;
80        h.data[1][0] += p.y * q.x;
81        h.data[1][1] += p.y * q.y;
82        h.data[1][2] += p.y * q.z;
83        h.data[2][0] += p.z * q.x;
84        h.data[2][1] += p.z * q.y;
85        h.data[2][2] += p.z * q.z;
86    }
87
88    // Step 3: SVD of H
89    let svd = svd_3x3(&h);
90
91    // Step 4: R = V * U^T, with reflection correction
92    let v = svd.vt.transpose();
93    let ut = svd.u.transpose();
94    let mut r = v.multiply(&ut);
95
96    // Fix reflection: if det(R) < 0, negate the column of V corresponding
97    // to the smallest singular value (always column 2 after sorting)
98    if r.determinant() < 0.0 {
99        let mut v_fixed = v;
100        for row in 0..3 {
101            v_fixed.data[row][2] = -v_fixed.data[row][2];
102        }
103        r = v_fixed.multiply(&ut);
104    }
105
106    // Step 5: apply rotation and compute translation + RMSD
107    let mut transformed = Vec::with_capacity(n);
108    let mut sum_sq = 0.0;
109    for i in 0..n {
110        let rotated = r.apply(&centered2[i]);
111        let final_point = rotated.add(&com1);
112        let diff = final_point.sub(&points1[i]);
113        sum_sq += diff.dot(&diff);
114        transformed.push(final_point);
115    }
116
117    let rmsd = (sum_sq / n as f64).sqrt();
118
119    Ok(SuperpositionResult {
120        rmsd,
121        rotation: r.data,
122        translation: com1.sub(&r.apply(&com2)),
123        transformed_coords: transformed,
124    })
125}
126
127/// Align two structures using only alpha-carbon atoms.
128pub fn align_structures_by_ca(
129    atoms1: &[&Atom],
130    atoms2: &[&Atom],
131) -> Result<SuperpositionResult> {
132    let ca1: Vec<&Atom> = atoms1.iter().copied().filter(|a| a.is_alpha_carbon()).collect();
133    let ca2: Vec<&Atom> = atoms2.iter().copied().filter(|a| a.is_alpha_carbon()).collect();
134
135    if ca1.len() != ca2.len() {
136        return Err(CyaneaError::InvalidInput(format!(
137            "different number of CA atoms: {} vs {}",
138            ca1.len(),
139            ca2.len()
140        )));
141    }
142
143    kabsch(&ca1, &ca2)
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use alloc::vec;
150    use crate::types::Atom;
151
152    fn make_atom(name: &str, x: f64, y: f64, z: f64) -> Atom {
153        Atom {
154            serial: 1,
155            name: name.into(),
156            alt_loc: None,
157            coords: Point3D::new(x, y, z),
158            occupancy: 1.0,
159            temp_factor: 0.0,
160            element: None,
161            charge: None,
162            is_hetatm: false,
163        }
164    }
165
166    #[test]
167    fn identical_points_rmsd_zero() {
168        let points = vec![
169            Point3D::new(0.0, 0.0, 0.0),
170            Point3D::new(1.0, 0.0, 0.0),
171            Point3D::new(0.0, 1.0, 0.0),
172            Point3D::new(0.0, 0.0, 1.0),
173        ];
174        let result = kabsch_points(&points, &points).unwrap();
175        assert!(result.rmsd < 1e-6, "RMSD should be ~0, got {}", result.rmsd);
176    }
177
178    #[test]
179    fn translated_points() {
180        let p1 = vec![
181            Point3D::new(0.0, 0.0, 0.0),
182            Point3D::new(1.0, 0.0, 0.0),
183            Point3D::new(0.0, 1.0, 0.0),
184            Point3D::new(0.0, 0.0, 1.0),
185        ];
186        let p2: Vec<Point3D> = p1.iter().map(|p: &Point3D| p.add(&Point3D::new(10.0, 20.0, 30.0))).collect();
187        let result = kabsch_points(&p1, &p2).unwrap();
188        assert!(result.rmsd < 1e-6, "RMSD should be ~0 for translated set, got {}", result.rmsd);
189    }
190
191    #[test]
192    fn rotated_points() {
193        // 90-degree rotation around Z axis
194        let p1 = vec![
195            Point3D::new(1.0, 0.0, 0.0),
196            Point3D::new(0.0, 1.0, 0.0),
197            Point3D::new(-1.0, 0.0, 0.0),
198            Point3D::new(0.0, -1.0, 0.0),
199        ];
200        let p2 = vec![
201            Point3D::new(0.0, 1.0, 0.0),
202            Point3D::new(-1.0, 0.0, 0.0),
203            Point3D::new(0.0, -1.0, 0.0),
204            Point3D::new(1.0, 0.0, 0.0),
205        ];
206        let result = kabsch_points(&p1, &p2).unwrap();
207        assert!(result.rmsd < 1e-6, "RMSD should be ~0 for rotated set, got {}", result.rmsd);
208    }
209
210    #[test]
211    fn mismatched_lengths_error() {
212        let p1 = vec![Point3D::new(0.0, 0.0, 0.0); 3];
213        let p2 = vec![Point3D::new(0.0, 0.0, 0.0); 4];
214        assert!(kabsch_points(&p1, &p2).is_err());
215    }
216
217    #[test]
218    fn align_by_ca() {
219        let atoms1 = vec![
220            make_atom("N", 0.0, 0.0, 0.0),
221            make_atom("CA", 1.0, 0.0, 0.0),
222            make_atom("C", 2.0, 0.0, 0.0),
223            make_atom("N", 3.0, 0.0, 0.0),
224            make_atom("CA", 4.0, 0.0, 0.0),
225            make_atom("C", 5.0, 0.0, 0.0),
226            make_atom("N", 6.0, 0.0, 0.0),
227            make_atom("CA", 7.0, 0.0, 0.0),
228            make_atom("C", 8.0, 0.0, 0.0),
229        ];
230        let atoms2 = vec![
231            make_atom("N", 0.0, 0.0, 5.0),
232            make_atom("CA", 1.0, 0.0, 5.0),
233            make_atom("C", 2.0, 0.0, 5.0),
234            make_atom("N", 3.0, 0.0, 5.0),
235            make_atom("CA", 4.0, 0.0, 5.0),
236            make_atom("C", 5.0, 0.0, 5.0),
237            make_atom("N", 6.0, 0.0, 5.0),
238            make_atom("CA", 7.0, 0.0, 5.0),
239            make_atom("C", 8.0, 0.0, 5.0),
240        ];
241        let refs1: Vec<&Atom> = atoms1.iter().collect();
242        let refs2: Vec<&Atom> = atoms2.iter().collect();
243        let result = align_structures_by_ca(&refs1, &refs2).unwrap();
244        assert!(result.rmsd < 1e-6);
245    }
246}