dag/nameset/
slice.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8use std::any::Any;
9use std::collections::HashSet;
10use std::fmt;
11use std::sync::atomic::AtomicBool;
12use std::sync::atomic::Ordering::Acquire;
13use std::sync::atomic::Ordering::Release;
14use std::sync::Arc;
15
16use futures::lock::Mutex;
17use futures::StreamExt;
18use indexmap::IndexSet;
19use tracing::debug;
20use tracing::instrument;
21use tracing::trace;
22use tracing::Level;
23
24use super::hints::Flags;
25use super::AsyncNameSetQuery;
26use super::BoxVertexStream;
27use super::Hints;
28use super::NameSet;
29use crate::fmt::write_debug;
30use crate::Result;
31use crate::VertexName;
32
33/// Slice of a set.
34#[derive(Clone)]
35pub struct SliceSet {
36    inner: NameSet,
37    hints: Hints,
38    skip_count: u64,
39    take_count: Option<u64>,
40
41    // Skipped vertexes. Updated during iteration.
42    skip_cache: Arc<Mutex<HashSet<VertexName>>>,
43    // Taken vertexes. Updated during iteration.
44    take_cache: Arc<Mutex<IndexSet<VertexName>>>,
45    // If take_cache is complete.
46    take_cache_complete: Arc<AtomicBool>,
47}
48
49impl SliceSet {
50    pub fn new(set: NameSet, skip_count: u64, take_count: Option<u64>) -> Self {
51        let hints = set.hints().clone();
52        hints.update_flags_with(|mut f| {
53            // Only keep compatible flags.
54            f &= Flags::ID_DESC
55                | Flags::ID_ASC
56                | Flags::TOPO_DESC
57                | Flags::HAS_MIN_ID
58                | Flags::HAS_MAX_ID
59                | Flags::EMPTY;
60            // Add EMPTY hints if take_count is 0.
61            if take_count == Some(0) {
62                f |= Flags::EMPTY;
63            }
64            f
65        });
66        Self {
67            inner: set,
68            hints,
69            skip_count,
70            take_count,
71            skip_cache: Default::default(),
72            take_cache: Default::default(),
73            take_cache_complete: Default::default(),
74        }
75    }
76
77    fn is_take_cache_complete(&self) -> bool {
78        self.take_cache_complete.load(Acquire)
79    }
80
81    async fn is_skip_cache_complete(&self) -> bool {
82        self.skip_cache.lock().await.len() as u64 == self.skip_count
83    }
84
85    #[instrument(level=Level::DEBUG)]
86    async fn populate_take_cache(&self) -> Result<()> {
87        // See Iter::next. If take_count is not set, the "take" can be unbounded,
88        // and take_cache won't be populated.
89        assert!(self.take_count.is_some());
90
91        // Use iter() to populate take_cache.
92        let mut iter = self.iter().await?;
93        while let Some(_) = iter.next().await {}
94        assert!(self.is_take_cache_complete());
95
96        Ok(())
97    }
98}
99
100struct Iter {
101    inner_iter: BoxVertexStream,
102    set: SliceSet,
103    index: u64,
104}
105
106const SKIP_CACHE_SIZE_THRESHOLD: u64 = 1000;
107
108impl Iter {
109    async fn next(&mut self) -> Option<Result<VertexName>> {
110        if self.set.is_take_cache_complete() {
111            // Fast path - no need to use inner_iter.
112            let index = self.index.max(self.set.skip_count);
113            let take_index = index - self.set.skip_count;
114            let result = {
115                let cache = self.set.take_cache.lock().await;
116                cache.get_index(take_index as _).cloned()
117            };
118            trace!("next(index={}) = {:?} (fast path)", index, &result);
119            self.index = index + 1;
120            return Ok(result).transpose();
121        }
122
123        loop {
124            // Slow path - use inner_iter.
125            let index = self.index;
126            trace!("next(index={})", index);
127            let next: Option<VertexName> = match self.inner_iter.next().await {
128                Some(Err(e)) => {
129                    self.index = u64::MAX;
130                    return Some(Err(e));
131                }
132                Some(Ok(v)) => Some(v),
133                None => None,
134            };
135            self.index += 1;
136
137            // Skip?
138            if index < self.set.skip_count {
139                if index < SKIP_CACHE_SIZE_THRESHOLD {
140                    // Update skip_cache.
141                    if let Some(v) = next.as_ref() {
142                        let mut cache = self.set.skip_cache.lock().await;
143                        cache.insert(v.clone());
144                    }
145                }
146                continue;
147            }
148
149            // Take?
150            let take_index = index - self.set.skip_count;
151            let should_take: bool = match self.set.take_count {
152                Some(count) => {
153                    if take_index < count {
154                        // Update take_cache.
155                        let mut cache = self.set.take_cache.lock().await;
156                        if take_index == cache.len() as u64 {
157                            if let Some(v) = next.as_ref() {
158                                cache.insert(v.clone());
159                            } else {
160                                // No more item in the original set.
161                                self.set.take_cache_complete.store(true, Release);
162                            }
163                        }
164                        true
165                    } else {
166                        self.set.take_cache_complete.store(true, Release);
167                        false
168                    }
169                }
170                None => {
171                    // Do not update take_cache, since the inner
172                    // set can be quite large.
173                    true
174                }
175            };
176            if should_take {
177                return next.map(Ok);
178            } else {
179                return None;
180            }
181        }
182    }
183
184    fn into_stream(self) -> BoxVertexStream {
185        Box::pin(futures::stream::unfold(self, |mut state| async move {
186            let result = state.next().await;
187            result.map(|r| (r, state))
188        }))
189    }
190}
191
192struct TakeCacheRevIter {
193    take_cache: Arc<Mutex<IndexSet<VertexName>>>,
194    index: usize,
195}
196
197impl TakeCacheRevIter {
198    async fn next(&mut self) -> Option<Result<VertexName>> {
199        let index = self.index;
200        self.index += 1;
201        let cache = self.take_cache.lock().await;
202        if index >= cache.len() {
203            None
204        } else {
205            let index = cache.len() - index - 1;
206            cache.get_index(index).cloned().map(Ok)
207        }
208    }
209
210    fn into_stream(self) -> BoxVertexStream {
211        Box::pin(futures::stream::unfold(self, |mut state| async move {
212            let result = state.next().await;
213            result.map(|r| (r, state))
214        }))
215    }
216}
217
218#[async_trait::async_trait]
219impl AsyncNameSetQuery for SliceSet {
220    async fn iter(&self) -> Result<BoxVertexStream> {
221        let inner_iter = self.inner.iter().await?;
222        let iter = Iter {
223            inner_iter,
224            set: self.clone(),
225            index: 0,
226        };
227        Ok(iter.into_stream())
228    }
229
230    async fn iter_rev(&self) -> Result<BoxVertexStream> {
231        if let Some(_take) = self.take_count {
232            self.populate_take_cache().await?;
233            trace!("iter_rev({:0.6?}): use take_cache", self);
234            // Use take_cache to answer RevIter. This is probably better
235            // than using inner.iter_rev(), if take_count is small:
236            //     [<----------------------------]
237            //     [skip][take][...(need skip)...]
238            let iter = TakeCacheRevIter {
239                take_cache: self.take_cache.clone(),
240                index: 0,
241            };
242            Ok(iter.into_stream())
243        } else {
244            // Unbounded "take_count". Reuse inner.rev_iter().
245            //     [<-------------------]
246            //     [skip][<---take------]
247            trace!("iter_rev({:0.6?}): use inner.iter_rev()", self,);
248            let count = self.count().await?;
249            let iter = self.inner.iter_rev().await?;
250            Ok(Box::pin(iter.take(count)))
251        }
252    }
253
254    async fn count(&self) -> Result<usize> {
255        let count = self.inner.count().await?;
256        // consider skip_count
257        let count = (count as u64).max(self.skip_count) - self.skip_count;
258        // consider take_count
259        let count = count.min(self.take_count.unwrap_or(u64::MAX));
260        Ok(count as _)
261    }
262
263    async fn contains(&self, name: &VertexName) -> Result<bool> {
264        if let Some(result) = self.contains_fast(name).await? {
265            return Ok(result);
266        }
267
268        debug!("SliceSet::contains({:.6?}, {:?}) (slow path)", self, name);
269        let mut iter = self.iter().await?;
270        while let Some(item) = iter.next().await {
271            if &item? == name {
272                return Ok(true);
273            }
274        }
275        Ok(false)
276    }
277
278    async fn contains_fast(&self, name: &VertexName) -> Result<Option<bool>> {
279        // Check take_cache.
280        {
281            let take_cache = self.take_cache.lock().await;
282            let is_take_cache_complete = self.is_take_cache_complete();
283            let contains = take_cache.contains(name);
284            match (contains, is_take_cache_complete) {
285                (_, true) | (true, _) => return Ok(Some(contains)),
286                (false, false) => {}
287            }
288        }
289
290        // Check skip_cache.
291        // Assumes one vertex only occurs once in a set.
292        let skip_contains = self.skip_cache.lock().await.contains(name);
293        if skip_contains {
294            return Ok(Some(false));
295        }
296
297        // Check with the original set.
298        let result = self.inner.contains_fast(name).await?;
299        match (result, self.is_skip_cache_complete().await) {
300            // Not in the original set. Slice is a subset. Result: false.
301            (Some(false), _) => Ok(Some(false)),
302            // In the original set. Skip cache is completed _and_ checked
303            // above (name was _not_ skipped). Result: true.
304            (Some(true), true) => {
305                // skip_cache was checked above
306                debug_assert!(!self.skip_cache.lock().await.contains(name));
307                Ok(Some(true))
308            }
309            // Unsure cases.
310            (None, false) => Ok(None),
311            (Some(true), false) => Ok(None),
312            (None, true) => Ok(None),
313        }
314    }
315
316    fn as_any(&self) -> &dyn Any {
317        self
318    }
319
320    fn hints(&self) -> &Hints {
321        &self.hints
322    }
323}
324
325impl fmt::Debug for SliceSet {
326    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
327        f.write_str("<slice")?;
328        write_debug(f, &self.inner)?;
329        f.write_str(" [")?;
330        if self.skip_count > 0 {
331            write!(f, "{}", self.skip_count)?;
332        }
333        f.write_str("..")?;
334        if let Some(n) = self.take_count {
335            write!(f, "{}", self.skip_count + n)?;
336        }
337        f.write_str("]>")
338    }
339}
340
341#[cfg(test)]
342#[allow(clippy::redundant_clone)]
343mod tests {
344    use nonblocking::non_blocking_result as r;
345
346    use super::super::tests::*;
347    use super::*;
348
349    #[test]
350    fn test_basic() -> Result<()> {
351        let orig = NameSet::from("a b c d e f g h i");
352        let count = r(orig.count())?;
353
354        let set = SliceSet::new(orig.clone(), 0, None);
355        assert_eq!(r(set.count())?, count);
356        check_invariants(&set)?;
357
358        let set = SliceSet::new(orig.clone(), 0, Some(0));
359        assert_eq!(r(set.count())?, 0);
360        check_invariants(&set)?;
361
362        let set = SliceSet::new(orig.clone(), 4, None);
363        assert_eq!(r(set.count())?, count - 4);
364        check_invariants(&set)?;
365
366        let set = SliceSet::new(orig.clone(), 4, Some(0));
367        assert_eq!(r(set.count())?, 0);
368        check_invariants(&set)?;
369
370        let set = SliceSet::new(orig.clone(), 0, Some(4));
371        assert_eq!(r(set.count())?, 4);
372        check_invariants(&set)?;
373
374        let set = SliceSet::new(orig.clone(), 4, Some(4));
375        assert_eq!(r(set.count())?, 4);
376        check_invariants(&set)?;
377
378        let set = SliceSet::new(orig.clone(), 7, Some(4));
379        assert_eq!(r(set.count())?, 2);
380        check_invariants(&set)?;
381
382        let set = SliceSet::new(orig.clone(), 20, Some(4));
383        assert_eq!(r(set.count())?, 0);
384        check_invariants(&set)?;
385
386        let set = SliceSet::new(orig.clone(), 20, Some(0));
387        assert_eq!(r(set.count())?, 0);
388        check_invariants(&set)?;
389
390        Ok(())
391    }
392
393    #[test]
394    fn test_debug() {
395        let orig = NameSet::from("a b c d e f g h i");
396        let set = SliceSet::new(orig.clone(), 0, None);
397        assert_eq!(
398            format!("{:?}", set),
399            "<slice <static [a, b, c] + 6 more> [..]>"
400        );
401        let set = SliceSet::new(orig.clone(), 4, None);
402        assert_eq!(
403            format!("{:?}", set),
404            "<slice <static [a, b, c] + 6 more> [4..]>"
405        );
406        let set = SliceSet::new(orig.clone(), 4, Some(4));
407        assert_eq!(
408            format!("{:?}", set),
409            "<slice <static [a, b, c] + 6 more> [4..8]>"
410        );
411        let set = SliceSet::new(orig.clone(), 0, Some(4));
412        assert_eq!(
413            format!("{:?}", set),
414            "<slice <static [a, b, c] + 6 more> [..4]>"
415        );
416    }
417
418    quickcheck::quickcheck! {
419        fn test_static_quickcheck(skip_and_take: u8) -> bool {
420            let skip = (skip_and_take & 0xf) as u64;
421            let take = (skip_and_take >> 4) as u64;
422            let take = if take > 12 {
423                None
424            } else {
425                Some(take as u64)
426            };
427            let orig = NameSet::from("a c b d e f g i h j");
428            let set = SliceSet::new(orig, skip, take);
429            check_invariants(&set).unwrap();
430            true
431        }
432    }
433}