1use core::any::Any;
2use core::fmt::Debug;
3use core::hash::Hash;
4
5#[cfg(not(feature = "std"))]
6use bevy_platform::prelude::{Box, Vec};
7
8use bevy_platform::collections::HashMap;
9use firewheel_core::channel_config::{ChannelConfig, ChannelCount};
10use firewheel_core::event::NodeEvent;
11use firewheel_core::node::{ConstructProcessorContext, UpdateContext};
12use firewheel_core::StreamInfo;
13use smallvec::SmallVec;
14use thunderdome::Arena;
15
16use crate::error::{AddEdgeError, CompileGraphError, RemoveNodeError};
17use crate::graph::dummy_node::{DummyNode, DummyNodeConfig};
18use crate::FirewheelConfig;
19use firewheel_core::node::{
20 AudioNode, AudioNodeInfo, AudioNodeInfoInner, Constructor, DynAudioNode, NodeID,
21};
22
23pub(crate) use self::compiler::{CompiledSchedule, NodeHeapData, ScheduleHeapData};
24
25pub use self::compiler::{Edge, EdgeID, NodeEntry, PortIdx};
26
27mod compiler;
28mod dummy_node;
29
30#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
31struct EdgeHash {
32 pub src_node: NodeID,
33 pub dst_node: NodeID,
34 pub src_port: PortIdx,
35 pub dst_port: PortIdx,
36}
37
38pub(crate) struct AudioGraph {
40 nodes: Arena<NodeEntry>,
41 edges: Arena<Edge>,
42 existing_edges: HashMap<EdgeHash, EdgeID>,
43
44 graph_in_id: NodeID,
45 graph_out_id: NodeID,
46 needs_compile: bool,
47
48 nodes_to_remove_from_schedule: Vec<NodeID>,
49 active_nodes_to_remove: HashMap<NodeID, NodeEntry>,
50 nodes_to_call_update_method: Vec<NodeID>,
51
52 prev_node_arena_capacity: usize,
53}
54
55impl AudioGraph {
56 pub fn new(config: &FirewheelConfig) -> Self {
57 let mut nodes = Arena::with_capacity(config.initial_node_capacity as usize);
58
59 let graph_in_config = DummyNodeConfig {
60 channel_config: ChannelConfig {
61 num_inputs: ChannelCount::ZERO,
62 num_outputs: config.num_graph_inputs,
63 },
64 };
65 let graph_out_config = DummyNodeConfig {
66 channel_config: ChannelConfig {
67 num_inputs: config.num_graph_outputs,
68 num_outputs: ChannelCount::ZERO,
69 },
70 };
71
72 let graph_in_id = NodeID(
73 nodes.insert(NodeEntry::new(
74 AudioNodeInfo::new()
75 .debug_name("graph_in")
76 .channel_config(graph_in_config.channel_config)
77 .into(),
78 Box::new(Constructor::new(DummyNode, Some(graph_in_config))),
79 )),
80 );
81 nodes[graph_in_id.0].id = graph_in_id;
82
83 let graph_out_id = NodeID(
84 nodes.insert(NodeEntry::new(
85 AudioNodeInfo::new()
86 .debug_name("graph_out")
87 .channel_config(graph_out_config.channel_config)
88 .into(),
89 Box::new(Constructor::new(DummyNode, Some(graph_out_config))),
90 )),
91 );
92 nodes[graph_out_id.0].id = graph_out_id;
93
94 Self {
95 nodes,
96 edges: Arena::with_capacity(config.initial_edge_capacity as usize),
97 existing_edges: HashMap::with_capacity(config.initial_edge_capacity as usize),
98 graph_in_id,
99 graph_out_id,
100 needs_compile: true,
101 nodes_to_remove_from_schedule: Vec::with_capacity(
102 config.initial_node_capacity as usize,
103 ),
104 active_nodes_to_remove: HashMap::with_capacity(config.initial_node_capacity as usize),
105 nodes_to_call_update_method: Vec::new(),
106 prev_node_arena_capacity: 0,
107 }
108 }
109
110 pub fn graph_in_node(&self) -> NodeID {
112 self.graph_in_id
113 }
114
115 pub fn graph_out_node(&self) -> NodeID {
117 self.graph_out_id
118 }
119
120 pub fn add_node<T: AudioNode + 'static>(
122 &mut self,
123 node: T,
124 config: Option<T::Configuration>,
125 ) -> NodeID {
126 let constructor = Constructor::new(node, config);
127 let info: AudioNodeInfoInner = constructor.info().into();
128 let call_update_method = info.call_update_method;
129
130 let new_id = NodeID(
131 self.nodes
132 .insert(NodeEntry::new(info, Box::new(constructor))),
133 );
134 self.nodes[new_id.0].id = new_id;
135
136 if call_update_method {
137 self.nodes_to_call_update_method.push(new_id);
138 }
139
140 self.needs_compile = true;
141
142 new_id
143 }
144
145 pub fn add_dyn_node<T: DynAudioNode + 'static>(&mut self, node: T) -> NodeID {
147 let info: AudioNodeInfoInner = node.info().into();
148 let call_update_method = info.call_update_method;
149
150 let new_id = NodeID(self.nodes.insert(NodeEntry::new(info, Box::new(node))));
151 self.nodes[new_id.0].id = new_id;
152
153 if call_update_method {
154 self.nodes_to_call_update_method.push(new_id);
155 }
156
157 self.needs_compile = true;
158
159 new_id
160 }
161
162 pub fn remove_node(
173 &mut self,
174 node_id: NodeID,
175 ) -> Result<SmallVec<[EdgeID; 4]>, RemoveNodeError> {
176 if node_id == self.graph_in_id {
177 return Err(RemoveNodeError::CannotRemoveGraphInNode);
178 }
179 if node_id == self.graph_out_id {
180 return Err(RemoveNodeError::CannotRemoveGraphOutNode);
181 }
182
183 let mut removed_edges = SmallVec::new();
184
185 let Some(node_entry) = self.nodes.remove(node_id.0) else {
186 return Ok(removed_edges);
187 };
188
189 for port_idx in 0..node_entry.info.channel_config.num_inputs.get() {
190 removed_edges.append(&mut self.remove_edges_with_input_port(node_id, port_idx));
191 }
192 for port_idx in 0..node_entry.info.channel_config.num_outputs.get() {
193 removed_edges.append(&mut self.remove_edges_with_output_port(node_id, port_idx));
194 }
195
196 self.nodes_to_remove_from_schedule.push(node_id);
197 self.active_nodes_to_remove.insert(node_id, node_entry);
198
199 self.needs_compile = true;
200
201 Ok(removed_edges)
202 }
203
204 pub fn node_info(&self, id: NodeID) -> Option<&NodeEntry> {
206 self.nodes.get(id.0)
207 }
208
209 pub fn node_state<T: 'static>(&self, id: NodeID) -> Option<&T> {
211 self.node_state_dyn(id).and_then(|s| s.downcast_ref())
212 }
213
214 pub fn node_state_dyn(&self, id: NodeID) -> Option<&dyn Any> {
216 self.nodes
217 .get(id.0)
218 .and_then(|node_entry| node_entry.info.custom_state.as_ref().map(|s| s.as_ref()))
219 }
220
221 pub fn node_state_mut<T: 'static>(&mut self, id: NodeID) -> Option<&mut T> {
223 self.node_state_dyn_mut(id).and_then(|s| s.downcast_mut())
224 }
225
226 pub fn node_state_dyn_mut(&mut self, id: NodeID) -> Option<&mut dyn Any> {
228 self.nodes
229 .get_mut(id.0)
230 .and_then(|node_entry| node_entry.info.custom_state.as_mut().map(|s| s.as_mut()))
231 }
232
233 pub fn nodes(&self) -> impl Iterator<Item = &NodeEntry> {
235 self.nodes.iter().map(|(_, n)| n)
236 }
237
238 pub fn edges(&self) -> impl Iterator<Item = &Edge> {
240 self.edges.iter().map(|(_, e)| e)
241 }
242
243 pub fn set_graph_channel_config(
247 &mut self,
248 channel_config: ChannelConfig,
249 ) -> SmallVec<[EdgeID; 4]> {
250 let mut removed_edges = SmallVec::new();
251
252 let graph_in_node = self.nodes.get_mut(self.graph_in_id.0).unwrap();
253 if channel_config.num_inputs != graph_in_node.info.channel_config.num_outputs {
254 let old_num_inputs = graph_in_node.info.channel_config.num_outputs;
255 graph_in_node.info.channel_config.num_outputs = channel_config.num_inputs;
256
257 if channel_config.num_inputs < old_num_inputs {
258 for port_idx in channel_config.num_inputs.get()..old_num_inputs.get() {
259 removed_edges.append(
260 &mut self.remove_edges_with_output_port(self.graph_in_id, port_idx),
261 );
262 }
263 }
264
265 self.needs_compile = true;
266 }
267
268 let graph_out_node = self.nodes.get_mut(self.graph_in_id.0).unwrap();
269
270 if channel_config.num_outputs != graph_out_node.info.channel_config.num_inputs {
271 let old_num_outputs = graph_out_node.info.channel_config.num_inputs;
272 graph_out_node.info.channel_config.num_inputs = channel_config.num_outputs;
273
274 if channel_config.num_outputs < old_num_outputs {
275 for port_idx in channel_config.num_outputs.get()..old_num_outputs.get() {
276 removed_edges.append(
277 &mut self.remove_edges_with_input_port(self.graph_out_id, port_idx),
278 );
279 }
280 }
281
282 self.needs_compile = true;
283 }
284
285 removed_edges
286 }
287
288 pub fn connect(
306 &mut self,
307 src_node: NodeID,
308 dst_node: NodeID,
309 ports_src_dst: &[(PortIdx, PortIdx)],
310 check_for_cycles: bool,
311 ) -> Result<SmallVec<[EdgeID; 4]>, AddEdgeError> {
312 let src_node_entry = self
313 .nodes
314 .get(src_node.0)
315 .ok_or(AddEdgeError::SrcNodeNotFound(src_node))?;
316 let dst_node_entry = self
317 .nodes
318 .get(dst_node.0)
319 .ok_or(AddEdgeError::DstNodeNotFound(dst_node))?;
320
321 if src_node.0 == dst_node.0 {
322 return Err(AddEdgeError::CycleDetected);
323 }
324
325 for (src_port, dst_port) in ports_src_dst.iter().copied() {
326 if src_port >= src_node_entry.info.channel_config.num_outputs.get() {
327 return Err(AddEdgeError::OutPortOutOfRange {
328 node: src_node,
329 port_idx: src_port,
330 num_out_ports: src_node_entry.info.channel_config.num_outputs,
331 });
332 }
333 if dst_port >= dst_node_entry.info.channel_config.num_inputs.get() {
334 return Err(AddEdgeError::InPortOutOfRange {
335 node: dst_node,
336 port_idx: dst_port,
337 num_in_ports: dst_node_entry.info.channel_config.num_inputs,
338 });
339 }
340 }
341
342 let mut edge_ids = SmallVec::new();
343
344 for (src_port, dst_port) in ports_src_dst.iter().copied() {
345 if let Some(id) = self.existing_edges.get(&EdgeHash {
346 src_node,
347 src_port,
348 dst_node,
349 dst_port,
350 }) {
351 edge_ids.push(*id);
353 continue;
354 }
355
356 let new_edge_id = EdgeID(self.edges.insert(Edge {
357 id: EdgeID(thunderdome::Index::DANGLING),
358 src_node,
359 src_port,
360 dst_node,
361 dst_port,
362 }));
363 self.edges[new_edge_id.0].id = new_edge_id;
364 self.existing_edges.insert(
365 EdgeHash {
366 src_node,
367 src_port,
368 dst_node,
369 dst_port,
370 },
371 new_edge_id,
372 );
373
374 edge_ids.push(new_edge_id);
375 }
376
377 if check_for_cycles && self.cycle_detected() {
378 self.disconnect(src_node, dst_node, ports_src_dst);
379
380 return Err(AddEdgeError::CycleDetected);
381 }
382
383 self.needs_compile = true;
384
385 Ok(edge_ids)
386 }
387
388 pub fn disconnect(
399 &mut self,
400 src_node: NodeID,
401 dst_node: NodeID,
402 ports_src_dst: &[(PortIdx, PortIdx)],
403 ) -> bool {
404 let mut any_removed = false;
405
406 for (src_port, dst_port) in ports_src_dst.iter().copied() {
407 if let Some(edge_id) = self.existing_edges.remove(&EdgeHash {
408 src_node,
409 src_port,
410 dst_node,
411 dst_port,
412 }) {
413 self.disconnect_by_edge_id(edge_id);
414 any_removed = true;
415 }
416 }
417
418 any_removed
419 }
420
421 pub fn disconnect_all_between(
426 &mut self,
427 src_node: NodeID,
428 dst_node: NodeID,
429 ) -> SmallVec<[EdgeID; 4]> {
430 let mut removed_edges = SmallVec::new();
431
432 if !self.nodes.contains(src_node.0) || !self.nodes.contains(dst_node.0) {
433 return removed_edges;
434 };
435
436 for (edge_id, edge) in self.edges.iter() {
437 if edge.src_node == src_node && edge.dst_node == dst_node {
438 removed_edges.push(EdgeID(edge_id));
439 }
440 }
441
442 for &edge_id in removed_edges.iter() {
443 self.disconnect_by_edge_id(edge_id);
444 }
445
446 removed_edges
447 }
448
449 pub fn disconnect_by_edge_id(&mut self, edge_id: EdgeID) -> bool {
453 if let Some(edge) = self.edges.remove(edge_id.0) {
454 self.existing_edges.remove(&EdgeHash {
455 src_node: edge.src_node,
456 src_port: edge.src_port,
457 dst_node: edge.dst_node,
458 dst_port: edge.dst_port,
459 });
460
461 self.needs_compile = true;
462
463 true
464 } else {
465 false
466 }
467 }
468
469 pub fn edge(&self, edge_id: EdgeID) -> Option<&Edge> {
471 self.edges.get(edge_id.0)
472 }
473
474 fn remove_edges_with_input_port(
475 &mut self,
476 node_id: NodeID,
477 port_idx: PortIdx,
478 ) -> SmallVec<[EdgeID; 4]> {
479 let mut edges_to_remove = SmallVec::new();
480
481 for (edge_id, edge) in self.edges.iter() {
483 if edge.dst_node == node_id && edge.dst_port == port_idx {
484 edges_to_remove.push(EdgeID(edge_id));
485 }
486 }
487
488 for edge_id in edges_to_remove.iter() {
489 self.disconnect_by_edge_id(*edge_id);
490 }
491
492 edges_to_remove
493 }
494
495 fn remove_edges_with_output_port(
496 &mut self,
497 node_id: NodeID,
498 port_idx: PortIdx,
499 ) -> SmallVec<[EdgeID; 4]> {
500 let mut edges_to_remove = SmallVec::new();
501
502 for (edge_id, edge) in self.edges.iter() {
504 if edge.src_node == node_id && edge.src_port == port_idx {
505 edges_to_remove.push(EdgeID(edge_id));
506 }
507 }
508
509 for edge_id in edges_to_remove.iter() {
510 self.disconnect_by_edge_id(*edge_id);
511 }
512
513 edges_to_remove
514 }
515
516 pub fn cycle_detected(&mut self) -> bool {
517 compiler::cycle_detected(
518 &mut self.nodes,
519 &mut self.edges,
520 self.graph_in_id,
521 self.graph_out_id,
522 )
523 }
524
525 pub(crate) fn needs_compile(&self) -> bool {
526 self.needs_compile
527 }
528
529 pub(crate) fn on_schedule_send_failed(&mut self, failed_schedule: Box<ScheduleHeapData>) {
530 self.needs_compile = true;
531
532 for node in failed_schedule.new_node_processors.iter() {
533 if let Some(node_entry) = &mut self.nodes.get_mut(node.id.0) {
534 node_entry.processor_constructed = false;
535 }
536 }
537 }
538
539 pub(crate) fn deactivate(&mut self) {
540 self.needs_compile = true;
541 }
542
543 pub(crate) fn compile(
544 &mut self,
545 stream_info: &StreamInfo,
546 ) -> Result<Box<ScheduleHeapData>, CompileGraphError> {
547 let schedule = self.compile_internal(stream_info.max_block_frames.get() as usize)?;
548
549 let mut new_node_processors = Vec::new();
550 for (_, entry) in self.nodes.iter_mut() {
551 if !entry.processor_constructed {
552 entry.processor_constructed = true;
553
554 let cx = ConstructProcessorContext::new(
555 entry.id,
556 stream_info,
557 &mut entry.info.custom_state,
558 );
559
560 new_node_processors.push(NodeHeapData {
561 id: entry.id,
562 processor: entry.dyn_node.construct_processor(cx),
563 is_pre_process: entry.info.channel_config.is_empty(),
564 });
565 }
566 }
567
568 let mut nodes_to_remove = Vec::new();
569 core::mem::swap(
570 &mut self.nodes_to_remove_from_schedule,
571 &mut nodes_to_remove,
572 );
573
574 let new_arena = if self.nodes.capacity() > self.prev_node_arena_capacity {
575 Some(Arena::with_capacity(self.nodes.capacity()))
576 } else {
577 None
578 };
579 self.prev_node_arena_capacity = self.nodes.capacity();
580
581 let schedule_data = Box::new(ScheduleHeapData::new(
582 schedule,
583 nodes_to_remove,
584 new_node_processors,
585 new_arena,
586 ));
587
588 self.needs_compile = false;
589
590 #[cfg(feature = "tracing")]
591 tracing::debug!("compiled new audio graph: {:?}", &schedule_data);
592
593 #[cfg(all(feature = "log", not(feature = "tracing")))]
594 log::debug!("compiled new audio graph: {:?}", &schedule_data);
595
596 Ok(schedule_data)
597 }
598
599 fn compile_internal(
600 &mut self,
601 max_block_frames: usize,
602 ) -> Result<CompiledSchedule, CompileGraphError> {
603 assert!(max_block_frames > 0);
604
605 compiler::compile(
606 &mut self.nodes,
607 &mut self.edges,
608 self.graph_in_id,
609 self.graph_out_id,
610 max_block_frames,
611 )
612 }
613
614 pub(crate) fn update(
615 &mut self,
616 stream_info: Option<&StreamInfo>,
617 event_queue: &mut Vec<NodeEvent>,
618 ) {
619 let mut cull_list = false;
620 for node_id in self.nodes_to_call_update_method.iter() {
621 if let Some(node_entry) = self.nodes.get_mut(node_id.0) {
622 node_entry.dyn_node.update(UpdateContext::new(
623 *node_id,
624 stream_info,
625 &mut node_entry.info.custom_state,
626 event_queue,
627 ));
628 } else {
629 cull_list = true;
630 }
631 }
632
633 if cull_list {
634 self.nodes_to_call_update_method
635 .retain(|node_id| self.nodes.contains(node_id.0));
636 }
637 }
638}