Skip to main content

dag/set/
slice.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8use std::any::Any;
9use std::borrow::Cow;
10use std::collections::HashSet;
11use std::fmt;
12use std::sync::atomic::AtomicBool;
13use std::sync::atomic::Ordering::Acquire;
14use std::sync::atomic::Ordering::Release;
15use std::sync::Arc;
16
17use futures::lock::Mutex;
18use futures::StreamExt;
19use indexmap::IndexSet;
20use tracing::debug;
21use tracing::instrument;
22use tracing::trace;
23use tracing::Level;
24
25use super::hints::Flags;
26use super::id_static::IdStaticSet;
27use super::AsyncSetQuery;
28use super::BoxVertexStream;
29use super::Hints;
30use super::Set;
31use crate::fmt::write_debug;
32use crate::Result;
33use crate::Vertex;
34
35/// Slice of a set.
36#[derive(Clone)]
37pub struct SliceSet {
38    inner: Set,
39    hints: Hints,
40    skip_count: u64,
41    take_count: Option<u64>,
42
43    // Skipped vertexes. Updated during iteration.
44    skip_cache: Arc<Mutex<HashSet<Vertex>>>,
45    // Taken vertexes. Updated during iteration.
46    take_cache: Arc<Mutex<IndexSet<Vertex>>>,
47    // If take_cache is complete.
48    take_cache_complete: Arc<AtomicBool>,
49}
50
51impl SliceSet {
52    pub fn new(set: Set, skip_count: u64, take_count: Option<u64>) -> Self {
53        let hints = set.hints().clone();
54        hints.update_flags_with(|mut f| {
55            // Only keep compatible flags.
56            f &= Flags::ID_DESC
57                | Flags::ID_ASC
58                | Flags::TOPO_DESC
59                | Flags::HAS_MIN_ID
60                | Flags::HAS_MAX_ID
61                | Flags::EMPTY;
62            // Add EMPTY hints if take_count is 0.
63            if take_count == Some(0) {
64                f |= Flags::EMPTY;
65            }
66            f
67        });
68        Self {
69            inner: set,
70            hints,
71            skip_count,
72            take_count,
73            skip_cache: Default::default(),
74            take_cache: Default::default(),
75            take_cache_complete: Default::default(),
76        }
77    }
78
79    fn is_take_cache_complete(&self) -> bool {
80        self.take_cache_complete.load(Acquire)
81    }
82
83    async fn is_skip_cache_complete(&self) -> bool {
84        self.skip_cache.lock().await.len() as u64 == self.skip_count
85    }
86
87    #[instrument(level=Level::DEBUG)]
88    async fn populate_take_cache(&self) -> Result<()> {
89        // See Iter::next. If take_count is not set, the "take" can be unbounded,
90        // and take_cache won't be populated.
91        assert!(self.take_count.is_some());
92
93        // Use iter() to populate take_cache.
94        let mut iter = self.iter().await?;
95        while let Some(_) = iter.next().await {}
96        assert!(self.is_take_cache_complete());
97
98        Ok(())
99    }
100}
101
102struct Iter {
103    inner_iter: BoxVertexStream,
104    set: SliceSet,
105    index: u64,
106    ended: bool,
107}
108
109const SKIP_CACHE_SIZE_THRESHOLD: u64 = 1000;
110
111impl Iter {
112    async fn next(&mut self) -> Option<Result<Vertex>> {
113        if self.ended {
114            return None;
115        }
116        if self.set.is_take_cache_complete() {
117            // Fast path - no need to use inner_iter.
118            let index = self.index.max(self.set.skip_count);
119            let take_index = index - self.set.skip_count;
120            let result = {
121                let cache = self.set.take_cache.lock().await;
122                cache.get_index(take_index as _).cloned()
123            };
124            trace!("next(index={}) = {:?} (fast path)", index, &result);
125            self.index = index + 1;
126            return Ok(result).transpose();
127        }
128
129        loop {
130            // Slow path - use inner_iter.
131            let index = self.index;
132            trace!("next(index={})", index);
133            let next: Option<Vertex> = match self.inner_iter.next().await {
134                Some(Err(e)) => {
135                    self.index = u64::MAX;
136                    return Some(Err(e));
137                }
138                Some(Ok(v)) => Some(v),
139                None => None,
140            };
141            self.index += 1;
142
143            // Skip?
144            if index < self.set.skip_count {
145                if index < SKIP_CACHE_SIZE_THRESHOLD {
146                    // Update skip_cache.
147                    if let Some(v) = next.as_ref() {
148                        let mut cache = self.set.skip_cache.lock().await;
149                        cache.insert(v.clone());
150                    }
151                }
152                continue;
153            }
154
155            // Take?
156            let take_index = index - self.set.skip_count;
157            let should_take: bool = match self.set.take_count {
158                Some(count) => {
159                    if take_index < count {
160                        // Update take_cache.
161                        let mut cache = self.set.take_cache.lock().await;
162                        if take_index == cache.len() as u64 {
163                            if let Some(v) = next.as_ref() {
164                                cache.insert(v.clone());
165                            } else {
166                                // No more item in the original set.
167                                self.set.take_cache_complete.store(true, Release);
168                            }
169                        }
170                        true
171                    } else {
172                        self.set.take_cache_complete.store(true, Release);
173                        false
174                    }
175                }
176                None => {
177                    // Do not update take_cache, since the inner
178                    // set can be quite large.
179                    true
180                }
181            };
182
183            if next.is_none() {
184                self.ended = true;
185            }
186
187            if should_take {
188                return next.map(Ok);
189            } else {
190                return None;
191            }
192        }
193    }
194
195    fn into_stream(self) -> BoxVertexStream {
196        Box::pin(futures::stream::unfold(self, |mut state| async move {
197            let result = state.next().await;
198            result.map(|r| (r, state))
199        }))
200    }
201}
202
203struct TakeCacheRevIter {
204    take_cache: Arc<Mutex<IndexSet<Vertex>>>,
205    index: usize,
206}
207
208impl TakeCacheRevIter {
209    async fn next(&mut self) -> Option<Result<Vertex>> {
210        let index = self.index;
211        self.index += 1;
212        let cache = self.take_cache.lock().await;
213        if index >= cache.len() {
214            None
215        } else {
216            let index = cache.len() - index - 1;
217            cache.get_index(index).cloned().map(Ok)
218        }
219    }
220
221    fn into_stream(self) -> BoxVertexStream {
222        Box::pin(futures::stream::unfold(self, |mut state| async move {
223            let result = state.next().await;
224            result.map(|r| (r, state))
225        }))
226    }
227}
228
229#[async_trait::async_trait]
230impl AsyncSetQuery for SliceSet {
231    async fn iter(&self) -> Result<BoxVertexStream> {
232        let inner_iter = self.inner.iter().await?;
233        let iter = Iter {
234            inner_iter,
235            set: self.clone(),
236            index: 0,
237            ended: false,
238        };
239        Ok(iter.into_stream())
240    }
241
242    async fn iter_rev(&self) -> Result<BoxVertexStream> {
243        if let Some(_take) = self.take_count {
244            self.populate_take_cache().await?;
245            trace!("iter_rev({:0.6?}): use take_cache", self);
246            // Use take_cache to answer RevIter. This is probably better
247            // than using inner.iter_rev(), if take_count is small:
248            //     [<----------------------------]
249            //     [skip][take][...(need skip)...]
250            let iter = TakeCacheRevIter {
251                take_cache: self.take_cache.clone(),
252                index: 0,
253            };
254            Ok(iter.into_stream())
255        } else {
256            // Unbounded "take_count". Reuse inner.rev_iter().
257            //     [<-------------------]
258            //     [skip][<---take------]
259            trace!("iter_rev({:0.6?}): use inner.iter_rev()", self,);
260            let count = self.count().await?;
261            let iter = self.inner.iter_rev().await?;
262            let count = count.try_into()?;
263            Ok(Box::pin(iter.take(count)))
264        }
265    }
266
267    async fn count(&self) -> Result<u64> {
268        let count = self.inner.count().await?;
269        // consider skip_count
270        let count = (count as u64).max(self.skip_count) - self.skip_count;
271        // consider take_count
272        let count = count.min(self.take_count.unwrap_or(u64::MAX));
273        Ok(count)
274    }
275
276    async fn size_hint(&self) -> (u64, Option<u64>) {
277        let (min, max) = self.inner.size_hint().await;
278        // [0 .. min .. max]
279        // [ skip ][--- take ---]
280        let skip = self.skip_count;
281        let take = self.take_count;
282        let min = match take {
283            None => min.saturating_sub(skip),
284            Some(take) => min.saturating_sub(skip).min(take),
285        };
286        let max = match (max, take) {
287            (Some(max), Some(take)) => Some(max.saturating_sub(skip).min(take)),
288            (Some(max), None) => Some(max.saturating_sub(skip)),
289            (None, Some(take)) => Some(take),
290            (None, None) => None,
291        };
292        (min, max)
293    }
294
295    async fn contains(&self, name: &Vertex) -> Result<bool> {
296        if let Some(result) = self.contains_fast(name).await? {
297            return Ok(result);
298        }
299
300        debug!("SliceSet::contains({:.6?}, {:?}) (slow path)", self, name);
301        let mut iter = self.iter().await?;
302        while let Some(item) = iter.next().await {
303            if &item? == name {
304                return Ok(true);
305            }
306        }
307        Ok(false)
308    }
309
310    async fn contains_fast(&self, name: &Vertex) -> Result<Option<bool>> {
311        // Check take_cache.
312        {
313            let take_cache = self.take_cache.lock().await;
314            let is_take_cache_complete = self.is_take_cache_complete();
315            let contains = take_cache.contains(name);
316            match (contains, is_take_cache_complete) {
317                (_, true) | (true, _) => return Ok(Some(contains)),
318                (false, false) => {}
319            }
320        }
321
322        // Check skip_cache.
323        // Assumes one vertex only occurs once in a set.
324        let skip_contains = self.skip_cache.lock().await.contains(name);
325        if skip_contains {
326            return Ok(Some(false));
327        }
328
329        // Check with the original set.
330        let result = self.inner.contains_fast(name).await?;
331        match (result, self.is_skip_cache_complete().await) {
332            // Not in the original set. Slice is a subset. Result: false.
333            (Some(false), _) => Ok(Some(false)),
334            // In the original set. Skip cache is completed _and_ checked
335            // above (name was _not_ skipped). Result: true.
336            (Some(true), true) => {
337                // skip_cache was checked above
338                debug_assert!(!self.skip_cache.lock().await.contains(name));
339                Ok(Some(true))
340            }
341            // Unsure cases.
342            (None, false) => Ok(None),
343            (Some(true), false) => Ok(None),
344            (None, true) => Ok(None),
345        }
346    }
347
348    fn as_any(&self) -> &dyn Any {
349        self
350    }
351
352    fn hints(&self) -> &Hints {
353        &self.hints
354    }
355
356    fn specialized_flatten_id(&self) -> Option<Cow<IdStaticSet>> {
357        // Attention! `inner` might have lost order. So we might not have a fast path.
358        // For example, this is flawed:
359        let inner = self.inner.specialized_flatten_id()?.into_owned();
360        let sensitive_flags = Flags::ID_DESC | Flags::ID_ASC;
361        let expected_flags = self.hints().flags() & sensitive_flags;
362        let mut can_use_fast_path = true;
363        let spans = inner.id_set_try_preserving_order()?;
364        if self.skip_count == 0 && spans.count() <= self.take_count.unwrap_or(u64::MAX) {
365            can_use_fast_path = true
366        } else if expected_flags.is_empty() {
367            can_use_fast_path = false;
368        } else if (inner.hints().flags() & sensitive_flags) != expected_flags {
369            can_use_fast_path = false;
370        }
371        if can_use_fast_path {
372            let result = inner.slice_spans(self.skip_count, self.take_count.unwrap_or(u64::MAX));
373            Some(Cow::Owned(result))
374        } else {
375            None
376        }
377    }
378}
379
380impl fmt::Debug for SliceSet {
381    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
382        f.write_str("<slice")?;
383        write_debug(f, &self.inner)?;
384        f.write_str(" [")?;
385        if self.skip_count > 0 {
386            write!(f, "{}", self.skip_count)?;
387        }
388        f.write_str("..")?;
389        if let Some(n) = self.take_count {
390            write!(f, "{}", self.skip_count + n)?;
391        }
392        f.write_str("]>")
393    }
394}
395
396#[cfg(test)]
397#[allow(clippy::redundant_clone)]
398mod tests {
399    use nonblocking::non_blocking_result as r;
400
401    use super::super::tests::*;
402    use super::*;
403
404    #[test]
405    fn test_basic() -> Result<()> {
406        let orig = Set::from("a b c d e f g h i");
407        let count = r(orig.count())?;
408
409        let set = SliceSet::new(orig.clone(), 0, None);
410        assert_eq!(r(set.count())?, count);
411        check_invariants(&set)?;
412
413        let set = SliceSet::new(orig.clone(), 0, Some(0));
414        assert_eq!(r(set.count())?, 0);
415        check_invariants(&set)?;
416
417        let set = SliceSet::new(orig.clone(), 4, None);
418        assert_eq!(r(set.count())?, count - 4);
419        check_invariants(&set)?;
420
421        let set = SliceSet::new(orig.clone(), 4, Some(0));
422        assert_eq!(r(set.count())?, 0);
423        check_invariants(&set)?;
424
425        let set = SliceSet::new(orig.clone(), 0, Some(4));
426        assert_eq!(r(set.count())?, 4);
427        check_invariants(&set)?;
428
429        let set = SliceSet::new(orig.clone(), 4, Some(4));
430        assert_eq!(r(set.count())?, 4);
431        check_invariants(&set)?;
432
433        let set = SliceSet::new(orig.clone(), 7, Some(4));
434        assert_eq!(r(set.count())?, 2);
435        check_invariants(&set)?;
436
437        let set = SliceSet::new(orig.clone(), 20, Some(4));
438        assert_eq!(r(set.count())?, 0);
439        check_invariants(&set)?;
440
441        let set = SliceSet::new(orig.clone(), 20, Some(0));
442        assert_eq!(r(set.count())?, 0);
443        check_invariants(&set)?;
444
445        Ok(())
446    }
447
448    #[test]
449    fn test_debug() {
450        let orig = Set::from("a b c d e f g h i");
451        let set = SliceSet::new(orig.clone(), 0, None);
452        assert_eq!(dbg(set), "<slice <static [a, b, c] + 6 more> [..]>");
453        let set = SliceSet::new(orig.clone(), 4, None);
454        assert_eq!(dbg(set), "<slice <static [a, b, c] + 6 more> [4..]>");
455        let set = SliceSet::new(orig.clone(), 4, Some(4));
456        assert_eq!(dbg(set), "<slice <static [a, b, c] + 6 more> [4..8]>");
457        let set = SliceSet::new(orig.clone(), 0, Some(4));
458        assert_eq!(dbg(set), "<slice <static [a, b, c] + 6 more> [..4]>");
459    }
460
461    #[test]
462    fn test_size_hint_sets() {
463        let bytes = b"\x11\x22\x33";
464        for skip in 0..(bytes.len() + 2) {
465            for size_hint_adjust in 0..7 {
466                let vec_set = VecQuery::from_bytes(&bytes[..]).adjust_size_hint(size_hint_adjust);
467                let vec_set = Set::from_query(vec_set);
468                for take in 0..(bytes.len() + 2) {
469                    let set = SliceSet::new(vec_set.clone(), skip as _, Some(take as _));
470                    check_invariants(&set).unwrap();
471                }
472                let set = SliceSet::new(vec_set, skip as _, None);
473                check_invariants(&set).unwrap();
474            }
475        }
476    }
477
478    quickcheck::quickcheck! {
479        fn test_static_quickcheck(skip_and_take: u8) -> bool {
480            let skip = (skip_and_take & 0xf) as u64;
481            let take = (skip_and_take >> 4) as u64;
482            let take = if take > 12 {
483                None
484            } else {
485                Some(take)
486            };
487            let orig = Set::from("a c b d e f g i h j");
488            let set = SliceSet::new(orig, skip, take);
489            check_invariants(&set).unwrap();
490            true
491        }
492    }
493}