identity_diff/
hashset.rs

1// Copyright 2020-2021 IOTA Stiftung
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::Diff;
5use serde::Deserialize;
6use serde::Serialize;
7use std::collections::HashSet;
8use std::fmt::Debug;
9use std::fmt::Formatter;
10
11use std::hash::Hash;
12use std::iter::empty;
13
14#[derive(Clone, PartialEq, Serialize, Deserialize)]
15pub struct DiffHashSet<T: Diff>(#[serde(skip_serializing_if = "Option::is_none")] pub Option<Vec<InnerValue<T>>>);
16
17#[derive(Clone, PartialEq, Serialize, Deserialize)]
18#[serde(untagged)]
19pub enum InnerValue<T: Diff> {
20  Add(<T as Diff>::Type),
21  Remove { remove: <T as Diff>::Type },
22}
23
24impl<T> Diff for HashSet<T>
25where
26  T: Debug + Clone + PartialEq + Eq + Diff + Hash + for<'de> Deserialize<'de> + Serialize,
27{
28  type Type = DiffHashSet<T>;
29
30  fn diff(&self, other: &Self) -> crate::Result<Self::Type> {
31    Ok(DiffHashSet(if self == other {
32      None
33    } else {
34      let mut val_diffs: Vec<InnerValue<T>> = vec![];
35      for add in other.difference(self) {
36        let add = add.clone().into_diff()?;
37        val_diffs.push(InnerValue::Add(add));
38      }
39
40      for remove in self.difference(other) {
41        let remove = remove.clone().into_diff()?;
42        val_diffs.push(InnerValue::Remove { remove });
43      }
44
45      Some(val_diffs)
46    }))
47  }
48
49  fn merge(&self, diff: Self::Type) -> crate::Result<Self> {
50    match diff.0 {
51      None => Ok(self.clone()),
52      Some(val_diffs) => {
53        let mut new: Self = self.clone();
54        for val_diff in val_diffs {
55          match val_diff {
56            InnerValue::Add(val) => {
57              new.insert(<T>::from_diff(val)?);
58            }
59            InnerValue::Remove { remove } => {
60              new.remove(&(<T>::from_diff(remove)?));
61            }
62          }
63        }
64        Ok(new)
65      }
66    }
67  }
68
69  fn into_diff(self) -> crate::Result<Self::Type> {
70    Ok(DiffHashSet(if self.is_empty() {
71      None
72    } else {
73      let mut diffs: Vec<InnerValue<T>> = vec![];
74      for val in self {
75        diffs.push(InnerValue::Add(val.into_diff()?));
76      }
77      Some(diffs)
78    }))
79  }
80
81  fn from_diff(diff: Self::Type) -> crate::Result<Self> {
82    let mut set = Self::new();
83    if let Some(vals) = diff.0 {
84      for val in vals {
85        match val {
86          InnerValue::Add(val) => {
87            set.insert(<T>::from_diff(val)?);
88          }
89          InnerValue::Remove { remove } => {
90            let val = <T>::from_diff(remove)?;
91            set.remove(&val);
92          }
93        }
94      }
95    }
96    Ok(set)
97  }
98}
99
100impl<T> Debug for DiffHashSet<T>
101where
102  T: Debug + Diff,
103{
104  fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
105    write!(f, "DiffHashSet")?;
106    let mut buf = f.debug_list();
107    if let Some(d) = &self.0 {
108      buf.entries(d.iter());
109    } else {
110      buf.entries(empty::<Vec<InnerValue<T>>>());
111    }
112    buf.finish()
113  }
114}
115
116impl<T> Debug for InnerValue<T>
117where
118  T: Debug + Diff,
119{
120  fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
121    match &self {
122      Self::Add(val) => f.debug_tuple("Add").field(val).finish(),
123      Self::Remove { remove } => f.debug_tuple("Remove").field(remove).finish(),
124    }
125  }
126}
127
128#[cfg(test)]
129mod tests {
130  use super::*;
131  use std::collections::HashSet;
132
133  macro_rules! set {
134        ($($val:expr),* $(,)?) => {{ #[allow(redundant_semicolons)] {
135            let mut set = HashSet::new();
136            $( set.insert($val); )* ;
137            set
138        }}}
139    }
140
141  #[test]
142  fn test_hashset_diff() {
143    let s: HashSet<String> = set! {
144        "test".into(),
145        "foo".into(),
146    };
147
148    let s1: HashSet<String> = set! {
149        "test".into(),
150        "foo".into(),
151    };
152
153    let diff = s.diff(&s1).unwrap();
154    let expected = DiffHashSet(None);
155
156    assert_eq!(diff, expected);
157    let s2 = s.merge(diff).unwrap();
158
159    assert_eq!(s, s2);
160    assert_eq!(s1, s2);
161  }
162
163  #[test]
164  fn test_hashset_diff_add_and_remove() {
165    let s: HashSet<String> = set! {
166        "test".into(),
167        "foo".into(),
168        "faux".into(),
169    };
170
171    let s1: HashSet<String> = set! {
172        "test".into(),
173        "foo".into(),
174        "bar".into(),
175    };
176
177    let diff = s.diff(&s1).unwrap();
178
179    let json = serde_json::to_string(&diff).unwrap();
180
181    println!("{}", json);
182
183    let diff: DiffHashSet<String> = serde_json::from_str(&json).unwrap();
184
185    println!("{:?}", diff);
186  }
187}