use crate::builder::descriptors::{Connection, PendingInstruction};
use crate::builder::error::{BuildError, InputSocket, InvalidConnection};
use crate::graph::BufferAddress;
use crate::{CompiledGraph, Instruction, Node};
use petgraph::prelude::{EdgeRef, NodeIndex};
use petgraph::visit::IntoNodeReferences;
use petgraph::Direction;
use std::any::TypeId;
use std::collections::{HashMap, VecDeque};
use std::hash::Hash;
use std::iter::repeat_with;
struct EdgeData {
from_output_socket: usize,
to_input_socket: usize,
}
struct RelationshipNode<Id> {
id: Id,
missing_dependencies: usize,
}
pub(crate) struct Compilation<Id: Hash + Eq + Copy, N: Node> {
relationships: petgraph::Graph<RelationshipNode<Id>, EdgeData>,
pending_instructions: HashMap<Id, PendingInstruction<N>>,
instructions: Vec<Instruction<N>>,
buffers: Vec<N::Data>,
lookup: HashMap<Id, usize>,
}
impl<Id: Hash + Eq + Copy, N: Node> Compilation<Id, N> {
pub fn new(
nodes: HashMap<Id, PendingInstruction<N>>,
connections: Vec<Connection<Id>>,
) -> Result<Self, BuildError<Id>> {
let mut relationships = petgraph::Graph::with_capacity(nodes.len(), connections.len());
let mut tmp_graph_lookup = HashMap::with_capacity(nodes.len());
for (&id, _) in nodes.iter() {
let node_index = relationships.add_node(RelationshipNode {
id,
missing_dependencies: 0,
});
tmp_graph_lookup.insert(id, node_index);
}
for connection in connections.into_iter() {
let Some(&from_index) = tmp_graph_lookup.get(&connection.from_node) else {
return Err(BuildError::InvalidConnection(
connection,
InvalidConnection::InvalidFromNode,
))
};
let Some(&to_index) = tmp_graph_lookup.get(&connection.to_node) else {
return Err(BuildError::InvalidConnection(
connection,
InvalidConnection::InvalidToNode,
))
};
if connection.from_output_socket
>= nodes
.get(&connection.from_node)
.unwrap()
.sockets
.outputs
.len()
{
return Err(BuildError::InvalidConnection(
connection,
InvalidConnection::InvalidFromOutputSocket,
));
}
if connection.to_input_socket
>= nodes.get(&connection.to_node).unwrap().sockets.inputs.len()
{
return Err(BuildError::InvalidConnection(
connection,
InvalidConnection::InvalidToInputSocket,
));
}
relationships.add_edge(
from_index,
to_index,
EdgeData {
from_output_socket: connection.from_output_socket,
to_input_socket: connection.to_input_socket,
},
);
relationships
.node_weight_mut(to_index)
.unwrap()
.missing_dependencies += 1;
}
Ok(Self {
instructions: Vec::with_capacity(nodes.len()),
buffers: Vec::new(),
lookup: HashMap::with_capacity(nodes.len()),
relationships,
pending_instructions: nodes,
})
}
pub fn compile(mut self) -> Result<CompiledGraph<Id, N>, BuildError<Id>> {
let mut visit_queue: VecDeque<_> =
self.relationships.externals(Direction::Incoming).collect();
while let Some(node_index) = visit_queue.pop_front() {
let w = self.relationships.node_weight(node_index).unwrap();
debug_assert!(w.missing_dependencies == 0);
let id = w.id;
let instruction = self.compile_node(node_index, id)?;
self.instructions.push(instruction);
self.lookup.insert(id, self.instructions.len() - 1);
self.update_dependents(&mut visit_queue, node_index);
}
self.ensure_compilation_complete()?;
Ok(CompiledGraph {
instructions: self.instructions,
buffers: self.buffers,
lookup: self.lookup,
})
}
fn update_dependents(&mut self, visit_queue: &mut VecDeque<NodeIndex>, node_index: NodeIndex) {
let dependents: Vec<_> = self
.relationships
.edges_directed(node_index, Direction::Outgoing)
.map(|e| e.target())
.collect();
for dependent in dependents {
let w = self.relationships.node_weight_mut(dependent).unwrap();
w.missing_dependencies -= 1;
if w.missing_dependencies == 0 {
visit_queue.push_back(dependent)
}
}
}
fn compile_node(
&mut self,
index: NodeIndex,
node_id: Id,
) -> Result<Instruction<N>, BuildError<Id>> {
let PendingInstruction { node, sockets } =
self.pending_instructions.remove(&node_id).unwrap();
let socket_inputs = self.collect_inputs(index, node_id, sockets.inputs)?;
let socket_outputs = sockets
.outputs
.into_iter()
.map(|socket| {
self.buffers.push(socket.buffer);
let buffer_index = self.buffers.len() - 1;
BufferAddress {
index: buffer_index,
type_: socket.type_,
}
})
.collect();
Ok(Instruction {
node,
socket_inputs,
socket_outputs,
})
}
fn collect_inputs(
&self,
node_index: NodeIndex,
node_id: Id,
input_types: Vec<TypeId>,
) -> Result<Vec<BufferAddress>, BuildError<Id>> {
let mut inputs: Vec<_> = repeat_with(|| Vec::new()).take(input_types.len()).collect();
for edge in self
.relationships
.edges_directed(node_index, Direction::Incoming)
{
let EdgeData {
from_output_socket,
to_input_socket,
} = *edge.weight();
let source_node_id = self.relationships.node_weight(edge.source()).unwrap().id;
let source_node_index = self.lookup[&source_node_id];
let buffer_address = self.instructions[source_node_index]
.output_address(from_output_socket)
.clone();
if input_types[to_input_socket] != buffer_address.type_ {
return Err(BuildError::TypeMismatch(Connection {
from_node: source_node_id,
from_output_socket,
to_node: node_id,
to_input_socket,
}));
}
inputs[to_input_socket].push(buffer_address);
}
inputs
.into_iter()
.enumerate()
.map(|(input_index, sources)| match sources.len() {
0 => Err(BuildError::MissingInput(InputSocket {
node: node_id,
socket_index: input_index,
})),
1 => Ok(sources[0]),
_ => Err(BuildError::MultipleInputs(InputSocket {
node: node_id,
socket_index: input_index,
})),
})
.collect()
}
fn ensure_compilation_complete(&self) -> Result<(), BuildError<Id>> {
if self.pending_instructions.is_empty() {
return Ok(());
}
let mut trace = Vec::new();
let (mut current_id, mut current_weight) = self
.relationships
.node_references()
.find(|(_, node)| node.missing_dependencies != 0)
.unwrap();
loop {
let (edge, dependency_node_weight) = self
.relationships
.edges_directed(current_id, Direction::Incoming)
.find_map(|edge| {
let source_id = edge.source();
let w = self.relationships.node_weight(source_id).unwrap();
(w.missing_dependencies != 0).then(|| (edge, w))
})
.unwrap();
let edge_weight = edge.weight();
trace.push(Connection {
from_node: dependency_node_weight.id,
from_output_socket: edge_weight.from_output_socket,
to_node: current_weight.id,
to_input_socket: edge_weight.to_input_socket,
});
if let Some(cycle_start_index) = trace
.iter()
.position(|connection| connection.to_node == dependency_node_weight.id)
{
trace.drain(0..cycle_start_index);
return Err(BuildError::CircularDependencies(trace));
}
current_weight = dependency_node_weight;
current_id = edge.source();
}
}
}