1use std::any::Any;
9use std::cmp::Ordering;
10use std::fmt;
11
12use futures::StreamExt;
13
14use super::hints::Flags;
15use super::AsyncNameSetQuery;
16use super::BoxVertexStream;
17use super::Hints;
18use super::NameSet;
19use crate::fmt::write_debug;
20use crate::Id;
21use crate::Result;
22use crate::VertexName;
23
24pub struct IntersectionSet {
28 lhs: NameSet,
29 rhs: NameSet,
30 hints: Hints,
31}
32
33struct Iter {
34 iter: BoxVertexStream,
35 rhs: NameSet,
36 ended: bool,
37
38 stop_condition: Option<StopCondition>,
40}
41
42impl Iter {
43 async fn next(&mut self) -> Option<Result<VertexName>> {
44 if self.ended {
45 return None;
46 }
47 loop {
48 let result = self.iter.as_mut().next().await;
49 if let Some(Ok(ref name)) = result {
50 match self.rhs.contains(&name).await {
51 Err(err) => break Some(Err(err)),
52 Ok(false) => {
53 if let Some(ref cond) = self.stop_condition {
55 if let Some(id_convert) = self.rhs.id_convert() {
56 if let Ok(Some(id)) = id_convert.vertex_id_optional(&name).await {
57 if cond.should_stop_with_id(id) {
58 self.ended = true;
59 return None;
60 }
61 }
62 }
63 }
64 continue;
65 }
66 Ok(true) => {}
67 }
68 }
69 break result;
70 }
71 }
72
73 fn into_stream(self) -> BoxVertexStream {
74 Box::pin(futures::stream::unfold(self, |mut state| async move {
75 let result = state.next().await;
76 result.map(|r| (r, state))
77 }))
78 }
79}
80
81struct StopCondition {
82 order: Ordering,
83 id: Id,
84}
85
86impl StopCondition {
87 fn should_stop_with_id(&self, id: Id) -> bool {
88 id.cmp(&self.id) == self.order
89 }
90}
91
92impl IntersectionSet {
93 pub fn new(lhs: NameSet, rhs: NameSet) -> Self {
94 let (lhs, rhs) = if lhs.hints().contains(Flags::FULL)
96 && !rhs.hints().contains(Flags::FULL)
97 && !rhs.hints().contains(Flags::FILTER)
98 && lhs.hints().dag_version() >= rhs.hints().dag_version()
99 {
100 (rhs, lhs)
101 } else {
102 (lhs, rhs)
103 };
104
105 let hints = Hints::new_inherit_idmap_dag(lhs.hints());
106 hints.add_flags(
107 lhs.hints().flags()
108 & (Flags::EMPTY
109 | Flags::ID_DESC
110 | Flags::ID_ASC
111 | Flags::TOPO_DESC
112 | Flags::FILTER),
113 );
114 if lhs.hints().dag_version() >= rhs.hints().dag_version() {
116 hints.add_flags(lhs.hints().flags() & rhs.hints().flags() & Flags::ANCESTORS);
117 }
118 let (rhs_min_id, rhs_max_id) = if hints.id_map_version() >= rhs.hints().id_map_version() {
119 (rhs.hints().min_id(), rhs.hints().max_id())
121 } else {
122 (None, None)
123 };
124 match (lhs.hints().min_id(), rhs_min_id) {
125 (Some(id), None) | (None, Some(id)) => {
126 hints.set_min_id(id);
127 }
128 (Some(id1), Some(id2)) => {
129 hints.set_min_id(id1.max(id2));
130 }
131 (None, None) => {}
132 }
133 match (lhs.hints().max_id(), rhs_max_id) {
134 (Some(id), None) | (None, Some(id)) => {
135 hints.set_max_id(id);
136 }
137 (Some(id1), Some(id2)) => {
138 hints.set_max_id(id1.min(id2));
139 }
140 (None, None) => {}
141 }
142 Self { lhs, rhs, hints }
143 }
144
145 fn is_rhs_id_map_comapatible(&self) -> bool {
146 let lhs_version = self.lhs.hints().id_map_version();
147 let rhs_version = self.rhs.hints().id_map_version();
148 lhs_version == rhs_version || (lhs_version > rhs_version && rhs_version > None)
149 }
150}
151
152#[async_trait::async_trait]
153impl AsyncNameSetQuery for IntersectionSet {
154 async fn iter(&self) -> Result<BoxVertexStream> {
155 let stop_condition = if !self.is_rhs_id_map_comapatible() {
156 None
157 } else if self.lhs.hints().contains(Flags::ID_ASC) {
158 if let Some(id) = self.rhs.hints().max_id() {
159 Some(StopCondition {
160 id,
161 order: Ordering::Greater,
162 })
163 } else {
164 None
165 }
166 } else if self.lhs.hints().contains(Flags::ID_DESC) {
167 if let Some(id) = self.rhs.hints().min_id() {
168 Some(StopCondition {
169 id,
170 order: Ordering::Less,
171 })
172 } else {
173 None
174 }
175 } else {
176 None
177 };
178
179 let iter = Iter {
180 iter: self.lhs.iter().await?,
181 rhs: self.rhs.clone(),
182 ended: false,
183 stop_condition,
184 };
185 Ok(iter.into_stream())
186 }
187
188 async fn iter_rev(&self) -> Result<BoxVertexStream> {
189 let stop_condition = if !self.is_rhs_id_map_comapatible() {
190 None
191 } else if self.lhs.hints().contains(Flags::ID_DESC) {
192 if let Some(id) = self.rhs.hints().max_id() {
193 Some(StopCondition {
194 id,
195 order: Ordering::Greater,
196 })
197 } else {
198 None
199 }
200 } else if self.lhs.hints().contains(Flags::ID_ASC) {
201 if let Some(id) = self.rhs.hints().min_id() {
202 Some(StopCondition {
203 id,
204 order: Ordering::Less,
205 })
206 } else {
207 None
208 }
209 } else {
210 None
211 };
212
213 let iter = Iter {
214 iter: self.lhs.iter_rev().await?,
215 rhs: self.rhs.clone(),
216 ended: false,
217 stop_condition,
218 };
219 Ok(iter.into_stream())
220 }
221
222 async fn contains(&self, name: &VertexName) -> Result<bool> {
223 Ok(self.lhs.contains(name).await? && self.rhs.contains(name).await?)
224 }
225
226 async fn contains_fast(&self, name: &VertexName) -> Result<Option<bool>> {
227 for set in &[&self.lhs, &self.rhs] {
228 let contains = set.contains_fast(name).await?;
229 match contains {
230 Some(false) | None => return Ok(contains),
231 Some(true) => {}
232 }
233 }
234 Ok(Some(true))
235 }
236
237 fn as_any(&self) -> &dyn Any {
238 self
239 }
240
241 fn hints(&self) -> &Hints {
242 &self.hints
243 }
244}
245
246impl fmt::Debug for IntersectionSet {
247 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
248 write!(f, "<and")?;
249 write_debug(f, &self.lhs)?;
250 write_debug(f, &self.rhs)?;
251 write!(f, ">")
252 }
253}
254
255#[cfg(test)]
256#[allow(clippy::redundant_clone)]
257mod tests {
258 use std::collections::HashSet;
259
260 use super::super::id_lazy::test_utils::lazy_set;
261 use super::super::id_lazy::test_utils::lazy_set_inherit;
262 use super::super::tests::*;
263 use super::*;
264 use crate::Id;
265
266 fn intersection(a: &[u8], b: &[u8]) -> IntersectionSet {
267 let a = NameSet::from_query(VecQuery::from_bytes(a));
268 let b = NameSet::from_query(VecQuery::from_bytes(b));
269 IntersectionSet::new(a, b)
270 }
271
272 #[test]
273 fn test_intersection_basic() -> Result<()> {
274 let set = intersection(b"\x11\x33\x55\x22\x44", b"\x44\x33\x66");
275 check_invariants(&set)?;
276 assert_eq!(shorten_iter(ni(set.iter())), ["33", "44"]);
277 assert_eq!(shorten_iter(ni(set.iter_rev())), ["44", "33"]);
278 assert!(!nb(set.is_empty())?);
279 assert_eq!(nb(set.count())?, 2);
280 assert_eq!(shorten_name(nb(set.first())?.unwrap()), "33");
281 assert_eq!(shorten_name(nb(set.last())?.unwrap()), "44");
282 for &b in b"\x11\x22\x55\x66".iter() {
283 assert!(!nb(set.contains(&to_name(b)))?);
284 }
285 Ok(())
286 }
287
288 #[test]
289 fn test_intersection_min_max_id_fast_path() {
290 let a = lazy_set(&[0x70, 0x60, 0x50, 0x40, 0x30, 0x20]);
292 let b = lazy_set_inherit(&[0x70, 0x65, 0x50, 0x40, 0x35, 0x20], &a);
293 let a = NameSet::from_query(a);
294 let b = NameSet::from_query(b);
295 a.hints().add_flags(Flags::ID_DESC);
296 b.hints().set_min_id(Id(0x40));
297 b.hints().set_max_id(Id(0x50));
298
299 let set = IntersectionSet::new(a, b.clone());
300 assert_eq!(shorten_iter(ni(set.iter())), ["70", "50", "40"]);
302 assert_eq!(shorten_iter(ni(set.iter_rev())), ["20", "40", "50"]);
304
305 let a = lazy_set(&[0x20, 0x30, 0x40, 0x50, 0x60, 0x70]);
307 let b = lazy_set_inherit(&[0x70, 0x65, 0x50, 0x40, 0x35, 0x20], &a);
308 let a = NameSet::from_query(a);
309 let b = NameSet::from_query(b);
310 a.hints().add_flags(Flags::ID_ASC);
311 b.hints().set_min_id(Id(0x40));
312 b.hints().set_max_id(Id(0x50));
313 let set = IntersectionSet::new(a, b.clone());
314 assert_eq!(shorten_iter(ni(set.iter())), ["20", "40", "50"]);
316 assert_eq!(shorten_iter(ni(set.iter_rev())), ["70", "50", "40"]);
318
319 let a = NameSet::from_query(lazy_set(&[0x20, 0x30, 0x40, 0x50, 0x60, 0x70]));
321 a.hints().add_flags(Flags::ID_ASC);
322 let set = IntersectionSet::new(a, b.clone());
323 assert_eq!(shorten_iter(ni(set.iter())), ["20", "40", "50", "70"]);
325 assert_eq!(shorten_iter(ni(set.iter_rev())), ["70", "50", "40", "20"]);
326 }
327
328 quickcheck::quickcheck! {
329 fn test_intersection_quickcheck(a: Vec<u8>, b: Vec<u8>) -> bool {
330 let set = intersection(&a, &b);
331 check_invariants(&set).unwrap();
332
333 let count = nb(set.count()).unwrap();
334 assert!(count <= a.len(), "len({:?}) = {} should <= len({:?})" , &set, count, &a);
335 assert!(count <= b.len(), "len({:?}) = {} should <= len({:?})" , &set, count, &b);
336
337 let contains_a: HashSet<u8> = a.into_iter().filter(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)).collect();
338 let contains_b: HashSet<u8> = b.into_iter().filter(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)).collect();
339 assert_eq!(contains_a, contains_b);
340
341 true
342 }
343 }
344}