use crate::batch::WriteBatchIterator;
use crate::bytes_range::BytesRange;
use crate::error::SlateDBError;
use crate::filter_iterator::FilterIterator;
use crate::iter::{EmptyIterator, KeyValueIterator};
use crate::map_iter::MapIterator;
use crate::merge_iterator::MergeIterator;
use crate::merge_operator::{
MergeOperatorIterator, MergeOperatorRequiredIterator, MergeOperatorType,
};
use crate::types::{KeyValue, RowEntry, ValueDeletable};
use async_trait::async_trait;
use bytes::Bytes;
use parking_lot::Mutex;
use std::ops::RangeBounds;
use std::sync::Arc;
#[derive(Debug)]
pub struct DbIteratorRangeTracker {
inner: Mutex<DbIteratorRangeTrackerInner>,
}
#[derive(Debug)]
struct DbIteratorRangeTrackerInner {
first_key: Option<Bytes>,
last_key: Option<Bytes>,
has_data: bool,
}
impl DbIteratorRangeTracker {
pub fn new() -> Self {
Self {
inner: Mutex::new(DbIteratorRangeTrackerInner {
first_key: None,
last_key: None,
has_data: false,
}),
}
}
pub fn track_key(&self, key: &Bytes) {
let mut inner = self.inner.lock();
inner.first_key = Some(match &inner.first_key {
Some(first) if key < first => key.clone(),
Some(first) => first.clone(),
None => key.clone(),
});
inner.last_key = Some(match &inner.last_key {
Some(last) if key > last => key.clone(),
Some(last) => last.clone(),
None => key.clone(),
});
inner.has_data = true;
}
pub fn get_range(&self) -> Option<BytesRange> {
let inner = self.inner.lock();
match (&inner.first_key, &inner.last_key) {
(Some(first), Some(last)) => {
use std::ops::Bound;
Some(BytesRange::from((
Bound::Included(first.clone()),
Bound::Included(last.clone()),
)))
}
_ => None,
}
}
pub fn has_data(&self) -> bool {
self.inner.lock().has_data
}
}
struct GetIterator {
key: Bytes,
iters: Vec<Box<dyn KeyValueIterator + 'static>>,
idx: usize,
}
impl GetIterator {
pub(crate) fn new(
key: Bytes,
write_batch_iter: Box<dyn KeyValueIterator + 'static>,
mem_iters: impl IntoIterator<Item = Box<dyn KeyValueIterator + 'static>>,
l0_iters: impl IntoIterator<Item = Box<dyn KeyValueIterator + 'static>>,
sr_iters: impl IntoIterator<Item = Box<dyn KeyValueIterator + 'static>>,
) -> Self {
let iters = vec![write_batch_iter]
.into_iter()
.chain(mem_iters)
.chain(l0_iters)
.chain(sr_iters)
.collect();
Self { key, iters, idx: 0 }
}
}
#[async_trait]
impl KeyValueIterator for GetIterator {
async fn init(&mut self) -> Result<(), SlateDBError> {
Ok(())
}
async fn next_entry(&mut self) -> Result<Option<RowEntry>, SlateDBError> {
while self.idx < self.iters.len() {
self.iters[self.idx].init().await?;
let result = self.iters[self.idx].next_entry().await?;
if let Some(entry) = result {
match &entry.value {
ValueDeletable::Tombstone => {
return Ok(None);
}
_ => {
return Ok(Some(entry));
}
}
}
self.idx += 1;
}
Ok(None)
}
async fn seek(&mut self, next_key: &[u8]) -> Result<(), SlateDBError> {
if next_key != self.key {
return Err(SlateDBError::SeekKeyOutOfRange {
key: next_key.to_vec(),
range: BytesRange::from(self.key.clone()..=self.key.clone()),
});
}
Ok(())
}
}
struct ScanIterator {
delegate: Box<dyn KeyValueIterator + 'static>,
}
impl ScanIterator {
pub(crate) fn new(
write_batch_iter: Box<dyn KeyValueIterator + 'static>,
mem_iters: impl IntoIterator<Item = Box<dyn KeyValueIterator + 'static>>,
l0_iters: impl IntoIterator<Item = Box<dyn KeyValueIterator + 'static>>,
sr_iters: impl IntoIterator<Item = Box<dyn KeyValueIterator + 'static>>,
) -> Result<Self, SlateDBError> {
let iters = vec![
write_batch_iter,
Box::new(MergeIterator::new(mem_iters)?),
Box::new(MergeIterator::new(l0_iters)?),
Box::new(MergeIterator::new(sr_iters)?),
];
Ok(Self {
delegate: Box::new(MergeIterator::new(iters)?),
})
}
}
#[async_trait]
impl KeyValueIterator for ScanIterator {
async fn init(&mut self) -> Result<(), SlateDBError> {
self.delegate.init().await
}
async fn next_entry(&mut self) -> Result<Option<RowEntry>, SlateDBError> {
self.delegate.next_entry().await
}
async fn seek(&mut self, next_key: &[u8]) -> Result<(), SlateDBError> {
self.delegate.seek(next_key).await
}
}
pub struct DbIterator {
range: BytesRange,
iter: Box<dyn KeyValueIterator + 'static>,
invalidated_error: Option<SlateDBError>,
last_key: Option<Bytes>,
range_tracker: Option<Arc<DbIteratorRangeTracker>>,
}
impl DbIterator {
pub(crate) async fn new(
range: BytesRange,
write_batch_iter: Option<WriteBatchIterator>,
mem_iters: impl IntoIterator<Item = Box<dyn KeyValueIterator + 'static>>,
l0_iters: impl IntoIterator<Item = Box<dyn KeyValueIterator + 'static>>,
sr_iters: impl IntoIterator<Item = Box<dyn KeyValueIterator + 'static>>,
max_seq: Option<u64>,
range_tracker: Option<Arc<DbIteratorRangeTracker>>,
now: i64,
merge_operator: Option<MergeOperatorType>,
) -> Result<Self, SlateDBError> {
let write_batch_iter = write_batch_iter
.map(|iter| Box::new(iter) as Box<dyn KeyValueIterator + 'static>)
.unwrap_or_else(|| Box::new(EmptyIterator::new()));
let mem_iters = apply_filters(mem_iters, max_seq, now);
let l0_iters = apply_filters(l0_iters, max_seq, now);
let sr_iters = apply_filters(sr_iters, max_seq, now);
let mut iter = match range.as_point() {
Some(key) => Box::new(GetIterator::new(
key.clone(),
write_batch_iter,
mem_iters,
l0_iters,
sr_iters,
)) as Box<dyn KeyValueIterator + 'static>,
None => Box::new(ScanIterator::new(
write_batch_iter,
mem_iters,
l0_iters,
sr_iters,
)?) as Box<dyn KeyValueIterator + 'static>,
};
if let Some(merge_operator) = merge_operator {
iter = Box::new(MergeOperatorIterator::new(
merge_operator,
iter,
true,
now,
None,
));
} else {
iter = Box::new(MergeOperatorRequiredIterator::new(iter));
}
iter.init().await?;
Ok(DbIterator {
range,
iter,
invalidated_error: None,
last_key: None,
range_tracker,
})
}
pub async fn next(&mut self) -> Result<Option<KeyValue>, crate::Error> {
self.next_key_value().await.map_err(Into::into)
}
pub(crate) async fn next_key_value(&mut self) -> Result<Option<KeyValue>, SlateDBError> {
if let Some(error) = self.invalidated_error.clone() {
Err(error)
} else {
let result = self.iter.next().await;
let result = self.maybe_invalidate(result);
if let Ok(Some(ref kv)) = result {
self.last_key = Some(kv.key.clone());
if let Some(tracker) = &self.range_tracker {
tracker.track_key(&kv.key);
}
}
result
}
}
fn maybe_invalidate<T: Clone>(
&mut self,
result: Result<T, SlateDBError>,
) -> Result<T, SlateDBError> {
if let Err(error) = &result {
self.invalidated_error = Some(error.clone());
}
result
}
pub async fn seek<K: AsRef<[u8]>>(&mut self, next_key: K) -> Result<(), crate::Error> {
let next_key = next_key.as_ref();
if let Some(error) = self.invalidated_error.clone() {
Err(error.into())
} else if !self.range.contains(&next_key) {
Err(SlateDBError::SeekKeyOutOfRange {
key: next_key.to_vec(),
range: self.range.clone(),
}
.into())
} else if self
.last_key
.clone()
.is_some_and(|last_key| next_key <= last_key)
{
Err(SlateDBError::SeekKeyLessThanLastReturnedKey.into())
} else {
let result = self.iter.seek(next_key).await;
self.maybe_invalidate(result).map_err(Into::into)
}
}
}
pub(crate) fn apply_filters<T>(
iters: impl IntoIterator<Item = T>,
max_seq: Option<u64>,
now: i64,
) -> Vec<Box<dyn KeyValueIterator>>
where
T: KeyValueIterator + 'static,
{
iters
.into_iter()
.map(|iter| FilterIterator::new_with_max_seq(iter, max_seq))
.map(|iter| MapIterator::new_with_ttl_now(iter, now))
.map(|iter| Box::new(iter) as Box<dyn KeyValueIterator + 'static>)
.collect::<Vec<Box<dyn KeyValueIterator>>>()
}
#[cfg(test)]
mod tests {
use crate::batch::{WriteBatch, WriteBatchIterator};
use crate::bytes_range::BytesRange;
use crate::db_iter::DbIterator;
use crate::error::SlateDBError;
use crate::iter::{IterationOrder, KeyValueIterator};
use crate::test_utils::TestIterator;
use crate::types::RowEntry;
use bytes::Bytes;
use std::collections::VecDeque;
#[tokio::test]
async fn test_invalidated_iterator() {
let mem_iters: VecDeque<Box<dyn KeyValueIterator + 'static>> = VecDeque::new();
let mut iter = DbIterator::new(
BytesRange::from(..),
None,
mem_iters,
VecDeque::new(),
VecDeque::new(),
None,
None,
0,
None,
)
.await
.unwrap();
iter.invalidated_error = Some(SlateDBError::ChecksumMismatch);
let result = iter.next().await;
let err = result.expect_err("Failed to return invalidated iterator");
assert_invalidated_iterator_error(err);
let result = iter.seek(Bytes::new()).await;
let err = result.expect_err("Failed to return invalidated iterator");
assert_invalidated_iterator_error(err);
}
fn assert_invalidated_iterator_error(err: crate::Error) {
assert_eq!(err.to_string(), "Data error: checksum mismatch");
}
#[tokio::test]
async fn test_sequence_number_filtering() {
let mem_iter1 = TestIterator::new()
.with_entry(b"key1", b"value1", 96)
.with_entry(b"key1", b"value2", 110);
let mem_iter2 = TestIterator::new().with_entry(b"key1", b"value3", 95);
let mut iter = DbIterator::new(
BytesRange::from(..),
None,
vec![
Box::new(mem_iter1) as Box<dyn KeyValueIterator + 'static>,
Box::new(mem_iter2) as Box<dyn KeyValueIterator + 'static>,
],
VecDeque::new(),
VecDeque::new(),
Some(100),
None,
0,
None,
)
.await
.unwrap();
let result = iter.next().await.unwrap();
assert!(result.is_some());
let kv = result.unwrap();
assert_eq!(kv.key, Bytes::from("key1"));
assert_eq!(kv.value, Bytes::from("value1"));
assert!(iter.next().await.unwrap().is_none());
}
#[tokio::test]
async fn test_seek_cannot_rewind() {
let mem_iter = TestIterator::new()
.with_entry(b"key1", b"value1", 1)
.with_entry(b"key2", b"value2", 2);
let mut iter = DbIterator::new(
BytesRange::from(..),
None,
vec![Box::new(mem_iter) as Box<dyn KeyValueIterator + 'static>],
VecDeque::new(),
VecDeque::new(),
None,
None,
0,
None,
)
.await
.unwrap();
let first = iter.next().await.unwrap().unwrap();
assert_eq!(first.key, Bytes::from_static(b"key1"));
let err = iter.seek(b"key1").await.unwrap_err();
assert_eq!(
err.to_string(),
"Invalid error: cannot seek to a key less than the last returned key"
);
let err = iter.seek(b"key0").await.unwrap_err();
assert_eq!(
err.to_string(),
"Invalid error: cannot seek to a key less than the last returned key"
);
iter.seek(b"key2").await.unwrap();
let kv = iter.next().await.unwrap().unwrap();
assert_eq!(kv.key, Bytes::from_static(b"key2"));
assert!(iter.next().await.unwrap().is_none());
}
#[tokio::test]
async fn test_dbiterator_with_writebatch() {
let mut batch = WriteBatch::new();
batch.put(b"key1", b"value1");
batch.put(b"key3", b"value3");
let wb_iter = WriteBatchIterator::new(batch.clone(), .., IterationOrder::Ascending);
let mem_iters: VecDeque<Box<dyn KeyValueIterator + 'static>> = VecDeque::new();
let mut iter = DbIterator::new(
BytesRange::from(..),
Some(wb_iter),
mem_iters,
VecDeque::new(),
VecDeque::new(),
None,
None,
0,
None,
)
.await
.unwrap();
let kv1 = iter.next().await.unwrap().unwrap();
assert_eq!(kv1.key, Bytes::from_static(b"key1"));
assert_eq!(kv1.value, Bytes::from_static(b"value1"));
let kv2 = iter.next().await.unwrap().unwrap();
assert_eq!(kv2.key, Bytes::from_static(b"key3"));
assert_eq!(kv2.value, Bytes::from_static(b"value3"));
let kv3 = iter.next().await.unwrap();
assert!(kv3.is_none());
}
#[tokio::test]
async fn test_dbiterator_with_ttl_filtering() {
let mut entry1 = RowEntry::new_value(b"key1", b"value1", 1);
entry1.create_ts = Some(0);
entry1.expire_ts = Some(50);
let mut entry2 = RowEntry::new_value(b"key2", b"value2", 2);
entry2.create_ts = Some(0);
entry2.expire_ts = Some(100);
let mut entry3 = RowEntry::new_value(b"key3", b"value3", 3);
entry3.create_ts = Some(0);
entry3.expire_ts = None;
let mem_iter = TestIterator::new()
.with_row_entry(entry1)
.with_row_entry(entry2)
.with_row_entry(entry3);
let mut iter = DbIterator::new(
BytesRange::from(..),
None,
vec![Box::new(mem_iter) as Box<dyn KeyValueIterator + 'static>],
VecDeque::new(),
VecDeque::new(),
None,
None,
49,
None,
)
.await
.unwrap();
assert_eq!(
iter.next().await.unwrap().unwrap().key,
Bytes::from_static(b"key1")
);
assert_eq!(
iter.next().await.unwrap().unwrap().key,
Bytes::from_static(b"key2")
);
assert_eq!(
iter.next().await.unwrap().unwrap().key,
Bytes::from_static(b"key3")
);
assert!(iter.next().await.unwrap().is_none());
let mut entry1 = RowEntry::new_value(b"key1", b"value1", 1);
entry1.create_ts = Some(0);
entry1.expire_ts = Some(50);
let mut entry2 = RowEntry::new_value(b"key2", b"value2", 2);
entry2.create_ts = Some(0);
entry2.expire_ts = Some(100);
let mut entry3 = RowEntry::new_value(b"key3", b"value3", 3);
entry3.create_ts = Some(0);
entry3.expire_ts = None;
let mem_iter = TestIterator::new()
.with_row_entry(entry1)
.with_row_entry(entry2)
.with_row_entry(entry3);
let mut iter = DbIterator::new(
BytesRange::from(..),
None,
vec![Box::new(mem_iter) as Box<dyn KeyValueIterator + 'static>],
VecDeque::new(),
VecDeque::new(),
None,
None,
50,
None,
)
.await
.unwrap();
assert_eq!(
iter.next().await.unwrap().unwrap().key,
Bytes::from_static(b"key2")
);
assert_eq!(
iter.next().await.unwrap().unwrap().key,
Bytes::from_static(b"key3")
);
assert!(iter.next().await.unwrap().is_none());
let mut entry1 = RowEntry::new_value(b"key1", b"value1", 1);
entry1.create_ts = Some(0);
entry1.expire_ts = Some(50);
let mut entry2 = RowEntry::new_value(b"key2", b"value2", 2);
entry2.create_ts = Some(0);
entry2.expire_ts = Some(100);
let mut entry3 = RowEntry::new_value(b"key3", b"value3", 3);
entry3.create_ts = Some(0);
entry3.expire_ts = None;
let mem_iter = TestIterator::new()
.with_row_entry(entry1)
.with_row_entry(entry2)
.with_row_entry(entry3);
let mut iter = DbIterator::new(
BytesRange::from(..),
None,
vec![Box::new(mem_iter) as Box<dyn KeyValueIterator + 'static>],
VecDeque::new(),
VecDeque::new(),
None,
None,
100,
None,
)
.await
.unwrap();
assert_eq!(
iter.next().await.unwrap().unwrap().key,
Bytes::from_static(b"key3")
);
assert!(iter.next().await.unwrap().is_none());
let mut entry1 = RowEntry::new_value(b"key1", b"value1", 1);
entry1.create_ts = Some(0);
entry1.expire_ts = Some(50);
let mut entry2 = RowEntry::new_value(b"key2", b"value2", 2);
entry2.create_ts = Some(0);
entry2.expire_ts = Some(100);
let mut entry3 = RowEntry::new_value(b"key3", b"value3", 3);
entry3.create_ts = Some(0);
entry3.expire_ts = None;
let mem_iter = TestIterator::new()
.with_row_entry(entry1)
.with_row_entry(entry2)
.with_row_entry(entry3);
let mut iter = DbIterator::new(
BytesRange::from(..),
None,
vec![Box::new(mem_iter) as Box<dyn KeyValueIterator + 'static>],
VecDeque::new(),
VecDeque::new(),
None,
None,
200,
None,
)
.await
.unwrap();
assert_eq!(
iter.next().await.unwrap().unwrap().key,
Bytes::from_static(b"key3")
);
assert!(iter.next().await.unwrap().is_none());
}
#[tokio::test]
async fn test_dbiterator_expired_value_hides_older_valid_value() {
let mut newer_entry = RowEntry::new_value(b"key1", b"newer_value", 100);
newer_entry.create_ts = Some(0);
newer_entry.expire_ts = Some(50);
let mut older_entry = RowEntry::new_value(b"key1", b"older_value", 50);
older_entry.create_ts = Some(0);
older_entry.expire_ts = None;
let mem_iter = TestIterator::new()
.with_row_entry(newer_entry)
.with_row_entry(older_entry);
let mut iter = DbIterator::new(
BytesRange::from(..),
None,
vec![Box::new(mem_iter) as Box<dyn KeyValueIterator + 'static>],
VecDeque::new(),
VecDeque::new(),
None,
None,
100, None,
)
.await
.unwrap();
assert!(iter.next().await.unwrap().is_none());
}
}