wolf_graph/
graph.rs

1use std::sync::{Arc, RwLock};
2use std::borrow::Cow;
3use std::collections::BTreeMap;
4
5use anyhow::{Result, bail};
6#[cfg(feature = "serde")]
7use serde::{ser::{Serialize, Serializer}, de::{self, Deserialize, Deserializer}};
8
9use crate::{Edges, MutableGraph, Nodes, VisitableGraph};
10use crate::{
11    ids::{EdgeID, NodeID},
12    Error,
13    details::{edge::Edge, node::Node}
14};
15
16/// A general graph data structure with value semantics.
17///
18/// The graph is parameterized by three types:
19/// - `GData` is the type of the graph data.
20/// - `NData` is the type of the node data.
21/// - `EData` is the type of the edge data.
22///
23/// The graph is implemented as a set of nodes and edges. Nodes are identified by
24/// unique `NodeID` values, and edges are identified by unique `EdgeID` values.
25/// The graph and its nodes and edges can have associated data, or they can be
26/// the unit type `()`.
27#[derive(Debug, Clone)]
28pub struct Graph<GData, NData, EData>
29where
30    GData: Clone + 'static,
31    NData: Clone + 'static,
32    EData: Clone + 'static,
33{
34    nodes: Arc<RwLock<BTreeMap<NodeID, Cow<'static, Node<NData>>>>>,
35    edges: Arc<RwLock<BTreeMap<EdgeID, Cow<'static, Edge<EData>>>>>,
36    data: Cow<'static, GData>,
37}
38
39/// A convenience type for a graph with no data.
40pub type BlankGraph = Graph<(), (), ()>;
41
42// Internal Utilities
43impl<GData, NData, EData> Graph<GData, NData, EData>
44where
45    GData: Clone + 'static,
46    NData: Clone + 'static,
47    EData: Clone + 'static,
48{
49    fn _node(&self, id: &NodeID) -> Result<Cow<'static, Node<NData>>, Error> {
50        self.nodes.read().unwrap().get(id).cloned()
51            .ok_or(Error::NodeNotFound(id.clone()))
52    }
53
54    fn _edge(&self, id: &EdgeID) -> Result<Cow<'static, Edge<EData>>, Error> {
55        self.edges.read().unwrap().get(id).cloned()
56            .ok_or(Error::EdgeNotFound(id.clone()))
57    }
58
59    fn _update_nodes(&mut self, f: impl FnOnce(&mut BTreeMap<NodeID, Cow<'static, Node<NData>>>)) {
60        let mut nodes = (*self.nodes.read().unwrap()).clone();
61        f(&mut nodes);
62        self.nodes = Arc::new(RwLock::new(nodes));
63    }
64
65    fn _update_edges(&mut self, f: impl FnOnce(&mut BTreeMap<EdgeID, Cow<'static, Edge<EData>>>)) {
66        let mut edges = (*self.edges.read().unwrap()).clone();
67        f(&mut edges);
68        self.edges = Arc::new(RwLock::new(edges));
69    }
70
71    fn _remove_edge(&mut self, edge: &mut Edge<EData>) {
72        let source = edge.source.clone();
73        let target = edge.target.clone();
74        self._update_edges(|edges| {
75            edges.remove(&edge.id);
76        });
77        self._update_nodes(|nodes| {
78            if source == target {
79                // Self-loop
80                let node = nodes.get_mut(&source).unwrap()
81                    .removing_out_edge(&edge.id)
82                    .removing_in_edge(&edge.id);
83                nodes.insert(source, Cow::Owned(node));
84            } else {
85                let source_node = nodes.get_mut(&source).unwrap().removing_out_edge(&edge.id);
86                let target_node = nodes.get_mut(&target).unwrap().removing_in_edge(&edge.id);
87                nodes.insert(source, Cow::Owned(source_node));
88                nodes.insert(target, Cow::Owned(target_node));
89            }
90        });
91    }
92
93    fn _remove_node(&mut self, id: &NodeID) {
94        self._update_nodes(|nodes| {
95            nodes.remove(id);
96        });
97    }
98
99    fn _clear_edges(&mut self, node: &mut Node<NData>) {
100        for edge_id in self.incident_edges(&node.id).unwrap() {
101            let mut edge = self._edge(&edge_id).unwrap().into_owned();
102            self._remove_edge(&mut edge);
103        }
104    }
105}
106
107// Constructor where `GData` is specified directly or is not `Default`
108impl<GData, NData, EData> Graph<GData, NData, EData>
109where
110    GData: Clone + 'static,
111    NData: Clone + 'static,
112    EData: Clone + 'static,
113{
114    pub fn new_with_data(data: GData) -> Self {
115        Graph {
116            nodes: Arc::new(RwLock::new(BTreeMap::new())),
117            edges: Arc::new(RwLock::new(BTreeMap::new())),
118            data: Cow::Owned(data),
119        }
120    }
121}
122
123// Constructor where `GData` implements `Default`
124impl<GData, NData, EData> Graph<GData, NData, EData>
125where
126    GData: Clone + Default + 'static,
127    NData: Clone + 'static,
128    EData: Clone + 'static,
129{
130    pub fn new() -> Self {
131        Graph {
132            nodes: Arc::new(RwLock::new(BTreeMap::new())),
133            edges: Arc::new(RwLock::new(BTreeMap::new())),
134            data: Cow::Owned(GData::default()),
135        }
136    }
137}
138
139impl<GData, NData, EData> Default for Graph<GData, NData, EData>
140where
141    GData: Clone + Default + 'static,
142    NData: Clone + 'static,
143    EData: Clone + 'static,
144{
145    fn default() -> Self {
146        Graph::new()
147    }
148}
149
150impl<GData, NData, EData> MutableGraph for Graph<GData, NData, EData>
151where
152    GData: Clone + 'static,
153    NData: Clone + 'static,
154    EData: Clone + 'static,
155{
156    fn add_node_with_data(&mut self, id: impl AsRef<NodeID>, data: NData) -> Result<()> {
157        let id = id.as_ref();
158        if self.nodes.read().unwrap().contains_key(id) {
159            bail!(Error::DuplicateNode(id.clone()))
160        } else {
161            let node = Node::new(id.clone(), data);
162            self._update_nodes(|nodes| {
163                nodes.insert(id.clone(), Cow::Owned(node));
164            });
165            Ok(())
166        }
167    }
168
169    fn add_edge_with_data(&mut self, id: impl AsRef<EdgeID>, source: impl AsRef<NodeID>, target: impl AsRef<NodeID>, data: Self::EData) -> Result<()> {
170        let id = id.as_ref();
171        let source = source.as_ref();
172        let target = target.as_ref();
173        if self.edges.read().unwrap().contains_key(id) {
174            bail!(Error::DuplicateEdge(id.clone()))
175        } else if !self.nodes.read().unwrap().contains_key(source) {
176            bail!(Error::NodeNotFound(source.clone()))
177        } else if !self.nodes.read().unwrap().contains_key(target) {
178            bail!(Error::NodeNotFound(target.clone()))
179        } else {
180            self._update_edges(|edges| {
181                edges.insert(id.clone(), Cow::Owned(Edge::new(id.clone(), source.clone(), target.clone(), data)));
182            });
183
184            self._update_nodes(|nodes| {
185                if source == target {
186                    // Self-loop
187                    let node = nodes.get_mut(source).unwrap()
188                        .inserting_out_edge(id.clone())
189                        .inserting_in_edge(id.clone());
190                    nodes.insert(source.clone(), Cow::Owned(node));
191                } else {
192                    let source_node = nodes.get_mut(source).unwrap().inserting_out_edge(id.clone());
193                    let target_node = nodes.get_mut(target).unwrap().inserting_in_edge(id.clone());
194                    nodes.insert(source.clone(), Cow::Owned(source_node));
195                    nodes.insert(target.clone(), Cow::Owned(target_node));
196                }
197            });
198
199            Ok(())
200        }
201    }
202
203    fn remove_edge(&mut self, id: impl AsRef<EdgeID>) -> Result<()> {
204        let edge = self._edge(id.as_ref())?;
205        self._remove_edge(&mut edge.into_owned());
206        Ok(())
207    }
208
209    fn clear_edges(&mut self, id: impl AsRef<NodeID>) -> Result<()> {
210        let mut node = self._node(id.as_ref())?.into_owned();
211        self._clear_edges(&mut node);
212        Ok(())
213    }
214
215    fn remove_node(&mut self, id: impl AsRef<NodeID>) -> Result<()> {
216        let id = id.as_ref();
217        let mut node = self._node(id)?.into_owned();
218        self._clear_edges(&mut node);
219        self._remove_node(id);
220        Ok(())
221    }
222
223    fn move_edge(&mut self, id: impl AsRef<EdgeID>, new_source: impl AsRef<NodeID>, new_target: impl AsRef<NodeID>) -> Result<()> {
224        let id = id.as_ref();
225        let new_source = new_source.as_ref();
226        let new_target = new_target.as_ref();
227        self._node(new_source)?;
228        self._node(new_target)?;
229        let mut edge = self._edge(id.as_ref())?.into_owned();
230        let old_source = edge.source.clone();
231        let old_target = edge.target.clone();
232        if &old_source != new_source || &old_target != new_target {
233            edge.source = new_source.clone();
234            edge.target = new_target.clone();
235            self._update_edges(|edges| {
236                edges.insert(id.clone(), Cow::Owned(edge));
237            });
238            self._update_nodes(|nodes| {
239                if &old_source != new_source && &old_target != new_target {
240                    let old_source_node = nodes.get_mut(&old_source).unwrap().removing_out_edge(id);
241                    let old_target_node = nodes.get_mut(&old_target).unwrap().removing_in_edge(id);
242                    nodes.insert(old_source.clone(), Cow::Owned(old_source_node));
243                    nodes.insert(old_target.clone(), Cow::Owned(old_target_node));
244                    let new_source_node = nodes.get_mut(new_source).unwrap().inserting_out_edge(id.clone());
245                    let new_target_node = nodes.get_mut(new_target).unwrap().inserting_in_edge(id.clone());
246                    nodes.insert(new_source.clone(), Cow::Owned(new_source_node));
247                    nodes.insert(new_target.clone(), Cow::Owned(new_target_node));
248                } else if &old_source != new_source {
249                    let old_source_node = nodes.get_mut(&old_source).unwrap().removing_out_edge(id);
250                    nodes.insert(old_source.clone(), Cow::Owned(old_source_node));
251                    let new_source_node = nodes.get_mut(new_source).unwrap().inserting_out_edge(id.clone());
252                    nodes.insert(new_source.clone(), Cow::Owned(new_source_node));
253                } else if &old_target != new_target {
254                    let old_target_node = nodes.get_mut(&old_target).unwrap().removing_in_edge(id);
255                    nodes.insert(old_target.clone(), Cow::Owned(old_target_node));
256                    let new_target_node = nodes.get_mut(new_target).unwrap().inserting_in_edge(id.clone());
257                    nodes.insert(new_target.clone(), Cow::Owned(new_target_node));
258                }
259            });
260        }
261        Ok(())
262    }
263
264    fn set_data(&mut self, data: GData) {
265        self.data = Cow::Owned(data);
266    }
267
268    fn set_node_data(&mut self, id: impl AsRef<NodeID>, data: NData) -> Result<()> {
269        let id = id.as_ref();
270        self._node(id)?;
271        self._update_nodes(|nodes| {
272            let node = nodes.get_mut(id).unwrap().setting_data(data);
273            nodes.insert(id.clone(), Cow::Owned(node));
274        });
275        Ok(())
276    }
277
278    fn set_edge_data(&mut self, id: impl AsRef<EdgeID>, data: EData) -> Result<()> {
279        let id = id.as_ref();
280        self._edge(id)?;
281        self._update_edges(|edges| {
282            let edge = edges.get_mut(id).unwrap().setting_data(data);
283            edges.insert(id.clone(), Cow::Owned(edge));
284        });
285        Ok(())
286    }
287
288    fn with_data(&mut self, transform: &dyn Fn(&mut Self::GData)) {
289        let mut data = self.data.clone().into_owned();
290        transform(&mut data);
291        self.set_data(data);
292    }
293
294    fn with_node_data(&mut self, id: impl AsRef<NodeID>, transform: &dyn Fn(&mut Self::NData)) -> Result<()> {
295        let id = id.as_ref();
296        let mut data = self.node_data(id)?.into_owned();
297        transform(&mut data);
298        self.set_node_data(id, data)?;
299        Ok(())
300    }
301
302    fn with_edge_data(&mut self, id: impl AsRef<EdgeID>, transform: &dyn Fn(&mut Self::EData)) -> Result<()> {
303        let id = id.as_ref();
304        let mut data = self.edge_data(id)?.into_owned();
305        transform(&mut data);
306        self.set_edge_data(id, data)?;
307        Ok(())
308    }
309}
310
311// Queries
312impl<GData, NData, EData> VisitableGraph for Graph<GData, NData, EData>
313where
314    GData: Clone + 'static,
315    NData: Clone + 'static,
316    EData: Clone + 'static,
317{
318    type GData = GData;
319    type NData = NData;
320    type EData = EData;
321
322    fn is_empty(&self) -> bool {
323        self.nodes.read().unwrap().is_empty()
324    }
325
326    fn node_count(&self) -> usize {
327        self.nodes.read().unwrap().len()
328    }
329
330    fn edge_count(&self) -> usize {
331        self.edges.read().unwrap().len()
332    }
333
334    fn all_nodes(&self) -> Nodes {
335        self.nodes.read().unwrap().keys().cloned().collect()
336    }
337
338    fn all_edges(&self) -> Edges {
339        self.edges.read().unwrap().keys().cloned().collect()
340    }
341
342    fn has_node(&self, id: impl AsRef<NodeID>) -> bool {
343        self.nodes.read().unwrap().contains_key(id.as_ref())
344    }
345
346    fn has_edge(&self, id: impl AsRef<EdgeID>) -> bool {
347        self.edges.read().unwrap().contains_key(id.as_ref())
348    }
349
350    fn has_edge_from_to(&self, source: impl AsRef<NodeID>, target: impl AsRef<NodeID>) -> bool {
351        self.edges.read().unwrap()
352            .values().any(|edge| &edge.source == source.as_ref() && &edge.target == target.as_ref())
353    }
354
355    fn has_edge_between(&self, a: impl AsRef<NodeID>, b: impl AsRef<NodeID>) -> bool {
356        let a = a.as_ref();
357        let b = b.as_ref();
358        self.has_edge_from_to(a, b) || self.has_edge_from_to(b, a)
359    }
360
361    fn source(&self, id: impl AsRef<EdgeID>) -> Result<NodeID> {
362        Ok(self._edge(id.as_ref())?.source.clone())
363    }
364
365    fn target(&self, id: impl AsRef<EdgeID>) -> Result<NodeID> {
366        Ok(self._edge(id.as_ref())?.target.clone())
367    }
368
369    fn endpoints(&self, id: impl AsRef<EdgeID>) -> Result<(NodeID, NodeID)> {
370        let edge = self._edge(id.as_ref())?;
371        Ok((edge.source.clone(), edge.target.clone()))
372    }
373
374    fn out_edges(&self, id: impl AsRef<NodeID>) -> Result<Edges> {
375        Ok(self._node(id.as_ref())?.out_edges.read().unwrap().iter().cloned().collect())
376    }
377
378    fn in_edges(&self, id: impl AsRef<NodeID>) -> Result<Edges> {
379        Ok(self._node(id.as_ref())?.in_edges.read().unwrap().iter().cloned().collect())
380    }
381
382    fn incident_edges(&self, id: impl AsRef<NodeID>) -> Result<Edges> {
383        let id = id.as_ref();
384        let mut edges = self.out_edges(id)?;
385        edges.extend(self.in_edges(id)?);
386        Ok(edges)
387    }
388
389    fn out_degree(&self, id: impl AsRef<NodeID>) -> Result<usize> {
390        Ok(self._node(id.as_ref())?.out_edges.read().unwrap().len())
391    }
392
393    fn in_degree(&self, id: impl AsRef<NodeID>) -> Result<usize> {
394        Ok(self._node(id.as_ref())?.in_edges.read().unwrap().len())
395    }
396
397    fn degree(&self, id: impl AsRef<NodeID>) -> Result<usize> {
398        let id = id.as_ref();
399        Ok(self.in_degree(id)? + self.out_degree(id)?)
400    }
401
402    fn successors(&self, id: impl AsRef<NodeID>) -> Result<Nodes> {
403        Ok(self.out_edges(id)?.iter()
404            .map(|edge| self.target(edge).unwrap()).collect())
405    }
406
407    fn predecessors(&self, id: impl AsRef<NodeID>) -> Result<Nodes> {
408        Ok(self.in_edges(id)?.iter()
409            .map(|edge| self.source(edge).unwrap()).collect())
410    }
411
412    fn neighbors(&self, id: impl AsRef<NodeID>) -> Result<Nodes> {
413        let id = id.as_ref();
414        let mut neighbors = self.successors(id)?;
415        neighbors.extend(self.predecessors(id)?);
416        Ok(neighbors.into_iter().collect())
417    }
418
419    fn has_successors(&self, id: impl AsRef<NodeID>) -> Result<bool> {
420        Ok(self.out_degree(id)? > 0)
421    }
422
423    fn has_predecessors(&self, id: impl AsRef<NodeID>) -> Result<bool> {
424        Ok(self.in_degree(id)? > 0)
425    }
426
427    fn has_neighbors(&self, id: impl AsRef<NodeID>) -> Result<bool> {
428        Ok(self.degree(id)? > 0)
429    }
430
431    fn data(&self) -> &GData {
432        &self.data
433    }
434
435    fn node_data(&self, id: impl AsRef<NodeID>) -> Result<Cow<'static, NData>> {
436        Ok(self._node(id.as_ref())?.data.clone())
437    }
438
439    fn edge_data(&self, id: impl AsRef<EdgeID>) -> Result<Cow<'static, EData>> {
440        Ok(self._edge(id.as_ref())?.data.clone())
441    }
442
443    fn all_roots(&self) -> Nodes {
444        self.nodes.read().unwrap().iter()
445            .filter(|(_, node)| node.in_edges.read().unwrap().is_empty())
446            .map(|(id, _)| id).cloned().collect()
447    }
448
449    fn all_leaves(&self) -> Nodes {
450        self.nodes.read().unwrap().iter()
451            .filter(|(_, node)| node.out_edges.read().unwrap().is_empty())
452            .map(|(id, _)| id).cloned().collect()
453    }
454
455    fn non_roots(&self) -> Nodes {
456        self.nodes.read().unwrap().iter()
457            .filter(|(_, node)| !node.in_edges.read().unwrap().is_empty())
458            .map(|(id, _)| id).cloned().collect()
459    }
460
461    fn non_leaves(&self) -> Nodes {
462        self.nodes.read().unwrap().iter()
463            .filter(|(_, node)| !node.out_edges.read().unwrap().is_empty())
464            .map(|(id, _)| id).cloned().collect()
465    }
466
467    fn all_internals(&self) -> Nodes {
468        self.nodes.read().unwrap().iter()
469            .filter(|(_, node)| {
470                !node.in_edges.read().unwrap().is_empty() && !node.out_edges.read().unwrap().is_empty()
471            })
472            .map(|(id, _)| id).cloned().collect()
473    }
474
475    fn is_leaf(&self, id: impl AsRef<NodeID>) -> Result<bool> {
476        Ok(self._node(id.as_ref())?.out_edges.read().unwrap().is_empty())
477    }
478
479    fn is_root(&self, id: impl AsRef<NodeID>) -> Result<bool> {
480        Ok(self._node(id.as_ref())?.in_edges.read().unwrap().is_empty())
481    }
482
483    fn is_internal(&self, id: impl AsRef<NodeID>) -> Result<bool> {
484        let node = self._node(id.as_ref())?;
485        Ok(!node.in_edges.read().unwrap().is_empty() && !node.out_edges.read().unwrap().is_empty())
486    }
487}
488
489impl<GData, NData, EData> PartialEq for Graph<GData, NData, EData>
490where
491    GData: Clone + PartialEq + 'static,
492    NData: Clone + PartialEq + 'static,
493    EData: Clone + PartialEq + 'static,
494{
495    fn eq(&self, other: &Self) -> bool {
496        self.nodes.read().unwrap().iter().eq(other.nodes.read().unwrap().iter())
497            && self.edges.read().unwrap().iter().eq(other.edges.read().unwrap().iter())
498            && self.data == other.data
499    }
500}
501
502#[cfg(feature = "serde")]
503impl<GData, NData, EData> Serialize for Graph<GData, NData, EData>
504where
505    GData: Clone + Serialize + 'static,
506    NData: Clone + Serialize + 'static,
507    EData: Clone + Serialize + 'static,
508{
509    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
510    where
511        S: Serializer,
512    {
513        let nodes: Vec<Node<NData>> = self.nodes.read().unwrap()
514            .values()
515            .map(|n| n.clone().into_owned())
516            .collect();
517
518        let edges: Vec<Edge<EData>> = self.edges.read().unwrap()
519            .values()
520            .map(|e| e.clone().into_owned())
521            .collect();
522
523        if std::mem::size_of::<GData>() > 0 {
524            (&nodes, &edges, &self.data).serialize(serializer)
525        } else {
526            (&nodes, &edges).serialize(serializer)
527        }
528    }
529}
530
531#[cfg(feature = "serde")]
532impl<'de, GData, NData, EData> Deserialize<'de> for Graph<GData, NData, EData>
533where
534    GData: Clone + Deserialize<'de> + 'static + Default,
535    NData: Clone + Deserialize<'de> + 'static + Default,
536    EData: Clone + Deserialize<'de> + 'static + Default,
537{
538    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
539    where
540        D: Deserializer<'de>,
541    {
542        let mut graph: Graph<GData, NData, EData>;
543        let nodes: Vec<Node<NData>>;
544        let edges: Vec<Edge<EData>>;
545
546        if std::mem::size_of::<GData>() > 0 {
547            let (n, e, data): (Vec<Node<NData>>, Vec<Edge<EData>>, GData) = Deserialize::deserialize(deserializer)?;
548            graph = Graph::new_with_data(data);
549            nodes = n;
550            edges = e;
551        } else {
552            let (n, e): (Vec<Node<NData>>, Vec<Edge<EData>>) = Deserialize::deserialize(deserializer)?;
553            graph = Graph::new();
554            nodes = n;
555            edges = e;
556        }
557
558        for node in nodes {
559            let node_data: NData = node.data.into_owned();
560            graph.add_node_with_data(&node.id, node_data).map_err(de::Error::custom)?;
561        }
562
563        for edge in edges {
564            let edge_data: EData = edge.data.into_owned();
565            graph.add_edge_with_data(&edge.id, &edge.source, &edge.target, edge_data).map_err(de::Error::custom)?;
566        }
567
568        Ok(graph)
569    }
570}
571
572// If Serde and SerdeJSON are both present, add convenience to serialize a Graph
573// to JSON.
574#[cfg(all(feature = "serde", feature = "serde_json"))]
575impl<GData, NData, EData> Graph<GData, NData, EData>
576where
577    GData: Clone + Serialize + 'static,
578    NData: Clone + Serialize + 'static,
579    EData: Clone + Serialize + 'static,
580{
581    pub fn to_json(&self) -> String {
582        serde_json::to_string(self).unwrap()
583    }
584}
585
586// If Serde and SerdeJSON are both present, add convenience to deserialize a
587// Graph from JSON.
588#[cfg(all(feature = "serde", feature = "serde_json"))]
589impl<'de, GData, NData, EData> Graph<GData, NData, EData>
590where
591    GData: Clone + Default + Deserialize<'de> + 'static,
592    NData: Clone + Default + Deserialize<'de> + 'static,
593    EData: Clone + Default + Deserialize<'de> + 'static,
594{
595    pub fn from_json(json: &'de str) -> Result<Self, serde_json::Error> {
596        serde_json::from_str(json)
597    }
598}
599
600#[cfg(all(feature = "serde", feature = "serde_json"))]
601impl<GData, NData, EData> std::fmt::Display for Graph<GData, NData, EData>
602where
603    GData: Clone + Serialize + 'static,
604    NData: Clone + Serialize + 'static,
605    EData: Clone + Serialize + 'static,
606{
607    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
608        write!(f, "{}", self.to_json())
609    }
610}