use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use parking_lot::Mutex;
use serde::{de::DeserializeOwned, Serialize};
use crate::backend::{memory_backend, MemoryBackend, StorageBackend};
use crate::codec::{Codec, JsonCodec};
use crate::error::StorageError;
use crate::tier::{
AppendCursor, AppendLoadResult, AppendLogMode, AppendLogStorageTier, BaseStorageTier,
KvStorageTier, LoadEntriesOpts, PrefixIter, SnapshotStorageTier,
};
type FilterFn<T> = Box<dyn Fn(&T) -> bool + Send + Sync>;
type KeyOfFn<T> = Box<dyn Fn(&T) -> String + Send + Sync>;
type KvFilterFn<T> = Box<dyn Fn(&str, &T) -> bool + Send + Sync>;
pub struct SnapshotStorage<B, T, C = JsonCodec>
where
B: StorageBackend + ?Sized,
T: Send + Sync + 'static,
C: Codec<T>,
{
backend: Arc<B>,
codec: C,
name: String,
debounce_ms: Option<u32>,
compact_every: Option<u32>,
filter: Option<FilterFn<T>>,
key_of: KeyOfFn<T>,
pending: Mutex<Option<T>>,
write_count: Mutex<u64>,
last_saved_key: Mutex<Option<String>>,
}
pub struct SnapshotStorageOptions<T, C = JsonCodec>
where
T: Send + Sync + 'static,
C: Codec<T>,
{
pub name: Option<String>,
pub codec: C,
pub debounce_ms: Option<u32>,
pub compact_every: Option<u32>,
pub filter: Option<FilterFn<T>>,
pub key_of: Option<KeyOfFn<T>>,
}
impl<T> Default for SnapshotStorageOptions<T, JsonCodec>
where
T: Serialize + DeserializeOwned + Send + Sync + 'static,
{
fn default() -> Self {
Self {
name: None,
codec: JsonCodec,
debounce_ms: None,
compact_every: None,
filter: None,
key_of: None,
}
}
}
pub fn snapshot_storage<B, T, C>(
backend: Arc<B>,
opts: SnapshotStorageOptions<T, C>,
) -> SnapshotStorage<B, T, C>
where
B: StorageBackend + ?Sized,
T: Send + Sync + 'static,
C: Codec<T>,
{
assert!(
opts.compact_every != Some(0),
"snapshot_storage: compact_every must be None or Some(n) where n >= 1, got Some(0)",
);
let name = opts.name.unwrap_or_else(|| backend.name().to_string());
let fallback_key = name.clone();
let key_of = opts
.key_of
.unwrap_or_else(|| Box::new(move |_| fallback_key.clone()));
SnapshotStorage {
backend,
codec: opts.codec,
name,
debounce_ms: opts.debounce_ms,
compact_every: opts.compact_every,
filter: opts.filter,
key_of,
pending: Mutex::new(None),
write_count: Mutex::new(0),
last_saved_key: Mutex::new(None),
}
}
pub fn memory_snapshot<T, C>(
opts: SnapshotStorageOptions<T, C>,
) -> SnapshotStorage<MemoryBackend, T, C>
where
T: Send + Sync + 'static,
C: Codec<T>,
{
snapshot_storage(memory_backend(), opts)
}
impl<B, T, C> SnapshotStorage<B, T, C>
where
B: StorageBackend + ?Sized,
T: Send + Sync + 'static,
C: Codec<T>,
{
fn try_flush(
backend: &B,
codec: &C,
key_of: &KeyOfFn<T>,
last_saved_key: &Mutex<Option<String>>,
snapshot: T,
) -> Result<(), (T, StorageError)> {
let key = key_of(&snapshot);
let bytes = match codec.encode(&snapshot) {
Ok(b) => b,
Err(e) => return Err((snapshot, e.into())),
};
if let Err(e) = backend.write(&key, &bytes) {
return Err((snapshot, e));
}
*last_saved_key.lock() = Some(key);
Ok(())
}
}
impl<B, T, C> BaseStorageTier for SnapshotStorage<B, T, C>
where
B: StorageBackend + ?Sized,
T: Send + Sync + 'static,
C: Codec<T>,
{
fn name(&self) -> &str {
&self.name
}
fn debounce_ms(&self) -> Option<u32> {
self.debounce_ms
}
fn compact_every(&self) -> Option<u32> {
self.compact_every
}
fn flush(&self) -> Result<(), StorageError> {
let slot = self.pending.lock().take();
let Some(snapshot) = slot else {
return Ok(());
};
match Self::try_flush(
&*self.backend,
&self.codec,
&self.key_of,
&self.last_saved_key,
snapshot,
) {
Ok(()) => Ok(()),
Err((snapshot, err)) => {
*self.pending.lock() = Some(snapshot);
Err(err)
}
}
}
fn rollback(&self) -> Result<(), StorageError> {
*self.pending.lock() = None;
Ok(())
}
fn list_by_prefix_bytes<'a>(
&'a self,
prefix: &str,
) -> Box<dyn Iterator<Item = Result<(String, Vec<u8>), StorageError>> + 'a> {
Box::new(PrefixIter::new(&*self.backend, prefix))
}
fn compact(&self) -> Result<(), StorageError> {
self.flush()
}
}
impl<B, T, C> SnapshotStorageTier<T> for SnapshotStorage<B, T, C>
where
B: StorageBackend + ?Sized,
T: Send + Sync + 'static,
C: Codec<T>,
{
fn save(&self, snapshot: T) -> Result<(), StorageError> {
if let Some(filter) = &self.filter {
if !filter(&snapshot) {
return Ok(());
}
}
let captured: Option<T> = {
let mut pending = self.pending.lock();
*pending = Some(snapshot);
let mut count = self.write_count.lock();
let prev = *count;
*count = count.saturating_add(1);
let new = *count;
let compact_trigger = matches!(
self.compact_every,
Some(n) if n > 0 && (prev / u64::from(n)) != (new / u64::from(n))
);
let trigger = compact_trigger || self.debounce_ms.is_none();
if trigger {
pending.take()
} else {
None
}
};
if let Some(snap) = captured {
if let Err((snap, err)) = Self::try_flush(
&self.backend,
&self.codec,
&self.key_of,
&self.last_saved_key,
snap,
) {
*self.pending.lock() = Some(snap);
return Err(err);
}
}
Ok(())
}
fn load(&self) -> Result<Option<T>, StorageError> {
let key = self
.last_saved_key
.lock()
.clone()
.unwrap_or_else(|| self.name.clone());
match self.backend.read(&key)? {
Some(bytes) if !bytes.is_empty() => Ok(Some(self.codec.decode(&bytes)?)),
_ => Ok(None),
}
}
}
pub struct AppendLogStorage<B, T, C = JsonCodec>
where
B: StorageBackend + ?Sized,
T: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
C: Codec<Vec<T>>,
{
backend: Arc<B>,
codec: C,
name: String,
debounce_ms: Option<u32>,
compact_every: Option<u32>,
mode: AppendLogMode,
key_of: KeyOfFn<T>,
pending: Mutex<std::collections::HashMap<String, Vec<T>>>,
append_count: Mutex<u64>,
rollback_epoch: AtomicU64,
}
pub struct AppendLogStorageOptions<T, C = JsonCodec>
where
T: Send + Sync + 'static,
C: Codec<Vec<T>>,
{
pub name: Option<String>,
pub codec: C,
pub debounce_ms: Option<u32>,
pub compact_every: Option<u32>,
pub key_of: Option<KeyOfFn<T>>,
pub mode: AppendLogMode,
}
impl<T> Default for AppendLogStorageOptions<T, JsonCodec>
where
T: Serialize + DeserializeOwned + Send + Sync + 'static,
{
fn default() -> Self {
Self {
name: None,
codec: JsonCodec,
debounce_ms: None,
compact_every: None,
key_of: None,
mode: AppendLogMode::Append,
}
}
}
pub fn append_log_storage<B, T, C>(
backend: Arc<B>,
opts: AppendLogStorageOptions<T, C>,
) -> AppendLogStorage<B, T, C>
where
B: StorageBackend + ?Sized,
T: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
C: Codec<Vec<T>>,
{
assert!(
opts.compact_every != Some(0),
"append_log_storage: compact_every must be None or Some(n) where n >= 1, got Some(0)",
);
let name = opts.name.unwrap_or_else(|| backend.name().to_string());
let fallback_key = name.clone();
let key_of = opts
.key_of
.unwrap_or_else(|| Box::new(move |_| fallback_key.clone()));
AppendLogStorage {
backend,
codec: opts.codec,
name,
debounce_ms: opts.debounce_ms,
compact_every: opts.compact_every,
mode: opts.mode,
key_of,
pending: Mutex::new(std::collections::HashMap::new()),
append_count: Mutex::new(0),
rollback_epoch: AtomicU64::new(0),
}
}
pub fn memory_append_log<T, C>(
opts: AppendLogStorageOptions<T, C>,
) -> AppendLogStorage<MemoryBackend, T, C>
where
T: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
C: Codec<Vec<T>>,
{
append_log_storage(memory_backend(), opts)
}
impl<B, T, C> BaseStorageTier for AppendLogStorage<B, T, C>
where
B: StorageBackend + ?Sized,
T: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
C: Codec<Vec<T>>,
{
fn name(&self) -> &str {
&self.name
}
fn debounce_ms(&self) -> Option<u32> {
self.debounce_ms
}
fn compact_every(&self) -> Option<u32> {
self.compact_every
}
fn flush(&self) -> Result<(), StorageError> {
let scheduled_epoch = self.rollback_epoch.load(Ordering::Acquire);
let mut buckets = std::mem::take(&mut *self.pending.lock());
let keys: Vec<String> = buckets.keys().cloned().collect();
for key in keys {
if self.rollback_epoch.load(Ordering::Acquire) != scheduled_epoch {
return Ok(());
}
let bucket = match buckets.remove(&key) {
Some(b) if !b.is_empty() => b,
_ => continue,
};
let (final_payload, restore_payload): (Vec<T>, Vec<T>) = match self.mode {
AppendLogMode::Overwrite => {
let snapshot = bucket.clone();
(bucket, snapshot)
}
AppendLogMode::Append => {
let existing = match self.backend.read(&key) {
Ok(e) => e,
Err(e) => {
buckets.insert(key, bucket);
*self.pending.lock() = buckets;
return Err(e);
}
};
let mut merged = match existing {
Some(bytes) if !bytes.is_empty() => match self.codec.decode(&bytes) {
Ok(v) => v,
Err(e) => {
buckets.insert(key, bucket);
*self.pending.lock() = buckets;
return Err(e.into());
}
},
_ => Vec::new(),
};
let new_entries_backup = bucket.clone();
merged.extend(bucket);
(merged, new_entries_backup)
}
};
let encoded = match self.codec.encode(&final_payload) {
Ok(b) => b,
Err(e) => {
buckets.insert(key, restore_payload);
*self.pending.lock() = buckets;
return Err(e.into());
}
};
if let Err(e) = self.backend.write(&key, &encoded) {
buckets.insert(key, restore_payload);
*self.pending.lock() = buckets;
return Err(e);
}
}
Ok(())
}
fn rollback(&self) -> Result<(), StorageError> {
self.rollback_epoch.fetch_add(1, Ordering::AcqRel);
self.pending.lock().clear();
Ok(())
}
fn list_by_prefix_bytes<'a>(
&'a self,
prefix: &str,
) -> Box<dyn Iterator<Item = Result<(String, Vec<u8>), StorageError>> + 'a> {
Box::new(PrefixIter::new(&*self.backend, prefix))
}
}
impl<B, T, C> AppendLogStorageTier<T> for AppendLogStorage<B, T, C>
where
B: StorageBackend + ?Sized,
T: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
C: Codec<Vec<T>>,
{
fn append_entries(&self, entries: &[T]) -> Result<(), StorageError> {
if entries.is_empty() {
return Ok(());
}
let trigger_now = {
let mut pending = self.pending.lock();
for entry in entries {
let k = (self.key_of)(entry);
pending.entry(k).or_default().push(entry.clone());
}
let mut count = self.append_count.lock();
let prev = *count;
*count = count.saturating_add(entries.len() as u64);
let new = *count;
let compact_trigger = matches!(
self.compact_every,
Some(n) if n > 0 && (prev / u64::from(n)) != (new / u64::from(n))
);
compact_trigger || self.debounce_ms.is_none()
};
if trigger_now {
self.flush()?;
}
Ok(())
}
fn mode(&self) -> AppendLogMode {
self.mode
}
fn load_entries(&self, opts: LoadEntriesOpts<'_>) -> Result<AppendLoadResult<T>, StorageError> {
let mut keys = match self.backend.list(opts.key_filter.unwrap_or("")) {
Ok(ks) => ks,
Err(StorageError::BackendNoListSupport { .. }) => match opts.key_filter {
Some(k) => vec![k.to_string()],
None => vec![self.name.clone()],
},
Err(e) => return Err(e),
};
keys.sort();
let start: u64 = opts.cursor.map_or(0, |c| c.position);
let page_size = opts.page_size.filter(|n| *n > 0);
let want_decoded_at_least = page_size.map(|n| start + u64::from(n) + 1);
let mut decoded: Vec<T> = Vec::new();
let mut total_seen: u64 = 0;
for k in keys {
if let Some(want) = want_decoded_at_least {
if total_seen >= want {
break;
}
}
if let Some(bytes) = self.backend.read(&k)? {
if !bytes.is_empty() {
let entries: Vec<T> = self.codec.decode(&bytes)?;
total_seen = total_seen.saturating_add(entries.len() as u64);
decoded.extend(entries);
}
}
}
let start_idx: usize = start.try_into().unwrap_or(usize::MAX).min(decoded.len());
let mut window: Vec<T> = decoded.split_off(start_idx);
let next_cursor: Option<AppendCursor> = match page_size {
Some(n) => {
let n_usize: usize = (n as usize).min(window.len());
let has_more = window.len() > n_usize;
window.truncate(n_usize);
if has_more {
Some(AppendCursor::from_position(start + u64::from(n)))
} else {
None
}
}
None => None,
};
Ok(AppendLoadResult {
entries: window,
cursor: next_cursor,
})
}
}
pub struct KvStorage<B, T, C = JsonCodec>
where
B: StorageBackend + ?Sized,
T: Send + Sync + 'static,
C: Codec<T>,
{
backend: Arc<B>,
codec: C,
name: String,
debounce_ms: Option<u32>,
compact_every: Option<u32>,
filter: Option<KvFilterFn<T>>,
pending: Mutex<std::collections::HashMap<String, T>>,
write_count: Mutex<u64>,
}
pub struct KvStorageOptions<T, C = JsonCodec>
where
T: Send + Sync + 'static,
C: Codec<T>,
{
pub name: Option<String>,
pub codec: C,
pub debounce_ms: Option<u32>,
pub compact_every: Option<u32>,
pub filter: Option<KvFilterFn<T>>,
}
impl<T> Default for KvStorageOptions<T, JsonCodec>
where
T: Serialize + DeserializeOwned + Send + Sync + 'static,
{
fn default() -> Self {
Self {
name: None,
codec: JsonCodec,
debounce_ms: None,
compact_every: None,
filter: None,
}
}
}
pub fn kv_storage<B, T, C>(backend: Arc<B>, opts: KvStorageOptions<T, C>) -> KvStorage<B, T, C>
where
B: StorageBackend + ?Sized,
T: Send + Sync + 'static,
C: Codec<T>,
{
assert!(
opts.compact_every != Some(0),
"kv_storage: compact_every must be None or Some(n) where n >= 1, got Some(0)",
);
let name = opts.name.unwrap_or_else(|| backend.name().to_string());
KvStorage {
backend,
codec: opts.codec,
name,
debounce_ms: opts.debounce_ms,
compact_every: opts.compact_every,
filter: opts.filter,
pending: Mutex::new(std::collections::HashMap::new()),
write_count: Mutex::new(0),
}
}
pub fn memory_kv<T, C>(opts: KvStorageOptions<T, C>) -> KvStorage<MemoryBackend, T, C>
where
T: Send + Sync + 'static,
C: Codec<T>,
{
kv_storage(memory_backend(), opts)
}
impl<B, T, C> BaseStorageTier for KvStorage<B, T, C>
where
B: StorageBackend + ?Sized,
T: Send + Sync + 'static,
C: Codec<T>,
{
fn name(&self) -> &str {
&self.name
}
fn debounce_ms(&self) -> Option<u32> {
self.debounce_ms
}
fn compact_every(&self) -> Option<u32> {
self.compact_every
}
fn flush(&self) -> Result<(), StorageError> {
let mut entries = std::mem::take(&mut *self.pending.lock());
let keys: Vec<String> = entries.keys().cloned().collect();
for key in keys {
let Some(value) = entries.remove(&key) else {
continue;
};
let bytes = match self.codec.encode(&value) {
Ok(b) => b,
Err(e) => {
entries.insert(key, value);
*self.pending.lock() = entries;
return Err(e.into());
}
};
if let Err(e) = self.backend.write(&key, &bytes) {
entries.insert(key, value);
*self.pending.lock() = entries;
return Err(e);
}
}
Ok(())
}
fn rollback(&self) -> Result<(), StorageError> {
self.pending.lock().clear();
Ok(())
}
fn list_by_prefix_bytes<'a>(
&'a self,
prefix: &str,
) -> Box<dyn Iterator<Item = Result<(String, Vec<u8>), StorageError>> + 'a> {
Box::new(PrefixIter::new(&*self.backend, prefix))
}
}
impl<B, T, C> KvStorageTier<T> for KvStorage<B, T, C>
where
B: StorageBackend + ?Sized,
T: Send + Sync + 'static,
C: Codec<T>,
{
fn save(&self, key: &str, value: T) -> Result<(), StorageError> {
if let Some(filter) = &self.filter {
if !filter(key, &value) {
return Ok(());
}
}
let trigger_now = {
self.pending.lock().insert(key.to_string(), value);
let mut count = self.write_count.lock();
let prev = *count;
*count = count.saturating_add(1);
let new = *count;
let compact_trigger = matches!(
self.compact_every,
Some(n) if n > 0 && (prev / u64::from(n)) != (new / u64::from(n))
);
compact_trigger || self.debounce_ms.is_none()
};
if trigger_now {
self.flush()?;
}
Ok(())
}
fn load(&self, key: &str) -> Result<Option<T>, StorageError> {
match self.backend.read(key)? {
Some(bytes) if !bytes.is_empty() => Ok(Some(self.codec.decode(&bytes)?)),
_ => Ok(None),
}
}
fn delete(&self, key: &str) -> Result<(), StorageError> {
self.backend.delete(key)?;
self.pending.lock().remove(key);
Ok(())
}
fn list(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
self.backend.list(prefix)
}
}