use std::{
cell::RefCell,
fmt::Debug,
hash::Hash,
sync::{Arc, OnceLock},
};
use hashbrown::HashSet;
use map::ConcurrentMap;
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
mod map;
pub struct Graph<Q, R> {
new: QueryNodeMap<Q, R>,
old: QueryNodeMap<Q, R>,
resolver: Box<dyn ResolveQuery<Q, R>>,
}
#[derive(Debug)]
struct Node<Q, R> {
result: R,
changed: bool,
edges_from: Arc<HashSet<Q>>,
}
type QueryNodeMap<Q, R> = Arc<ConcurrentMap<Q, Arc<OnceLock<Node<Q, R>>>>>;
impl<Q: Debug + Clone + Eq + Hash, R: Debug + Clone> Debug for Graph<Q, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Graph")
.field("new", &self.new)
.field("old", &self.old)
.finish()
}
}
impl<Q: Clone + Eq + Hash + Send + Sync, R: Clone + Eq + Send + Sync> Graph<Q, R> {
pub fn new(resolver: impl ResolveQuery<Q, R> + 'static) -> Arc<Self> {
Arc::new(Self {
new: Arc::new(ConcurrentMap::new()),
old: Arc::new(ConcurrentMap::new()),
resolver: Box::new(resolver),
})
}
pub fn query(self: &Arc<Self>, q: Q) -> R {
let node = self.get_node(&q);
let node = node.get_or_init(|| self.resolve(q));
node.result.clone()
}
fn get_node(self: &Arc<Self>, q: &Q) -> Arc<OnceLock<Node<Q, R>>> {
self.new
.get_or_insert(q.clone(), || Arc::new(OnceLock::default()))
}
fn resolve(self: &Arc<Self>, q: Q) -> Node<Q, R> {
if let Some(old) = self.old.get(&q) {
let old_node = old.get();
if let Some(old_node) = old_node {
if old_node.edges_from.len() == 0 {
let resolver = Arc::new(QueryResolver::new(self.clone()));
let result = self.resolver.resolve(q, resolver.clone());
Node {
changed: result != old_node.result,
result,
edges_from: Arc::new(resolver.edges_from.take()),
}
} else {
let any_changed = old_node.edges_from.par_iter().any(|parent| {
let node = self.get_node(parent);
let node = node.get_or_init(|| self.resolve(parent.clone()));
node.changed
});
if any_changed {
let resolver = Arc::new(QueryResolver::new(self.clone()));
let result = self.resolver.resolve(q, resolver.clone());
Node {
changed: result != old_node.result,
result,
edges_from: Arc::new(resolver.edges_from.take()),
}
} else {
Node {
result: old_node.result.clone(),
edges_from: old_node.edges_from.clone(),
changed: false,
}
}
}
} else {
let resolver = Arc::new(QueryResolver::new(self.clone()));
let result = self.resolver.resolve(q, resolver.clone());
Node {
changed: match old.get() {
Some(old_node) => result != old_node.result,
None => true,
},
result,
edges_from: Arc::new(resolver.edges_from.take()),
}
}
} else {
let resolver = Arc::new(QueryResolver::new(self.clone()));
let result = self.resolver.resolve(q, resolver.clone());
Node {
result,
changed: false,
edges_from: Arc::new(resolver.edges_from.take()),
}
}
}
pub fn increment(self: &Arc<Self>, resolver: impl ResolveQuery<Q, R> + 'static) -> Arc<Self> {
Arc::new(Self {
new: Arc::new(ConcurrentMap::new()),
old: self.new.clone(),
resolver: Box::new(resolver),
})
}
}
pub struct QueryResolver<Q, R> {
graph: Arc<Graph<Q, R>>,
edges_from: RefCell<HashSet<Q>>,
}
unsafe impl<Q, R> Send for QueryResolver<Q, R> {}
unsafe impl<Q, R> Sync for QueryResolver<Q, R> {}
impl<Q: Clone + Eq + Hash + Send + Sync, R: Clone + Eq + Send + Sync> QueryResolver<Q, R> {
fn new(graph: Arc<Graph<Q, R>>) -> Self {
Self {
graph,
edges_from: RefCell::new(HashSet::new()),
}
}
pub fn query(&self, q: Q) -> R {
let result = self.graph.query(q.clone());
self.edges_from.borrow_mut().insert(q);
result
}
}
pub trait ResolveQuery<Q, R>: Send + Sync {
fn resolve(&self, q: Q, resolve: Arc<QueryResolver<Q, R>>) -> R;
}