use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use rand::Rng as _;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use super::{Bucket, Key, KeyValue, Store, StoreError, StoreOutcome, WatchEvent};
#[derive(Clone, Debug)]
enum MemoryEvent {
Put { key: String, value: bytes::Bytes },
Delete { key: String },
}
#[derive(Clone)]
pub struct MemoryStore {
inner: Arc<MemoryStoreInner>,
connection_id: u64,
}
impl Default for MemoryStore {
fn default() -> Self {
Self::new()
}
}
struct MemoryStoreInner {
data: parking_lot::Mutex<HashMap<String, MemoryBucket>>,
change_sender: UnboundedSender<MemoryEvent>,
change_receiver: tokio::sync::Mutex<UnboundedReceiver<MemoryEvent>>,
}
pub struct MemoryBucketRef {
name: String,
inner: Arc<MemoryStoreInner>,
}
struct MemoryBucket {
data: HashMap<String, (u64, bytes::Bytes)>,
}
impl MemoryBucket {
fn new() -> Self {
MemoryBucket {
data: HashMap::new(),
}
}
}
impl MemoryStore {
pub(super) fn new() -> Self {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
MemoryStore {
inner: Arc::new(MemoryStoreInner {
data: parking_lot::Mutex::new(HashMap::new()),
change_sender: tx,
change_receiver: tokio::sync::Mutex::new(rx),
}),
connection_id: rand::rng().random(),
}
}
}
#[async_trait]
impl Store for MemoryStore {
type Bucket = MemoryBucketRef;
async fn get_or_create_bucket(
&self,
bucket_name: &str,
_ttl: Option<Duration>,
) -> Result<Self::Bucket, StoreError> {
let mut locked_data = self.inner.data.lock();
locked_data
.entry(bucket_name.to_string())
.or_insert_with(MemoryBucket::new);
Ok(MemoryBucketRef {
name: bucket_name.to_string(),
inner: self.inner.clone(),
})
}
async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
let locked_data = self.inner.data.lock();
match locked_data.get(bucket_name) {
Some(_) => Ok(Some(MemoryBucketRef {
name: bucket_name.to_string(),
inner: self.inner.clone(),
})),
None => Ok(None),
}
}
fn connection_id(&self) -> u64 {
self.connection_id
}
fn shutdown(&self) {}
}
#[async_trait]
impl Bucket for MemoryBucketRef {
async fn insert(
&self,
key: &Key,
value: bytes::Bytes,
revision: u64,
) -> Result<StoreOutcome, StoreError> {
let mut locked_data = self.inner.data.lock();
let mut b = locked_data.get_mut(&self.name);
let Some(bucket) = b.as_mut() else {
return Err(StoreError::MissingBucket(self.name.to_string()));
};
let outcome = match bucket.data.entry(key.to_string()) {
Entry::Vacant(e) => {
e.insert((revision, value.clone()));
let _ = self.inner.change_sender.send(MemoryEvent::Put {
key: key.to_string(),
value,
});
StoreOutcome::Created(revision)
}
Entry::Occupied(mut entry) => {
let (rev, _v) = entry.get();
if *rev == revision {
StoreOutcome::Exists(revision)
} else {
entry.insert((revision, value));
StoreOutcome::Created(revision)
}
}
};
Ok(outcome)
}
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
let locked_data = self.inner.data.lock();
let Some(bucket) = locked_data.get(&self.name) else {
return Ok(None);
};
Ok(bucket.data.get(&key.0).map(|(_, v)| v.clone()))
}
async fn delete(&self, key: &Key) -> Result<(), StoreError> {
let mut locked_data = self.inner.data.lock();
let Some(bucket) = locked_data.get_mut(&self.name) else {
return Err(StoreError::MissingBucket(self.name.to_string()));
};
if bucket.data.remove(&key.0).is_some() {
let _ = self.inner.change_sender.send(MemoryEvent::Delete {
key: key.to_string(),
});
}
Ok(())
}
async fn watch(
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
let mut existing_items = vec![];
let mut seen_keys = HashSet::new();
let data_lock = self.inner.data.lock();
let Some(bucket) = data_lock.get(&self.name) else {
return Err(StoreError::MissingBucket(self.name.to_string()));
};
for (key, (_rev, v)) in &bucket.data {
seen_keys.insert(key.clone());
let item = KeyValue::new(Key::new(key.clone()), v.clone());
existing_items.push(WatchEvent::Put(item));
}
drop(data_lock);
Ok(Box::pin(async_stream::stream! {
for event in existing_items {
yield event;
}
let mut rcv_lock = self.inner.change_receiver.lock().await;
loop {
match rcv_lock.recv().await {
None => {
break;
},
Some(MemoryEvent::Put { key, value }) => {
if seen_keys.contains(&key) {
continue;
}
let item = KeyValue::new(Key::new(key), value);
yield WatchEvent::Put(item);
},
Some(MemoryEvent::Delete { key }) => {
yield WatchEvent::Delete(Key::new(key));
}
}
}
}))
}
async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError> {
let locked_data = self.inner.data.lock();
match locked_data.get(&self.name) {
Some(bucket) => {
let mut out = HashMap::new();
for (k, (_rev, v)) in bucket.data.iter() {
let key = Key::new([self.name.clone(), k.to_string()].join("/"));
let value = v.clone();
out.insert(key, value);
}
Ok(out)
}
None => Err(StoreError::MissingBucket(self.name.clone())),
}
}
}
#[cfg(test)]
mod tests {
use crate::storage::kv::{Bucket as _, Key, MemoryStore, Store as _};
use std::collections::HashSet;
#[tokio::test]
async fn test_entries_full_path() {
let m = MemoryStore::new();
let bucket = m.get_or_create_bucket("bucket1", None).await.unwrap();
let _ = bucket
.insert(&Key::new("key1".to_string()), "value1".into(), 0)
.await
.unwrap();
let _ = bucket
.insert(&Key::new("key2".to_string()), "value2".into(), 0)
.await
.unwrap();
let entries = bucket.entries().await.unwrap();
let keys: HashSet<Key> = entries.into_keys().collect();
assert!(keys.contains(&Key::new("bucket1/key1".to_string())));
assert!(keys.contains(&Key::new("bucket1/key2".to_string())));
}
}