use std::collections::HashMap;
use std::sync::Arc;
use futures_util::{Stream, StreamExt};
use parking_lot::RwLock;
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;
const DEFAULT_MAX_VALUE_SIZE: usize = 1024 * 1024;
const DEFAULT_MAX_KEYS: usize = 10000;
const WATCH_CHANNEL_CAPACITY: usize = 1024;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum KvError {
NotFound,
ValueTooLarge,
QuotaExceeded,
InvalidKey,
Storage(String),
}
impl std::fmt::Display for KvError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KvError::NotFound => write!(f, "key not found"),
KvError::ValueTooLarge => write!(f, "value too large"),
KvError::QuotaExceeded => write!(f, "storage quota exceeded"),
KvError::InvalidKey => write!(f, "invalid key format"),
KvError::Storage(msg) => write!(f, "storage error: {msg}"),
}
}
}
impl std::error::Error for KvError {}
#[derive(Debug, Clone)]
pub struct KvEntry {
pub value: Vec<u8>,
pub expires_at: Option<std::time::Instant>,
}
impl KvEntry {
#[must_use]
pub fn new(value: Vec<u8>) -> Self {
Self {
value,
expires_at: None,
}
}
#[must_use]
pub fn with_ttl(value: Vec<u8>, ttl_ns: u64) -> Self {
let expires_at = Some(std::time::Instant::now() + std::time::Duration::from_nanos(ttl_ns));
Self { value, expires_at }
}
#[must_use]
pub fn is_expired(&self) -> bool {
self.expires_at
.is_some_and(|exp| std::time::Instant::now() >= exp)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum KvEventKind {
Set,
Delete,
}
#[derive(Debug, Clone)]
pub struct KvEvent {
pub key: String,
pub kind: KvEventKind,
pub value: Option<Vec<u8>>,
}
#[async_trait::async_trait]
pub trait KvBackend: Send + Sync + std::fmt::Debug {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, KvError>;
async fn set(&self, key: &str, value: &[u8]) -> Result<(), KvError>;
async fn set_with_ttl(&self, key: &str, value: &[u8], ttl_ns: u64) -> Result<(), KvError>;
async fn delete(&self, key: &str) -> Result<bool, KvError>;
async fn exists(&self, key: &str) -> Result<bool, KvError>;
async fn list_keys(&self, prefix: &str) -> Result<Vec<String>, KvError>;
async fn increment(&self, key: &str, delta: i64) -> Result<i64, KvError>;
async fn compare_and_swap(
&self,
key: &str,
expected: Option<&[u8]>,
new: &[u8],
) -> Result<bool, KvError>;
}
#[derive(Clone)]
pub struct KvStore {
inner: Arc<RwLock<HashMap<String, KvEntry>>>,
max_value_size: usize,
max_keys: usize,
events: broadcast::Sender<KvEvent>,
backend: Option<Arc<dyn KvBackend>>,
}
impl std::fmt::Debug for KvStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KvStore")
.field("len", &self.inner.read().len())
.field("max_value_size", &self.max_value_size)
.field("max_keys", &self.max_keys)
.field("clustered", &self.backend.is_some())
.finish_non_exhaustive()
}
}
impl Default for KvStore {
fn default() -> Self {
Self::new()
}
}
impl KvStore {
#[must_use]
pub fn new() -> Self {
let (events, _rx) = broadcast::channel(WATCH_CHANNEL_CAPACITY);
Self {
inner: Arc::new(RwLock::new(HashMap::new())),
max_value_size: DEFAULT_MAX_VALUE_SIZE,
max_keys: DEFAULT_MAX_KEYS,
events,
backend: None,
}
}
#[must_use]
pub fn with_backend(mut self, backend: Arc<dyn KvBackend>) -> Self {
self.backend = Some(backend);
self
}
#[must_use]
pub fn is_clustered(&self) -> bool {
self.backend.is_some()
}
#[must_use]
pub fn with_max_value_size(mut self, size: usize) -> Self {
self.max_value_size = size;
self
}
#[must_use]
pub fn with_max_keys(mut self, count: usize) -> Self {
self.max_keys = count;
self
}
pub fn set_max_value_size(&mut self, size: usize) {
self.max_value_size = size;
}
pub fn set_max_keys(&mut self, count: usize) {
self.max_keys = count;
}
#[must_use]
pub fn max_value_size(&self) -> usize {
self.max_value_size
}
#[must_use]
pub fn max_keys(&self) -> usize {
self.max_keys
}
pub fn validate_key(key: &str) -> Result<(), KvError> {
if key.is_empty() {
return Err(KvError::InvalidKey);
}
if key.len() > 1024 {
return Err(KvError::InvalidKey);
}
if !key
.chars()
.all(|c| c.is_alphanumeric() || "-_./:".contains(c))
{
return Err(KvError::InvalidKey);
}
Ok(())
}
pub fn clean_expired(&self) {
let mut kv = self.inner.write();
kv.retain(|_, entry| !entry.is_expired());
}
fn emit(&self, key: &str, kind: KvEventKind, value: Option<Vec<u8>>) {
let _ = self.events.send(KvEvent {
key: key.to_string(),
kind,
value,
});
}
pub fn get(&self, key: &str) -> Result<Option<Vec<u8>>, KvError> {
Self::validate_key(key)?;
self.clean_expired();
let kv = self.inner.read();
match kv.get(key) {
Some(entry) if !entry.is_expired() => Ok(Some(entry.value.clone())),
_ => Ok(None),
}
}
pub fn get_string(&self, key: &str) -> Result<Option<String>, KvError> {
match self.get(key)? {
Some(bytes) => String::from_utf8(bytes)
.map(Some)
.map_err(|e| KvError::Storage(format!("invalid UTF-8: {e}"))),
None => Ok(None),
}
}
pub fn set(&self, key: &str, value: &[u8]) -> Result<(), KvError> {
Self::validate_key(key)?;
if value.len() > self.max_value_size {
return Err(KvError::ValueTooLarge);
}
{
let mut kv = self.inner.write();
if !kv.contains_key(key) && kv.len() >= self.max_keys {
return Err(KvError::QuotaExceeded);
}
kv.insert(key.to_string(), KvEntry::new(value.to_vec()));
}
self.emit(key, KvEventKind::Set, Some(value.to_vec()));
Ok(())
}
pub fn set_string(&self, key: &str, value: &str) -> Result<(), KvError> {
self.set(key, value.as_bytes())
}
pub fn set_with_ttl(&self, key: &str, value: &[u8], ttl_ns: u64) -> Result<(), KvError> {
Self::validate_key(key)?;
if value.len() > self.max_value_size {
return Err(KvError::ValueTooLarge);
}
{
let mut kv = self.inner.write();
if !kv.contains_key(key) && kv.len() >= self.max_keys {
return Err(KvError::QuotaExceeded);
}
kv.insert(key.to_string(), KvEntry::with_ttl(value.to_vec(), ttl_ns));
}
self.emit(key, KvEventKind::Set, Some(value.to_vec()));
Ok(())
}
pub fn delete(&self, key: &str) -> Result<bool, KvError> {
Self::validate_key(key)?;
let removed = {
let mut kv = self.inner.write();
kv.remove(key).is_some()
};
if removed {
self.emit(key, KvEventKind::Delete, None);
}
Ok(removed)
}
#[must_use]
pub fn exists(&self, key: &str) -> bool {
self.clean_expired();
let kv = self.inner.read();
kv.get(key).is_some_and(|e| !e.is_expired())
}
pub fn list_keys(&self, prefix: &str) -> Result<Vec<String>, KvError> {
self.clean_expired();
let kv = self.inner.read();
Ok(kv
.iter()
.filter(|(k, entry)| k.starts_with(prefix) && !entry.is_expired())
.map(|(k, _)| k.clone())
.collect())
}
pub fn increment(&self, key: &str, delta: i64) -> Result<i64, KvError> {
Self::validate_key(key)?;
let (new_value, bytes) = {
let mut kv = self.inner.write();
let current: i64 = match kv.get(key) {
Some(entry) if !entry.is_expired() => {
let s = String::from_utf8(entry.value.clone())
.map_err(|e| KvError::Storage(format!("invalid number: {e}")))?;
s.parse()
.map_err(|e| KvError::Storage(format!("invalid number: {e}")))?
}
_ => 0,
};
let new_value = current.saturating_add(delta);
let value_str = new_value.to_string();
if !kv.contains_key(key) && kv.len() >= self.max_keys {
return Err(KvError::QuotaExceeded);
}
let bytes = value_str.into_bytes();
kv.insert(key.to_string(), KvEntry::new(bytes.clone()));
(new_value, bytes)
};
self.emit(key, KvEventKind::Set, Some(bytes));
Ok(new_value)
}
pub fn compare_and_swap(
&self,
key: &str,
expected: Option<&[u8]>,
new_value: &[u8],
) -> Result<bool, KvError> {
Self::validate_key(key)?;
if new_value.len() > self.max_value_size {
return Err(KvError::ValueTooLarge);
}
let swapped = {
let mut kv = self.inner.write();
let current = kv.get(key).and_then(|e| {
if e.is_expired() {
None
} else {
Some(e.value.as_slice())
}
});
if current == expected {
if current.is_none() && kv.len() >= self.max_keys {
return Err(KvError::QuotaExceeded);
}
kv.insert(key.to_string(), KvEntry::new(new_value.to_vec()));
true
} else {
false
}
};
if swapped {
self.emit(key, KvEventKind::Set, Some(new_value.to_vec()));
}
Ok(swapped)
}
pub async fn get_async(&self, key: &str) -> Result<Option<Vec<u8>>, KvError> {
match &self.backend {
Some(b) => b.get(key).await,
None => self.get(key),
}
}
pub async fn set_async(&self, key: &str, value: &[u8]) -> Result<(), KvError> {
match &self.backend {
Some(b) => b.set(key, value).await,
None => self.set(key, value),
}
}
pub async fn set_with_ttl_async(
&self,
key: &str,
value: &[u8],
ttl_ns: u64,
) -> Result<(), KvError> {
match &self.backend {
Some(b) => b.set_with_ttl(key, value, ttl_ns).await,
None => self.set_with_ttl(key, value, ttl_ns),
}
}
pub async fn delete_async(&self, key: &str) -> Result<bool, KvError> {
match &self.backend {
Some(b) => b.delete(key).await,
None => self.delete(key),
}
}
pub async fn exists_async(&self, key: &str) -> Result<bool, KvError> {
match &self.backend {
Some(b) => b.exists(key).await,
None => Ok(self.exists(key)),
}
}
pub async fn list_keys_async(&self, prefix: &str) -> Result<Vec<String>, KvError> {
match &self.backend {
Some(b) => b.list_keys(prefix).await,
None => self.list_keys(prefix),
}
}
pub async fn increment_async(&self, key: &str, delta: i64) -> Result<i64, KvError> {
match &self.backend {
Some(b) => b.increment(key, delta).await,
None => self.increment(key, delta),
}
}
pub async fn compare_and_swap_async(
&self,
key: &str,
expected: Option<&[u8]>,
new: &[u8],
) -> Result<bool, KvError> {
match &self.backend {
Some(b) => b.compare_and_swap(key, expected, new).await,
None => self.compare_and_swap(key, expected, new),
}
}
pub fn clear(&self) {
self.inner.write().clear();
}
#[must_use]
pub fn subscribe(&self) -> broadcast::Receiver<KvEvent> {
self.events.subscribe()
}
pub fn watch_prefix(&self, prefix: impl Into<String>) -> impl Stream<Item = KvEvent> {
let prefix = prefix.into();
BroadcastStream::new(self.events.subscribe()).filter_map(move |res| {
let prefix = prefix.clone();
async move {
match res {
Ok(event) if event.key.starts_with(&prefix) => Some(event),
_ => None,
}
}
})
}
}
static GLOBAL_KV: std::sync::OnceLock<KvStore> = std::sync::OnceLock::new();
pub fn set_global_kv(store: KvStore) {
if GLOBAL_KV.set(store).is_err() {
tracing::warn!("global KvStore already set; ignoring duplicate set_global_kv call");
}
}
#[must_use]
pub fn global_kv() -> Option<KvStore> {
GLOBAL_KV.get().cloned()
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::StreamExt;
#[test]
fn set_and_get() {
let store = KvStore::new();
store.set("foo", b"bar").unwrap();
assert_eq!(store.get("foo").unwrap(), Some(b"bar".to_vec()));
assert_eq!(store.get_string("foo").unwrap(), Some("bar".to_string()));
}
#[test]
fn get_missing_returns_none() {
let store = KvStore::new();
assert_eq!(store.get("missing").unwrap(), None);
}
#[test]
fn ttl_expiry() {
let store = KvStore::new();
store.set_with_ttl("temp", b"v", 1_000_000).unwrap();
std::thread::sleep(std::time::Duration::from_millis(5));
assert_eq!(store.get("temp").unwrap(), None);
assert!(!store.exists("temp"));
}
#[test]
fn delete_reports_existence() {
let store = KvStore::new();
store.set("k", b"v").unwrap();
assert!(store.delete("k").unwrap());
assert!(!store.delete("k").unwrap());
assert_eq!(store.get("k").unwrap(), None);
}
#[test]
fn list_keys_prefix() {
let store = KvStore::new();
store.set("a/1", b"1").unwrap();
store.set("a/2", b"2").unwrap();
store.set("b/1", b"3").unwrap();
let mut keys = store.list_keys("a/").unwrap();
keys.sort();
assert_eq!(keys, vec!["a/1".to_string(), "a/2".to_string()]);
}
#[test]
fn increment() {
let store = KvStore::new();
assert_eq!(store.increment("counter", 5).unwrap(), 5);
assert_eq!(store.increment("counter", 3).unwrap(), 8);
assert_eq!(store.increment("counter", -10).unwrap(), -2);
}
#[test]
fn increment_saturates() {
let store = KvStore::new();
store.set("c", i64::MAX.to_string().as_bytes()).unwrap();
assert_eq!(store.increment("c", 1).unwrap(), i64::MAX);
}
#[test]
fn compare_and_swap_hit_and_miss() {
let store = KvStore::new();
assert!(store.compare_and_swap("k", None, b"v1").unwrap());
assert!(store.compare_and_swap("k", Some(b"v1"), b"v2").unwrap());
assert!(!store.compare_and_swap("k", Some(b"v1"), b"v3").unwrap());
assert_eq!(store.get("k").unwrap(), Some(b"v2".to_vec()));
}
#[test]
fn quota_exceeded() {
let store = KvStore::new().with_max_keys(2);
store.set("a", b"1").unwrap();
store.set("b", b"2").unwrap();
assert_eq!(store.set("c", b"3"), Err(KvError::QuotaExceeded));
assert!(store.set("a", b"x").is_ok());
}
#[test]
fn value_too_large() {
let store = KvStore::new().with_max_value_size(4);
assert_eq!(store.set("k", b"toolong"), Err(KvError::ValueTooLarge));
}
#[test]
fn invalid_key() {
let store = KvStore::new();
assert_eq!(store.set("", b"v"), Err(KvError::InvalidKey));
assert_eq!(store.set("bad key", b"v"), Err(KvError::InvalidKey));
}
#[test]
fn clone_shares_state() {
let a = KvStore::new();
let b = a.clone();
a.set("k", b"v").unwrap();
assert_eq!(b.get("k").unwrap(), Some(b"v".to_vec()));
}
#[tokio::test]
async fn watch_receives_set_event() {
let store = KvStore::new();
let mut rx = store.subscribe();
store.set("watched", b"hello").unwrap();
let event = rx.recv().await.unwrap();
assert_eq!(event.key, "watched");
assert_eq!(event.kind, KvEventKind::Set);
assert_eq!(event.value, Some(b"hello".to_vec()));
}
#[tokio::test]
async fn watch_prefix_filters() {
let store = KvStore::new();
let mut stream = Box::pin(store.watch_prefix("user/"));
store.set("other/1", b"x").unwrap();
store.set("user/1", b"y").unwrap();
let event = stream.next().await.unwrap();
assert_eq!(event.key, "user/1");
assert_eq!(event.value, Some(b"y".to_vec()));
}
#[tokio::test]
async fn watch_receives_delete_event() {
let store = KvStore::new();
store.set("k", b"v").unwrap();
let mut rx = store.subscribe();
store.delete("k").unwrap();
let event = rx.recv().await.unwrap();
assert_eq!(event.kind, KvEventKind::Delete);
assert_eq!(event.key, "k");
assert_eq!(event.value, None);
}
#[derive(Debug, Default)]
struct MockBackend {
map: std::sync::Mutex<HashMap<String, Vec<u8>>>,
calls: std::sync::Mutex<Vec<String>>,
}
impl MockBackend {
fn record(&self, op: &str) {
self.calls.lock().unwrap().push(op.to_string());
}
fn calls(&self) -> Vec<String> {
self.calls.lock().unwrap().clone()
}
}
#[async_trait::async_trait]
impl KvBackend for MockBackend {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, KvError> {
self.record("get");
Ok(self.map.lock().unwrap().get(key).cloned())
}
async fn set(&self, key: &str, value: &[u8]) -> Result<(), KvError> {
self.record("set");
self.map
.lock()
.unwrap()
.insert(key.to_string(), value.to_vec());
Ok(())
}
async fn set_with_ttl(&self, key: &str, value: &[u8], _ttl_ns: u64) -> Result<(), KvError> {
self.record("set_with_ttl");
self.map
.lock()
.unwrap()
.insert(key.to_string(), value.to_vec());
Ok(())
}
async fn delete(&self, key: &str) -> Result<bool, KvError> {
self.record("delete");
Ok(self.map.lock().unwrap().remove(key).is_some())
}
async fn exists(&self, key: &str) -> Result<bool, KvError> {
self.record("exists");
Ok(self.map.lock().unwrap().contains_key(key))
}
async fn list_keys(&self, prefix: &str) -> Result<Vec<String>, KvError> {
self.record("list_keys");
Ok(self
.map
.lock()
.unwrap()
.keys()
.filter(|k| k.starts_with(prefix))
.cloned()
.collect())
}
async fn increment(&self, key: &str, delta: i64) -> Result<i64, KvError> {
self.record("increment");
let mut map = self.map.lock().unwrap();
let current: i64 = map
.get(key)
.map_or(0, |v| String::from_utf8_lossy(v).parse().unwrap_or(0));
let new = current + delta;
map.insert(key.to_string(), new.to_string().into_bytes());
Ok(new)
}
async fn compare_and_swap(
&self,
key: &str,
expected: Option<&[u8]>,
new: &[u8],
) -> Result<bool, KvError> {
self.record("compare_and_swap");
let mut map = self.map.lock().unwrap();
let current = map.get(key).map(Vec::as_slice);
if current == expected {
map.insert(key.to_string(), new.to_vec());
Ok(true)
} else {
Ok(false)
}
}
}
#[test]
fn is_clustered_reflects_backend() {
let local = KvStore::new();
assert!(!local.is_clustered());
let clustered = KvStore::new().with_backend(Arc::new(MockBackend::default()));
assert!(clustered.is_clustered());
}
#[tokio::test]
async fn async_routes_to_backend_when_clustered() {
let backend = Arc::new(MockBackend::default());
let store = KvStore::new().with_backend(backend.clone());
store.set_async("foo", b"bar").await.unwrap();
assert_eq!(store.get_async("foo").await.unwrap(), Some(b"bar".to_vec()));
assert!(store.exists_async("foo").await.unwrap());
store.set_with_ttl_async("ttlk", b"v", 1_000).await.unwrap();
assert_eq!(store.increment_async("counter", 5).await.unwrap(), 5);
assert_eq!(store.increment_async("counter", 3).await.unwrap(), 8);
assert!(store
.compare_and_swap_async("cas", None, b"v1")
.await
.unwrap());
let mut keys = store.list_keys_async("").await.unwrap();
keys.sort();
assert_eq!(
keys,
vec![
"cas".to_string(),
"counter".to_string(),
"foo".to_string(),
"ttlk".to_string(),
]
);
assert!(store.delete_async("foo").await.unwrap());
assert!(!store.exists_async("foo").await.unwrap());
assert_eq!(store.get("foo").unwrap(), None);
assert!(!store.exists("counter"));
assert_eq!(store.list_keys("").unwrap(), Vec::<String>::new());
let calls = backend.calls();
for op in [
"set",
"get",
"exists",
"set_with_ttl",
"increment",
"compare_and_swap",
"list_keys",
"delete",
] {
assert!(
calls.contains(&op.to_string()),
"missing backend call: {op}"
);
}
}
#[tokio::test]
async fn async_uses_local_when_not_clustered() {
let store = KvStore::new();
store.set_async("foo", b"bar").await.unwrap();
assert_eq!(store.get("foo").unwrap(), Some(b"bar".to_vec()));
assert_eq!(
store.get_async("foo").await.unwrap(),
store.get("foo").unwrap()
);
assert_eq!(
store.exists_async("foo").await.unwrap(),
store.exists("foo")
);
store
.set_with_ttl_async("ttlk", b"v", 1_000_000_000)
.await
.unwrap();
assert!(store.exists("ttlk"));
assert_eq!(store.increment_async("c", 4).await.unwrap(), 4);
assert_eq!(store.increment("c", 0).unwrap(), 4);
assert!(store
.compare_and_swap_async("cas", None, b"v1")
.await
.unwrap());
assert_eq!(store.get("cas").unwrap(), Some(b"v1".to_vec()));
let mut a = store.list_keys_async("").await.unwrap();
let mut b = store.list_keys("").unwrap();
a.sort();
b.sort();
assert_eq!(a, b);
assert!(store.delete_async("foo").await.unwrap());
assert_eq!(store.get("foo").unwrap(), None);
}
#[test]
fn global_kv_accessor_shares_state() {
set_global_kv(KvStore::new());
let a = global_kv().expect("global KvStore should be set after set_global_kv");
let b = global_kv().expect("global KvStore should still be set");
a.set("global-kv-share-test", b"shared").unwrap();
assert_eq!(
b.get("global-kv-share-test").unwrap(),
Some(b"shared".to_vec()),
"writes through one global_kv() clone must be visible through another"
);
}
}