firewheel_graph/
graph.rs

1mod compiler;
2
3use core::any::Any;
4use core::fmt::Debug;
5use core::hash::Hash;
6
7use bevy_platform::collections::HashMap;
8use firewheel_core::channel_config::{ChannelConfig, ChannelCount};
9use firewheel_core::event::NodeEvent;
10use firewheel_core::node::{ConstructProcessorContext, UpdateContext};
11use firewheel_core::StreamInfo;
12use smallvec::SmallVec;
13use thunderdome::Arena;
14
15use crate::error::{AddEdgeError, CompileGraphError};
16use crate::FirewheelConfig;
17use firewheel_core::node::{
18    dummy::{DummyNode, DummyNodeConfig},
19    AudioNode, AudioNodeInfo, AudioNodeInfoInner, Constructor, DynAudioNode, NodeID,
20};
21
22pub(crate) use self::compiler::{CompiledSchedule, NodeHeapData, ScheduleHeapData};
23
24pub use self::compiler::{Edge, EdgeID, NodeEntry, PortIdx};
25
26#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
27struct EdgeHash {
28    pub src_node: NodeID,
29    pub dst_node: NodeID,
30    pub src_port: PortIdx,
31    pub dst_port: PortIdx,
32}
33
34/// The audio graph interface.
35pub(crate) struct AudioGraph {
36    nodes: Arena<NodeEntry>,
37    edges: Arena<Edge>,
38    existing_edges: HashMap<EdgeHash, EdgeID>,
39
40    graph_in_id: NodeID,
41    graph_out_id: NodeID,
42    needs_compile: bool,
43
44    nodes_to_remove_from_schedule: Vec<NodeID>,
45    active_nodes_to_remove: HashMap<NodeID, NodeEntry>,
46    nodes_to_call_update_method: Vec<NodeID>,
47    event_queue_capacity: usize,
48}
49
50impl AudioGraph {
51    pub fn new(config: &FirewheelConfig) -> Self {
52        let mut nodes = Arena::with_capacity(config.initial_node_capacity as usize);
53
54        let graph_in_config = DummyNodeConfig {
55            channel_config: ChannelConfig {
56                num_inputs: ChannelCount::ZERO,
57                num_outputs: config.num_graph_inputs,
58            },
59        };
60        let graph_out_config = DummyNodeConfig {
61            channel_config: ChannelConfig {
62                num_inputs: config.num_graph_outputs,
63                num_outputs: ChannelCount::ZERO,
64            },
65        };
66
67        let graph_in_id = NodeID(
68            nodes.insert(NodeEntry::new(
69                AudioNodeInfo::new()
70                    .debug_name("graph_in")
71                    .channel_config(graph_in_config.channel_config)
72                    .uses_events(false)
73                    .into(),
74                Box::new(Constructor::new(DummyNode, Some(graph_in_config))),
75            )),
76        );
77        nodes[graph_in_id.0].id = graph_in_id;
78
79        let graph_out_id = NodeID(
80            nodes.insert(NodeEntry::new(
81                AudioNodeInfo::new()
82                    .debug_name("graph_out")
83                    .channel_config(graph_out_config.channel_config)
84                    .uses_events(false)
85                    .into(),
86                Box::new(Constructor::new(DummyNode, Some(graph_out_config))),
87            )),
88        );
89        nodes[graph_out_id.0].id = graph_out_id;
90
91        Self {
92            nodes,
93            edges: Arena::with_capacity(config.initial_edge_capacity as usize),
94            existing_edges: HashMap::with_capacity(config.initial_edge_capacity as usize),
95            graph_in_id,
96            graph_out_id,
97            needs_compile: true,
98            nodes_to_remove_from_schedule: Vec::with_capacity(
99                config.initial_node_capacity as usize,
100            ),
101            active_nodes_to_remove: HashMap::with_capacity(config.initial_node_capacity as usize),
102            nodes_to_call_update_method: Vec::new(),
103            event_queue_capacity: 0, // This will be overwritten later once activated.
104        }
105    }
106
107    /// The ID of the graph input node
108    pub fn graph_in_node(&self) -> NodeID {
109        self.graph_in_id
110    }
111
112    /// The ID of the graph output node
113    pub fn graph_out_node(&self) -> NodeID {
114        self.graph_out_id
115    }
116
117    /// Add a node to the audio graph.
118    pub fn add_node<T: AudioNode + 'static>(
119        &mut self,
120        node: T,
121        config: Option<T::Configuration>,
122    ) -> NodeID {
123        let constructor = Constructor::new(node, config);
124        let info: AudioNodeInfoInner = constructor.info().into();
125        let call_update_method = info.call_update_method;
126
127        let new_id = NodeID(
128            self.nodes
129                .insert(NodeEntry::new(info, Box::new(constructor))),
130        );
131        self.nodes[new_id.0].id = new_id;
132
133        if call_update_method {
134            self.nodes_to_call_update_method.push(new_id);
135        }
136
137        self.needs_compile = true;
138
139        new_id
140    }
141
142    /// Add a node to the audio graph which implements the type-erased [`DynAudioNode`] trait.
143    pub fn add_dyn_node<T: DynAudioNode + 'static>(&mut self, node: T) -> NodeID {
144        let info: AudioNodeInfoInner = node.info().into();
145        let call_update_method = info.call_update_method;
146
147        let new_id = NodeID(self.nodes.insert(NodeEntry::new(info, Box::new(node))));
148        self.nodes[new_id.0].id = new_id;
149
150        if call_update_method {
151            self.nodes_to_call_update_method.push(new_id);
152        }
153
154        self.needs_compile = true;
155
156        new_id
157    }
158
159    /// Remove the given node from the audio graph.
160    ///
161    /// This will automatically remove all edges from the graph that
162    /// were connected to this node.
163    ///
164    /// On success, this returns a list of all edges that were removed
165    /// from the graph as a result of removing this node.
166    ///
167    /// This will return an error if a node with the given ID does not
168    /// exist in the graph, or if the ID is of the graph input or graph
169    /// output node.
170    pub fn remove_node(&mut self, node_id: NodeID) -> Result<SmallVec<[EdgeID; 4]>, ()> {
171        if node_id == self.graph_in_id || node_id == self.graph_out_id {
172            return Err(());
173        }
174
175        let node_entry = self.nodes.remove(node_id.0).ok_or(())?;
176
177        let mut removed_edges = SmallVec::new();
178
179        for port_idx in 0..node_entry.info.channel_config.num_inputs.get() {
180            removed_edges.append(&mut self.remove_edges_with_input_port(node_id, port_idx));
181        }
182        for port_idx in 0..node_entry.info.channel_config.num_outputs.get() {
183            removed_edges.append(&mut self.remove_edges_with_output_port(node_id, port_idx));
184        }
185
186        self.nodes_to_remove_from_schedule.push(node_id);
187        self.active_nodes_to_remove.insert(node_id, node_entry);
188
189        self.needs_compile = true;
190
191        Ok(removed_edges)
192    }
193
194    /// Get information about a node in the graph.
195    pub fn node_info(&self, id: NodeID) -> Option<&NodeEntry> {
196        self.nodes.get(id.0)
197    }
198
199    /// Get an immutable reference to the custom state of a node.
200    pub fn node_state<T: 'static>(&self, id: NodeID) -> Option<&T> {
201        self.node_state_dyn(id).and_then(|s| s.downcast_ref())
202    }
203
204    /// Get a type-erased, immutable reference to the custom state of a node.
205    pub fn node_state_dyn(&self, id: NodeID) -> Option<&dyn Any> {
206        self.nodes
207            .get(id.0)
208            .and_then(|node_entry| node_entry.info.custom_state.as_ref().map(|s| s.as_ref()))
209    }
210
211    /// Get a mutable reference to the custom state of a node.
212    pub fn node_state_mut<T: 'static>(&mut self, id: NodeID) -> Option<&mut T> {
213        self.node_state_dyn_mut(id).and_then(|s| s.downcast_mut())
214    }
215
216    /// Get a type-erased, mutable reference to the custom state of a node.
217    pub fn node_state_dyn_mut(&mut self, id: NodeID) -> Option<&mut dyn Any> {
218        self.nodes
219            .get_mut(id.0)
220            .and_then(|node_entry| node_entry.info.custom_state.as_mut().map(|s| s.as_mut()))
221    }
222
223    /// Get a list of all the existing nodes in the graph.
224    pub fn nodes<'a>(&'a self) -> impl Iterator<Item = &'a NodeEntry> {
225        self.nodes.iter().map(|(_, n)| n)
226    }
227
228    /// Get a list of all the existing edges in the graph.
229    pub fn edges<'a>(&'a self) -> impl Iterator<Item = &'a Edge> {
230        self.edges.iter().map(|(_, e)| e)
231    }
232
233    /// Set the number of input and output channels to and from the audio graph.
234    ///
235    /// Returns the list of edges that were removed.
236    pub fn set_graph_channel_config(
237        &mut self,
238        channel_config: ChannelConfig,
239    ) -> SmallVec<[EdgeID; 4]> {
240        let mut removed_edges = SmallVec::new();
241
242        let graph_in_node = self.nodes.get_mut(self.graph_in_id.0).unwrap();
243        if channel_config.num_inputs != graph_in_node.info.channel_config.num_outputs {
244            let old_num_inputs = graph_in_node.info.channel_config.num_outputs;
245            graph_in_node.info.channel_config.num_outputs = channel_config.num_inputs;
246
247            if channel_config.num_inputs < old_num_inputs {
248                for port_idx in channel_config.num_inputs.get()..old_num_inputs.get() {
249                    removed_edges.append(
250                        &mut self.remove_edges_with_output_port(self.graph_in_id, port_idx),
251                    );
252                }
253            }
254
255            self.needs_compile = true;
256        }
257
258        let graph_out_node = self.nodes.get_mut(self.graph_in_id.0).unwrap();
259
260        if channel_config.num_outputs != graph_out_node.info.channel_config.num_inputs {
261            let old_num_outputs = graph_out_node.info.channel_config.num_inputs;
262            graph_out_node.info.channel_config.num_inputs = channel_config.num_outputs;
263
264            if channel_config.num_outputs < old_num_outputs {
265                for port_idx in channel_config.num_outputs.get()..old_num_outputs.get() {
266                    removed_edges.append(
267                        &mut self.remove_edges_with_input_port(self.graph_out_id, port_idx),
268                    );
269                }
270            }
271
272            self.needs_compile = true;
273        }
274
275        removed_edges
276    }
277
278    /// Add connections (edges) between two nodes to the graph.
279    ///
280    /// * `src_node` - The ID of the source node.
281    /// * `dst_node` - The ID of the destination node.
282    /// * `ports_src_dst` - The port indices for each connection to make,
283    /// where the first value in a tuple is the output port on `src_node`,
284    /// and the second value in that tuple is the input port on `dst_node`.
285    /// * `check_for_cycles` - If `true`, then this will run a check to
286    /// see if adding these edges will create a cycle in the graph, and
287    /// return an error if it does. Note, checking for cycles can be quite
288    /// expensive, so avoid enabling this when calling this method many times
289    /// in a row.
290    ///
291    /// If successful, then this returns a list of edge IDs in order.
292    ///
293    /// If this returns an error, then the audio graph has not been
294    /// modified.
295    pub fn connect(
296        &mut self,
297        src_node: NodeID,
298        dst_node: NodeID,
299        ports_src_dst: &[(PortIdx, PortIdx)],
300        check_for_cycles: bool,
301    ) -> Result<SmallVec<[EdgeID; 4]>, AddEdgeError> {
302        let src_node_entry = self
303            .nodes
304            .get(src_node.0)
305            .ok_or(AddEdgeError::SrcNodeNotFound(src_node))?;
306        let dst_node_entry = self
307            .nodes
308            .get(dst_node.0)
309            .ok_or(AddEdgeError::DstNodeNotFound(dst_node))?;
310
311        if src_node.0 == dst_node.0 {
312            return Err(AddEdgeError::CycleDetected);
313        }
314
315        for (src_port, dst_port) in ports_src_dst.iter().copied() {
316            if src_port >= src_node_entry.info.channel_config.num_outputs.get() {
317                return Err(AddEdgeError::OutPortOutOfRange {
318                    node: src_node,
319                    port_idx: src_port,
320                    num_out_ports: src_node_entry.info.channel_config.num_outputs,
321                });
322            }
323            if dst_port >= dst_node_entry.info.channel_config.num_inputs.get() {
324                return Err(AddEdgeError::InPortOutOfRange {
325                    node: dst_node,
326                    port_idx: dst_port,
327                    num_in_ports: dst_node_entry.info.channel_config.num_inputs,
328                });
329            }
330        }
331
332        let mut edge_ids = SmallVec::new();
333
334        for (src_port, dst_port) in ports_src_dst.iter().copied() {
335            if let Some(id) = self.existing_edges.get(&EdgeHash {
336                src_node,
337                src_port,
338                dst_node,
339                dst_port,
340            }) {
341                // The caller gave us more than one of the same edge.
342                edge_ids.push(*id);
343                continue;
344            }
345
346            let new_edge_id = EdgeID(self.edges.insert(Edge {
347                id: EdgeID(thunderdome::Index::DANGLING),
348                src_node,
349                src_port,
350                dst_node,
351                dst_port,
352            }));
353            self.edges[new_edge_id.0].id = new_edge_id;
354            self.existing_edges.insert(
355                EdgeHash {
356                    src_node,
357                    src_port,
358                    dst_node,
359                    dst_port,
360                },
361                new_edge_id,
362            );
363
364            edge_ids.push(new_edge_id);
365        }
366
367        if check_for_cycles {
368            if self.cycle_detected() {
369                self.disconnect(src_node, dst_node, ports_src_dst);
370
371                return Err(AddEdgeError::CycleDetected);
372            }
373        }
374
375        self.needs_compile = true;
376
377        Ok(edge_ids)
378    }
379
380    /// Remove connections (edges) between two nodes from the graph.
381    ///
382    /// * `src_node` - The ID of the source node.
383    /// * `dst_node` - The ID of the destination node.
384    /// * `ports_src_dst` - The port indices for each connection to make,
385    /// where the first value in a tuple is the output port on `src_node`,
386    /// and the second value in that tuple is the input port on `dst_node`.
387    ///
388    /// If none of the edges existed in the graph, then `false` will be
389    /// returned.
390    pub fn disconnect(
391        &mut self,
392        src_node: NodeID,
393        dst_node: NodeID,
394        ports_src_dst: &[(PortIdx, PortIdx)],
395    ) -> bool {
396        let mut any_removed = false;
397
398        for (src_port, dst_port) in ports_src_dst.iter().copied() {
399            if let Some(edge_id) = self.existing_edges.remove(&EdgeHash {
400                src_node,
401                src_port: src_port.into(),
402                dst_node,
403                dst_port: dst_port.into(),
404            }) {
405                self.disconnect_by_edge_id(edge_id);
406                any_removed = true;
407            }
408        }
409
410        any_removed
411    }
412
413    /// Remove all connections (edges) between two nodes in the graph.
414    ///
415    /// * `src_node` - The ID of the source node.
416    /// * `dst_node` - The ID of the destination node.
417    pub fn disconnect_all_between(
418        &mut self,
419        src_node: NodeID,
420        dst_node: NodeID,
421    ) -> SmallVec<[EdgeID; 4]> {
422        let mut removed_edges = SmallVec::new();
423
424        if !self.nodes.contains(src_node.0) || !self.nodes.contains(dst_node.0) {
425            return removed_edges;
426        };
427
428        for (edge_id, edge) in self.edges.iter() {
429            if edge.src_node == src_node && edge.dst_node == dst_node {
430                removed_edges.push(EdgeID(edge_id));
431            }
432        }
433
434        for &edge_id in removed_edges.iter() {
435            self.disconnect_by_edge_id(edge_id);
436        }
437
438        removed_edges
439    }
440
441    /// Remove a connection (edge) via the edge's unique ID.
442    ///
443    /// If the edge did not exist in this graph, then `false` will be returned.
444    pub fn disconnect_by_edge_id(&mut self, edge_id: EdgeID) -> bool {
445        if let Some(edge) = self.edges.remove(edge_id.0) {
446            self.existing_edges.remove(&EdgeHash {
447                src_node: edge.src_node,
448                src_port: edge.src_port,
449                dst_node: edge.dst_node,
450                dst_port: edge.dst_port,
451            });
452
453            self.needs_compile = true;
454
455            true
456        } else {
457            false
458        }
459    }
460
461    /// Get information about the given [Edge]
462    pub fn edge(&self, edge_id: EdgeID) -> Option<&Edge> {
463        self.edges.get(edge_id.0)
464    }
465
466    fn remove_edges_with_input_port(
467        &mut self,
468        node_id: NodeID,
469        port_idx: PortIdx,
470    ) -> SmallVec<[EdgeID; 4]> {
471        let mut edges_to_remove = SmallVec::new();
472
473        // Remove all existing edges which have this port.
474        for (edge_id, edge) in self.edges.iter() {
475            if edge.dst_node == node_id && edge.dst_port == port_idx {
476                edges_to_remove.push(EdgeID(edge_id));
477            }
478        }
479
480        for edge_id in edges_to_remove.iter() {
481            self.disconnect_by_edge_id(*edge_id);
482        }
483
484        edges_to_remove
485    }
486
487    fn remove_edges_with_output_port(
488        &mut self,
489        node_id: NodeID,
490        port_idx: PortIdx,
491    ) -> SmallVec<[EdgeID; 4]> {
492        let mut edges_to_remove = SmallVec::new();
493
494        // Remove all existing edges which have this port.
495        for (edge_id, edge) in self.edges.iter() {
496            if edge.src_node == node_id && edge.src_port == port_idx {
497                edges_to_remove.push(EdgeID(edge_id));
498            }
499        }
500
501        for edge_id in edges_to_remove.iter() {
502            self.disconnect_by_edge_id(*edge_id);
503        }
504
505        edges_to_remove
506    }
507
508    pub fn cycle_detected(&mut self) -> bool {
509        compiler::cycle_detected(
510            &mut self.nodes,
511            &mut self.edges,
512            self.graph_in_id,
513            self.graph_out_id,
514        )
515    }
516
517    pub(crate) fn needs_compile(&self) -> bool {
518        self.needs_compile
519    }
520
521    pub(crate) fn on_schedule_send_failed(&mut self, failed_schedule: Box<ScheduleHeapData>) {
522        self.needs_compile = true;
523
524        for node in failed_schedule.new_node_processors.iter() {
525            if let Some(node_entry) = &mut self.nodes.get_mut(node.id.0) {
526                node_entry.activated = false;
527            }
528        }
529    }
530
531    pub(crate) fn node_capacity(&self) -> usize {
532        self.nodes.capacity()
533    }
534
535    pub(crate) fn deactivate(&mut self) {
536        for (_, entry) in self.nodes.iter_mut() {
537            entry.activated = false;
538        }
539        self.needs_compile = true;
540    }
541
542    pub(crate) fn compile(
543        &mut self,
544        stream_info: &StreamInfo,
545    ) -> Result<Box<ScheduleHeapData>, CompileGraphError> {
546        let schedule = self.compile_internal(stream_info.max_block_frames.get() as usize)?;
547
548        let mut new_node_processors = Vec::new();
549        for (_, entry) in self.nodes.iter_mut() {
550            if !entry.activated {
551                entry.activated = true;
552
553                let event_buffer_indices = if entry.info.uses_events {
554                    Vec::with_capacity(self.event_queue_capacity)
555                } else {
556                    Vec::new()
557                };
558
559                let cx = ConstructProcessorContext::new(
560                    entry.id,
561                    stream_info,
562                    &mut entry.info.custom_state,
563                );
564
565                new_node_processors.push(NodeHeapData {
566                    id: entry.id,
567                    processor: entry.dyn_node.construct_processor(cx),
568                    event_buffer_indices,
569                });
570            }
571        }
572
573        let mut nodes_to_remove = Vec::new();
574        core::mem::swap(
575            &mut self.nodes_to_remove_from_schedule,
576            &mut nodes_to_remove,
577        );
578
579        let schedule_data = Box::new(ScheduleHeapData::new(
580            schedule,
581            nodes_to_remove,
582            new_node_processors,
583        ));
584
585        self.needs_compile = false;
586
587        log::debug!("compiled new audio graph: {:?}", &schedule_data);
588
589        Ok(schedule_data)
590    }
591
592    fn compile_internal(
593        &mut self,
594        max_block_frames: usize,
595    ) -> Result<CompiledSchedule, CompileGraphError> {
596        assert!(max_block_frames > 0);
597
598        compiler::compile(
599            &mut self.nodes,
600            &mut self.edges,
601            self.graph_in_id,
602            self.graph_out_id,
603            max_block_frames,
604        )
605    }
606
607    pub(crate) fn update(
608        &mut self,
609        stream_info: Option<&StreamInfo>,
610        event_queue: &mut Vec<NodeEvent>,
611    ) {
612        let mut cull_list = false;
613        for node_id in self.nodes_to_call_update_method.iter() {
614            if let Some(node_entry) = self.nodes.get_mut(node_id.0) {
615                node_entry.dyn_node.update(UpdateContext::new(
616                    *node_id,
617                    stream_info,
618                    &mut node_entry.info.custom_state,
619                    event_queue,
620                ));
621            } else {
622                cull_list = true;
623            }
624        }
625
626        if cull_list {
627            self.nodes_to_call_update_method
628                .retain(|node_id| self.nodes.contains(node_id.0));
629        }
630    }
631}