1use crate::*;
34#[cfg(feature = "hashbrown")]
35use hashbrown::HashMap;
36#[cfg(feature = "hashbrown")]
37use hashbrown::HashSet;
38
39#[cfg(not(feature = "hashbrown"))]
40use std::collections::HashMap;
41#[cfg(not(feature = "hashbrown"))]
42use std::collections::HashSet;
43
44use thiserror::Error;
45
46#[derive(Debug)]
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49pub struct CategorizedGraph<N, E> {
50 pub nodes: SlotMap<NodeID, Node<N>>,
51 pub edges: SlotMap<EdgeID, Edge<E>>,
52 pub categories: HashMap<String, NodeID>,
53}
54
55impl<N, E> GraphInterface for CategorizedGraph<N, E> {
56 type NodeData = N;
57 type EdgeData = E;
58
59 fn nodes(&self) -> impl Iterator<Item = NodeID> {
60 self.nodes.keys()
61 }
62
63 fn node_count(&self) -> usize {
64 self.nodes.len()
65 }
66
67 fn remove_node(&mut self, id: NodeID) -> Result<(), GraphError> {
68 let node = self
69 .nodes
70 .remove(id)
71 .map_or(Err(GraphError::NodeNotFound), |n| Ok(n))?;
72
73 for edge_id in node.connections.iter() {
74 self.remove_edge(*edge_id).or_else(|e| Ok(()))?;
75 }
76
77 Ok(())
78 }
79
80 fn remove_edge(&mut self, id: EdgeID) -> Result<(), GraphError> {
81 let edge = self.edge(id)?;
82 let from = edge.from;
83 let to = edge.to;
84
85 if let Ok(node) = self.node_mut(from) {
86 node.connections.retain(|&x| x != id)
87 }
88
89 if let Ok(node) = self.node_mut(to) {
90 node.connections.retain(|&x| x != id)
91 }
92
93 self.edges
94 .remove(id)
95 .map_or(Err(GraphError::EdgeNotFound), |_| Ok(()))?;
96
97 Ok(())
98 }
99
100 fn add_node(&mut self, data: N) -> NodeID {
101 let id = self.nodes.insert_with_key(|id| Node::new(id, data));
102 id
103 }
104
105 fn add_nodes(&mut self, data: &[N]) -> Vec<NodeID>
106 where
107 N: Clone,
108 {
109 let mut nodes = Vec::new();
110 for data in data {
111 let node = self.add_node(data.clone());
112 nodes.push(node);
113 }
114 nodes
115 }
116
117 fn add_edges(&mut self, data: &[(NodeID, NodeID)]) -> Vec<EdgeID>
118 where
119 E: Default + Clone,
120 N: Clone,
121 {
122 let with_data: Vec<(NodeID, NodeID, E)> = data
123 .iter()
124 .map(|(from, to)| (*from, *to, E::default()))
125 .collect();
126
127 self.add_edges_with_data(&with_data)
128 }
129
130 fn add_edge(&mut self, from: NodeID, to: NodeID, data: E) -> EdgeID {
131 let id = self
132 .edges
133 .insert_with_key(|id| Edge::new(id, from, to, data));
134 if let Some(node) = self.nodes.get_mut(from) {
135 node.add_connection(id);
136 }
137 if let Some(node) = self.nodes.get_mut(to) {
138 node.add_connection(id);
139 }
140 id
141 }
142
143 fn node(&self, id: NodeID) -> Result<&Node<N>, GraphError> {
144 self.nodes.get(id).ok_or(GraphError::NodeNotFound)
145 }
146
147 fn node_mut(&mut self, id: NodeID) -> Result<&mut Node<N>, GraphError> {
148 self.nodes.get_mut(id).ok_or(GraphError::NodeNotFound)
149 }
150
151 fn edge(&self, id: EdgeID) -> Result<&Edge<E>, GraphError> {
152 self.edges.get(id).ok_or(GraphError::EdgeNotFound)
153 }
154
155 fn edge_mut(&mut self, id: EdgeID) -> Result<&mut Edge<E>, GraphError> {
156 self.edges.get_mut(id).ok_or(GraphError::EdgeNotFound)
157 }
158}
159
160impl<N, E> CategorizedGraph<N, E> {
161 pub fn new() -> Self {
162 CategorizedGraph {
163 edges: SlotMap::with_key(),
164 nodes: SlotMap::with_key(),
165 categories: HashMap::new(),
166 }
167 }
168}
169
170#[derive(Clone, Debug, thiserror::Error)]
171#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
172pub enum CategorizedGraphError {
173 #[error("Category identified by `{0}` already exists")]
174 CategoryAlreadyExists(String),
175 #[error("Category identified by `{0}` does not exists")]
176 CategoryNotFound(String),
177}
178
179pub trait Categorized<N, E, C>: GraphInterface<NodeData = N, EdgeData = E> {
181 fn category_id_by_name(&self, category_name: &str) -> Option<&NodeID>;
183
184 fn category_exists(&self, category_name: &str) -> bool {
186 self.category_id_by_name(category_name).is_some()
187 }
188
189 fn add_to_category_by_id(
191 &mut self,
192 category_id: NodeID,
193 nodes: Vec<NodeID>,
194 ) -> Result<(), CategorizedGraphError>
195 where
196 E: Default + Clone,
197 N: Clone,
198 {
199 let category_node = self.node(category_id).map_or(
200 Err(CategorizedGraphError::CategoryNotFound(format!(
201 "NodeID({:?})",
202 category_id
203 ))),
204 |node| Ok(node),
205 )?;
206
207 let edges: Vec<(NodeID, NodeID)> = nodes
208 .iter()
209 .map(|node: &NodeID| (category_node.id, *node))
210 .collect();
211
212 self.add_edges(&edges);
213
214 Ok(())
215 }
216
217 fn insert_category_id_by_name(&mut self, category_name: &str, category_id: NodeID) {
219 }
222
223 fn add_to_category(&mut self, category_name: &str, nodes: Vec<NodeID>) -> NodeID
225 where
226 E: Default + Clone,
227 N: Clone + Default,
228 {
229 let existing: Option<&NodeID> = self.category_id_by_name(category_name);
230 let category_node: NodeID;
231
232 if existing.is_some() {
233 category_node = *existing.unwrap();
234 self.add_to_category_by_id(category_node, nodes).unwrap();
235 } else {
236 category_node = self.add_node(N::default());
237 self.add_to_category_by_id(category_node, nodes).unwrap();
238 self.insert_category_id_by_name(category_name, category_node)
239 }
240
241 category_node
242 }
243
244 fn create_category(
250 &mut self,
251 category: &str,
252 nodes: Vec<NodeID>,
253 data: C,
254 ) -> Result<NodeID, String>
255 where
256 E: Default + Clone,
257 N: Clone + Default;
258
259 fn all_categories(&self) -> Vec<(&String, NodeID)>;
261
262 fn category(&self, category: &str) -> Option<&Node<N>>;
264
265 fn category_exists_by_id(&self, category: NodeID) -> bool {
267 self.category_by_id(category).is_ok()
268 }
269
270 fn category_by_id(&self, category: NodeID) -> Result<&Node<N>, GraphError>;
272
273 fn nodes_by_category_id(&self, category: NodeID) -> Vec<NodeID>;
275
276 fn nodes_by_category(&self, category: &str) -> Vec<NodeID>;
278
279 fn nodes_by_categories(&self, categories: Vec<&str>) -> Vec<NodeID> {
281 categories
282 .iter()
283 .map(|category| self.nodes_by_category(category))
284 .flatten()
285 .collect()
286 }
287
288 fn nodes_by_category_ids(&self, categories: Vec<NodeID>) -> Vec<NodeID> {
290 categories
291 .iter()
292 .map(|category| self.nodes_by_category_id(*category))
293 .flatten()
294 .collect()
295 }
296}
297
298impl<N, E> Categorized<N, E, N> for CategorizedGraph<N, E>
299where
300 Self: GraphInterface<NodeData = N, EdgeData = E>,
301{
302 fn category_id_by_name(&self, category_name: &str) -> Option<&NodeID> {
303 self.categories.get(category_name)
304 }
305
306 fn insert_category_id_by_name(&mut self, category_name: &str, category_id: NodeID) {
307 self.categories
308 .insert(category_name.to_string(), category_id);
309 }
310
311 fn create_category(
312 &mut self,
313 category: &str,
314 nodes: Vec<NodeID>,
315 data: N,
316 ) -> Result<NodeID, String>
317 where
318 E: Default + Clone,
319 N: Clone + Default,
320 {
321 let existing_category: Option<&NodeID> = self.categories.get(category);
322 if existing_category.is_some() {
323 return Err(format!("Category {} already exists", category));
324 }
325 let category_node = self.add_node(data);
326 self.add_to_category(category, nodes);
327 Ok(category_node)
328 }
329
330 fn all_categories(&self) -> Vec<(&String, NodeID)> {
331 self.categories
332 .iter()
333 .map(|(cat, node)| (cat, *node))
334 .collect()
335 }
336
337 fn category(&self, category: &str) -> Option<&Node<N>> {
338 self.categories
339 .get(category)
340 .map(|id| self.node(*id).unwrap())
341 }
342
343 fn category_by_id(&self, category: NodeID) -> Result<&Node<N>, GraphError> {
344 self.node(category)
345 }
346
347 fn nodes_by_category_id(&self, category: NodeID) -> Vec<NodeID> {
348 self.node(category)
349 .and_then(|category_node| {
350 category_node
351 .connections
352 .iter()
353 .filter_map(|edge_id| self.edge(*edge_id).map_or(None, |edge| Some(edge)))
354 .map(|edge| Ok(edge.to))
355 .collect()
356 })
357 .unwrap_or(Vec::new())
358 }
359
360 fn nodes_by_category(&self, category: &str) -> Vec<NodeID> {
361 self.categories
362 .get(category)
363 .map(|id| self.nodes_by_category_id(*id))
364 .unwrap_or(Vec::new())
365 }
366}