dag/nameset/
intersection.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8use std::any::Any;
9use std::cmp::Ordering;
10use std::fmt;
11
12use futures::StreamExt;
13
14use super::hints::Flags;
15use super::AsyncNameSetQuery;
16use super::BoxVertexStream;
17use super::Hints;
18use super::NameSet;
19use crate::fmt::write_debug;
20use crate::Id;
21use crate::Result;
22use crate::VertexName;
23
24/// Intersection of 2 sets.
25///
26/// The iteration order is defined by the first set.
27pub struct IntersectionSet {
28    lhs: NameSet,
29    rhs: NameSet,
30    hints: Hints,
31}
32
33struct Iter {
34    iter: BoxVertexStream,
35    rhs: NameSet,
36    ended: bool,
37
38    /// Optional fast path for stop.
39    stop_condition: Option<StopCondition>,
40}
41
42impl Iter {
43    async fn next(&mut self) -> Option<Result<VertexName>> {
44        if self.ended {
45            return None;
46        }
47        loop {
48            let result = self.iter.as_mut().next().await;
49            if let Some(Ok(ref name)) = result {
50                match self.rhs.contains(&name).await {
51                    Err(err) => break Some(Err(err)),
52                    Ok(false) => {
53                        // Check if we can stop iteration early using hints.
54                        if let Some(ref cond) = self.stop_condition {
55                            if let Some(id_convert) = self.rhs.id_convert() {
56                                if let Ok(Some(id)) = id_convert.vertex_id_optional(&name).await {
57                                    if cond.should_stop_with_id(id) {
58                                        self.ended = true;
59                                        return None;
60                                    }
61                                }
62                            }
63                        }
64                        continue;
65                    }
66                    Ok(true) => {}
67                }
68            }
69            break result;
70        }
71    }
72
73    fn into_stream(self) -> BoxVertexStream {
74        Box::pin(futures::stream::unfold(self, |mut state| async move {
75            let result = state.next().await;
76            result.map(|r| (r, state))
77        }))
78    }
79}
80
81struct StopCondition {
82    order: Ordering,
83    id: Id,
84}
85
86impl StopCondition {
87    fn should_stop_with_id(&self, id: Id) -> bool {
88        id.cmp(&self.id) == self.order
89    }
90}
91
92impl IntersectionSet {
93    pub fn new(lhs: NameSet, rhs: NameSet) -> Self {
94        // More efficient if `lhs` is smaller. Swap `lhs` and `rhs` if `lhs` is `FULL`.
95        let (lhs, rhs) = if lhs.hints().contains(Flags::FULL)
96            && !rhs.hints().contains(Flags::FULL)
97            && !rhs.hints().contains(Flags::FILTER)
98            && lhs.hints().dag_version() >= rhs.hints().dag_version()
99        {
100            (rhs, lhs)
101        } else {
102            (lhs, rhs)
103        };
104
105        let hints = Hints::new_inherit_idmap_dag(lhs.hints());
106        hints.add_flags(
107            lhs.hints().flags()
108                & (Flags::EMPTY
109                    | Flags::ID_DESC
110                    | Flags::ID_ASC
111                    | Flags::TOPO_DESC
112                    | Flags::FILTER),
113        );
114        // Only keep the ANCESTORS flag if lhs and rhs use a compatible Dag.
115        if lhs.hints().dag_version() >= rhs.hints().dag_version() {
116            hints.add_flags(lhs.hints().flags() & rhs.hints().flags() & Flags::ANCESTORS);
117        }
118        let (rhs_min_id, rhs_max_id) = if hints.id_map_version() >= rhs.hints().id_map_version() {
119            // rhs ids are all known by lhs.
120            (rhs.hints().min_id(), rhs.hints().max_id())
121        } else {
122            (None, None)
123        };
124        match (lhs.hints().min_id(), rhs_min_id) {
125            (Some(id), None) | (None, Some(id)) => {
126                hints.set_min_id(id);
127            }
128            (Some(id1), Some(id2)) => {
129                hints.set_min_id(id1.max(id2));
130            }
131            (None, None) => {}
132        }
133        match (lhs.hints().max_id(), rhs_max_id) {
134            (Some(id), None) | (None, Some(id)) => {
135                hints.set_max_id(id);
136            }
137            (Some(id1), Some(id2)) => {
138                hints.set_max_id(id1.min(id2));
139            }
140            (None, None) => {}
141        }
142        Self { lhs, rhs, hints }
143    }
144
145    fn is_rhs_id_map_comapatible(&self) -> bool {
146        let lhs_version = self.lhs.hints().id_map_version();
147        let rhs_version = self.rhs.hints().id_map_version();
148        lhs_version == rhs_version || (lhs_version > rhs_version && rhs_version > None)
149    }
150}
151
152#[async_trait::async_trait]
153impl AsyncNameSetQuery for IntersectionSet {
154    async fn iter(&self) -> Result<BoxVertexStream> {
155        let stop_condition = if !self.is_rhs_id_map_comapatible() {
156            None
157        } else if self.lhs.hints().contains(Flags::ID_ASC) {
158            if let Some(id) = self.rhs.hints().max_id() {
159                Some(StopCondition {
160                    id,
161                    order: Ordering::Greater,
162                })
163            } else {
164                None
165            }
166        } else if self.lhs.hints().contains(Flags::ID_DESC) {
167            if let Some(id) = self.rhs.hints().min_id() {
168                Some(StopCondition {
169                    id,
170                    order: Ordering::Less,
171                })
172            } else {
173                None
174            }
175        } else {
176            None
177        };
178
179        let iter = Iter {
180            iter: self.lhs.iter().await?,
181            rhs: self.rhs.clone(),
182            ended: false,
183            stop_condition,
184        };
185        Ok(iter.into_stream())
186    }
187
188    async fn iter_rev(&self) -> Result<BoxVertexStream> {
189        let stop_condition = if !self.is_rhs_id_map_comapatible() {
190            None
191        } else if self.lhs.hints().contains(Flags::ID_DESC) {
192            if let Some(id) = self.rhs.hints().max_id() {
193                Some(StopCondition {
194                    id,
195                    order: Ordering::Greater,
196                })
197            } else {
198                None
199            }
200        } else if self.lhs.hints().contains(Flags::ID_ASC) {
201            if let Some(id) = self.rhs.hints().min_id() {
202                Some(StopCondition {
203                    id,
204                    order: Ordering::Less,
205                })
206            } else {
207                None
208            }
209        } else {
210            None
211        };
212
213        let iter = Iter {
214            iter: self.lhs.iter_rev().await?,
215            rhs: self.rhs.clone(),
216            ended: false,
217            stop_condition,
218        };
219        Ok(iter.into_stream())
220    }
221
222    async fn contains(&self, name: &VertexName) -> Result<bool> {
223        Ok(self.lhs.contains(name).await? && self.rhs.contains(name).await?)
224    }
225
226    async fn contains_fast(&self, name: &VertexName) -> Result<Option<bool>> {
227        for set in &[&self.lhs, &self.rhs] {
228            let contains = set.contains_fast(name).await?;
229            match contains {
230                Some(false) | None => return Ok(contains),
231                Some(true) => {}
232            }
233        }
234        Ok(Some(true))
235    }
236
237    fn as_any(&self) -> &dyn Any {
238        self
239    }
240
241    fn hints(&self) -> &Hints {
242        &self.hints
243    }
244}
245
246impl fmt::Debug for IntersectionSet {
247    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
248        write!(f, "<and")?;
249        write_debug(f, &self.lhs)?;
250        write_debug(f, &self.rhs)?;
251        write!(f, ">")
252    }
253}
254
255#[cfg(test)]
256#[allow(clippy::redundant_clone)]
257mod tests {
258    use std::collections::HashSet;
259
260    use super::super::id_lazy::test_utils::lazy_set;
261    use super::super::id_lazy::test_utils::lazy_set_inherit;
262    use super::super::tests::*;
263    use super::*;
264    use crate::Id;
265
266    fn intersection(a: &[u8], b: &[u8]) -> IntersectionSet {
267        let a = NameSet::from_query(VecQuery::from_bytes(a));
268        let b = NameSet::from_query(VecQuery::from_bytes(b));
269        IntersectionSet::new(a, b)
270    }
271
272    #[test]
273    fn test_intersection_basic() -> Result<()> {
274        let set = intersection(b"\x11\x33\x55\x22\x44", b"\x44\x33\x66");
275        check_invariants(&set)?;
276        assert_eq!(shorten_iter(ni(set.iter())), ["33", "44"]);
277        assert_eq!(shorten_iter(ni(set.iter_rev())), ["44", "33"]);
278        assert!(!nb(set.is_empty())?);
279        assert_eq!(nb(set.count())?, 2);
280        assert_eq!(shorten_name(nb(set.first())?.unwrap()), "33");
281        assert_eq!(shorten_name(nb(set.last())?.unwrap()), "44");
282        for &b in b"\x11\x22\x55\x66".iter() {
283            assert!(!nb(set.contains(&to_name(b)))?);
284        }
285        Ok(())
286    }
287
288    #[test]
289    fn test_intersection_min_max_id_fast_path() {
290        // The min_ids are intentionally wrong to test the fast paths.
291        let a = lazy_set(&[0x70, 0x60, 0x50, 0x40, 0x30, 0x20]);
292        let b = lazy_set_inherit(&[0x70, 0x65, 0x50, 0x40, 0x35, 0x20], &a);
293        let a = NameSet::from_query(a);
294        let b = NameSet::from_query(b);
295        a.hints().add_flags(Flags::ID_DESC);
296        b.hints().set_min_id(Id(0x40));
297        b.hints().set_max_id(Id(0x50));
298
299        let set = IntersectionSet::new(a, b.clone());
300        // No "20" - filtered out by min id fast path.
301        assert_eq!(shorten_iter(ni(set.iter())), ["70", "50", "40"]);
302        // No "70" - filtered out by max id fast path.
303        assert_eq!(shorten_iter(ni(set.iter_rev())), ["20", "40", "50"]);
304
305        // Test the reversed sort order.
306        let a = lazy_set(&[0x20, 0x30, 0x40, 0x50, 0x60, 0x70]);
307        let b = lazy_set_inherit(&[0x70, 0x65, 0x50, 0x40, 0x35, 0x20], &a);
308        let a = NameSet::from_query(a);
309        let b = NameSet::from_query(b);
310        a.hints().add_flags(Flags::ID_ASC);
311        b.hints().set_min_id(Id(0x40));
312        b.hints().set_max_id(Id(0x50));
313        let set = IntersectionSet::new(a, b.clone());
314        // No "70".
315        assert_eq!(shorten_iter(ni(set.iter())), ["20", "40", "50"]);
316        // No "20".
317        assert_eq!(shorten_iter(ni(set.iter_rev())), ["70", "50", "40"]);
318
319        // If two sets have incompatible IdMap, fast paths are not used.
320        let a = NameSet::from_query(lazy_set(&[0x20, 0x30, 0x40, 0x50, 0x60, 0x70]));
321        a.hints().add_flags(Flags::ID_ASC);
322        let set = IntersectionSet::new(a, b.clone());
323        // Should contain "70" and "20".
324        assert_eq!(shorten_iter(ni(set.iter())), ["20", "40", "50", "70"]);
325        assert_eq!(shorten_iter(ni(set.iter_rev())), ["70", "50", "40", "20"]);
326    }
327
328    quickcheck::quickcheck! {
329        fn test_intersection_quickcheck(a: Vec<u8>, b: Vec<u8>) -> bool {
330            let set = intersection(&a, &b);
331            check_invariants(&set).unwrap();
332
333            let count = nb(set.count()).unwrap();
334            assert!(count <= a.len(), "len({:?}) = {} should <= len({:?})" , &set, count, &a);
335            assert!(count <= b.len(), "len({:?}) = {} should <= len({:?})" , &set, count, &b);
336
337            let contains_a: HashSet<u8> = a.into_iter().filter(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)).collect();
338            let contains_b: HashSet<u8> = b.into_iter().filter(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)).collect();
339            assert_eq!(contains_a, contains_b);
340
341            true
342        }
343    }
344}