1use std::any::Any;
9use std::borrow::Cow;
10use std::collections::HashSet;
11use std::fmt;
12use std::sync::atomic::AtomicBool;
13use std::sync::atomic::Ordering::Acquire;
14use std::sync::atomic::Ordering::Release;
15use std::sync::Arc;
16
17use futures::lock::Mutex;
18use futures::StreamExt;
19use indexmap::IndexSet;
20use tracing::debug;
21use tracing::instrument;
22use tracing::trace;
23use tracing::Level;
24
25use super::hints::Flags;
26use super::id_static::IdStaticSet;
27use super::AsyncSetQuery;
28use super::BoxVertexStream;
29use super::Hints;
30use super::Set;
31use crate::fmt::write_debug;
32use crate::Result;
33use crate::Vertex;
34
35#[derive(Clone)]
37pub struct SliceSet {
38 inner: Set,
39 hints: Hints,
40 skip_count: u64,
41 take_count: Option<u64>,
42
43 skip_cache: Arc<Mutex<HashSet<Vertex>>>,
45 take_cache: Arc<Mutex<IndexSet<Vertex>>>,
47 take_cache_complete: Arc<AtomicBool>,
49}
50
51impl SliceSet {
52 pub fn new(set: Set, skip_count: u64, take_count: Option<u64>) -> Self {
53 let hints = set.hints().clone();
54 hints.update_flags_with(|mut f| {
55 f &= Flags::ID_DESC
57 | Flags::ID_ASC
58 | Flags::TOPO_DESC
59 | Flags::HAS_MIN_ID
60 | Flags::HAS_MAX_ID
61 | Flags::EMPTY;
62 if take_count == Some(0) {
64 f |= Flags::EMPTY;
65 }
66 f
67 });
68 Self {
69 inner: set,
70 hints,
71 skip_count,
72 take_count,
73 skip_cache: Default::default(),
74 take_cache: Default::default(),
75 take_cache_complete: Default::default(),
76 }
77 }
78
79 fn is_take_cache_complete(&self) -> bool {
80 self.take_cache_complete.load(Acquire)
81 }
82
83 async fn is_skip_cache_complete(&self) -> bool {
84 self.skip_cache.lock().await.len() as u64 == self.skip_count
85 }
86
87 #[instrument(level=Level::DEBUG)]
88 async fn populate_take_cache(&self) -> Result<()> {
89 assert!(self.take_count.is_some());
92
93 let mut iter = self.iter().await?;
95 while let Some(_) = iter.next().await {}
96 assert!(self.is_take_cache_complete());
97
98 Ok(())
99 }
100}
101
102struct Iter {
103 inner_iter: BoxVertexStream,
104 set: SliceSet,
105 index: u64,
106 ended: bool,
107}
108
109const SKIP_CACHE_SIZE_THRESHOLD: u64 = 1000;
110
111impl Iter {
112 async fn next(&mut self) -> Option<Result<Vertex>> {
113 if self.ended {
114 return None;
115 }
116 if self.set.is_take_cache_complete() {
117 let index = self.index.max(self.set.skip_count);
119 let take_index = index - self.set.skip_count;
120 let result = {
121 let cache = self.set.take_cache.lock().await;
122 cache.get_index(take_index as _).cloned()
123 };
124 trace!("next(index={}) = {:?} (fast path)", index, &result);
125 self.index = index + 1;
126 return Ok(result).transpose();
127 }
128
129 loop {
130 let index = self.index;
132 trace!("next(index={})", index);
133 let next: Option<Vertex> = match self.inner_iter.next().await {
134 Some(Err(e)) => {
135 self.index = u64::MAX;
136 return Some(Err(e));
137 }
138 Some(Ok(v)) => Some(v),
139 None => None,
140 };
141 self.index += 1;
142
143 if index < self.set.skip_count {
145 if index < SKIP_CACHE_SIZE_THRESHOLD {
146 if let Some(v) = next.as_ref() {
148 let mut cache = self.set.skip_cache.lock().await;
149 cache.insert(v.clone());
150 }
151 }
152 continue;
153 }
154
155 let take_index = index - self.set.skip_count;
157 let should_take: bool = match self.set.take_count {
158 Some(count) => {
159 if take_index < count {
160 let mut cache = self.set.take_cache.lock().await;
162 if take_index == cache.len() as u64 {
163 if let Some(v) = next.as_ref() {
164 cache.insert(v.clone());
165 } else {
166 self.set.take_cache_complete.store(true, Release);
168 }
169 }
170 true
171 } else {
172 self.set.take_cache_complete.store(true, Release);
173 false
174 }
175 }
176 None => {
177 true
180 }
181 };
182
183 if next.is_none() {
184 self.ended = true;
185 }
186
187 if should_take {
188 return next.map(Ok);
189 } else {
190 return None;
191 }
192 }
193 }
194
195 fn into_stream(self) -> BoxVertexStream {
196 Box::pin(futures::stream::unfold(self, |mut state| async move {
197 let result = state.next().await;
198 result.map(|r| (r, state))
199 }))
200 }
201}
202
203struct TakeCacheRevIter {
204 take_cache: Arc<Mutex<IndexSet<Vertex>>>,
205 index: usize,
206}
207
208impl TakeCacheRevIter {
209 async fn next(&mut self) -> Option<Result<Vertex>> {
210 let index = self.index;
211 self.index += 1;
212 let cache = self.take_cache.lock().await;
213 if index >= cache.len() {
214 None
215 } else {
216 let index = cache.len() - index - 1;
217 cache.get_index(index).cloned().map(Ok)
218 }
219 }
220
221 fn into_stream(self) -> BoxVertexStream {
222 Box::pin(futures::stream::unfold(self, |mut state| async move {
223 let result = state.next().await;
224 result.map(|r| (r, state))
225 }))
226 }
227}
228
229#[async_trait::async_trait]
230impl AsyncSetQuery for SliceSet {
231 async fn iter(&self) -> Result<BoxVertexStream> {
232 let inner_iter = self.inner.iter().await?;
233 let iter = Iter {
234 inner_iter,
235 set: self.clone(),
236 index: 0,
237 ended: false,
238 };
239 Ok(iter.into_stream())
240 }
241
242 async fn iter_rev(&self) -> Result<BoxVertexStream> {
243 if let Some(_take) = self.take_count {
244 self.populate_take_cache().await?;
245 trace!("iter_rev({:0.6?}): use take_cache", self);
246 let iter = TakeCacheRevIter {
251 take_cache: self.take_cache.clone(),
252 index: 0,
253 };
254 Ok(iter.into_stream())
255 } else {
256 trace!("iter_rev({:0.6?}): use inner.iter_rev()", self,);
260 let count = self.count().await?;
261 let iter = self.inner.iter_rev().await?;
262 let count = count.try_into()?;
263 Ok(Box::pin(iter.take(count)))
264 }
265 }
266
267 async fn count(&self) -> Result<u64> {
268 let count = self.inner.count().await?;
269 let count = (count as u64).max(self.skip_count) - self.skip_count;
271 let count = count.min(self.take_count.unwrap_or(u64::MAX));
273 Ok(count)
274 }
275
276 async fn size_hint(&self) -> (u64, Option<u64>) {
277 let (min, max) = self.inner.size_hint().await;
278 let skip = self.skip_count;
281 let take = self.take_count;
282 let min = match take {
283 None => min.saturating_sub(skip),
284 Some(take) => min.saturating_sub(skip).min(take),
285 };
286 let max = match (max, take) {
287 (Some(max), Some(take)) => Some(max.saturating_sub(skip).min(take)),
288 (Some(max), None) => Some(max.saturating_sub(skip)),
289 (None, Some(take)) => Some(take),
290 (None, None) => None,
291 };
292 (min, max)
293 }
294
295 async fn contains(&self, name: &Vertex) -> Result<bool> {
296 if let Some(result) = self.contains_fast(name).await? {
297 return Ok(result);
298 }
299
300 debug!("SliceSet::contains({:.6?}, {:?}) (slow path)", self, name);
301 let mut iter = self.iter().await?;
302 while let Some(item) = iter.next().await {
303 if &item? == name {
304 return Ok(true);
305 }
306 }
307 Ok(false)
308 }
309
310 async fn contains_fast(&self, name: &Vertex) -> Result<Option<bool>> {
311 {
313 let take_cache = self.take_cache.lock().await;
314 let is_take_cache_complete = self.is_take_cache_complete();
315 let contains = take_cache.contains(name);
316 match (contains, is_take_cache_complete) {
317 (_, true) | (true, _) => return Ok(Some(contains)),
318 (false, false) => {}
319 }
320 }
321
322 let skip_contains = self.skip_cache.lock().await.contains(name);
325 if skip_contains {
326 return Ok(Some(false));
327 }
328
329 let result = self.inner.contains_fast(name).await?;
331 match (result, self.is_skip_cache_complete().await) {
332 (Some(false), _) => Ok(Some(false)),
334 (Some(true), true) => {
337 debug_assert!(!self.skip_cache.lock().await.contains(name));
339 Ok(Some(true))
340 }
341 (None, false) => Ok(None),
343 (Some(true), false) => Ok(None),
344 (None, true) => Ok(None),
345 }
346 }
347
348 fn as_any(&self) -> &dyn Any {
349 self
350 }
351
352 fn hints(&self) -> &Hints {
353 &self.hints
354 }
355
356 fn specialized_flatten_id(&self) -> Option<Cow<IdStaticSet>> {
357 let inner = self.inner.specialized_flatten_id()?.into_owned();
360 let sensitive_flags = Flags::ID_DESC | Flags::ID_ASC;
361 let expected_flags = self.hints().flags() & sensitive_flags;
362 let mut can_use_fast_path = true;
363 let spans = inner.id_set_try_preserving_order()?;
364 if self.skip_count == 0 && spans.count() <= self.take_count.unwrap_or(u64::MAX) {
365 can_use_fast_path = true
366 } else if expected_flags.is_empty() {
367 can_use_fast_path = false;
368 } else if (inner.hints().flags() & sensitive_flags) != expected_flags {
369 can_use_fast_path = false;
370 }
371 if can_use_fast_path {
372 let result = inner.slice_spans(self.skip_count, self.take_count.unwrap_or(u64::MAX));
373 Some(Cow::Owned(result))
374 } else {
375 None
376 }
377 }
378}
379
380impl fmt::Debug for SliceSet {
381 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
382 f.write_str("<slice")?;
383 write_debug(f, &self.inner)?;
384 f.write_str(" [")?;
385 if self.skip_count > 0 {
386 write!(f, "{}", self.skip_count)?;
387 }
388 f.write_str("..")?;
389 if let Some(n) = self.take_count {
390 write!(f, "{}", self.skip_count + n)?;
391 }
392 f.write_str("]>")
393 }
394}
395
396#[cfg(test)]
397#[allow(clippy::redundant_clone)]
398mod tests {
399 use nonblocking::non_blocking_result as r;
400
401 use super::super::tests::*;
402 use super::*;
403
404 #[test]
405 fn test_basic() -> Result<()> {
406 let orig = Set::from("a b c d e f g h i");
407 let count = r(orig.count())?;
408
409 let set = SliceSet::new(orig.clone(), 0, None);
410 assert_eq!(r(set.count())?, count);
411 check_invariants(&set)?;
412
413 let set = SliceSet::new(orig.clone(), 0, Some(0));
414 assert_eq!(r(set.count())?, 0);
415 check_invariants(&set)?;
416
417 let set = SliceSet::new(orig.clone(), 4, None);
418 assert_eq!(r(set.count())?, count - 4);
419 check_invariants(&set)?;
420
421 let set = SliceSet::new(orig.clone(), 4, Some(0));
422 assert_eq!(r(set.count())?, 0);
423 check_invariants(&set)?;
424
425 let set = SliceSet::new(orig.clone(), 0, Some(4));
426 assert_eq!(r(set.count())?, 4);
427 check_invariants(&set)?;
428
429 let set = SliceSet::new(orig.clone(), 4, Some(4));
430 assert_eq!(r(set.count())?, 4);
431 check_invariants(&set)?;
432
433 let set = SliceSet::new(orig.clone(), 7, Some(4));
434 assert_eq!(r(set.count())?, 2);
435 check_invariants(&set)?;
436
437 let set = SliceSet::new(orig.clone(), 20, Some(4));
438 assert_eq!(r(set.count())?, 0);
439 check_invariants(&set)?;
440
441 let set = SliceSet::new(orig.clone(), 20, Some(0));
442 assert_eq!(r(set.count())?, 0);
443 check_invariants(&set)?;
444
445 Ok(())
446 }
447
448 #[test]
449 fn test_debug() {
450 let orig = Set::from("a b c d e f g h i");
451 let set = SliceSet::new(orig.clone(), 0, None);
452 assert_eq!(dbg(set), "<slice <static [a, b, c] + 6 more> [..]>");
453 let set = SliceSet::new(orig.clone(), 4, None);
454 assert_eq!(dbg(set), "<slice <static [a, b, c] + 6 more> [4..]>");
455 let set = SliceSet::new(orig.clone(), 4, Some(4));
456 assert_eq!(dbg(set), "<slice <static [a, b, c] + 6 more> [4..8]>");
457 let set = SliceSet::new(orig.clone(), 0, Some(4));
458 assert_eq!(dbg(set), "<slice <static [a, b, c] + 6 more> [..4]>");
459 }
460
461 #[test]
462 fn test_size_hint_sets() {
463 let bytes = b"\x11\x22\x33";
464 for skip in 0..(bytes.len() + 2) {
465 for size_hint_adjust in 0..7 {
466 let vec_set = VecQuery::from_bytes(&bytes[..]).adjust_size_hint(size_hint_adjust);
467 let vec_set = Set::from_query(vec_set);
468 for take in 0..(bytes.len() + 2) {
469 let set = SliceSet::new(vec_set.clone(), skip as _, Some(take as _));
470 check_invariants(&set).unwrap();
471 }
472 let set = SliceSet::new(vec_set, skip as _, None);
473 check_invariants(&set).unwrap();
474 }
475 }
476 }
477
478 quickcheck::quickcheck! {
479 fn test_static_quickcheck(skip_and_take: u8) -> bool {
480 let skip = (skip_and_take & 0xf) as u64;
481 let take = (skip_and_take >> 4) as u64;
482 let take = if take > 12 {
483 None
484 } else {
485 Some(take)
486 };
487 let orig = Set::from("a c b d e f g i h j");
488 let set = SliceSet::new(orig, skip, take);
489 check_invariants(&set).unwrap();
490 true
491 }
492 }
493}