1use itertools::Itertools;
2use na::base::{DMatrix, Vector3};
3use std::collections::{HashMap, HashSet};
4
5use crate::constants::get_element;
6
7#[derive(Clone)]
8pub struct Atom {
9 pub element: String,
10 pub x: f64,
11 pub y: f64,
12 pub z: f64,
13}
14
15fn compute_distance_matrix(atoms: &Vec<Atom>) -> DMatrix<f64> {
16 let n_atoms = atoms.len();
17 let mut dist = DMatrix::zeros(n_atoms, n_atoms);
18
19 for i in 0..n_atoms {
20 let a = Vector3::new(atoms[i].x, atoms[i].y, atoms[i].z);
21 for j in 0..i {
22 if i == j {
23 continue;
24 }
25 let b = Vector3::new(atoms[j].x, atoms[j].y, atoms[j].z);
26 let norm = (b - a).norm();
27 dist[(i, j)] = norm;
28 dist[(j, i)] = norm;
29 }
30 }
31
32 dist
33}
34
35fn vdw_distance(symbol_a: &str, symbol_b: &str, scaling_factor: Option<f64>) -> Option<f64> {
36 let a = get_element(&symbol_a).unwrap();
37 let b = get_element(&symbol_b).unwrap();
38
39 let mut r: f64 = 0.0;
40 if let Some(v) = a.van_del_waals_radius {
41 r += v as f64;
42 } else {
43 return None;
44 }
45
46 if let Some(v) = b.van_del_waals_radius {
47 r += v as f64;
48 } else {
49 return None;
50 }
51
52 r *= scaling_factor.unwrap_or(1.0) * 1e-2; Some(r)
54}
55
56fn make_vdw_bond_table(atoms: &Vec<Atom>, scaling_factor: Option<f64>) -> HashMap<String, f64> {
57 let mut table: HashMap<String, f64> = HashMap::new();
58 let unique_elements: HashSet<String> = atoms.iter().map(|a| a.element.clone()).collect();
59
60 let combinations: HashSet<_> = unique_elements
61 .clone()
62 .into_iter()
63 .combinations_with_replacement(2)
64 .collect();
65 let permutations: HashSet<_> = unique_elements.into_iter().permutations(2).collect();
66
67 let union = combinations.union(&permutations);
68 for pair in union {
69 let vdw = vdw_distance(&pair[0], &pair[1], scaling_factor).unwrap_or(-1.0_f64);
70 table.insert(format!("{}-{}", pair[0].clone(), pair[1].clone()), vdw);
71 }
72 table
73}
74
75fn get_adjacency_matrix(atoms: &Vec<Atom>) -> DMatrix<bool> {
76 let n_atoms = atoms.len();
77 let mut interact = DMatrix::from_element(n_atoms, n_atoms, false);
78 let distance = compute_distance_matrix(&atoms);
79 let vdw_table = make_vdw_bond_table(&atoms, Some(0.5));
80
81 for i in 0..n_atoms {
82 let a = atoms[i].element.clone();
83 for j in 0..i {
84 if i == j {
85 continue;
86 }
87 let b = atoms[j].element.clone();
88 let vdw_distance = vdw_table
89 .get(&format!("{}-{}", &a, &b))
90 .cloned()
91 .unwrap_or(-1.0_f64);
92 interact[(i, j)] = distance[(i, j)] <= vdw_distance;
93 interact[(j, i)] = distance[(i, j)] <= vdw_distance;
94 }
95 }
96
97 interact
98}
99
100pub fn get_fragment_indices(atoms: &Vec<Atom>) -> Vec<Vec<usize>> {
101 let mut fragments = Vec::new();
102 let adj = get_adjacency_matrix(&atoms);
103 let mut visited = vec![false; adj.nrows()];
104
105 for i in 0..adj.nrows() {
106 if !visited[i] {
107 let mut fragment = Vec::new();
108 dfs(i, &adj, &mut visited, &mut fragment);
109 fragments.push(fragment);
110 }
111 }
112
113 fragments
114}
115
116fn dfs(
117 atom_index: usize,
118 adjacency_matrix: &DMatrix<bool>,
119 visited: &mut Vec<bool>,
120 fragment: &mut Vec<usize>,
121) {
122 visited[atom_index] = true;
123 fragment.push(atom_index);
124
125 for j in 0..adjacency_matrix.ncols() {
126 if adjacency_matrix[(atom_index, j)] && !visited[j] {
127 dfs(j, adjacency_matrix, visited, fragment);
128 }
129 }
130}
131
132pub fn get_fragments(atoms: &Vec<Atom>) -> Vec<Vec<Atom>> {
133 let frag_index_groups = get_fragment_indices(&atoms);
134 let mut fragments = Vec::new();
135
136 for frag_group in frag_index_groups {
137 let mut fragment = Vec::new();
138 for ifrag in frag_group {
139 fragment.push(atoms[ifrag].clone());
140 }
141 fragments.push(fragment);
142 }
143
144 fragments
145}