dag/nameset/
union.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8use std::any::Any;
9use std::fmt;
10
11use futures::StreamExt;
12
13use super::hints::Flags;
14use super::AsyncNameSetQuery;
15use super::BoxVertexStream;
16use super::Hints;
17use super::NameSet;
18use crate::fmt::write_debug;
19use crate::Result;
20use crate::VertexName;
21
22/// Union of 2 sets.
23///
24/// The order is preserved. The first set is iterated first, then the second set
25/// is iterated, with duplicated names skipped.
26pub struct UnionSet {
27    sets: [NameSet; 2],
28    hints: Hints,
29}
30
31impl UnionSet {
32    pub fn new(lhs: NameSet, rhs: NameSet) -> Self {
33        let hints = Hints::union(&[lhs.hints(), rhs.hints()]);
34        if hints.id_map().is_some() {
35            if let (Some(id1), Some(id2)) = (lhs.hints().min_id(), rhs.hints().min_id()) {
36                hints.set_min_id(id1.min(id2));
37            }
38            if let (Some(id1), Some(id2)) = (lhs.hints().max_id(), rhs.hints().max_id()) {
39                hints.set_max_id(id1.max(id2));
40            }
41        };
42        hints.add_flags(lhs.hints().flags() & rhs.hints().flags() & Flags::ANCESTORS);
43        if lhs.hints().contains(Flags::FILTER) || rhs.hints().contains(Flags::FILTER) {
44            hints.add_flags(Flags::FILTER);
45        }
46        Self {
47            sets: [lhs, rhs],
48            hints,
49        }
50    }
51}
52
53#[async_trait::async_trait]
54impl AsyncNameSetQuery for UnionSet {
55    async fn iter(&self) -> Result<BoxVertexStream> {
56        debug_assert_eq!(self.sets.len(), 2);
57        let diff = self.sets[1].clone() - self.sets[0].clone();
58        let diff_iter = diff.iter().await?;
59        let set0_iter = self.sets[0].iter().await?;
60        let iter = set0_iter.chain(diff_iter);
61        Ok(Box::pin(iter))
62    }
63
64    async fn iter_rev(&self) -> Result<BoxVertexStream> {
65        debug_assert_eq!(self.sets.len(), 2);
66        let diff = self.sets[1].clone() - self.sets[0].clone();
67        let diff_iter = diff.iter_rev().await?;
68        let set0_iter = self.sets[0].iter_rev().await?;
69        let iter = diff_iter.chain(set0_iter);
70        Ok(Box::pin(iter))
71    }
72
73    async fn count(&self) -> Result<usize> {
74        debug_assert_eq!(self.sets.len(), 2);
75        // This is more efficient if sets[0] is a large set that has a fast path
76        // for "count()".
77        let mut count = self.sets[0].count().await?;
78        let mut iter = self.sets[1].iter().await?;
79        while let Some(item) = iter.next().await {
80            let name = item?;
81            if !self.sets[0].contains(&name).await? {
82                count += 1;
83            }
84        }
85        Ok(count)
86    }
87
88    async fn is_empty(&self) -> Result<bool> {
89        for set in &self.sets {
90            if !set.is_empty().await? {
91                return Ok(false);
92            }
93        }
94        Ok(true)
95    }
96
97    async fn contains(&self, name: &VertexName) -> Result<bool> {
98        for set in &self.sets {
99            if set.contains(name).await? {
100                return Ok(true);
101            }
102        }
103        Ok(false)
104    }
105
106    async fn contains_fast(&self, name: &VertexName) -> Result<Option<bool>> {
107        for set in &self.sets {
108            if let Some(result) = set.contains_fast(name).await? {
109                return Ok(Some(result));
110            }
111        }
112        Ok(None)
113    }
114
115    fn as_any(&self) -> &dyn Any {
116        self
117    }
118
119    fn hints(&self) -> &Hints {
120        &self.hints
121    }
122}
123
124impl fmt::Debug for UnionSet {
125    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
126        write!(f, "<or")?;
127        write_debug(f, &self.sets[0])?;
128        write_debug(f, &self.sets[1])?;
129        write!(f, ">")
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use std::collections::HashSet;
136
137    use super::super::tests::*;
138    use super::*;
139
140    fn union(a: &[u8], b: &[u8]) -> UnionSet {
141        let a = NameSet::from_query(VecQuery::from_bytes(a));
142        let b = NameSet::from_query(VecQuery::from_bytes(b));
143        UnionSet::new(a, b)
144    }
145
146    #[test]
147    fn test_union_basic() -> Result<()> {
148        // 'a' overlaps with 'b'. UnionSet should de-duplicate items.
149        let set = union(b"\x11\x33\x22", b"\x44\x11\x55\x33");
150        check_invariants(&set)?;
151        assert_eq!(shorten_iter(ni(set.iter())), ["11", "33", "22", "44", "55"]);
152        assert_eq!(
153            shorten_iter(ni(set.iter_rev())),
154            ["55", "44", "22", "33", "11"]
155        );
156        assert!(!nb(set.is_empty())?);
157        assert_eq!(nb(set.count())?, 5);
158        assert_eq!(shorten_name(nb(set.first())?.unwrap()), "11");
159        assert_eq!(shorten_name(nb(set.last())?.unwrap()), "55");
160        for &b in b"\x11\x22\x33\x44\x55".iter() {
161            assert!(nb(set.contains(&to_name(b)))?);
162        }
163        for &b in b"\x66\x77\x88".iter() {
164            assert!(!nb(set.contains(&to_name(b)))?);
165        }
166        Ok(())
167    }
168
169    quickcheck::quickcheck! {
170        fn test_union_quickcheck(a: Vec<u8>, b: Vec<u8>) -> bool {
171            let set = union(&a, &b);
172            check_invariants(&set).unwrap();
173
174            let count = nb(set.count()).unwrap();
175            assert!(count <= a.len() + b.len());
176
177            let set2: HashSet<_> = a.iter().chain(b.iter()).cloned().collect();
178            assert_eq!(count, set2.len());
179
180            assert!(a.iter().all(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)));
181            assert!(b.iter().all(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)));
182
183            true
184        }
185    }
186}