siraph 0.1.2

A node-based digital signal processing crate
Documentation
use crate::error::{Error, Result};
use crate::node::{Name, Node, NodeWrapper};

use dynvec::{Block, RawDynVec};

use std::any::TypeId;
use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::ptr::NonNull;

/// The central structure of the crate. The graph manages a bunch of nodes
/// and makes them work together.
pub struct Graph {
    nodes: Vec<Option<NodeWrapper>>,

    /// Those links are known to be valid and do not create cycles.
    links: Vec<Link>,
}

impl Default for Graph {
    fn default() -> Self {
        Self {
            nodes: Vec::default(),
            links: Vec::default(),
        }
    }
}

impl Graph {
    /// Creates a new empty `Graph`.
    pub fn new() -> Self {
        Self::default()
    }

    /// Inserts a new node into this graph.
    pub fn insert(&mut self, node: impl Node + 'static) -> NodeHandle {
        let wrapper = NodeWrapper::new(node);

        match self.nodes.iter().position(Option::is_none) {
            Some(idx) => {
                unsafe { *self.nodes.get_unchecked_mut(idx) = Some(wrapper) };
                NodeHandle(idx)
            }

            None => {
                let idx = self.nodes.len();
                self.nodes.push(Some(wrapper));
                NodeHandle(idx)
            }
        }
    }

    /// Removes a node from the graph. `Error::InvalidNodeHandle` can
    /// be returned if the given `NodeHandle` was invalid (the
    /// node did not exist).
    pub fn remove(&mut self, node: NodeHandle) -> Result<()> {
        if node.0 >= self.nodes.len() {
            // prevent `Vec::remove` to panic
            return Err(Error::InvalidNodeHandle(node));
        }

        let old = self.nodes.get_mut(node.0).take();

        if old.is_none() {
            Err(Error::InvalidNodeHandle(node))
        } else {
            self.links
                .retain(|link| link.from_node != node && link.to_node != node);
            Ok(())
        }
    }

    /// Returns the number of nodes that are owned by this graph.
    pub fn count(&self) -> usize {
        self.nodes.iter().filter(|&o| o.is_some()).count()
    }

    /// Gets a reference to a node of the graph.
    fn node(&self, node: NodeHandle) -> Result<&NodeWrapper> {
        self.nodes
            .get(node.0)
            .and_then(Option::as_ref)
            .ok_or(Error::InvalidNodeHandle(node))
    }

    /// Gets a reference to a node of the graph.
    fn node_mut(&mut self, node: NodeHandle) -> Result<&mut NodeWrapper> {
        self.nodes
            .get_mut(node.0)
            .and_then(Option::as_mut)
            .ok_or(Error::InvalidNodeHandle(node))
    }

    /// Plugs two nodes together.
    /// * `Error::Incompatible` is returned if the types of the input
    /// and the output are not compatible (they are not the same).
    /// * `Error::Cycle` is returned if this new dependency makes
    /// a node depend on itself.
    pub fn plug(
        &mut self,
        from_node: NodeHandle,
        from_output: impl Into<Name>,
        to_node: NodeHandle,
        to_input: impl Into<Name>,
    ) -> Result<()> {
        let from_output = from_output.into();
        let to_input = to_input.into();

        let output_type = self.node(from_node)?.get_output(from_output)?.type_id();
        let input_type = self.node(to_node)?.get_input(to_input)?.type_id();
        if output_type != input_type {
            return Err(Error::Incompatible {
                output: output_type,
                input: input_type,
            });
        }

        let link_idx = self.links.len();
        self.links.push(Link {
            from_node,
            from_output,
            to_node,
            to_input,
        });

        fn check_cycle(graph: &Graph, init: NodeHandle, current: NodeHandle) -> bool {
            for dep in graph.dependencies(current) {
                if dep == init {
                    return true;
                }

                if check_cycle(graph, init, dep) {
                    return true;
                }
            }

            false
        }

        if to_node == from_node || check_cycle(self, to_node, from_node) {
            // the dependency was not valid
            // we have to remove it
            self.links.remove(link_idx);
            return Err(Error::Cycle(to_node));
        }

        Ok(())
    }

    /// Gets the direct dependencies of the given node.
    fn dependencies(&self, node: NodeHandle) -> Dependencies {
        Dependencies {
            node,
            inner: self.links.iter(),
        }
    }

