1use cyanea_core::{Annotated, ContentAddressable, Summarizable};
4use sha2::{Digest, Sha256};
5
6use alloc::format;
7use alloc::string::String;
8use alloc::vec::Vec;
9
10#[derive(Debug, Clone, Copy, PartialEq)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub struct Point3D {
14 pub x: f64,
15 pub y: f64,
16 pub z: f64,
17}
18
19impl Point3D {
20 pub fn new(x: f64, y: f64, z: f64) -> Self {
22 Self { x, y, z }
23 }
24
25 pub fn zero() -> Self {
27 Self {
28 x: 0.0,
29 y: 0.0,
30 z: 0.0,
31 }
32 }
33
34 pub fn distance_to(&self, other: &Point3D) -> f64 {
36 let dx = self.x - other.x;
37 let dy = self.y - other.y;
38 let dz = self.z - other.z;
39 (dx * dx + dy * dy + dz * dz).sqrt()
40 }
41
42 pub fn dot(&self, other: &Point3D) -> f64 {
44 self.x * other.x + self.y * other.y + self.z * other.z
45 }
46
47 pub fn cross(&self, other: &Point3D) -> Point3D {
49 Point3D {
50 x: self.y * other.z - self.z * other.y,
51 y: self.z * other.x - self.x * other.z,
52 z: self.x * other.y - self.y * other.x,
53 }
54 }
55
56 pub fn norm(&self) -> f64 {
58 (self.x * self.x + self.y * self.y + self.z * self.z).sqrt()
59 }
60
61 pub fn normalize(&self) -> Point3D {
63 let n = self.norm();
64 if n < 1e-15 {
65 Point3D::zero()
66 } else {
67 Point3D {
68 x: self.x / n,
69 y: self.y / n,
70 z: self.z / n,
71 }
72 }
73 }
74
75 pub fn add(&self, other: &Point3D) -> Point3D {
77 Point3D {
78 x: self.x + other.x,
79 y: self.y + other.y,
80 z: self.z + other.z,
81 }
82 }
83
84 pub fn sub(&self, other: &Point3D) -> Point3D {
86 Point3D {
87 x: self.x - other.x,
88 y: self.y - other.y,
89 z: self.z - other.z,
90 }
91 }
92
93 pub fn scale(&self, s: f64) -> Point3D {
95 Point3D {
96 x: self.x * s,
97 y: self.y * s,
98 z: self.z * s,
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
105#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
106pub struct Atom {
107 pub serial: u32,
109 pub name: String,
111 pub alt_loc: Option<char>,
113 pub coords: Point3D,
115 pub occupancy: f64,
117 pub temp_factor: f64,
119 pub element: Option<String>,
121 pub charge: Option<i8>,
123 pub is_hetatm: bool,
125}
126
127impl Atom {
128 pub fn is_backbone(&self) -> bool {
130 let trimmed = self.name.trim();
131 matches!(trimmed, "N" | "CA" | "C" | "O")
132 }
133
134 pub fn is_alpha_carbon(&self) -> bool {
136 self.name.trim() == "CA"
137 }
138}
139
140#[derive(Debug, Clone)]
142#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
143pub struct Residue {
144 pub name: String,
146 pub seq_num: i32,
148 pub i_code: Option<char>,
150 pub atoms: Vec<Atom>,
152}
153
154impl Residue {
155 pub fn get_atom(&self, name: &str) -> Option<&Atom> {
157 self.atoms.iter().find(|a| a.name.trim() == name)
158 }
159
160 pub fn get_alpha_carbon(&self) -> Option<&Atom> {
162 self.get_atom("CA")
163 }
164
165 pub fn backbone_atoms(&self) -> Vec<&Atom> {
167 self.atoms.iter().filter(|a| a.is_backbone()).collect()
168 }
169
170 pub fn center_of_mass(&self) -> Point3D {
172 if self.atoms.is_empty() {
173 return Point3D::zero();
174 }
175 let mut sum = Point3D::zero();
176 for atom in &self.atoms {
177 sum = sum.add(&atom.coords);
178 }
179 sum.scale(1.0 / self.atoms.len() as f64)
180 }
181}
182
183impl Annotated for Residue {
184 fn name(&self) -> &str {
185 &self.name
186 }
187}
188
189#[derive(Debug, Clone)]
191#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
192pub struct Chain {
193 pub id: char,
195 pub residues: Vec<Residue>,
197 chain_id_str: String,
199}
200
201impl Chain {
202 pub fn new(id: char, residues: Vec<Residue>) -> Self {
204 Self {
205 id,
206 residues,
207 chain_id_str: format!("Chain {}", id),
208 }
209 }
210
211 pub fn residue_count(&self) -> usize {
213 self.residues.len()
214 }
215
216 pub fn atom_count(&self) -> usize {
218 self.residues.iter().map(|r| r.atoms.len()).sum()
219 }
220
221 pub fn alpha_carbons(&self) -> Vec<&Atom> {
223 self.residues
224 .iter()
225 .filter_map(|r| r.get_alpha_carbon())
226 .collect()
227 }
228}
229
230impl Annotated for Chain {
231 fn name(&self) -> &str {
232 &self.chain_id_str
233 }
234}
235
236#[derive(Debug, Clone)]
238#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
239pub struct Structure {
240 pub id: String,
242 pub chains: Vec<Chain>,
244}
245
246impl Structure {
247 pub fn chain_count(&self) -> usize {
249 self.chains.len()
250 }
251
252 pub fn residue_count(&self) -> usize {
254 self.chains.iter().map(|c| c.residue_count()).sum()
255 }
256
257 pub fn atom_count(&self) -> usize {
259 self.chains.iter().map(|c| c.atom_count()).sum()
260 }
261
262 pub fn get_chain(&self, id: char) -> Option<&Chain> {
264 self.chains.iter().find(|c| c.id == id)
265 }
266
267 pub fn all_atoms(&self) -> Vec<&Atom> {
269 self.chains
270 .iter()
271 .flat_map(|c| c.residues.iter().flat_map(|r| r.atoms.iter()))
272 .collect()
273 }
274
275 pub fn alpha_carbons(&self) -> Vec<&Atom> {
277 self.chains.iter().flat_map(|c| c.alpha_carbons()).collect()
278 }
279
280 pub fn center_of_mass(&self) -> Point3D {
282 let atoms = self.all_atoms();
283 if atoms.is_empty() {
284 return Point3D::zero();
285 }
286 let mut sum = Point3D::zero();
287 for atom in &atoms {
288 sum = sum.add(&atom.coords);
289 }
290 sum.scale(1.0 / atoms.len() as f64)
291 }
292}
293
294impl Annotated for Structure {
295 fn name(&self) -> &str {
296 &self.id
297 }
298}
299
300impl Summarizable for Structure {
301 fn summary(&self) -> String {
302 format!(
303 "Structure {} — {} chain(s), {} residue(s), {} atom(s)",
304 self.id,
305 self.chain_count(),
306 self.residue_count(),
307 self.atom_count(),
308 )
309 }
310}
311
312impl ContentAddressable for Structure {
313 fn content_hash(&self) -> String {
314 let mut hasher = Sha256::new();
315 hasher.update(self.id.as_bytes());
316 for chain in &self.chains {
317 hasher.update(&[chain.id as u8]);
318 for residue in &chain.residues {
319 hasher.update(residue.name.as_bytes());
320 hasher.update(&residue.seq_num.to_le_bytes());
321 for atom in &residue.atoms {
322 hasher.update(atom.name.as_bytes());
323 hasher.update(&atom.coords.x.to_le_bytes());
324 hasher.update(&atom.coords.y.to_le_bytes());
325 hasher.update(&atom.coords.z.to_le_bytes());
326 }
327 }
328 }
329 hex::encode(hasher.finalize())
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use alloc::vec;
337
338 fn make_atom(name: &str, x: f64, y: f64, z: f64) -> Atom {
339 Atom {
340 serial: 1,
341 name: name.into(),
342 alt_loc: None,
343 coords: Point3D::new(x, y, z),
344 occupancy: 1.0,
345 temp_factor: 0.0,
346 element: None,
347 charge: None,
348 is_hetatm: false,
349 }
350 }
351
352 #[test]
353 fn point3d_arithmetic() {
354 let a = Point3D::new(1.0, 2.0, 3.0);
355 let b = Point3D::new(4.0, 5.0, 6.0);
356 assert_eq!(a.add(&b), Point3D::new(5.0, 7.0, 9.0));
357 assert_eq!(a.sub(&b), Point3D::new(-3.0, -3.0, -3.0));
358 assert!((a.dot(&b) - 32.0).abs() < 1e-10);
359 assert!((a.scale(2.0).x - 2.0).abs() < 1e-10);
360 assert!((a.distance_to(&b) - (27.0_f64).sqrt()).abs() < 1e-10);
361 }
362
363 #[test]
364 fn point3d_cross_product() {
365 let x = Point3D::new(1.0, 0.0, 0.0);
366 let y = Point3D::new(0.0, 1.0, 0.0);
367 let z = x.cross(&y);
368 assert!((z.x).abs() < 1e-10);
369 assert!((z.y).abs() < 1e-10);
370 assert!((z.z - 1.0).abs() < 1e-10);
371 }
372
373 #[test]
374 fn atom_backbone_detection() {
375 let ca = make_atom("CA", 0.0, 0.0, 0.0);
376 let cb = make_atom("CB", 0.0, 0.0, 0.0);
377 let n = make_atom("N", 0.0, 0.0, 0.0);
378 assert!(ca.is_backbone());
379 assert!(ca.is_alpha_carbon());
380 assert!(!cb.is_backbone());
381 assert!(!cb.is_alpha_carbon());
382 assert!(n.is_backbone());
383 }
384
385 #[test]
386 fn residue_get_alpha_carbon() {
387 let r = Residue {
388 name: "ALA".into(),
389 seq_num: 1,
390 i_code: None,
391 atoms: vec![
392 make_atom("N", 0.0, 0.0, 0.0),
393 make_atom("CA", 1.0, 0.0, 0.0),
394 make_atom("C", 2.0, 0.0, 0.0),
395 ],
396 };
397 assert!(r.get_alpha_carbon().is_some());
398 assert_eq!(r.backbone_atoms().len(), 3);
399 }
400
401 #[test]
402 fn structure_summary_and_hash() {
403 let chain = Chain::new(
404 'A',
405 vec![Residue {
406 name: "GLY".into(),
407 seq_num: 1,
408 i_code: None,
409 atoms: vec![make_atom("CA", 1.0, 2.0, 3.0)],
410 }],
411 );
412 let s = Structure {
413 id: "1ABC".into(),
414 chains: vec![chain],
415 };
416 assert!(s.summary().contains("1ABC"));
417 assert!(s.summary().contains("1 chain"));
418 assert!(s.summary().contains("1 residue"));
419 assert!(s.summary().contains("1 atom"));
420
421 let hash = s.content_hash();
422 assert_eq!(hash.len(), 64); assert_eq!(hash, s.content_hash());
425 }
426}