Skip to main content

firewheel_graph/
graph.rs

1use core::any::Any;
2use core::fmt::Debug;
3use core::hash::Hash;
4
5#[cfg(not(feature = "std"))]
6use bevy_platform::prelude::{Box, Vec};
7
8use bevy_platform::collections::HashMap;
9use firewheel_core::channel_config::{ChannelConfig, ChannelCount};
10use firewheel_core::event::NodeEvent;
11use firewheel_core::node::{ConstructProcessorContext, UpdateContext};
12use firewheel_core::StreamInfo;
13use smallvec::SmallVec;
14use thunderdome::Arena;
15
16use crate::error::{AddEdgeError, CompileGraphError, RemoveNodeError};
17use crate::graph::dummy_node::{DummyNode, DummyNodeConfig};
18use crate::FirewheelConfig;
19use firewheel_core::node::{
20    AudioNode, AudioNodeInfo, AudioNodeInfoInner, Constructor, DynAudioNode, NodeID,
21};
22
23pub(crate) use self::compiler::{CompiledSchedule, NodeHeapData, ScheduleHeapData};
24
25pub use self::compiler::{Edge, EdgeID, NodeEntry, PortIdx};
26
27mod compiler;
28mod dummy_node;
29
30#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
31struct EdgeHash {
32    pub src_node: NodeID,
33    pub dst_node: NodeID,
34    pub src_port: PortIdx,
35    pub dst_port: PortIdx,
36}
37
38/// The audio graph interface.
39pub(crate) struct AudioGraph {
40    nodes: Arena<NodeEntry>,
41    edges: Arena<Edge>,
42    existing_edges: HashMap<EdgeHash, EdgeID>,
43
44    graph_in_id: NodeID,
45    graph_out_id: NodeID,
46    needs_compile: bool,
47
48    nodes_to_remove_from_schedule: Vec<NodeID>,
49    active_nodes_to_remove: HashMap<NodeID, NodeEntry>,
50    nodes_to_call_update_method: Vec<NodeID>,
51
52    prev_node_arena_capacity: usize,
53}
54
55impl AudioGraph {
56    pub fn new(config: &FirewheelConfig) -> Self {
57        let mut nodes = Arena::with_capacity(config.initial_node_capacity as usize);
58
59        let graph_in_config = DummyNodeConfig {
60            channel_config: ChannelConfig {
61                num_inputs: ChannelCount::ZERO,
62                num_outputs: config.num_graph_inputs,
63            },
64        };
65        let graph_out_config = DummyNodeConfig {
66            channel_config: ChannelConfig {
67                num_inputs: config.num_graph_outputs,
68                num_outputs: ChannelCount::ZERO,
69            },
70        };
71
72        let graph_in_id = NodeID(
73            nodes.insert(NodeEntry::new(
74                AudioNodeInfo::new()
75                    .debug_name("graph_in")
76                    .channel_config(graph_in_config.channel_config)
77                    .into(),
78                Box::new(Constructor::new(DummyNode, Some(graph_in_config))),
79            )),
80        );
81        nodes[graph_in_id.0].id = graph_in_id;
82
83        let graph_out_id = NodeID(
84            nodes.insert(NodeEntry::new(
85                AudioNodeInfo::new()
86                    .debug_name("graph_out")
87                    .channel_config(graph_out_config.channel_config)
88                    .into(),
89                Box::new(Constructor::new(DummyNode, Some(graph_out_config))),
90            )),
91        );
92        nodes[graph_out_id.0].id = graph_out_id;
93
94        Self {
95            nodes,
96            edges: Arena::with_capacity(config.initial_edge_capacity as usize),
97            existing_edges: HashMap::with_capacity(config.initial_edge_capacity as usize),
98            graph_in_id,
99            graph_out_id,
100            needs_compile: true,
101            nodes_to_remove_from_schedule: Vec::with_capacity(
102                config.initial_node_capacity as usize,
103            ),
104            active_nodes_to_remove: HashMap::with_capacity(config.initial_node_capacity as usize),
105            nodes_to_call_update_method: Vec::new(),
106            prev_node_arena_capacity: 0,
107        }
108    }
109
110    /// The ID of the graph input node
111    pub fn graph_in_node(&self) -> NodeID {
112        self.graph_in_id
113    }
114
115    /// The ID of the graph output node
116    pub fn graph_out_node(&self) -> NodeID {
117        self.graph_out_id
118    }
119
120    /// Add a node to the audio graph.
121    pub fn add_node<T: AudioNode + 'static>(
122        &mut self,
123        node: T,
124        config: Option<T::Configuration>,
125    ) -> NodeID {
126        let constructor = Constructor::new(node, config);
127        let info: AudioNodeInfoInner = constructor.info().into();
128        let call_update_method = info.call_update_method;
129
130        let new_id = NodeID(
131            self.nodes
132                .insert(NodeEntry::new(info, Box::new(constructor))),
133        );
134        self.nodes[new_id.0].id = new_id;
135
136        if call_update_method {
137            self.nodes_to_call_update_method.push(new_id);
138        }
139
140        self.needs_compile = true;
141
142        new_id
143    }
144
145    /// Add a node to the audio graph which implements the type-erased [`DynAudioNode`] trait.
146    pub fn add_dyn_node<T: DynAudioNode + 'static>(&mut self, node: T) -> NodeID {
147        let info: AudioNodeInfoInner = node.info().into();
148        let call_update_method = info.call_update_method;
149
150        let new_id = NodeID(self.nodes.insert(NodeEntry::new(info, Box::new(node))));
151        self.nodes[new_id.0].id = new_id;
152
153        if call_update_method {
154            self.nodes_to_call_update_method.push(new_id);
155        }
156
157        self.needs_compile = true;
158
159        new_id
160    }
161
162    /// Remove the given node from the audio graph.
163    ///
164    /// This will automatically remove all edges from the graph that
165    /// were connected to this node.
166    ///
167    /// On success, this returns a list of all edges that were removed
168    /// from the graph as a result of removing this node.
169    ///
170    /// This will return an error if the ID is of the graph input or graph
171    /// output node.
172    pub fn remove_node(
173        &mut self,
174        node_id: NodeID,
175    ) -> Result<SmallVec<[EdgeID; 4]>, RemoveNodeError> {
176        if node_id == self.graph_in_id {
177            return Err(RemoveNodeError::CannotRemoveGraphInNode);
178        }
179        if node_id == self.graph_out_id {
180            return Err(RemoveNodeError::CannotRemoveGraphOutNode);
181        }
182
183        let mut removed_edges = SmallVec::new();
184
185        let Some(node_entry) = self.nodes.remove(node_id.0) else {
186            return Ok(removed_edges);
187        };
188
189        for port_idx in 0..node_entry.info.channel_config.num_inputs.get() {
190            removed_edges.append(&mut self.remove_edges_with_input_port(node_id, port_idx));
191        }
192        for port_idx in 0..node_entry.info.channel_config.num_outputs.get() {
193            removed_edges.append(&mut self.remove_edges_with_output_port(node_id, port_idx));
194        }
195
196        self.nodes_to_remove_from_schedule.push(node_id);
197        self.active_nodes_to_remove.insert(node_id, node_entry);
198
199        self.needs_compile = true;
200
201        Ok(removed_edges)
202    }
203
204    /// Get information about a node in the graph.
205    pub fn node_info(&self, id: NodeID) -> Option<&NodeEntry> {
206        self.nodes.get(id.0)
207    }
208
209    /// Get an immutable reference to the custom state of a node.
210    pub fn node_state<T: 'static>(&self, id: NodeID) -> Option<&T> {
211        self.node_state_dyn(id).and_then(|s| s.downcast_ref())
212    }
213
214    /// Get a type-erased, immutable reference to the custom state of a node.
215    pub fn node_state_dyn(&self, id: NodeID) -> Option<&dyn Any> {
216        self.nodes
217            .get(id.0)
218            .and_then(|node_entry| node_entry.info.custom_state.as_ref().map(|s| s.as_ref()))
219    }
220
221    /// Get a mutable reference to the custom state of a node.
222    pub fn node_state_mut<T: 'static>(&mut self, id: NodeID) -> Option<&mut T> {
223        self.node_state_dyn_mut(id).and_then(|s| s.downcast_mut())
224    }
225
226    /// Get a type-erased, mutable reference to the custom state of a node.
227    pub fn node_state_dyn_mut(&mut self, id: NodeID) -> Option<&mut dyn Any> {
228        self.nodes
229            .get_mut(id.0)
230            .and_then(|node_entry| node_entry.info.custom_state.as_mut().map(|s| s.as_mut()))
231    }
232
233    /// Get a list of all the existing nodes in the graph.
234    pub fn nodes(&self) -> impl Iterator<Item = &NodeEntry> {
235        self.nodes.iter().map(|(_, n)| n)
236    }
237
238    /// Get a list of all the existing edges in the graph.
239    pub fn edges(&self) -> impl Iterator<Item = &Edge> {
240        self.edges.iter().map(|(_, e)| e)
241    }
242
243    /// Set the number of input and output channels to and from the audio graph.
244    ///
245    /// Returns the list of edges that were removed.
246    pub fn set_graph_channel_config(
247        &mut self,
248        channel_config: ChannelConfig,
249    ) -> SmallVec<[EdgeID; 4]> {
250        let mut removed_edges = SmallVec::new();
251
252        let graph_in_node = self.nodes.get_mut(self.graph_in_id.0).unwrap();
253        if channel_config.num_inputs != graph_in_node.info.channel_config.num_outputs {
254            let old_num_inputs = graph_in_node.info.channel_config.num_outputs;
255            graph_in_node.info.channel_config.num_outputs = channel_config.num_inputs;
256
257            if channel_config.num_inputs < old_num_inputs {
258                for port_idx in channel_config.num_inputs.get()..old_num_inputs.get() {
259                    removed_edges.append(
260                        &mut self.remove_edges_with_output_port(self.graph_in_id, port_idx),
261                    );
262                }
263            }
264
265            self.needs_compile = true;
266        }
267
268        let graph_out_node = self.nodes.get_mut(self.graph_in_id.0).unwrap();
269
270        if channel_config.num_outputs != graph_out_node.info.channel_config.num_inputs {
271            let old_num_outputs = graph_out_node.info.channel_config.num_inputs;
272            graph_out_node.info.channel_config.num_inputs = channel_config.num_outputs;
273
274            if channel_config.num_outputs < old_num_outputs {
275                for port_idx in channel_config.num_outputs.get()..old_num_outputs.get() {
276                    removed_edges.append(
277                        &mut self.remove_edges_with_input_port(self.graph_out_id, port_idx),
278                    );
279                }
280            }
281
282            self.needs_compile = true;
283        }
284
285        removed_edges
286    }
287
288    /// Add connections (edges) between two nodes to the graph.
289    ///
290    /// * `src_node` - The ID of the source node.
291    /// * `dst_node` - The ID of the destination node.
292    /// * `ports_src_dst` - The port indices for each connection to make,
293    ///   where the first value in a tuple is the output port on `src_node`,
294    ///   and the second value in that tuple is the input port on `dst_node`.
295    /// * `check_for_cycles` - If `true`, then this will run a check to
296    ///   see if adding these edges will create a cycle in the graph, and
297    ///   return an error if it does. Note, checking for cycles can be quite
298    ///   expensive, so avoid enabling this when calling this method many times
299    ///   in a row.
300    ///
301    /// If successful, then this returns a list of edge IDs in order.
302    ///
303    /// If this returns an error, then the audio graph has not been
304    /// modified.
305    pub fn connect(
306        &mut self,
307        src_node: NodeID,
308        dst_node: NodeID,
309        ports_src_dst: &[(PortIdx, PortIdx)],
310        check_for_cycles: bool,
311    ) -> Result<SmallVec<[EdgeID; 4]>, AddEdgeError> {
312        let src_node_entry = self
313            .nodes
314            .get(src_node.0)
315            .ok_or(AddEdgeError::SrcNodeNotFound(src_node))?;
316        let dst_node_entry = self
317            .nodes
318            .get(dst_node.0)
319            .ok_or(AddEdgeError::DstNodeNotFound(dst_node))?;
320
321        if src_node.0 == dst_node.0 {
322            return Err(AddEdgeError::CycleDetected);
323        }
324
325        for (src_port, dst_port) in ports_src_dst.iter().copied() {
326            if src_port >= src_node_entry.info.channel_config.num_outputs.get() {
327                return Err(AddEdgeError::OutPortOutOfRange {
328                    node: src_node,
329                    port_idx: src_port,
330                    num_out_ports: src_node_entry.info.channel_config.num_outputs,
331                });
332            }
333            if dst_port >= dst_node_entry.info.channel_config.num_inputs.get() {
334                return Err(AddEdgeError::InPortOutOfRange {
335                    node: dst_node,
336                    port_idx: dst_port,
337                    num_in_ports: dst_node_entry.info.channel_config.num_inputs,
338                });
339            }
340        }
341
342        let mut edge_ids = SmallVec::new();
343
344        for (src_port, dst_port) in ports_src_dst.iter().copied() {
345            if let Some(id) = self.existing_edges.get(&EdgeHash {
346                src_node,
347                src_port,
348                dst_node,
349                dst_port,
350            }) {
351                // The caller gave us more than one of the same edge.
352                edge_ids.push(*id);
353                continue;
354            }
355
356            let new_edge_id = EdgeID(self.edges.insert(Edge {
357                id: EdgeID(thunderdome::Index::DANGLING),
358                src_node,
359                src_port,
360                dst_node,
361                dst_port,
362            }));
363            self.edges[new_edge_id.0].id = new_edge_id;
364            self.existing_edges.insert(
365                EdgeHash {
366                    src_node,
367                    src_port,
368                    dst_node,
369                    dst_port,
370                },
371                new_edge_id,
372            );
373
374            edge_ids.push(new_edge_id);
375        }
376
377        if check_for_cycles && self.cycle_detected() {
378            self.disconnect(src_node, dst_node, ports_src_dst);
379
380            return Err(AddEdgeError::CycleDetected);
381        }
382
383        self.needs_compile = true;
384
385        Ok(edge_ids)
386    }
387
388    /// Remove connections (edges) between two nodes from the graph.
389    ///
390    /// * `src_node` - The ID of the source node.
391    /// * `dst_node` - The ID of the destination node.
392    /// * `ports_src_dst` - The port indices for each connection to make,
393    ///   where the first value in a tuple is the output port on `src_node`,
394    ///   and the second value in that tuple is the input port on `dst_node`.
395    ///
396    /// If none of the edges existed in the graph, then `false` will be
397    /// returned.
398    pub fn disconnect(
399        &mut self,
400        src_node: NodeID,
401        dst_node: NodeID,
402        ports_src_dst: &[(PortIdx, PortIdx)],
403    ) -> bool {
404        let mut any_removed = false;
405
406        for (src_port, dst_port) in ports_src_dst.iter().copied() {
407            if let Some(edge_id) = self.existing_edges.remove(&EdgeHash {
408                src_node,
409                src_port,
410                dst_node,
411                dst_port,
412            }) {
413                self.disconnect_by_edge_id(edge_id);
414                any_removed = true;
415            }
416        }
417
418        any_removed
419    }
420
421    /// Remove all connections (edges) between two nodes in the graph.
422    ///
423    /// * `src_node` - The ID of the source node.
424    /// * `dst_node` - The ID of the destination node.
425    pub fn disconnect_all_between(
426        &mut self,
427        src_node: NodeID,
428        dst_node: NodeID,
429    ) -> SmallVec<[EdgeID; 4]> {
430        let mut removed_edges = SmallVec::new();
431
432        if !self.nodes.contains(src_node.0) || !self.nodes.contains(dst_node.0) {
433            return removed_edges;
434        };
435
436        for (edge_id, edge) in self.edges.iter() {
437            if edge.src_node == src_node && edge.dst_node == dst_node {
438                removed_edges.push(EdgeID(edge_id));
439            }
440        }
441
442        for &edge_id in removed_edges.iter() {
443            self.disconnect_by_edge_id(edge_id);
444        }
445
446        removed_edges
447    }
448
449    /// Remove a connection (edge) via the edge's unique ID.
450    ///
451    /// If the edge did not exist in this graph, then `false` will be returned.
452    pub fn disconnect_by_edge_id(&mut self, edge_id: EdgeID) -> bool {
453        if let Some(edge) = self.edges.remove(edge_id.0) {
454            self.existing_edges.remove(&EdgeHash {
455                src_node: edge.src_node,
456                src_port: edge.src_port,
457                dst_node: edge.dst_node,
458                dst_port: edge.dst_port,
459            });
460
461            self.needs_compile = true;
462
463            true
464        } else {
465            false
466        }
467    }
468
469    /// Get information about the given [Edge]
470    pub fn edge(&self, edge_id: EdgeID) -> Option<&Edge> {
471        self.edges.get(edge_id.0)
472    }
473
474    fn remove_edges_with_input_port(
475        &mut self,
476        node_id: NodeID,
477        port_idx: PortIdx,
478    ) -> SmallVec<[EdgeID; 4]> {
479        let mut edges_to_remove = SmallVec::new();
480
481        // Remove all existing edges which have this port.
482        for (edge_id, edge) in self.edges.iter() {
483            if edge.dst_node == node_id && edge.dst_port == port_idx {
484                edges_to_remove.push(EdgeID(edge_id));
485            }
486        }
487
488        for edge_id in edges_to_remove.iter() {
489            self.disconnect_by_edge_id(*edge_id);
490        }
491
492        edges_to_remove
493    }
494
495    fn remove_edges_with_output_port(
496        &mut self,
497        node_id: NodeID,
498        port_idx: PortIdx,
499    ) -> SmallVec<[EdgeID; 4]> {
500        let mut edges_to_remove = SmallVec::new();
501
502        // Remove all existing edges which have this port.
503        for (edge_id, edge) in self.edges.iter() {
504            if edge.src_node == node_id && edge.src_port == port_idx {
505                edges_to_remove.push(EdgeID(edge_id));
506            }
507        }
508
509        for edge_id in edges_to_remove.iter() {
510            self.disconnect_by_edge_id(*edge_id);
511        }
512
513        edges_to_remove
514    }
515
516    pub fn cycle_detected(&mut self) -> bool {
517        compiler::cycle_detected(
518            &mut self.nodes,
519            &mut self.edges,
520            self.graph_in_id,
521            self.graph_out_id,
522        )
523    }
524
525    pub(crate) fn needs_compile(&self) -> bool {
526        self.needs_compile
527    }
528
529    pub(crate) fn on_schedule_send_failed(&mut self, failed_schedule: Box<ScheduleHeapData>) {
530        self.needs_compile = true;
531
532        for node in failed_schedule.new_node_processors.iter() {
533            if let Some(node_entry) = &mut self.nodes.get_mut(node.id.0) {
534                node_entry.processor_constructed = false;
535            }
536        }
537    }
538
539    pub(crate) fn deactivate(&mut self) {
540        self.needs_compile = true;
541    }
542
543    pub(crate) fn compile(
544        &mut self,
545        stream_info: &StreamInfo,
546    ) -> Result<Box<ScheduleHeapData>, CompileGraphError> {
547        let schedule = self.compile_internal(stream_info.max_block_frames.get() as usize)?;
548
549        let mut new_node_processors = Vec::new();
550        for (_, entry) in self.nodes.iter_mut() {
551            if !entry.processor_constructed {
552                entry.processor_constructed = true;
553
554                let cx = ConstructProcessorContext::new(
555                    entry.id,
556                    stream_info,
557                    &mut entry.info.custom_state,
558                );
559
560                new_node_processors.push(NodeHeapData {
561                    id: entry.id,
562                    processor: entry.dyn_node.construct_processor(cx),
563                    is_pre_process: entry.info.channel_config.is_empty(),
564                });
565            }
566        }
567
568        let mut nodes_to_remove = Vec::new();
569        core::mem::swap(
570            &mut self.nodes_to_remove_from_schedule,
571            &mut nodes_to_remove,
572        );
573
574        let new_arena = if self.nodes.capacity() > self.prev_node_arena_capacity {
575            Some(Arena::with_capacity(self.nodes.capacity()))
576        } else {
577            None
578        };
579        self.prev_node_arena_capacity = self.nodes.capacity();
580
581        let schedule_data = Box::new(ScheduleHeapData::new(
582            schedule,
583            nodes_to_remove,
584            new_node_processors,
585            new_arena,
586        ));
587
588        self.needs_compile = false;
589
590        #[cfg(feature = "tracing")]
591        tracing::debug!("compiled new audio graph: {:?}", &schedule_data);
592
593        #[cfg(all(feature = "log", not(feature = "tracing")))]
594        log::debug!("compiled new audio graph: {:?}", &schedule_data);
595
596        Ok(schedule_data)
597    }
598
599    fn compile_internal(
600        &mut self,
601        max_block_frames: usize,
602    ) -> Result<CompiledSchedule, CompileGraphError> {
603        assert!(max_block_frames > 0);
604
605        compiler::compile(
606            &mut self.nodes,
607            &mut self.edges,
608            self.graph_in_id,
609            self.graph_out_id,
610            max_block_frames,
611        )
612    }
613
614    pub(crate) fn update(
615        &mut self,
616        stream_info: Option<&StreamInfo>,
617        event_queue: &mut Vec<NodeEvent>,
618    ) {
619        let mut cull_list = false;
620        for node_id in self.nodes_to_call_update_method.iter() {
621            if let Some(node_entry) = self.nodes.get_mut(node_id.0) {
622                node_entry.dyn_node.update(UpdateContext::new(
623                    *node_id,
624                    stream_info,
625                    &mut node_entry.info.custom_state,
626                    event_queue,
627                ));
628            } else {
629                cull_list = true;
630            }
631        }
632
633        if cull_list {
634            self.nodes_to_call_update_method
635                .retain(|node_id| self.nodes.contains(node_id.0));
636        }
637    }
638}