use crate::{
arena::{Arena, NodeId},
buffer_pool::BufferPool,
node::{DspNode, NodeRecord},
MAX_NODES,
};
use std::collections::HashMap;
pub struct DspGraph {
pub arena: Arena<NodeRecord>,
pub buffers: BufferPool,
pub execution_order: Vec<NodeId>,
pub levels: Vec<Vec<NodeId>>,
pub output_node: Option<NodeId>,
forward_edges: HashMap<u32, Vec<(NodeId, usize)>>,
index_to_id: HashMap<u32, NodeId>,
}
impl DspGraph {
pub fn new() -> Self {
Self {
arena: Arena::with_capacity(MAX_NODES),
buffers: BufferPool::default(),
execution_order: Vec::with_capacity(MAX_NODES),
levels: Vec::with_capacity(MAX_NODES),
output_node: None,
forward_edges: HashMap::new(),
index_to_id: HashMap::new(),
}
}
pub fn add_node(&mut self, processor: Box<dyn DspNode>) -> Option<NodeId> {
let buf = self.buffers.acquire()?;
let record = NodeRecord::new(processor, buf);
let id = self.arena.insert(record)?;
self.forward_edges.insert(id.index, Vec::new());
self.index_to_id.insert(id.index, id);
self.rebuild_execution_order();
Some(id)
}
pub fn remove_node(&mut self, id: NodeId) -> bool {
if let Some(record) = self.arena.remove(id) {
self.buffers.release(record.output_buffer);
self.forward_edges.remove(&id.index);
self.index_to_id.remove(&id.index);
for edges in self.forward_edges.values_mut() {
edges.retain(|(dst, _)| dst.index != id.index);
}
self.rebuild_execution_order();
true
} else {
false
}
}
pub fn connect(&mut self, src: NodeId, dst: NodeId, slot: usize) -> bool {
if self.arena.get(src).is_none() || self.arena.get(dst).is_none() {
return false;
}
if let Some(edges) = self.forward_edges.get_mut(&src.index) {
edges.push((dst, slot));
}
if let Some(record) = self.arena.get_mut(dst) {
record.inputs[slot] = Some(src);
}
self.rebuild_execution_order();
true
}
pub fn disconnect(&mut self, dst: NodeId, slot: usize) -> bool {
let src_id = self.arena.get(dst).and_then(|r| r.inputs[slot]);
if let Some(src) = src_id {
if let Some(edges) = self.forward_edges.get_mut(&src.index) {
edges.retain(|(d, s)| !(d.index == dst.index && *s == slot));
}
}
if let Some(record) = self.arena.get_mut(dst) {
record.inputs[slot] = None;
self.rebuild_execution_order();
true
} else {
false
}
}
fn rebuild_execution_order(&mut self) {
self.execution_order.clear();
self.levels.clear();
let mut in_degree: HashMap<u32, usize> = self.index_to_id.keys().map(|&k| (k, 0)).collect();
for edges in self.forward_edges.values() {
for (dst, _) in edges {
*in_degree.entry(dst.index).or_insert(0) += 1;
}
}
let mut current_wave: Vec<u32> = in_degree
.iter()
.filter(|(_, °)| deg == 0)
.map(|(&idx, _)| idx)
.collect();
while !current_wave.is_empty() {
let mut level_ids: Vec<NodeId> = Vec::with_capacity(current_wave.len());
let mut next_wave: Vec<u32> = Vec::new();
for idx in ¤t_wave {
if let Some(&id) = self.index_to_id.get(idx) {
level_ids.push(id);
self.execution_order.push(id);
}
if let Some(edges) = self.forward_edges.get(idx) {
for (dst, _) in edges.clone() {
let deg = in_degree.entry(dst.index).or_insert(0);
if *deg > 0 {
*deg -= 1;
if *deg == 0 {
next_wave.push(dst.index);
}
}
}
}
}
self.levels.push(level_ids);
current_wave = next_wave;
}
}
pub fn set_output_node(&mut self, id: NodeId) {
self.output_node = Some(id);
}
}
impl Default for DspGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{node::DspNode, param::ParamBlock, BUFFER_SIZE, MAX_INPUTS};
use proptest::prelude::*;
struct TestNode;
impl DspNode for TestNode {
fn process(
&mut self,
_inputs: &[Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS],
output: &mut [f32; BUFFER_SIZE],
_params: &mut ParamBlock,
_sample_rate: f32,
) {
output.fill(0.0);
}
fn type_name(&self) -> &'static str {
"TestNode"
}
}
proptest! {
#[test]
fn prop_topological_level_ordering_invariant(
num_nodes in 1usize..=20,
edges in prop::collection::vec((0usize..20, 0usize..20, 0usize..MAX_INPUTS), 0..50)
) {
let mut graph = DspGraph::new();
let mut node_ids = Vec::new();
for _ in 0..num_nodes {
if let Some(id) = graph.add_node(Box::new(TestNode)) {
node_ids.push(id);
}
}
for &(src_idx, dst_idx, slot) in &edges {
if src_idx < num_nodes && dst_idx < num_nodes && src_idx < dst_idx {
let src = node_ids[src_idx];
let dst = node_ids[dst_idx];
graph.connect(src, dst, slot);
}
}
let mut node_to_level: HashMap<u32, usize> = HashMap::new();
for (level_idx, level_nodes) in graph.levels.iter().enumerate() {
for &node_id in level_nodes {
node_to_level.insert(node_id.index, level_idx);
}
}
for &(src_idx, dst_idx, slot) in &edges {
if src_idx < num_nodes && dst_idx < num_nodes && src_idx < dst_idx {
let src = node_ids[src_idx];
let dst = node_ids[dst_idx];
if let Some(record) = graph.arena.get(dst) {
if record.inputs[slot] == Some(src) {
let src_level = node_to_level.get(&src.index).copied();
let dst_level = node_to_level.get(&dst.index).copied();
if let (Some(src_lvl), Some(dst_lvl)) = (src_level, dst_level) {
prop_assert!(
src_lvl < dst_lvl,
"Level ordering violated: node {} at level {} → node {} at level {}",
src.index, src_lvl, dst.index, dst_lvl
);
}
}
}
}
}
}
}
}