use std::collections::{BTreeMap, BTreeSet};
use crate::{
angle::Angle,
bond::{Bond, BondOrder},
dihedral::Dihedral,
error::CError,
improper::Improper,
};
#[derive(Default, Debug, Clone)]
pub struct Connectivity {
pub(crate) bonds: BTreeMap<Bond, BondOrder>,
pub(crate) angles: BTreeSet<Angle>,
pub(crate) dihedrals: BTreeSet<Dihedral>,
pub(crate) impropers: BTreeSet<Improper>,
up_to_date: bool,
biggest_atom: usize,
}
impl Connectivity {
pub fn angles(&mut self) -> &BTreeSet<Angle> {
if !self.up_to_date {
self.recalculate();
}
&self.angles
}
pub fn dihedrals(&mut self) -> &BTreeSet<Dihedral> {
if !self.up_to_date {
self.recalculate();
}
&self.dihedrals
}
pub fn impropers(&mut self) -> &BTreeSet<Improper> {
if !self.up_to_date {
self.recalculate();
}
&self.impropers
}
pub fn add_bond(&mut self, i: usize, j: usize, bond_order: BondOrder) {
self.up_to_date = false;
if i > self.biggest_atom {
self.biggest_atom = i;
}
if j > self.biggest_atom {
self.biggest_atom = j;
}
self.bonds.entry(Bond::new(i, j)).or_insert(bond_order);
}
pub fn remove_bond(&mut self, i: usize, j: usize) {
let bond = Bond::new(i, j);
if self.bonds.remove(&bond).is_some() {
self.up_to_date = false;
}
}
pub fn bond_order(&self, i: usize, j: usize) -> Result<BondOrder, CError> {
let bond = Bond::new(i, j);
self.bonds.get(&bond).copied().ok_or_else(|| {
CError::GenericError(format!(
"out of bounds atomic index. No bond between {i} and {j} exists"
))
})
}
fn recalculate(&mut self) {
self.angles.clear();
self.dihedrals.clear();
self.impropers.clear();
let mut bonded_to = vec![Vec::with_capacity(4); self.biggest_atom + 1];
for bond in self.bonds.keys() {
debug_assert!(bond[0] < bonded_to.len());
debug_assert!(bond[1] < bonded_to.len());
bonded_to[bond[0]].push(bond[1]);
bonded_to[bond[1]].push(bond[0]);
}
for bond in self.bonds.keys() {
let i = bond[0];
let j = bond[1];
for &k in &bonded_to[i] {
if k != j {
self.angles.insert(Angle::new(k, i, j));
}
}
for &k in &bonded_to[j] {
if k != i {
self.angles.insert(Angle::new(i, j, k));
}
}
}
for angle in &self.angles {
let i = angle[0];
let j = angle[1];
let k = angle[2];
for &m in &bonded_to[i] {
if m != j && m != k {
self.dihedrals.insert(Dihedral::new(m, i, j, k));
}
}
for &m in &bonded_to[k] {
if m != i && m != j {
self.dihedrals.insert(Dihedral::new(i, j, k, m));
}
}
for &m in &bonded_to[j] {
if m != i && m != k {
self.impropers.insert(Improper::new(i, j, k, m));
}
}
}
self.up_to_date = true;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_bond() {
let mut connectivity = Connectivity::default();
connectivity.add_bond(1, 2, BondOrder::Single);
assert_eq!(connectivity.bonds.len(), 1);
assert!(connectivity.bonds.contains_key(&Bond::new(1, 2)));
assert_eq!(connectivity.bond_order(1, 2).unwrap(), BondOrder::Single);
connectivity.add_bond(1, 2, BondOrder::Double);
assert_eq!(connectivity.bonds.len(), 1);
assert_eq!(connectivity.bond_order(1, 2).unwrap(), BondOrder::Single);
connectivity.add_bond(2, 1, BondOrder::Single);
assert_eq!(connectivity.bonds.len(), 1);
assert_eq!(connectivity.bond_order(1, 2).unwrap(), BondOrder::Single);
}
#[test]
fn test_remove_bond() {
let mut connectivity = Connectivity::default();
connectivity.add_bond(1, 2, BondOrder::Single);
connectivity.add_bond(2, 3, BondOrder::Double);
connectivity.add_bond(3, 4, BondOrder::Triple);
assert_eq!(connectivity.bonds.len(), 3);
connectivity.remove_bond(1, 2);
assert_eq!(connectivity.bonds.len(), 2);
assert!(!connectivity.bonds.contains_key(&Bond::new(1, 2)));
connectivity.remove_bond(1, 2);
assert_eq!(connectivity.bonds.len(), 2);
}
#[test]
fn test_bond_order() {
let mut connectivity = Connectivity::default();
connectivity.add_bond(1, 2, BondOrder::Single);
connectivity.add_bond(2, 3, BondOrder::Double);
connectivity.add_bond(3, 4, BondOrder::Triple);
assert_eq!(connectivity.bond_order(1, 2).unwrap(), BondOrder::Single);
assert_eq!(connectivity.bond_order(2, 3).unwrap(), BondOrder::Double);
assert_eq!(connectivity.bond_order(3, 4).unwrap(), BondOrder::Triple);
assert_eq!(connectivity.bond_order(2, 1).unwrap(), BondOrder::Single);
assert_eq!(connectivity.bond_order(3, 2).unwrap(), BondOrder::Double);
assert_eq!(connectivity.bond_order(4, 3).unwrap(), BondOrder::Triple);
assert!(connectivity.bond_order(1, 3).is_err());
}
#[test]
fn test_angles_generation() {
let mut connectivity = Connectivity::default();
connectivity.add_bond(1, 2, BondOrder::Single);
connectivity.add_bond(2, 3, BondOrder::Single);
let angles = connectivity.angles();
assert_eq!(angles.len(), 1);
assert!(angles.contains(&Angle::new(1, 2, 3)));
connectivity.add_bond(2, 4, BondOrder::Single);
let angles = connectivity.angles();
assert_eq!(angles.len(), 3);
assert!(angles.contains(&Angle::new(1, 2, 3)));
assert!(angles.contains(&Angle::new(1, 2, 4)));
assert!(angles.contains(&Angle::new(3, 2, 4)));
}
#[test]
fn test_dihedrals_generation() {
let mut connectivity = Connectivity::default();
connectivity.add_bond(1, 2, BondOrder::Single);
connectivity.add_bond(2, 3, BondOrder::Single);
connectivity.add_bond(3, 4, BondOrder::Single);
let dihedrals = connectivity.dihedrals();
assert_eq!(dihedrals.len(), 1);
assert!(dihedrals.contains(&Dihedral::new(1, 2, 3, 4)));
connectivity.add_bond(3, 5, BondOrder::Single);
let dihedrals = connectivity.dihedrals().clone();
let impropers = connectivity.impropers().clone();
assert_eq!(dihedrals.len(), 2);
assert_eq!(impropers.len(), 1);
assert!(dihedrals.contains(&Dihedral::new(1, 2, 3, 4)));
assert!(dihedrals.contains(&Dihedral::new(1, 2, 3, 5)));
assert!(impropers.contains(&Improper::new(2, 3, 4, 5)));
}
#[test]
fn test_impropers_generation() {
let mut connectivity = Connectivity::default();
connectivity.add_bond(1, 2, BondOrder::Single);
connectivity.add_bond(2, 3, BondOrder::Single);
connectivity.add_bond(2, 4, BondOrder::Single);
let impropers = connectivity.impropers();
assert_eq!(impropers.len(), 1);
assert!(impropers.contains(&Improper::new(1, 2, 3, 4)));
connectivity.add_bond(2, 5, BondOrder::Single);
let impropers = connectivity.impropers();
assert_eq!(impropers.len(), 4);
assert!(impropers.contains(&Improper::new(1, 2, 3, 4)));
assert!(impropers.contains(&Improper::new(1, 2, 3, 5)));
assert!(impropers.contains(&Improper::new(1, 2, 4, 5)));
assert!(impropers.contains(&Improper::new(3, 2, 4, 5)));
}
}