1use crate::model::{
2 residue::ResidueType,
3 types::{TypeIdx, Vec3},
4};
5use arrayvec::ArrayVec;
6use std::collections::{HashMap, HashSet};
7
8pub const MAX_SIDECHAIN_ATOMS: usize = 18;
9
10#[derive(Debug, Clone)]
12pub struct System {
13 pub mobile: Vec<Residue>,
15 pub fixed: FixedAtomPool,
17 pub ff: ForceFieldParams,
19}
20
21#[derive(Debug, Clone)]
23pub struct Residue {
24 res_type: ResidueType,
26 anchor: [Vec3; 3],
28 phi: f32,
30 psi: f32,
32 omega: f32,
34 sidechain: ArrayVec<Vec3, MAX_SIDECHAIN_ATOMS>,
36 atom_types: ArrayVec<TypeIdx, MAX_SIDECHAIN_ATOMS>,
38 atom_charges: ArrayVec<f32, MAX_SIDECHAIN_ATOMS>,
40 donor_of_h: ArrayVec<u8, MAX_SIDECHAIN_ATOMS>,
42}
43
44pub struct SidechainAtoms<'a> {
46 pub coords: &'a [Vec3],
48 pub types: &'a [TypeIdx],
50 pub charges: &'a [f32],
52 pub donor_of_h: &'a [u8],
54}
55
56impl Residue {
57 pub fn new(
67 res_type: ResidueType,
68 anchor: [Vec3; 3],
69 phi: f32,
70 psi: f32,
71 omega: f32,
72 atoms: SidechainAtoms<'_>,
73 ) -> Option<Self> {
74 if !res_type.is_packable() {
75 return None;
76 }
77 let n = atoms.coords.len();
78 assert!(n <= MAX_SIDECHAIN_ATOMS, "too many sidechain atoms: {n}");
79 assert_eq!(atoms.types.len(), n, "types/coords length mismatch");
80 assert_eq!(atoms.charges.len(), n, "charges/coords length mismatch");
81 assert_eq!(
82 atoms.donor_of_h.len(),
83 n,
84 "donor_of_h/coords length mismatch"
85 );
86 Some(Self {
87 res_type,
88 anchor,
89 phi,
90 psi,
91 omega,
92 sidechain: atoms.coords.iter().copied().collect(),
93 atom_types: atoms.types.iter().copied().collect(),
94 atom_charges: atoms.charges.iter().copied().collect(),
95 donor_of_h: atoms.donor_of_h.iter().copied().collect(),
96 })
97 }
98
99 #[inline]
100 pub fn res_type(&self) -> ResidueType {
101 self.res_type
102 }
103 #[inline]
104 pub fn anchor(&self) -> &[Vec3; 3] {
105 &self.anchor
106 }
107 #[inline]
108 pub fn phi(&self) -> f32 {
109 self.phi
110 }
111 #[inline]
112 pub fn psi(&self) -> f32 {
113 self.psi
114 }
115 #[inline]
116 pub fn omega(&self) -> f32 {
117 self.omega
118 }
119 #[inline]
120 pub fn sidechain(&self) -> &[Vec3] {
121 &self.sidechain
122 }
123 #[inline]
124 pub fn atom_types(&self) -> &[TypeIdx] {
125 &self.atom_types
126 }
127 #[inline]
128 pub fn atom_charges(&self) -> &[f32] {
129 &self.atom_charges
130 }
131 #[inline]
132 pub fn donor_of_h(&self) -> &[u8] {
133 &self.donor_of_h
134 }
135
136 #[inline]
138 pub(crate) fn set_sidechain(&mut self, coords: &[Vec3]) {
139 debug_assert!(
140 coords.len() <= MAX_SIDECHAIN_ATOMS,
141 "coords.len()={} > MAX_SIDECHAIN_ATOMS={}",
142 coords.len(),
143 MAX_SIDECHAIN_ATOMS
144 );
145 self.sidechain.clear();
146 unsafe {
148 self.sidechain
149 .try_extend_from_slice(coords)
150 .unwrap_unchecked()
151 };
152 }
153}
154
155#[derive(Debug, Clone)]
157pub struct FixedAtomPool {
158 pub positions: Vec<Vec3>,
160 pub types: Vec<TypeIdx>,
162 pub charges: Vec<f32>,
164 pub donor_for_h: Vec<u32>,
166}
167
168#[derive(Debug, Clone)]
170pub struct ForceFieldParams {
171 pub vdw: VdwMatrix,
172 pub hbond: HBondParams,
173}
174
175#[derive(Debug, Clone)]
177pub enum VdwMatrix {
178 LennardJones(LjMatrix),
180 Buckingham(BuckMatrix),
182}
183
184#[derive(Debug, Clone, Copy, PartialEq)]
186pub struct LjPair {
187 pub d0: f32,
189 pub r0_sq: f32,
191}
192
193#[derive(Debug, Clone, Copy, PartialEq)]
195pub struct BuckPair {
196 pub a: f32,
198 pub b: f32,
200 pub c: f32,
202 pub r_max_sq: f32,
204 pub two_e_max: f32,
206}
207
208#[derive(Debug, Clone)]
210pub struct LjMatrix {
211 n: usize,
213 data: Box<[LjPair]>,
215}
216
217impl LjMatrix {
218 pub fn new(n: usize, data: Vec<LjPair>) -> Self {
224 assert_eq!(data.len(), n * n, "data.len() must equal n*n");
225 assert!(
226 (0..n).all(|i| (0..i).all(|j| data[i * n + j] == data[j * n + i])),
227 "matrix must be symmetric"
228 );
229 Self {
230 n,
231 data: data.into_boxed_slice(),
232 }
233 }
234
235 #[inline(always)]
237 pub fn get(&self, i: TypeIdx, j: TypeIdx) -> LjPair {
238 self.data[usize::from(i) * self.n + usize::from(j)]
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct BuckMatrix {
245 n: usize,
247 data: Box<[BuckPair]>,
249}
250
251impl BuckMatrix {
252 pub fn new(n: usize, data: Vec<BuckPair>) -> Self {
258 assert_eq!(data.len(), n * n, "data.len() must equal n*n");
259 assert!(
260 (0..n).all(|i| (0..i).all(|j| data[i * n + j] == data[j * n + i])),
261 "matrix must be symmetric"
262 );
263 Self {
264 n,
265 data: data.into_boxed_slice(),
266 }
267 }
268
269 #[inline(always)]
271 pub fn get(&self, i: TypeIdx, j: TypeIdx) -> BuckPair {
272 self.data[usize::from(i) * self.n + usize::from(j)]
273 }
274}
275
276#[derive(Debug, Clone)]
278pub struct HBondParams {
279 h_types: HashSet<TypeIdx>,
281 acc_types: HashSet<TypeIdx>,
283 params: HashMap<(TypeIdx, TypeIdx, TypeIdx), (f32, f32)>,
285}
286
287impl HBondParams {
288 pub fn new(
290 h_types: HashSet<TypeIdx>,
291 acc_types: HashSet<TypeIdx>,
292 params: HashMap<(TypeIdx, TypeIdx, TypeIdx), (f32, f32)>,
293 ) -> Self {
294 Self {
295 h_types,
296 acc_types,
297 params,
298 }
299 }
300
301 #[inline]
303 pub fn is_hbond_candidate(&self, ta: TypeIdx, tb: TypeIdx) -> bool {
304 (self.h_types.contains(&ta) && self.acc_types.contains(&tb))
305 || (self.h_types.contains(&tb) && self.acc_types.contains(&ta))
306 }
307
308 #[inline]
310 pub fn get(
311 &self,
312 donor_type: TypeIdx,
313 h_type: TypeIdx,
314 acc_type: TypeIdx,
315 ) -> Option<(f32, f32)> {
316 self.params.get(&(donor_type, h_type, acc_type)).copied()
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::model::residue::ResidueType;
324 use crate::model::types::{TypeIdx, Vec3};
325 use std::f32::consts::PI;
326
327 fn v(x: f32, y: f32, z: f32) -> Vec3 {
328 Vec3::new(x, y, z)
329 }
330
331 fn t(n: u8) -> TypeIdx {
332 TypeIdx(n)
333 }
334
335 fn ser_residue() -> Residue {
336 let anchor = [v(0.0, 0.0, 0.0), v(1.5, 0.0, 0.0), v(3.0, 0.0, 0.0)];
337 let coords = [v(1.0, 1.0, 0.0); 5];
338 let types = [t(1); 5];
339 let charges = [0.1f32; 5];
340 let donor_of_h = [u8::MAX; 5];
341 Residue::new(
342 ResidueType::Ser,
343 anchor,
344 -1.0,
345 1.0,
346 PI,
347 SidechainAtoms {
348 coords: &coords,
349 types: &types,
350 charges: &charges,
351 donor_of_h: &donor_of_h,
352 },
353 )
354 .unwrap()
355 }
356
357 fn lj_identity(n: usize) -> LjMatrix {
358 let mut data = vec![
359 LjPair {
360 d0: 0.0,
361 r0_sq: 0.0
362 };
363 n * n
364 ];
365 for i in 0..n {
366 data[i * n + i] = LjPair {
367 d0: 1.0,
368 r0_sq: 4.0,
369 };
370 }
371 LjMatrix::new(n, data)
372 }
373
374 fn buck_identity(n: usize) -> BuckMatrix {
375 let zero = BuckPair {
376 a: 0.0,
377 b: 0.0,
378 c: 0.0,
379 r_max_sq: 0.0,
380 two_e_max: 0.0,
381 };
382 let diag = BuckPair {
383 a: 1.0,
384 b: 0.5,
385 c: 2.0,
386 r_max_sq: 4.0,
387 two_e_max: 0.1,
388 };
389 let mut data = vec![zero; n * n];
390 for i in 0..n {
391 data[i * n + i] = diag;
392 }
393 BuckMatrix::new(n, data)
394 }
395
396 fn empty_hbond() -> HBondParams {
397 HBondParams::new(HashSet::new(), HashSet::new(), HashMap::new())
398 }
399
400 #[test]
401 fn residue_new_rejects_non_packable() {
402 let anchor = [v(0.0, 0.0, 0.0); 3];
403 let empty = SidechainAtoms {
404 coords: &[],
405 types: &[],
406 charges: &[],
407 donor_of_h: &[],
408 };
409 assert!(Residue::new(ResidueType::Gly, anchor, 0.0, 0.0, PI, empty).is_none());
410 let empty = SidechainAtoms {
411 coords: &[],
412 types: &[],
413 charges: &[],
414 donor_of_h: &[],
415 };
416 assert!(Residue::new(ResidueType::Ala, anchor, 0.0, 0.0, PI, empty).is_none());
417 }
418
419 #[test]
420 fn residue_new_accepts_packable() {
421 let r = ser_residue();
422 assert_eq!(r.res_type(), ResidueType::Ser);
423 }
424
425 #[test]
426 fn residue_accessors_round_trip() {
427 let r = ser_residue();
428 assert_eq!(r.anchor()[1], v(1.5, 0.0, 0.0));
429 assert_eq!(r.phi(), -1.0);
430 assert_eq!(r.psi(), 1.0);
431 assert_eq!(r.omega(), PI);
432 assert_eq!(r.sidechain().len(), 5);
433 assert_eq!(r.atom_types().len(), 5);
434 assert_eq!(r.atom_charges().len(), 5);
435 assert_eq!(r.donor_of_h().len(), 5);
436 }
437
438 #[test]
439 fn residue_set_sidechain_overwrites() {
440 let mut r = ser_residue();
441 let new_coords = [v(9.0, 9.0, 9.0); 5];
442 r.set_sidechain(&new_coords);
443 assert_eq!(r.sidechain().len(), 5);
444 assert!(r.sidechain().iter().all(|&c| c == v(9.0, 9.0, 9.0)));
445 }
446
447 #[test]
448 fn residue_set_sidechain_clears_before_write() {
449 let mut r = ser_residue();
450 r.set_sidechain(&[v(1.0, 2.0, 3.0); 3]);
451 assert_eq!(r.sidechain().len(), 3);
452 r.set_sidechain(&[v(4.0, 5.0, 6.0); 5]);
453 assert_eq!(r.sidechain().len(), 5);
454 assert!(r.sidechain().iter().all(|&c| c == v(4.0, 5.0, 6.0)));
455 }
456
457 #[test]
458 fn lj_matrix_diagonal_lookup() {
459 let m = lj_identity(4);
460 assert_eq!(
461 m.get(t(0), t(0)),
462 LjPair {
463 d0: 1.0,
464 r0_sq: 4.0
465 }
466 );
467 assert_eq!(
468 m.get(t(3), t(3)),
469 LjPair {
470 d0: 1.0,
471 r0_sq: 4.0
472 }
473 );
474 }
475
476 #[test]
477 fn lj_matrix_off_diagonal_zero() {
478 let m = lj_identity(4);
479 assert_eq!(
480 m.get(t(0), t(1)),
481 LjPair {
482 d0: 0.0,
483 r0_sq: 0.0
484 }
485 );
486 assert_eq!(
487 m.get(t(2), t(3)),
488 LjPair {
489 d0: 0.0,
490 r0_sq: 0.0
491 }
492 );
493 }
494
495 #[test]
496 fn lj_matrix_symmetric_fill() {
497 let n = 3usize;
498 let mut data = vec![
499 LjPair {
500 d0: 0.0,
501 r0_sq: 0.0
502 };
503 n * n
504 ];
505 data[0 * n + 1] = LjPair {
506 d0: 2.0,
507 r0_sq: 8.0,
508 };
509 data[1 * n + 0] = LjPair {
510 d0: 2.0,
511 r0_sq: 8.0,
512 };
513 let m = LjMatrix::new(n, data);
514 assert_eq!(m.get(t(0), t(1)), m.get(t(1), t(0)));
515 }
516
517 #[test]
518 fn buck_matrix_diagonal_lookup() {
519 let m = buck_identity(4);
520 let diag = BuckPair {
521 a: 1.0,
522 b: 0.5,
523 c: 2.0,
524 r_max_sq: 4.0,
525 two_e_max: 0.1,
526 };
527 assert_eq!(m.get(t(0), t(0)), diag);
528 assert_eq!(m.get(t(3), t(3)), diag);
529 }
530
531 #[test]
532 fn buck_matrix_off_diagonal_zero() {
533 let m = buck_identity(4);
534 let zero = BuckPair {
535 a: 0.0,
536 b: 0.0,
537 c: 0.0,
538 r_max_sq: 0.0,
539 two_e_max: 0.0,
540 };
541 assert_eq!(m.get(t(0), t(1)), zero);
542 assert_eq!(m.get(t(2), t(3)), zero);
543 }
544
545 #[test]
546 fn buck_matrix_symmetric_fill() {
547 let n = 3usize;
548 let pair = BuckPair {
549 a: 1.0,
550 b: 0.5,
551 c: 2.0,
552 r_max_sq: 4.0,
553 two_e_max: 0.1,
554 };
555 let zero = BuckPair {
556 a: 0.0,
557 b: 0.0,
558 c: 0.0,
559 r_max_sq: 0.0,
560 two_e_max: 0.0,
561 };
562 let mut data = vec![zero; n * n];
563 data[0 * n + 1] = pair;
564 data[1 * n + 0] = pair;
565 let m = BuckMatrix::new(n, data);
566 assert_eq!(m.get(t(0), t(1)), m.get(t(1), t(0)));
567 }
568
569 #[test]
570 fn hbond_candidate_both_directions() {
571 let mut h_types = HashSet::new();
572 let mut acc_types = HashSet::new();
573 h_types.insert(t(1));
574 acc_types.insert(t(2));
575 let p = HBondParams::new(h_types, acc_types, HashMap::new());
576
577 assert!(p.is_hbond_candidate(t(1), t(2)));
578 assert!(p.is_hbond_candidate(t(2), t(1)));
579 assert!(!p.is_hbond_candidate(t(0), t(3)));
580 }
581
582 #[test]
583 fn hbond_get_returns_params() {
584 let mut h_types = HashSet::new();
585 let mut acc_types = HashSet::new();
586 let mut params = HashMap::new();
587 h_types.insert(t(1));
588 acc_types.insert(t(3));
589 params.insert((t(0), t(1), t(3)), (5.0f32, 25.0f32));
590 let p = HBondParams::new(h_types, acc_types, params);
591
592 assert_eq!(p.get(t(0), t(1), t(3)), Some((5.0, 25.0)));
593 assert_eq!(p.get(t(0), t(1), t(0)), None);
594 }
595
596 #[test]
597 fn hbond_empty_never_candidate() {
598 let p = empty_hbond();
599 assert!(!p.is_hbond_candidate(t(0), t(1)));
600 }
601
602 #[test]
603 fn system_mobile_len() {
604 let system = System {
605 mobile: vec![ser_residue(), ser_residue()],
606 fixed: FixedAtomPool {
607 positions: vec![],
608 types: vec![],
609 charges: vec![],
610 donor_for_h: vec![],
611 },
612 ff: ForceFieldParams {
613 vdw: VdwMatrix::LennardJones(lj_identity(4)),
614 hbond: empty_hbond(),
615 },
616 };
617 assert_eq!(system.mobile.len(), 2);
618 assert_eq!(system.fixed.positions.len(), 0);
619 }
620
621 #[test]
622 fn system_fixed_pool_fields_consistent() {
623 let n = 3;
624 let fixed = FixedAtomPool {
625 positions: vec![v(0.0, 0.0, 0.0); n],
626 types: vec![t(0); n],
627 charges: vec![0.0f32; n],
628 donor_for_h: vec![u32::MAX; n],
629 };
630 assert_eq!(fixed.positions.len(), fixed.types.len());
631 assert_eq!(fixed.types.len(), fixed.charges.len());
632 assert_eq!(fixed.charges.len(), fixed.donor_for_h.len());
633 }
634
635 #[test]
636 #[should_panic]
637 fn residue_new_panics_when_coords_exceed_max_sidechain_atoms() {
638 let anchor = [v(0.0, 0.0, 0.0); 3];
639 let n = MAX_SIDECHAIN_ATOMS + 1;
640 let coords = vec![v(1.0, 0.0, 0.0); n];
641 let types = vec![t(1); n];
642 let charges = vec![0.1f32; n];
643 let donor_of_h = vec![u8::MAX; n];
644 Residue::new(
645 ResidueType::Ser,
646 anchor,
647 0.0,
648 0.0,
649 PI,
650 SidechainAtoms {
651 coords: &coords,
652 types: &types,
653 charges: &charges,
654 donor_of_h: &donor_of_h,
655 },
656 );
657 }
658
659 #[test]
660 #[should_panic]
661 fn residue_new_panics_on_types_length_mismatch() {
662 let anchor = [v(0.0, 0.0, 0.0); 3];
663 let coords = [v(1.0, 1.0, 0.0); 5];
664 let types = [t(1); 3];
665 let charges = [0.1f32; 5];
666 let donor_of_h = [u8::MAX; 5];
667 Residue::new(
668 ResidueType::Ser,
669 anchor,
670 0.0,
671 0.0,
672 PI,
673 SidechainAtoms {
674 coords: &coords,
675 types: &types,
676 charges: &charges,
677 donor_of_h: &donor_of_h,
678 },
679 );
680 }
681
682 #[test]
683 #[should_panic]
684 fn residue_new_panics_on_charges_length_mismatch() {
685 let anchor = [v(0.0, 0.0, 0.0); 3];
686 let coords = [v(1.0, 1.0, 0.0); 5];
687 let types = [t(1); 5];
688 let charges = [0.1f32; 4];
689 let donor_of_h = [u8::MAX; 5];
690 Residue::new(
691 ResidueType::Ser,
692 anchor,
693 0.0,
694 0.0,
695 PI,
696 SidechainAtoms {
697 coords: &coords,
698 types: &types,
699 charges: &charges,
700 donor_of_h: &donor_of_h,
701 },
702 );
703 }
704
705 #[test]
706 #[should_panic]
707 fn residue_new_panics_on_donor_of_h_length_mismatch() {
708 let anchor = [v(0.0, 0.0, 0.0); 3];
709 let coords = [v(1.0, 1.0, 0.0); 5];
710 let types = [t(1); 5];
711 let charges = [0.1f32; 5];
712 let donor_of_h = [u8::MAX; 2];
713 Residue::new(
714 ResidueType::Ser,
715 anchor,
716 0.0,
717 0.0,
718 PI,
719 SidechainAtoms {
720 coords: &coords,
721 types: &types,
722 charges: &charges,
723 donor_of_h: &donor_of_h,
724 },
725 );
726 }
727
728 #[test]
729 #[should_panic]
730 fn lj_matrix_new_panics_on_wrong_data_length() {
731 LjMatrix::new(
732 3,
733 vec![
734 LjPair {
735 d0: 1.0,
736 r0_sq: 1.0
737 };
738 8
739 ],
740 );
741 }
742
743 #[test]
744 #[should_panic]
745 fn lj_matrix_new_panics_on_asymmetric() {
746 let zero = LjPair {
747 d0: 0.0,
748 r0_sq: 0.0,
749 };
750 let mut data = vec![zero; 4];
751 data[0 * 2 + 1] = LjPair {
752 d0: 1.0,
753 r0_sq: 1.0,
754 };
755 LjMatrix::new(2, data);
756 }
757
758 #[test]
759 #[should_panic]
760 fn buck_matrix_new_panics_on_wrong_data_length() {
761 let p = BuckPair {
762 a: 0.0,
763 b: 0.0,
764 c: 0.0,
765 r_max_sq: 0.0,
766 two_e_max: 0.0,
767 };
768 BuckMatrix::new(3, vec![p; 8]);
769 }
770
771 #[test]
772 #[should_panic]
773 fn buck_matrix_new_panics_on_asymmetric() {
774 let zero = BuckPair {
775 a: 0.0,
776 b: 0.0,
777 c: 0.0,
778 r_max_sq: 0.0,
779 two_e_max: 0.0,
780 };
781 let mut data = vec![zero; 4];
782 data[0 * 2 + 1] = BuckPair {
783 a: 1.0,
784 b: 0.5,
785 c: 2.0,
786 r_max_sq: 4.0,
787 two_e_max: 0.1,
788 };
789 BuckMatrix::new(2, data);
790 }
791
792 #[test]
793 #[should_panic]
794 fn set_sidechain_panics_on_overflow() {
795 let mut r = ser_residue();
796 let too_many = vec![v(1.0, 0.0, 0.0); MAX_SIDECHAIN_ATOMS + 1];
797 r.set_sidechain(&too_many);
798 }
799}