    /// Creates a `Sink` that allows the caller to retreive data from this `Graph`.
    /// * `Error::Incompatible` is returned if `T` is not the same type as
    /// the type of `from_output`.
    pub fn sink<T: 'static>(
        &mut self,
        from_node: NodeHandle,
        from_output: impl Into<Name>,
    ) -> Result<Sink<'_, T>> {
        let from_output = from_output.into();

        // check potential type errors
        let output_type = self.node(from_node)?.get_output(from_output)?.type_id();
        let input_type = TypeId::of::<T>();
        if output_type != input_type {
            return Err(Error::Incompatible {
                input: input_type,
                output: output_type,
            });
        }

        let node_count = self.count();

        // create the schedule
        let mut schedule = Vec::with_capacity(node_count);

        fn add_deps_to_vec(
            vec: &mut Vec<NonNull<NodeWrapper>>,
            graph: &mut Graph,
            node: NodeHandle,
        ) {
            let deps: Vec<NodeHandle> = graph.dependencies(node).collect();
            for dep in deps {
                add_deps_to_vec(vec, graph, dep);
            }

            let ptr = NonNull::from(graph.node_mut(node).expect("Invalid dependency"));

            vec.push(ptr);
        }

        add_deps_to_vec(&mut schedule, self, from_node);

        // create the value pool
        let mut outputs = Vec::with_capacity(node_count);

        outputs.push(
            self.node(from_node)
                .unwrap()
                .get_output(from_output)
                .unwrap(),
        );

        for link in &self.links {
            let output = self
                .node(link.from_node)
                .unwrap()
                .get_output(link.from_output)?;
            if !outputs.contains(&output) {
                outputs.push(output);
            }
        }

        let mut pool =
            RawDynVec::with_region(Block::for_layouts(outputs.iter().map(|o| o.layout())));

        // link the outputs
        for &output in &outputs {
            unsafe {
                let handle = pool
                    .insert_raw(output.layout(), std::mem::transmute(output.drop_fn()))
                    .expect("Failed to allocate");
                let ptr = NonNull::new_unchecked(pool.get_mut_ptr_raw(handle));
                output.set_target(ptr.cast());
            }
        }

        // now that all outputs are linked
        // we can link the inputs
        for link in &self.links {
            let input = self.node(link.to_node).unwrap().get_input(link.to_input)?;
            let output = self
                .node(link.from_node)
                .unwrap()
                .get_output(link.from_output)
                .unwrap();

            unsafe {
                input.set_target(output.get_target().unwrap());
            }
        }

        let output_ptr = outputs.first().unwrap().get_target().unwrap().cast();

        // initialize the nodes
        for node in &mut self.nodes {
            if let Some(node) = node {
                node.reset();
            }
        }

        Ok(Sink {
            _g: PhantomData,
            schedule,
            // the pool must stay alive as long as the sink is used
            _pool: pool,
            output: output_ptr,
        })
    }
}

/// An iterator that look for the dependencies of a node within a `Graph`.
pub struct Dependencies<'a> {
    node: NodeHandle,
    inner: std::slice::Iter<'a, Link>,
}

impl<'a> Iterator for Dependencies<'a> {
    type Item = NodeHandle;

    fn next(&mut self) -> Option<Self::Item> {
        let node = self.node;
        self.inner
            .find(|link| link.to_node == node)
            .map(|link| link.from_node)
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        (0, Some(self.inner.len()))
    }
}

/// A handle to a node.
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct NodeHandle(pub usize);

/// A link between two nodes.
#[derive(Clone, Debug)]
pub struct Link {
    from_node: NodeHandle,
    from_output: Name,
    to_node: NodeHandle,
    to_input: Name,
}

pub struct Sink<'a, T> {
    _g: PhantomData<&'a mut Graph>,
    // this pool must be kept alive while the sink is used.
    _pool: RawDynVec<Block>,
    schedule: Vec<NonNull<NodeWrapper>>,
    output: NonNull<UnsafeCell<Option<T>>>,
}

impl<'a, T> Sink<'a, T> {
    /// Resets all node used by this sink.
    pub fn reset(&mut self) {
        for node in &mut self.schedule {
            unsafe {
                node.as_mut().reset();
            }
        }
    }
}

impl<'a, T> Iterator for Sink<'a, T> {
    type Item = T;

    fn next(&mut self) -> Option<T> {
        for node in &mut self.schedule {
            unsafe {
                node.as_mut().process();
            }
        }

        unsafe { (&mut *self.output.as_ref().get()).take() }
    }
}