Skip to main content

chematic_3d/
conformer.rs

1//! Conformer ensemble: a molecule with multiple sets of 3D coordinates.
2
3use std::fmt;
4
5use chematic_core::{AtomIdx, Molecule};
6
7use crate::coords::Coords3D;
8use crate::shape_descriptors::jacobi3;
9
10// ---------------------------------------------------------------------------
11// Error type
12// ---------------------------------------------------------------------------
13
14#[derive(Debug, PartialEq)]
15pub enum ConformerError {
16    AtomCountMismatch { expected: usize, got: usize },
17}
18
19impl fmt::Display for ConformerError {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        match self {
22            ConformerError::AtomCountMismatch { expected, got } => {
23                write!(f, "conformer has {got} atoms but molecule has {expected}")
24            }
25        }
26    }
27}
28
29// ---------------------------------------------------------------------------
30// ConformerEnsemble
31// ---------------------------------------------------------------------------
32
33/// A molecule paired with zero or more sets of 3D coordinates.
34///
35/// Conformer indices are contiguous; `remove_conformer` shifts all subsequent
36/// indices down by one (Vec::remove semantics).
37pub struct ConformerEnsemble {
38    mol: Molecule,
39    conformers: Vec<Coords3D>,
40}
41
42impl ConformerEnsemble {
43    /// Create an ensemble with no conformers.
44    pub fn new(mol: Molecule) -> Self {
45        Self { mol, conformers: Vec::new() }
46    }
47
48    /// Create an ensemble pre-loaded with one conformer.
49    ///
50    /// Returns an error if `coords.atom_count() != mol.atom_count()`.
51    pub fn with_conformer(mol: Molecule, coords: Coords3D) -> Result<Self, ConformerError> {
52        let expected = mol.atom_count();
53        let got = coords.atom_count();
54        if got != expected {
55            return Err(ConformerError::AtomCountMismatch { expected, got });
56        }
57        Ok(Self { mol, conformers: vec![coords] })
58    }
59
60    /// The molecule (topology only; no coordinates).
61    pub fn mol(&self) -> &Molecule {
62        &self.mol
63    }
64
65    /// Number of conformers currently stored.
66    pub fn conformer_count(&self) -> usize {
67        self.conformers.len()
68    }
69
70    /// Append a conformer.
71    ///
72    /// Returns the index of the newly added conformer, or an error if the
73    /// atom count does not match.
74    pub fn add_conformer(&mut self, coords: Coords3D) -> Result<usize, ConformerError> {
75        let expected = self.mol.atom_count();
76        let got = coords.atom_count();
77        if got != expected {
78            return Err(ConformerError::AtomCountMismatch { expected, got });
79        }
80        let idx = self.conformers.len();
81        self.conformers.push(coords);
82        Ok(idx)
83    }
84
85    /// Return a reference to the conformer at `idx`, or `None` if out of range.
86    pub fn get_conformer(&self, idx: usize) -> Option<&Coords3D> {
87        self.conformers.get(idx)
88    }
89
90    /// Return a mutable reference to the conformer at `idx`, or `None` if out of range.
91    pub fn get_conformer_mut(&mut self, idx: usize) -> Option<&mut Coords3D> {
92        self.conformers.get_mut(idx)
93    }
94
95    /// Remove and return the conformer at `idx`.
96    ///
97    /// All conformers with index > `idx` shift down by one.
98    /// Returns `None` if `idx` is out of range.
99    pub fn remove_conformer(&mut self, idx: usize) -> Option<Coords3D> {
100        if idx < self.conformers.len() {
101            Some(self.conformers.remove(idx))
102        } else {
103            None
104        }
105    }
106
107    /// RMSD between conformers `a` and `b` **without** superposition.
108    ///
109    /// Returns `None` if either index is out of range or the molecule has no atoms.
110    pub fn conformer_rmsd_no_align(&self, a: usize, b: usize) -> Option<f64> {
111        let ca = self.conformers.get(a)?;
112        let cb = self.conformers.get(b)?;
113        let n = self.mol.atom_count();
114        if n == 0 {
115            return Some(0.0);
116        }
117        let sum_sq: f64 = (0..n)
118            .map(|i| {
119                let idx = AtomIdx(i as u32);
120                let pa = ca.get(idx);
121                let pb = cb.get(idx);
122                let dx = pa.x - pb.x;
123                let dy = pa.y - pb.y;
124                let dz = pa.z - pb.z;
125                dx * dx + dy * dy + dz * dz
126            })
127            .sum();
128        Some((sum_sq / n as f64).sqrt())
129    }
130
131    /// Kabsch-aligned RMSD between conformers `a` and `b`.
132    ///
133    /// Finds the rigid-body rotation (no scaling) that minimises RMSD, then
134    /// returns that minimum RMSD.  Returns `None` if either index is out of
135    /// range.
136    pub fn conformer_rmsd(&self, a: usize, b: usize) -> Option<f64> {
137        let ca = self.conformers.get(a)?;
138        let cb = self.conformers.get(b)?;
139        let n = self.mol.atom_count();
140        Some(kabsch_rmsd(ca, cb, n))
141    }
142}
143
144// ---------------------------------------------------------------------------
145// Kabsch RMSD helper
146// ---------------------------------------------------------------------------
147
148fn kabsch_rmsd(coords_a: &Coords3D, coords_b: &Coords3D, n: usize) -> f64 {
149    if n == 0 {
150        return 0.0;
151    }
152
153    let nf = n as f64;
154
155    // Centroids.
156    let mut ca = [0.0f64; 3];
157    let mut cb = [0.0f64; 3];
158    for i in 0..n {
159        let idx = AtomIdx(i as u32);
160        let pa = coords_a.get(idx);
161        let pb = coords_b.get(idx);
162        ca[0] += pa.x; ca[1] += pa.y; ca[2] += pa.z;
163        cb[0] += pb.x; cb[1] += pb.y; cb[2] += pb.z;
164    }
165    for k in 0..3 { ca[k] /= nf; cb[k] /= nf; }
166
167    // Centered coordinates.
168    let mut p = vec![[0.0f64; 3]; n];
169    let mut q = vec![[0.0f64; 3]; n];
170    for i in 0..n {
171        let idx = AtomIdx(i as u32);
172        let pa = coords_a.get(idx);
173        let pb = coords_b.get(idx);
174        p[i] = [pa.x - ca[0], pa.y - ca[1], pa.z - ca[2]];
175        q[i] = [pb.x - cb[0], pb.y - cb[1], pb.z - cb[2]];
176    }
177
178    // H = P^T * Q  (3×3 covariance matrix).
179    let mut h = [[0.0f64; 3]; 3];
180    for i in 0..n {
181        for r in 0..3 {
182            for c in 0..3 {
183                h[r][c] += p[i][r] * q[i][c];
184            }
185        }
186    }
187
188    // H^T * H (symmetric).
189    let mut hth = [[0.0f64; 3]; 3];
190    for r in 0..3 {
191        for c in 0..3 {
192            for k in 0..3 {
193                hth[r][c] += h[k][r] * h[k][c];
194            }
195        }
196    }
197
198    // Eigendecompose H^T * H → V columns are right singular vectors.
199    // evecs[i][j] = component i of eigenvector j (sorted ascending by eigenvalue).
200    let (evals, v) = jacobi3(hth);
201
202    // U = H * V * diag(1/σ).  σ_j = sqrt(evals[j]).
203    let mut hv = [[0.0f64; 3]; 3];
204    for r in 0..3 {
205        for c in 0..3 {
206            for k in 0..3 {
207                hv[r][c] += h[r][k] * v[k][c];
208            }
209        }
210    }
211    let mut u = [[0.0f64; 3]; 3];
212    for j in 0..3 {
213        let sigma = evals[j].max(0.0).sqrt();
214        for r in 0..3 {
215            u[r][j] = if sigma > 1e-10 { hv[r][j] / sigma } else { 0.0 };
216        }
217    }
218
219    // R = V * U^T.  R[r][c] = Σ_k V[r][k] * U[c][k].
220    let mut r_mat = [[0.0f64; 3]; 3];
221    for r in 0..3 {
222        for c in 0..3 {
223            for k in 0..3 {
224                r_mat[r][c] += v[r][k] * u[c][k];
225            }
226        }
227    }
228
229    // Reflection correction: if det(R) < 0, flip V column with smallest σ (col 0).
230    let det = det3(r_mat);
231    let mut v_final = v;
232    if det < 0.0 {
233        for r in 0..3 { v_final[r][0] *= -1.0; }
234        // Recompute R.
235        r_mat = [[0.0f64; 3]; 3];
236        for r in 0..3 {
237            for c in 0..3 {
238                for k in 0..3 {
239                    r_mat[r][c] += v_final[r][k] * u[c][k];
240                }
241            }
242        }
243    }
244
245    // Apply R to q, compute RMSD.
246    let mut sum_sq = 0.0f64;
247    for i in 0..n {
248        for row in 0..3 {
249            let rotated = r_mat[row][0] * q[i][0]
250                + r_mat[row][1] * q[i][1]
251                + r_mat[row][2] * q[i][2];
252            let diff = p[i][row] - rotated;
253            sum_sq += diff * diff;
254        }
255    }
256    (sum_sq / nf).sqrt()
257}
258
259fn det3(m: [[f64; 3]; 3]) -> f64 {
260    m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
261        - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
262        + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0])
263}
264
265// ---------------------------------------------------------------------------
266// Tests
267// ---------------------------------------------------------------------------
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use chematic_smiles::parse;
273
274    use crate::{coords::Point3, dg::generate_coords};
275
276    fn make_ensemble() -> ConformerEnsemble {
277        let mol = parse("CCC").unwrap();
278        let c = generate_coords(&mol);
279        ConformerEnsemble::with_conformer(mol, c).unwrap()
280    }
281
282    // --- Construction and basic access --------------------------------------
283
284    #[test]
285    fn new_has_zero_conformers() {
286        let mol = parse("C").unwrap();
287        let ens = ConformerEnsemble::new(mol);
288        assert_eq!(ens.conformer_count(), 0);
289    }
290
291    #[test]
292    fn with_conformer_has_one() {
293        let ens = make_ensemble();
294        assert_eq!(ens.conformer_count(), 1);
295    }
296
297    #[test]
298    fn add_conformer_increments_count() {
299        let mol = parse("CC").unwrap();
300        let c1 = generate_coords(&mol);
301        let c2 = generate_coords(&mol);
302        let mut ens = ConformerEnsemble::with_conformer(mol, c1).unwrap();
303        let idx = ens.add_conformer(c2).unwrap();
304        assert_eq!(idx, 1);
305        assert_eq!(ens.conformer_count(), 2);
306    }
307
308    #[test]
309    fn add_conformer_wrong_atom_count_errors() {
310        let mol = parse("CC").unwrap();
311        let wrong = Coords3D::new_zeroed(5);
312        let mut ens = ConformerEnsemble::new(mol);
313        let err = ens.add_conformer(wrong).unwrap_err();
314        assert!(matches!(err, ConformerError::AtomCountMismatch { expected: 2, got: 5 }));
315    }
316
317    #[test]
318    fn get_conformer_out_of_range_returns_none() {
319        let ens = make_ensemble();
320        assert!(ens.get_conformer(99).is_none());
321    }
322
323    // --- remove_conformer ---------------------------------------------------
324
325    #[test]
326    fn remove_conformer_decrements_count() {
327        let mut ens = make_ensemble();
328        let removed = ens.remove_conformer(0);
329        assert!(removed.is_some());
330        assert_eq!(ens.conformer_count(), 0);
331    }
332
333    #[test]
334    fn remove_conformer_shifts_indices() {
335        let mol = parse("C").unwrap();
336        let n = mol.atom_count();
337        let mut ens = ConformerEnsemble::new(mol);
338
339        // Add three conformers with distinct x-coordinates for atom 0.
340        for x in [1.0f64, 2.0, 3.0] {
341            let mut c = Coords3D::new_zeroed(n);
342            c.set(AtomIdx(0), Point3::new(x, 0.0, 0.0));
343            ens.add_conformer(c).unwrap();
344        }
345
346        // Remove index 0; what was index 1 (x=2) is now index 0.
347        ens.remove_conformer(0).unwrap();
348        assert_eq!(ens.conformer_count(), 2);
349        assert!((ens.get_conformer(0).unwrap().get(AtomIdx(0)).x - 2.0).abs() < 1e-10);
350    }
351
352    #[test]
353    fn remove_conformer_out_of_range_returns_none() {
354        let mut ens = make_ensemble();
355        assert!(ens.remove_conformer(99).is_none());
356    }
357
358    // --- RMSD ---------------------------------------------------------------
359
360    #[test]
361    fn rmsd_no_align_same_conformer_is_zero() {
362        let ens = make_ensemble();
363        let rmsd = ens.conformer_rmsd_no_align(0, 0).unwrap();
364        assert!(rmsd.abs() < 1e-10, "self-RMSD should be 0, got {rmsd}");
365    }
366
367    #[test]
368    fn rmsd_no_align_translated_is_nonzero() {
369        let mol = parse("CC").unwrap();
370        let n = mol.atom_count();
371        let mut c1 = Coords3D::new_zeroed(n);
372        let mut c2 = Coords3D::new_zeroed(n);
373        for i in 0..n {
374            c1.set(AtomIdx(i as u32), Point3::new(i as f64, 0.0, 0.0));
375            c2.set(AtomIdx(i as u32), Point3::new(i as f64 + 10.0, 0.0, 0.0));
376        }
377        let mut ens = ConformerEnsemble::with_conformer(mol, c1).unwrap();
378        ens.add_conformer(c2).unwrap();
379        let rmsd = ens.conformer_rmsd_no_align(0, 1).unwrap();
380        assert!(rmsd > 0.0, "translated conformers should have non-zero RMSD");
381    }
382
383    #[test]
384    fn kabsch_rmsd_same_conformer_is_zero() {
385        let ens = make_ensemble();
386        let rmsd = ens.conformer_rmsd(0, 0).unwrap();
387        assert!(rmsd.abs() < 1e-8, "Kabsch self-RMSD should be 0, got {rmsd}");
388    }
389
390    #[test]
391    fn kabsch_rmsd_pure_translation_is_zero() {
392        // After Kabsch superposition, a pure translation should give RMSD = 0.
393        let mol = parse("CCC").unwrap();
394        let n = mol.atom_count();
395        let base = generate_coords(&mol);
396        let mut shifted = Coords3D::new_zeroed(n);
397        let offset = 5.0;
398        for i in 0..n {
399            let p = base.get(AtomIdx(i as u32));
400            shifted.set(AtomIdx(i as u32), Point3::new(p.x + offset, p.y + offset, p.z + offset));
401        }
402        let mut ens = ConformerEnsemble::with_conformer(mol, base).unwrap();
403        ens.add_conformer(shifted).unwrap();
404        let rmsd = ens.conformer_rmsd(0, 1).unwrap();
405        assert!(rmsd < 1e-6, "pure-translation Kabsch RMSD should be ~0, got {rmsd}");
406    }
407
408    #[test]
409    fn kabsch_rmsd_different_conformers_nonzero() {
410        let mol = parse("CCC").unwrap();
411        let c1 = generate_coords(&mol);
412        let n = mol.atom_count();
413        // Build a clearly different conformer by mirroring coordinates.
414        let mut c2 = Coords3D::new_zeroed(n);
415        for i in 0..n {
416            let p = c1.get(AtomIdx(i as u32));
417            c2.set(AtomIdx(i as u32), Point3::new(-p.x, p.y, p.z));
418        }
419        let mut ens = ConformerEnsemble::with_conformer(mol, c1).unwrap();
420        ens.add_conformer(c2).unwrap();
421        let rmsd = ens.conformer_rmsd(0, 1).unwrap();
422        // For a non-trivially symmetric molecule this should be > 0.
423        // (May be 0 for perfectly symmetric, so just assert non-negative.)
424        assert!(rmsd >= 0.0, "RMSD must be non-negative, got {rmsd}");
425    }
426
427    #[test]
428    fn kabsch_rmsd_out_of_range_returns_none() {
429        let ens = make_ensemble();
430        assert!(ens.conformer_rmsd(0, 99).is_none());
431        assert!(ens.conformer_rmsd(99, 0).is_none());
432    }
433
434    #[test]
435    fn rmsd_no_align_out_of_range_returns_none() {
436        let ens = make_ensemble();
437        assert!(ens.conformer_rmsd_no_align(0, 99).is_none());
438    }
439}