use crate::handles::{AsyncCache, Cache};
use std::collections::VecDeque;
use std::future::Future;
use std::hash::{BuildHasher, Hash};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct Cursor {
shard_index: usize,
items_seen_in_shard: usize,
}
pub const DEFAULT_ITER_BATCH_SIZE: usize = 64;
pub struct Iter<'a, K: Send, V: Send + Sync, H> {
cache: &'a Cache<K, V, H>,
buffer: VecDeque<(K, Arc<V>)>,
cursor: Cursor,
batch_size: usize,
finished: bool,
}
impl<'a, K, V, H> Iter<'a, K, V, H>
where
K: Eq + Hash + Clone + Send,
V: Send + Sync,
H: BuildHasher + Clone,
{
pub(crate) fn new(cache: &'a Cache<K, V, H>, batch_size: usize) -> Self {
Self {
cache,
buffer: VecDeque::with_capacity(batch_size),
cursor: Cursor::default(),
batch_size,
finished: false,
}
}
fn refill_buffer(&mut self) {
if self.finished {
return;
}
let num_shards = self.cache.shared.store.shards.len();
while self.cursor.shard_index < num_shards && self.buffer.len() < self.batch_size {
let shard = &self.cache.shared.store.shards[self.cursor.shard_index];
let guard = shard.map.read();
let items_in_shard = guard.len();
if self.cursor.items_seen_in_shard >= items_in_shard {
self.cursor.shard_index += 1;
self.cursor.items_seen_in_shard = 0;
continue;
}
let needed = self.batch_size - self.buffer.len();
let mut scanned = 0;
let chunk = guard
.iter()
.skip(self.cursor.items_seen_in_shard)
.take(needed);
for (key, entry) in chunk {
scanned += 1;
if !entry.is_expired(self.cache.shared.time_to_idle) {
self.buffer.push_back((key.clone(), entry.value()));
}
}
self.cursor.items_seen_in_shard += scanned;
}
if self.cursor.shard_index >= num_shards {
self.finished = true;
}
}
}
impl<'a, K, V, H> Iterator for Iter<'a, K, V, H>
where
K: Eq + Hash + Clone + Send,
V: Send + Sync,
H: BuildHasher + Clone,
{
type Item = (K, Arc<V>);
fn next(&mut self) -> Option<Self::Item> {
if let Some(item) = self.buffer.pop_front() {
return Some(item);
}
if self.finished {
return None;
}
self.refill_buffer();
self.buffer.pop_front()
}
}
pub struct IterStream<K: Send, V: Send + Sync, H> {
cache: AsyncCache<K, V, H>,
buffer: VecDeque<(K, Arc<V>)>,
cursor: Cursor,
batch_size: usize,
finished: bool,
refill_future:
Option<Pin<Box<dyn Future<Output = (VecDeque<(K, Arc<V>)>, Cursor, bool)> + Send + 'static>>>,
}
impl<K, V, H> IterStream<K, V, H>
where
K: Eq + Hash + Clone + Send + Sync,
V: Send + Sync,
H: BuildHasher + Clone + Send,
{
pub(crate) fn new(cache: &AsyncCache<K, V, H>, batch_size: usize) -> Self {
Self {
cache: cache.clone(),
buffer: VecDeque::with_capacity(batch_size),
cursor: Cursor::default(),
batch_size,
finished: false,
refill_future: None,
}
}
}
impl<K, V, H> futures_util::Stream for IterStream<K, V, H>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
V: Send + Sync + 'static,
H: BuildHasher + Clone + Send + Sync + 'static,
{
type Item = (K, Arc<V>);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = unsafe { self.as_mut().get_unchecked_mut() };
if let Some(item) = this.buffer.pop_front() {
return Poll::Ready(Some(item));
}
if this.finished {
return Poll::Ready(None);
}
if let Some(ref mut fut) = this.refill_future {
match fut.as_mut().poll(cx) {
Poll::Ready((batch, new_cursor, finished_flag)) => {
this.refill_future = None;
this.cursor = new_cursor;
this.finished = finished_flag;
this.buffer.extend(batch);
if let Some(item) = this.buffer.pop_front() {
Poll::Ready(Some(item))
} else {
debug_assert!(this.finished);
Poll::Ready(None)
}
}
Poll::Pending => return Poll::Pending,
}
} else {
let cache_clone = this.cache.clone();
let cursor_snapshot = this.cursor;
let batch_size = this.batch_size;
let time_to_idle = cache_clone.shared.time_to_idle;
let mut fut = Box::pin(async move {
let mut cursor = cursor_snapshot;
let mut local_buf = VecDeque::new();
let num_shards = cache_clone.shared.store.shards.len();
while cursor.shard_index < num_shards && local_buf.len() < batch_size {
let shard = &cache_clone.shared.store.shards[cursor.shard_index];
let guard = shard.map.read_async().await;
let items_in_shard = guard.len();
if cursor.items_seen_in_shard >= items_in_shard {
cursor.shard_index += 1;
cursor.items_seen_in_shard = 0;
continue;
}
let needed = batch_size - local_buf.len();
let mut scanned = 0;
let chunk = guard.iter().skip(cursor.items_seen_in_shard).take(needed);
for (key, entry) in chunk {
scanned += 1;
if !entry.is_expired(time_to_idle) {
local_buf.push_back((key.clone(), entry.value()));
}
}
cursor.items_seen_in_shard += scanned;
}
let finished = cursor.shard_index >= num_shards;
(local_buf, cursor, finished)
});
match fut.as_mut().poll(cx) {
Poll::Ready((batch, new_cursor, finished_flag)) => {
this.cursor = new_cursor;
this.finished = finished_flag;
this.buffer.extend(batch);
if let Some(item) = this.buffer.pop_front() {
Poll::Ready(Some(item))
} else {
debug_assert!(this.finished);
Poll::Ready(None)
}
}
Poll::Pending => {
this.refill_future = Some(fut);
Poll::Pending
}
}
}
}
}
pub struct SnapshotIter<'a, K: Send, V: Send + Sync, H> {
cache: &'a Cache<K, V, H>,
shard_keys: Vec<K>,
key_idx: usize,
shard_idx: usize,
}
impl<'a, K, V, H> SnapshotIter<'a, K, V, H>
where
K: Eq + Hash + Clone + Send,
V: Send + Sync,
H: BuildHasher + Clone,
{
pub(crate) fn new(cache: &'a Cache<K, V, H>) -> Self {
Self {
cache,
shard_keys: Vec::new(),
key_idx: 0,
shard_idx: 0,
}
}
fn load_next_shard(&mut self) -> bool {
let num_shards = self.cache.shared.store.shards.len();
while self.shard_idx < num_shards {
let shard = &self.cache.shared.store.shards[self.shard_idx];
self.shard_idx += 1;
let guard = shard.map.read();
if !guard.is_empty() {
self.shard_keys = guard.keys().cloned().collect();
self.key_idx = 0;
return true;
}
}
false
}
}
impl<'a, K, V, H> Iterator for SnapshotIter<'a, K, V, H>
where
K: Eq + Hash + Clone + Send,
V: Send + Sync,
H: BuildHasher + Clone,
{
type Item = (K, Arc<V>);
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(key) = self.shard_keys.get(self.key_idx) {
self.key_idx += 1;
if let Some(value) = self.cache.fetch(key) {
return Some((key.clone(), value));
}
} else {
if self.load_next_shard() {
continue;
} else {
return None;
}
}
}
}
}
pub struct AsyncSnapshotIter<'a, K: Send, V: Send + Sync, H> {
cache: &'a AsyncCache<K, V, H>,
shard_keys: Vec<K>,
key_idx: usize,
shard_idx: usize,
}
impl<'a, K, V, H> AsyncSnapshotIter<'a, K, V, H>
where
K: Eq + Hash + Clone + Send + Sync,
V: Send + Sync,
H: BuildHasher + Clone + Send,
{
pub(crate) fn new(cache: &'a AsyncCache<K, V, H>) -> Self {
Self {
cache,
shard_keys: Vec::new(),
key_idx: 0,
shard_idx: 0,
}
}
pub async fn next(&mut self) -> Option<(K, Arc<V>)> {
loop {
if let Some(key) = self.shard_keys.get(self.key_idx) {
self.key_idx += 1;
if let Some(value) = self.cache.fetch(key).await {
return Some((key.clone(), value));
}
} else {
if self.load_next_shard().await {
continue;
} else {
return None; }
}
}
}
async fn load_next_shard(&mut self) -> bool {
let num_shards = self.cache.shared.store.shards.len();
while self.shard_idx < num_shards {
let shard = &self.cache.shared.store.shards[self.shard_idx];
self.shard_idx += 1;
let guard = shard.map.read_async().await;
if !guard.is_empty() {
self.shard_keys = guard.keys().cloned().collect();
self.key_idx = 0;
return true;
}
}
false
}
}