sound_flow 0.3.0

Execute graphs of functions in real time
Documentation
use crate::builder::descriptors::{Connection, PendingInstruction};
use crate::builder::error::{BuildError, InputSocket, InvalidConnection};
use crate::graph::BufferAddress;
use crate::{CompiledGraph, Instruction, Node};
use petgraph::prelude::{EdgeRef, NodeIndex};
use petgraph::visit::IntoNodeReferences;
use petgraph::Direction;
use std::any::TypeId;
use std::collections::{HashMap, VecDeque};
use std::hash::Hash;
use std::iter::repeat_with;

struct EdgeData {
    from_output_socket: usize,
    to_input_socket: usize,
}

struct RelationshipNode<Id> {
    id: Id,
    missing_dependencies: usize,
}

pub(crate) struct Compilation<Id: Hash + Eq + Copy, N: Node> {
    relationships: petgraph::Graph<RelationshipNode<Id>, EdgeData>,
    pending_instructions: HashMap<Id, PendingInstruction<N>>,

    instructions: Vec<Instruction<N>>,
    buffers: Vec<N::Data>,
    lookup: HashMap<Id, usize>,
}

impl<Id: Hash + Eq + Copy, N: Node> Compilation<Id, N> {
    pub fn new(
        nodes: HashMap<Id, PendingInstruction<N>>,
        connections: Vec<Connection<Id>>,
    ) -> Result<Self, BuildError<Id>> {
        let mut relationships = petgraph::Graph::with_capacity(nodes.len(), connections.len());
        // Maps Id to corresponding node in `relationships`
        let mut tmp_graph_lookup = HashMap::with_capacity(nodes.len());

        for (&id, _) in nodes.iter() {
            let node_index = relationships.add_node(RelationshipNode {
                id,
                missing_dependencies: 0,
            });
            tmp_graph_lookup.insert(id, node_index);
        }

        for connection in connections.into_iter() {
            let Some(&from_index) = tmp_graph_lookup.get(&connection.from_node) else {
                return Err(BuildError::InvalidConnection(
                    connection,
                    InvalidConnection::InvalidFromNode,
                ))
            };

            let Some(&to_index) = tmp_graph_lookup.get(&connection.to_node) else {
                return Err(BuildError::InvalidConnection(
                    connection,
                    InvalidConnection::InvalidToNode,
                ))
            };

            if connection.from_output_socket
                >= nodes
                    .get(&connection.from_node)
                    .unwrap()
                    .sockets
                    .outputs
                    .len()
            {
                return Err(BuildError::InvalidConnection(
                    connection,
                    InvalidConnection::InvalidFromOutputSocket,
                ));
            }

            if connection.to_input_socket
                >= nodes.get(&connection.to_node).unwrap().sockets.inputs.len()
            {
                return Err(BuildError::InvalidConnection(
                    connection,
                    InvalidConnection::InvalidToInputSocket,
                ));
            }

            relationships.add_edge(
                from_index,
                to_index,
                EdgeData {
                    from_output_socket: connection.from_output_socket,
                    to_input_socket: connection.to_input_socket,
                },
            );
            relationships
                .node_weight_mut(to_index)
                .unwrap()
                .missing_dependencies += 1;
        }

        Ok(Self {
            instructions: Vec::with_capacity(nodes.len()),
            buffers: Vec::new(),
            lookup: HashMap::with_capacity(nodes.len()),

            relationships,
            pending_instructions: nodes,
        })
    }

    pub fn compile(mut self) -> Result<CompiledGraph<Id, N>, BuildError<Id>> {
        // Nodes whose dependencies have been satisfied - compilation queue.
        // Init with `externals(Direction::Incoming)`: nodes with no dependencies
        let mut visit_queue: VecDeque<_> =
            self.relationships.externals(Direction::Incoming).collect();

        while let Some(node_index) = visit_queue.pop_front() {
            let w = self.relationships.node_weight(node_index).unwrap();
            debug_assert!(w.missing_dependencies == 0);
            let id = w.id;

            let instruction = self.compile_node(node_index, id)?;
            self.instructions.push(instruction);
            self.lookup.insert(id, self.instructions.len() - 1);

            self.update_dependents(&mut visit_queue, node_index);
        }

        self.ensure_compilation_complete()?;

        Ok(CompiledGraph {
            instructions: self.instructions,
            buffers: self.buffers,
            lookup: self.lookup,
        })
    }

