dag/nameset/
difference.rs1use std::any::Any;
9use std::fmt;
10
11use futures::StreamExt;
12
13use super::hints::Flags;
14use super::AsyncNameSetQuery;
15use super::BoxVertexStream;
16use super::Hints;
17use super::NameSet;
18use crate::fmt::write_debug;
19use crate::Result;
20use crate::VertexName;
21
22pub struct DifferenceSet {
26 lhs: NameSet,
27 rhs: NameSet,
28 hints: Hints,
29}
30
31struct Iter {
32 iter: BoxVertexStream,
33 rhs: NameSet,
34}
35
36impl DifferenceSet {
37 pub fn new(lhs: NameSet, rhs: NameSet) -> Self {
38 let hints = Hints::new_inherit_idmap_dag(lhs.hints());
39 hints.add_flags(
41 lhs.hints().flags()
42 & (Flags::EMPTY
43 | Flags::ID_DESC
44 | Flags::ID_ASC
45 | Flags::TOPO_DESC
46 | Flags::FILTER),
47 );
48 if let Some(id) = lhs.hints().min_id() {
49 hints.set_min_id(id);
50 }
51 if let Some(id) = lhs.hints().max_id() {
52 hints.set_max_id(id);
53 }
54 Self { lhs, rhs, hints }
55 }
56}
57
58#[async_trait::async_trait]
59impl AsyncNameSetQuery for DifferenceSet {
60 async fn iter(&self) -> Result<BoxVertexStream> {
61 let iter = Iter {
62 iter: self.lhs.iter().await?,
63 rhs: self.rhs.clone(),
64 };
65 Ok(iter.into_stream())
66 }
67
68 async fn iter_rev(&self) -> Result<BoxVertexStream> {
69 let iter = Iter {
70 iter: self.lhs.iter_rev().await?,
71 rhs: self.rhs.clone(),
72 };
73 Ok(iter.into_stream())
74 }
75
76 async fn contains(&self, name: &VertexName) -> Result<bool> {
77 Ok(self.lhs.contains(name).await? && !self.rhs.contains(name).await?)
78 }
79
80 async fn contains_fast(&self, name: &VertexName) -> Result<Option<bool>> {
81 let lhs_contains = self.lhs.contains_fast(name).await?;
82 if lhs_contains == Some(false) {
83 return Ok(Some(false));
84 }
85 let rhs_contains = self.rhs.contains_fast(name).await?;
86 let result = match (lhs_contains, rhs_contains) {
87 (Some(true), Some(false)) => Some(true),
88 (_, Some(true)) | (Some(false), _) => Some(false),
89 (Some(true), None) | (None, _) => None,
90 };
91 Ok(result)
92 }
93
94 fn as_any(&self) -> &dyn Any {
95 self
96 }
97
98 fn hints(&self) -> &Hints {
99 &self.hints
100 }
101}
102
103impl fmt::Debug for DifferenceSet {
104 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
105 write!(f, "<diff")?;
106 write_debug(f, &self.lhs)?;
107 write_debug(f, &self.rhs)?;
108 write!(f, ">")
109 }
110}
111
112impl Iter {
113 async fn next(&mut self) -> Option<Result<VertexName>> {
114 loop {
115 let result = self.iter.as_mut().next().await;
116 if let Some(Ok(ref name)) = result {
117 match self.rhs.contains(&name).await {
118 Err(err) => break Some(Err(err)),
119 Ok(true) => continue,
120 _ => {}
121 }
122 }
123 break result;
124 }
125 }
126
127 fn into_stream(self) -> BoxVertexStream {
128 Box::pin(futures::stream::unfold(self, |mut state| async move {
129 let result = state.next().await;
130 result.map(|r| (r, state))
131 }))
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use nonblocking::non_blocking as nb;
138
139 use super::super::tests::*;
140 use super::*;
141
142 fn difference(a: &[u8], b: &[u8]) -> DifferenceSet {
143 let a = NameSet::from_query(VecQuery::from_bytes(a));
144 let b = NameSet::from_query(VecQuery::from_bytes(b));
145 DifferenceSet::new(a, b)
146 }
147
148 #[test]
149 fn test_difference_basic() -> Result<()> {
150 let set = difference(b"\x11\x33\x55\x22\x44", b"\x44\x33\x66");
151 check_invariants(&set)?;
152 assert_eq!(shorten_iter(ni(set.iter())), ["11", "55", "22"]);
153 assert_eq!(shorten_iter(ni(set.iter_rev())), ["22", "55", "11"]);
154 assert!(!nb(set.is_empty())??);
155 assert_eq!(nb(set.count())??, 3);
156 assert_eq!(shorten_name(nb(set.first())??.unwrap()), "11");
157 assert_eq!(shorten_name(nb(set.last())??.unwrap()), "22");
158 for &b in b"\x11\x22\x55".iter() {
159 assert!(nb(set.contains(&to_name(b)))??);
160 }
161 for &b in b"\x33\x44\x66".iter() {
162 assert!(!nb(set.contains(&to_name(b)))??);
163 }
164 Ok(())
165 }
166
167 quickcheck::quickcheck! {
168 fn test_difference_quickcheck(a: Vec<u8>, b: Vec<u8>) -> bool {
169 let set = difference(&a, &b);
170 check_invariants(&set).unwrap();
171
172 let count = nb(set.count()).unwrap().unwrap();
173 assert!(count <= a.len());
174
175 assert!(b.iter().all(|&b| nb(set.contains(&to_name(b))).unwrap().ok() == Some(false)));
176
177 true
178 }
179 }
180}