use crate::arguments::Args;
use crate::atoms::Atoms;
use crate::grid::Grid;
use crate::hash::{IntMap, IntSet};
use crate::progress::{Bar, HiddenBar, ProgressBar};
use crate::threading::parallel_prune;
use crate::utils::{cross, norm, subtract, vdot};
use crate::voxel_map::EncodedAtom;
#[derive(Clone)]
pub struct CriticalPoint {
pub position: isize,
pub kind: CriticalPointKind,
pub atoms: Box<[EncodedAtom]>,
}
impl CriticalPoint {
pub fn new(
position: isize,
kind: CriticalPointKind,
atoms: Box<[EncodedAtom]>,
) -> Self {
CriticalPoint {
position,
kind,
atoms,
}
}
}
#[derive(Eq, Ord, PartialEq, PartialOrd, Debug, Clone, Copy)]
pub enum CriticalPointKind {
Nuclei,
Bond,
Ring,
Cage,
Blank,
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct CriticalPointKey(Vec<EncodedAtom>);
impl CriticalPointKey {
pub fn from_cp(cp: CriticalPoint) -> Self {
let mut atoms = cp.atoms.to_vec();
atoms.sort_unstable();
if let Some(anchor) = atoms.first() {
let image = anchor.image();
atoms
.iter_mut()
.for_each(|atom| *atom = atom.image_sub(image));
}
atoms.sort_unstable();
Self(atoms)
}
pub fn into_box(self) -> Box<[EncodedAtom]> {
self.0.into()
}
}
pub fn nuclei_ordering(
nuclei: &mut [CriticalPoint],
density: &[f64],
atoms: &Atoms,
grid: &Grid,
visible_bar: bool,
) -> Vec<CriticalPoint> {
let atom_len = atoms.positions.len();
let progress_bar: Box<dyn ProgressBar> = match visible_bar {
false => Box::new(HiddenBar {}),
true => Box::new(Bar::new(
nuclei.len() + atom_len,
String::from("Pruning Nucleus Critical Points"),
)),
};
let pbar = &progress_bar;
let mut nuclei_sorting = vec![Vec::<usize>::new(); atom_len];
nuclei.iter().enumerate().for_each(|(i, cp)| {
nuclei_sorting[cp.atoms[0].atom_index() as usize].push(i);
pbar.tick();
});
nuclei_sorting
.into_iter()
.map(|indices| {
pbar.tick();
match indices.iter().max_by(|a, b| {
density[nuclei[**a].position as usize]
.total_cmp(&density[nuclei[**b].position as usize])
}) {
Some(index) => {
let true_maximum = nuclei[*index].clone();
if indices.len() > 1 {
let true_position =
grid.to_cartesian(nuclei[*index].position);
indices.iter().for_each(|i| {
if i != index {
if let Some(cp) = nuclei.get_mut(*i) {
let position =
grid.to_cartesian(cp.position);
let image = atoms
.lattice
.closest_image(true_position, position);
*cp = CriticalPoint::new(
cp.position,
CriticalPointKind::Nuclei,
Box::new(
[cp.atoms[0].image_sub(image)],
),
)
}
}
});
}
true_maximum
}
None => CriticalPoint::new(
0,
CriticalPointKind::Blank,
Box::new([]),
),
}
})
.collect()
}
pub fn bond_pruning(
bonds: &[CriticalPoint],
density: &[f64],
args: &Args,
) -> Vec<CriticalPoint> {
let threads = args.threads;
let progress_bar: Box<dyn ProgressBar> = match args.silent {
true => Box::new(HiddenBar {}),
false => Box::new(Bar::new(
bonds.len(),
String::from("Pruning Bond Critical Points"),
)),
};
parallel_prune(bonds, density, |_| true, threads, progress_bar)
}
pub fn bond_adjacency(
bonds: &[CriticalPoint],
atom_len: usize,
) -> Vec<Vec<EncodedAtom>> {
let mut adjacency: Vec<Vec<EncodedAtom>> = vec![Vec::new(); atom_len];
bonds.iter().for_each(|bond| {
adjacency[bond.atoms[0].atom_index() as usize].push(bond.atoms[1]);
adjacency[bond.atoms[1].atom_index() as usize]
.push(bond.atoms[0].image_sub(bond.atoms[1].image()));
});
adjacency
}
pub fn ring_pruning(
rings: &[CriticalPoint],
bond_adjancy: &[Vec<EncodedAtom>],
density: &[f64],
args: &Args,
) -> Vec<CriticalPoint> {
let threads = args.threads;
let progress_bar: Box<dyn ProgressBar> = match args.silent {
true => Box::new(HiddenBar {}),
false => Box::new(Bar::new(
rings.len(),
String::from("Pruning Ring Critical Points"),
)),
};
let validator = |cp: &CriticalPoint| {
let mut previous_index = 0;
let mut current_index = 0;
let mut next_index = 0;
let mut steps = 1;
loop {
let mut intrabonds = 0;
let graph =
&bond_adjancy[cp.atoms[current_index].atom_index() as usize];
cp.atoms.iter().enumerate().for_each(|(i, encoded_atom)| {
if graph.contains(
&encoded_atom.image_sub(cp.atoms[current_index].image()),
) {
intrabonds += 1;
if i != previous_index {
next_index = i;
}
}
});
if intrabonds != 2 {
return false;
}
if next_index == 0 || steps > cp.atoms.len() {
break;
}
steps += 1;
previous_index = current_index;
current_index = next_index;
}
steps == cp.atoms.len()
};
parallel_prune(rings, density, validator, threads, progress_bar)
}
pub fn cage_pruning(
cages: &[CriticalPoint],
ordered_nuclei: &[CriticalPoint],
density: &[f64],
atoms: &Atoms,
grid: &Grid,
args: &Args,
) -> Vec<CriticalPoint> {
let threads = args.threads;
let progress_bar: Box<dyn ProgressBar> = match args.silent {
true => Box::new(HiddenBar {}),
false => Box::new(Bar::new(
cages.len(),
String::from("Pruning Cage Critical Points"),
)),
};
_ = |cp: &CriticalPoint| {
if cp.atoms.len() < 4 {
return false;
}
let positions: Vec<[f64; 3]> = cp.atoms[..3]
.iter()
.map(|encoded_atom| {
let (atom_num, encoded_image) = encoded_atom.decode_partial();
let mut position = grid
.to_cartesian(ordered_nuclei[atom_num as usize].position);
let image = match encoded_image.is_zero() {
true => [0., 0., 0.],
false => {
let image = encoded_image.decode();
atoms.lattice.fractional_to_cartesian([
image[0] as f64,
image[1] as f64,
image[2] as f64,
])
}
};
position.iter_mut().zip(image).for_each(|(f, i)| *f += i);
position
})
.collect();
let vec_1 = subtract(positions[1], positions[0]);
let vec_2 = subtract(positions[2], positions[0]);
let mut plane = cross(vec_1, vec_2);
let plane_normal = norm(plane);
plane.iter_mut().for_each(|f| *f /= plane_normal);
for encoded_atom in cp.atoms[3..].iter() {
let (atom_num, encoded_image) = encoded_atom.decode_partial();
let mut position =
grid.to_cartesian(ordered_nuclei[atom_num as usize].position);
let image = match encoded_image.is_zero() {
true => [0., 0., 0.],
false => {
let image = encoded_image.decode();
atoms.lattice.fractional_to_cartesian([
image[0] as f64,
image[1] as f64,
image[2] as f64,
])
}
};
position.iter_mut().zip(image).for_each(|(f, i)| *f += i);
let vec_3 = subtract(position, positions[0]);
let mut plane_t = cross(vec_1, vec_3);
let plane_normal = norm(plane_t);
plane_t.iter_mut().for_each(|f| *f /= plane_normal);
if vdot(plane, plane_t).abs() < 0.995 {
return true;
}
}
false
};
parallel_prune(cages, density, |_| true, threads, progress_bar)
}
pub fn critical_point_merge(mut cps: Vec<CriticalPoint>) -> Vec<CriticalPoint> {
cps.sort_unstable_by(|a, b| b.atoms.len().cmp(&a.atoms.len()));
let mut merged_points: Vec<CriticalPoint> = Vec::with_capacity(cps.len());
let mut inverted_index: IntMap<u32, Vec<usize>> = IntMap::default();
'critical: for cp in cps.iter() {
let mut superset: Option<IntSet<usize>> = None;
'atom: for atom in cp.atoms.iter() {
if let Some(matches) = inverted_index.get(&atom.atom_index()) {
match superset {
None => superset = Some(matches.iter().copied().collect()),
Some(ref mut set) => {
set.retain(|id| matches.contains(id));
if set.is_empty() {
superset = None;
break 'atom;
}
}
}
} else {
superset = None;
break 'atom;
}
}
if let Some(set) = superset {
for index in set.iter() {
let set_sub = IntSet::from_iter(cp.atoms.iter().copied());
let cp_super = &merged_points[*index];
for encoded_atom in cp_super.atoms.iter() {
let new_anchor = encoded_atom.image();
let rotated_super = cp_super
.atoms
.iter()
.map(|a| a.image_sub(new_anchor))
.collect::<IntSet<EncodedAtom>>();
if set_sub.is_subset(&rotated_super) {
continue 'critical;
}
}
}
}
let i = merged_points.len();
cp.atoms.iter().for_each(|atom| {
inverted_index.entry(atom.atom_index()).or_default().push(i);
});
merged_points.push(cp.clone());
}
merged_points
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
atoms::Lattice,
voxel_map::{EncodedAtom, EncodedImage},
};
fn create_cp(
pos: isize,
kind: CriticalPointKind,
atom_ids: &[u32],
) -> CriticalPoint {
let atoms = atom_ids
.iter()
.map(|&id| EncodedAtom::new(id, EncodedImage::new([0, 0, 0])))
.collect::<Vec<_>>()
.into_boxed_slice();
CriticalPoint::new(pos, kind, atoms)
}
#[test]
fn test_critical_point_key_sorting() {
let cp = create_cp(0, CriticalPointKind::Bond, &[2, 1, 3]);
let key = CriticalPointKey::from_cp(cp);
let ids: Vec<u32> =
key.into_box().iter().map(|a| a.atom_index()).collect();
assert_eq!(ids, vec![1, 2, 3]);
}
#[test]
fn test_critical_point_key_translation() {
let img1 = EncodedImage::new([1, 0, 0]);
let img2 = EncodedImage::new([1, 0, 0]);
let a1 = EncodedAtom::new(1, img1);
let a2 = EncodedAtom::new(2, img2);
let cp = CriticalPoint::new(
0,
CriticalPointKind::Bond,
vec![a1, a2].into_boxed_slice(),
);
let key = CriticalPointKey::from_cp(cp);
for atom in key.into_box().iter() {
assert!(
atom.image().is_zero(),
"Image should be normalised to zero"
);
}
}
#[test]
fn test_nuclei_ordering_simple() {
let atoms = Atoms::new(
Lattice::new([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),
vec![[0.0, 0.0, 0.0]],
String::with_capacity(0),
);
let grid = Grid::new(
[10, 10, 10],
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
[0.0, 0.0, 0.0],
);
let mut density = vec![0.0; 30];
density[10] = 1.0;
density[20] = 5.0;
let cp1 = create_cp(10, CriticalPointKind::Nuclei, &[0]);
let cp2 = create_cp(20, CriticalPointKind::Nuclei, &[0]);
let mut candidates = vec![cp1, cp2];
let ordered =
nuclei_ordering(&mut candidates, &density, &atoms, &grid, false);
assert_eq!(ordered.len(), 1);
assert_eq!(ordered[0].position, 20); }
#[test]
fn test_nuclei_ordering_multiple_atoms() {
let atoms = Atoms::new(
Lattice::new([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),
vec![[0.0, 0.0, 0.0], [0., 0.5, 0.0]],
String::with_capacity(0),
);
let grid = Grid::new(
[10, 10, 10],
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
[0.0, 0.0, 0.0],
);
let mut density = vec![0.0; 100];
density[10] = 5.0; density[50] = 3.0;
let cp0 = create_cp(10, CriticalPointKind::Nuclei, &[0]);
let cp1 = create_cp(50, CriticalPointKind::Nuclei, &[1]);
let mut candidates = vec![cp0, cp1];
let ordered =
nuclei_ordering(&mut candidates, &density, &atoms, &grid, false);
assert_eq!(ordered.len(), 2);
assert_eq!(ordered[0].position, 10);
assert_eq!(ordered[1].position, 50);
}
#[test]
fn test_critical_point_merge_exact_duplicate() {
let cp1 = create_cp(10, CriticalPointKind::Bond, &[1, 2]);
let cp2 = create_cp(11, CriticalPointKind::Bond, &[1, 2]);
let input = vec![cp1, cp2];
let merged = critical_point_merge(input);
assert_eq!(merged.len(), 1);
}
#[test]
fn test_critical_point_merge_subset() {
let cp_ring = create_cp(10, CriticalPointKind::Ring, &[1, 2, 3]);
let cp_bond = create_cp(11, CriticalPointKind::Bond, &[1, 2]);
let input = vec![cp_ring, cp_bond];
let merged = critical_point_merge(input);
assert_eq!(merged.len(), 1);
assert_eq!(merged[0].atoms.len(), 3); }
#[test]
fn test_critical_point_merge_distinct() {
let cp1 = create_cp(10, CriticalPointKind::Bond, &[1, 2]);
let cp2 = create_cp(11, CriticalPointKind::Bond, &[3, 4]);
let input = vec![cp1, cp2];
let merged = critical_point_merge(input);
assert_eq!(merged.len(), 2);
}
}