drisk_api/
diff.rs

1use hashbrown::{HashMap, HashSet};
2use serde::{Deserialize, Serialize};
3use std::{fmt::Debug, hash::Hash, ops::AddAssign};
4
5/// A differential between two graphs.
6///
7/// Contains a diff for the nodes and edges of a graph. Each diff contains new or updated
8/// items and items that are marked for deletion. The diff will always be internally consistent
9/// if the safe, public methods are used.
10///
11/// `GraphDiff` requires two generic types:
12/// * `Id` is the type used to index nodes in the graph. It requires standard trait bounds for
13/// index types.
14/// * `T` is the type used to represent node property updates. It requires `Default` used
15/// when adding a new node to the diff and `AddAssign` to combine updates.
16///
17/// `GraphDiff`s support composition with `AddAssign`:
18/// ```
19/// use drisk_api::GraphDiff;
20///
21/// let mut diff1: GraphDiff<u32, u32> = GraphDiff::default();
22/// let diff2: GraphDiff<u32, u32> = GraphDiff::default();
23///
24/// // `diff1` will contain all nodes and edges from `diff2`.
25/// // If a node or edge is updated in both, the updates will be combined.
26/// // Updates to the same properties from `diff2` will overwrite updates from `diff1`.
27/// // If a node is deleted in `diff2`, it will be deleted in the combined diff.
28/// diff1 += diff2;
29/// ```
30#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
31pub struct GraphDiff<Id: Hash + Eq + Copy, T: Default + AddAssign, W = f32> {
32    pub(crate) nodes: NodeDiff<Id, T>,
33    pub(crate) edges: EdgeDiff<Id, W>,
34}
35
36impl<Id: Hash + Eq + Copy, T: Default + AddAssign> Default for GraphDiff<Id, T> {
37    fn default() -> GraphDiff<Id, T> {
38        GraphDiff {
39            nodes: NodeDiff {
40                new_or_updated: HashMap::new(),
41                deleted: HashSet::new(),
42            },
43            edges: EdgeDiff {
44                new_or_updated: HashMap::new(),
45                deleted: HashMap::new(),
46            },
47        }
48    }
49}
50
51impl<Id: Hash + Eq + Copy, T: Default + AddAssign, W: Copy + PartialEq> GraphDiff<Id, T, W> {
52    pub fn new() -> GraphDiff<Id, T> {
53        GraphDiff::default()
54    }
55
56    /// Initialse diff from a NodeDiff and an EdgeDiff
57    pub fn from_diffs(nodes: NodeDiff<Id, T>, edges: EdgeDiff<Id, W>) -> GraphDiff<Id, T, W> {
58        GraphDiff { nodes, edges }
59    }
60
61    /// Get a reference to the node diff.
62    pub fn nodes(&self) -> &NodeDiff<Id, T> {
63        &self.nodes
64    }
65
66    /// Get a reference to the new or updated nodes.
67    pub fn new_or_updated_nodes(&self) -> &HashMap<Id, T> {
68        &self.nodes.new_or_updated
69    }
70
71    /// Get a reference to the deleted nodes.
72    pub fn deleted_nodes(&self) -> &HashSet<Id> {
73        &self.nodes.deleted
74    }
75
76    /// Get a reference to the edge diff.
77    pub fn edges(&self) -> &EdgeDiff<Id, W> {
78        &self.edges
79    }
80
81    /// Get a reference to the new or updated edges.
82    pub fn new_or_updated_edges(&self) -> &HashMap<Id, HashMap<Id, W>> {
83        &self.edges.new_or_updated
84    }
85
86    /// Get a reference to the deleted edges.
87    pub fn deleted_edges(&self) -> &HashMap<Id, HashSet<Id>> {
88        &self.edges.deleted
89    }
90
91    /// Returns `true` if the diff contains no nodes or edges (new, updated or deleted).
92    pub fn is_empty(&self) -> bool {
93        self.nodes.new_or_updated.is_empty()
94            && self.nodes.deleted.is_empty()
95            && self.edges.new_or_updated.is_empty()
96            && self.edges.deleted.is_empty()
97    }
98
99    /// Add a new node to the diff. If previously marked as deleted, it will be overwritten.
100    pub fn add_node(&mut self, node_id: &Id) {
101        let _ = self.nodes.new_or_updated.try_insert(*node_id, T::default());
102        self.nodes.deleted.remove(node_id);
103    }
104
105    /// Add or update a node in the diff with an update.
106    /// If previously marked as deleted, it will be overwritten
107    pub fn add_or_update_node(&mut self, node_id: &Id, update: T) {
108        if let Some(node) = self.nodes.new_or_updated.get_mut(node_id) {
109            *node += update;
110        } else {
111            self.nodes.new_or_updated.insert(*node_id, update);
112        }
113        self.nodes.deleted.remove(node_id);
114    }
115
116    /// Get a mutable reference to a node update in the diff. If the node is not
117    /// present, it will be added with an empty update.
118    pub fn get_or_create_mut_node_update(&mut self, node_id: &Id) -> &mut T {
119        if self.nodes.new_or_updated.get(node_id).is_none() {
120            self.add_node(node_id);
121        };
122        self.nodes.new_or_updated.get_mut(node_id).unwrap()
123    }
124
125    /// Use with caution: overwrites the node update to whatever you provide.
126    pub fn set_node_update(&mut self, node_id: &Id, update: T) {
127        self.nodes.new_or_updated.insert(*node_id, update);
128        self.nodes.deleted.remove(node_id);
129    }
130
131    /// Add a new node to be deleted to the diff.
132    /// If present the node will be removed from `new_or_updated`.
133    /// It further updates the edge diff to make sure an edge
134    /// deletion is recorded for all edges connecting to the node.
135    pub fn delete_node(&mut self, node_id: Id) {
136        self.nodes.new_or_updated.remove(&node_id);
137
138        // remove all edges where node_id is predecessor
139        self.edges.new_or_updated.remove(&node_id);
140
141        for (from, to_weight) in self.edges.new_or_updated.iter_mut() {
142            if to_weight.contains_key(&node_id) {
143                self.edges.deleted.entry(*from).or_default().insert(node_id);
144            }
145            // remove all edges where node_id is successor
146            to_weight.remove(&node_id);
147        }
148        self.nodes.deleted.insert(node_id);
149    }
150
151    /// Add a new edge to the diff.
152    /// If previously marked as deleted, it will be overwritten
153    /// If either the from or to nodes are marked as deleted, it will error.
154    pub fn add_edge(
155        &mut self,
156        from: &Id,
157        to: &Id,
158        weight: W,
159    ) -> Result<(), Box<dyn std::error::Error>> {
160        if self.nodes.deleted.contains(from) || self.nodes.deleted.contains(to) {
161            return Err("Either from or to nodes are marked to be deleted".into());
162        }
163        if let Some(inner) = self.edges.deleted.get_mut(from) {
164            inner.remove(to);
165        }
166        if self.edges.deleted.get(from).is_some_and(|e| e.is_empty()) {
167            self.edges.deleted.remove(from);
168        }
169        self.edges
170            .new_or_updated
171            .entry(*from)
172            .or_default()
173            .insert(*to, weight);
174        Ok(())
175    }
176
177    /// Add edges in batch to the dif.
178    pub fn add_edges(
179        &mut self,
180        edges: &HashMap<Id, HashMap<Id, W>>,
181    ) -> Result<(), Box<dyn std::error::Error>> {
182        for (from, to_weight) in edges {
183            for (to, weight) in to_weight {
184                self.add_edge(from, to, *weight)?;
185            }
186        }
187        Ok(())
188    }
189
190    /// Delete edges in batch from the diff.
191    pub fn delete_edges(
192        &mut self,
193        edges: &HashMap<Id, HashSet<Id>>,
194    ) -> Result<(), Box<dyn std::error::Error>> {
195        for (from, to_set) in edges {
196            for to in to_set {
197                self.delete_edge(from, to);
198            }
199        }
200        Ok(())
201    }
202
203    /// # Safety
204    /// Does not check that the node IDs are valid (i.e. not marked as deleted).
205    pub unsafe fn add_edges_unchecked(
206        &mut self,
207        edges: HashMap<Id, HashMap<Id, W>>,
208    ) -> Result<(), Box<dyn std::error::Error>> {
209        for (from, inner_map) in edges {
210            self.edges
211                .new_or_updated
212                .entry(from)
213                .or_default()
214                .extend(inner_map);
215        }
216        Ok(())
217    }
218
219    /// Add a new edge to be deleted to the diff.
220    /// If present, the edge is removed from `new_or_updated`.
221    pub fn delete_edge(&mut self, from: &Id, to: &Id) {
222        self.edges.deleted.entry(*from).or_default().insert(*to);
223
224        let empty_inner_map = match self.edges.new_or_updated.get_mut(from) {
225            None => false,
226            Some(to_weight) => {
227                to_weight.remove(to);
228                to_weight.is_empty()
229            }
230        };
231        if empty_inner_map {
232            self.edges.new_or_updated.remove(from);
233        }
234    }
235
236    /// Clear the diff of all nodes and edges.
237    pub fn clear(&mut self) {
238        self.nodes.new_or_updated.clear();
239        self.nodes.deleted.clear();
240        self.edges.new_or_updated.clear();
241        self.edges.deleted.clear();
242    }
243
244    #[cfg(test)]
245    fn is_internally_consistent(&self) -> bool {
246        for (from, to_weight) in self.edges.new_or_updated.iter() {
247            if self.nodes.deleted.contains(from) {
248                return false;
249            }
250            for (to, _) in to_weight.iter() {
251                if self.nodes.deleted.contains(to) {
252                    return false;
253                }
254            }
255        }
256        for (from, to_set) in self.edges.deleted.iter() {
257            if self.nodes.deleted.contains(from) {
258                return false;
259            }
260            for to in to_set.iter() {
261                if self.nodes.deleted.contains(to) {
262                    return false;
263                }
264            }
265        }
266        true
267    }
268}
269
270impl<Id: Hash + Eq + Copy, T: Default + AddAssign> AddAssign for GraphDiff<Id, T> {
271    fn add_assign(&mut self, other: Self) {
272        *self += other.nodes;
273        *self += other.edges;
274    }
275}
276
277impl<Id: Hash + Eq + Copy, T: Default + AddAssign> AddAssign<EdgeDiff<Id>> for GraphDiff<Id, T> {
278    fn add_assign(&mut self, edges: EdgeDiff<Id>) {
279        for (from, to_weight) in edges.new_or_updated {
280            for (to, weight) in to_weight {
281                let _ = self.add_edge(&from, &to, weight);
282            }
283        }
284        for (from, to) in edges.deleted {
285            for to in to {
286                self.delete_edge(&from, &to);
287            }
288        }
289    }
290}
291
292impl<Id: Hash + Eq + Copy, T: Default + AddAssign> AddAssign<NodeDiff<Id, T>> for GraphDiff<Id, T> {
293    fn add_assign(&mut self, nodes: NodeDiff<Id, T>) {
294        for (node_id, update) in nodes.new_or_updated {
295            self.add_or_update_node(&node_id, update);
296        }
297        for node_id in nodes.deleted {
298            self.delete_node(node_id);
299        }
300    }
301}
302
303/// A diff between the nodes of a graph.
304#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
305#[serde(rename_all = "camelCase")]
306pub struct NodeDiff<Id: Hash + Eq, T> {
307    new_or_updated: HashMap<Id, T>,
308    deleted: HashSet<Id>,
309}
310
311impl<Id: Hash + Eq, T> NodeDiff<Id, T> {
312    pub fn new(new_or_updated: HashMap<Id, T>, deleted: HashSet<Id>) -> NodeDiff<Id, T> {
313        NodeDiff {
314            new_or_updated,
315            deleted,
316        }
317    }
318    pub fn get_new_or_updated(&self) -> &HashMap<Id, T> {
319        &self.new_or_updated
320    }
321    pub fn get_deleted(&self) -> &HashSet<Id> {
322        &self.deleted
323    }
324}
325
326/// A diff between the edges of a graph.
327#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
328#[serde(rename_all = "camelCase")]
329pub struct EdgeDiff<Id: Hash + Eq, W = f32> {
330    new_or_updated: HashMap<Id, HashMap<Id, W>>,
331    deleted: HashMap<Id, HashSet<Id>>,
332}
333
334impl<Id: Hash + Eq> EdgeDiff<Id> {
335    pub fn new(
336        new_or_updated: HashMap<Id, HashMap<Id, f32>>,
337        deleted: HashMap<Id, HashSet<Id>>,
338    ) -> EdgeDiff<Id> {
339        EdgeDiff {
340            new_or_updated,
341            deleted,
342        }
343    }
344    pub fn get_new_or_updated(&self) -> &HashMap<Id, HashMap<Id, f32>> {
345        &self.new_or_updated
346    }
347    pub fn get_deleted(&self) -> &HashMap<Id, HashSet<Id>> {
348        &self.deleted
349    }
350}
351
352#[cfg(test)]
353mod tests {
354
355    use super::*;
356    use crate::node_update::NodeUpdate;
357    use hashbrown::HashMap;
358
359    #[test]
360    fn test_node() {
361        let mut diff = GraphDiff::<usize, NodeUpdate>::new();
362
363        let id = 1;
364        let mut node = NodeUpdate {
365            label: Some("test".to_string()),
366            ..NodeUpdate::default()
367        };
368
369        diff.add_node(&id);
370        diff.add_or_update_node(&id, node.clone());
371        assert_eq!(diff.nodes.new_or_updated.get(&id).unwrap(), &node);
372
373        node.size = Some(10.0);
374        diff.add_or_update_node(&id, node.clone());
375        assert_eq!(diff.nodes.new_or_updated.get(&id).unwrap(), &node);
376
377        let node2 = NodeUpdate {
378            green: Some(5),
379            ..NodeUpdate::default()
380        };
381        diff.add_or_update_node(&id, node2.clone());
382
383        let combined = NodeUpdate {
384            label: Some("test".to_string()),
385            size: Some(10.0),
386            green: Some(5),
387            ..NodeUpdate::default()
388        };
389        assert_eq!(diff.nodes.new_or_updated.get(&id).unwrap(), &combined);
390
391        diff.delete_node(id);
392        assert!(diff.nodes.new_or_updated.is_empty());
393    }
394
395    #[test]
396    fn test_edge() {
397        let mut diff = GraphDiff::<usize, NodeUpdate>::new();
398
399        let from = 1;
400        let to = 2;
401        let weight = 1.0;
402
403        diff.add_edge(&from, &to, weight).unwrap();
404        assert_eq!(
405            diff.edges
406                .new_or_updated
407                .get(&from)
408                .unwrap()
409                .get(&to)
410                .unwrap(),
411            &weight
412        );
413
414        let weight2 = 2.0;
415        diff.add_edge(&from, &to, weight2).unwrap();
416        assert_eq!(
417            diff.edges
418                .new_or_updated
419                .get(&from)
420                .unwrap()
421                .get(&to)
422                .unwrap(),
423            &weight2
424        );
425
426        diff.delete_node(from);
427        assert!(diff.edges.new_or_updated.is_empty());
428    }
429
430    #[test]
431    fn test_add_assign_nodes() {
432        let mut diff1 = GraphDiff::<usize, NodeUpdate>::new();
433        let node = NodeUpdate {
434            label: Some("test".to_string()),
435            ..NodeUpdate::default()
436        };
437        let node_other = NodeUpdate {
438            size: Some(10.0),
439            ..NodeUpdate::default()
440        };
441        diff1.add_node(&1);
442        diff1.add_or_update_node(&1, node.clone());
443        diff1.add_node(&2);
444        diff1.delete_node(3);
445
446        let mut diff2 = GraphDiff::<usize, NodeUpdate>::new();
447        diff2.add_node(&1);
448        diff2.add_or_update_node(&1, node_other.clone());
449        diff2.delete_node(2);
450
451        diff1 += diff2;
452
453        let d1 = diff1.nodes.new_or_updated.get(&1).unwrap();
454        assert_eq!(d1.label.as_ref().unwrap(), "test");
455        assert_eq!(d1.size.unwrap(), 10.0);
456        assert!(!diff1.nodes.new_or_updated.contains_key(&2));
457        assert!(diff1.nodes.deleted.contains(&2));
458        assert!(diff1.nodes.deleted.contains(&3));
459    }
460
461    #[test]
462    fn test_add_assign_edges() {
463        let mut diff1 = GraphDiff::<usize, NodeUpdate>::new();
464        diff1.add_edge(&1, &2, 1.0).unwrap();
465        diff1.add_edge(&1, &3, 2.0).unwrap();
466        diff1.add_edge(&1, &4, 2.0).unwrap();
467        diff1.add_edge(&2, &3, 3.0).unwrap();
468        diff1.add_edge(&3, &1, 4.0).unwrap();
469
470        let mut diff2 = GraphDiff::<usize, NodeUpdate>::new();
471        diff2.add_edge(&1, &2, 5.0).unwrap();
472        diff2.add_edge(&2, &3, 6.0).unwrap();
473        diff2.add_edge(&3, &1, 7.0).unwrap();
474        diff2.delete_edge(&1, &3);
475
476        diff1 += diff2;
477
478        assert_eq!(
479            diff1.edges.new_or_updated.get(&1).unwrap().get(&2).unwrap(),
480            &5.0
481        );
482        assert_eq!(
483            diff1.edges.new_or_updated.get(&2).unwrap().get(&3).unwrap(),
484            &6.0
485        );
486        assert_eq!(
487            diff1.edges.new_or_updated.get(&3).unwrap().get(&1).unwrap(),
488            &7.0
489        );
490        assert_eq!(
491            diff1.edges.new_or_updated.get(&1).unwrap().get(&4).unwrap(),
492            &2.0
493        );
494        assert!(diff1.edges.deleted.get(&1).unwrap().contains(&3));
495    }
496
497    #[test]
498    fn test_add_edges() {
499        let mut diff = GraphDiff::<usize, usize>::new();
500        for i in 0..50 {
501            diff.add_node(&i);
502        }
503
504        for i in 10..20 {
505            diff.delete_node(i);
506        }
507
508        let edges = (0..50usize)
509            .map(|i| {
510                let mut inner = HashMap::new();
511                for j in 0..i {
512                    inner.insert(j, 1f32);
513                }
514                (i, inner)
515            })
516            .collect::<HashMap<usize, HashMap<usize, f32>>>();
517
518        // check can't add if nodes are deleted
519        let mut diff2 = diff.clone();
520        for i in 10..20 {
521            diff2.delete_node(i);
522        }
523        assert!(diff2.add_edges(&edges).is_err());
524
525        for i in 30..40 {
526            diff.delete_node(i);
527        }
528
529        assert!(diff.is_internally_consistent());
530    }
531}