use crate::cacheable::Cacheable;
use crate::error::FetchError;
use dashmap::DashMap;
use futures::FutureExt;
use futures::future::Shared;
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::pin::Pin;
use std::sync::{Arc, Weak};
pub(crate) type FetchOutput<T> = Result<Option<Arc<T>>, FetchError>;
type SharedOutput<T> = Result<Option<Arc<T>>, FetchErrorClone>;
#[derive(Debug, Clone)]
pub(crate) enum FetchErrorClone {
BackendNotFound,
BackendSerialization(String),
BackendNetwork(String),
BackendOtherRendered(String),
Serialization(String),
FetcherPanic {
type_name: &'static str,
message: String,
},
IdentityMismatch {
type_name: &'static str,
},
CustomRendered(String),
InsertRendered(String),
}
impl From<FetchErrorClone> for FetchError {
fn from(value: FetchErrorClone) -> Self {
use crate::error::BackendError;
match value {
FetchErrorClone::BackendNotFound => FetchError::Backend(BackendError::NotFound),
FetchErrorClone::BackendSerialization(s) => {
FetchError::Backend(BackendError::Serialization(s))
}
FetchErrorClone::BackendNetwork(s) => FetchError::Backend(BackendError::Network(s)),
FetchErrorClone::BackendOtherRendered(s) => {
FetchError::Backend(BackendError::Other(Box::new(RenderedError(s))))
}
FetchErrorClone::Serialization(s) => FetchError::Serialization(s),
FetchErrorClone::FetcherPanic { type_name, message } => {
FetchError::FetcherPanic { type_name, message }
}
FetchErrorClone::IdentityMismatch { type_name } => {
FetchError::IdentityMismatch { type_name }
}
FetchErrorClone::CustomRendered(s) => FetchError::Custom(Box::new(RenderedError(s))),
FetchErrorClone::InsertRendered(s) => {
FetchError::Custom(Box::new(RenderedError(format!("insert during fetch: {s}"))))
}
}
}
}
#[derive(Debug)]
struct RenderedError(String);
impl std::fmt::Display for RenderedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl std::error::Error for RenderedError {}
pub(crate) fn into_clone(err: FetchError) -> FetchErrorClone {
use crate::error::BackendError;
match err {
FetchError::Backend(BackendError::NotFound) => FetchErrorClone::BackendNotFound,
FetchError::Backend(BackendError::Serialization(s)) => {
FetchErrorClone::BackendSerialization(s)
}
FetchError::Backend(BackendError::Network(s)) => FetchErrorClone::BackendNetwork(s),
#[cfg(feature = "serde")]
FetchError::Backend(BackendError::WireFormat(e)) => {
FetchErrorClone::BackendSerialization(format!("{e}"))
}
FetchError::Backend(BackendError::Other(e)) => {
FetchErrorClone::BackendOtherRendered(format!("{e}"))
}
FetchError::Serialization(s) => FetchErrorClone::Serialization(s),
FetchError::FetcherPanic { type_name, message } => {
FetchErrorClone::FetcherPanic { type_name, message }
}
FetchError::IdentityMismatch { type_name } => {
FetchErrorClone::IdentityMismatch { type_name }
}
FetchError::Custom(e) => FetchErrorClone::CustomRendered(format!("{e}")),
FetchError::Insert(e) => FetchErrorClone::InsertRendered(format!("{e}")),
}
}
type SharedFetchFuture<T> = Shared<Pin<Box<dyn Future<Output = SharedOutput<T>> + Send>>>;
type StrongFetch<T> = Arc<SharedFetchFuture<T>>;
type WeakFetch<T> = Weak<SharedFetchFuture<T>>;
struct SlotGuard<T: Cacheable> {
pending: Arc<DashMap<T::Id, WeakFetch<T>>>,
id: T::Id,
self_weak: WeakFetch<T>,
}
impl<T: Cacheable> Drop for SlotGuard<T> {
fn drop(&mut self) {
self.pending.remove_if(&self.id, |_k, current_weak| {
Weak::ptr_eq(current_weak, &self.self_weak)
});
}
}
pub(crate) struct InFlightRegistry<T: Cacheable> {
pending: Arc<DashMap<T::Id, WeakFetch<T>>>,
}
impl<T: Cacheable> InFlightRegistry<T> {
pub(crate) fn new() -> Self {
Self {
pending: Arc::new(DashMap::new()),
}
}
pub(crate) async fn get_or_fetch<F, Fut, OnFetched, OnFetchedFut>(
&self,
id: &T::Id,
fetcher: F,
on_fetched: OnFetched,
) -> FetchOutput<T>
where
F: FnOnce(T::Id) -> Fut + Send + 'static,
Fut: Future<Output = Result<Option<T>, FetchError>> + Send + 'static,
OnFetched: FnOnce(T::Id, Arc<T>) -> OnFetchedFut + Send + 'static,
OnFetchedFut: Future<Output = Arc<T>> + Send + 'static,
{
let strong: StrongFetch<T> = match self.pending.entry(id.clone()) {
dashmap::mapref::entry::Entry::Occupied(mut e) => match e.get().upgrade() {
Some(strong) => strong,
None => {
let strong = build_fetch::<T, _, _, _, _>(
id.clone(),
fetcher,
on_fetched,
&self.pending,
);
e.insert(Arc::downgrade(&strong));
strong
}
},
dashmap::mapref::entry::Entry::Vacant(e) => {
let strong =
build_fetch::<T, _, _, _, _>(id.clone(), fetcher, on_fetched, &self.pending);
e.insert(Arc::downgrade(&strong));
strong
}
};
let shared = (*strong).clone();
let _strong_holder = strong;
let out: SharedOutput<T> = shared.await;
out.map_err(FetchError::from)
}
}
fn build_fetch<T, F, Fut, OnFetched, OnFetchedFut>(
id: T::Id,
fetcher: F,
on_fetched: OnFetched,
pending: &Arc<DashMap<T::Id, WeakFetch<T>>>,
) -> StrongFetch<T>
where
T: Cacheable,
F: FnOnce(T::Id) -> Fut + Send + 'static,
Fut: Future<Output = Result<Option<T>, FetchError>> + Send + 'static,
OnFetched: FnOnce(T::Id, Arc<T>) -> OnFetchedFut + Send + 'static,
OnFetchedFut: Future<Output = Arc<T>> + Send + 'static,
{
let pending = pending.clone();
let type_name = std::any::type_name::<T>();
let id_for_guard = id.clone();
let id_for_fetch = id.clone();
let id_for_on_fetched = id;
let pending_for_guard = pending;
Arc::new_cyclic(move |self_weak: &WeakFetch<T>| {
let self_weak_for_guard = self_weak.clone();
let inner: Pin<Box<dyn Future<Output = SharedOutput<T>> + Send>> = Box::pin(async move {
let _slot_guard = SlotGuard {
pending: pending_for_guard,
id: id_for_guard,
self_weak: self_weak_for_guard,
};
let result = AssertUnwindSafe(async move { fetcher(id_for_fetch).await })
.catch_unwind()
.await;
match result {
Ok(Ok(Some(value))) => {
if value.id() != id_for_on_fetched {
return Err(FetchErrorClone::IdentityMismatch { type_name });
}
let arc = Arc::new(value);
let canonical = on_fetched(id_for_on_fetched, arc).await;
Ok(Some(canonical))
}
Ok(Ok(None)) => Ok(None),
Ok(Err(e)) => Err(into_clone(e)),
Err(panic_payload) => {
let message = if let Some(s) = panic_payload.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = panic_payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else {
String::new()
};
Err(FetchErrorClone::FetcherPanic { type_name, message })
}
}
});
inner.shared()
})
}