fast_graph/
categories.rs

1//! # A graph with category nodes.
2//!
3//! The [CategorizedGraph] struct uses a hash map to map category names ([String]) to a category node ([NodeID]) (where the node's edges are the nodes belonging to the category).
4//! There's also some useful extra functions to query categories and their nodes, and a [Categorized] trait that can be implemented for a custom struct if needed.
5//!
6//! In other words a simple extension to the graph that allows for efficient and easy grouping of nodes by strings.
7//!
8//! # Example
9//! ```
10//! use fast_graph::*;
11//!
12//! #[derive(Clone, Debug, Default, PartialEq)]
13//! #[cfg_attr(feature = "serde", derive(serde::Serialize))]
14//! enum NodeData {
15//!     String(String),
16//!     CategoryName(String),
17//!     #[default]
18//!     None,
19//! }
20//!
21//! let mut graph: CategorizedGraph<NodeData, ()> = CategorizedGraph::new();
22//!
23//! let node1 = graph.add_node(NodeData::String("Node 1".into()));
24//! let node2 = graph.add_node(NodeData::String("Node 2".into()));
25//! let node3 = graph.add_node(NodeData::String("Node 3".into()));
26//!
27//! let category1 = graph.create_category("Category 1", vec![node1, node2], NodeData::CategoryName("Category 1".into())).unwrap();
28//! let category2 = graph.add_to_category("Category 2", vec![node3]);
29//!
30//! assert_eq!(graph.all_categories().len(), 2);
31//! ```
32
33use crate::*;
34#[cfg(feature = "hashbrown")]
35use hashbrown::HashMap;
36#[cfg(feature = "hashbrown")]
37use hashbrown::HashSet;
38
39#[cfg(not(feature = "hashbrown"))]
40use std::collections::HashMap;
41#[cfg(not(feature = "hashbrown"))]
42use std::collections::HashSet;
43
44use thiserror::Error;
45
46/// A graph with category nodes (where the nodes contain an ID of the category and a list of nodes in that category) and a hash map that maps category names to category nodes efficiently.
47#[derive(Debug)]
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49pub struct CategorizedGraph<N, E> {
50    pub nodes: SlotMap<NodeID, Node<N>>,
51    pub edges: SlotMap<EdgeID, Edge<E>>,
52    pub categories: HashMap<String, NodeID>,
53}
54
55impl<N, E> GraphInterface for CategorizedGraph<N, E> {
56    type NodeData = N;
57    type EdgeData = E;
58
59    fn nodes(&self) -> impl Iterator<Item = NodeID> {
60        self.nodes.keys()
61    }
62
63    fn node_count(&self) -> usize {
64        self.nodes.len()
65    }
66
67    fn remove_node(&mut self, id: NodeID) -> Result<(), GraphError> {
68        let node = self
69            .nodes
70            .remove(id)
71            .map_or(Err(GraphError::NodeNotFound), |n| Ok(n))?;
72
73        for edge_id in node.connections.iter() {
74            self.remove_edge(*edge_id).or_else(|e| Ok(()))?;
75        }
76
77        Ok(())
78    }
79
80    fn remove_edge(&mut self, id: EdgeID) -> Result<(), GraphError> {
81        let edge = self.edge(id)?;
82        let from = edge.from;
83        let to = edge.to;
84
85        if let Ok(node) = self.node_mut(from) {
86            node.connections.retain(|&x| x != id)
87        }
88
89        if let Ok(node) = self.node_mut(to) {
90            node.connections.retain(|&x| x != id)
91        }
92
93        self.edges
94            .remove(id)
95            .map_or(Err(GraphError::EdgeNotFound), |_| Ok(()))?;
96
97        Ok(())
98    }
99
100    fn add_node(&mut self, data: N) -> NodeID {
101        let id = self.nodes.insert_with_key(|id| Node::new(id, data));
102        id
103    }
104
105    fn add_nodes(&mut self, data: &[N]) -> Vec<NodeID>
106    where
107        N: Clone,
108    {
109        let mut nodes = Vec::new();
110        for data in data {
111            let node = self.add_node(data.clone());
112            nodes.push(node);
113        }
114        nodes
115    }
116
117    fn add_edges(&mut self, data: &[(NodeID, NodeID)]) -> Vec<EdgeID>
118    where
119        E: Default + Clone,
120        N: Clone,
121    {
122        let with_data: Vec<(NodeID, NodeID, E)> = data
123            .iter()
124            .map(|(from, to)| (*from, *to, E::default()))
125            .collect();
126
127        self.add_edges_with_data(&with_data)
128    }
129
130    fn add_edge(&mut self, from: NodeID, to: NodeID, data: E) -> EdgeID {
131        let id = self
132            .edges
133            .insert_with_key(|id| Edge::new(id, from, to, data));
134        if let Some(node) = self.nodes.get_mut(from) {
135            node.add_connection(id);
136        }
137        if let Some(node) = self.nodes.get_mut(to) {
138            node.add_connection(id);
139        }
140        id
141    }
142
143    fn node(&self, id: NodeID) -> Result<&Node<N>, GraphError> {
144        self.nodes.get(id).ok_or(GraphError::NodeNotFound)
145    }
146
147    fn node_mut(&mut self, id: NodeID) -> Result<&mut Node<N>, GraphError> {
148        self.nodes.get_mut(id).ok_or(GraphError::NodeNotFound)
149    }
150
151    fn edge(&self, id: EdgeID) -> Result<&Edge<E>, GraphError> {
152        self.edges.get(id).ok_or(GraphError::EdgeNotFound)
153    }
154
155    fn edge_mut(&mut self, id: EdgeID) -> Result<&mut Edge<E>, GraphError> {
156        self.edges.get_mut(id).ok_or(GraphError::EdgeNotFound)
157    }
158}
159
160impl<N, E> CategorizedGraph<N, E> {
161    pub fn new() -> Self {
162        CategorizedGraph {
163            edges: SlotMap::with_key(),
164            nodes: SlotMap::with_key(),
165            categories: HashMap::new(),
166        }
167    }
168}
169
170#[derive(Clone, Debug, thiserror::Error)]
171#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
172pub enum CategorizedGraphError {
173    #[error("Category identified by `{0}` already exists")]
174    CategoryAlreadyExists(String),
175    #[error("Category identified by `{0}` does not exists")]
176    CategoryNotFound(String),
177}
178
179/// Methods for a graph with categories.
180pub trait Categorized<N, E, C>: GraphInterface<NodeData = N, EdgeData = E> {
181    /// Returns the category ID by name. In the standard implementation this is a hashmap lookup.
182    fn category_id_by_name(&self, category_name: &str) -> Option<&NodeID>;
183
184    /// Checks if the category exists by name.
185    fn category_exists(&self, category_name: &str) -> bool {
186        self.category_id_by_name(category_name).is_some()
187    }
188
189    /// Adds a list of nodes to a category by ID. Returns `Ok(())` if successful, otherwise returns Error([CategorizedGraphError::CategoryNotFound]).
190    fn add_to_category_by_id(
191        &mut self,
192        category_id: NodeID,
193        nodes: Vec<NodeID>,
194    ) -> Result<(), CategorizedGraphError>
195    where
196        E: Default + Clone,
197        N: Clone,
198    {
199        let category_node = self.node(category_id).map_or(
200            Err(CategorizedGraphError::CategoryNotFound(format!(
201                "NodeID({:?})",
202                category_id
203            ))),
204            |node| Ok(node),
205        )?;
206
207        let edges: Vec<(NodeID, NodeID)> = nodes
208            .iter()
209            .map(|node: &NodeID| (category_node.id, *node))
210            .collect();
211
212        self.add_edges(&edges);
213
214        Ok(())
215    }
216
217    /// In the default implementation this is used to insert the category ID into the hashmap.
218    fn insert_category_id_by_name(&mut self, category_name: &str, category_id: NodeID) {
219        // Default implementation (optional logic)
220        // You can leave this empty or provide some default behavior
221    }
222
223    /// If the category does not exist, it is created. Returns the [NodeID] of the category.
224    fn add_to_category(&mut self, category_name: &str, nodes: Vec<NodeID>) -> NodeID
225    where
226        E: Default + Clone,
227        N: Clone + Default,
228    {
229        let existing: Option<&NodeID> = self.category_id_by_name(category_name);
230        let category_node: NodeID;
231
232        if existing.is_some() {
233            category_node = *existing.unwrap();
234            self.add_to_category_by_id(category_node, nodes).unwrap();
235        } else {
236            category_node = self.add_node(N::default());
237            self.add_to_category_by_id(category_node, nodes).unwrap();
238            self.insert_category_id_by_name(category_name, category_node)
239        }
240
241        category_node
242    }
243
244    /// Creates a new category [Node] with the given name, nodes, and (optionally) data.
245    ///
246    /// Returns the [NodeID] of the category if successful, otherwise returns Error(CategorizedGraphError::CategoryAlreadyExists).
247    ///
248    /// An empty vector of nodes can be passed.
249    fn create_category(
250        &mut self,
251        category: &str,
252        nodes: Vec<NodeID>,
253        data: C,
254    ) -> Result<NodeID, String>
255    where
256        E: Default + Clone,
257        N: Clone + Default;
258
259    /// Returns a list of all categories.
260    fn all_categories(&self) -> Vec<(&String, NodeID)>;
261
262    /// Returns the category node by name.
263    fn category(&self, category: &str) -> Option<&Node<N>>;
264
265    /// Checks if the category exists by ID.
266    fn category_exists_by_id(&self, category: NodeID) -> bool {
267        self.category_by_id(category).is_ok()
268    }
269
270    /// Returns the category node by ID.
271    fn category_by_id(&self, category: NodeID) -> Result<&Node<N>, GraphError>;
272
273    /// Returns a list of nodes in the category by ID.
274    fn nodes_by_category_id(&self, category: NodeID) -> Vec<NodeID>;
275
276    /// Returns a list of nodes in the category by name.
277    fn nodes_by_category(&self, category: &str) -> Vec<NodeID>;
278
279    /// Returns a list of nodes in the categories by name.
280    fn nodes_by_categories(&self, categories: Vec<&str>) -> Vec<NodeID> {
281        categories
282            .iter()
283            .map(|category| self.nodes_by_category(category))
284            .flatten()
285            .collect()
286    }
287
288    /// Returns a list of nodes in the categories by ID.
289    fn nodes_by_category_ids(&self, categories: Vec<NodeID>) -> Vec<NodeID> {
290        categories
291            .iter()
292            .map(|category| self.nodes_by_category_id(*category))
293            .flatten()
294            .collect()
295    }
296}
297
298impl<N, E> Categorized<N, E, N> for CategorizedGraph<N, E>
299where
300    Self: GraphInterface<NodeData = N, EdgeData = E>,
301{
302    fn category_id_by_name(&self, category_name: &str) -> Option<&NodeID> {
303        self.categories.get(category_name)
304    }
305
306    fn insert_category_id_by_name(&mut self, category_name: &str, category_id: NodeID) {
307        self.categories
308            .insert(category_name.to_string(), category_id);
309    }
310
311    fn create_category(
312        &mut self,
313        category: &str,
314        nodes: Vec<NodeID>,
315        data: N,
316    ) -> Result<NodeID, String>
317    where
318        E: Default + Clone,
319        N: Clone + Default,
320    {
321        let existing_category: Option<&NodeID> = self.categories.get(category);
322        if existing_category.is_some() {
323            return Err(format!("Category {} already exists", category));
324        }
325        let category_node = self.add_node(data);
326        self.add_to_category(category, nodes);
327        Ok(category_node)
328    }
329
330    fn all_categories(&self) -> Vec<(&String, NodeID)> {
331        self.categories
332            .iter()
333            .map(|(cat, node)| (cat, *node))
334            .collect()
335    }
336
337    fn category(&self, category: &str) -> Option<&Node<N>> {
338        self.categories
339            .get(category)
340            .map(|id| self.node(*id).unwrap())
341    }
342
343    fn category_by_id(&self, category: NodeID) -> Result<&Node<N>, GraphError> {
344        self.node(category)
345    }
346
347    fn nodes_by_category_id(&self, category: NodeID) -> Vec<NodeID> {
348        self.node(category)
349            .and_then(|category_node| {
350                category_node
351                    .connections
352                    .iter()
353                    .filter_map(|edge_id| self.edge(*edge_id).map_or(None, |edge| Some(edge)))
354                    .map(|edge| Ok(edge.to))
355                    .collect()
356            })
357            .unwrap_or(Vec::new())
358    }
359
360    fn nodes_by_category(&self, category: &str) -> Vec<NodeID> {
361        self.categories
362            .get(category)
363            .map(|id| self.nodes_by_category_id(*id))
364            .unwrap_or(Vec::new())
365    }
366}