#[cfg(feature = "channel")]
use crate::channel::{Channel, Receiver, Sender};
use crate::runtime::executor::{ExecutionContext, Executor};
use crate::runtime::garbage_collector::GarbageCollector;
use crate::runtime::graph::NodeContext;
use crate::runtime::{CycleFn, Notifier};
use crate::{Control, Relationship};
use petgraph::prelude::NodeIndex;
use std::cell::UnsafeCell;
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::io;
use std::rc::{Rc, Weak};
#[cfg(feature = "channel")]
use std::sync::Arc;
type OnDrop<T> = Box<dyn FnMut(&mut T) + 'static>;
pub trait NodeHandle {
fn index(&self) -> NodeIndex;
fn depth(&self) -> u32;
fn mut_epoch(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct RawHandle {
index: NodeIndex,
depth: u32,
}
impl RawHandle {
const fn new(index: NodeIndex, depth: u32) -> Self {
Self { index, depth }
}
pub const fn index(&self) -> NodeIndex {
self.index
}
pub const fn depth(&self) -> u32 {
self.depth
}
}
impl NodeHandle for RawHandle {
fn index(&self) -> NodeIndex {
self.index
}
fn depth(&self) -> u32 {
self.depth
}
fn mut_epoch(&self) -> usize {
unimplemented!("RawHandle doesn't track mutations")
}
}
pub struct Node<T: 'static>(Rc<UnsafeCell<NodeInner<T>>>);
impl<T: 'static> NodeHandle for Node<T> {
fn index(&self) -> NodeIndex {
self.index()
}
fn depth(&self) -> u32 {
self.depth()
}
fn mut_epoch(&self) -> usize {
self.mut_epoch()
}
}
impl<T: 'static> Node<T> {
pub(crate) fn uninitialized(data: T, name: Option<String>, gc: GarbageCollector) -> Self {
Self {
0: Rc::new(UnsafeCell::new(NodeInner {
data,
name,
on_drop: None,
gc,
index: NodeIndex::new(0),
mut_epoch: 0,
depth: 0,
})),
}
}
#[inline(always)]
pub fn raw_handle(&self) -> RawHandle {
RawHandle::new(self.index(), self.depth())
}
#[inline(always)]
pub fn name(&self) -> Option<&str> {
unsafe { self.get_inner().name.as_deref() }
}
#[inline(always)]
pub fn downgrade(&self) -> WeakNode<T> {
WeakNode(Rc::downgrade(&self.0))
}
#[inline(always)]
pub fn upgrade(self) -> Option<ExclusiveNode<T>> {
if Rc::strong_count(&self.0) == 1 && Rc::weak_count(&self.0) == 1 {
Some(ExclusiveNode(self.0))
} else {
None
}
}
#[inline(always)]
pub fn index(&self) -> NodeIndex {
unsafe { self.get_inner().index }
}
#[inline(always)]
pub fn borrow(&self) -> &T {
unsafe { &self.get_inner().data }
}
#[inline(always)]
fn borrow_mut(&self) -> &mut T {
unsafe { &mut self.get_inner_mut().data }
}
pub unsafe fn get(&self) -> &T {
unsafe { &self.get_inner().data }
}
pub unsafe fn get_mut(&mut self) -> &mut T {
unsafe { &mut self.get_inner_mut().data }
}
#[inline(always)]
unsafe fn get_inner(&self) -> &NodeInner<T> {
unsafe { &*self.0.get() }
}
#[inline(always)]
unsafe fn get_inner_mut(&self) -> &mut NodeInner<T> {
unsafe { &mut *self.0.get() }
}
#[inline(always)]
pub fn depth(&self) -> u32 {
unsafe { self.get_inner().depth }
}
#[inline(always)]
pub fn mut_epoch(&self) -> usize {
unsafe { self.get_inner().mut_epoch }
}
#[inline(always)]
fn set_mut_epoch(&mut self, epoch: usize) {
unsafe {
self.get_inner_mut().mut_epoch = epoch;
}
}
}
impl<T: 'static> Clone for Node<T> {
#[inline(always)]
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
pub struct ExclusiveNode<T: 'static>(Rc<UnsafeCell<NodeInner<T>>>);
impl<T: 'static> NodeHandle for ExclusiveNode<T> {
fn index(&self) -> NodeIndex {
self.index()
}
fn depth(&self) -> u32 {
self.depth()
}
fn mut_epoch(&self) -> usize {
self.mut_epoch()
}
}
impl<T: 'static> ExclusiveNode<T> {
#[inline(always)]
pub fn downgrade(self) -> Node<T> {
Node(self.0)
}
#[inline(always)]
pub fn raw_handle(&self) -> RawHandle {
RawHandle::new(self.index(), self.depth())
}
#[inline(always)]
pub fn name(&self) -> Option<&str> {
unsafe { self.get_inner().name.as_deref() }
}
#[inline(always)]
pub fn index(&self) -> NodeIndex {
unsafe { self.get_inner().index }
}
#[inline(always)]
pub fn depth(&self) -> u32 {
unsafe { self.get_inner().depth }
}
#[inline(always)]
pub fn mut_epoch(&self) -> usize {
unsafe { self.get_inner().mut_epoch }
}
#[inline(always)]
pub fn borrow(&self) -> &T {
unsafe { &self.get_inner().data }
}
#[inline(always)]
pub fn borrow_mut(&mut self) -> &mut T {
unsafe { &mut self.get_inner_mut().data }
}
pub unsafe fn get(&self) -> &T {
unsafe { &self.get_inner().data }
}
pub unsafe fn get_mut(&mut self) -> &mut T {
unsafe { &mut self.get_inner_mut().data }
}
#[inline(always)]
unsafe fn get_inner(&self) -> &NodeInner<T> {
unsafe { &*self.0.get() }
}
#[inline(always)]
unsafe fn get_inner_mut(&self) -> &mut NodeInner<T> {
unsafe { &mut *self.0.get() }
}
}
pub struct WeakNode<T: 'static>(Weak<UnsafeCell<NodeInner<T>>>);
impl<T: 'static> WeakNode<T> {
#[inline(always)]
pub fn upgrade(&self) -> Option<Node<T>> {
self.0.upgrade().map(|rc| Node(rc))
}
}
struct NodeInner<T: 'static> {
data: T,
name: Option<String>,
on_drop: Option<OnDrop<T>>,
gc: GarbageCollector,
index: NodeIndex,
mut_epoch: usize,
depth: u32,
}
impl<T: 'static> Drop for NodeInner<T> {
fn drop(&mut self) {
self.gc.mark_for_sweep(self.index);
self.on_drop
.take()
.map(|mut on_drop| (on_drop)(&mut self.data));
}
}
pub struct NodeBuilder<T: 'static> {
data: T,
name: Option<String>,
parents: HashMap<(NodeIndex, u32), Relationship>,
on_init: Option<Box<dyn FnMut(&mut Executor, &mut T, NodeIndex) + 'static>>,
on_drop: Option<OnDrop<T>>,
allow_panic: bool,
}
impl<T: 'static> NodeBuilder<T> {
pub fn new(data: T) -> Self {
Self {
data,
name: None,
parents: HashMap::new(),
on_init: None,
on_drop: None,
allow_panic: false,
}
}
pub fn named(mut self, name: String) -> Self {
self.name = Some(name);
self
}
#[inline]
pub fn add_relationship<N: NodeHandle>(
mut self,
parent: &N,
relationship: Relationship,
) -> Self {
let entry = self.parents.entry((parent.index(), parent.depth()));
match entry {
Entry::Occupied(mut occupied) => {
if relationship.is_trigger() {
occupied.insert(relationship);
}
}
Entry::Vacant(vacant) => {
vacant.insert(relationship);
}
}
self
}
#[inline]
pub fn add_many_relationships<'a, N: NodeHandle + 'a>(
self,
parents: impl IntoIterator<Item = &'a N>,
relationship: Relationship,
) -> Self {
parents
.into_iter()
.fold(self, |s, parent| s.add_relationship(parent, relationship))
}
#[inline]
pub fn triggered_by<N: NodeHandle>(self, parent: &N) -> Self {
self.add_relationship(parent, Relationship::Trigger)
}
#[inline]
pub fn triggered_by_many<'a, N: NodeHandle + 'a>(
self,
parents: impl IntoIterator<Item = &'a N>,
) -> Self {
self.add_many_relationships(parents, Relationship::Trigger)
}
#[inline]
pub fn observer_of<N: NodeHandle>(self, parent: &N) -> Self {
self.add_relationship(parent, Relationship::Observe)
}
#[inline]
pub fn observer_of_many<'a, N: NodeHandle + 'a>(
self,
parents: impl IntoIterator<Item = &'a N>,
) -> Self {
self.add_many_relationships(parents, Relationship::Observe)
}
pub fn on_init<F>(mut self, on_init: F) -> Self
where
F: FnMut(&mut Executor, &mut T, NodeIndex) + 'static,
{
assert!(self.on_init.is_none(), "cannot set on_init twice");
self.on_init = Some(Box::new(on_init));
self
}
pub fn on_drop<F>(mut self, on_drop: F) -> Self
where
F: FnMut(&mut T) + 'static,
{
assert!(self.on_drop.is_none(), "cannot set on_drop twice");
self.on_drop = Some(Box::new(on_drop));
self
}
pub fn allow_panic(mut self, allow_panic: bool) -> Self {
self.allow_panic = allow_panic;
self
}
pub fn build<F>(self, executor: &mut Executor, mut cycle_fn: F) -> Node<T>
where
F: FnMut(&mut T, &mut ExecutionContext) -> Control + 'static,
{
let node = Node::uninitialized(self.data, self.name, executor.garbage_collector());
let depth = self
.parents
.iter()
.map(|((_, depth), _)| depth)
.max()
.map(|d| d + 1)
.unwrap_or(0);
{
let state = node.downgrade();
let cycle_fn: CycleFn = if self.allow_panic {
Box::new(move |ctx: &mut ExecutionContext| match state.upgrade() {
Some(mut state) => {
let ctrl = cycle_fn(state.borrow_mut(), ctx);
if ctrl.is_broadcast() {
state.set_mut_epoch(ctx.epoch());
}
ctrl
}
None => Control::Sweep,
})
} else {
Box::new(move |ctx: &mut ExecutionContext| match state.upgrade() {
Some(mut state) => {
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
cycle_fn(state.borrow_mut(), ctx)
})) {
Ok(ctrl) => {
if ctrl.is_broadcast() {
state.set_mut_epoch(ctx.epoch());
}
ctrl
}
Err(_) => Control::Sweep,
}
}
None => Control::Sweep,
})
};
let idx = executor.graph().add_node(NodeContext::new(cycle_fn, depth));
let inner = unsafe { node.get_inner_mut() };
inner.index = idx;
inner.depth = depth;
self.parents.iter().for_each(|((parent, _), relationship)| {
executor.graph().add_edge(*parent, idx, *relationship);
});
executor.scheduler().enable_depth(depth);
if let Some(mut on_init) = self.on_init {
(on_init)(executor, &mut unsafe { node.get_inner_mut() }.data, idx)
}
inner.on_drop = self.on_drop;
}
node
}
pub fn build_with_notifier<F>(
self,
executor: &mut Executor,
cycle_fn: F,
) -> io::Result<(Node<T>, Notifier)>
where
F: FnMut(&mut T, &mut ExecutionContext) -> Control + 'static,
{
let node = self.build(executor, cycle_fn);
let notifier = executor.register_notifier(node.index());
Ok((node, notifier))
}
#[cfg(feature = "channel")]
pub fn build_with_channel<U, F>(
self,
executor: &mut Executor,
capacity: usize,
mut cycle_fn: F,
) -> io::Result<(Node<T>, Sender<U>)>
where
U: 'static,
F: FnMut(&mut T, &mut ExecutionContext, &Receiver<U>) -> Control + 'static,
{
assert!(capacity > 0, "capacity must be greater than 0");
let node = Node::uninitialized(self.data, self.name, executor.garbage_collector());
let depth = self
.parents
.iter()
.map(|((_, depth), _)| depth)
.max()
.map(|d| d + 1)
.unwrap_or(0);
let chan = Arc::new(Channel::new(capacity));
let rx = Receiver::new(chan.clone());
{
let state = node.downgrade();
let receiver = rx;
let cycle_fn: CycleFn = if self.allow_panic {
Box::new(move |ctx: &mut ExecutionContext| match state.upgrade() {
Some(mut state) => {
let ctrl = cycle_fn(state.borrow_mut(), ctx, &receiver);
if ctrl.is_broadcast() {
state.set_mut_epoch(ctx.epoch());
}
ctrl
}
None => Control::Sweep,
})
} else {
Box::new(move |ctx: &mut ExecutionContext| match state.upgrade() {
Some(mut state) => {
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
cycle_fn(state.borrow_mut(), ctx, &receiver)
})) {
Ok(ctrl) => {
if ctrl.is_broadcast() {
state.set_mut_epoch(ctx.epoch());
}
ctrl
}
Err(_) => Control::Sweep,
}
}
None => Control::Sweep,
})
};
let idx = executor.graph().add_node(NodeContext::new(cycle_fn, depth));
let inner = unsafe { node.get_inner_mut() };
inner.index = idx;
inner.depth = depth;
self.parents.iter().for_each(|((parent, _), relationship)| {
executor.graph().add_edge(*parent, idx, *relationship);
});
executor.scheduler().enable_depth(depth);
if let Some(mut on_init) = self.on_init {
(on_init)(executor, &mut unsafe { node.get_inner_mut() }.data, idx)
}
inner.on_drop = self.on_drop;
}
let notifier = executor.register_notifier(node.index());
let tx = Sender::new(chan, notifier);
Ok((node, tx))
}
pub fn spawn<F>(self, executor: &mut Executor, mut cycle_fn: F)
where
F: FnMut(&mut T, &mut ExecutionContext) -> Control + 'static,
{
let node = Node::uninitialized(self.data, self.name, executor.garbage_collector());
let depth = self
.parents
.iter()
.map(|((_, depth), _)| depth)
.max()
.map(|d| d + 1)
.unwrap_or(0);
let mut state = node.clone();
let cycle_fn: CycleFn = if self.allow_panic {
Box::new(move |ctx: &mut ExecutionContext| {
let ctrl = cycle_fn(state.borrow_mut(), ctx);
if ctrl.is_broadcast() {
state.set_mut_epoch(ctx.epoch());
}
ctrl
})
} else {
Box::new(move |ctx: &mut ExecutionContext| {
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
cycle_fn(state.borrow_mut(), ctx)
})) {
Ok(ctrl) => {
if ctrl.is_broadcast() {
state.set_mut_epoch(ctx.epoch());
}
ctrl
}
Err(_) => Control::Sweep,
}
})
};
let idx = executor.graph().add_node(NodeContext::new(cycle_fn, depth));
let inner = unsafe { node.get_inner_mut() };
inner.index = idx;
inner.depth = depth;
self.parents.iter().for_each(|((parent, _), relationship)| {
executor.graph().add_edge(*parent, idx, *relationship);
});
executor.scheduler().enable_depth(depth);
if let Some(mut on_init) = self.on_init {
(on_init)(executor, &mut unsafe { node.get_inner_mut() }.data, idx)
}
inner.on_drop = self.on_drop;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prelude::ExecutionMode;
#[test]
fn test_exclusive_node() {
let mut executor = Executor::new(ExecutionMode::Spin);
let node = NodeBuilder::new(23).build(&mut executor, |_, _| Control::Unchanged);
let exclusive = node.upgrade();
assert!(exclusive.is_some());
let node = exclusive.unwrap().downgrade();
let _cloned = std::hint::black_box(node.clone());
let exclusive = node.upgrade();
assert!(exclusive.is_none());
}
#[test]
fn test_relationships() {
let mut executor = Executor::new(ExecutionMode::Spin);
let parent = NodeBuilder::new(()).build(&mut executor, |_, _| Control::Unchanged);
let parent_idx = parent.index();
let parent_depth = parent.depth();
let builder = NodeBuilder::new(23);
let builder = builder.add_relationship(&parent, Relationship::Observe);
assert_eq!(builder.parents.len(), 1);
let relationship = builder.parents.get(&(parent_idx, parent_depth)).unwrap();
assert!(relationship.is_observe());
let builder = builder.add_relationship(&parent, Relationship::Observe);
assert_eq!(builder.parents.len(), 1);
let relationship = builder.parents.get(&(parent_idx, parent_depth)).unwrap();
assert!(relationship.is_observe());
let builder = builder.add_relationship(&parent, Relationship::Trigger);
assert_eq!(builder.parents.len(), 1);
let relationship = builder.parents.get(&(parent_idx, parent_depth)).unwrap();
assert!(relationship.is_trigger());
let builder = builder.add_relationship(&parent, Relationship::Trigger);
assert_eq!(builder.parents.len(), 1);
let relationship = builder.parents.get(&(parent_idx, parent_depth)).unwrap();
assert!(relationship.is_trigger());
let builder = builder.add_relationship(&parent, Relationship::Observe);
assert_eq!(builder.parents.len(), 1);
let relationship = builder.parents.get(&(parent_idx, parent_depth)).unwrap();
assert!(relationship.is_trigger());
}
}