1use cyanea_core::{CyaneaError, Result, Summarizable};
7
8use crate::types::{Chain, Point3D};
9
10use alloc::format;
11use alloc::string::String;
12use alloc::vec;
13use alloc::vec::Vec;
14
15#[derive(Debug, Clone)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18pub struct ContactMap {
19 pub chain_id: char,
21 pub size: usize,
23 pub distances: Vec<f64>,
25}
26
27impl ContactMap {
28 pub fn get(&self, i: usize, j: usize) -> f64 {
30 self.distances[i * self.size + j]
31 }
32
33 pub fn count_contacts(&self, cutoff: f64) -> usize {
35 let mut count = 0;
36 for i in 0..self.size {
37 for j in (i + 1)..self.size {
38 if self.get(i, j) < cutoff {
39 count += 1;
40 }
41 }
42 }
43 count
44 }
45
46 pub fn contacts_below(&self, cutoff: f64) -> Vec<(usize, usize)> {
48 let mut contacts = Vec::new();
49 for i in 0..self.size {
50 for j in (i + 1)..self.size {
51 if self.get(i, j) < cutoff {
52 contacts.push((i, j));
53 }
54 }
55 }
56 contacts
57 }
58
59 pub fn contact_density(&self, cutoff: f64) -> f64 {
61 if self.size < 2 {
62 return 0.0;
63 }
64 let total_pairs = self.size * (self.size - 1) / 2;
65 self.count_contacts(cutoff) as f64 / total_pairs as f64
66 }
67}
68
69impl Summarizable for ContactMap {
70 fn summary(&self) -> String {
71 let contacts_8 = self.count_contacts(8.0);
72 format!(
73 "ContactMap chain {} — {} residues, {} contacts (<8Å)",
74 self.chain_id, self.size, contacts_8,
75 )
76 }
77}
78
79pub fn compute_contact_map(chain: &Chain) -> Result<ContactMap> {
84 let n = chain.residues.len();
85 if n == 0 {
86 return Err(CyaneaError::InvalidInput(
87 "cannot compute contact map for empty chain".into(),
88 ));
89 }
90
91 let ca_positions: Vec<Option<Point3D>> = chain
93 .residues
94 .iter()
95 .map(|r| r.get_alpha_carbon().map(|a| a.coords))
96 .collect();
97
98 let mut distances = vec![0.0f64; n * n];
99
100 for i in 0..n {
101 for j in (i + 1)..n {
102 let dist = match (&ca_positions[i], &ca_positions[j]) {
103 (Some(pi), Some(pj)) => pi.distance_to(pj),
104 _ => f64::INFINITY,
105 };
106 distances[i * n + j] = dist;
107 distances[j * n + i] = dist;
108 }
109 }
110
111 Ok(ContactMap {
112 chain_id: chain.id,
113 size: n,
114 distances,
115 })
116}
117
118pub fn compute_contact_map_allatom(chain: &Chain) -> Result<ContactMap> {
123 let n = chain.residues.len();
124 if n == 0 {
125 return Err(CyaneaError::InvalidInput(
126 "cannot compute contact map for empty chain".into(),
127 ));
128 }
129
130 #[cfg(feature = "parallel")]
131 let distances = {
132 use rayon::prelude::*;
133 let upper: Vec<Vec<(usize, f64)>> = (0..n)
134 .into_par_iter()
135 .map(|i| {
136 ((i + 1)..n)
137 .map(|j| {
138 let mut min_dist = f64::INFINITY;
139 for a1 in &chain.residues[i].atoms {
140 if a1.element.as_deref() == Some("H") {
141 continue;
142 }
143 for a2 in &chain.residues[j].atoms {
144 if a2.element.as_deref() == Some("H") {
145 continue;
146 }
147 let d = a1.coords.distance_to(&a2.coords);
148 if d < min_dist {
149 min_dist = d;
150 }
151 }
152 }
153 (j, min_dist)
154 })
155 .collect()
156 })
157 .collect();
158 let mut distances = vec![0.0f64; n * n];
159 for (i, row) in upper.into_iter().enumerate() {
160 for (j, d) in row {
161 distances[i * n + j] = d;
162 distances[j * n + i] = d;
163 }
164 }
165 distances
166 };
167
168 #[cfg(not(feature = "parallel"))]
169 let distances = {
170 let mut distances = vec![0.0f64; n * n];
171 for i in 0..n {
172 for j in (i + 1)..n {
173 let mut min_dist = f64::INFINITY;
174 for a1 in &chain.residues[i].atoms {
175 if a1.element.as_deref() == Some("H") {
176 continue;
177 }
178 for a2 in &chain.residues[j].atoms {
179 if a2.element.as_deref() == Some("H") {
180 continue;
181 }
182 let d = a1.coords.distance_to(&a2.coords);
183 if d < min_dist {
184 min_dist = d;
185 }
186 }
187 }
188 distances[i * n + j] = min_dist;
189 distances[j * n + i] = min_dist;
190 }
191 }
192 distances
193 };
194
195 Ok(ContactMap {
196 chain_id: chain.id,
197 size: n,
198 distances,
199 })
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::types::{Atom, Chain, Point3D, Residue};
206
207 fn make_atom(name: &str, x: f64, y: f64, z: f64) -> Atom {
208 Atom {
209 serial: 1,
210 name: name.into(),
211 alt_loc: None,
212 coords: Point3D::new(x, y, z),
213 occupancy: 1.0,
214 temp_factor: 0.0,
215 element: Some("C".into()),
216 charge: None,
217 is_hetatm: false,
218 }
219 }
220
221 fn make_test_chain() -> Chain {
222 Chain::new(
223 'A',
224 vec![
225 Residue {
226 name: "ALA".into(),
227 seq_num: 1,
228 i_code: None,
229 atoms: vec![make_atom("CA", 0.0, 0.0, 0.0)],
230 },
231 Residue {
232 name: "GLY".into(),
233 seq_num: 2,
234 i_code: None,
235 atoms: vec![make_atom("CA", 3.8, 0.0, 0.0)],
236 },
237 Residue {
238 name: "VAL".into(),
239 seq_num: 3,
240 i_code: None,
241 atoms: vec![make_atom("CA", 7.6, 0.0, 0.0)],
242 },
243 Residue {
244 name: "LEU".into(),
245 seq_num: 4,
246 i_code: None,
247 atoms: vec![make_atom("CA", 11.4, 0.0, 0.0)],
248 },
249 ],
250 )
251 }
252
253 #[test]
254 fn ca_contact_map() {
255 let chain = make_test_chain();
256 let cm = compute_contact_map(&chain).unwrap();
257 assert_eq!(cm.size, 4);
258 assert!((cm.get(0, 0)).abs() < 1e-10);
260 assert!((cm.get(0, 1) - 3.8).abs() < 1e-10);
262 assert!((cm.get(0, 1) - cm.get(1, 0)).abs() < 1e-10);
264 }
265
266 #[test]
267 fn get_distance() {
268 let chain = make_test_chain();
269 let cm = compute_contact_map(&chain).unwrap();
270 assert!((cm.get(0, 2) - 7.6).abs() < 1e-10);
271 }
272
273 #[test]
274 fn count_contacts_cutoff() {
275 let chain = make_test_chain();
276 let cm = compute_contact_map(&chain).unwrap();
277 assert_eq!(cm.count_contacts(8.0), 5);
279 assert_eq!(cm.count_contacts(4.0), 3);
281 }
282
283 #[test]
284 fn contact_density() {
285 let chain = make_test_chain();
286 let cm = compute_contact_map(&chain).unwrap();
287 let density = cm.contact_density(8.0);
289 assert!((density - 5.0 / 6.0).abs() < 1e-10);
290 }
291
292 #[test]
293 fn allatom_contact_map() {
294 let chain = Chain::new(
295 'A',
296 vec![
297 Residue {
298 name: "ALA".into(),
299 seq_num: 1,
300 i_code: None,
301 atoms: vec![
302 make_atom("N", 0.0, 0.0, 0.0),
303 make_atom("CA", 1.0, 0.0, 0.0),
304 make_atom("C", 2.0, 0.0, 0.0),
305 ],
306 },
307 Residue {
308 name: "GLY".into(),
309 seq_num: 2,
310 i_code: None,
311 atoms: vec![
312 make_atom("N", 3.0, 0.0, 0.0),
313 make_atom("CA", 4.0, 0.0, 0.0),
314 make_atom("C", 5.0, 0.0, 0.0),
315 ],
316 },
317 ],
318 );
319 let cm = compute_contact_map_allatom(&chain).unwrap();
320 assert!((cm.get(0, 1) - 1.0).abs() < 1e-10);
322 }
323}