modular_decomposition/md_tree/
mod.rs1mod functions;
2
3use std::fmt::{Debug, Display, Formatter};
4use std::iter::from_fn;
5
6use petgraph::graph::DiGraph;
7use petgraph::{Incoming, Outgoing};
8
9use crate::index::make_index;
10
11make_index!(pub(crate) NodeIndex);
12
13#[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
21pub enum ModuleKind<NodeId: Copy + PartialEq> {
22 Prime,
24 Series,
26 Parallel,
28 Node(NodeId),
30}
31
32impl<NodeId: Debug + Copy + PartialEq> Debug for ModuleKind<NodeId> {
33 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34 match self {
35 ModuleKind::Prime => {
36 write!(f, "Prime")
37 }
38 ModuleKind::Series => {
39 write!(f, "Series")
40 }
41 ModuleKind::Parallel => {
42 write!(f, "Parallel")
43 }
44 ModuleKind::Node(v) => {
45 write!(f, "{v:?}")
46 }
47 }
48 }
49}
50
51#[derive(Clone, Debug)]
62pub struct MDTree<NodeId: Copy + PartialEq> {
63 tree: DiGraph<ModuleKind<NodeId>, ()>,
64 root: ModuleIndex,
65}
66
67#[derive(Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
69pub struct ModuleIndex(pub(crate) petgraph::graph::NodeIndex);
70
71impl Debug for ModuleIndex {
72 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73 f.debug_tuple("ModuleIndex").field(&self.0.index()).finish()
74 }
75}
76
77impl ModuleIndex {
78 pub fn new(x: usize) -> Self {
80 Self(petgraph::graph::NodeIndex::new(x))
81 }
82
83 pub fn index(&self) -> usize {
85 self.0.index()
86 }
87}
88
89impl<NodeId: Copy + PartialEq> MDTree<NodeId> {
90 pub(crate) fn from_digraph(tree: DiGraph<ModuleKind<NodeId>, ()>) -> Result<Self, NullGraphError> {
100 if tree.node_count() == 0 {
101 return Err(NullGraphError);
102 }
103 let root = tree.externals(Incoming).next().expect("non-null trees must have a root");
104 let root = ModuleIndex(root);
105 Ok(Self { tree, root })
106 }
107
108 #[inline(always)]
112 pub fn strong_module_count(&self) -> usize {
113 self.tree.node_count()
114 }
115
116 #[inline(always)]
118 pub fn root(&self) -> ModuleIndex {
119 self.root
120 }
121
122 pub fn module_kind(&self, module: ModuleIndex) -> Option<&ModuleKind<NodeId>> {
126 self.tree.node_weight(module.0)
127 }
128
129 pub fn module_kinds(&self) -> impl Iterator<Item = &ModuleKind<NodeId>> {
131 self.tree.node_weights()
132 }
133
134 pub fn children(&self, module: ModuleIndex) -> impl Iterator<Item = ModuleIndex> + '_ {
136 self.tree.neighbors_directed(module.0, Outgoing).map(ModuleIndex)
137 }
138
139 pub fn nodes(&self, module: ModuleIndex) -> impl Iterator<Item = NodeId> + '_ {
144 let mut stack = vec![module.0];
145 from_fn(move || {
146 while let Some(next) = stack.pop() {
147 if let Some(ModuleKind::Node(node)) = self.tree.node_weight(next) {
148 return Some(*node);
149 } else {
150 let children = self.tree.neighbors_directed(next, Outgoing);
151 stack.extend(children);
152 }
153 }
154 None
155 })
156 }
157
158 pub fn into_digraph(self) -> DiGraph<ModuleKind<NodeId>, ()> {
197 self.tree
198 }
199}
200
201#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
203pub struct NullGraphError;
204
205impl Display for NullGraphError {
206 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
207 f.write_str("graph does not contain any nodes or edges")
208 }
209}
210
211impl std::error::Error for NullGraphError {}
212
213#[cfg(test)]
214mod test {
215 use petgraph::graph::{DiGraph, NodeIndex};
216 use petgraph::visit::IntoNodeIdentifiers;
217 use petgraph::Outgoing;
218
219 use crate::md_tree::NullGraphError;
220 use crate::tests::{complete_graph, pace2023_exact_024};
221 use crate::{modular_decomposition, MDTree, ModuleIndex, ModuleKind};
222
223 #[test]
224 fn nodes() {
225 let graph = pace2023_exact_024();
226 let md = modular_decomposition(&graph).unwrap();
227 let mut module_nodes: Vec<(ModuleKind<_>, Vec<_>)> = (0..md.strong_module_count())
228 .map(ModuleIndex::new)
229 .map(|module| (*md.module_kind(module).unwrap(), md.nodes(module).map(|node| node.index()).collect()))
230 .collect();
231 module_nodes.iter_mut().for_each(|(_, nodes)| nodes.sort());
232 module_nodes.sort();
233
234 for (kind, nodes) in &module_nodes {
235 if nodes.len() == 1 {
236 assert_eq!(*kind, ModuleKind::Node(NodeIndex::new(nodes[0])));
237 }
238 if nodes.len() == graph.node_count() {
239 assert_eq!(*nodes, graph.node_identifiers().map(|node| node.index()).collect::<Vec<_>>());
240 }
241 }
242
243 module_nodes.retain(|(_, nodes)| nodes.len() > 1);
244
245 let expected = [
246 (
247 ModuleKind::Prime,
248 vec![
249 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21, 22, 23, 26, 27, 28, 29, 30, 31,
250 32, 33, 34, 35, 36, 37, 38, 39,
251 ],
252 ),
253 (ModuleKind::Series, vec![17, 18, 19]),
254 (ModuleKind::Series, vec![24, 25]),
255 (
256 ModuleKind::Parallel,
257 vec![
258 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
259 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
260 ],
261 ),
262 (ModuleKind::Parallel, vec![20, 30]),
263 ];
264 assert_eq!(module_nodes, expected);
265 }
266
267 #[test]
268 fn mdtree_and_digraph_are_equivalent() {
269 let graph = complete_graph(5);
270 let md = modular_decomposition(&graph).unwrap();
271 let root = md.root();
272
273 assert_eq!(md.module_kind(root), Some(&ModuleKind::Series));
274
275 let children: Vec<_> = md.children(root).collect();
276 assert_eq!(md.module_kind(children[0]), Some(&ModuleKind::Node(NodeIndex::new(0))));
277
278 let md = md.into_digraph();
279 let root = NodeIndex::new(root.index());
280
281 let children: Vec<_> = md.neighbors_directed(root, Outgoing).collect();
282 assert_eq!(md.node_weight(root), Some(&ModuleKind::Series));
283 assert_eq!(md.node_weight(children[0]), Some(&ModuleKind::Node(NodeIndex::new(0))));
284 }
285
286 #[test]
287 fn null_graph_error() {
288 let digraph: DiGraph<ModuleKind<NodeIndex>, ()> = Default::default();
289 let err = MDTree::from_digraph(digraph).unwrap_err();
290 assert_eq!(err, NullGraphError);
291 assert_eq!(format!("{}", err), "graph does not contain any nodes or edges".to_string());
292 }
293
294 #[test]
295 fn module_index_fmt() {
296 let idx = ModuleIndex::new(42);
297 assert_eq!(format!("{:?}", idx), "ModuleIndex(42)".to_string())
298 }
299}