use crate::error::{Error, Result};
use crate::node::{Name, Node, NodeWrapper};
use dynvec::{Block, RawDynVec};
use std::any::TypeId;
use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::ptr::NonNull;
pub struct Graph {
nodes: Vec<Option<NodeWrapper>>,
links: Vec<Link>,
}
impl Default for Graph {
fn default() -> Self {
Self {
nodes: Vec::default(),
links: Vec::default(),
}
}
}
impl Graph {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, node: impl Node + 'static) -> NodeHandle {
let wrapper = NodeWrapper::new(node);
match self.nodes.iter().position(Option::is_none) {
Some(idx) => {
unsafe { *self.nodes.get_unchecked_mut(idx) = Some(wrapper) };
NodeHandle(idx)
}
None => {
let idx = self.nodes.len();
self.nodes.push(Some(wrapper));
NodeHandle(idx)
}
}
}
pub fn remove(&mut self, node: NodeHandle) -> Result<()> {
if node.0 >= self.nodes.len() {
return Err(Error::InvalidNodeHandle(node));
}
let old = self.nodes.get_mut(node.0).take();
if old.is_none() {
Err(Error::InvalidNodeHandle(node))
} else {
self.links
.retain(|link| link.from_node != node && link.to_node != node);
Ok(())
}
}
pub fn count(&self) -> usize {
self.nodes.iter().filter(|&o| o.is_some()).count()
}
fn node(&self, node: NodeHandle) -> Result<&NodeWrapper> {
self.nodes
.get(node.0)
.and_then(Option::as_ref)
.ok_or(Error::InvalidNodeHandle(node))
}
fn node_mut(&mut self, node: NodeHandle) -> Result<&mut NodeWrapper> {
self.nodes
.get_mut(node.0)
.and_then(Option::as_mut)
.ok_or(Error::InvalidNodeHandle(node))
}
pub fn plug(
&mut self,
from_node: NodeHandle,
from_output: impl Into<Name>,
to_node: NodeHandle,
to_input: impl Into<Name>,
) -> Result<()> {
let from_output = from_output.into();
let to_input = to_input.into();
let output_type = self.node(from_node)?.get_output(from_output)?.type_id();
let input_type = self.node(to_node)?.get_input(to_input)?.type_id();
if output_type != input_type {
return Err(Error::Incompatible {
output: output_type,
input: input_type,
});
}
let link_idx = self.links.len();
self.links.push(Link {
from_node,
from_output,
to_node,
to_input,
});
fn check_cycle(graph: &Graph, init: NodeHandle, current: NodeHandle) -> bool {
for dep in graph.dependencies(current) {
if dep == init {
return true;
}
if check_cycle(graph, init, dep) {
return true;
}
}
false
}
if to_node == from_node || check_cycle(self, to_node, from_node) {
self.links.remove(link_idx);
return Err(Error::Cycle(to_node));
}
Ok(())
}
fn dependencies(&self, node: NodeHandle) -> Dependencies {
Dependencies {
node,
inner: self.links.iter(),
}
}
pub fn sink<T: 'static>(
&mut self,
from_node: NodeHandle,
from_output: impl Into<Name>,
) -> Result<Sink<'_, T>> {
let from_output = from_output.into();
let output_type = self.node(from_node)?.get_output(from_output)?.type_id();
let input_type = TypeId::of::<T>();
if output_type != input_type {
return Err(Error::Incompatible {
input: input_type,
output: output_type,
});
}
let node_count = self.count();
let mut schedule = Vec::with_capacity(node_count);
fn add_deps_to_vec(
vec: &mut Vec<NonNull<NodeWrapper>>,
graph: &mut Graph,
node: NodeHandle,
) {
let deps: Vec<NodeHandle> = graph.dependencies(node).collect();
for dep in deps {
add_deps_to_vec(vec, graph, dep);
}
let ptr = NonNull::from(graph.node_mut(node).expect("Invalid dependency"));
vec.push(ptr);
}
add_deps_to_vec(&mut schedule, self, from_node);
let mut outputs = Vec::with_capacity(node_count);
outputs.push(
self.node(from_node)
.unwrap()
.get_output(from_output)
.unwrap(),
);
for link in &self.links {
let output = self
.node(link.from_node)
.unwrap()
.get_output(link.from_output)?;
if !outputs.contains(&output) {
outputs.push(output);
}
}
let mut pool =
RawDynVec::with_region(Block::for_layouts(outputs.iter().map(|o| o.layout())));
for &output in &outputs {
unsafe {
let handle = pool
.insert_raw(output.layout(), std::mem::transmute(output.drop_fn()))
.expect("Failed to allocate");
let ptr = NonNull::new_unchecked(pool.get_mut_ptr_raw(handle));
output.set_target(ptr.cast());
}
}
for link in &self.links {
let input = self.node(link.to_node).unwrap().get_input(link.to_input)?;
let output = self
.node(link.from_node)
.unwrap()
.get_output(link.from_output)
.unwrap();
unsafe {
input.set_target(output.get_target().unwrap());
}
}
let output_ptr = outputs.first().unwrap().get_target().unwrap().cast();
for node in &mut self.nodes {
if let Some(node) = node {
node.reset();
}
}
Ok(Sink {
_g: PhantomData,
schedule,
_pool: pool,
output: output_ptr,
})
}
}
pub struct Dependencies<'a> {
node: NodeHandle,
inner: std::slice::Iter<'a, Link>,
}
impl<'a> Iterator for Dependencies<'a> {
type Item = NodeHandle;
fn next(&mut self) -> Option<Self::Item> {
let node = self.node;
self.inner
.find(|link| link.to_node == node)
.map(|link| link.from_node)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(self.inner.len()))
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct NodeHandle(pub usize);
#[derive(Clone, Debug)]
pub struct Link {
from_node: NodeHandle,
from_output: Name,
to_node: NodeHandle,
to_input: Name,
}
pub struct Sink<'a, T> {
_g: PhantomData<&'a mut Graph>,
_pool: RawDynVec<Block>,
schedule: Vec<NonNull<NodeWrapper>>,
output: NonNull<UnsafeCell<Option<T>>>,
}
impl<'a, T> Sink<'a, T> {
pub fn reset(&mut self) {
for node in &mut self.schedule {
unsafe {
node.as_mut().reset();
}
}
}
}
impl<'a, T> Iterator for Sink<'a, T> {
type Item = T;
fn next(&mut self) -> Option<T> {
for node in &mut self.schedule {
unsafe {
node.as_mut().process();
}
}
unsafe { (&mut *self.output.as_ref().get()).take() }
}
}