use crate::clock::{MonotonicClock, SystemClock};
use crate::config::DurabilityLevel;
use crate::config::DurabilityLevel::{Memory, Remote};
use crate::db_state::SortedRun;
use crate::error::SlateDBError;
use crate::types::RowEntry;
use bytes::{BufMut, Bytes};
use futures::FutureExt;
use log::error;
use rand::{Rng, RngCore};
use std::any::Any;
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc::UnboundedSender;
use ulid::Ulid;
use uuid::Uuid;
use futures::StreamExt;
use std::collections::VecDeque;
static EMPTY_KEY: Bytes = Bytes::new();
#[derive(Clone, Debug)]
pub(crate) struct WatchableOnceCell<T: Clone> {
rx: tokio::sync::watch::Receiver<Option<T>>,
tx: tokio::sync::watch::Sender<Option<T>>,
}
#[derive(Clone)]
pub(crate) struct WatchableOnceCellReader<T: Clone> {
rx: tokio::sync::watch::Receiver<Option<T>>,
}
impl<T: Clone> WatchableOnceCell<T> {
pub(crate) fn new() -> Self {
let (tx, rx) = tokio::sync::watch::channel(None);
Self { rx, tx }
}
pub(crate) fn write(&self, val: T) {
self.tx.send_if_modified(|v| {
if v.is_some() {
return false;
}
v.replace(val);
true
});
}
pub(crate) fn reader(&self) -> WatchableOnceCellReader<T> {
WatchableOnceCellReader {
rx: self.rx.clone(),
}
}
}
impl<T: Clone> WatchableOnceCellReader<T> {
pub(crate) fn read(&self) -> Option<T> {
self.rx.borrow().clone()
}
pub(crate) async fn await_value(&mut self) -> T {
self.rx
.wait_for(|v| v.is_some())
.await
.expect("watch channel closed")
.clone()
.expect("no value found")
}
}
pub(crate) fn spawn_bg_task<F, T, C>(
name: String,
handle: &tokio::runtime::Handle,
cleanup_fn: C,
future: F,
) -> tokio::task::JoinHandle<Result<T, SlateDBError>>
where
F: Future<Output = Result<T, SlateDBError>> + Send + 'static,
T: Send + 'static,
C: FnOnce(&Result<T, SlateDBError>) + Send + 'static,
{
let wrapped = AssertUnwindSafe(future).catch_unwind().map(move |outcome| {
let result = match outcome {
Ok(result) => result,
Err(payload) => {
error!(
"spawned task panicked. [name={}, panic={}]",
name,
panic_string(&payload)
);
Err(SlateDBError::BackgroundTaskPanic(name))
}
};
cleanup_fn(&result);
result
});
handle.spawn(wrapped)
}
pub(crate) async fn get_now_for_read(
mono_clock: Arc<MonotonicClock>,
durability_level: DurabilityLevel,
) -> Result<i64, SlateDBError> {
match durability_level {
Remote => Ok(mono_clock.get_last_durable_tick()),
Memory => mono_clock.now().await,
}
}
pub(crate) fn is_not_expired(entry: &RowEntry, now: i64) -> bool {
if let Some(expire_ts) = entry.expire_ts {
expire_ts > now
} else {
true
}
}
pub(crate) fn merge_options<T>(
current: Option<T>,
next: Option<T>,
f: impl Fn(T, T) -> T,
) -> Option<T> {
match (current, next) {
(Some(current), Some(next)) => Some(f(current, next)),
(None, next) => next,
(current, None) => current,
}
}
fn bytes_into_minimal_vec(bytes: &Bytes) -> Vec<u8> {
let mut clamped = Vec::new();
clamped.reserve_exact(bytes.len());
clamped.put_slice(bytes.as_ref());
clamped
}
pub(crate) fn clamp_allocated_size_bytes(bytes: &Bytes) -> Bytes {
bytes_into_minimal_vec(bytes).into()
}
pub(crate) fn compute_index_key(
prev_block_last_key: Option<Bytes>,
this_block_first_key: &Bytes,
) -> Bytes {
if let Some(prev_key) = prev_block_last_key {
compute_lower_bound(&prev_key, this_block_first_key)
} else {
EMPTY_KEY.clone()
}
}
fn compute_lower_bound(prev_block_last_key: &Bytes, this_block_first_key: &Bytes) -> Bytes {
assert!(!prev_block_last_key.is_empty() && !this_block_first_key.is_empty());
for i in 0..prev_block_last_key.len() {
if prev_block_last_key[i] != this_block_first_key[i] {
return this_block_first_key.slice(..i + 1);
}
}
if prev_block_last_key.len() == this_block_first_key.len() {
return this_block_first_key.clone();
}
this_block_first_key.slice(..prev_block_last_key.len() + 1)
}
#[derive(Debug)]
pub(crate) struct MonotonicSeq {
val: AtomicU64,
}
impl MonotonicSeq {
pub fn new(initial_value: u64) -> Self {
Self {
val: AtomicU64::new(initial_value),
}
}
pub fn next(&self) -> u64 {
self.val.fetch_add(1, SeqCst) + 1
}
pub fn store(&self, value: u64) {
self.val.store(value, SeqCst);
}
pub fn load(&self) -> u64 {
self.val.load(SeqCst)
}
pub fn store_if_greater(&self, value: u64) {
self.val.fetch_max(value, SeqCst);
}
}
pub trait SendSafely<T> {
fn send_safely(
&self,
closed_result_reader: WatchableOnceCellReader<Result<(), SlateDBError>>,
message: T,
) -> Result<(), SlateDBError>;
}
#[allow(clippy::panic, clippy::disallowed_methods)]
impl<T> SendSafely<T> for UnboundedSender<T> {
#[inline]
fn send_safely(
&self,
closed_result_reader: WatchableOnceCellReader<Result<(), SlateDBError>>,
message: T,
) -> Result<(), SlateDBError> {
match self.send(message) {
Ok(_) => Ok(()),
Err(e) => {
if let Some(result) = closed_result_reader.read() {
match result {
Ok(()) => Err(SlateDBError::Closed),
Err(err) => Err(err),
}
} else {
panic!("Failed to send message to unbounded channel: {}", e);
}
}
}
}
}
pub trait IdGenerator {
fn gen_uuid(&mut self) -> Uuid;
fn gen_ulid(&mut self, clock: &dyn SystemClock) -> Ulid;
}
impl<R: RngCore> IdGenerator for R {
fn gen_uuid(&mut self) -> Uuid {
let mut bytes = [0u8; 16];
self.fill_bytes(&mut bytes);
bytes[6] = (bytes[6] & 0x0f) | 0x40;
bytes[8] = (bytes[8] & 0x3f) | 0x80;
Uuid::from_bytes(bytes)
}
fn gen_ulid(&mut self, clock: &dyn SystemClock) -> Ulid {
let now = u64::try_from(clock.now().timestamp_millis())
.expect("timestamp outside u64 range in gen_ulid");
let random_bytes = self.random::<u128>();
Ulid::from_parts(now, random_bytes)
}
}
pub async fn timeout<T, Err>(
clock: Arc<dyn SystemClock>,
duration: Duration,
error_fn: impl FnOnce() -> Err,
future: impl Future<Output = Result<T, Err>> + Send,
) -> Result<T, Err> {
tokio::select! {
biased;
res = future => res,
_ = clock.sleep(duration) => Err(error_fn())
}
}
pub(crate) struct BitWriter {
buf: Vec<u8>,
cur: u8, n: u8, }
impl BitWriter {
pub(crate) fn new() -> Self {
BitWriter {
buf: Vec::new(),
cur: 0,
n: 0,
}
}
pub(crate) fn push(&mut self, bit: bool) {
if bit {
self.cur |= 1 << (7 - self.n);
}
self.n += 1;
if self.n == 8 {
self.flush_byte();
}
}
pub(crate) fn push32(&mut self, value: u32, bits: u8) {
for i in (0..bits).rev() {
let bit = ((value >> i) & 1) != 0;
self.push(bit);
}
}
pub(crate) fn push64(&mut self, value: u64, bits: u8) {
for i in (0..bits).rev() {
let bit = ((value >> i) & 1) != 0;
self.push(bit);
}
}
fn flush_byte(&mut self) {
self.buf.push(self.cur);
self.cur = 0;
self.n = 0;
}
pub(crate) fn finish(mut self) -> Vec<u8> {
if self.n > 0 {
self.buf.push(self.cur);
}
self.buf
}
}
pub(crate) struct BitReader<'a> {
buf: &'a [u8],
byte_pos: usize,
bit_pos: u8, }
impl<'a> BitReader<'a> {
pub(crate) fn new(buf: &'a [u8]) -> Self {
BitReader {
buf,
byte_pos: 0,
bit_pos: 0,
}
}
pub(crate) fn read_bit(&mut self) -> Option<bool> {
if self.byte_pos >= self.buf.len() {
return None;
}
let byte = self.buf[self.byte_pos];
let bit = ((byte >> (7 - self.bit_pos)) & 1) != 0;
self.bit_pos += 1;
if self.bit_pos == 8 {
self.bit_pos = 0;
self.byte_pos += 1;
}
Some(bit)
}
pub(crate) fn read32(&mut self, bits: u8) -> Option<u32> {
let mut val = 0u32;
for _ in 0..bits {
val <<= 1;
match self.read_bit() {
Some(true) => val |= 1,
Some(false) => (),
None => return None,
}
}
Some(val)
}
pub(crate) fn read64(&mut self, bits: u8) -> Option<u64> {
let mut val = 0u64;
for _ in 0..bits {
val <<= 1;
match self.read_bit() {
Some(true) => val |= 1,
Some(false) => (),
None => return None,
}
}
Some(val)
}
}
pub(crate) fn sign_extend(val: u32, bits: u8) -> i32 {
let shift = 32 - bits;
((val << shift) as i32) >> shift
}
pub(crate) fn compute_max_parallel(l0_count: usize, srs: &[SortedRun], cap: usize) -> usize {
let total_ssts = l0_count + srs.iter().map(|sr| sr.ssts.len()).sum::<usize>();
total_ssts.min(cap).max(1)
}
#[allow(clippy::redundant_closure)]
pub(crate) async fn build_concurrent<I, T, F, Fut>(
inputs: I,
max_parallel: usize,
f: F,
) -> Result<VecDeque<T>, SlateDBError>
where
I: IntoIterator,
I::Item: Send,
T: Send,
F: Fn(I::Item) -> Fut + Send,
Fut: std::future::Future<Output = Result<Option<T>, SlateDBError>> + Send,
{
let mut out = VecDeque::new();
let results = futures::stream::iter(inputs.into_iter().map(move |it| f(it)))
.buffer_unordered(max_parallel.max(1))
.collect::<Vec<_>>()
.await;
for r in results {
match r {
Ok(Some(t)) => out.push_back(t),
Ok(None) => {}
Err(e) => return Err(e),
}
}
Ok(out)
}
#[allow(dead_code)]
pub fn panic_string(panic: &Box<dyn Any + Send>) -> String {
if let Some(result) = panic.downcast_ref::<Result<(), SlateDBError>>() {
match result {
Ok(()) => "ok".to_string(),
Err(e) => e.to_string(),
}
} else if let Some(err) = panic.downcast_ref::<SlateDBError>() {
err.to_string()
} else if let Some(err) = panic.downcast_ref::<Box<dyn std::error::Error>>() {
err.to_string()
} else if let Some(err) = panic.downcast_ref::<String>() {
err.clone()
} else if let Some(err) = panic.downcast_ref::<&str>() {
err.to_string()
} else {
format!(
"task panicked with unknown type [type_id=`{:?}`]",
(**panic).type_id()
)
}
}
pub(crate) fn split_unwind_result(
name: String,
unwind_result: Result<Result<(), SlateDBError>, Box<dyn std::any::Any + Send>>,
) -> (
Result<(), SlateDBError>,
Option<Box<dyn std::any::Any + Send>>,
) {
match unwind_result {
Ok(result) => (result.clone(), None),
Err(payload) => (Err(SlateDBError::BackgroundTaskPanic(name)), Some(payload)),
}
}
pub(crate) fn split_join_result(
name: String,
join_result: Result<Result<(), SlateDBError>, tokio::task::JoinError>,
) -> (
Result<(), SlateDBError>,
Option<Box<dyn std::any::Any + Send>>,
) {
match join_result {
Ok(task_result) => (task_result.clone(), None),
Err(join_error) => {
if join_error.is_cancelled() {
(Err(SlateDBError::BackgroundTaskCancelled(name)), None)
} else {
let payload = join_error.into_panic();
(Err(SlateDBError::BackgroundTaskPanic(name)), Some(payload))
}
}
}
}
pub fn format_bytes_si(bytes: u64) -> String {
const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB", "PB", "EB"];
const FACTOR: f64 = 1000.0;
if bytes < 1000 {
return format!("{} B", bytes);
}
let mut value = bytes as f64;
let mut unit_index = 0;
while value >= FACTOR && unit_index < UNITS.len() - 1 {
value /= FACTOR;
unit_index += 1;
}
format!("{:.2} {}", value, UNITS[unit_index])
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use crate::clock::MonotonicClock;
use crate::error::SlateDBError;
use crate::test_utils::TestClock;
use crate::utils::{
build_concurrent, bytes_into_minimal_vec, clamp_allocated_size_bytes, compute_index_key,
compute_max_parallel, format_bytes_si, panic_string, spawn_bg_task, BitReader, BitWriter,
WatchableOnceCell,
};
use bytes::{BufMut, Bytes, BytesMut};
use parking_lot::Mutex;
use std::any::Any;
use std::collections::VecDeque;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
struct ResultCaptor<T: Clone> {
error: Mutex<Option<Result<T, SlateDBError>>>,
}
impl<T: Clone> ResultCaptor<T> {
fn new() -> Self {
Self {
error: Mutex::new(None),
}
}
fn capture(&self, result: &Result<T, SlateDBError>) {
let mut guard = self.error.lock();
let prev = guard.replace(result.clone());
assert!(prev.is_none());
}
fn captured(&self) -> Option<Result<T, SlateDBError>> {
self.error.lock().clone()
}
}
#[test]
fn test_should_return_empty_for_index_of_first_block() {
let this_block_first_key = Bytes::from(vec![0x01, 0x02, 0x03]);
let result = compute_index_key(None, &this_block_first_key);
assert_eq!(result, &Bytes::new());
}
#[rstest]
#[case(Some("aaaac"), "abaaa", "ab")]
#[case(Some("ababc"), "abacd", "abac")]
#[case(Some("cc"), "ccccccc", "ccc")]
#[case(Some("eed"), "eee", "eee")]
#[case(Some("abcdef"), "abcdef", "abcdef")]
fn test_should_compute_index_key(
#[case] prev_block_last_key: Option<&'static str>,
#[case] this_block_first_key: &'static str,
#[case] expected_index_key: &'static str,
) {
assert_eq!(
compute_index_key(
prev_block_last_key.map(|s| Bytes::from(s.to_string())),
&Bytes::from(this_block_first_key.to_string())
),
Bytes::from_static(expected_index_key.as_bytes())
);
}
#[rstest]
#[case(Some(""), "a")]
#[case(Some("a"), "")]
#[should_panic]
fn test_should_panic_on_empty_keys(
#[case] prev_block_last_key: Option<&'static str>,
#[case] this_block_first_key: &'static str,
) {
compute_index_key(
prev_block_last_key.map(|s| Bytes::from(s.to_string())),
&Bytes::from(this_block_first_key.to_string()),
);
}
#[tokio::test]
async fn test_should_cleanup_when_task_exits_with_error() {
let captor = Arc::new(ResultCaptor::new());
let handle = tokio::runtime::Handle::current();
let captor2 = captor.clone();
let task = spawn_bg_task(
"test".to_string(),
&handle,
move |err| captor2.capture(err),
async { Err(SlateDBError::Fenced) },
);
let result: Result<(), SlateDBError> = task.await.expect("join failure");
assert!(matches!(result, Err(SlateDBError::Fenced)));
assert!(matches!(captor.captured(), Some(Err(SlateDBError::Fenced))));
}
#[tokio::test]
async fn test_should_cleanup_when_task_panics() {
let monitored = async {
panic!("oops");
};
let captor = Arc::new(ResultCaptor::new());
let handle = tokio::runtime::Handle::current();
let captor2 = captor.clone();
let task = spawn_bg_task(
"test".to_string(),
&handle,
move |err| captor2.capture(err),
monitored,
);
let result: Result<(), SlateDBError> = task.await.expect("join failure");
assert!(matches!(result, Err(SlateDBError::BackgroundTaskPanic(_))));
assert!(matches!(
captor.captured(),
Some(Err(SlateDBError::BackgroundTaskPanic(_)))
));
}
#[tokio::test]
async fn test_should_cleanup_when_task_exits() {
let captor = Arc::new(ResultCaptor::new());
let handle = tokio::runtime::Handle::current();
let captor2 = captor.clone();
let task = spawn_bg_task(
"test".to_string(),
&handle,
move |err| captor2.capture(err),
async { Ok(()) },
);
let result: Result<(), SlateDBError> = task.await.expect("join failure");
assert!(matches!(result, Ok(())));
assert!(matches!(captor.captured(), Some(Ok(()))));
}
#[tokio::test]
async fn test_should_only_write_register_once() {
let register = WatchableOnceCell::new();
let reader = register.reader();
assert_eq!(reader.read(), None);
register.write(123);
assert_eq!(reader.read(), Some(123));
register.write(456);
assert_eq!(reader.read(), Some(123));
}
#[tokio::test]
async fn test_should_return_on_await_written_register() {
let register = WatchableOnceCell::new();
let mut reader = register.reader();
let h = tokio::spawn(async move {
assert_eq!(reader.await_value().await, 123);
assert_eq!(reader.await_value().await, 123);
});
register.write(123);
h.await.unwrap();
}
#[tokio::test]
async fn test_monotonicity_enforcement_on_mono_clock() {
let clock = Arc::new(TestClock::new());
let mono_clock = MonotonicClock::new(clock.clone(), 0);
clock.ticker.store(10, SeqCst);
mono_clock.now().await.unwrap();
clock.ticker.store(5, SeqCst);
if let Err(SlateDBError::InvalidClockTick {
last_tick,
next_tick,
}) = mono_clock.now().await
{
assert_eq!(last_tick, 10);
assert_eq!(next_tick, 5);
} else {
panic!("Expected InvalidClockTick from mono_clock")
}
}
#[tokio::test]
async fn test_monotonicity_enforcement_on_mono_clock_set_tick() {
let clock = Arc::new(TestClock::new());
let mono_clock = MonotonicClock::new(clock.clone(), 0);
clock.ticker.store(10, SeqCst);
mono_clock.now().await.unwrap();
if let Err(SlateDBError::InvalidClockTick {
last_tick,
next_tick,
}) = mono_clock.set_last_tick(5)
{
assert_eq!(last_tick, 10);
assert_eq!(next_tick, 5);
} else {
panic!("Expected InvalidClockTick from mono_clock")
}
}
#[tokio::test(start_paused = true)]
async fn test_await_valid_tick() {
let delegate_clock = Arc::new(TestClock::new());
let mono_clock = MonotonicClock::new(delegate_clock.clone(), 100);
tokio::spawn({
let delegate_clock = delegate_clock.clone();
async move {
tokio::time::sleep(Duration::from_millis(50)).await;
delegate_clock.ticker.store(101, SeqCst);
}
});
let tick_future = mono_clock.now();
tokio::time::advance(Duration::from_millis(100)).await;
let result = tick_future.await;
assert_eq!(result.unwrap(), 101);
}
#[tokio::test(start_paused = true)]
async fn test_await_valid_tick_failure() {
let delegate_clock = Arc::new(TestClock::new());
let mono_clock = MonotonicClock::new(delegate_clock.clone(), 100);
let tick_future = mono_clock.now();
tokio::time::advance(Duration::from_millis(110)).await;
let result = tick_future.await;
assert!(result.is_err());
}
#[test]
fn test_should_clamp_bytes_to_minimal_vec() {
let mut bytes = BytesMut::with_capacity(2048);
bytes.put_bytes(0u8, 2048);
let bytes = bytes.freeze();
let slice = bytes.slice(100..1124);
let clamped = bytes_into_minimal_vec(&slice);
assert_eq!(slice.as_ref(), clamped.as_slice());
assert_eq!(clamped.capacity(), 1024);
}
#[test]
fn test_should_clamp_bytes_and_preserve_data() {
let mut bytes = BytesMut::with_capacity(2048);
bytes.put_bytes(0u8, 2048);
let bytes = bytes.freeze();
let slice = bytes.slice(100..1124);
let clamped = clamp_allocated_size_bytes(&slice);
assert_eq!(clamped, slice);
assert_ne!(clamped.as_ptr(), slice.as_ptr());
}
#[tokio::test]
#[cfg(feature = "test-util")]
async fn test_timeout_completes_before_expiry() {
use crate::{clock::MockSystemClock, utils::timeout};
let clock = Arc::new(MockSystemClock::new());
let completed_future = async { Ok::<_, SlateDBError>(42) };
let timeout_future = timeout(
clock,
Duration::from_millis(100),
|| unreachable!(),
completed_future,
);
let result = timeout_future.await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
#[cfg(feature = "test-util")]
async fn test_timeout_expires() {
use std::sync::atomic::AtomicBool;
use crate::clock::{MockSystemClock, SystemClock};
use crate::utils::timeout;
let clock = Arc::new(MockSystemClock::new());
let never_completes = std::future::pending::<Result<(), SlateDBError>>();
let timeout_duration = Duration::from_millis(100);
let timeout_future = timeout(
clock.clone(),
timeout_duration,
|| SlateDBError::TransactionalObjectTimeout {
timeout: timeout_duration,
},
never_completes,
);
let done = Arc::new(AtomicBool::new(false));
let this_done = done.clone();
tokio::spawn(async move {
while !this_done.load(SeqCst) {
clock.advance(Duration::from_millis(100)).await;
tokio::task::yield_now().await;
}
});
let result = timeout_future.await;
done.store(true, SeqCst);
assert!(matches!(
result,
Err(SlateDBError::TransactionalObjectTimeout { .. })
));
}
#[tokio::test]
#[cfg(feature = "test-util")]
async fn test_timeout_respects_biased_select() {
use crate::{clock::MockSystemClock, utils::timeout};
let clock = Arc::new(MockSystemClock::new());
let completes_immediately = async { Ok::<_, SlateDBError>(42) };
let timeout_future = timeout(
clock,
Duration::from_millis(100),
|| unreachable!(),
completes_immediately,
);
let result = timeout_future.await;
assert_eq!(result.unwrap(), 42);
}
#[rstest]
#[case("alternating_bits", vec![true, false, true, false, true, false, true, false], vec![], vec![], vec![0xAA])]
#[case("u64_value", vec![], vec![(0xAB, 8)], vec![], vec![0xAB])]
#[case("partial_and_multiple_bytes", vec![true, false], vec![(0x3F, 6), (0xCD, 8)], vec![], vec![0xBF, 0xCD])]
#[case("empty_writer", vec![], vec![], vec![], vec![])]
#[case("single_bit_true", vec![true], vec![], vec![], vec![0x80])]
#[case("single_bit_false", vec![false], vec![], vec![], vec![0x00])]
#[case("all_zeros", vec![false, false, false, false, false, false, false, false], vec![], vec![], vec![0x00])]
#[case("all_ones", vec![true, true, true, true, true, true, true, true], vec![], vec![], vec![0xFF])]
#[case("partial_byte_padding", vec![true, false, true], vec![], vec![], vec![0xA0])]
#[case("push32_single_bit", vec![], vec![(1, 1)], vec![], vec![0x80])]
#[case("push32_zero_bits", vec![], vec![(0xFF, 0)], vec![], vec![])]
#[case("push32_max_bits", vec![], vec![(0xDEADBEEF, 32)], vec![], vec![0xDE, 0xAD, 0xBE, 0xEF])]
#[case("push64_operations", vec![], vec![], vec![(0x123456789ABCDEF0, 64)], vec![0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0])]
#[case("push64_partial", vec![], vec![], vec![(0xABCD, 12)], vec![0xBC, 0xD0])]
#[case("mixed_operations", vec![true], vec![(0x7F, 7)], vec![], vec![0xFF])]
#[case("boundary_crossing", vec![true, false, true, false], vec![(0xF0, 4)], vec![], vec![0xA0])]
#[case("multiple_partial_bytes", vec![true], vec![(0x5, 3), (0x2, 2), (0x1, 2)], vec![], vec![0xD9])]
fn test_bit_writer(
#[case] _description: &str,
#[case] individual_bits: Vec<bool>,
#[case] push32_operations: Vec<(u32, u8)>,
#[case] push64_operations: Vec<(u64, u8)>,
#[case] expected: Vec<u8>,
) {
let mut writer = BitWriter::new();
for bit in individual_bits {
writer.push(bit);
}
for (value, bits) in push32_operations {
writer.push32(value, bits);
}
for (value, bits) in push64_operations {
writer.push64(value, bits);
}
let result = writer.finish();
assert_eq!(result, expected);
}
#[rstest]
#[case("alternating_bits", vec![0xAA], vec![true, false, true, false, true, false, true, false], vec![], vec![])]
#[case("u64_value", vec![0xAB], vec![], vec![(0xAB, 8)], vec![])]
#[case("partial_and_multiple_bytes", vec![0xBF, 0xCD], vec![true, false], vec![(0x3F, 6), (0xCD, 8)], vec![])]
#[case("empty_reader", vec![], vec![], vec![], vec![])]
#[case("single_bit_true", vec![0x80], vec![true], vec![], vec![])]
#[case("single_bit_false", vec![0x00], vec![false], vec![], vec![])]
#[case("all_zeros", vec![0x00], vec![false, false, false, false, false, false, false, false], vec![], vec![])]
#[case("all_ones", vec![0xFF], vec![true, true, true, true, true, true, true, true], vec![], vec![])]
#[case("partial_byte_padding", vec![0xA0], vec![true, false, true], vec![], vec![])]
#[case("push32_single_bit", vec![0x80], vec![], vec![(1, 1)], vec![])]
#[case("push32_zero_bits", vec![], vec![], vec![], vec![])]
#[case("push32_max_bits", vec![0xDE, 0xAD, 0xBE, 0xEF], vec![], vec![(0xDEADBEEF, 32)], vec![])]
#[case("push64_operations", vec![0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0], vec![], vec![], vec![(0x123456789ABCDEF0, 64)])]
#[case("push64_partial", vec![0xBC, 0xD0], vec![], vec![], vec![(0xBCD, 12)])]
#[case("mixed_operations", vec![0xFF], vec![true], vec![(0x7F, 7)], vec![])]
#[case("boundary_crossing", vec![0xA0], vec![true, false, true, false], vec![(0x0, 4)], vec![])]
#[case("multiple_partial_bytes", vec![0xD9], vec![true], vec![(0x5, 3), (0x2, 2), (0x1, 2)], vec![])]
fn test_bit_reader(
#[case] _description: &str,
#[case] input_bytes: Vec<u8>,
#[case] expected_individual_bits: Vec<bool>,
#[case] expected_read32_operations: Vec<(u32, u8)>,
#[case] expected_read64_operations: Vec<(u64, u8)>,
) {
let mut reader = BitReader::new(&input_bytes);
for expected_bit in expected_individual_bits {
let actual_bit = reader.read_bit();
assert_eq!(actual_bit, Some(expected_bit));
}
for (expected_value, bits) in expected_read32_operations {
let actual_value = reader.read32(bits);
assert_eq!(actual_value, Some(expected_value));
}
for (expected_value, bits) in expected_read64_operations {
let actual_value = reader.read64(bits);
assert_eq!(actual_value, Some(expected_value));
}
if !input_bytes.is_empty() {
let next_bit = reader.read_bit();
if let Some(bit) = next_bit {
assert!(!bit);
}
}
}
#[test]
fn test_bit_reader_exhaustion() {
let bytes = vec![0xFF]; let mut reader = BitReader::new(&bytes);
for _ in 0..8 {
assert_eq!(reader.read_bit(), Some(true));
}
assert_eq!(reader.read_bit(), None);
assert_eq!(reader.read32(1), None);
assert_eq!(reader.read64(1), None);
}
#[rstest]
#[case(0, "0 B")]
#[case(1, "1 B")]
#[case(999, "999 B")]
#[case(1_000, "1.00 KB")]
#[case(1_500, "1.50 KB")]
#[case(1_000_000, "1.00 MB")]
#[case(1_500_000, "1.50 MB")]
#[case(1_000_000_000, "1.00 GB")]
#[case(1_000_000_000_000, "1.00 TB")]
#[case(1_000_000_000_000_000, "1.00 PB")]
#[case(1_000_000_000_000_000_000, "1.00 EB")]
#[case(u64::MAX, "18.45 EB")]
fn test_format_bytes_si(#[case] bytes: u64, #[case] expected: &str) {
assert_eq!(format_bytes_si(bytes), expected);
}
#[test]
fn test_compute_max_parallel_min_and_cap() {
assert_eq!(compute_max_parallel(5, &[], 8), 5);
assert_eq!(compute_max_parallel(10, &[], 8), 8);
assert_eq!(compute_max_parallel(0, &[], 0), 1);
}
#[tokio::test]
async fn test_build_iters_concurrent_option_filters_none() {
let inputs = 0..10usize;
let out: VecDeque<usize> = build_concurrent(inputs, 4, |x| async move {
if x % 2 == 0 {
Ok(Some(x * 3))
} else {
Ok(None)
}
})
.await
.expect("should succeed");
let mut got: Vec<_> = out.into_iter().collect();
got.sort_unstable();
assert_eq!(got, vec![0, 6, 12, 18, 24]);
}
#[tokio::test]
async fn test_build_iters_concurrent_option_error() {
let inputs = 0..10usize;
let res: Result<VecDeque<usize>, SlateDBError> =
build_concurrent(inputs, 3, |x| async move {
if x == 5 {
Err(SlateDBError::Fenced)
} else {
Ok(Some(x))
}
})
.await;
assert!(res.is_err());
}
#[tokio::test]
async fn test_build_iters_concurrent_option_respects_max_parallel() {
let inputs = 0..16usize;
let max_parallel = 4;
let in_flight = Arc::new(AtomicUsize::new(0));
let peak = Arc::new(AtomicUsize::new(0));
let res: Result<VecDeque<()>, SlateDBError> = build_concurrent(inputs, max_parallel, {
let in_flight = in_flight.clone();
let peak = peak.clone();
move |_x| {
let in_flight = in_flight.clone();
let peak = peak.clone();
async move {
let cur = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
peak.fetch_max(cur, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(15)).await;
in_flight.fetch_sub(1, Ordering::SeqCst);
Ok(Some(()))
}
}
})
.await;
assert!(res.is_ok());
let observed_peak = peak.load(Ordering::SeqCst);
assert!(
observed_peak <= max_parallel,
"observed peak {} exceeds max_parallel {}",
observed_peak,
max_parallel
);
}
#[test]
fn panic_string_handles_slatedb_error() {
let err = SlateDBError::InvalidDBState;
let payload: Box<dyn Any + Send> = Box::new(err.clone());
let msg = panic_string(&payload);
assert_eq!(msg, err.to_string());
}
#[test]
fn panic_string_handles_slatedb_result() {
let payload: Box<Result<(), SlateDBError>> = Box::new(Err(SlateDBError::InvalidDBState));
let msg = panic_string(&(payload as Box<dyn Any + Send>));
assert_eq!(msg, SlateDBError::InvalidDBState.to_string());
}
#[test]
fn panic_string_handles_string() {
let s: Box<dyn Any + Send> = Box::new(String::from("hello"));
let msg = panic_string(&s);
assert_eq!(msg, "hello");
}
#[test]
fn panic_string_handles_static_str() {
let s: Box<dyn Any + Send> = Box::new("boom");
let msg = panic_string(&s);
assert_eq!(msg, "boom");
}
#[test]
fn panic_string_falls_back_for_boxed_error_trait_object() {
let err_box: Box<dyn Any + Send> = Box::new(std::io::Error::other("oh no"));
let msg = panic_string(&err_box);
assert!(msg.contains("task panicked with unknown type"));
}
#[test]
fn panic_string_falls_back_for_unknown_type() {
#[derive(Clone, Debug)]
struct MyType;
let msg = panic_string(&(Box::new(MyType) as Box<dyn Any + Send>));
assert!(msg.contains("task panicked with unknown type"));
}
#[test]
fn test_split_unwind_result_ok_ok() {
let unwind_result: Result<Result<(), SlateDBError>, Box<dyn std::any::Any + Send>> =
Ok(Ok(()));
let (result, payload) = super::split_unwind_result("test".to_string(), unwind_result);
assert!(result.is_ok());
assert!(payload.is_none());
}
#[test]
fn test_split_unwind_result_ok_error() {
let unwind_result: Result<Result<(), SlateDBError>, Box<dyn std::any::Any + Send>> =
Ok(Err(SlateDBError::Fenced));
let (result, payload) = super::split_unwind_result("test".to_string(), unwind_result);
assert!(matches!(result, Err(SlateDBError::Fenced)));
assert!(payload.is_none());
}
#[test]
fn test_split_unwind_result_panic() {
let panic_msg = "something went wrong";
let unwind_result: Result<Result<(), SlateDBError>, Box<dyn std::any::Any + Send>> =
Err(Box::new(panic_msg));
let (result, payload) = super::split_unwind_result("test_task".to_string(), unwind_result);
assert!(matches!(
result,
Err(SlateDBError::BackgroundTaskPanic(ref name)) if name == "test_task"
));
assert!(payload.is_some());
if let Some(p) = payload {
if let Some(msg) = p.downcast_ref::<&str>() {
assert_eq!(msg, &panic_msg);
} else {
panic!("expected &str, got {:?}", p);
}
}
}
#[test]
fn test_split_join_result_ok_ok() {
let join_result: Result<Result<(), SlateDBError>, tokio::task::JoinError> = Ok(Ok(()));
let (result, payload) = super::split_join_result("test".to_string(), join_result);
assert!(result.is_ok());
assert!(payload.is_none());
}
#[test]
fn test_split_join_result_ok_error() {
let join_result: Result<Result<(), SlateDBError>, tokio::task::JoinError> =
Ok(Err(SlateDBError::Fenced));
let (result, payload) = super::split_join_result("test".to_string(), join_result);
assert!(matches!(result, Err(SlateDBError::Fenced)));
assert!(payload.is_none());
}
#[tokio::test]
async fn test_split_join_result_cancelled() {
let handle = tokio::spawn(async {
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
handle.abort();
let join_result = handle.await;
let (result, payload) = super::split_join_result("test_task".to_string(), join_result);
assert!(matches!(
result,
Err(SlateDBError::BackgroundTaskCancelled(ref name)) if name == "test_task"
));
assert!(payload.is_none());
}
#[tokio::test]
async fn test_split_join_result_panic() {
let handle = tokio::spawn(async {
panic!("something went wrong");
});
let join_result = handle.await;
let (result, payload) = super::split_join_result("test_task".to_string(), join_result);
assert!(matches!(
result,
Err(SlateDBError::BackgroundTaskPanic(ref name)) if name == "test_task"
));
assert!(payload.is_some());
if let Some(p) = payload {
if let Some(msg) = p.downcast_ref::<&str>() {
assert_eq!(msg, &"something went wrong");
} else {
panic!("expected &str, got {:?}", p);
}
}
}
}