use std::sync::{Arc, Weak};
use dashmap::{DashMap, Entry};
use tokio::task::JoinHandle;
use tokio_util::task::TaskTracker;
use uuid::Uuid;
pub trait BackingStoreT: Send + Sync + 'static {
type PersistPath: Send + Sync;
fn delete(&self, key: Uuid);
fn delete_persisted(&self, path: &Self::PersistPath, key: Uuid);
fn register(&self, src_path: &Self::PersistPath, key: Uuid);
fn persist(&self, dest_path: &Self::PersistPath, key: Uuid);
fn sanitize_path(&self, path: &Self::PersistPath) -> impl IntoIterator<Item = Uuid>;
fn sync_persisted(&self, path: &Self::PersistPath);
}
pub trait Strategy<T>: BackingStoreT {
fn store(&self, key: Uuid, data: &T);
fn load(&self, key: Uuid) -> T;
}
impl<B: BackingStoreT> BackingStoreT for Arc<B> {
type PersistPath = B::PersistPath;
fn delete(&self, key: Uuid) {
B::delete(self, key)
}
fn delete_persisted(&self, path: &Self::PersistPath, key: Uuid) {
B::delete_persisted(self, path, key)
}
fn register(&self, src_path: &Self::PersistPath, key: Uuid) {
B::register(self, src_path, key)
}
fn persist(&self, dest_path: &Self::PersistPath, key: Uuid) {
B::persist(self, dest_path, key)
}
fn sanitize_path(&self, path: &Self::PersistPath) -> impl IntoIterator<Item = Uuid> {
B::sanitize_path(self, path)
}
fn sync_persisted(&self, path: &Self::PersistPath) {
B::sync_persisted(self, path)
}
}
impl<T, B: Strategy<T>> Strategy<T> for Arc<B> {
fn store(&self, key: Uuid, data: &T) {
B::store(self, key, data)
}
fn load(&self, key: Uuid) -> T {
B::load(self, key)
}
}
pub struct BackingStore<B: BackingStoreT> {
backing: B,
use_counts: DashMap<Uuid, Weak<Token<B>>>,
runtime: tokio::runtime::Handle,
task_tracker: TaskTracker,
}
pub(super) struct Token<B: BackingStoreT> {
key: Uuid,
store: Arc<BackingStore<B>>,
}
pub struct TrackedPath<P> {
path: P,
present: DashMap<Uuid, ()>,
}
impl<P> TrackedPath<P> {
pub fn path(&self) -> &P {
&self.path
}
pub fn all_keys(&self) -> Vec<Uuid> {
self.present.iter().map(|entry| *entry.key()).collect()
}
pub fn contains_key(&self, key: Uuid) -> bool {
self.present.contains_key(&key)
}
}
impl<B: BackingStoreT> Drop for Token<B> {
fn drop(&mut self) {
let name = self.key;
let store = Arc::clone(&self.store);
self.store.spawn_blocking(move || {
let Entry::Occupied(entry) = store.use_counts.entry(name) else {
return;
};
if entry.get().strong_count() > 0 {
return;
}
store.backing.delete(name);
entry.remove();
});
}
}
impl<B: BackingStoreT> BackingStore<B> {
pub fn new(backing: B, runtime: tokio::runtime::Handle) -> Self {
Self {
backing,
use_counts: DashMap::new(),
runtime,
task_tracker: TaskTracker::new(),
}
}
pub fn track_path(
self: &Arc<Self>,
path: B::PersistPath,
) -> JoinHandle<TrackedPath<B::PersistPath>> {
let this = Arc::clone(self);
self.spawn_blocking(move || this.blocking_track_path(path))
}
pub fn blocking_track_path(
self: &Arc<Self>,
path: B::PersistPath,
) -> TrackedPath<B::PersistPath> {
let all_keys = self.backing.sanitize_path(&path);
let present = key_map(all_keys);
TrackedPath { path, present }
}
pub fn spawn_blocking<R: Send + 'static>(
self: &Arc<Self>,
f: impl FnOnce() -> R + Send + 'static,
) -> JoinHandle<R> {
self.task_tracker.spawn_blocking_on(f, &self.runtime)
}
pub(crate) fn runtime_handle(&self) -> &tokio::runtime::Handle {
&self.runtime
}
pub fn task_tracker(&self) -> &TaskTracker {
&self.task_tracker
}
pub async fn finished(&self) {
self.task_tracker.close();
self.task_tracker.wait().await;
}
pub(super) fn store<T>(self: &Arc<Self>, key: Uuid, data: &T) -> Arc<Token<B>>
where
B: Strategy<T>,
{
let entry = match self.use_counts.entry(key) {
Entry::Vacant(entry) => entry,
Entry::Occupied(_) => panic!("Token already exists for key: {}", key),
};
self.backing.store(key, data);
let store = Arc::clone(self);
let token = Arc::new(Token { key, store });
entry.insert(Arc::downgrade(&token));
token
}
pub(super) fn load<T>(&self, token: &Token<B>) -> T
where
B: Strategy<T>,
{
self.backing.load(token.key)
}
pub(super) fn persist(&self, token: &Token<B>, tracked: &TrackedPath<B::PersistPath>) {
let entry = match tracked.present.entry(token.key) {
Entry::Occupied(_) => return,
Entry::Vacant(entry) => entry,
};
self.backing.persist(&tracked.path, token.key);
entry.insert(());
}
pub(super) fn register(
self: &Arc<Self>,
key: Uuid,
tracked: &TrackedPath<B::PersistPath>,
) -> Option<Arc<Token<B>>> {
let _exists_guard = tracked.present.get(&key)?;
let mut entry = match self.use_counts.entry(key) {
Entry::Vacant(entry) => {
self.backing.register(&tracked.path, key);
entry.insert(Weak::new())
}
Entry::Occupied(entry) => match entry.get().upgrade() {
Some(token) => return Some(token),
None => entry.into_ref(),
},
};
let store = Arc::clone(self);
let new_token = Arc::new(Token { key, store });
*entry = Arc::downgrade(&new_token);
Some(new_token)
}
pub fn sync(self: &Arc<Self>, tracked: &Arc<TrackedPath<B::PersistPath>>) -> JoinHandle<()> {
let this = Arc::clone(self);
let tracked = Arc::clone(tracked);
self.spawn_blocking(move || this.blocking_sync(tracked.path()))
}
pub fn blocking_sync(&self, path: &B::PersistPath) {
self.backing.sync_persisted(path);
}
pub fn delete_persisted(
self: &Arc<Self>,
tracked: &Arc<TrackedPath<B::PersistPath>>,
key: Uuid,
) -> JoinHandle<()> {
let this = Arc::clone(self);
let tracked = Arc::clone(tracked);
self.spawn_blocking(move || this.blocking_delete_persisted(&tracked, key))
}
pub fn blocking_delete_persisted(&self, tracked: &TrackedPath<B::PersistPath>, key: Uuid) {
let entry = match tracked.present.entry(key) {
Entry::Occupied(entry) => entry,
Entry::Vacant(_) => return,
};
self.backing.delete_persisted(tracked.path(), key);
entry.remove();
}
}
fn key_map(all_keys: impl IntoIterator<Item = Uuid>) -> DashMap<Uuid, ()> {
DashMap::from_iter(all_keys.into_iter().map(|key| (key, ())))
}
#[cfg(test)]
mod tests {
use super::*;
struct NoopBacking;
impl BackingStoreT for NoopBacking {
type PersistPath = ();
fn delete(&self, _key: Uuid) {}
fn delete_persisted(&self, _path: &(), _key: Uuid) {}
fn register(&self, _src_path: &(), _key: Uuid) {}
fn persist(&self, _dest_path: &(), _key: Uuid) {}
fn sanitize_path(&self, _path: &()) -> impl IntoIterator<Item = Uuid> {
[]
}
fn sync_persisted(&self, _path: &()) {}
}
#[tokio::test]
async fn register_drop_race() {
let store = Arc::new(BackingStore::new(
NoopBacking,
tokio::runtime::Handle::current(),
));
let key = Uuid::new_v4();
let tracked = TrackedPath {
path: (),
present: DashMap::from_iter([(key, ())]),
};
drop(store.register(key, &tracked).unwrap());
drop(store.register(key, &tracked).unwrap());
store.finished().await;
}
}