firewheel_graph/
graph.rs

1use core::any::Any;
2use core::fmt::Debug;
3use core::hash::Hash;
4
5#[cfg(not(feature = "std"))]
6use bevy_platform::prelude::{Box, Vec};
7
8use bevy_platform::collections::HashMap;
9use firewheel_core::channel_config::{ChannelConfig, ChannelCount};
10use firewheel_core::event::NodeEvent;
11use firewheel_core::node::{ConstructProcessorContext, UpdateContext};
12use firewheel_core::StreamInfo;
13use smallvec::SmallVec;
14use thunderdome::Arena;
15
16use crate::error::{AddEdgeError, CompileGraphError, RemoveNodeError};
17use crate::graph::dummy_node::{DummyNode, DummyNodeConfig};
18use crate::FirewheelConfig;
19use firewheel_core::node::{
20    AudioNode, AudioNodeInfo, AudioNodeInfoInner, Constructor, DynAudioNode, NodeID,
21};
22
23pub(crate) use self::compiler::{CompiledSchedule, NodeHeapData, ScheduleHeapData};
24
25pub use self::compiler::{Edge, EdgeID, NodeEntry, PortIdx};
26
27mod compiler;
28mod dummy_node;
29
30#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
31struct EdgeHash {
32    pub src_node: NodeID,
33    pub dst_node: NodeID,
34    pub src_port: PortIdx,
35    pub dst_port: PortIdx,
36}
37
38/// The audio graph interface.
39pub(crate) struct AudioGraph {
40    nodes: Arena<NodeEntry>,
41    edges: Arena<Edge>,
42    existing_edges: HashMap<EdgeHash, EdgeID>,
43
44    graph_in_id: NodeID,
45    graph_out_id: NodeID,
46    needs_compile: bool,
47
48    nodes_to_remove_from_schedule: Vec<NodeID>,
49    active_nodes_to_remove: HashMap<NodeID, NodeEntry>,
50    nodes_to_call_update_method: Vec<NodeID>,
51
52    prev_node_arena_capacity: usize,
53}
54
55impl AudioGraph {
56    pub fn new(config: &FirewheelConfig) -> Self {
57        let mut nodes = Arena::with_capacity(config.initial_node_capacity as usize);
58
59        let graph_in_config = DummyNodeConfig {
60            channel_config: ChannelConfig {
61                num_inputs: ChannelCount::ZERO,
62                num_outputs: config.num_graph_inputs,
63            },
64        };
65        let graph_out_config = DummyNodeConfig {
66            channel_config: ChannelConfig {
67                num_inputs: config.num_graph_outputs,
68                num_outputs: ChannelCount::ZERO,
69            },
70        };
71
72        let graph_in_id = NodeID(
73            nodes.insert(NodeEntry::new(
74                AudioNodeInfo::new()
75                    .debug_name("graph_in")
76                    .channel_config(graph_in_config.channel_config)
77                    .into(),
78                Box::new(Constructor::new(DummyNode, Some(graph_in_config))),
79            )),
80        );
81        nodes[graph_in_id.0].id = graph_in_id;
82
83        let graph_out_id = NodeID(
84            nodes.insert(NodeEntry::new(
85                AudioNodeInfo::new()
86                    .debug_name("graph_out")
87                    .channel_config(graph_out_config.channel_config)
88                    .into(),
89                Box::new(Constructor::new(DummyNode, Some(graph_out_config))),
90            )),
91        );
92        nodes[graph_out_id.0].id = graph_out_id;
93
94        Self {
95            nodes,
96            edges: Arena::with_capacity(config.initial_edge_capacity as usize),
97            existing_edges: HashMap::with_capacity(config.initial_edge_capacity as usize),
98            graph_in_id,
99            graph_out_id,
100            needs_compile: true,
101            nodes_to_remove_from_schedule: Vec::with_capacity(
102                config.initial_node_capacity as usize,
103            ),
104            active_nodes_to_remove: HashMap::with_capacity(config.initial_node_capacity as usize),
105            nodes_to_call_update_method: Vec::new(),
106            prev_node_arena_capacity: 0,
107        }
108    }
109
110    /// The ID of the graph input node
111    pub fn graph_in_node(&self) -> NodeID {
112        self.graph_in_id
113    }
114
115    /// The ID of the graph output node
116    pub fn graph_out_node(&self) -> NodeID {
117        self.graph_out_id
118    }
119
120    /// Add a node to the audio graph.
121    pub fn add_node<T: AudioNode + 'static>(
122        &mut self,
123        node: T,
124        config: Option<T::Configuration>,
125    ) -> NodeID {
126        let constructor = Constructor::new(node, config);
127        let info: AudioNodeInfoInner = constructor.info().into();
128        let call_update_method = info.call_update_method;
129
130        let new_id = NodeID(
131            self.nodes
132                .insert(NodeEntry::new(info, Box::new(constructor))),
133        );
134        self.nodes[new_id.0].id = new_id;
135
136        if call_update_method {
137            self.nodes_to_call_update_method.push(new_id);
138        }
139
140        self.needs_compile = true;
141
142        new_id
143    }
144
145    /// Add a node to the audio graph which implements the type-erased [`DynAudioNode`] trait.
146    pub fn add_dyn_node<T: DynAudioNode + 'static>(&mut self, node: T) -> NodeID {
147        let info: AudioNodeInfoInner = node.info().into();
148        let call_update_method = info.call_update_method;
149
150        let new_id = NodeID(self.nodes.insert(NodeEntry::new(info, Box::new(node))));
151        self.nodes[new_id.0].id = new_id;
152
153        if call_update_method {
154            self.nodes_to_call_update_method.push(new_id);
155        }
156
157        self.needs_compile = true;
158
159        new_id
160    }
161
162    /// Remove the given node from the audio graph.
163    ///
164    /// This will automatically remove all edges from the graph that
165    /// were connected to this node.
166    ///
167    /// On success, this returns a list of all edges that were removed
168    /// from the graph as a result of removing this node.
169    ///
170    /// This will return an error if the ID is of the graph input or graph
171    /// output node.
172    pub fn remove_node(
173        &mut self,
174        node_id: NodeID,
175    ) -> Result<SmallVec<[EdgeID; 4]>, RemoveNodeError> {
176        if node_id == self.graph_in_id {
177            return Err(RemoveNodeError::CannotRemoveGraphInNode);
178        }
179        if node_id == self.graph_out_id {
180            return Err(RemoveNodeError::CannotRemoveGraphOutNode);
181        }
182
183        let mut removed_edges = SmallVec::new();
184
185        let Some(node_entry) = self.nodes.remove(node_id.0) else {
186            return Ok(removed_edges);
187        };
188
189        for port_idx in 0..node_entry.info.channel_config.num_inputs.get() {
190            removed_edges.append(&mut self.remove_edges_with_input_port(node_id, port_idx));
191        }
192        for port_idx in 0..node_entry.info.channel_config.num_outputs.get() {
193            removed_edges.append(&mut self.remove_edges_with_output_port(node_id, port_idx));
194        }
195
196        self.nodes_to_remove_from_schedule.push(node_id);
197        self.active_nodes_to_remove.insert(node_id, node_entry);
198
199        self.needs_compile = true;
200
201        Ok(removed_edges)
202    }
203
204    /// Get information about a node in the graph.
205    pub fn node_info(&self, id: NodeID) -> Option<&NodeEntry> {
206        self.nodes.get(id.0)
207    }
208
209    /// Get an immutable reference to the custom state of a node.
210    pub fn node_state<T: 'static>(&self, id: NodeID) -> Option<&T> {
211        self.node_state_dyn(id).and_then(|s| s.downcast_ref())
212    }
213
214    /// Get a type-erased, immutable reference to the custom state of a node.
215    pub fn node_state_dyn(&self, id: NodeID) -> Option<&dyn Any> {
216        self.nodes
217            .get(id.0)
218            .and_then(|node_entry| node_entry.info.custom_state.as_ref().map(|s| s.as_ref()))
219    }
220
221    /// Get a mutable reference to the custom state of a node.
222    pub fn node_state_mut<T: 'static>(&mut self, id: NodeID) -> Option<&mut T> {
223        self.node_state_dyn_mut(id).and_then(|s| s.downcast_mut())
224    }
225
226    /// Get a type-erased, mutable reference to the custom state of a node.
227    pub fn node_state_dyn_mut(&mut self, id: NodeID) -> Option<&mut dyn Any> {
228        self.nodes
229            .get_mut(id.0)
230            .and_then(|node_entry| node_entry.info.custom_state.as_mut().map(|s| s.as_mut()))
231    }
232
233    /// Get a list of all the existing nodes in the graph.
234    pub fn nodes<'a>(&'a self) -> impl Iterator<Item = &'a NodeEntry> {
235        self.nodes.iter().map(|(_, n)| n)
236    }
237
238    /// Get a list of all the existing edges in the graph.
239    pub fn edges<'a>(&'a self) -> impl Iterator<Item = &'a Edge> {
240        self.edges.iter().map(|(_, e)| e)
241    }
242
243    /// Set the number of input and output channels to and from the audio graph.
244    ///
245    /// Returns the list of edges that were removed.
246    pub fn set_graph_channel_config(
247        &mut self,
248        channel_config: ChannelConfig,
249    ) -> SmallVec<[EdgeID; 4]> {
250        let mut removed_edges = SmallVec::new();
251
252        let graph_in_node = self.nodes.get_mut(self.graph_in_id.0).unwrap();
253        if channel_config.num_inputs != graph_in_node.info.channel_config.num_outputs {
254            let old_num_inputs = graph_in_node.info.channel_config.num_outputs;
255            graph_in_node.info.channel_config.num_outputs = channel_config.num_inputs;
256
257            if channel_config.num_inputs < old_num_inputs {
258                for port_idx in channel_config.num_inputs.get()..old_num_inputs.get() {
259                    removed_edges.append(
260                        &mut self.remove_edges_with_output_port(self.graph_in_id, port_idx),
261                    );
262                }
263            }
264
265            self.needs_compile = true;
266        }
267
268        let graph_out_node = self.nodes.get_mut(self.graph_in_id.0).unwrap();
269
270        if channel_config.num_outputs != graph_out_node.info.channel_config.num_inputs {
271            let old_num_outputs = graph_out_node.info.channel_config.num_inputs;
272            graph_out_node.info.channel_config.num_inputs = channel_config.num_outputs;
273
274            if channel_config.num_outputs < old_num_outputs {
275                for port_idx in channel_config.num_outputs.get()..old_num_outputs.get() {
276                    removed_edges.append(
277                        &mut self.remove_edges_with_input_port(self.graph_out_id, port_idx),
278                    );
279                }
280            }
281
282            self.needs_compile = true;
283        }
284
285        removed_edges
286    }
287
288    /// Add connections (edges) between two nodes to the graph.
289    ///
290    /// * `src_node` - The ID of the source node.
291    /// * `dst_node` - The ID of the destination node.
292    /// * `ports_src_dst` - The port indices for each connection to make,
293    /// where the first value in a tuple is the output port on `src_node`,
294    /// and the second value in that tuple is the input port on `dst_node`.
295    /// * `check_for_cycles` - If `true`, then this will run a check to
296    /// see if adding these edges will create a cycle in the graph, and
297    /// return an error if it does. Note, checking for cycles can be quite
298    /// expensive, so avoid enabling this when calling this method many times
299    /// in a row.
300    ///
301    /// If successful, then this returns a list of edge IDs in order.
302    ///
303    /// If this returns an error, then the audio graph has not been
304    /// modified.
305    pub fn connect(
306        &mut self,
307        src_node: NodeID,
308        dst_node: NodeID,
309        ports_src_dst: &[(PortIdx, PortIdx)],
310        check_for_cycles: bool,
311    ) -> Result<SmallVec<[EdgeID; 4]>, AddEdgeError> {
312        let src_node_entry = self
313            .nodes
314            .get(src_node.0)
315            .ok_or(AddEdgeError::SrcNodeNotFound(src_node))?;
316        let dst_node_entry = self
317            .nodes
318            .get(dst_node.0)
319            .ok_or(AddEdgeError::DstNodeNotFound(dst_node))?;
320
321        if src_node.0 == dst_node.0 {
322            return Err(AddEdgeError::CycleDetected);
323        }
324
325        for (src_port, dst_port) in ports_src_dst.iter().copied() {
326            if src_port >= src_node_entry.info.channel_config.num_outputs.get() {
327                return Err(AddEdgeError::OutPortOutOfRange {
328                    node: src_node,
329                    port_idx: src_port,
330                    num_out_ports: src_node_entry.info.channel_config.num_outputs,
331                });
332            }
333            if dst_port >= dst_node_entry.info.channel_config.num_inputs.get() {
334                return Err(AddEdgeError::InPortOutOfRange {
335                    node: dst_node,
336                    port_idx: dst_port,
337                    num_in_ports: dst_node_entry.info.channel_config.num_inputs,
338                });
339            }
340        }
341
342        let mut edge_ids = SmallVec::new();
343
344        for (src_port, dst_port) in ports_src_dst.iter().copied() {
345            if let Some(id) = self.existing_edges.get(&EdgeHash {
346                src_node,
347                src_port,
348                dst_node,
349                dst_port,
350            }) {
351                // The caller gave us more than one of the same edge.
352                edge_ids.push(*id);
353                continue;
354            }
355
356            let new_edge_id = EdgeID(self.edges.insert(Edge {
357                id: EdgeID(thunderdome::Index::DANGLING),
358                src_node,
359                src_port,
360                dst_node,
361                dst_port,
362            }));
363            self.edges[new_edge_id.0].id = new_edge_id;
364            self.existing_edges.insert(
365                EdgeHash {
366                    src_node,
367                    src_port,
368                    dst_node,
369                    dst_port,
370                },
371                new_edge_id,
372            );
373
374            edge_ids.push(new_edge_id);
375        }
376
377        if check_for_cycles {
378            if self.cycle_detected() {
379                self.disconnect(src_node, dst_node, ports_src_dst);
380
381                return Err(AddEdgeError::CycleDetected);
382            }
383        }
384
385        self.needs_compile = true;
386
387        Ok(edge_ids)
388    }
389
390    /// Remove connections (edges) between two nodes from the graph.
391    ///
392    /// * `src_node` - The ID of the source node.
393    /// * `dst_node` - The ID of the destination node.
394    /// * `ports_src_dst` - The port indices for each connection to make,
395    /// where the first value in a tuple is the output port on `src_node`,
396    /// and the second value in that tuple is the input port on `dst_node`.
397    ///
398    /// If none of the edges existed in the graph, then `false` will be
399    /// returned.
400    pub fn disconnect(
401        &mut self,
402        src_node: NodeID,
403        dst_node: NodeID,
404        ports_src_dst: &[(PortIdx, PortIdx)],
405    ) -> bool {
406        let mut any_removed = false;
407
408        for (src_port, dst_port) in ports_src_dst.iter().copied() {
409            if let Some(edge_id) = self.existing_edges.remove(&EdgeHash {
410                src_node,
411                src_port: src_port.into(),
412                dst_node,
413                dst_port: dst_port.into(),
414            }) {
415                self.disconnect_by_edge_id(edge_id);
416                any_removed = true;
417            }
418        }
419
420        any_removed
421    }
422
423    /// Remove all connections (edges) between two nodes in the graph.
424    ///
425    /// * `src_node` - The ID of the source node.
426    /// * `dst_node` - The ID of the destination node.
427    pub fn disconnect_all_between(
428        &mut self,
429        src_node: NodeID,
430        dst_node: NodeID,
431    ) -> SmallVec<[EdgeID; 4]> {
432        let mut removed_edges = SmallVec::new();
433
434        if !self.nodes.contains(src_node.0) || !self.nodes.contains(dst_node.0) {
435            return removed_edges;
436        };
437
438        for (edge_id, edge) in self.edges.iter() {
439            if edge.src_node == src_node && edge.dst_node == dst_node {
440                removed_edges.push(EdgeID(edge_id));
441            }
442        }
443
444        for &edge_id in removed_edges.iter() {
445            self.disconnect_by_edge_id(edge_id);
446        }
447
448        removed_edges
449    }
450
451    /// Remove a connection (edge) via the edge's unique ID.
452    ///
453    /// If the edge did not exist in this graph, then `false` will be returned.
454    pub fn disconnect_by_edge_id(&mut self, edge_id: EdgeID) -> bool {
455        if let Some(edge) = self.edges.remove(edge_id.0) {
456            self.existing_edges.remove(&EdgeHash {
457                src_node: edge.src_node,
458                src_port: edge.src_port,
459                dst_node: edge.dst_node,
460                dst_port: edge.dst_port,
461            });
462
463            self.needs_compile = true;
464
465            true
466        } else {
467            false
468        }
469    }
470
471    /// Get information about the given [Edge]
472    pub fn edge(&self, edge_id: EdgeID) -> Option<&Edge> {
473        self.edges.get(edge_id.0)
474    }
475
476    fn remove_edges_with_input_port(
477        &mut self,
478        node_id: NodeID,
479        port_idx: PortIdx,
480    ) -> SmallVec<[EdgeID; 4]> {
481        let mut edges_to_remove = SmallVec::new();
482
483        // Remove all existing edges which have this port.
484        for (edge_id, edge) in self.edges.iter() {
485            if edge.dst_node == node_id && edge.dst_port == port_idx {
486                edges_to_remove.push(EdgeID(edge_id));
487            }
488        }
489
490        for edge_id in edges_to_remove.iter() {
491            self.disconnect_by_edge_id(*edge_id);
492        }
493
494        edges_to_remove
495    }
496
497    fn remove_edges_with_output_port(
498        &mut self,
499        node_id: NodeID,
500        port_idx: PortIdx,
501    ) -> SmallVec<[EdgeID; 4]> {
502        let mut edges_to_remove = SmallVec::new();
503
504        // Remove all existing edges which have this port.
505        for (edge_id, edge) in self.edges.iter() {
506            if edge.src_node == node_id && edge.src_port == port_idx {
507                edges_to_remove.push(EdgeID(edge_id));
508            }
509        }
510
511        for edge_id in edges_to_remove.iter() {
512            self.disconnect_by_edge_id(*edge_id);
513        }
514
515        edges_to_remove
516    }
517
518    pub fn cycle_detected(&mut self) -> bool {
519        compiler::cycle_detected(
520            &mut self.nodes,
521            &mut self.edges,
522            self.graph_in_id,
523            self.graph_out_id,
524        )
525    }
526
527    pub(crate) fn needs_compile(&self) -> bool {
528        self.needs_compile
529    }
530
531    pub(crate) fn on_schedule_send_failed(&mut self, failed_schedule: Box<ScheduleHeapData>) {
532        self.needs_compile = true;
533
534        for node in failed_schedule.new_node_processors.iter() {
535            if let Some(node_entry) = &mut self.nodes.get_mut(node.id.0) {
536                node_entry.processor_constructed = false;
537            }
538        }
539    }
540
541    pub(crate) fn deactivate(&mut self) {
542        self.needs_compile = true;
543    }
544
545    pub(crate) fn compile(
546        &mut self,
547        stream_info: &StreamInfo,
548    ) -> Result<Box<ScheduleHeapData>, CompileGraphError> {
549        let schedule = self.compile_internal(stream_info.max_block_frames.get() as usize)?;
550
551        let mut new_node_processors = Vec::new();
552        for (_, entry) in self.nodes.iter_mut() {
553            if !entry.processor_constructed {
554                entry.processor_constructed = true;
555
556                let cx = ConstructProcessorContext::new(
557                    entry.id,
558                    stream_info,
559                    &mut entry.info.custom_state,
560                );
561
562                new_node_processors.push(NodeHeapData {
563                    id: entry.id,
564                    processor: entry.dyn_node.construct_processor(cx),
565                    is_pre_process: entry.info.channel_config.is_empty(),
566                });
567            }
568        }
569
570        let mut nodes_to_remove = Vec::new();
571        core::mem::swap(
572            &mut self.nodes_to_remove_from_schedule,
573            &mut nodes_to_remove,
574        );
575
576        let new_arena = if self.nodes.capacity() > self.prev_node_arena_capacity {
577            Some(Arena::with_capacity(self.nodes.capacity()))
578        } else {
579            None
580        };
581        self.prev_node_arena_capacity = self.nodes.capacity();
582
583        let schedule_data = Box::new(ScheduleHeapData::new(
584            schedule,
585            nodes_to_remove,
586            new_node_processors,
587            new_arena,
588        ));
589
590        self.needs_compile = false;
591
592        log::debug!("compiled new audio graph: {:?}", &schedule_data);
593
594        Ok(schedule_data)
595    }
596
597    fn compile_internal(
598        &mut self,
599        max_block_frames: usize,
600    ) -> Result<CompiledSchedule, CompileGraphError> {
601        assert!(max_block_frames > 0);
602
603        compiler::compile(
604            &mut self.nodes,
605            &mut self.edges,
606            self.graph_in_id,
607            self.graph_out_id,
608            max_block_frames,
609        )
610    }
611
612    pub(crate) fn update(
613        &mut self,
614        stream_info: Option<&StreamInfo>,
615        event_queue: &mut Vec<NodeEvent>,
616    ) {
617        let mut cull_list = false;
618        for node_id in self.nodes_to_call_update_method.iter() {
619            if let Some(node_entry) = self.nodes.get_mut(node_id.0) {
620                node_entry.dyn_node.update(UpdateContext::new(
621                    *node_id,
622                    stream_info,
623                    &mut node_entry.info.custom_state,
624                    event_queue,
625                ));
626            } else {
627                cull_list = true;
628            }
629        }
630
631        if cull_list {
632            self.nodes_to_call_update_method
633                .retain(|node_id| self.nodes.contains(node_id.0));
634        }
635    }
636}