dag/nameset/
difference.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8use std::any::Any;
9use std::fmt;
10
11use futures::StreamExt;
12
13use super::hints::Flags;
14use super::AsyncNameSetQuery;
15use super::BoxVertexStream;
16use super::Hints;
17use super::NameSet;
18use crate::fmt::write_debug;
19use crate::Result;
20use crate::VertexName;
21
22/// Subset of `lhs` that does not overlap with `rhs`.
23///
24/// The iteration order is defined by `lhs`.
25pub struct DifferenceSet {
26    lhs: NameSet,
27    rhs: NameSet,
28    hints: Hints,
29}
30
31struct Iter {
32    iter: BoxVertexStream,
33    rhs: NameSet,
34}
35
36impl DifferenceSet {
37    pub fn new(lhs: NameSet, rhs: NameSet) -> Self {
38        let hints = Hints::new_inherit_idmap_dag(lhs.hints());
39        // Inherit flags, min/max Ids from lhs.
40        hints.add_flags(
41            lhs.hints().flags()
42                & (Flags::EMPTY
43                    | Flags::ID_DESC
44                    | Flags::ID_ASC
45                    | Flags::TOPO_DESC
46                    | Flags::FILTER),
47        );
48        if let Some(id) = lhs.hints().min_id() {
49            hints.set_min_id(id);
50        }
51        if let Some(id) = lhs.hints().max_id() {
52            hints.set_max_id(id);
53        }
54        Self { lhs, rhs, hints }
55    }
56}
57
58#[async_trait::async_trait]
59impl AsyncNameSetQuery for DifferenceSet {
60    async fn iter(&self) -> Result<BoxVertexStream> {
61        let iter = Iter {
62            iter: self.lhs.iter().await?,
63            rhs: self.rhs.clone(),
64        };
65        Ok(iter.into_stream())
66    }
67
68    async fn iter_rev(&self) -> Result<BoxVertexStream> {
69        let iter = Iter {
70            iter: self.lhs.iter_rev().await?,
71            rhs: self.rhs.clone(),
72        };
73        Ok(iter.into_stream())
74    }
75
76    async fn contains(&self, name: &VertexName) -> Result<bool> {
77        Ok(self.lhs.contains(name).await? && !self.rhs.contains(name).await?)
78    }
79
80    async fn contains_fast(&self, name: &VertexName) -> Result<Option<bool>> {
81        let lhs_contains = self.lhs.contains_fast(name).await?;
82        if lhs_contains == Some(false) {
83            return Ok(Some(false));
84        }
85        let rhs_contains = self.rhs.contains_fast(name).await?;
86        let result = match (lhs_contains, rhs_contains) {
87            (Some(true), Some(false)) => Some(true),
88            (_, Some(true)) | (Some(false), _) => Some(false),
89            (Some(true), None) | (None, _) => None,
90        };
91        Ok(result)
92    }
93
94    fn as_any(&self) -> &dyn Any {
95        self
96    }
97
98    fn hints(&self) -> &Hints {
99        &self.hints
100    }
101}
102
103impl fmt::Debug for DifferenceSet {
104    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
105        write!(f, "<diff")?;
106        write_debug(f, &self.lhs)?;
107        write_debug(f, &self.rhs)?;
108        write!(f, ">")
109    }
110}
111
112impl Iter {
113    async fn next(&mut self) -> Option<Result<VertexName>> {
114        loop {
115            let result = self.iter.as_mut().next().await;
116            if let Some(Ok(ref name)) = result {
117                match self.rhs.contains(&name).await {
118                    Err(err) => break Some(Err(err)),
119                    Ok(true) => continue,
120                    _ => {}
121                }
122            }
123            break result;
124        }
125    }
126
127    fn into_stream(self) -> BoxVertexStream {
128        Box::pin(futures::stream::unfold(self, |mut state| async move {
129            let result = state.next().await;
130            result.map(|r| (r, state))
131        }))
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use nonblocking::non_blocking as nb;
138
139    use super::super::tests::*;
140    use super::*;
141
142    fn difference(a: &[u8], b: &[u8]) -> DifferenceSet {
143        let a = NameSet::from_query(VecQuery::from_bytes(a));
144        let b = NameSet::from_query(VecQuery::from_bytes(b));
145        DifferenceSet::new(a, b)
146    }
147
148    #[test]
149    fn test_difference_basic() -> Result<()> {
150        let set = difference(b"\x11\x33\x55\x22\x44", b"\x44\x33\x66");
151        check_invariants(&set)?;
152        assert_eq!(shorten_iter(ni(set.iter())), ["11", "55", "22"]);
153        assert_eq!(shorten_iter(ni(set.iter_rev())), ["22", "55", "11"]);
154        assert!(!nb(set.is_empty())??);
155        assert_eq!(nb(set.count())??, 3);
156        assert_eq!(shorten_name(nb(set.first())??.unwrap()), "11");
157        assert_eq!(shorten_name(nb(set.last())??.unwrap()), "22");
158        for &b in b"\x11\x22\x55".iter() {
159            assert!(nb(set.contains(&to_name(b)))??);
160        }
161        for &b in b"\x33\x44\x66".iter() {
162            assert!(!nb(set.contains(&to_name(b)))??);
163        }
164        Ok(())
165    }
166
167    quickcheck::quickcheck! {
168        fn test_difference_quickcheck(a: Vec<u8>, b: Vec<u8>) -> bool {
169            let set = difference(&a, &b);
170            check_invariants(&set).unwrap();
171
172            let count = nb(set.count()).unwrap().unwrap();
173            assert!(count <= a.len());
174
175            assert!(b.iter().all(|&b| nb(set.contains(&to_name(b))).unwrap().ok() == Some(false)));
176
177            true
178        }
179    }
180}