firewheel_graph/
graph.rs

1mod compiler;
2
3use core::any::Any;
4use core::fmt::Debug;
5use core::hash::Hash;
6
7#[cfg(not(feature = "std"))]
8use bevy_platform::prelude::{Box, Vec};
9
10use bevy_platform::collections::HashMap;
11use firewheel_core::channel_config::{ChannelConfig, ChannelCount};
12use firewheel_core::event::NodeEvent;
13use firewheel_core::node::{ConstructProcessorContext, UpdateContext};
14use firewheel_core::StreamInfo;
15use smallvec::SmallVec;
16use thunderdome::Arena;
17
18use crate::error::{AddEdgeError, CompileGraphError, RemoveNodeError};
19use crate::FirewheelConfig;
20use firewheel_core::node::{
21    dummy::{DummyNode, DummyNodeConfig},
22    AudioNode, AudioNodeInfo, AudioNodeInfoInner, Constructor, DynAudioNode, NodeID,
23};
24
25pub(crate) use self::compiler::{CompiledSchedule, NodeHeapData, ScheduleHeapData};
26
27pub use self::compiler::{Edge, EdgeID, NodeEntry, PortIdx};
28
29#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
30struct EdgeHash {
31    pub src_node: NodeID,
32    pub dst_node: NodeID,
33    pub src_port: PortIdx,
34    pub dst_port: PortIdx,
35}
36
37/// The audio graph interface.
38pub(crate) struct AudioGraph {
39    nodes: Arena<NodeEntry>,
40    edges: Arena<Edge>,
41    existing_edges: HashMap<EdgeHash, EdgeID>,
42
43    graph_in_id: NodeID,
44    graph_out_id: NodeID,
45    needs_compile: bool,
46
47    nodes_to_remove_from_schedule: Vec<NodeID>,
48    active_nodes_to_remove: HashMap<NodeID, NodeEntry>,
49    nodes_to_call_update_method: Vec<NodeID>,
50
51    prev_node_arena_capacity: usize,
52}
53
54impl AudioGraph {
55    pub fn new(config: &FirewheelConfig) -> Self {
56        let mut nodes = Arena::with_capacity(config.initial_node_capacity as usize);
57
58        let graph_in_config = DummyNodeConfig {
59            channel_config: ChannelConfig {
60                num_inputs: ChannelCount::ZERO,
61                num_outputs: config.num_graph_inputs,
62            },
63        };
64        let graph_out_config = DummyNodeConfig {
65            channel_config: ChannelConfig {
66                num_inputs: config.num_graph_outputs,
67                num_outputs: ChannelCount::ZERO,
68            },
69        };
70
71        let graph_in_id = NodeID(
72            nodes.insert(NodeEntry::new(
73                AudioNodeInfo::new()
74                    .debug_name("graph_in")
75                    .channel_config(graph_in_config.channel_config)
76                    .into(),
77                Box::new(Constructor::new(DummyNode, Some(graph_in_config))),
78            )),
79        );
80        nodes[graph_in_id.0].id = graph_in_id;
81
82        let graph_out_id = NodeID(
83            nodes.insert(NodeEntry::new(
84                AudioNodeInfo::new()
85                    .debug_name("graph_out")
86                    .channel_config(graph_out_config.channel_config)
87                    .into(),
88                Box::new(Constructor::new(DummyNode, Some(graph_out_config))),
89            )),
90        );
91        nodes[graph_out_id.0].id = graph_out_id;
92
93        Self {
94            nodes,
95            edges: Arena::with_capacity(config.initial_edge_capacity as usize),
96            existing_edges: HashMap::with_capacity(config.initial_edge_capacity as usize),
97            graph_in_id,
98            graph_out_id,
99            needs_compile: true,
100            nodes_to_remove_from_schedule: Vec::with_capacity(
101                config.initial_node_capacity as usize,
102            ),
103            active_nodes_to_remove: HashMap::with_capacity(config.initial_node_capacity as usize),
104            nodes_to_call_update_method: Vec::new(),
105            prev_node_arena_capacity: 0,
106        }
107    }
108
109    /// The ID of the graph input node
110    pub fn graph_in_node(&self) -> NodeID {
111        self.graph_in_id
112    }
113
114    /// The ID of the graph output node
115    pub fn graph_out_node(&self) -> NodeID {
116        self.graph_out_id
117    }
118
119    /// Add a node to the audio graph.
120    pub fn add_node<T: AudioNode + 'static>(
121        &mut self,
122        node: T,
123        config: Option<T::Configuration>,
124    ) -> NodeID {
125        let constructor = Constructor::new(node, config);
126        let info: AudioNodeInfoInner = constructor.info().into();
127        let call_update_method = info.call_update_method;
128
129        let new_id = NodeID(
130            self.nodes
131                .insert(NodeEntry::new(info, Box::new(constructor))),
132        );
133        self.nodes[new_id.0].id = new_id;
134
135        if call_update_method {
136            self.nodes_to_call_update_method.push(new_id);
137        }
138
139        self.needs_compile = true;
140
141        new_id
142    }
143
144    /// Add a node to the audio graph which implements the type-erased [`DynAudioNode`] trait.
145    pub fn add_dyn_node<T: DynAudioNode + 'static>(&mut self, node: T) -> NodeID {
146        let info: AudioNodeInfoInner = node.info().into();
147        let call_update_method = info.call_update_method;
148
149        let new_id = NodeID(self.nodes.insert(NodeEntry::new(info, Box::new(node))));
150        self.nodes[new_id.0].id = new_id;
151
152        if call_update_method {
153            self.nodes_to_call_update_method.push(new_id);
154        }
155
156        self.needs_compile = true;
157
158        new_id
159    }
160
161    /// Remove the given node from the audio graph.
162    ///
163    /// This will automatically remove all edges from the graph that
164    /// were connected to this node.
165    ///
166    /// On success, this returns a list of all edges that were removed
167    /// from the graph as a result of removing this node.
168    ///
169    /// This will return an error if the ID is of the graph input or graph
170    /// output node.
171    pub fn remove_node(
172        &mut self,
173        node_id: NodeID,
174    ) -> Result<SmallVec<[EdgeID; 4]>, RemoveNodeError> {
175        if node_id == self.graph_in_id {
176            return Err(RemoveNodeError::CannotRemoveGraphInNode);
177        }
178        if node_id == self.graph_out_id {
179            return Err(RemoveNodeError::CannotRemoveGraphOutNode);
180        }
181
182        let mut removed_edges = SmallVec::new();
183
184        let Some(node_entry) = self.nodes.remove(node_id.0) else {
185            return Ok(removed_edges);
186        };
187
188        for port_idx in 0..node_entry.info.channel_config.num_inputs.get() {
189            removed_edges.append(&mut self.remove_edges_with_input_port(node_id, port_idx));
190        }
191        for port_idx in 0..node_entry.info.channel_config.num_outputs.get() {
192            removed_edges.append(&mut self.remove_edges_with_output_port(node_id, port_idx));
193        }
194
195        self.nodes_to_remove_from_schedule.push(node_id);
196        self.active_nodes_to_remove.insert(node_id, node_entry);
197
198        self.needs_compile = true;
199
200        Ok(removed_edges)
201    }
202
203    /// Get information about a node in the graph.
204    pub fn node_info(&self, id: NodeID) -> Option<&NodeEntry> {
205        self.nodes.get(id.0)
206    }
207
208    /// Get an immutable reference to the custom state of a node.
209    pub fn node_state<T: 'static>(&self, id: NodeID) -> Option<&T> {
210        self.node_state_dyn(id).and_then(|s| s.downcast_ref())
211    }
212
213    /// Get a type-erased, immutable reference to the custom state of a node.
214    pub fn node_state_dyn(&self, id: NodeID) -> Option<&dyn Any> {
215        self.nodes
216            .get(id.0)
217            .and_then(|node_entry| node_entry.info.custom_state.as_ref().map(|s| s.as_ref()))
218    }
219
220    /// Get a mutable reference to the custom state of a node.
221    pub fn node_state_mut<T: 'static>(&mut self, id: NodeID) -> Option<&mut T> {
222        self.node_state_dyn_mut(id).and_then(|s| s.downcast_mut())
223    }
224
225    /// Get a type-erased, mutable reference to the custom state of a node.
226    pub fn node_state_dyn_mut(&mut self, id: NodeID) -> Option<&mut dyn Any> {
227        self.nodes
228            .get_mut(id.0)
229            .and_then(|node_entry| node_entry.info.custom_state.as_mut().map(|s| s.as_mut()))
230    }
231
232    /// Get a list of all the existing nodes in the graph.
233    pub fn nodes<'a>(&'a self) -> impl Iterator<Item = &'a NodeEntry> {
234        self.nodes.iter().map(|(_, n)| n)
235    }
236
237    /// Get a list of all the existing edges in the graph.
238    pub fn edges<'a>(&'a self) -> impl Iterator<Item = &'a Edge> {
239        self.edges.iter().map(|(_, e)| e)
240    }
241
242    /// Set the number of input and output channels to and from the audio graph.
243    ///
244    /// Returns the list of edges that were removed.
245    pub fn set_graph_channel_config(
246        &mut self,
247        channel_config: ChannelConfig,
248    ) -> SmallVec<[EdgeID; 4]> {
249        let mut removed_edges = SmallVec::new();
250
251        let graph_in_node = self.nodes.get_mut(self.graph_in_id.0).unwrap();
252        if channel_config.num_inputs != graph_in_node.info.channel_config.num_outputs {
253            let old_num_inputs = graph_in_node.info.channel_config.num_outputs;
254            graph_in_node.info.channel_config.num_outputs = channel_config.num_inputs;
255
256            if channel_config.num_inputs < old_num_inputs {
257                for port_idx in channel_config.num_inputs.get()..old_num_inputs.get() {
258                    removed_edges.append(
259                        &mut self.remove_edges_with_output_port(self.graph_in_id, port_idx),
260                    );
261                }
262            }
263
264            self.needs_compile = true;
265        }
266
267        let graph_out_node = self.nodes.get_mut(self.graph_in_id.0).unwrap();
268
269        if channel_config.num_outputs != graph_out_node.info.channel_config.num_inputs {
270            let old_num_outputs = graph_out_node.info.channel_config.num_inputs;
271            graph_out_node.info.channel_config.num_inputs = channel_config.num_outputs;
272
273            if channel_config.num_outputs < old_num_outputs {
274                for port_idx in channel_config.num_outputs.get()..old_num_outputs.get() {
275                    removed_edges.append(
276                        &mut self.remove_edges_with_input_port(self.graph_out_id, port_idx),
277                    );
278                }
279            }
280
281            self.needs_compile = true;
282        }
283
284        removed_edges
285    }
286
287    /// Add connections (edges) between two nodes to the graph.
288    ///
289    /// * `src_node` - The ID of the source node.
290    /// * `dst_node` - The ID of the destination node.
291    /// * `ports_src_dst` - The port indices for each connection to make,
292    /// where the first value in a tuple is the output port on `src_node`,
293    /// and the second value in that tuple is the input port on `dst_node`.
294    /// * `check_for_cycles` - If `true`, then this will run a check to
295    /// see if adding these edges will create a cycle in the graph, and
296    /// return an error if it does. Note, checking for cycles can be quite
297    /// expensive, so avoid enabling this when calling this method many times
298    /// in a row.
299    ///
300    /// If successful, then this returns a list of edge IDs in order.
301    ///
302    /// If this returns an error, then the audio graph has not been
303    /// modified.
304    pub fn connect(
305        &mut self,
306        src_node: NodeID,
307        dst_node: NodeID,
308        ports_src_dst: &[(PortIdx, PortIdx)],
309        check_for_cycles: bool,
310    ) -> Result<SmallVec<[EdgeID; 4]>, AddEdgeError> {
311        let src_node_entry = self
312            .nodes
313            .get(src_node.0)
314            .ok_or(AddEdgeError::SrcNodeNotFound(src_node))?;
315        let dst_node_entry = self
316            .nodes
317            .get(dst_node.0)
318            .ok_or(AddEdgeError::DstNodeNotFound(dst_node))?;
319
320        if src_node.0 == dst_node.0 {
321            return Err(AddEdgeError::CycleDetected);
322        }
323
324        for (src_port, dst_port) in ports_src_dst.iter().copied() {
325            if src_port >= src_node_entry.info.channel_config.num_outputs.get() {
326                return Err(AddEdgeError::OutPortOutOfRange {
327                    node: src_node,
328                    port_idx: src_port,
329                    num_out_ports: src_node_entry.info.channel_config.num_outputs,
330                });
331            }
332            if dst_port >= dst_node_entry.info.channel_config.num_inputs.get() {
333                return Err(AddEdgeError::InPortOutOfRange {
334                    node: dst_node,
335                    port_idx: dst_port,
336                    num_in_ports: dst_node_entry.info.channel_config.num_inputs,
337                });
338            }
339        }
340
341        let mut edge_ids = SmallVec::new();
342
343        for (src_port, dst_port) in ports_src_dst.iter().copied() {
344            if let Some(id) = self.existing_edges.get(&EdgeHash {
345                src_node,
346                src_port,
347                dst_node,
348                dst_port,
349            }) {
350                // The caller gave us more than one of the same edge.
351                edge_ids.push(*id);
352                continue;
353            }
354
355            let new_edge_id = EdgeID(self.edges.insert(Edge {
356                id: EdgeID(thunderdome::Index::DANGLING),
357                src_node,
358                src_port,
359                dst_node,
360                dst_port,
361            }));
362            self.edges[new_edge_id.0].id = new_edge_id;
363            self.existing_edges.insert(
364                EdgeHash {
365                    src_node,
366                    src_port,
367                    dst_node,
368                    dst_port,
369                },
370                new_edge_id,
371            );
372
373            edge_ids.push(new_edge_id);
374        }
375
376        if check_for_cycles {
377            if self.cycle_detected() {
378                self.disconnect(src_node, dst_node, ports_src_dst);
379
380                return Err(AddEdgeError::CycleDetected);
381            }
382        }
383
384        self.needs_compile = true;
385
386        Ok(edge_ids)
387    }
388
389    /// Remove connections (edges) between two nodes from the graph.
390    ///
391    /// * `src_node` - The ID of the source node.
392    /// * `dst_node` - The ID of the destination node.
393    /// * `ports_src_dst` - The port indices for each connection to make,
394    /// where the first value in a tuple is the output port on `src_node`,
395    /// and the second value in that tuple is the input port on `dst_node`.
396    ///
397    /// If none of the edges existed in the graph, then `false` will be
398    /// returned.
399    pub fn disconnect(
400        &mut self,
401        src_node: NodeID,
402        dst_node: NodeID,
403        ports_src_dst: &[(PortIdx, PortIdx)],
404    ) -> bool {
405        let mut any_removed = false;
406
407        for (src_port, dst_port) in ports_src_dst.iter().copied() {
408            if let Some(edge_id) = self.existing_edges.remove(&EdgeHash {
409                src_node,
410                src_port: src_port.into(),
411                dst_node,
412                dst_port: dst_port.into(),
413            }) {
414                self.disconnect_by_edge_id(edge_id);
415                any_removed = true;
416            }
417        }
418
419        any_removed
420    }
421
422    /// Remove all connections (edges) between two nodes in the graph.
423    ///
424    /// * `src_node` - The ID of the source node.
425    /// * `dst_node` - The ID of the destination node.
426    pub fn disconnect_all_between(
427        &mut self,
428        src_node: NodeID,
429        dst_node: NodeID,
430    ) -> SmallVec<[EdgeID; 4]> {
431        let mut removed_edges = SmallVec::new();
432
433        if !self.nodes.contains(src_node.0) || !self.nodes.contains(dst_node.0) {
434            return removed_edges;
435        };
436
437        for (edge_id, edge) in self.edges.iter() {
438            if edge.src_node == src_node && edge.dst_node == dst_node {
439                removed_edges.push(EdgeID(edge_id));
440            }
441        }
442
443        for &edge_id in removed_edges.iter() {
444            self.disconnect_by_edge_id(edge_id);
445        }
446
447        removed_edges
448    }
449
450    /// Remove a connection (edge) via the edge's unique ID.
451    ///
452    /// If the edge did not exist in this graph, then `false` will be returned.
453    pub fn disconnect_by_edge_id(&mut self, edge_id: EdgeID) -> bool {
454        if let Some(edge) = self.edges.remove(edge_id.0) {
455            self.existing_edges.remove(&EdgeHash {
456                src_node: edge.src_node,
457                src_port: edge.src_port,
458                dst_node: edge.dst_node,
459                dst_port: edge.dst_port,
460            });
461
462            self.needs_compile = true;
463
464            true
465        } else {
466            false
467        }
468    }
469
470    /// Get information about the given [Edge]
471    pub fn edge(&self, edge_id: EdgeID) -> Option<&Edge> {
472        self.edges.get(edge_id.0)
473    }
474
475    fn remove_edges_with_input_port(
476        &mut self,
477        node_id: NodeID,
478        port_idx: PortIdx,
479    ) -> SmallVec<[EdgeID; 4]> {
480        let mut edges_to_remove = SmallVec::new();
481
482        // Remove all existing edges which have this port.
483        for (edge_id, edge) in self.edges.iter() {
484            if edge.dst_node == node_id && edge.dst_port == port_idx {
485                edges_to_remove.push(EdgeID(edge_id));
486            }
487        }
488
489        for edge_id in edges_to_remove.iter() {
490            self.disconnect_by_edge_id(*edge_id);
491        }
492
493        edges_to_remove
494    }
495
496    fn remove_edges_with_output_port(
497        &mut self,
498        node_id: NodeID,
499        port_idx: PortIdx,
500    ) -> SmallVec<[EdgeID; 4]> {
501        let mut edges_to_remove = SmallVec::new();
502
503        // Remove all existing edges which have this port.
504        for (edge_id, edge) in self.edges.iter() {
505            if edge.src_node == node_id && edge.src_port == port_idx {
506                edges_to_remove.push(EdgeID(edge_id));
507            }
508        }
509
510        for edge_id in edges_to_remove.iter() {
511            self.disconnect_by_edge_id(*edge_id);
512        }
513
514        edges_to_remove
515    }
516
517    pub fn cycle_detected(&mut self) -> bool {
518        compiler::cycle_detected(
519            &mut self.nodes,
520            &mut self.edges,
521            self.graph_in_id,
522            self.graph_out_id,
523        )
524    }
525
526    pub(crate) fn needs_compile(&self) -> bool {
527        self.needs_compile
528    }
529
530    pub(crate) fn on_schedule_send_failed(&mut self, failed_schedule: Box<ScheduleHeapData>) {
531        self.needs_compile = true;
532
533        for node in failed_schedule.new_node_processors.iter() {
534            if let Some(node_entry) = &mut self.nodes.get_mut(node.id.0) {
535                node_entry.processor_constructed = false;
536            }
537        }
538    }
539
540    pub(crate) fn deactivate(&mut self) {
541        self.needs_compile = true;
542    }
543
544    pub(crate) fn compile(
545        &mut self,
546        stream_info: &StreamInfo,
547    ) -> Result<Box<ScheduleHeapData>, CompileGraphError> {
548        let schedule = self.compile_internal(stream_info.max_block_frames.get() as usize)?;
549
550        let mut new_node_processors = Vec::new();
551        for (_, entry) in self.nodes.iter_mut() {
552            if !entry.processor_constructed {
553                entry.processor_constructed = true;
554
555                let cx = ConstructProcessorContext::new(
556                    entry.id,
557                    stream_info,
558                    &mut entry.info.custom_state,
559                );
560
561                new_node_processors.push(NodeHeapData {
562                    id: entry.id,
563                    processor: entry.dyn_node.construct_processor(cx),
564                });
565            }
566        }
567
568        let mut nodes_to_remove = Vec::new();
569        core::mem::swap(
570            &mut self.nodes_to_remove_from_schedule,
571            &mut nodes_to_remove,
572        );
573
574        let new_arena = if self.nodes.capacity() > self.prev_node_arena_capacity {
575            Some(Arena::with_capacity(self.nodes.capacity()))
576        } else {
577            None
578        };
579        self.prev_node_arena_capacity = self.nodes.capacity();
580
581        let schedule_data = Box::new(ScheduleHeapData::new(
582            schedule,
583            nodes_to_remove,
584            new_node_processors,
585            new_arena,
586        ));
587
588        self.needs_compile = false;
589
590        log::debug!("compiled new audio graph: {:?}", &schedule_data);
591
592        Ok(schedule_data)
593    }
594
595    fn compile_internal(
596        &mut self,
597        max_block_frames: usize,
598    ) -> Result<CompiledSchedule, CompileGraphError> {
599        assert!(max_block_frames > 0);
600
601        compiler::compile(
602            &mut self.nodes,
603            &mut self.edges,
604            self.graph_in_id,
605            self.graph_out_id,
606            max_block_frames,
607        )
608    }
609
610    pub(crate) fn update(
611        &mut self,
612        stream_info: Option<&StreamInfo>,
613        event_queue: &mut Vec<NodeEvent>,
614    ) {
615        let mut cull_list = false;
616        for node_id in self.nodes_to_call_update_method.iter() {
617            if let Some(node_entry) = self.nodes.get_mut(node_id.0) {
618                node_entry.dyn_node.update(UpdateContext::new(
619                    *node_id,
620                    stream_info,
621                    &mut node_entry.info.custom_state,
622                    event_queue,
623                ));
624            } else {
625                cull_list = true;
626            }
627        }
628
629        if cull_list {
630            self.nodes_to_call_update_method
631                .retain(|node_id| self.nodes.contains(node_id.0));
632        }
633    }
634}