    fn update_dependents(&mut self, visit_queue: &mut VecDeque<NodeIndex>, node_index: NodeIndex) {
        // nodes that depend on the just-compiled node
        // (need to collect due to borrow checker constraints)
        let dependents: Vec<_> = self
            .relationships
            .edges_directed(node_index, Direction::Outgoing)
            .map(|e| e.target())
            .collect();

        // decrement their missing dependency count,
        // and enqueue them for compilation when all dependencies have been satisfied
        for dependent in dependents {
            let w = self.relationships.node_weight_mut(dependent).unwrap();
            w.missing_dependencies -= 1;

            if w.missing_dependencies == 0 {
                visit_queue.push_back(dependent)
            }
        }
    }

    fn compile_node(
        &mut self,
        index: NodeIndex,
        node_id: Id,
    ) -> Result<Instruction<N>, BuildError<Id>> {
        let PendingInstruction { node, sockets } =
            self.pending_instructions.remove(&node_id).unwrap();

        let socket_inputs = self.collect_inputs(index, node_id, sockets.inputs)?;

        let socket_outputs = sockets
            .outputs
            .into_iter()
            .map(|socket| {
                self.buffers.push(socket.buffer);
                let buffer_index = self.buffers.len() - 1;

                BufferAddress {
                    index: buffer_index,
                    type_: socket.type_,
                }
            })
            .collect();

        Ok(Instruction {
            node,
            socket_inputs,
            socket_outputs,
        })
    }

    fn collect_inputs(
        &self,
        node_index: NodeIndex,
        node_id: Id,
        input_types: Vec<TypeId>,
    ) -> Result<Vec<BufferAddress>, BuildError<Id>> {
        let mut inputs: Vec<_> = repeat_with(|| Vec::new()).take(input_types.len()).collect();

        for edge in self
            .relationships
            .edges_directed(node_index, Direction::Incoming)
        {
            let EdgeData {
                from_output_socket,
                to_input_socket,
            } = *edge.weight();

            let source_node_id = self.relationships.node_weight(edge.source()).unwrap().id;

            let source_node_index = self.lookup[&source_node_id];
            let buffer_address = self.instructions[source_node_index]
                .output_address(from_output_socket)
                .clone();

            if input_types[to_input_socket] != buffer_address.type_ {
                return Err(BuildError::TypeMismatch(Connection {
                    from_node: source_node_id,
                    from_output_socket,
                    to_node: node_id,
                    to_input_socket,
                }));
            }

            inputs[to_input_socket].push(buffer_address);
        }

        inputs
            .into_iter()
            .enumerate()
            .map(|(input_index, sources)| match sources.len() {
                0 => Err(BuildError::MissingInput(InputSocket {
                    node: node_id,
                    socket_index: input_index,
                })),
                1 => Ok(sources[0]),
                _ => Err(BuildError::MultipleInputs(InputSocket {
                    node: node_id,
                    socket_index: input_index,
                })),
            })
            .collect()
    }

    fn ensure_compilation_complete(&self) -> Result<(), BuildError<Id>> {
        if self.pending_instructions.is_empty() {
            return Ok(());
        }

        let mut trace = Vec::new();

        // Find a node that has missing dependencies.
        // Since there exists a node that was not compiled, it must be missing a dependency.
        let (mut current_id, mut current_weight) = self
            .relationships
            .node_references()
            .find(|(_, node)| node.missing_dependencies != 0)
            .unwrap();

        loop {
            // A node with nonzero missing_dependencies will have an edge from an uncompiled node. Find it.
            let (edge, dependency_node_weight) = self
                .relationships
                .edges_directed(current_id, Direction::Incoming)
                .find_map(|edge| {
                    let source_id = edge.source();
                    let w = self.relationships.node_weight(source_id).unwrap();

                    (w.missing_dependencies != 0).then(|| (edge, w))
                })
                .unwrap();

            let edge_weight = edge.weight();
            trace.push(Connection {
                from_node: dependency_node_weight.id,
                from_output_socket: edge_weight.from_output_socket,
                to_node: current_weight.id,
                to_input_socket: edge_weight.to_input_socket,
            });

            if let Some(cycle_start_index) = trace
                .iter()
                .position(|connection| connection.to_node == dependency_node_weight.id)
            {
                // loop has completed, current edge leads back to cycle_start_index (in reverse)

                trace.drain(0..cycle_start_index);
                return Err(BuildError::CircularDependencies(trace));
            }

            current_weight = dependency_node_weight;
            current_id = edge.source();
        }
    }
}