1use std::{collections::HashMap, hash::Hash, iter, mem};
4
5use bit_set::BitSet;
6use clap::ValueEnum;
7use petgraph::{
8 graph::{EdgeIndex, Graph, NodeIndex},
9 Undirected,
10};
11
12use crate::{
13 molecule::{AtomOrBond, Index, Molecule},
14 nauty::CanonLabeling,
15 utils::node_count_under_edge_mask,
16};
17
18#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)]
20pub enum CanonizeMode {
21 Nauty,
24 Faulon,
27 TreeNauty,
29 TreeFaulon,
31}
32
33#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
35pub enum Labeling {
36 Nauty(CanonLabeling<AtomOrBond>),
38
39 Faulon(String),
43
44 Tree(Vec<u8>),
46}
47
48pub fn canonize(mol: &Molecule, subgraph: &BitSet, mode: CanonizeMode) -> Labeling {
51 match mode {
52 CanonizeMode::Nauty => {
53 let cgraph = subgraph_to_cgraph(mol, subgraph);
54 Labeling::Nauty(CanonLabeling::new(&cgraph))
55 }
56 CanonizeMode::TreeNauty => {
57 if is_tree(mol, subgraph) {
58 Labeling::Tree(tree_canonize(mol, subgraph))
59 } else {
60 let cgraph = subgraph_to_cgraph(mol, subgraph);
61 Labeling::Nauty(CanonLabeling::new(&cgraph))
62 }
63 }
64 _ => {
65 panic!("The chosen --canonize mode is not implemented yet!")
66 }
67 }
68}
69
70type CGraph = Graph<AtomOrBond, (), Undirected, Index>;
72
73fn subgraph_to_cgraph(mol: &Molecule, subgraph: &BitSet) -> CGraph {
75 let mut h = CGraph::with_capacity(subgraph.len(), 2 * subgraph.len());
76 let mut vtx_map = HashMap::<NodeIndex, NodeIndex>::new();
77 for e in subgraph {
78 let eix = EdgeIndex::new(e);
79 let (src, dst) = mol.graph().edge_endpoints(eix).unwrap();
80 let src_w = mol.graph().node_weight(src).unwrap();
81 let dst_w = mol.graph().node_weight(dst).unwrap();
82 let e_w = mol.graph().edge_weight(eix).unwrap();
83
84 let h_enode = h.add_node(AtomOrBond::Bond(*e_w));
85
86 let h_src = vtx_map
87 .entry(src)
88 .or_insert(h.add_node(AtomOrBond::Atom(*src_w)));
89 h.add_edge(*h_src, h_enode, ());
90
91 let h_dst = vtx_map
92 .entry(dst)
93 .or_insert(h.add_node(AtomOrBond::Atom(*dst_w)));
94 h.add_edge(*h_dst, h_enode, ());
95 }
96 h
97}
98
99fn is_tree(mol: &Molecule, subgraph: &BitSet) -> bool {
101 node_count_under_edge_mask(mol.graph(), subgraph) == subgraph.len() + 1
102}
103
104fn wrap_with_delimiters(data: Vec<u8>) -> impl Iterator<Item = u8> {
106 iter::once(u8::MIN).chain(data).chain(iter::once(u8::MAX))
107}
108
109fn collapse_set(mut set: Vec<Vec<u8>>) -> Vec<u8> {
113 set.sort_unstable();
114 set.into_iter().flat_map(wrap_with_delimiters).collect()
115}
116
117fn tree_canonize(mol: &Molecule, subgraph: &BitSet) -> Vec<u8> {
126 let graph = mol.graph();
127 let order = graph.node_count();
128
129 let mut adjacencies = vec![BitSet::with_capacity(order); order];
132
133 let mut partial_canonical_sets = vec![Vec::<Vec<u8>>::new(); order];
137
138 let mut unlabeled_vertices = BitSet::with_capacity(order);
140
141 for ix in subgraph.iter() {
144 let (u, v) = graph
145 .edge_endpoints(EdgeIndex::new(ix))
146 .expect("malformed bitset!");
147
148 for node in [u, v] {
149 let index = node.index();
150 if unlabeled_vertices.contains(index) {
151 continue;
152 }
153 unlabeled_vertices.insert(index);
154 let weight = graph.node_weight(node).unwrap();
155 partial_canonical_sets[index].push(vec![weight.element().repr()]);
156 }
157
158 let (u, v) = (u.index(), v.index());
159 adjacencies[u].insert(v);
160 adjacencies[v].insert(u);
161 }
162
163 while unlabeled_vertices.len() > 2 {
165 let leaves = unlabeled_vertices
166 .iter()
167 .filter(|&i| adjacencies[i].len() == 1)
168 .collect::<Vec<_>>();
169
170 for leaf in leaves {
171 let parent = adjacencies[leaf].iter().next().unwrap();
172 let edge = graph
173 .edges_connecting(NodeIndex::new(parent), NodeIndex::new(leaf))
174 .next()
175 .unwrap();
176
177 let mut canonical_label = vec![(*edge.weight()).repr()];
181 canonical_label.extend(collapse_set(mem::take(&mut partial_canonical_sets[leaf])));
182 partial_canonical_sets[parent].push(canonical_label);
183
184 adjacencies[leaf].clear();
186 adjacencies[parent].remove(leaf);
187 unlabeled_vertices.remove(leaf);
188 }
189 }
190
191 if unlabeled_vertices.len() == 2 {
192 let mut iter = unlabeled_vertices.iter();
196 let (u, v) = (iter.next().unwrap(), iter.next().unwrap());
197 let edge = graph
198 .edges_connecting(NodeIndex::new(u), NodeIndex::new(v))
199 .next()
200 .unwrap();
201
202 let u = collapse_set(mem::take(&mut partial_canonical_sets[u]));
203 let v = collapse_set(mem::take(&mut partial_canonical_sets[v]));
204
205 let (first, second) = if u < v { (u, v) } else { (v, u) };
206
207 [(*edge.weight()).repr()]
208 .into_iter()
209 .chain(wrap_with_delimiters(first))
210 .chain(wrap_with_delimiters(second))
211 .collect()
212 } else {
213 let canonical_root = unlabeled_vertices.iter().next().unwrap();
215 let canonical_set = mem::take(&mut partial_canonical_sets[canonical_root]);
216 collapse_set(canonical_set)
217 }
218}
219
220mod tests {
221 #[allow(unused_imports)]
222 use super::*;
223
224 #[allow(unused_imports)]
225 use petgraph::algo::is_isomorphic_matching;
226
227 #[test]
228 fn noncanonical() {
229 let mut p3_010 = Graph::<u8, (), Undirected>::new_undirected();
230 let n0 = p3_010.add_node(0);
231 let n1 = p3_010.add_node(1);
232 let n2 = p3_010.add_node(0);
233 p3_010.add_edge(n0, n1, ());
234 p3_010.add_edge(n1, n2, ());
235
236 let mut p3_001 = Graph::<u8, (), Undirected>::new_undirected();
237 let n0 = p3_001.add_node(0);
238 let n1 = p3_001.add_node(0);
239 let n2 = p3_001.add_node(1);
240 p3_001.add_edge(n0, n1, ());
241 p3_001.add_edge(n1, n2, ());
242
243 let repr_a = CanonLabeling::new(&p3_010);
244 let repr_b = CanonLabeling::new(&p3_001);
245
246 assert_ne!(repr_a, repr_b);
247 }
248
249 #[test]
250 fn nonisomorphic() {
251 let mut p3_010 = Graph::<u8, (), Undirected>::new_undirected();
252 let n0 = p3_010.add_node(0);
253 let n1 = p3_010.add_node(1);
254 let n2 = p3_010.add_node(0);
255 p3_010.add_edge(n0, n1, ());
256 p3_010.add_edge(n1, n2, ());
257
258 let mut p3_001 = Graph::<u8, (), Undirected>::new_undirected();
259 let n0 = p3_001.add_node(0);
260 let n1 = p3_001.add_node(0);
261 let n2 = p3_001.add_node(1);
262 p3_001.add_edge(n0, n1, ());
263 p3_001.add_edge(n1, n2, ());
264
265 assert!(!is_isomorphic_matching(
266 &p3_001,
267 &p3_010,
268 |e0, e1| e0 == e1,
269 |n0, n1| n0 == n1
270 ))
271 }
272}