Skip to main content

cyanea_struct/
contact.rs

1//! Residue-residue contact maps.
2//!
3//! Provides CA-CA distance matrices and all-atom (minimum heavy-atom distance)
4//! contact maps for protein chains.
5
6use cyanea_core::{CyaneaError, Result, Summarizable};
7
8use crate::types::{Chain, Point3D};
9
10use alloc::format;
11use alloc::string::String;
12use alloc::vec;
13use alloc::vec::Vec;
14
15/// A symmetric distance matrix for residue-residue contacts.
16#[derive(Debug, Clone)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18pub struct ContactMap {
19    /// Chain identifier.
20    pub chain_id: char,
21    /// Number of residues (matrix is size × size).
22    pub size: usize,
23    /// Row-major n×n distance matrix.
24    pub distances: Vec<f64>,
25}
26
27impl ContactMap {
28    /// Get the distance between residues i and j.
29    pub fn get(&self, i: usize, j: usize) -> f64 {
30        self.distances[i * self.size + j]
31    }
32
33    /// Count residue pairs with distance below the cutoff (excluding diagonal).
34    pub fn count_contacts(&self, cutoff: f64) -> usize {
35        let mut count = 0;
36        for i in 0..self.size {
37            for j in (i + 1)..self.size {
38                if self.get(i, j) < cutoff {
39                    count += 1;
40                }
41            }
42        }
43        count
44    }
45
46    /// Return (i, j) pairs of residues in contact below the cutoff.
47    pub fn contacts_below(&self, cutoff: f64) -> Vec<(usize, usize)> {
48        let mut contacts = Vec::new();
49        for i in 0..self.size {
50            for j in (i + 1)..self.size {
51                if self.get(i, j) < cutoff {
52                    contacts.push((i, j));
53                }
54            }
55        }
56        contacts
57    }
58
59    /// Contact density: fraction of possible pairs below the cutoff.
60    pub fn contact_density(&self, cutoff: f64) -> f64 {
61        if self.size < 2 {
62            return 0.0;
63        }
64        let total_pairs = self.size * (self.size - 1) / 2;
65        self.count_contacts(cutoff) as f64 / total_pairs as f64
66    }
67}
68
69impl Summarizable for ContactMap {
70    fn summary(&self) -> String {
71        let contacts_8 = self.count_contacts(8.0);
72        format!(
73            "ContactMap chain {} — {} residues, {} contacts (<8Å)",
74            self.chain_id, self.size, contacts_8,
75        )
76    }
77}
78
79/// Compute a CA-CA distance contact map for a chain.
80///
81/// Each entry (i, j) is the Euclidean distance between the alpha-carbon atoms
82/// of residues i and j. Residues without a CA atom are assigned `f64::INFINITY`.
83pub fn compute_contact_map(chain: &Chain) -> Result<ContactMap> {
84    let n = chain.residues.len();
85    if n == 0 {
86        return Err(CyaneaError::InvalidInput(
87            "cannot compute contact map for empty chain".into(),
88        ));
89    }
90
91    // Extract CA positions
92    let ca_positions: Vec<Option<Point3D>> = chain
93        .residues
94        .iter()
95        .map(|r| r.get_alpha_carbon().map(|a| a.coords))
96        .collect();
97
98    let mut distances = vec![0.0f64; n * n];
99
100    for i in 0..n {
101        for j in (i + 1)..n {
102            let dist = match (&ca_positions[i], &ca_positions[j]) {
103                (Some(pi), Some(pj)) => pi.distance_to(pj),
104                _ => f64::INFINITY,
105            };
106            distances[i * n + j] = dist;
107            distances[j * n + i] = dist;
108        }
109    }
110
111    Ok(ContactMap {
112        chain_id: chain.id,
113        size: n,
114        distances,
115    })
116}
117
118/// Compute an all-atom contact map for a chain.
119///
120/// Each entry (i, j) is the minimum distance between any pair of non-hydrogen
121/// atoms in residues i and j. This gives a tighter contact definition than CA-CA.
122pub fn compute_contact_map_allatom(chain: &Chain) -> Result<ContactMap> {
123    let n = chain.residues.len();
124    if n == 0 {
125        return Err(CyaneaError::InvalidInput(
126            "cannot compute contact map for empty chain".into(),
127        ));
128    }
129
130    #[cfg(feature = "parallel")]
131    let distances = {
132        use rayon::prelude::*;
133        let upper: Vec<Vec<(usize, f64)>> = (0..n)
134            .into_par_iter()
135            .map(|i| {
136                ((i + 1)..n)
137                    .map(|j| {
138                        let mut min_dist = f64::INFINITY;
139                        for a1 in &chain.residues[i].atoms {
140                            if a1.element.as_deref() == Some("H") {
141                                continue;
142                            }
143                            for a2 in &chain.residues[j].atoms {
144                                if a2.element.as_deref() == Some("H") {
145                                    continue;
146                                }
147                                let d = a1.coords.distance_to(&a2.coords);
148                                if d < min_dist {
149                                    min_dist = d;
150                                }
151                            }
152                        }
153                        (j, min_dist)
154                    })
155                    .collect()
156            })
157            .collect();
158        let mut distances = vec![0.0f64; n * n];
159        for (i, row) in upper.into_iter().enumerate() {
160            for (j, d) in row {
161                distances[i * n + j] = d;
162                distances[j * n + i] = d;
163            }
164        }
165        distances
166    };
167
168    #[cfg(not(feature = "parallel"))]
169    let distances = {
170        let mut distances = vec![0.0f64; n * n];
171        for i in 0..n {
172            for j in (i + 1)..n {
173                let mut min_dist = f64::INFINITY;
174                for a1 in &chain.residues[i].atoms {
175                    if a1.element.as_deref() == Some("H") {
176                        continue;
177                    }
178                    for a2 in &chain.residues[j].atoms {
179                        if a2.element.as_deref() == Some("H") {
180                            continue;
181                        }
182                        let d = a1.coords.distance_to(&a2.coords);
183                        if d < min_dist {
184                            min_dist = d;
185                        }
186                    }
187                }
188                distances[i * n + j] = min_dist;
189                distances[j * n + i] = min_dist;
190            }
191        }
192        distances
193    };
194
195    Ok(ContactMap {
196        chain_id: chain.id,
197        size: n,
198        distances,
199    })
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::types::{Atom, Chain, Point3D, Residue};
206
207    fn make_atom(name: &str, x: f64, y: f64, z: f64) -> Atom {
208        Atom {
209            serial: 1,
210            name: name.into(),
211            alt_loc: None,
212            coords: Point3D::new(x, y, z),
213            occupancy: 1.0,
214            temp_factor: 0.0,
215            element: Some("C".into()),
216            charge: None,
217            is_hetatm: false,
218        }
219    }
220
221    fn make_test_chain() -> Chain {
222        Chain::new(
223            'A',
224            vec![
225                Residue {
226                    name: "ALA".into(),
227                    seq_num: 1,
228                    i_code: None,
229                    atoms: vec![make_atom("CA", 0.0, 0.0, 0.0)],
230                },
231                Residue {
232                    name: "GLY".into(),
233                    seq_num: 2,
234                    i_code: None,
235                    atoms: vec![make_atom("CA", 3.8, 0.0, 0.0)],
236                },
237                Residue {
238                    name: "VAL".into(),
239                    seq_num: 3,
240                    i_code: None,
241                    atoms: vec![make_atom("CA", 7.6, 0.0, 0.0)],
242                },
243                Residue {
244                    name: "LEU".into(),
245                    seq_num: 4,
246                    i_code: None,
247                    atoms: vec![make_atom("CA", 11.4, 0.0, 0.0)],
248                },
249            ],
250        )
251    }
252
253    #[test]
254    fn ca_contact_map() {
255        let chain = make_test_chain();
256        let cm = compute_contact_map(&chain).unwrap();
257        assert_eq!(cm.size, 4);
258        // Diagonal should be zero
259        assert!((cm.get(0, 0)).abs() < 1e-10);
260        // Distance between residues 0 and 1 should be 3.8
261        assert!((cm.get(0, 1) - 3.8).abs() < 1e-10);
262        // Symmetric
263        assert!((cm.get(0, 1) - cm.get(1, 0)).abs() < 1e-10);
264    }
265
266    #[test]
267    fn get_distance() {
268        let chain = make_test_chain();
269        let cm = compute_contact_map(&chain).unwrap();
270        assert!((cm.get(0, 2) - 7.6).abs() < 1e-10);
271    }
272
273    #[test]
274    fn count_contacts_cutoff() {
275        let chain = make_test_chain();
276        let cm = compute_contact_map(&chain).unwrap();
277        // With cutoff 8.0: (0,1)=3.8, (0,2)=7.6, (1,2)=3.8, (1,3)=7.6, (2,3)=3.8 → 5 contacts
278        assert_eq!(cm.count_contacts(8.0), 5);
279        // With cutoff 4.0: only adjacent pairs 3.8 < 4.0
280        assert_eq!(cm.count_contacts(4.0), 3);
281    }
282
283    #[test]
284    fn contact_density() {
285        let chain = make_test_chain();
286        let cm = compute_contact_map(&chain).unwrap();
287        // 4 residues = 6 pairs. 5 contacts at 8Å cutoff → density = 5/6
288        let density = cm.contact_density(8.0);
289        assert!((density - 5.0 / 6.0).abs() < 1e-10);
290    }
291
292    #[test]
293    fn allatom_contact_map() {
294        let chain = Chain::new(
295            'A',
296            vec![
297                Residue {
298                    name: "ALA".into(),
299                    seq_num: 1,
300                    i_code: None,
301                    atoms: vec![
302                        make_atom("N", 0.0, 0.0, 0.0),
303                        make_atom("CA", 1.0, 0.0, 0.0),
304                        make_atom("C", 2.0, 0.0, 0.0),
305                    ],
306                },
307                Residue {
308                    name: "GLY".into(),
309                    seq_num: 2,
310                    i_code: None,
311                    atoms: vec![
312                        make_atom("N", 3.0, 0.0, 0.0),
313                        make_atom("CA", 4.0, 0.0, 0.0),
314                        make_atom("C", 5.0, 0.0, 0.0),
315                    ],
316                },
317            ],
318        );
319        let cm = compute_contact_map_allatom(&chain).unwrap();
320        // Min distance: C of res1 (2.0) to N of res2 (3.0) = 1.0
321        assert!((cm.get(0, 1) - 1.0).abs() < 1e-10);
322    }
323}