use std::{
collections::HashMap,
hash::Hash,
sync::{Arc, Mutex},
};
use uuid::Uuid;
use crate::{
cell_map::MapDiff,
map_query::{MapDiffSink, MapQuery, MapQueryInstall},
subscription::SubscriptionGuard,
traits::CellValue,
};
type DiffSubscriber<K, V> = Arc<dyn Fn(&MapDiff<K, V>) + Send + Sync>;
type UpstreamInstall<K, V> =
Box<dyn FnOnce(MapDiffSink<K, V>) -> Vec<SubscriptionGuard> + Send + Sync>;
pub(crate) struct SharedMapQueryInner<K, V>
where
K: CellValue + Hash + Eq,
V: CellValue,
{
upstream: Mutex<Option<UpstreamInstall<K, V>>>,
upstream_guards: Mutex<Vec<SubscriptionGuard>>,
state: Mutex<HashMap<K, V>>,
subscribers: parking_lot::Mutex<Arc<Vec<(Uuid, DiffSubscriber<K, V>)>>>,
}
impl<K, V> SharedMapQueryInner<K, V>
where
K: CellValue + Hash + Eq,
V: CellValue,
{
fn add_subscriber(&self, id: Uuid, cb: DiffSubscriber<K, V>) {
let _old = {
let mut guard = self.subscribers.lock();
let mut next: Vec<(Uuid, DiffSubscriber<K, V>)> = (**guard).clone();
next.push((id, cb));
std::mem::replace(&mut *guard, Arc::new(next))
};
}
fn remove_subscriber(&self, id: Uuid) -> usize {
let (remaining, _old) = {
let mut guard = self.subscribers.lock();
let mut next: Vec<(Uuid, DiffSubscriber<K, V>)> = (**guard)
.iter()
.filter(|(i, _)| *i != id)
.cloned()
.collect();
let remaining = next.len();
next.shrink_to_fit();
(remaining, std::mem::replace(&mut *guard, Arc::new(next)))
};
remaining
}
fn apply_diff(&self, diff: &MapDiff<K, V>) {
let mut state = self.state.lock().expect("share state poisoned");
Self::apply_diff_into(&mut state, diff);
}
fn apply_diff_into(state: &mut HashMap<K, V>, diff: &MapDiff<K, V>) {
match diff {
MapDiff::Initial { entries } => {
state.clear();
state.reserve(entries.len());
for (k, v) in entries {
state.insert(k.clone(), v.clone());
}
}
MapDiff::Insert { key, value } => {
state.insert(key.clone(), value.clone());
}
MapDiff::Remove { key, .. } => {
state.remove(key);
}
MapDiff::Update { key, new_value, .. } => {
state.insert(key.clone(), new_value.clone());
}
MapDiff::Batch { changes } => {
for c in changes {
Self::apply_diff_into(state, c);
}
}
}
}
fn snapshot_initial(&self) -> MapDiff<K, V> {
let state = self.state.lock().expect("share state poisoned");
let entries: Vec<(K, V)> = state.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
MapDiff::Initial { entries }
}
}
pub struct SharedMapQuery<K, V>
where
K: CellValue + Hash + Eq,
V: CellValue,
{
inner: Arc<SharedMapQueryInner<K, V>>,
}
impl<K, V> Clone for SharedMapQuery<K, V>
where
K: CellValue + Hash + Eq,
V: CellValue,
{
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<K, V> SharedMapQuery<K, V>
where
K: CellValue + Hash + Eq,
V: CellValue,
{
pub fn new<Q: MapQuery<K, V>>(q: Q) -> Self {
let upstream: UpstreamInstall<K, V> = Box::new(move |sink| q.install(sink));
Self {
inner: Arc::new(SharedMapQueryInner {
upstream: Mutex::new(Some(upstream)),
upstream_guards: Mutex::new(Vec::new()),
state: Mutex::new(HashMap::new()),
subscribers: parking_lot::Mutex::new(Arc::new(Vec::new())),
}),
}
}
}
impl<K, V> MapQueryInstall<K, V> for SharedMapQuery<K, V>
where
K: CellValue + Hash + Eq,
V: CellValue,
{
fn install(self, sink: MapDiffSink<K, V>) -> Vec<SubscriptionGuard> {
let id = Uuid::new_v4();
let upstream_take = {
let mut slot = self.inner.upstream.lock().expect("share upstream poisoned");
slot.take()
};
if let Some(install_fn) = upstream_take {
self.inner.add_subscriber(id, sink);
let weak = Arc::downgrade(&self.inner);
let fanout: MapDiffSink<K, V> = Arc::new(move |diff: &MapDiff<K, V>| {
let Some(inner) = weak.upgrade() else {
return;
};
inner.apply_diff(diff);
let subs = Arc::clone(&*inner.subscribers.lock());
for (_, cb) in subs.iter() {
cb(diff);
}
});
let guards = install_fn(fanout);
let mut slot = self
.inner
.upstream_guards
.lock()
.expect("share upstream_guards poisoned");
slot.extend(guards);
} else {
let initial = self.inner.snapshot_initial();
sink(&initial);
self.inner.add_subscriber(id, sink);
}
let weak = Arc::downgrade(&self.inner);
vec![SubscriptionGuard::from_callback(move || {
let Some(inner) = weak.upgrade() else {
return;
};
let remaining = inner.remove_subscriber(id);
if remaining == 0 {
let drained: Vec<SubscriptionGuard> = {
let mut slot = inner
.upstream_guards
.lock()
.expect("share upstream_guards poisoned");
std::mem::take(&mut *slot)
};
drop(drained);
}
})]
}
}
#[allow(private_bounds)]
impl<K, V> MapQuery<K, V> for SharedMapQuery<K, V>
where
K: CellValue + Hash + Eq,
V: CellValue,
{
}
pub trait MapQueryShareExt<K, V>: MapQuery<K, V>
where
K: CellValue + Hash + Eq,
V: CellValue,
{
fn share(self) -> SharedMapQuery<K, V> {
SharedMapQuery::new(self)
}
}
impl<K, V, Q> MapQueryShareExt<K, V> for Q
where
K: CellValue + Hash + Eq,
V: CellValue,
Q: MapQuery<K, V>,
{
}