1use crate::Diff;
5use serde::Deserialize;
6use serde::Serialize;
7use std::collections::HashSet;
8use std::fmt::Debug;
9use std::fmt::Formatter;
10
11use std::hash::Hash;
12use std::iter::empty;
13
14#[derive(Clone, PartialEq, Serialize, Deserialize)]
15pub struct DiffHashSet<T: Diff>(#[serde(skip_serializing_if = "Option::is_none")] pub Option<Vec<InnerValue<T>>>);
16
17#[derive(Clone, PartialEq, Serialize, Deserialize)]
18#[serde(untagged)]
19pub enum InnerValue<T: Diff> {
20 Add(<T as Diff>::Type),
21 Remove { remove: <T as Diff>::Type },
22}
23
24impl<T> Diff for HashSet<T>
25where
26 T: Debug + Clone + PartialEq + Eq + Diff + Hash + for<'de> Deserialize<'de> + Serialize,
27{
28 type Type = DiffHashSet<T>;
29
30 fn diff(&self, other: &Self) -> crate::Result<Self::Type> {
31 Ok(DiffHashSet(if self == other {
32 None
33 } else {
34 let mut val_diffs: Vec<InnerValue<T>> = vec![];
35 for add in other.difference(self) {
36 let add = add.clone().into_diff()?;
37 val_diffs.push(InnerValue::Add(add));
38 }
39
40 for remove in self.difference(other) {
41 let remove = remove.clone().into_diff()?;
42 val_diffs.push(InnerValue::Remove { remove });
43 }
44
45 Some(val_diffs)
46 }))
47 }
48
49 fn merge(&self, diff: Self::Type) -> crate::Result<Self> {
50 match diff.0 {
51 None => Ok(self.clone()),
52 Some(val_diffs) => {
53 let mut new: Self = self.clone();
54 for val_diff in val_diffs {
55 match val_diff {
56 InnerValue::Add(val) => {
57 new.insert(<T>::from_diff(val)?);
58 }
59 InnerValue::Remove { remove } => {
60 new.remove(&(<T>::from_diff(remove)?));
61 }
62 }
63 }
64 Ok(new)
65 }
66 }
67 }
68
69 fn into_diff(self) -> crate::Result<Self::Type> {
70 Ok(DiffHashSet(if self.is_empty() {
71 None
72 } else {
73 let mut diffs: Vec<InnerValue<T>> = vec![];
74 for val in self {
75 diffs.push(InnerValue::Add(val.into_diff()?));
76 }
77 Some(diffs)
78 }))
79 }
80
81 fn from_diff(diff: Self::Type) -> crate::Result<Self> {
82 let mut set = Self::new();
83 if let Some(vals) = diff.0 {
84 for val in vals {
85 match val {
86 InnerValue::Add(val) => {
87 set.insert(<T>::from_diff(val)?);
88 }
89 InnerValue::Remove { remove } => {
90 let val = <T>::from_diff(remove)?;
91 set.remove(&val);
92 }
93 }
94 }
95 }
96 Ok(set)
97 }
98}
99
100impl<T> Debug for DiffHashSet<T>
101where
102 T: Debug + Diff,
103{
104 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
105 write!(f, "DiffHashSet")?;
106 let mut buf = f.debug_list();
107 if let Some(d) = &self.0 {
108 buf.entries(d.iter());
109 } else {
110 buf.entries(empty::<Vec<InnerValue<T>>>());
111 }
112 buf.finish()
113 }
114}
115
116impl<T> Debug for InnerValue<T>
117where
118 T: Debug + Diff,
119{
120 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
121 match &self {
122 Self::Add(val) => f.debug_tuple("Add").field(val).finish(),
123 Self::Remove { remove } => f.debug_tuple("Remove").field(remove).finish(),
124 }
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use std::collections::HashSet;
132
133 macro_rules! set {
134 ($($val:expr),* $(,)?) => {{ #[allow(redundant_semicolons)] {
135 let mut set = HashSet::new();
136 $( set.insert($val); )* ;
137 set
138 }}}
139 }
140
141 #[test]
142 fn test_hashset_diff() {
143 let s: HashSet<String> = set! {
144 "test".into(),
145 "foo".into(),
146 };
147
148 let s1: HashSet<String> = set! {
149 "test".into(),
150 "foo".into(),
151 };
152
153 let diff = s.diff(&s1).unwrap();
154 let expected = DiffHashSet(None);
155
156 assert_eq!(diff, expected);
157 let s2 = s.merge(diff).unwrap();
158
159 assert_eq!(s, s2);
160 assert_eq!(s1, s2);
161 }
162
163 #[test]
164 fn test_hashset_diff_add_and_remove() {
165 let s: HashSet<String> = set! {
166 "test".into(),
167 "foo".into(),
168 "faux".into(),
169 };
170
171 let s1: HashSet<String> = set! {
172 "test".into(),
173 "foo".into(),
174 "bar".into(),
175 };
176
177 let diff = s.diff(&s1).unwrap();
178
179 let json = serde_json::to_string(&diff).unwrap();
180
181 println!("{}", json);
182
183 let diff: DiffHashSet<String> = serde_json::from_str(&json).unwrap();
184
185 println!("{:?}", diff);
186 }
187}