1use cyanea_core::{CyaneaError, Result, Scored};
7
8use crate::geometry::center_of_mass_points;
9use crate::linalg::{svd_3x3, Matrix3x3};
10use crate::types::{Atom, Point3D};
11
12use alloc::format;
13use alloc::vec::Vec;
14
15#[derive(Debug, Clone)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18pub struct SuperpositionResult {
19 pub rmsd: f64,
21 pub rotation: [[f64; 3]; 3],
23 pub translation: Point3D,
25 pub transformed_coords: Vec<Point3D>,
27}
28
29impl Scored for SuperpositionResult {
30 fn score(&self) -> f64 {
31 -self.rmsd
32 }
33}
34
35pub fn kabsch(atoms1: &[&Atom], atoms2: &[&Atom]) -> Result<SuperpositionResult> {
40 let p1: Vec<Point3D> = atoms1.iter().map(|a| a.coords).collect();
41 let p2: Vec<Point3D> = atoms2.iter().map(|a| a.coords).collect();
42 kabsch_points(&p1, &p2)
43}
44
45pub fn kabsch_points(
47 points1: &[Point3D],
48 points2: &[Point3D],
49) -> Result<SuperpositionResult> {
50 if points1.len() != points2.len() {
51 return Err(CyaneaError::InvalidInput(format!(
52 "point set sizes differ: {} vs {}",
53 points1.len(),
54 points2.len()
55 )));
56 }
57 if points1.len() < 3 {
58 return Err(CyaneaError::InvalidInput(
59 "need at least 3 points for Kabsch superposition".into(),
60 ));
61 }
62
63 let n = points1.len();
64
65 let com1 = center_of_mass_points(points1);
67 let com2 = center_of_mass_points(points2);
68
69 let centered1: Vec<Point3D> = points1.iter().map(|p| p.sub(&com1)).collect();
70 let centered2: Vec<Point3D> = points2.iter().map(|p| p.sub(&com2)).collect();
71
72 let mut h = Matrix3x3::zeros();
74 for i in 0..n {
75 let p = ¢ered2[i];
76 let q = ¢ered1[i];
77 h.data[0][0] += p.x * q.x;
78 h.data[0][1] += p.x * q.y;
79 h.data[0][2] += p.x * q.z;
80 h.data[1][0] += p.y * q.x;
81 h.data[1][1] += p.y * q.y;
82 h.data[1][2] += p.y * q.z;
83 h.data[2][0] += p.z * q.x;
84 h.data[2][1] += p.z * q.y;
85 h.data[2][2] += p.z * q.z;
86 }
87
88 let svd = svd_3x3(&h);
90
91 let v = svd.vt.transpose();
93 let ut = svd.u.transpose();
94 let mut r = v.multiply(&ut);
95
96 if r.determinant() < 0.0 {
99 let mut v_fixed = v;
100 for row in 0..3 {
101 v_fixed.data[row][2] = -v_fixed.data[row][2];
102 }
103 r = v_fixed.multiply(&ut);
104 }
105
106 let mut transformed = Vec::with_capacity(n);
108 let mut sum_sq = 0.0;
109 for i in 0..n {
110 let rotated = r.apply(¢ered2[i]);
111 let final_point = rotated.add(&com1);
112 let diff = final_point.sub(&points1[i]);
113 sum_sq += diff.dot(&diff);
114 transformed.push(final_point);
115 }
116
117 let rmsd = (sum_sq / n as f64).sqrt();
118
119 Ok(SuperpositionResult {
120 rmsd,
121 rotation: r.data,
122 translation: com1.sub(&r.apply(&com2)),
123 transformed_coords: transformed,
124 })
125}
126
127pub fn align_structures_by_ca(
129 atoms1: &[&Atom],
130 atoms2: &[&Atom],
131) -> Result<SuperpositionResult> {
132 let ca1: Vec<&Atom> = atoms1.iter().copied().filter(|a| a.is_alpha_carbon()).collect();
133 let ca2: Vec<&Atom> = atoms2.iter().copied().filter(|a| a.is_alpha_carbon()).collect();
134
135 if ca1.len() != ca2.len() {
136 return Err(CyaneaError::InvalidInput(format!(
137 "different number of CA atoms: {} vs {}",
138 ca1.len(),
139 ca2.len()
140 )));
141 }
142
143 kabsch(&ca1, &ca2)
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use alloc::vec;
150 use crate::types::Atom;
151
152 fn make_atom(name: &str, x: f64, y: f64, z: f64) -> Atom {
153 Atom {
154 serial: 1,
155 name: name.into(),
156 alt_loc: None,
157 coords: Point3D::new(x, y, z),
158 occupancy: 1.0,
159 temp_factor: 0.0,
160 element: None,
161 charge: None,
162 is_hetatm: false,
163 }
164 }
165
166 #[test]
167 fn identical_points_rmsd_zero() {
168 let points = vec![
169 Point3D::new(0.0, 0.0, 0.0),
170 Point3D::new(1.0, 0.0, 0.0),
171 Point3D::new(0.0, 1.0, 0.0),
172 Point3D::new(0.0, 0.0, 1.0),
173 ];
174 let result = kabsch_points(&points, &points).unwrap();
175 assert!(result.rmsd < 1e-6, "RMSD should be ~0, got {}", result.rmsd);
176 }
177
178 #[test]
179 fn translated_points() {
180 let p1 = vec![
181 Point3D::new(0.0, 0.0, 0.0),
182 Point3D::new(1.0, 0.0, 0.0),
183 Point3D::new(0.0, 1.0, 0.0),
184 Point3D::new(0.0, 0.0, 1.0),
185 ];
186 let p2: Vec<Point3D> = p1.iter().map(|p: &Point3D| p.add(&Point3D::new(10.0, 20.0, 30.0))).collect();
187 let result = kabsch_points(&p1, &p2).unwrap();
188 assert!(result.rmsd < 1e-6, "RMSD should be ~0 for translated set, got {}", result.rmsd);
189 }
190
191 #[test]
192 fn rotated_points() {
193 let p1 = vec![
195 Point3D::new(1.0, 0.0, 0.0),
196 Point3D::new(0.0, 1.0, 0.0),
197 Point3D::new(-1.0, 0.0, 0.0),
198 Point3D::new(0.0, -1.0, 0.0),
199 ];
200 let p2 = vec![
201 Point3D::new(0.0, 1.0, 0.0),
202 Point3D::new(-1.0, 0.0, 0.0),
203 Point3D::new(0.0, -1.0, 0.0),
204 Point3D::new(1.0, 0.0, 0.0),
205 ];
206 let result = kabsch_points(&p1, &p2).unwrap();
207 assert!(result.rmsd < 1e-6, "RMSD should be ~0 for rotated set, got {}", result.rmsd);
208 }
209
210 #[test]
211 fn mismatched_lengths_error() {
212 let p1 = vec![Point3D::new(0.0, 0.0, 0.0); 3];
213 let p2 = vec![Point3D::new(0.0, 0.0, 0.0); 4];
214 assert!(kabsch_points(&p1, &p2).is_err());
215 }
216
217 #[test]
218 fn align_by_ca() {
219 let atoms1 = vec![
220 make_atom("N", 0.0, 0.0, 0.0),
221 make_atom("CA", 1.0, 0.0, 0.0),
222 make_atom("C", 2.0, 0.0, 0.0),
223 make_atom("N", 3.0, 0.0, 0.0),
224 make_atom("CA", 4.0, 0.0, 0.0),
225 make_atom("C", 5.0, 0.0, 0.0),
226 make_atom("N", 6.0, 0.0, 0.0),
227 make_atom("CA", 7.0, 0.0, 0.0),
228 make_atom("C", 8.0, 0.0, 0.0),
229 ];
230 let atoms2 = vec![
231 make_atom("N", 0.0, 0.0, 5.0),
232 make_atom("CA", 1.0, 0.0, 5.0),
233 make_atom("C", 2.0, 0.0, 5.0),
234 make_atom("N", 3.0, 0.0, 5.0),
235 make_atom("CA", 4.0, 0.0, 5.0),
236 make_atom("C", 5.0, 0.0, 5.0),
237 make_atom("N", 6.0, 0.0, 5.0),
238 make_atom("CA", 7.0, 0.0, 5.0),
239 make_atom("C", 8.0, 0.0, 5.0),
240 ];
241 let refs1: Vec<&Atom> = atoms1.iter().collect();
242 let refs2: Vec<&Atom> = atoms2.iter().collect();
243 let result = align_structures_by_ca(&refs1, &refs2).unwrap();
244 assert!(result.rmsd < 1e-6);
245 }
246}