use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::Mutex;
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
#[derive(Clone)]
pub struct MemoryStorage {
inner: Arc<MemoryStorageInner>,
}
impl Default for MemoryStorage {
fn default() -> Self {
Self::new()
}
}
struct MemoryStorageInner {
data: Mutex<HashMap<String, MemoryBucket>>,
change_sender: UnboundedSender<(String, String)>,
change_receiver: Mutex<UnboundedReceiver<(String, String)>>,
}
pub struct MemoryBucketRef {
name: String,
inner: Arc<MemoryStorageInner>,
}
struct MemoryBucket {
data: HashMap<String, (u64, String)>,
}
impl MemoryBucket {
fn new() -> Self {
MemoryBucket {
data: HashMap::new(),
}
}
}
impl MemoryStorage {
pub fn new() -> Self {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
MemoryStorage {
inner: Arc::new(MemoryStorageInner {
data: Mutex::new(HashMap::new()),
change_sender: tx,
change_receiver: Mutex::new(rx),
}),
}
}
}
#[async_trait]
impl KeyValueStore for MemoryStorage {
async fn get_or_create_bucket(
&self,
bucket_name: &str,
_ttl: Option<Duration>,
) -> Result<Box<dyn KeyValueBucket>, StorageError> {
let mut locked_data = self.inner.data.lock().await;
locked_data
.entry(bucket_name.to_string())
.or_insert_with(MemoryBucket::new);
Ok(Box::new(MemoryBucketRef {
name: bucket_name.to_string(),
inner: self.inner.clone(),
}))
}
async fn get_bucket(
&self,
bucket_name: &str,
) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> {
let locked_data = self.inner.data.lock().await;
match locked_data.get(bucket_name) {
Some(_) => Ok(Some(Box::new(MemoryBucketRef {
name: bucket_name.to_string(),
inner: self.inner.clone(),
}))),
None => Ok(None),
}
}
}
#[async_trait]
impl KeyValueBucket for MemoryBucketRef {
async fn insert(
&self,
key: String,
value: String,
revision: u64,
) -> Result<StorageOutcome, StorageError> {
let mut locked_data = self.inner.data.lock().await;
let mut b = locked_data.get_mut(&self.name);
let Some(bucket) = b.as_mut() else {
return Err(StorageError::MissingBucket(self.name.to_string()));
};
let outcome = match bucket.data.entry(key.to_string()) {
Entry::Vacant(e) => {
e.insert((revision, value.clone()));
let _ = self.inner.change_sender.send((key, value));
StorageOutcome::Created(revision)
}
Entry::Occupied(mut entry) => {
let (rev, _v) = entry.get();
if *rev == revision {
StorageOutcome::Exists(revision)
} else {
entry.insert((revision, value));
StorageOutcome::Created(revision)
}
}
};
Ok(outcome)
}
async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> {
let locked_data = self.inner.data.lock().await;
let Some(bucket) = locked_data.get(&self.name) else {
return Ok(None);
};
Ok(bucket
.data
.get(key)
.map(|(_, v)| bytes::Bytes::from(v.clone())))
}
async fn delete(&self, key: &str) -> Result<(), StorageError> {
let mut locked_data = self.inner.data.lock().await;
let Some(bucket) = locked_data.get_mut(&self.name) else {
return Err(StorageError::MissingBucket(self.name.to_string()));
};
bucket.data.remove(key);
Ok(())
}
async fn watch(
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>
{
Ok(Box::pin(async_stream::stream! {
let mut seen = HashSet::new();
let data_lock = self.inner.data.lock().await;
let Some(bucket) = data_lock.get(&self.name) else {
tracing::error!(bucket_name = self.name, "watch: Missing bucket");
return;
};
for (_rev, v) in bucket.data.values() {
seen.insert(v.clone());
yield bytes::Bytes::from(v.clone());
}
drop(data_lock);
let mut rcv_lock = self.inner.change_receiver.lock().await;
loop {
match rcv_lock.recv().await {
None => {
break;
},
Some((_k, v)) => {
if seen.contains(&v) {
continue;
}
yield bytes::Bytes::from(v.clone());
}
}
}
}))
}
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> {
let locked_data = self.inner.data.lock().await;
match locked_data.get(&self.name) {
Some(bucket) => Ok(bucket
.data
.iter()
.map(|(k, (_rev, v))| (k.to_string(), bytes::Bytes::from(v.clone())))
.collect()),
None => Err(StorageError::MissingBucket(self.name.clone())),
}
}
}