1use std::any::Any;
9use std::collections::HashSet;
10use std::fmt;
11use std::sync::atomic::AtomicBool;
12use std::sync::atomic::Ordering::Acquire;
13use std::sync::atomic::Ordering::Release;
14use std::sync::Arc;
15
16use futures::lock::Mutex;
17use futures::StreamExt;
18use indexmap::IndexSet;
19use tracing::debug;
20use tracing::instrument;
21use tracing::trace;
22use tracing::Level;
23
24use super::hints::Flags;
25use super::AsyncNameSetQuery;
26use super::BoxVertexStream;
27use super::Hints;
28use super::NameSet;
29use crate::fmt::write_debug;
30use crate::Result;
31use crate::VertexName;
32
33#[derive(Clone)]
35pub struct SliceSet {
36 inner: NameSet,
37 hints: Hints,
38 skip_count: u64,
39 take_count: Option<u64>,
40
41 skip_cache: Arc<Mutex<HashSet<VertexName>>>,
43 take_cache: Arc<Mutex<IndexSet<VertexName>>>,
45 take_cache_complete: Arc<AtomicBool>,
47}
48
49impl SliceSet {
50 pub fn new(set: NameSet, skip_count: u64, take_count: Option<u64>) -> Self {
51 let hints = set.hints().clone();
52 hints.update_flags_with(|mut f| {
53 f &= Flags::ID_DESC
55 | Flags::ID_ASC
56 | Flags::TOPO_DESC
57 | Flags::HAS_MIN_ID
58 | Flags::HAS_MAX_ID
59 | Flags::EMPTY;
60 if take_count == Some(0) {
62 f |= Flags::EMPTY;
63 }
64 f
65 });
66 Self {
67 inner: set,
68 hints,
69 skip_count,
70 take_count,
71 skip_cache: Default::default(),
72 take_cache: Default::default(),
73 take_cache_complete: Default::default(),
74 }
75 }
76
77 fn is_take_cache_complete(&self) -> bool {
78 self.take_cache_complete.load(Acquire)
79 }
80
81 async fn is_skip_cache_complete(&self) -> bool {
82 self.skip_cache.lock().await.len() as u64 == self.skip_count
83 }
84
85 #[instrument(level=Level::DEBUG)]
86 async fn populate_take_cache(&self) -> Result<()> {
87 assert!(self.take_count.is_some());
90
91 let mut iter = self.iter().await?;
93 while let Some(_) = iter.next().await {}
94 assert!(self.is_take_cache_complete());
95
96 Ok(())
97 }
98}
99
100struct Iter {
101 inner_iter: BoxVertexStream,
102 set: SliceSet,
103 index: u64,
104}
105
106const SKIP_CACHE_SIZE_THRESHOLD: u64 = 1000;
107
108impl Iter {
109 async fn next(&mut self) -> Option<Result<VertexName>> {
110 if self.set.is_take_cache_complete() {
111 let index = self.index.max(self.set.skip_count);
113 let take_index = index - self.set.skip_count;
114 let result = {
115 let cache = self.set.take_cache.lock().await;
116 cache.get_index(take_index as _).cloned()
117 };
118 trace!("next(index={}) = {:?} (fast path)", index, &result);
119 self.index = index + 1;
120 return Ok(result).transpose();
121 }
122
123 loop {
124 let index = self.index;
126 trace!("next(index={})", index);
127 let next: Option<VertexName> = match self.inner_iter.next().await {
128 Some(Err(e)) => {
129 self.index = u64::MAX;
130 return Some(Err(e));
131 }
132 Some(Ok(v)) => Some(v),
133 None => None,
134 };
135 self.index += 1;
136
137 if index < self.set.skip_count {
139 if index < SKIP_CACHE_SIZE_THRESHOLD {
140 if let Some(v) = next.as_ref() {
142 let mut cache = self.set.skip_cache.lock().await;
143 cache.insert(v.clone());
144 }
145 }
146 continue;
147 }
148
149 let take_index = index - self.set.skip_count;
151 let should_take: bool = match self.set.take_count {
152 Some(count) => {
153 if take_index < count {
154 let mut cache = self.set.take_cache.lock().await;
156 if take_index == cache.len() as u64 {
157 if let Some(v) = next.as_ref() {
158 cache.insert(v.clone());
159 } else {
160 self.set.take_cache_complete.store(true, Release);
162 }
163 }
164 true
165 } else {
166 self.set.take_cache_complete.store(true, Release);
167 false
168 }
169 }
170 None => {
171 true
174 }
175 };
176 if should_take {
177 return next.map(Ok);
178 } else {
179 return None;
180 }
181 }
182 }
183
184 fn into_stream(self) -> BoxVertexStream {
185 Box::pin(futures::stream::unfold(self, |mut state| async move {
186 let result = state.next().await;
187 result.map(|r| (r, state))
188 }))
189 }
190}
191
192struct TakeCacheRevIter {
193 take_cache: Arc<Mutex<IndexSet<VertexName>>>,
194 index: usize,
195}
196
197impl TakeCacheRevIter {
198 async fn next(&mut self) -> Option<Result<VertexName>> {
199 let index = self.index;
200 self.index += 1;
201 let cache = self.take_cache.lock().await;
202 if index >= cache.len() {
203 None
204 } else {
205 let index = cache.len() - index - 1;
206 cache.get_index(index).cloned().map(Ok)
207 }
208 }
209
210 fn into_stream(self) -> BoxVertexStream {
211 Box::pin(futures::stream::unfold(self, |mut state| async move {
212 let result = state.next().await;
213 result.map(|r| (r, state))
214 }))
215 }
216}
217
218#[async_trait::async_trait]
219impl AsyncNameSetQuery for SliceSet {
220 async fn iter(&self) -> Result<BoxVertexStream> {
221 let inner_iter = self.inner.iter().await?;
222 let iter = Iter {
223 inner_iter,
224 set: self.clone(),
225 index: 0,
226 };
227 Ok(iter.into_stream())
228 }
229
230 async fn iter_rev(&self) -> Result<BoxVertexStream> {
231 if let Some(_take) = self.take_count {
232 self.populate_take_cache().await?;
233 trace!("iter_rev({:0.6?}): use take_cache", self);
234 let iter = TakeCacheRevIter {
239 take_cache: self.take_cache.clone(),
240 index: 0,
241 };
242 Ok(iter.into_stream())
243 } else {
244 trace!("iter_rev({:0.6?}): use inner.iter_rev()", self,);
248 let count = self.count().await?;
249 let iter = self.inner.iter_rev().await?;
250 Ok(Box::pin(iter.take(count)))
251 }
252 }
253
254 async fn count(&self) -> Result<usize> {
255 let count = self.inner.count().await?;
256 let count = (count as u64).max(self.skip_count) - self.skip_count;
258 let count = count.min(self.take_count.unwrap_or(u64::MAX));
260 Ok(count as _)
261 }
262
263 async fn contains(&self, name: &VertexName) -> Result<bool> {
264 if let Some(result) = self.contains_fast(name).await? {
265 return Ok(result);
266 }
267
268 debug!("SliceSet::contains({:.6?}, {:?}) (slow path)", self, name);
269 let mut iter = self.iter().await?;
270 while let Some(item) = iter.next().await {
271 if &item? == name {
272 return Ok(true);
273 }
274 }
275 Ok(false)
276 }
277
278 async fn contains_fast(&self, name: &VertexName) -> Result<Option<bool>> {
279 {
281 let take_cache = self.take_cache.lock().await;
282 let is_take_cache_complete = self.is_take_cache_complete();
283 let contains = take_cache.contains(name);
284 match (contains, is_take_cache_complete) {
285 (_, true) | (true, _) => return Ok(Some(contains)),
286 (false, false) => {}
287 }
288 }
289
290 let skip_contains = self.skip_cache.lock().await.contains(name);
293 if skip_contains {
294 return Ok(Some(false));
295 }
296
297 let result = self.inner.contains_fast(name).await?;
299 match (result, self.is_skip_cache_complete().await) {
300 (Some(false), _) => Ok(Some(false)),
302 (Some(true), true) => {
305 debug_assert!(!self.skip_cache.lock().await.contains(name));
307 Ok(Some(true))
308 }
309 (None, false) => Ok(None),
311 (Some(true), false) => Ok(None),
312 (None, true) => Ok(None),
313 }
314 }
315
316 fn as_any(&self) -> &dyn Any {
317 self
318 }
319
320 fn hints(&self) -> &Hints {
321 &self.hints
322 }
323}
324
325impl fmt::Debug for SliceSet {
326 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
327 f.write_str("<slice")?;
328 write_debug(f, &self.inner)?;
329 f.write_str(" [")?;
330 if self.skip_count > 0 {
331 write!(f, "{}", self.skip_count)?;
332 }
333 f.write_str("..")?;
334 if let Some(n) = self.take_count {
335 write!(f, "{}", self.skip_count + n)?;
336 }
337 f.write_str("]>")
338 }
339}
340
341#[cfg(test)]
342#[allow(clippy::redundant_clone)]
343mod tests {
344 use nonblocking::non_blocking_result as r;
345
346 use super::super::tests::*;
347 use super::*;
348
349 #[test]
350 fn test_basic() -> Result<()> {
351 let orig = NameSet::from("a b c d e f g h i");
352 let count = r(orig.count())?;
353
354 let set = SliceSet::new(orig.clone(), 0, None);
355 assert_eq!(r(set.count())?, count);
356 check_invariants(&set)?;
357
358 let set = SliceSet::new(orig.clone(), 0, Some(0));
359 assert_eq!(r(set.count())?, 0);
360 check_invariants(&set)?;
361
362 let set = SliceSet::new(orig.clone(), 4, None);
363 assert_eq!(r(set.count())?, count - 4);
364 check_invariants(&set)?;
365
366 let set = SliceSet::new(orig.clone(), 4, Some(0));
367 assert_eq!(r(set.count())?, 0);
368 check_invariants(&set)?;
369
370 let set = SliceSet::new(orig.clone(), 0, Some(4));
371 assert_eq!(r(set.count())?, 4);
372 check_invariants(&set)?;
373
374 let set = SliceSet::new(orig.clone(), 4, Some(4));
375 assert_eq!(r(set.count())?, 4);
376 check_invariants(&set)?;
377
378 let set = SliceSet::new(orig.clone(), 7, Some(4));
379 assert_eq!(r(set.count())?, 2);
380 check_invariants(&set)?;
381
382 let set = SliceSet::new(orig.clone(), 20, Some(4));
383 assert_eq!(r(set.count())?, 0);
384 check_invariants(&set)?;
385
386 let set = SliceSet::new(orig.clone(), 20, Some(0));
387 assert_eq!(r(set.count())?, 0);
388 check_invariants(&set)?;
389
390 Ok(())
391 }
392
393 #[test]
394 fn test_debug() {
395 let orig = NameSet::from("a b c d e f g h i");
396 let set = SliceSet::new(orig.clone(), 0, None);
397 assert_eq!(
398 format!("{:?}", set),
399 "<slice <static [a, b, c] + 6 more> [..]>"
400 );
401 let set = SliceSet::new(orig.clone(), 4, None);
402 assert_eq!(
403 format!("{:?}", set),
404 "<slice <static [a, b, c] + 6 more> [4..]>"
405 );
406 let set = SliceSet::new(orig.clone(), 4, Some(4));
407 assert_eq!(
408 format!("{:?}", set),
409 "<slice <static [a, b, c] + 6 more> [4..8]>"
410 );
411 let set = SliceSet::new(orig.clone(), 0, Some(4));
412 assert_eq!(
413 format!("{:?}", set),
414 "<slice <static [a, b, c] + 6 more> [..4]>"
415 );
416 }
417
418 quickcheck::quickcheck! {
419 fn test_static_quickcheck(skip_and_take: u8) -> bool {
420 let skip = (skip_and_take & 0xf) as u64;
421 let take = (skip_and_take >> 4) as u64;
422 let take = if take > 12 {
423 None
424 } else {
425 Some(take as u64)
426 };
427 let orig = NameSet::from("a c b d e f g i h j");
428 let set = SliceSet::new(orig, skip, take);
429 check_invariants(&set).unwrap();
430 true
431 }
432 }
433}