use std::{collections::HashSet, sync::Arc};
use get_mut_drop_weak::sync::get_mut_drop_weak;
use tokio::task::JoinSet;
use uuid::Uuid;
use crate::backing_store::{BackingStore, BackingStoreT, Strategy, TrackedPath};
use crate::{Fb, WriteGuard};
impl<T: Send + Sync + 'static, B: Strategy<T>> Fb<T, B> {
pub async fn make_mut(self: &mut Arc<Self>) -> WriteGuard<'_, T, B>
where
T: Clone,
{
let arc = match get_mut_drop_weak(self) {
Ok(output) => return output.load_mut().await,
Err(arc) => arc,
};
let read_guard = arc.load().await;
let new_arc = Arc::new(arc.pool().insert(read_guard.clone()));
drop(read_guard);
*arc = new_arc;
Arc::get_mut(arc).unwrap().load_mut().await
}
pub fn blocking_make_mut(self: &mut Arc<Self>) -> WriteGuard<'_, T, B>
where
T: Clone,
{
let arc = match get_mut_drop_weak(self) {
Ok(output) => return output.blocking_load_mut(),
Err(arc) => arc,
};
let read_guard = arc.blocking_load();
let new_arc = Arc::new(arc.pool().insert(read_guard.clone()));
drop(read_guard);
*arc = new_arc;
Arc::get_mut(arc).unwrap().blocking_load_mut()
}
}
pub fn blocking_save<T: Send + Sync + 'static, B: Strategy<T>, R, E>(
store: &Arc<BackingStore<B>>,
arcs: impl IntoIterator<Item = Arc<Fb<T, B>>>,
tracked: &Arc<TrackedPath<B::PersistPath>>,
max_simultaneous_tasks: usize,
change_key: impl FnOnce() -> Result<R, E>,
) -> Result<R, E> {
blocking_save_with(
store,
|persister| {
for arc in arcs {
persister.persist(&arc);
}
Ok::<_, E>(())
},
tracked,
max_simultaneous_tasks,
change_key,
)
}
pub fn blocking_save_with<B: BackingStoreT, R, E>(
store: &Arc<BackingStore<B>>,
persist_arcs: impl FnOnce(&mut Persister<B>) -> Result<(), E>,
tracked: &Arc<TrackedPath<B::PersistPath>>,
max_simultaneous_tasks: usize,
change_key: impl FnOnce() -> Result<R, E>,
) -> Result<R, E> {
let mut old_keys = tracked.all_keys();
let new_keys_set = prepare_save(
persist_arcs,
tracked,
max_simultaneous_tasks,
store.runtime_handle(),
)?;
store.blocking_sync(tracked.path());
let output = change_key()?;
old_keys.retain(|key| !new_keys_set.contains(key));
post_save_cleanup(store, tracked, max_simultaneous_tasks, &old_keys);
Ok(output)
}
pub fn prepare_save<B: BackingStoreT, E>(
persist_arcs: impl FnOnce(&mut Persister<B>) -> Result<(), E>,
tracked: &Arc<TrackedPath<B::PersistPath>>,
max_simultaneous_tasks: usize,
runtime: &tokio::runtime::Handle,
) -> Result<HashSet<Uuid>, E> {
assert!(max_simultaneous_tasks > 0);
let mut persister = Persister {
tracked: Arc::clone(tracked),
join_set: JoinSet::new(),
new_keys_set: HashSet::new(),
max_simultaneous_tasks,
runtime: runtime.clone(),
};
persist_arcs(&mut persister)?;
let new_keys_set = persister.new_keys_set;
let _: Vec<()> = runtime.block_on(persister.join_set.join_all());
Ok(new_keys_set)
}
pub fn post_save_cleanup<B: BackingStoreT>(
store: &Arc<BackingStore<B>>,
tracked: &Arc<TrackedPath<B::PersistPath>>,
max_simultaneous_tasks: usize,
keys_to_delete: &[Uuid],
) {
assert!(max_simultaneous_tasks > 0);
let runtime = store.runtime_handle();
let mut join_set = JoinSet::new();
for &key in keys_to_delete {
if join_set.len() == max_simultaneous_tasks {
runtime.block_on(join_set.join_next()).unwrap().unwrap();
}
let store = Arc::clone(store);
let tracked = Arc::clone(tracked);
let task_tracker = store.task_tracker().clone();
let runtime_clone = runtime.clone();
join_set.spawn_on(
async move {
task_tracker
.spawn_blocking_on(
move || store.blocking_delete_persisted(&tracked, key),
&runtime_clone,
)
.await
.unwrap()
},
runtime,
);
}
let _: Vec<()> = runtime.block_on(join_set.join_all());
}
pub struct Persister<B: BackingStoreT> {
tracked: Arc<TrackedPath<B::PersistPath>>,
join_set: JoinSet<()>,
new_keys_set: HashSet<Uuid>,
max_simultaneous_tasks: usize,
runtime: tokio::runtime::Handle,
}
impl<B: BackingStoreT> Persister<B> {
pub fn persist<T: Send + Sync + 'static>(&mut self, arc: &Arc<Fb<T, B>>)
where
B: Strategy<T>,
{
let key = arc.key();
self.new_keys_set.insert(key);
if self.tracked.contains_key(key) {
return;
}
assert!(self.join_set.len() <= self.max_simultaneous_tasks);
if self.join_set.len() == self.max_simultaneous_tasks {
self.runtime
.block_on(self.join_set.join_next())
.unwrap()
.unwrap();
}
let tracked = Arc::clone(&self.tracked);
let arc = Arc::clone(arc);
self.join_set.spawn_on(
async move { arc.spawn_persist(&tracked).await.await.unwrap() },
&self.runtime,
);
}
}