use crate::block_iterator::BlockIterator;
use crate::block_iterator_v2::BlockIteratorV2;
use crate::cached_object_store::CachedObjectStore;
use crate::config::PreloadLevel;
use crate::db_state::ManifestCore;
use crate::db_state::SortedRun;
use crate::db_state::SsTableHandle;
use crate::error::SlateDBError;
use crate::format::sst::{SST_FORMAT_VERSION, SST_FORMAT_VERSION_V2};
use crate::iter::{IterationOrder, RowEntryIterator};
use crate::paths::PathResolver;
use crate::tablestore::TableStore;
use bytes::{BufMut, Bytes};
use futures::FutureExt;
use log::{error, warn};
use rand::{Rng, RngCore};
use slatedb_common::clock::SystemClock;
use std::any::Any;
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use ulid::Ulid;
use uuid::Uuid;
use futures::StreamExt;
use std::collections::VecDeque;
static EMPTY_KEY: Bytes = Bytes::new();
#[derive(Clone, Debug)]
pub(crate) struct WatchableOnceCell<T: Clone> {
rx: tokio::sync::watch::Receiver<Option<T>>,
tx: tokio::sync::watch::Sender<Option<T>>,
}
#[derive(Clone, Debug)]
pub(crate) struct WatchableOnceCellReader<T: Clone> {
rx: tokio::sync::watch::Receiver<Option<T>>,
}
impl<T: Clone> WatchableOnceCell<T> {
pub(crate) fn new() -> Self {
let (tx, rx) = tokio::sync::watch::channel(None);
Self { rx, tx }
}
pub(crate) fn write(&self, val: T) -> bool {
self.tx.send_if_modified(|v| {
if v.is_some() {
return false;
}
v.replace(val);
true
})
}
pub(crate) fn reader(&self) -> WatchableOnceCellReader<T> {
WatchableOnceCellReader {
rx: self.rx.clone(),
}
}
}
impl<T: Clone> WatchableOnceCellReader<T> {
pub(crate) fn read(&self) -> Option<T> {
self.rx.borrow().clone()
}
pub(crate) async fn await_value(&mut self) -> T {
self.rx
.wait_for(|v| v.is_some())
.await
.expect("watch channel closed")
.clone()
.expect("no value found")
}
}
pub(crate) fn spawn_bg_task<F, T, C>(
name: String,
handle: &tokio::runtime::Handle,
cleanup_fn: C,
future: F,
) -> tokio::task::JoinHandle<Result<T, SlateDBError>>
where
F: Future<Output = Result<T, SlateDBError>> + Send + 'static,
T: Send + 'static,
C: FnOnce(&Result<T, SlateDBError>) + Send + 'static,
{
let wrapped = AssertUnwindSafe(future).catch_unwind().map(move |outcome| {
let result = match outcome {
Ok(result) => result,
Err(payload) => {
error!(
"spawned task panicked. [name={}, panic={}]",
name,
panic_string(&payload)
);
Err(SlateDBError::BackgroundTaskPanic(name))
}
};
cleanup_fn(&result);
result
});
handle.spawn(wrapped)
}
pub(crate) fn merge_options<T>(
current: Option<T>,
next: Option<T>,
f: impl Fn(T, T) -> T,
) -> Option<T> {
match (current, next) {
(Some(current), Some(next)) => Some(f(current, next)),
(None, next) => next,
(current, None) => current,
}
}
pub(crate) async fn last_written_key_and_seq(
table_store: Arc<TableStore>,
output_sst: &SsTableHandle,
) -> Result<Option<(Bytes, u64)>, SlateDBError> {
let index = table_store.read_index(output_sst, false).await?;
let num_blocks = index.borrow().block_meta().len();
if num_blocks == 0 {
return Ok(None);
}
let last_block_idx = num_blocks - 1;
let mut blocks = table_store
.read_blocks_using_index(output_sst, index, last_block_idx..last_block_idx + 1, false)
.await?;
let Some(block) = blocks.pop_front() else {
return Ok(None);
};
let entry = match output_sst.format_version {
SST_FORMAT_VERSION => {
let mut block_iter = BlockIterator::new(block, IterationOrder::Descending);
block_iter.init().await?;
block_iter.next().await?
}
SST_FORMAT_VERSION_V2 => {
let mut block_iter = BlockIteratorV2::new(block, IterationOrder::Descending);
block_iter.init().await?;
block_iter.next().await?
}
_ => {
return Err(SlateDBError::InvalidVersion {
format_name: "SST",
supported_versions: vec![SST_FORMAT_VERSION, SST_FORMAT_VERSION_V2],
actual_version: output_sst.format_version,
});
}
};
Ok(entry.map(|e| (e.key, e.seq)))
}
fn bytes_into_minimal_vec(bytes: &Bytes) -> Vec<u8> {
let mut clamped = Vec::new();
clamped.reserve_exact(bytes.len());
clamped.put_slice(bytes.as_ref());
clamped
}
pub(crate) fn clamp_allocated_size_bytes(bytes: &Bytes) -> Bytes {
bytes_into_minimal_vec(bytes).into()
}
pub(crate) fn compute_index_key(
prev_block_last_key: Option<Bytes>,
this_block_first_key: &Bytes,
) -> Bytes {
if let Some(prev_key) = prev_block_last_key {
compute_lower_bound(&prev_key, this_block_first_key)
} else {
EMPTY_KEY.clone()
}
}
fn compute_lower_bound(prev_block_last_key: &Bytes, this_block_first_key: &Bytes) -> Bytes {
assert!(!prev_block_last_key.is_empty() && !this_block_first_key.is_empty());
for i in 0..prev_block_last_key.len() {
if prev_block_last_key[i] != this_block_first_key[i] {
return this_block_first_key.slice(..i + 1);
}
}
if prev_block_last_key.len() == this_block_first_key.len() {
return this_block_first_key.clone();
}
this_block_first_key.slice(..prev_block_last_key.len() + 1)
}
pub(crate) trait IdGenerator {
fn gen_uuid(&mut self) -> Uuid;
fn gen_ulid(&mut self, clock: &dyn SystemClock) -> Ulid;
}
impl<R: RngCore> IdGenerator for R {
fn gen_uuid(&mut self) -> Uuid {
let mut bytes = [0u8; 16];
self.fill_bytes(&mut bytes);
bytes[6] = (bytes[6] & 0x0f) | 0x40;
bytes[8] = (bytes[8] & 0x3f) | 0x80;
Uuid::from_bytes(bytes)
}
fn gen_ulid(&mut self, clock: &dyn SystemClock) -> Ulid {
let now = u64::try_from(clock.now().timestamp_millis())
.expect("timestamp outside u64 range in gen_ulid");
let random_bytes = self.random::<u128>();
Ulid::from_parts(now, random_bytes)
}
}
pub(crate) struct BitWriter {
buf: Vec<u8>,
cur: u8, n: u8, }
impl BitWriter {
pub(crate) fn new() -> Self {
BitWriter {
buf: Vec::new(),
cur: 0,
n: 0,
}
}
pub(crate) fn push(&mut self, bit: bool) {
if bit {
self.cur |= 1 << (7 - self.n);
}
self.n += 1;
if self.n == 8 {
self.flush_byte();
}
}
pub(crate) fn push32(&mut self, value: u32, bits: u8) {
for i in (0..bits).rev() {
let bit = ((value >> i) & 1) != 0;
self.push(bit);
}
}
pub(crate) fn push64(&mut self, value: u64, bits: u8) {
for i in (0..bits).rev() {
let bit = ((value >> i) & 1) != 0;
self.push(bit);
}
}
fn flush_byte(&mut self) {
self.buf.push(self.cur);
self.cur = 0;
self.n = 0;
}
pub(crate) fn finish(mut self) -> Vec<u8> {
if self.n > 0 {
self.buf.push(self.cur);
}
self.buf
}
}
pub(crate) struct BitReader<'a> {
buf: &'a [u8],
byte_pos: usize,
bit_pos: u8, }
impl<'a> BitReader<'a> {
pub(crate) fn new(buf: &'a [u8]) -> Self {
BitReader {
buf,
byte_pos: 0,
bit_pos: 0,
}
}
pub(crate) fn read_bit(&mut self) -> Option<bool> {
if self.byte_pos >= self.buf.len() {
return None;
}
let byte = self.buf[self.byte_pos];
let bit = ((byte >> (7 - self.bit_pos)) & 1) != 0;
self.bit_pos += 1;
if self.bit_pos == 8 {
self.bit_pos = 0;
self.byte_pos += 1;
}
Some(bit)
}
pub(crate) fn read32(&mut self, bits: u8) -> Option<u32> {
let mut val = 0u32;
for _ in 0..bits {
val <<= 1;
match self.read_bit() {
Some(true) => val |= 1,
Some(false) => (),
None => return None,
}
}
Some(val)
}
pub(crate) fn read64(&mut self, bits: u8) -> Option<u64> {
let mut val = 0u64;
for _ in 0..bits {
val <<= 1;
match self.read_bit() {
Some(true) => val |= 1,
Some(false) => (),
None => return None,
}
}
Some(val)
}
}
pub(crate) fn sign_extend(val: u32, bits: u8) -> i32 {
let shift = 32 - bits;
((val << shift) as i32) >> shift
}
pub(crate) fn compute_max_parallel(l0_count: usize, srs: &[SortedRun], cap: usize) -> usize {
let total_ssts = l0_count + srs.iter().map(|sr| sr.sst_views.len()).sum::<usize>();
total_ssts.min(cap).max(1)
}
pub(crate) fn estimate_bytes_before_key(sorted_runs: &[SortedRun], key: &Bytes) -> u64 {
sorted_runs
.iter()
.map(|sorted_run| {
let Some(idx) = sorted_run.find_last_sst_with_range_covering_key(key) else {
return 0;
};
sorted_run
.sst_views
.iter()
.take(idx)
.map(|sst| sst.estimate_size())
.sum::<u64>()
})
.sum()
}
#[allow(clippy::redundant_closure)]
pub(crate) async fn build_concurrent<I, T, F, Fut>(
inputs: I,
max_parallel: usize,
f: F,
) -> Result<VecDeque<T>, SlateDBError>
where
I: IntoIterator,
I::Item: Send,
T: Send,
F: Fn(I::Item) -> Fut + Send,
Fut: std::future::Future<Output = Result<Option<T>, SlateDBError>> + Send,
{
let mut out = VecDeque::new();
let results = futures::stream::iter(inputs.into_iter().map(move |it| f(it)))
.buffer_unordered(max_parallel.max(1))
.collect::<Vec<_>>()
.await;
for r in results {
match r {
Ok(Some(t)) => out.push_back(t),
Ok(None) => {}
Err(e) => return Err(e),
}
}
Ok(out)
}
pub(crate) fn panic_string(panic: &Box<dyn Any + Send>) -> String {
if let Some(result) = panic.downcast_ref::<Result<(), SlateDBError>>() {
match result {
Ok(()) => "ok".to_string(),
Err(e) => e.to_string(),
}
} else if let Some(err) = panic.downcast_ref::<SlateDBError>() {
err.to_string()
} else if let Some(err) = panic.downcast_ref::<Box<dyn std::error::Error>>() {
err.to_string()
} else if let Some(err) = panic.downcast_ref::<String>() {
err.clone()
} else if let Some(err) = panic.downcast_ref::<&str>() {
err.to_string()
} else {
format!(
"task panicked with unknown type [type_id=`{:?}`]",
(**panic).type_id()
)
}
}
pub(crate) fn split_unwind_result(
name: String,
unwind_result: Result<Result<(), SlateDBError>, Box<dyn std::any::Any + Send>>,
) -> (
Result<(), SlateDBError>,
Option<Box<dyn std::any::Any + Send>>,
) {
match unwind_result {
Ok(result) => (result, None),
Err(payload) => (Err(SlateDBError::BackgroundTaskPanic(name)), Some(payload)),
}
}
pub(crate) fn split_join_result(
name: String,
join_result: Result<Result<(), SlateDBError>, tokio::task::JoinError>,
) -> (
Result<(), SlateDBError>,
Option<Box<dyn std::any::Any + Send>>,
) {
match join_result {
Ok(task_result) => (task_result, None),
Err(join_error) => {
if join_error.is_cancelled() {
(Err(SlateDBError::BackgroundTaskCancelled(name)), None)
} else {
let payload = join_error.into_panic();
(Err(SlateDBError::BackgroundTaskPanic(name)), Some(payload))
}
}
}
}
pub(crate) fn format_bytes_si(bytes: u64) -> String {
const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB", "PB", "EB"];
const FACTOR: f64 = 1000.0;
if bytes < 1000 {
return format!("{} B", bytes);
}
let mut value = bytes as f64;
let mut unit_index = 0;
while value >= FACTOR && unit_index < UNITS.len() - 1 {
value /= FACTOR;
unit_index += 1;
}
format!("{:.2} {}", value, UNITS[unit_index])
}
#[allow(dead_code)]
pub(crate) fn encode_varint(buf: &mut Vec<u8>, mut value: u32) {
while value >= 0x80 {
buf.push((value as u8) | 0x80);
value >>= 7;
}
buf.push(value as u8);
}
#[allow(dead_code)]
pub(crate) fn decode_varint(buf: &mut &[u8]) -> u32 {
let mut result = 0u32;
let mut shift = 0;
loop {
let byte = buf[0];
*buf = &buf[1..];
result |= ((byte & 0x7F) as u32) << shift;
if byte & 0x80 == 0 {
break;
}
shift += 7;
}
result
}
#[allow(dead_code)]
pub(crate) fn varint_len(mut value: u32) -> usize {
let mut len = 1;
while value >= 0x80 {
value >>= 7;
len += 1;
}
len
}
pub(crate) async fn preload_cache_from_manifest(
core: &ManifestCore,
cached_obj_store: &CachedObjectStore,
path_resolver: &PathResolver,
preload_level: Option<PreloadLevel>,
max_cache_size: usize,
) -> Result<(), SlateDBError> {
match preload_level {
Some(PreloadLevel::AllSst) => {
let mut all_sst_paths: Vec<object_store::path::Path> = Vec::with_capacity(
core.l0.len()
+ core
.compacted
.iter()
.map(|sr| sr.sst_views.len())
.sum::<usize>(),
);
all_sst_paths.extend(
core.l0
.iter()
.map(|view| path_resolver.table_path(&view.sst.id)),
);
all_sst_paths.extend(
core.compacted
.iter()
.flat_map(|sr| &sr.sst_views)
.map(|view| path_resolver.table_path(&view.sst.id)),
);
if !all_sst_paths.is_empty() {
if let Err(e) = cached_obj_store
.load_files_to_cache(all_sst_paths, max_cache_size)
.await
{
warn!("Failed to preload all SSTs to cache: {:?}", e);
}
}
}
Some(PreloadLevel::L0Sst) => {
let l0_sst_paths: Vec<object_store::path::Path> = core
.l0
.iter()
.map(|view| path_resolver.table_path(&view.sst.id))
.collect();
if !l0_sst_paths.is_empty() {
if let Err(e) = cached_obj_store
.load_files_to_cache(l0_sst_paths, max_cache_size)
.await
{
warn!("Failed to preload L0 SSTs to cache: {:?}", e);
}
}
}
None => {
}
}
Ok(())
}
pub(crate) struct SafeSender<T> {
tx: async_channel::Sender<T>,
closed: WatchableOnceCellReader<Result<(), SlateDBError>>,
}
impl<T> SafeSender<T> {
pub(crate) fn new(
tx: async_channel::Sender<T>,
closed: WatchableOnceCellReader<Result<(), SlateDBError>>,
) -> Self {
Self { tx, closed }
}
pub(crate) fn unbounded_channel(
closed: WatchableOnceCellReader<Result<(), SlateDBError>>,
) -> (Self, async_channel::Receiver<T>) {
let (tx, rx) = async_channel::unbounded();
(Self::new(tx, closed), rx)
}
#[inline]
#[allow(clippy::panic, clippy::disallowed_methods)]
pub(crate) fn send(&self, message: T) -> Result<(), SlateDBError> {
match self.tx.try_send(message) {
Ok(_) => Ok(()),
Err(e) => {
if let Some(result) = self.closed.read() {
match result {
Ok(()) => Err(SlateDBError::Closed),
Err(err) => Err(err),
}
} else {
panic!("Failed to send message to unbounded channel: {}", e);
}
}
}
}
}
impl<T> Clone for SafeSender<T> {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
closed: self.closed.clone(),
}
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use slatedb_common::MockSystemClock;
use crate::clock::MonotonicClock;
use crate::db_state::{SortedRun, SsTableHandle, SsTableId, SsTableInfo, SsTableView};
use crate::error::SlateDBError;
use crate::format::sst::SST_FORMAT_VERSION_LATEST;
use crate::sst_builder::BlockFormat;
use crate::types::RowEntry;
use crate::utils::{
build_concurrent, bytes_into_minimal_vec, clamp_allocated_size_bytes, compute_index_key,
compute_max_parallel, estimate_bytes_before_key, format_bytes_si, last_written_key_and_seq,
panic_string, spawn_bg_task, BitReader, BitWriter, WatchableOnceCell,
};
use crate::Db;
use bytes::{BufMut, Bytes, BytesMut};
use object_store::memory::InMemory;
use parking_lot::Mutex;
use std::any::Any;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use ulid::Ulid;
struct ResultCaptor<T: Clone> {
error: Mutex<Option<Result<T, SlateDBError>>>,
}
impl<T: Clone> ResultCaptor<T> {
fn new() -> Self {
Self {
error: Mutex::new(None),
}
}
fn capture(&self, result: &Result<T, SlateDBError>) {
let mut guard = self.error.lock();
let prev = guard.replace(result.clone());
assert!(prev.is_none());
}
fn captured(&self) -> Option<Result<T, SlateDBError>> {
self.error.lock().clone()
}
}
fn make_sst_view(start_key: &str, size: u64) -> SsTableView {
let info = SsTableInfo {
first_entry: Some(Bytes::from(start_key.as_bytes().to_vec())),
index_offset: size.saturating_sub(1),
index_len: 1,
..Default::default()
};
SsTableView::identity(SsTableHandle::new(
SsTableId::Compacted(Ulid::new()),
SST_FORMAT_VERSION_LATEST,
info,
))
}
#[test]
fn test_should_return_empty_for_index_of_first_block() {
let this_block_first_key = Bytes::from(vec![0x01, 0x02, 0x03]);
let result = compute_index_key(None, &this_block_first_key);
assert_eq!(result, &Bytes::new());
}
#[rstest]
#[case(Some("aaaac"), "abaaa", "ab")]
#[case(Some("ababc"), "abacd", "abac")]
#[case(Some("cc"), "ccccccc", "ccc")]
#[case(Some("eed"), "eee", "eee")]
#[case(Some("abcdef"), "abcdef", "abcdef")]
fn test_should_compute_index_key(
#[case] prev_block_last_key: Option<&'static str>,
#[case] this_block_first_key: &'static str,
#[case] expected_index_key: &'static str,
) {
assert_eq!(
compute_index_key(
prev_block_last_key.map(|s| Bytes::from(s.to_string())),
&Bytes::from(this_block_first_key.to_string())
),
Bytes::from_static(expected_index_key.as_bytes())
);
}
#[rstest]
#[case(Some(""), "a")]
#[case(Some("a"), "")]
#[should_panic]
fn test_should_panic_on_empty_keys(
#[case] prev_block_last_key: Option<&'static str>,
#[case] this_block_first_key: &'static str,
) {
compute_index_key(
prev_block_last_key.map(|s| Bytes::from(s.to_string())),
&Bytes::from(this_block_first_key.to_string()),
);
}
#[tokio::test]
async fn test_should_cleanup_when_task_exits_with_error() {
let captor = Arc::new(ResultCaptor::new());
let handle = tokio::runtime::Handle::current();
let captor2 = captor.clone();
let task = spawn_bg_task(
"test".to_string(),
&handle,
move |err| captor2.capture(err),
async { Err(SlateDBError::Fenced) },
);
let result: Result<(), SlateDBError> = task.await.expect("join failure");
assert!(matches!(result, Err(SlateDBError::Fenced)));
assert!(matches!(captor.captured(), Some(Err(SlateDBError::Fenced))));
}
#[tokio::test]
async fn test_should_cleanup_when_task_panics() {
let monitored = async {
panic!("oops");
};
let captor = Arc::new(ResultCaptor::new());
let handle = tokio::runtime::Handle::current();
let captor2 = captor.clone();
let task = spawn_bg_task(
"test".to_string(),
&handle,
move |err| captor2.capture(err),
monitored,
);
let result: Result<(), SlateDBError> = task.await.expect("join failure");
assert!(matches!(result, Err(SlateDBError::BackgroundTaskPanic(_))));
assert!(matches!(
captor.captured(),
Some(Err(SlateDBError::BackgroundTaskPanic(_)))
));
}
#[tokio::test]
async fn test_should_cleanup_when_task_exits() {
let captor = Arc::new(ResultCaptor::new());
let handle = tokio::runtime::Handle::current();
let captor2 = captor.clone();
let task = spawn_bg_task(
"test".to_string(),
&handle,
move |err| captor2.capture(err),
async { Ok(()) },
);
let result: Result<(), SlateDBError> = task.await.expect("join failure");
assert!(matches!(result, Ok(())));
assert!(matches!(captor.captured(), Some(Ok(()))));
}
#[tokio::test]
async fn test_should_only_write_register_once() {
let register = WatchableOnceCell::new();
let reader = register.reader();
assert_eq!(reader.read(), None);
register.write(123);
assert_eq!(reader.read(), Some(123));
register.write(456);
assert_eq!(reader.read(), Some(123));
}
#[tokio::test]
async fn test_should_return_on_await_written_register() {
let register = WatchableOnceCell::new();
let mut reader = register.reader();
let h = tokio::spawn(async move {
assert_eq!(reader.await_value().await, 123);
assert_eq!(reader.await_value().await, 123);
});
register.write(123);
h.await.unwrap();
}
#[tokio::test]
async fn test_monotonicity_enforcement_on_mono_clock() {
let clock = Arc::new(MockSystemClock::new());
let mono_clock = MonotonicClock::new(clock.clone(), 0);
clock.set(10);
mono_clock.now().await.unwrap();
clock.set(5);
if let Err(SlateDBError::InvalidClockTick {
last_tick,
next_tick,
}) = mono_clock.now().await
{
assert_eq!(last_tick, 10);
assert_eq!(next_tick, 5);
} else {
panic!("Expected InvalidClockTick from mono_clock")
}
}
#[tokio::test]
async fn test_monotonicity_enforcement_on_mono_clock_set_tick() {
let clock = Arc::new(MockSystemClock::new());
let mono_clock = MonotonicClock::new(clock.clone(), 0);
clock.set(10);
mono_clock.now().await.unwrap();
if let Err(SlateDBError::InvalidClockTick {
last_tick,
next_tick,
}) = mono_clock.set_last_tick(5)
{
assert_eq!(last_tick, 10);
assert_eq!(next_tick, 5);
} else {
panic!("Expected InvalidClockTick from mono_clock")
}
}
#[tokio::test(start_paused = true)]
async fn test_await_valid_tick() {
let delegate_clock = Arc::new(MockSystemClock::new());
let mono_clock = MonotonicClock::new(delegate_clock.clone(), 100);
tokio::spawn({
let delegate_clock = delegate_clock.clone();
async move {
tokio::time::sleep(Duration::from_millis(50)).await;
delegate_clock.set(101);
}
});
let tick_future = mono_clock.now();
tokio::time::advance(Duration::from_millis(100)).await;
let result = tick_future.await;
assert_eq!(result.unwrap(), 101);
}
#[tokio::test(start_paused = true)]
async fn test_await_valid_tick_failure() {
let delegate_clock = Arc::new(MockSystemClock::new());
let mono_clock = MonotonicClock::new(delegate_clock.clone(), 100);
let tick_future = mono_clock.now();
tokio::time::advance(Duration::from_millis(110)).await;
let result = tick_future.await;
assert!(result.is_err());
}
#[test]
fn test_should_clamp_bytes_to_minimal_vec() {
let mut bytes = BytesMut::with_capacity(2048);
bytes.put_bytes(0u8, 2048);
let bytes = bytes.freeze();
let slice = bytes.slice(100..1124);
let clamped = bytes_into_minimal_vec(&slice);
assert_eq!(slice.as_ref(), clamped.as_slice());
assert_eq!(clamped.capacity(), 1024);
}
#[test]
fn test_should_clamp_bytes_and_preserve_data() {
let mut bytes = BytesMut::with_capacity(2048);
bytes.put_bytes(0u8, 2048);
let bytes = bytes.freeze();
let slice = bytes.slice(100..1124);
let clamped = clamp_allocated_size_bytes(&slice);
assert_eq!(clamped, slice);
assert_ne!(clamped.as_ptr(), slice.as_ptr());
}
#[tokio::test]
async fn test_last_written_key_and_seq_from_output_sst() {
let os = Arc::new(InMemory::new());
let path = "testdb-last-written".to_string();
let clock = Arc::new(MockSystemClock::new());
let db = Db::builder(path, os.clone())
.with_system_clock(clock.clone())
.build()
.await
.unwrap();
let table_store = db.inner.table_store.clone();
let mut sst_builder = table_store.table_builder();
sst_builder
.add(RowEntry::new_value(b"a", b"1", 1))
.await
.unwrap();
sst_builder
.add(RowEntry::new_value(b"b", b"2", 2))
.await
.unwrap();
let encoded_sst = sst_builder.build().await.unwrap();
let _sst1 = table_store
.write_sst(&SsTableId::Compacted(Ulid::new()), encoded_sst, false)
.await
.unwrap();
let mut sst_builder = table_store.table_builder();
sst_builder
.add(RowEntry::new_value(b"m", b"3", 3))
.await
.unwrap();
sst_builder
.add(RowEntry::new_value(b"z", b"4", 4))
.await
.unwrap();
let encoded_sst = sst_builder.build().await.unwrap();
let sst2 = table_store
.write_sst(&SsTableId::Compacted(Ulid::new()), encoded_sst, false)
.await
.unwrap();
let (last_key, last_seq) = last_written_key_and_seq(table_store.clone(), &sst2)
.await
.unwrap()
.expect("missing last entry");
assert_eq!(last_key, Bytes::from(b"z".as_slice()));
assert_eq!(last_seq, 4);
}
#[tokio::test]
async fn should_get_last_written_key_and_seq_from_v1_sst() {
let os = Arc::new(InMemory::new());
let path = "testdb-last-written-v1".to_string();
let clock = Arc::new(MockSystemClock::new());
let db = Db::builder(path, os.clone())
.with_system_clock(clock.clone())
.build()
.await
.unwrap();
let table_store = db.inner.table_store.clone();
let mut sst_builder = table_store
.table_builder()
.with_block_format(BlockFormat::V1);
sst_builder
.add(RowEntry::new_value(b"aaa", b"1", 10))
.await
.unwrap();
sst_builder
.add(RowEntry::new_value(b"zzz", b"2", 20))
.await
.unwrap();
let encoded_sst = sst_builder.build().await.unwrap();
let sst = table_store
.write_sst(&SsTableId::Compacted(Ulid::new()), encoded_sst, false)
.await
.unwrap();
let (last_key, last_seq) = last_written_key_and_seq(table_store.clone(), &sst)
.await
.unwrap()
.expect("missing last entry");
assert_eq!(last_key, Bytes::from(b"zzz".as_slice()));
assert_eq!(last_seq, 20);
}
#[rstest]
#[case("alternating_bits", vec![true, false, true, false, true, false, true, false], vec![], vec![], vec![0xAA])]
#[case("u64_value", vec![], vec![(0xAB, 8)], vec![], vec![0xAB])]
#[case("partial_and_multiple_bytes", vec![true, false], vec![(0x3F, 6), (0xCD, 8)], vec![], vec![0xBF, 0xCD])]
#[case("empty_writer", vec![], vec![], vec![], vec![])]
#[case("single_bit_true", vec![true], vec![], vec![], vec![0x80])]
#[case("single_bit_false", vec![false], vec![], vec![], vec![0x00])]
#[case("all_zeros", vec![false, false, false, false, false, false, false, false], vec![], vec![], vec![0x00])]
#[case("all_ones", vec![true, true, true, true, true, true, true, true], vec![], vec![], vec![0xFF])]
#[case("partial_byte_padding", vec![true, false, true], vec![], vec![], vec![0xA0])]
#[case("push32_single_bit", vec![], vec![(1, 1)], vec![], vec![0x80])]
#[case("push32_zero_bits", vec![], vec![(0xFF, 0)], vec![], vec![])]
#[case("push32_max_bits", vec![], vec![(0xDEADBEEF, 32)], vec![], vec![0xDE, 0xAD, 0xBE, 0xEF])]
#[case("push64_operations", vec![], vec![], vec![(0x123456789ABCDEF0, 64)], vec![0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0])]
#[case("push64_partial", vec![], vec![], vec![(0xABCD, 12)], vec![0xBC, 0xD0])]
#[case("mixed_operations", vec![true], vec![(0x7F, 7)], vec![], vec![0xFF])]
#[case("boundary_crossing", vec![true, false, true, false], vec![(0xF0, 4)], vec![], vec![0xA0])]
#[case("multiple_partial_bytes", vec![true], vec![(0x5, 3), (0x2, 2), (0x1, 2)], vec![], vec![0xD9])]
fn test_bit_writer(
#[case] _description: &str,
#[case] individual_bits: Vec<bool>,
#[case] push32_operations: Vec<(u32, u8)>,
#[case] push64_operations: Vec<(u64, u8)>,
#[case] expected: Vec<u8>,
) {
let mut writer = BitWriter::new();
for bit in individual_bits {
writer.push(bit);
}
for (value, bits) in push32_operations {
writer.push32(value, bits);
}
for (value, bits) in push64_operations {
writer.push64(value, bits);
}
let result = writer.finish();
assert_eq!(result, expected);
}
#[rstest]
#[case("alternating_bits", vec![0xAA], vec![true, false, true, false, true, false, true, false], vec![], vec![])]
#[case("u64_value", vec![0xAB], vec![], vec![(0xAB, 8)], vec![])]
#[case("partial_and_multiple_bytes", vec![0xBF, 0xCD], vec![true, false], vec![(0x3F, 6), (0xCD, 8)], vec![])]
#[case("empty_reader", vec![], vec![], vec![], vec![])]
#[case("single_bit_true", vec![0x80], vec![true], vec![], vec![])]
#[case("single_bit_false", vec![0x00], vec![false], vec![], vec![])]
#[case("all_zeros", vec![0x00], vec![false, false, false, false, false, false, false, false], vec![], vec![])]
#[case("all_ones", vec![0xFF], vec![true, true, true, true, true, true, true, true], vec![], vec![])]
#[case("partial_byte_padding", vec![0xA0], vec![true, false, true], vec![], vec![])]
#[case("push32_single_bit", vec![0x80], vec![], vec![(1, 1)], vec![])]
#[case("push32_zero_bits", vec![], vec![], vec![], vec![])]
#[case("push32_max_bits", vec![0xDE, 0xAD, 0xBE, 0xEF], vec![], vec![(0xDEADBEEF, 32)], vec![])]
#[case("push64_operations", vec![0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0], vec![], vec![], vec![(0x123456789ABCDEF0, 64)])]
#[case("push64_partial", vec![0xBC, 0xD0], vec![], vec![], vec![(0xBCD, 12)])]
#[case("mixed_operations", vec![0xFF], vec![true], vec![(0x7F, 7)], vec![])]
#[case("boundary_crossing", vec![0xA0], vec![true, false, true, false], vec![(0x0, 4)], vec![])]
#[case("multiple_partial_bytes", vec![0xD9], vec![true], vec![(0x5, 3), (0x2, 2), (0x1, 2)], vec![])]
fn test_bit_reader(
#[case] _description: &str,
#[case] input_bytes: Vec<u8>,
#[case] expected_individual_bits: Vec<bool>,
#[case] expected_read32_operations: Vec<(u32, u8)>,
#[case] expected_read64_operations: Vec<(u64, u8)>,
) {
let mut reader = BitReader::new(&input_bytes);
for expected_bit in expected_individual_bits {
let actual_bit = reader.read_bit();
assert_eq!(actual_bit, Some(expected_bit));
}
for (expected_value, bits) in expected_read32_operations {
let actual_value = reader.read32(bits);
assert_eq!(actual_value, Some(expected_value));
}
for (expected_value, bits) in expected_read64_operations {
let actual_value = reader.read64(bits);
assert_eq!(actual_value, Some(expected_value));
}
if !input_bytes.is_empty() {
let next_bit = reader.read_bit();
if let Some(bit) = next_bit {
assert!(!bit);
}
}
}
#[test]
fn test_bit_reader_exhaustion() {
let bytes = vec![0xFF]; let mut reader = BitReader::new(&bytes);
for _ in 0..8 {
assert_eq!(reader.read_bit(), Some(true));
}
assert_eq!(reader.read_bit(), None);
assert_eq!(reader.read32(1), None);
assert_eq!(reader.read64(1), None);
}
#[rstest]
#[case(0, "0 B")]
#[case(1, "1 B")]
#[case(999, "999 B")]
#[case(1_000, "1.00 KB")]
#[case(1_500, "1.50 KB")]
#[case(1_000_000, "1.00 MB")]
#[case(1_500_000, "1.50 MB")]
#[case(1_000_000_000, "1.00 GB")]
#[case(1_000_000_000_000, "1.00 TB")]
#[case(1_000_000_000_000_000, "1.00 PB")]
#[case(1_000_000_000_000_000_000, "1.00 EB")]
#[case(u64::MAX, "18.45 EB")]
fn test_format_bytes_si(#[case] bytes: u64, #[case] expected: &str) {
assert_eq!(format_bytes_si(bytes), expected);
}
#[test]
fn test_compute_max_parallel_min_and_cap() {
assert_eq!(compute_max_parallel(5, &[], 8), 5);
assert_eq!(compute_max_parallel(10, &[], 8), 8);
assert_eq!(compute_max_parallel(0, &[], 0), 1);
}
#[test]
fn test_estimate_bytes_before_key() {
let run1 = SortedRun {
id: 1,
sst_views: vec![
make_sst_view("a", 10),
make_sst_view("k", 20), make_sst_view("z", 30),
],
};
let run2 = SortedRun {
id: 2,
sst_views: vec![make_sst_view("b", 40), make_sst_view("f", 50)],
};
let key = Bytes::from("m");
let total = estimate_bytes_before_key(&[run1, run2], &key);
assert_eq!(total, 10 + 40);
}
#[tokio::test]
async fn test_build_iters_concurrent_option_filters_none() {
let inputs = 0..10usize;
let out: VecDeque<usize> = build_concurrent(inputs, 4, |x| async move {
if x % 2 == 0 {
Ok(Some(x * 3))
} else {
Ok(None)
}
})
.await
.expect("should succeed");
let mut got: Vec<_> = out.into_iter().collect();
got.sort_unstable();
assert_eq!(got, vec![0, 6, 12, 18, 24]);
}
#[tokio::test]
async fn test_build_iters_concurrent_option_error() {
let inputs = 0..10usize;
let res: Result<VecDeque<usize>, SlateDBError> =
build_concurrent(inputs, 3, |x| async move {
if x == 5 {
Err(SlateDBError::Fenced)
} else {
Ok(Some(x))
}
})
.await;
assert!(res.is_err());
}
#[tokio::test]
async fn test_build_iters_concurrent_option_respects_max_parallel() {
let inputs = 0..16usize;
let max_parallel = 4;
let in_flight = Arc::new(AtomicUsize::new(0));
let peak = Arc::new(AtomicUsize::new(0));
let res: Result<VecDeque<()>, SlateDBError> = build_concurrent(inputs, max_parallel, {
let in_flight = in_flight.clone();
let peak = peak.clone();
move |_x| {
let in_flight = in_flight.clone();
let peak = peak.clone();
async move {
let cur = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
peak.fetch_max(cur, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(15)).await;
in_flight.fetch_sub(1, Ordering::SeqCst);
Ok(Some(()))
}
}
})
.await;
assert!(res.is_ok());
let observed_peak = peak.load(Ordering::SeqCst);
assert!(
observed_peak <= max_parallel,
"observed peak {} exceeds max_parallel {}",
observed_peak,
max_parallel
);
}
#[test]
fn panic_string_handles_slatedb_error() {
let err = SlateDBError::InvalidDBState;
let payload: Box<dyn Any + Send> = Box::new(err.clone());
let msg = panic_string(&payload);
assert_eq!(msg, err.to_string());
}
#[test]
fn panic_string_handles_slatedb_result() {
let payload: Box<Result<(), SlateDBError>> = Box::new(Err(SlateDBError::InvalidDBState));
let msg = panic_string(&(payload as Box<dyn Any + Send>));
assert_eq!(msg, SlateDBError::InvalidDBState.to_string());
}
#[test]
fn panic_string_handles_string() {
let s: Box<dyn Any + Send> = Box::new(String::from("hello"));
let msg = panic_string(&s);
assert_eq!(msg, "hello");
}
#[test]
fn panic_string_handles_static_str() {
let s: Box<dyn Any + Send> = Box::new("boom");
let msg = panic_string(&s);
assert_eq!(msg, "boom");
}
#[test]
fn panic_string_falls_back_for_boxed_error_trait_object() {
let err_box: Box<dyn Any + Send> = Box::new(std::io::Error::other("oh no"));
let msg = panic_string(&err_box);
assert!(msg.contains("task panicked with unknown type"));
}
#[test]
fn panic_string_falls_back_for_unknown_type() {
#[derive(Clone, Debug)]
struct MyType;
let msg = panic_string(&(Box::new(MyType) as Box<dyn Any + Send>));
assert!(msg.contains("task panicked with unknown type"));
}
#[test]
fn test_split_unwind_result_ok_ok() {
let unwind_result: Result<Result<(), SlateDBError>, Box<dyn std::any::Any + Send>> =
Ok(Ok(()));
let (result, payload) = super::split_unwind_result("test".to_string(), unwind_result);
assert!(result.is_ok());
assert!(payload.is_none());
}
#[test]
fn test_split_unwind_result_ok_error() {
let unwind_result: Result<Result<(), SlateDBError>, Box<dyn std::any::Any + Send>> =
Ok(Err(SlateDBError::Fenced));
let (result, payload) = super::split_unwind_result("test".to_string(), unwind_result);
assert!(matches!(result, Err(SlateDBError::Fenced)));
assert!(payload.is_none());
}
#[test]
fn test_split_unwind_result_panic() {
let panic_msg = "something went wrong";
let unwind_result: Result<Result<(), SlateDBError>, Box<dyn std::any::Any + Send>> =
Err(Box::new(panic_msg));
let (result, payload) = super::split_unwind_result("test_task".to_string(), unwind_result);
assert!(matches!(
result,
Err(SlateDBError::BackgroundTaskPanic(ref name)) if name == "test_task"
));
assert!(payload.is_some());
if let Some(p) = payload {
if let Some(msg) = p.downcast_ref::<&str>() {
assert_eq!(msg, &panic_msg);
} else {
panic!("expected &str, got {:?}", p);
}
}
}
#[test]
fn test_split_join_result_ok_ok() {
let join_result: Result<Result<(), SlateDBError>, tokio::task::JoinError> = Ok(Ok(()));
let (result, payload) = super::split_join_result("test".to_string(), join_result);
assert!(result.is_ok());
assert!(payload.is_none());
}
#[test]
fn test_split_join_result_ok_error() {
let join_result: Result<Result<(), SlateDBError>, tokio::task::JoinError> =
Ok(Err(SlateDBError::Fenced));
let (result, payload) = super::split_join_result("test".to_string(), join_result);
assert!(matches!(result, Err(SlateDBError::Fenced)));
assert!(payload.is_none());
}
#[tokio::test]
async fn test_split_join_result_cancelled() {
let handle = tokio::spawn(async {
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
handle.abort();
let join_result = handle.await;
let (result, payload) = super::split_join_result("test_task".to_string(), join_result);
assert!(matches!(
result,
Err(SlateDBError::BackgroundTaskCancelled(ref name)) if name == "test_task"
));
assert!(payload.is_none());
}
#[tokio::test]
async fn test_split_join_result_panic() {
let handle = tokio::spawn(async {
panic!("something went wrong");
});
let join_result = handle.await;
let (result, payload) = super::split_join_result("test_task".to_string(), join_result);
assert!(matches!(
result,
Err(SlateDBError::BackgroundTaskPanic(ref name)) if name == "test_task"
));
assert!(payload.is_some());
if let Some(p) = payload {
if let Some(msg) = p.downcast_ref::<&str>() {
assert_eq!(msg, &"something went wrong");
} else {
panic!("expected &str, got {:?}", p);
}
}
}
#[rstest]
#[case(0, 1)] #[case(1, 1)] #[case(127, 1)] #[case(128, 2)] #[case(16383, 2)] #[case(16384, 3)] #[case(2097151, 3)] #[case(2097152, 4)] #[case(268435455, 4)] #[case(268435456, 5)] #[case(u32::MAX, 5)] fn should_calculate_varint_len(#[case] value: u32, #[case] expected_len: usize) {
let len = super::varint_len(value);
assert_eq!(len, expected_len);
}
#[rstest]
#[case(0, vec![0x00])]
#[case(1, vec![0x01])]
#[case(127, vec![0x7F])]
#[case(128, vec![0x80, 0x01])]
#[case(255, vec![0xFF, 0x01])]
#[case(300, vec![0xAC, 0x02])]
#[case(16384, vec![0x80, 0x80, 0x01])]
#[case(u32::MAX, vec![0xFF, 0xFF, 0xFF, 0xFF, 0x0F])]
fn should_encode_varint(#[case] value: u32, #[case] expected: Vec<u8>) {
let mut buf = Vec::new();
super::encode_varint(&mut buf, value);
assert_eq!(buf, expected);
}
#[rstest]
#[case(vec![0x00], 0)]
#[case(vec![0x01], 1)]
#[case(vec![0x7F], 127)]
#[case(vec![0x80, 0x01], 128)]
#[case(vec![0xFF, 0x01], 255)]
#[case(vec![0xAC, 0x02], 300)]
#[case(vec![0x80, 0x80, 0x01], 16384)]
#[case(vec![0xFF, 0xFF, 0xFF, 0xFF, 0x0F], u32::MAX)]
fn should_decode_varint(#[case] bytes: Vec<u8>, #[case] expected: u32) {
let mut buf: &[u8] = &bytes;
let value = super::decode_varint(&mut buf);
assert_eq!(value, expected);
assert!(buf.is_empty());
}
#[rstest]
#[case(0)]
#[case(1)]
#[case(127)]
#[case(128)]
#[case(255)]
#[case(16383)]
#[case(16384)]
#[case(2097151)]
#[case(2097152)]
#[case(268435455)]
#[case(268435456)]
#[case(u32::MAX)]
fn should_roundtrip_varint(#[case] value: u32) {
let mut buf = Vec::new();
super::encode_varint(&mut buf, value);
let mut slice: &[u8] = &buf;
let decoded = super::decode_varint(&mut slice);
assert_eq!(decoded, value);
assert!(slice.is_empty());
assert_eq!(buf.len(), super::varint_len(value));
}
#[test]
fn should_decode_varint_with_trailing_data() {
let bytes = vec![0xAC, 0x02, 0xFF, 0xAB, 0xCD];
let mut buf: &[u8] = &bytes;
let value = super::decode_varint(&mut buf);
assert_eq!(value, 300);
assert_eq!(buf, &[0xFF, 0xAB, 0xCD]);
}
#[test]
fn should_decode_multiple_varints() {
let mut buf = Vec::new();
super::encode_varint(&mut buf, 1);
super::encode_varint(&mut buf, 300);
super::encode_varint(&mut buf, 16384);
super::encode_varint(&mut buf, u32::MAX);
let mut slice: &[u8] = &buf;
let v1 = super::decode_varint(&mut slice);
let v2 = super::decode_varint(&mut slice);
let v3 = super::decode_varint(&mut slice);
let v4 = super::decode_varint(&mut slice);
assert_eq!(v1, 1);
assert_eq!(v2, 300);
assert_eq!(v3, 16384);
assert_eq!(v4, u32::MAX);
assert!(slice.is_empty());
}
}