1use std::borrow::Borrow;
4use std::collections::hash_map::RandomState;
5use std::fmt::{Debug, Formatter, Result as FmtResult};
6use std::hash::{BuildHasher, Hash};
7use std::iter::FromIterator;
8
9use crossbeam_epoch;
10
11#[cfg(feature = "rayon")]
12use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator};
13
14use crate::raw::config::Trivial as TrivialConfig;
15use crate::raw::{self, Raw};
16
17pub struct ConSet<T, S = RandomState>
64where
65 T: Clone + Hash + Eq + 'static,
66{
67 raw: Raw<TrivialConfig<T>, S>,
68}
69
70impl<T> ConSet<T, RandomState>
71where
72 T: Clone + Hash + Eq + 'static,
73{
74 pub fn new() -> Self {
76 Self::with_hasher(RandomState::default())
77 }
78}
79
80impl<T, S> ConSet<T, S>
81where
82 T: Clone + Hash + Eq + 'static,
83 S: BuildHasher,
84{
85 pub fn with_hasher(hasher: S) -> Self {
87 Self {
88 raw: Raw::with_hasher(hasher),
89 }
90 }
91
92 pub fn insert(&self, value: T) -> Option<T> {
96 let pin = crossbeam_epoch::pin();
97 self.raw.insert(value, &pin).cloned()
98 }
99
100 pub fn get<Q>(&self, key: &Q) -> Option<T>
104 where
105 Q: ?Sized + Eq + Hash,
106 T: Borrow<Q>,
107 {
108 let pin = crossbeam_epoch::pin();
109 self.raw.get(key, &pin).cloned()
110 }
111
112 pub fn contains<Q>(&self, key: &Q) -> bool
117 where
118 Q: ?Sized + Eq + Hash,
119 T: Borrow<Q>,
120 {
121 let pin = crossbeam_epoch::pin();
122 self.raw.get(key, &pin).is_some()
123 }
124
125 pub fn remove<Q>(&self, key: &Q) -> Option<T>
127 where
128 Q: ?Sized + Eq + Hash,
129 T: Borrow<Q>,
130 {
131 let pin = crossbeam_epoch::pin();
132 self.raw.remove(key, &pin).cloned()
133 }
134
135 pub fn is_empty(&self) -> bool {
140 self.raw.is_empty()
141 }
142}
143
144impl<T> Default for ConSet<T, RandomState>
145where
146 T: Clone + Hash + Eq + 'static,
147{
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153impl<T, S> Debug for ConSet<T, S>
154where
155 T: Debug + Clone + Hash + Eq + 'static,
156{
157 fn fmt(&self, fmt: &mut Formatter) -> FmtResult {
158 fmt.debug_set().entries(self.iter()).finish()
159 }
160}
161
162impl<T, S> ConSet<T, S>
163where
164 T: Clone + Hash + Eq + 'static,
165{
166 pub fn iter(&self) -> Iter<T, S> {
168 Iter {
169 inner: raw::iterator::Iter::new(&self.raw),
170 }
171 }
172}
173
174pub struct Iter<'a, T, S>
178where
179 T: Clone + Hash + Eq + 'static,
180{
181 inner: raw::iterator::Iter<'a, TrivialConfig<T>, S>,
182}
183
184impl<'a, T, S> Iterator for Iter<'a, T, S>
185where
186 T: Clone + Hash + Eq + 'static,
187{
188 type Item = T;
189
190 fn next(&mut self) -> Option<T> {
191 self.inner.next().cloned()
192 }
193}
194
195impl<'a, T, S> IntoIterator for &'a ConSet<T, S>
196where
197 T: Clone + Hash + Eq + 'static,
198{
199 type Item = T;
200 type IntoIter = Iter<'a, T, S>;
201
202 fn into_iter(self) -> Self::IntoIter {
203 self.iter()
204 }
205}
206
207impl<'a, T, S> Extend<T> for &'a ConSet<T, S>
208where
209 T: Clone + Hash + Eq + 'static,
210 S: BuildHasher,
211{
212 fn extend<I>(&mut self, iter: I)
213 where
214 I: IntoIterator<Item = T>,
215 {
216 for n in iter {
217 self.insert(n);
218 }
219 }
220}
221
222impl<T, S> Extend<T> for ConSet<T, S>
223where
224 T: Clone + Hash + Eq + 'static,
225 S: BuildHasher,
226{
227 fn extend<I>(&mut self, iter: I)
228 where
229 I: IntoIterator<Item = T>,
230 {
231 let mut me: &ConSet<_, _> = self;
232 me.extend(iter);
233 }
234}
235
236impl<T> FromIterator<T> for ConSet<T>
237where
238 T: Clone + Hash + Eq + 'static,
239{
240 fn from_iter<I>(iter: I) -> Self
241 where
242 I: IntoIterator<Item = T>,
243 {
244 let mut me = ConSet::new();
245 me.extend(iter);
246 me
247 }
248}
249
250#[cfg(feature = "rayon")]
251impl<'a, T, S> ParallelExtend<T> for &'a ConSet<T, S>
252where
253 T: Clone + Hash + Eq + Send + Sync,
254 S: BuildHasher + Sync,
255{
256 fn par_extend<I>(&mut self, par_iter: I)
257 where
258 I: IntoParallelIterator<Item = T>,
259 {
260 par_iter.into_par_iter().for_each(|n| {
261 self.insert(n);
262 });
263 }
264}
265
266#[cfg(feature = "rayon")]
267impl<T, S> ParallelExtend<T> for ConSet<T, S>
268where
269 T: Clone + Hash + Eq + Send + Sync,
270 S: BuildHasher + Sync,
271{
272 fn par_extend<I>(&mut self, par_iter: I)
273 where
274 I: IntoParallelIterator<Item = T>,
275 {
276 let mut me: &ConSet<_, _> = self;
277 me.par_extend(par_iter);
278 }
279}
280
281#[cfg(feature = "rayon")]
282impl<T> FromParallelIterator<T> for ConSet<T>
283where
284 T: Clone + Hash + Eq + Send + Sync,
285{
286 fn from_par_iter<I>(iter: I) -> Self
287 where
288 I: IntoParallelIterator<Item = T>,
289 {
290 let mut me = ConSet::new();
291 me.par_extend(iter);
292 me
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use crossbeam_utils::thread;
299 #[cfg(feature = "rayon")]
300 use rayon::prelude::*;
301
302 use super::*;
303 use crate::raw::tests::NoHasher;
304 use crate::raw::LEVEL_CELLS;
305
306 const TEST_THREADS: usize = 4;
307 const TEST_BATCH: usize = 10000;
308 const TEST_BATCH_SMALL: usize = 100;
309 const TEST_REP: usize = 20;
310
311 #[test]
312 fn debug_when_empty() {
313 let set: ConSet<String> = ConSet::new();
314 assert_eq!("{}", &format!("{:?}", set));
315 }
316
317 #[test]
318 fn debug_when_has_elements() {
319 let set: ConSet<&str> = ConSet::new();
320 assert!(set.insert("hello").is_none());
321 assert!(set.insert("world").is_none());
322 let expected = "{\"hello\", \"world\"}";
323 let actual = &format!("{:?}", set);
324
325 let mut expected_chars: Vec<char> = expected.chars().collect();
326 expected_chars.sort();
327 let mut actual_chars: Vec<char> = actual.chars().collect();
328 actual_chars.sort();
329 assert_eq!(expected_chars, actual_chars);
330 }
331
332 #[test]
333 fn debug_when_elements_are_added_and_removed() {
334 let set: ConSet<&str> = ConSet::new();
335 assert_eq!("{}", &format!("{:?}", set));
336 assert!(set.insert("hello").is_none());
337 assert!(set.insert("hello").is_some());
338 assert!(set.insert("hello").is_some());
339 assert_eq!("{\"hello\"}", &format!("{:?}", set));
340 assert!(set.remove("hello").is_some());
341 assert_eq!("{}", &format!("{:?}", set));
342 }
343
344 #[test]
345 fn create_destroy() {
346 let set: ConSet<String> = ConSet::new();
347 drop(set);
348 }
349
350 #[test]
351 fn lookup_empty() {
352 let set: ConSet<String> = ConSet::new();
353 assert!(set.get("hello").is_none());
354 }
355
356 #[test]
357 fn insert_lookup() {
358 let set = ConSet::new();
359 assert!(set.insert("hello").is_none());
360 assert!(set.get("world").is_none());
361 let found = set.get("hello").unwrap();
362 assert_eq!("hello", found);
363 }
364
365 #[test]
367 fn insert_many() {
368 let set = ConSet::new();
369 for i in 0..TEST_BATCH * LEVEL_CELLS {
370 assert!(set.insert(i).is_none());
371 }
372
373 for i in 0..TEST_BATCH * LEVEL_CELLS {
374 assert_eq!(i, set.get(&i).unwrap());
375 }
376 }
377
378 #[test]
379 fn par_insert_many() {
380 for _ in 0..TEST_REP {
381 let set: ConSet<usize> = ConSet::new();
382 thread::scope(|s| {
383 for t in 0..TEST_THREADS {
384 let set = &set;
385 s.spawn(move |_| {
386 for i in 0..TEST_BATCH {
387 let num = t * TEST_BATCH + i;
388 assert!(set.insert(num).is_none());
389 }
390 });
391 }
392 })
393 .unwrap();
394
395 for i in 0..TEST_BATCH * TEST_THREADS {
396 assert_eq!(set.get(&i).unwrap(), i);
397 }
398 }
399 }
400
401 #[test]
402 fn par_get_many() {
403 for _ in 0..TEST_REP {
404 let set = ConSet::new();
405 for i in 0..TEST_BATCH * TEST_THREADS {
406 assert!(set.insert(i).is_none());
407 }
408 thread::scope(|s| {
409 for t in 0..TEST_THREADS {
410 let set = &set;
411 s.spawn(move |_| {
412 for i in 0..TEST_BATCH {
413 let num = t * TEST_BATCH + i;
414 assert_eq!(set.get(&num).unwrap(), num);
415 }
416 });
417 }
418 })
419 .unwrap();
420 }
421 }
422
423 #[test]
424 fn no_collisions() {
425 let set = ConSet::with_hasher(NoHasher);
426 for i in 0..TEST_BATCH_SMALL {
428 assert!(set.insert(i).is_none());
429 }
430 for i in 0..TEST_BATCH_SMALL {
432 assert_eq!(i, set.get(&i).unwrap());
433 }
434 for i in 0..TEST_BATCH_SMALL {
436 assert_eq!(i, set.insert(i).unwrap());
437 }
438 }
439
440 #[test]
441 fn simple_remove() {
442 let set = ConSet::new();
443 assert!(set.remove(&42).is_none());
444 assert!(set.insert(42).is_none());
445 assert_eq!(42, set.get(&42).unwrap());
446 assert_eq!(42, set.remove(&42).unwrap());
447 assert!(set.get(&42).is_none());
448 assert!(set.is_empty());
449 assert!(set.remove(&42).is_none());
450 assert!(set.is_empty());
451 }
452
453 fn remove_many_inner<H: BuildHasher>(mut set: ConSet<usize, H>, len: usize) {
454 for i in 0..len {
455 assert!(set.insert(i).is_none());
456 }
457 for i in 0..len {
458 assert_eq!(i, set.get(&i).unwrap());
459 assert_eq!(i, set.remove(&i).unwrap());
460 assert!(set.get(&i).is_none());
461 set.raw.assert_pruned();
462 }
463
464 assert!(set.is_empty());
465 }
466
467 #[test]
468 fn remove_many() {
469 remove_many_inner(ConSet::new(), TEST_BATCH);
470 }
471
472 #[test]
473 fn remove_many_collision() {
474 remove_many_inner(ConSet::with_hasher(NoHasher), TEST_BATCH_SMALL);
475 }
476
477 #[test]
478 fn collision_remove_one_left() {
479 let mut set = ConSet::with_hasher(NoHasher);
480 set.insert(1);
481 set.insert(2);
482
483 set.raw.assert_pruned();
484
485 assert!(set.remove(&2).is_some());
486 set.raw.assert_pruned();
487
488 assert!(set.remove(&1).is_some());
489
490 set.raw.assert_pruned();
491 assert!(set.is_empty());
492 }
493
494 #[test]
495 fn collision_remove_one_left_with_str() {
496 let mut set = ConSet::with_hasher(NoHasher);
497 set.insert("hello");
498 set.insert("world");
499
500 set.raw.assert_pruned();
501
502 assert!(set.remove("world").is_some());
503 set.raw.assert_pruned();
504
505 assert!(set.remove("hello").is_some());
506
507 set.raw.assert_pruned();
508 assert!(set.is_empty());
509 }
510
511 #[test]
512 fn remove_par() {
513 let mut set = ConSet::new();
514 for i in 0..TEST_THREADS * TEST_BATCH {
515 set.insert(i);
516 }
517
518 thread::scope(|s| {
519 for t in 0..TEST_THREADS {
520 let set = &set;
521 s.spawn(move |_| {
522 for i in 0..TEST_BATCH {
523 let num = t * TEST_BATCH + i;
524 let val = set.remove(&num).unwrap();
525 assert_eq!(num, val);
526 assert_eq!(num, val);
527 }
528 });
529 }
530 })
531 .unwrap();
532
533 set.raw.assert_pruned();
534 assert!(set.is_empty());
535 }
536
537 fn iter_test_inner<S: BuildHasher>(set: ConSet<usize, S>) {
538 for i in 0..TEST_BATCH_SMALL {
539 assert!(set.insert(i).is_none());
540 }
541
542 let mut extracted = set.iter().collect::<Vec<_>>();
543
544 extracted.sort();
545 let expected = (0..TEST_BATCH_SMALL).collect::<Vec<_>>();
546 assert_eq!(expected, extracted);
547 }
548
549 #[test]
550 fn iter() {
551 let set = ConSet::new();
552 iter_test_inner(set);
553 }
554
555 #[test]
556 fn iter_collision() {
557 let set = ConSet::with_hasher(NoHasher);
558 iter_test_inner(set);
559 }
560
561 #[test]
562 fn collect() {
563 let set = (0..TEST_BATCH_SMALL).collect::<ConSet<_>>();
564
565 let mut extracted = set.iter().collect::<Vec<_>>();
566 extracted.sort();
567 let expected = (0..TEST_BATCH_SMALL).collect::<Vec<_>>();
568 assert_eq!(expected, extracted);
569 }
570
571 #[test]
572 fn par_extend() {
573 let set = ConSet::new();
574
575 thread::scope(|s| {
576 for t in 0..TEST_THREADS {
577 let mut set = &set;
578 s.spawn(move |_| {
579 let start = t * TEST_BATCH_SMALL;
580 let iter = start..start + TEST_BATCH_SMALL;
581 set.extend(iter);
582 });
583 }
584 })
585 .unwrap();
586
587 let mut extracted = set.iter().collect::<Vec<_>>();
588
589 extracted.sort();
590 let expected = (0..TEST_THREADS * TEST_BATCH_SMALL).collect::<Vec<_>>();
591
592 assert_eq!(expected, extracted);
593 }
594
595 #[cfg(feature = "rayon")]
596 #[test]
597 fn rayon_extend() {
598 let mut map = ConSet::new();
599 map.par_extend((0..TEST_BATCH_SMALL).into_par_iter());
600
601 let mut extracted = map.iter().collect::<Vec<_>>();
602 extracted.par_sort();
603
604 let expected = (0..TEST_BATCH_SMALL).collect::<Vec<_>>();
605 assert_eq!(expected, extracted);
606 }
607
608 #[cfg(feature = "rayon")]
609 #[test]
610 fn rayon_from_par_iter() {
611 let map = ConSet::from_par_iter((0..TEST_BATCH_SMALL).into_par_iter());
612
613 let mut extracted = map.iter().collect::<Vec<_>>();
614 extracted.sort();
615
616 let expected = (0..TEST_BATCH_SMALL).collect::<Vec<_>>();
617 assert_eq!(expected, extracted);
618 }
619}