1use crate::{
7 codec::{EncodeSize, Read, Write},
8 error::Error,
9 RangeCfg,
10};
11use bytes::{Buf, BufMut};
12use std::{cmp::Ordering, collections::HashSet, hash::Hash};
13
14const HASHSET_TYPE: &str = "HashSet";
15
16fn read_ordered_set<K, F>(
18 buf: &mut impl Buf,
19 len: usize,
20 cfg: &K::Cfg,
21 mut insert: F,
22 set_type: &'static str,
23) -> Result<(), Error>
24where
25 K: Read + Ord,
26 F: FnMut(K) -> bool,
27{
28 let mut last: Option<K> = None;
29 for _ in 0..len {
30 let item = K::read_cfg(buf, cfg)?;
32
33 if let Some(ref last) = last {
35 match item.cmp(last) {
36 Ordering::Equal => return Err(Error::Invalid(set_type, "Duplicate item")),
37 Ordering::Less => return Err(Error::Invalid(set_type, "Items must ascend")),
38 _ => {}
39 }
40 }
41
42 if let Some(last) = last.take() {
44 insert(last);
45 }
46 last = Some(item);
47 }
48
49 if let Some(last) = last {
51 insert(last);
52 }
53
54 Ok(())
55}
56
57impl<K: Ord + Hash + Eq + Write> Write for HashSet<K> {
58 fn write(&self, buf: &mut impl BufMut) {
59 self.len().write(buf);
60
61 let mut items: Vec<_> = self.iter().collect();
63 items.sort();
64 for item in items {
65 item.write(buf);
66 }
67 }
68}
69
70impl<K: Ord + Hash + Eq + EncodeSize> EncodeSize for HashSet<K> {
71 fn encode_size(&self) -> usize {
72 let mut size = self.len().encode_size();
73
74 for item in self {
76 size += item.encode_size();
77 }
78 size
79 }
80}
81
82impl<K: Read + Clone + Ord + Hash + Eq> Read for HashSet<K> {
83 type Cfg = (RangeCfg, K::Cfg);
84
85 fn read_cfg(buf: &mut impl Buf, (range, cfg): &Self::Cfg) -> Result<Self, Error> {
86 let len = usize::read_cfg(buf, range)?;
88 let mut set = HashSet::with_capacity(len);
89
90 read_ordered_set(buf, len, cfg, |item| set.insert(item), HASHSET_TYPE)?;
92
93 Ok(set)
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use crate::{
101 codec::{Decode, Encode},
102 FixedSize,
103 };
104 use bytes::{Bytes, BytesMut};
105 use std::fmt::Debug;
106
107 fn round_trip_hash<K>(set: &HashSet<K>, range_cfg: RangeCfg, item_cfg: K::Cfg)
109 where
110 K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + Debug + PartialEq,
111 HashSet<K>: Read<Cfg = (RangeCfg, K::Cfg)>
112 + Decode<Cfg = (RangeCfg, K::Cfg)>
113 + Debug
114 + PartialEq
115 + Write
116 + EncodeSize,
117 {
118 let encoded = set.encode();
119 assert_eq!(set.encode_size(), encoded.len());
120 let config_tuple = (range_cfg, item_cfg);
121 let decoded = HashSet::<K>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
122 assert_eq!(set, &decoded);
123 }
124
125 #[test]
128 fn test_empty_hashset() {
129 let set = HashSet::<u32>::new();
130 round_trip_hash(&set, (..).into(), ());
131 assert_eq!(set.encode_size(), 1); let encoded = set.encode();
133 assert_eq!(encoded, Bytes::from_static(&[0]));
134 }
135
136 #[test]
137 fn test_simple_hashset_u32() {
138 let mut set = HashSet::new();
139 set.insert(1u32);
140 set.insert(5u32);
141 set.insert(2u32);
142 round_trip_hash(&set, (..).into(), ());
143 assert_eq!(set.encode_size(), 1 + 3 * u32::SIZE);
145 let mut expected = BytesMut::new();
147 3usize.write(&mut expected); 1u32.write(&mut expected);
149 2u32.write(&mut expected);
150 5u32.write(&mut expected);
151 assert_eq!(set.encode(), expected.freeze());
152 }
153
154 #[test]
155 fn test_large_hashset() {
156 let set: HashSet<_> = (0..1000u16).collect();
158 round_trip_hash(&set, (1000..=1000).into(), ());
159
160 let set: HashSet<_> = (0..1000usize).collect();
162 round_trip_hash(&set, (1000..=1000).into(), (..=1000).into());
163 }
164
165 #[test]
166 fn test_hashset_with_variable_items() {
167 let mut set = HashSet::new();
168 set.insert(Bytes::from_static(b"apple"));
169 set.insert(Bytes::from_static(b"banana"));
170 set.insert(Bytes::from_static(b"cherry"));
171
172 let set_range = 0..=10;
173 let item_range = ..=10; round_trip_hash(&set, set_range.into(), item_range.into());
176 }
177
178 #[test]
179 fn test_hashset_decode_length_limit_exceeded() {
180 let mut set = HashSet::new();
181 set.insert(1u32);
182 set.insert(5u32);
183
184 let encoded = set.encode();
185 let config_tuple = ((0..=1).into(), ());
186
187 let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
188 assert!(matches!(result, Err(Error::InvalidLength(2))));
189 }
190
191 #[test]
192 fn test_hashset_decode_item_length_limit_exceeded() {
193 let mut set = HashSet::new();
194 set.insert(Bytes::from_static(b"longitem")); let set_range = 0..=10;
197 let restrictive_item_range = ..=5; let encoded = set.encode();
200 let config_tuple = (set_range.into(), restrictive_item_range.into());
201 let result = HashSet::<Bytes>::decode_cfg(encoded, &config_tuple);
202
203 assert!(matches!(result, Err(Error::InvalidLength(8))));
204 }
205
206 #[test]
207 fn test_hashset_decode_invalid_item_order() {
208 let mut encoded = BytesMut::new();
209 2usize.write(&mut encoded); 5u32.write(&mut encoded); 2u32.write(&mut encoded); let config_tuple = ((..).into(), ());
214
215 let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
216 assert!(matches!(
217 result,
218 Err(Error::Invalid("HashSet", "Items must ascend"))
219 ));
220 }
221
222 #[test]
223 fn test_hashset_decode_duplicate_item() {
224 let mut encoded = BytesMut::new();
225 2usize.write(&mut encoded); 1u32.write(&mut encoded); 1u32.write(&mut encoded); let config_tuple = ((..).into(), ());
230 let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
231 assert!(matches!(
232 result,
233 Err(Error::Invalid("HashSet", "Duplicate item"))
234 ));
235 }
236
237 #[test]
238 fn test_hashset_decode_end_of_buffer() {
239 let mut set = HashSet::new();
240 set.insert(1u32);
241 set.insert(5u32);
242
243 let mut encoded = set.encode(); encoded.truncate(set.encode_size() - 2); let config_tuple = ((..).into(), ());
247 let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
248 assert!(matches!(result, Err(Error::EndOfBuffer)));
249 }
250
251 #[test]
252 fn test_hashset_decode_extra_data() {
253 let mut set = HashSet::new();
254 set.insert(1u32);
255
256 let mut encoded = set.encode();
257 encoded.put_u8(0xFF); let config_tuple = ((..).into(), ()); let result = HashSet::<u32>::decode_cfg(encoded.clone(), &config_tuple);
262 assert!(matches!(result, Err(Error::ExtraData(1))));
263
264 let read_result = HashSet::<u32>::read_cfg(&mut encoded.clone(), &config_tuple);
266 assert!(read_result.is_ok());
267 let decoded_set = read_result.unwrap();
268 assert_eq!(decoded_set.len(), 1);
269 assert!(decoded_set.contains(&1u32));
270 }
271
272 #[test]
273 fn test_hashset_deterministic_encoding() {
274 let mut set1 = HashSet::new();
275 (0..1000u32).for_each(|i| {
276 set1.insert(i);
277 });
278
279 let mut set2 = HashSet::new();
280 (0..1000u32).rev().for_each(|i| {
281 set2.insert(i);
282 });
283
284 assert_eq!(set1.encode(), set2.encode());
285 }
286
287 #[test]
288 fn test_hashset_conformity() {
289 let set1 = HashSet::<u8>::new();
291 let mut expected1 = BytesMut::new();
292 0usize.write(&mut expected1); assert_eq!(set1.encode(), expected1.freeze());
294 assert_eq!(set1.encode_size(), 1);
295
296 let mut set2 = HashSet::<u8>::new();
299 set2.insert(5u8);
300 set2.insert(1u8);
301 set2.insert(2u8);
302
303 let mut expected2 = BytesMut::new();
304 3usize.write(&mut expected2); 1u8.write(&mut expected2); 2u8.write(&mut expected2); 5u8.write(&mut expected2); assert_eq!(set2.encode(), expected2.freeze());
309 assert_eq!(set2.encode_size(), 1 + 3 * u8::SIZE);
310
311 let mut set3 = HashSet::<Bytes>::new();
314 set3.insert(Bytes::from_static(b"cherry"));
315 set3.insert(Bytes::from_static(b"apple"));
316 set3.insert(Bytes::from_static(b"banana"));
317
318 let mut expected3 = BytesMut::new();
319 3usize.write(&mut expected3); Bytes::from_static(b"apple").write(&mut expected3);
321 Bytes::from_static(b"banana").write(&mut expected3);
322 Bytes::from_static(b"cherry").write(&mut expected3);
323 assert_eq!(set3.encode(), expected3.freeze());
324 let expected_size = 1usize.encode_size()
325 + Bytes::from_static(b"apple").encode_size()
326 + Bytes::from_static(b"banana").encode_size()
327 + Bytes::from_static(b"cherry").encode_size();
328 assert_eq!(set3.encode_size(), expected_size);
329 }
330}