use std::any::Any;
use std::collections::HashSet;
use std::fmt;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering::Acquire;
use std::sync::atomic::Ordering::Release;
use std::sync::Arc;
use futures::lock::Mutex;
use futures::StreamExt;
use indexmap::IndexSet;
use tracing::debug;
use tracing::instrument;
use tracing::trace;
use tracing::Level;
use super::hints::Flags;
use super::AsyncNameSetQuery;
use super::BoxVertexStream;
use super::Hints;
use super::NameSet;
use crate::fmt::write_debug;
use crate::Result;
use crate::VertexName;
#[derive(Clone)]
pub struct SliceSet {
inner: NameSet,
hints: Hints,
skip_count: u64,
take_count: Option<u64>,
skip_cache: Arc<Mutex<HashSet<VertexName>>>,
take_cache: Arc<Mutex<IndexSet<VertexName>>>,
take_cache_complete: Arc<AtomicBool>,
}
impl SliceSet {
pub fn new(set: NameSet, skip_count: u64, take_count: Option<u64>) -> Self {
let hints = set.hints().clone();
hints.update_flags_with(|mut f| {
f &= Flags::ID_DESC
| Flags::ID_ASC
| Flags::TOPO_DESC
| Flags::HAS_MIN_ID
| Flags::HAS_MAX_ID
| Flags::EMPTY;
if take_count == Some(0) {
f |= Flags::EMPTY;
}
f
});
Self {
inner: set,
hints,
skip_count,
take_count,
skip_cache: Default::default(),
take_cache: Default::default(),
take_cache_complete: Default::default(),
}
}
fn is_take_cache_complete(&self) -> bool {
self.take_cache_complete.load(Acquire)
}
async fn is_skip_cache_complete(&self) -> bool {
self.skip_cache.lock().await.len() as u64 == self.skip_count
}
#[instrument(level=Level::DEBUG)]
async fn populate_take_cache(&self) -> Result<()> {
assert!(self.take_count.is_some());
let mut iter = self.iter().await?;
while let Some(_) = iter.next().await {}
assert!(self.is_take_cache_complete());
Ok(())
}
}
struct Iter {
inner_iter: BoxVertexStream,
set: SliceSet,
index: u64,
}
const SKIP_CACHE_SIZE_THRESHOLD: u64 = 1000;
impl Iter {
async fn next(&mut self) -> Option<Result<VertexName>> {
if self.set.is_take_cache_complete() {
let index = self.index.max(self.set.skip_count);
let take_index = index - self.set.skip_count;
let result = {
let cache = self.set.take_cache.lock().await;
cache.get_index(take_index as _).cloned()
};
trace!("next(index={}) = {:?} (fast path)", index, &result);
self.index = index + 1;
return Ok(result).transpose();
}
loop {
let index = self.index;
trace!("next(index={})", index);
let next: Option<VertexName> = match self.inner_iter.next().await {
Some(Err(e)) => {
self.index = u64::MAX;
return Some(Err(e));
}
Some(Ok(v)) => Some(v),
None => None,
};
self.index += 1;
if index < self.set.skip_count {
if index < SKIP_CACHE_SIZE_THRESHOLD {
if let Some(v) = next.as_ref() {
let mut cache = self.set.skip_cache.lock().await;
cache.insert(v.clone());
}
}
continue;
}
let take_index = index - self.set.skip_count;
let should_take: bool = match self.set.take_count {
Some(count) => {
if take_index < count {
let mut cache = self.set.take_cache.lock().await;
if take_index == cache.len() as u64 {
if let Some(v) = next.as_ref() {
cache.insert(v.clone());
} else {
self.set.take_cache_complete.store(true, Release);
}
}
true
} else {
self.set.take_cache_complete.store(true, Release);
false
}
}
None => {
true
}
};
if should_take {
return next.map(Ok);
} else {
return None;
}
}
}
fn into_stream(self) -> BoxVertexStream {
Box::pin(futures::stream::unfold(self, |mut state| async move {
let result = state.next().await;
result.map(|r| (r, state))
}))
}
}
struct TakeCacheRevIter {
take_cache: Arc<Mutex<IndexSet<VertexName>>>,
index: usize,
}
impl TakeCacheRevIter {
async fn next(&mut self) -> Option<Result<VertexName>> {
let index = self.index;
self.index += 1;
let cache = self.take_cache.lock().await;
if index >= cache.len() {
None
} else {
let index = cache.len() - index - 1;
cache.get_index(index).cloned().map(Ok)
}
}
fn into_stream(self) -> BoxVertexStream {
Box::pin(futures::stream::unfold(self, |mut state| async move {
let result = state.next().await;
result.map(|r| (r, state))
}))
}
}
#[async_trait::async_trait]
impl AsyncNameSetQuery for SliceSet {
async fn iter(&self) -> Result<BoxVertexStream> {
let inner_iter = self.inner.iter().await?;
let iter = Iter {
inner_iter,
set: self.clone(),
index: 0,
};
Ok(iter.into_stream())
}
async fn iter_rev(&self) -> Result<BoxVertexStream> {
if let Some(_take) = self.take_count {
self.populate_take_cache().await?;
trace!("iter_rev({:0.6?}): use take_cache", self);
let iter = TakeCacheRevIter {
take_cache: self.take_cache.clone(),
index: 0,
};
Ok(iter.into_stream())
} else {
trace!("iter_rev({:0.6?}): use inner.iter_rev()", self,);
let count = self.count().await?;
let iter = self.inner.iter_rev().await?;
Ok(Box::pin(iter.take(count)))
}
}
async fn count(&self) -> Result<usize> {
let count = self.inner.count().await?;
let count = (count as u64).max(self.skip_count) - self.skip_count;
let count = count.min(self.take_count.unwrap_or(u64::MAX));
Ok(count as _)
}
async fn contains(&self, name: &VertexName) -> Result<bool> {
if let Some(result) = self.contains_fast(name).await? {
return Ok(result);
}
debug!("SliceSet::contains({:.6?}, {:?}) (slow path)", self, name);
let mut iter = self.iter().await?;
while let Some(item) = iter.next().await {
if &item? == name {
return Ok(true);
}
}
Ok(false)
}
async fn contains_fast(&self, name: &VertexName) -> Result<Option<bool>> {
{
let take_cache = self.take_cache.lock().await;
let is_take_cache_complete = self.is_take_cache_complete();
let contains = take_cache.contains(name);
match (contains, is_take_cache_complete) {
(_, true) | (true, _) => return Ok(Some(contains)),
(false, false) => {}
}
}
let skip_contains = self.skip_cache.lock().await.contains(name);
if skip_contains {
return Ok(Some(false));
}
let result = self.inner.contains_fast(name).await?;
match (result, self.is_skip_cache_complete().await) {
(Some(false), _) => Ok(Some(false)),
(Some(true), true) => {
debug_assert!(!self.skip_cache.lock().await.contains(name));
Ok(Some(true))
}
(None, false) => Ok(None),
(Some(true), false) => Ok(None),
(None, true) => Ok(None),
}
}
fn as_any(&self) -> &dyn Any {
self
}
fn hints(&self) -> &Hints {
&self.hints
}
}
impl fmt::Debug for SliceSet {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("<slice")?;
write_debug(f, &self.inner)?;
f.write_str(" [")?;
if self.skip_count > 0 {
write!(f, "{}", self.skip_count)?;
}
f.write_str("..")?;
if let Some(n) = self.take_count {
write!(f, "{}", self.skip_count + n)?;
}
f.write_str("]>")
}
}
#[cfg(test)]
#[allow(clippy::redundant_clone)]
mod tests {
use nonblocking::non_blocking_result as r;
use super::super::tests::*;
use super::*;
#[test]
fn test_basic() -> Result<()> {
let orig = NameSet::from("a b c d e f g h i");
let count = r(orig.count())?;
let set = SliceSet::new(orig.clone(), 0, None);
assert_eq!(r(set.count())?, count);
check_invariants(&set)?;
let set = SliceSet::new(orig.clone(), 0, Some(0));
assert_eq!(r(set.count())?, 0);
check_invariants(&set)?;
let set = SliceSet::new(orig.clone(), 4, None);
assert_eq!(r(set.count())?, count - 4);
check_invariants(&set)?;
let set = SliceSet::new(orig.clone(), 4, Some(0));
assert_eq!(r(set.count())?, 0);
check_invariants(&set)?;
let set = SliceSet::new(orig.clone(), 0, Some(4));
assert_eq!(r(set.count())?, 4);
check_invariants(&set)?;
let set = SliceSet::new(orig.clone(), 4, Some(4));
assert_eq!(r(set.count())?, 4);
check_invariants(&set)?;
let set = SliceSet::new(orig.clone(), 7, Some(4));
assert_eq!(r(set.count())?, 2);
check_invariants(&set)?;
let set = SliceSet::new(orig.clone(), 20, Some(4));
assert_eq!(r(set.count())?, 0);
check_invariants(&set)?;
let set = SliceSet::new(orig.clone(), 20, Some(0));
assert_eq!(r(set.count())?, 0);
check_invariants(&set)?;
Ok(())
}
#[test]
fn test_debug() {
let orig = NameSet::from("a b c d e f g h i");
let set = SliceSet::new(orig.clone(), 0, None);
assert_eq!(
format!("{:?}", set),
"<slice <static [a, b, c] + 6 more> [..]>"
);
let set = SliceSet::new(orig.clone(), 4, None);
assert_eq!(
format!("{:?}", set),
"<slice <static [a, b, c] + 6 more> [4..]>"
);
let set = SliceSet::new(orig.clone(), 4, Some(4));
assert_eq!(
format!("{:?}", set),
"<slice <static [a, b, c] + 6 more> [4..8]>"
);
let set = SliceSet::new(orig.clone(), 0, Some(4));
assert_eq!(
format!("{:?}", set),
"<slice <static [a, b, c] + 6 more> [..4]>"
);
}
quickcheck::quickcheck! {
fn test_static_quickcheck(skip_and_take: u8) -> bool {
let skip = (skip_and_take & 0xf) as u64;
let take = (skip_and_take >> 4) as u64;
let take = if take > 12 {
None
} else {
Some(take as u64)
};
let orig = NameSet::from("a c b d e f g i h j");
let set = SliceSet::new(orig, skip, take);
check_invariants(&set).unwrap();
true
}
}
}