use std::{
any::{Any, TypeId},
sync::{atomic::AtomicBool, Arc},
};
#[cfg(feature = "typed")]
use schemars::{JsonSchema, SchemaGenerator};
use serde::Serialize;
use tokio::sync::{
OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock, RwLockReadGuard, RwLockWriteGuard,
};
use crate::{
callable::{CallableFetch, CallableParam, SupportsAsync},
store::pointers::OwnedStoreWriteLock,
App, User,
};
use super::pointers::StoreWriteLock;
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StoreId(pub(crate) TypeId);
impl StoreId {
pub fn of<T: 'static>() -> Self {
StoreId(TypeId::of::<T>())
}
}
#[derive(Clone)]
pub struct AnyStore {
pub(crate) id: StoreId,
pub(crate) name: Arc<str>,
pub(crate) dirty: Arc<AtomicBool>,
pub(crate) data: Arc<RwLock<dyn Any + Send + Sync>>,
pub(crate) serializer: Arc<
dyn Fn(&dyn Any, &User) -> Result<serde_json::Value, StoreSerializeError>
+ Send
+ Sync
+ 'static,
>,
#[cfg(feature = "typed")]
pub(crate) desc:
Arc<dyn Fn(&mut SchemaGenerator) -> crate::typed::StoreDesc + Send + Sync + 'static>,
}
impl std::fmt::Debug for AnyStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AnyStore")
.field("id", &self.id)
.field("dirty", &self.dirty)
.field("data", &self.data)
.finish_non_exhaustive()
}
}
#[derive(Debug, thiserror::Error)]
pub enum StoreSerializeError {
#[error("failed to serialize store data: {0}")]
Serialize(#[from] serde_json::Error),
}
pub struct Store<T: StoreData> {
app: App,
pub(crate) inner: AnyStore,
_phantom: std::marker::PhantomData<T>,
}
pub trait StoreData: Send + Sync + 'static {
#[cfg(not(feature = "typed"))]
type Select<'this>: Serialize;
#[cfg(feature = "typed")]
type Select<'this>: Serialize + JsonSchema;
fn init() -> Self;
fn name() -> impl AsRef<str> + Send {
std::any::type_name::<Self>()
}
#[allow(unused_variables)]
fn select(&self, user: &User) -> Self::Select<'_>;
}
impl AnyStore {
pub fn new<T: StoreData>() -> Self {
Self {
id: StoreId::of::<T>(),
name: Arc::from(T::name().as_ref()),
dirty: Arc::new(AtomicBool::new(false)),
data: Arc::new(RwLock::new(T::init())),
serializer: Arc::new(|data, user| {
let data = data.downcast_ref::<T>().expect(&std::format!(
"store data is not of expected type {}",
std::any::type_name::<T>()
));
serde_json::to_value(T::select(&data, user)).map_err(Into::into)
}),
#[cfg(feature = "typed")]
desc: Arc::new(|generator| crate::typed::StoreDesc::new::<T>(generator)),
}
}
}
impl<T: StoreData> Store<T> {
pub async fn new(app: App) -> Self {
let id = StoreId::of::<T>();
let inner = app
.inner
.state
.stores
.read()
.await
.get(&id)
.cloned()
.expect("store not found");
Store {
app,
inner,
_phantom: std::marker::PhantomData,
}
}
pub async fn read(&self) -> RwLockReadGuard<'_, T> {
RwLockReadGuard::map(self.inner.data.read().await, |inner| {
inner
.downcast_ref::<T>()
.expect("failed to downcast store (is the store of the right type?)")
})
}
pub async fn write(&self) -> StoreWriteLock<'_, T> {
StoreWriteLock::new(
&self.app,
&self.inner,
RwLockWriteGuard::map(self.inner.data.write().await, |inner| {
inner
.downcast_mut::<T>()
.expect("failed to downcast store (is the store of the right type?)")
}),
)
}
pub async fn read_owned(&self) -> OwnedRwLockReadGuard<dyn Any + Send + Sync, T> {
OwnedRwLockReadGuard::map(self.inner.data.clone().read_owned().await, |inner| {
inner
.downcast_ref::<T>()
.expect("failed to downcast store (is the store of the right type?)")
})
}
pub async fn write_owned(&self) -> OwnedStoreWriteLock<T> {
OwnedStoreWriteLock {
app: self.app.clone(),
store: self.clone(),
guard: OwnedRwLockWriteGuard::map(
self.inner.data.clone().write_owned().await,
|inner| {
inner
.downcast_mut::<T>()
.expect("failed to downcast store (is the store of the right type?)")
},
),
}
}
}
impl<T: StoreData, Ctx: CallableFetch<App> + Send + Sync, Init: Send + Sync>
CallableParam<Ctx, Init> for Store<T>
{
type Error = std::convert::Infallible;
async fn extract(ctx: &mut Ctx, _init: &Init) -> Result<Self, Self::Error> {
let app = ctx.fetch();
let id = StoreId::of::<T>();
let existing_store = app.inner.state.stores.read().await.get(&id).cloned();
let store = match existing_store {
Some(store) => store,
None => panic!("store not found when trying to extract parameter: {:?}", id),
};
Ok(Store {
app: app.clone(),
inner: store,
_phantom: std::marker::PhantomData,
})
}
}
impl<T: StoreData> SupportsAsync for Store<T> {}
impl<T: StoreData> Clone for Store<T> {
fn clone(&self) -> Self {
Self {
app: self.app.clone(),
inner: self.inner.clone(),
_phantom: std::marker::PhantomData,
}
}
}