use std::{
collections::{HashMap, HashSet},
hash::Hash,
marker::PhantomData,
sync::Arc,
};
use arc_swap::ArcSwap;
use dashmap::DashMap;
use crate::{
cell_map::{CellMap, CellMapInner, MapDiff},
subscription::SubscriptionGuard,
traits::{
CellValue,
reactive_keys::{KeyChange, ReactiveKeys},
reactive_map::ReactiveMap,
},
};
#[derive(Clone)]
struct NestState<PK, K>
where
PK: Hash + Eq + CellValue,
K: Hash + Eq + CellValue,
{
forward: HashMap<PK, HashSet<K>>,
reverse: HashMap<K, PK>,
}
impl<PK, K> Default for NestState<PK, K>
where
PK: Hash + Eq + CellValue,
K: Hash + Eq + CellValue,
{
fn default() -> Self {
Self {
forward: HashMap::new(),
reverse: HashMap::new(),
}
}
}
pub struct NestedMap<PK, K, V>
where
PK: Hash + Eq + CellValue,
K: Hash + Eq + CellValue,
V: CellValue,
{
source_inner: Arc<CellMapInner<K, V>>,
state: Arc<ArcSwap<NestState<PK, K>>>,
reverse_cache: Arc<DashMap<K, PK>>,
_index_guard: Arc<SubscriptionGuard>,
_marker: PhantomData<V>,
}
impl<PK, K, V> Clone for NestedMap<PK, K, V>
where
PK: Hash + Eq + CellValue,
K: Hash + Eq + CellValue,
V: CellValue,
{
fn clone(&self) -> Self {
Self {
source_inner: self.source_inner.clone(),
state: self.state.clone(),
reverse_cache: self.reverse_cache.clone(),
_index_guard: self._index_guard.clone(),
_marker: PhantomData,
}
}
}
impl<PK, K, V> NestedMap<PK, K, V>
where
PK: Hash + Eq + CellValue,
K: Hash + Eq + CellValue,
V: CellValue,
{
pub fn new<M: Send + Sync + 'static>(
source: &CellMap<K, V, M>,
fk: impl Fn(&V) -> PK + Send + Sync + 'static,
) -> Self {
let state = Arc::new(ArcSwap::from_pointee(NestState::default()));
let reverse_cache = Arc::new(DashMap::<K, PK>::new());
let state_for_sub = state.clone();
let reverse_cache_for_sub = reverse_cache.clone();
let fk = Arc::new(fk);
let guard = source.subscribe_diffs(move |diff| {
fn apply<PK, K, V>(
st: &mut NestState<PK, K>,
rc: &DashMap<K, PK>,
diff: &MapDiff<K, V>,
fk: &dyn Fn(&V) -> PK,
) where
PK: Hash + Eq + CellValue,
K: Hash + Eq + CellValue,
V: CellValue,
{
match diff {
MapDiff::Initial { entries } => {
st.forward.clear();
st.reverse.clear();
rc.clear();
for (k, v) in entries {
let pk = fk(v);
st.forward.entry(pk.clone()).or_default().insert(k.clone());
st.reverse.insert(k.clone(), pk.clone());
rc.insert(k.clone(), pk);
}
}
MapDiff::Insert { key, value } => {
let pk = fk(value);
st.forward
.entry(pk.clone())
.or_default()
.insert(key.clone());
st.reverse.insert(key.clone(), pk.clone());
rc.insert(key.clone(), pk);
}
MapDiff::Remove { key, .. } => {
if let Some(old_pk) = st.reverse.remove(key)
&& let Some(set) = st.forward.get_mut(&old_pk)
{
set.remove(key);
if set.is_empty() {
st.forward.remove(&old_pk);
}
}
rc.remove(key);
}
MapDiff::Update { key, new_value, .. } => {
let new_pk = fk(new_value);
if let Some(old_pk) = st.reverse.insert(key.clone(), new_pk.clone())
&& old_pk != new_pk
&& let Some(set) = st.forward.get_mut(&old_pk)
{
set.remove(key);
if set.is_empty() {
st.forward.remove(&old_pk);
}
}
st.forward
.entry(new_pk.clone())
.or_default()
.insert(key.clone());
rc.insert(key.clone(), new_pk);
}
MapDiff::Batch { changes } => {
for change in changes {
apply(st, rc, change, fk);
}
}
}
}
state_for_sub.rcu(|current| {
let mut next = current.as_ref().clone();
apply(&mut next, &reverse_cache_for_sub, diff, fk.as_ref());
next
});
});
Self {
source_inner: source.inner.clone(),
state,
reverse_cache,
_index_guard: Arc::new(guard),
_marker: PhantomData,
}
}
pub fn lookup_parent(&self, child_key: &K) -> Option<PK> {
self.reverse_cache.get(child_key).map(|r| r.value().clone())
}
pub fn children_of(&self, parent_key: &PK) -> Vec<K> {
let snapshot = self.state.load();
snapshot
.forward
.get(parent_key)
.map(|set| set.iter().cloned().collect())
.unwrap_or_default()
}
fn subscribe_source_diffs(
&self,
cb: impl Fn(&MapDiff<K, V>) + Send + Sync + 'static,
) -> SubscriptionGuard {
use crate::traits::Watchable;
let entries: Vec<(K, V)> = self
.source_inner
.data
.iter()
.map(|r| (r.key().clone(), r.value().clone()))
.collect();
cb(&MapDiff::Initial { entries });
let keepalive = self.source_inner.clone();
let diffs = self.source_inner.diffs_cell.clone().lock();
let first = Arc::new(std::sync::atomic::AtomicBool::new(true));
diffs.subscribe(move |signal| {
let _ = &keepalive;
if first.swap(false, std::sync::atomic::Ordering::SeqCst) {
return;
}
if let crate::Signal::Value(diff) = signal {
cb(diff.as_ref());
}
})
}
pub fn subscribe_parent(
&self,
parent_key: PK,
cb: impl Fn(&MapDiff<K, V>) + Send + Sync + 'static,
) -> SubscriptionGuard {
let reverse_cache = self.reverse_cache.clone();
self.subscribe_source_diffs(move |diff| {
fn filter_for_parent<PK, K, V>(
diff: &MapDiff<K, V>,
_parent_key: &PK,
_rc: &DashMap<K, PK>,
fk_match: &dyn Fn(&K) -> bool,
) -> Option<MapDiff<K, V>>
where
PK: Hash + Eq + CellValue,
K: Hash + Eq + CellValue,
V: CellValue,
{
match diff {
MapDiff::Initial { entries } => {
let filtered: Vec<_> = entries
.iter()
.filter(|(k, _)| fk_match(k))
.cloned()
.collect();
if filtered.is_empty() {
None
} else {
Some(MapDiff::Initial { entries: filtered })
}
}
MapDiff::Insert { key, .. }
| MapDiff::Remove { key, .. }
| MapDiff::Update { key, .. } => {
if fk_match(key) {
Some(diff.clone())
} else {
None
}
}
MapDiff::Batch { changes } => {
let filtered: Vec<_> = changes
.iter()
.filter_map(|c| filter_for_parent(c, _parent_key, _rc, fk_match))
.collect();
if filtered.is_empty() {
None
} else {
Some(MapDiff::Batch { changes: filtered })
}
}
}
}
let pk = parent_key.clone();
let fk_match =
|k: &K| -> bool { reverse_cache.get(k).is_some_and(|r| *r.value() == pk) };
if let Some(filtered) = filter_for_parent(diff, &parent_key, &reverse_cache, &fk_match)
{
cb(&filtered);
}
})
}
pub fn subscribe_grouped(
&self,
cb: impl Fn(&PK, &MapDiff<K, V>) + Send + Sync + 'static,
) -> SubscriptionGuard {
let reverse_cache = self.reverse_cache.clone();
let fk_arc = self.state.clone();
self.subscribe_source_diffs(move |diff| {
fn route_diff<PK, K, V>(
diff: &MapDiff<K, V>,
rc: &DashMap<K, PK>,
cb: &dyn Fn(&PK, &MapDiff<K, V>),
) where
PK: Hash + Eq + CellValue,
K: Hash + Eq + CellValue,
V: CellValue,
{
match diff {
MapDiff::Initial { entries } => {
let mut groups: HashMap<PK, Vec<(K, V)>> = HashMap::new();
for (k, v) in entries {
if let Some(pk) = rc.get(k) {
groups
.entry(pk.value().clone())
.or_default()
.push((k.clone(), v.clone()));
}
}
for (pk, group_entries) in groups {
cb(
&pk,
&MapDiff::Initial {
entries: group_entries,
},
);
}
}
MapDiff::Insert { key, .. }
| MapDiff::Remove { key, .. }
| MapDiff::Update { key, .. } => {
if let Some(pk) = rc.get(key) {
cb(pk.value(), diff);
}
}
MapDiff::Batch { changes } => {
let mut groups: HashMap<PK, Vec<MapDiff<K, V>>> = HashMap::new();
fn collect_by_parent<PK, K, V>(
diff: &MapDiff<K, V>,
rc: &DashMap<K, PK>,
groups: &mut HashMap<PK, Vec<MapDiff<K, V>>>,
) where
PK: Hash + Eq + CellValue,
K: Hash + Eq + CellValue,
V: CellValue,
{
match diff {
MapDiff::Insert { key, .. }
| MapDiff::Remove { key, .. }
| MapDiff::Update { key, .. } => {
if let Some(pk) = rc.get(key) {
groups
.entry(pk.value().clone())
.or_default()
.push(diff.clone());
}
}
MapDiff::Batch { changes } => {
for change in changes {
collect_by_parent(change, rc, groups);
}
}
MapDiff::Initial { .. } => {
}
}
}
for change in changes {
collect_by_parent(change, rc, &mut groups);
}
for (pk, diffs) in groups {
if diffs.len() == 1 {
cb(&pk, &diffs[0]);
} else {
cb(&pk, &MapDiff::Batch { changes: diffs });
}
}
}
}
}
let _ = &fk_arc; route_diff(diff, &reverse_cache, &cb);
})
}
}
impl<PK, K, V> ReactiveKeys for NestedMap<PK, K, V>
where
PK: Hash + Eq + CellValue,
K: Hash + Eq + CellValue,
V: CellValue,
{
type Key = K;
fn key_set(&self) -> Vec<K> {
self.source_inner
.data
.iter()
.map(|r| r.key().clone())
.collect()
}
fn contains_key(&self, key: &K) -> bool {
self.source_inner.data.contains_key(key)
}
fn subscribe_keys(
&self,
cb: impl Fn(&KeyChange<K>) + Send + Sync + 'static,
) -> SubscriptionGuard {
use crate::cell_map::map_diff_to_key_change;
self.subscribe_source_diffs(move |diff| {
if let Some(kc) = map_diff_to_key_change(diff) {
cb(&kc);
}
})
}
}
impl<PK, K, V> ReactiveMap for NestedMap<PK, K, V>
where
PK: Hash + Eq + CellValue,
K: Hash + Eq + CellValue,
V: CellValue,
{
type Value = V;
fn get_value(&self, key: &K) -> Option<V> {
self.source_inner.data.get(key).map(|r| r.value().clone())
}
fn snapshot(&self) -> Vec<(K, V)> {
self.source_inner
.data
.iter()
.map(|r| (r.key().clone(), r.value().clone()))
.collect()
}
fn subscribe_diffs_reactive(
&self,
cb: impl Fn(&MapDiff<K, V>) + Send + Sync + 'static,
) -> SubscriptionGuard {
self.subscribe_source_diffs(cb)
}
}
impl<K, V, M> CellMap<K, V, M>
where
K: Hash + Eq + CellValue,
V: CellValue,
M: Send + Sync + 'static,
{
pub fn nest<PK>(&self, fk: impl Fn(&V) -> PK + Send + Sync + 'static) -> NestedMap<PK, K, V>
where
PK: Hash + Eq + CellValue,
{
NestedMap::new(self, fk)
}
}