use dashmap::DashMap;
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde_json::Value;
use std::borrow::Cow;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum StoreKey {
Global(Cow<'static, str>),
Client {
client: Cow<'static, str>,
key: Cow<'static, str>,
},
}
impl StoreKey {
pub fn global<K: Into<Cow<'static, str>>>(key: K) -> Self {
StoreKey::Global(key.into())
}
pub fn client<C: Into<Cow<'static, str>>, K: Into<Cow<'static, str>>>(
client: C,
key: K,
) -> Self {
StoreKey::Client {
client: client.into(),
key: key.into(),
}
}
pub fn namespace(&self) -> &'static str {
match self {
StoreKey::Global(_) => "global",
StoreKey::Client { .. } => "client",
}
}
}
#[derive(Debug, Default, Clone)]
pub struct Store {
global: DashMap<String, Value>,
per_client: DashMap<String, DashMap<String, Value>>,
}
impl Store {
pub fn new() -> Self {
Self {
global: DashMap::new(),
per_client: DashMap::new(),
}
}
pub fn insert<V: Serialize>(&self, key: StoreKey, value: V) -> anyhow::Result<()> {
let value = serde_json::to_value(value)
.map_err(|error| anyhow::anyhow!("Failed to serialize store value: {error}"))?;
match key {
StoreKey::Global(key) => {
self.global.insert(key.into_owned(), value);
}
StoreKey::Client { client, key } => {
let map = self.per_client.entry(client.into_owned()).or_default();
map.insert(key.into_owned(), value);
}
}
Ok(())
}
pub fn get<T: DeserializeOwned>(&self, key: StoreKey) -> Option<T> {
match key {
StoreKey::Global(key) => self
.global
.get(key.as_ref())
.and_then(|value| serde_json::from_value(value.value().clone()).ok()),
StoreKey::Client { client, key } => self
.per_client
.get(client.as_ref())
.and_then(|map| map.get(key.as_ref()).map(|value| value.value().clone()))
.and_then(|value| serde_json::from_value(value).ok()),
}
}
pub fn contains(&self, key: StoreKey) -> bool {
match key {
StoreKey::Global(key) => self.global.contains_key(key.as_ref()),
StoreKey::Client { client, key } => self
.per_client
.get(client.as_ref())
.map(|map| map.contains_key(key.as_ref()))
.unwrap_or(false),
}
}
}