1use std::any::Any;
9use std::fmt;
10
11use futures::StreamExt;
12
13use super::hints::Flags;
14use super::AsyncNameSetQuery;
15use super::BoxVertexStream;
16use super::Hints;
17use super::NameSet;
18use crate::fmt::write_debug;
19use crate::Result;
20use crate::VertexName;
21
22pub struct UnionSet {
27 sets: [NameSet; 2],
28 hints: Hints,
29}
30
31impl UnionSet {
32 pub fn new(lhs: NameSet, rhs: NameSet) -> Self {
33 let hints = Hints::union(&[lhs.hints(), rhs.hints()]);
34 if hints.id_map().is_some() {
35 if let (Some(id1), Some(id2)) = (lhs.hints().min_id(), rhs.hints().min_id()) {
36 hints.set_min_id(id1.min(id2));
37 }
38 if let (Some(id1), Some(id2)) = (lhs.hints().max_id(), rhs.hints().max_id()) {
39 hints.set_max_id(id1.max(id2));
40 }
41 };
42 hints.add_flags(lhs.hints().flags() & rhs.hints().flags() & Flags::ANCESTORS);
43 if lhs.hints().contains(Flags::FILTER) || rhs.hints().contains(Flags::FILTER) {
44 hints.add_flags(Flags::FILTER);
45 }
46 Self {
47 sets: [lhs, rhs],
48 hints,
49 }
50 }
51}
52
53#[async_trait::async_trait]
54impl AsyncNameSetQuery for UnionSet {
55 async fn iter(&self) -> Result<BoxVertexStream> {
56 debug_assert_eq!(self.sets.len(), 2);
57 let diff = self.sets[1].clone() - self.sets[0].clone();
58 let diff_iter = diff.iter().await?;
59 let set0_iter = self.sets[0].iter().await?;
60 let iter = set0_iter.chain(diff_iter);
61 Ok(Box::pin(iter))
62 }
63
64 async fn iter_rev(&self) -> Result<BoxVertexStream> {
65 debug_assert_eq!(self.sets.len(), 2);
66 let diff = self.sets[1].clone() - self.sets[0].clone();
67 let diff_iter = diff.iter_rev().await?;
68 let set0_iter = self.sets[0].iter_rev().await?;
69 let iter = diff_iter.chain(set0_iter);
70 Ok(Box::pin(iter))
71 }
72
73 async fn count(&self) -> Result<usize> {
74 debug_assert_eq!(self.sets.len(), 2);
75 let mut count = self.sets[0].count().await?;
78 let mut iter = self.sets[1].iter().await?;
79 while let Some(item) = iter.next().await {
80 let name = item?;
81 if !self.sets[0].contains(&name).await? {
82 count += 1;
83 }
84 }
85 Ok(count)
86 }
87
88 async fn is_empty(&self) -> Result<bool> {
89 for set in &self.sets {
90 if !set.is_empty().await? {
91 return Ok(false);
92 }
93 }
94 Ok(true)
95 }
96
97 async fn contains(&self, name: &VertexName) -> Result<bool> {
98 for set in &self.sets {
99 if set.contains(name).await? {
100 return Ok(true);
101 }
102 }
103 Ok(false)
104 }
105
106 async fn contains_fast(&self, name: &VertexName) -> Result<Option<bool>> {
107 for set in &self.sets {
108 if let Some(result) = set.contains_fast(name).await? {
109 return Ok(Some(result));
110 }
111 }
112 Ok(None)
113 }
114
115 fn as_any(&self) -> &dyn Any {
116 self
117 }
118
119 fn hints(&self) -> &Hints {
120 &self.hints
121 }
122}
123
124impl fmt::Debug for UnionSet {
125 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
126 write!(f, "<or")?;
127 write_debug(f, &self.sets[0])?;
128 write_debug(f, &self.sets[1])?;
129 write!(f, ">")
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use std::collections::HashSet;
136
137 use super::super::tests::*;
138 use super::*;
139
140 fn union(a: &[u8], b: &[u8]) -> UnionSet {
141 let a = NameSet::from_query(VecQuery::from_bytes(a));
142 let b = NameSet::from_query(VecQuery::from_bytes(b));
143 UnionSet::new(a, b)
144 }
145
146 #[test]
147 fn test_union_basic() -> Result<()> {
148 let set = union(b"\x11\x33\x22", b"\x44\x11\x55\x33");
150 check_invariants(&set)?;
151 assert_eq!(shorten_iter(ni(set.iter())), ["11", "33", "22", "44", "55"]);
152 assert_eq!(
153 shorten_iter(ni(set.iter_rev())),
154 ["55", "44", "22", "33", "11"]
155 );
156 assert!(!nb(set.is_empty())?);
157 assert_eq!(nb(set.count())?, 5);
158 assert_eq!(shorten_name(nb(set.first())?.unwrap()), "11");
159 assert_eq!(shorten_name(nb(set.last())?.unwrap()), "55");
160 for &b in b"\x11\x22\x33\x44\x55".iter() {
161 assert!(nb(set.contains(&to_name(b)))?);
162 }
163 for &b in b"\x66\x77\x88".iter() {
164 assert!(!nb(set.contains(&to_name(b)))?);
165 }
166 Ok(())
167 }
168
169 quickcheck::quickcheck! {
170 fn test_union_quickcheck(a: Vec<u8>, b: Vec<u8>) -> bool {
171 let set = union(&a, &b);
172 check_invariants(&set).unwrap();
173
174 let count = nb(set.count()).unwrap();
175 assert!(count <= a.len() + b.len());
176
177 let set2: HashSet<_> = a.iter().chain(b.iter()).cloned().collect();
178 assert_eq!(count, set2.len());
179
180 assert!(a.iter().all(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)));
181 assert!(b.iter().all(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)));
182
183 true
184 }
185 }
186}