1use core::any::Any;
2use core::fmt::Debug;
3use core::hash::Hash;
4
5#[cfg(not(feature = "std"))]
6use bevy_platform::prelude::{Box, Vec};
7
8use bevy_platform::collections::HashMap;
9use firewheel_core::channel_config::{ChannelConfig, ChannelCount};
10use firewheel_core::event::NodeEvent;
11use firewheel_core::node::{ConstructProcessorContext, UpdateContext};
12use firewheel_core::StreamInfo;
13use smallvec::SmallVec;
14use thunderdome::Arena;
15
16use crate::error::{AddEdgeError, CompileGraphError, RemoveNodeError};
17use crate::graph::dummy_node::{DummyNode, DummyNodeConfig};
18use crate::FirewheelConfig;
19use firewheel_core::node::{
20 AudioNode, AudioNodeInfo, AudioNodeInfoInner, Constructor, DynAudioNode, NodeID,
21};
22
23pub(crate) use self::compiler::{CompiledSchedule, NodeHeapData, ScheduleHeapData};
24
25pub use self::compiler::{Edge, EdgeID, NodeEntry, PortIdx};
26
27mod compiler;
28mod dummy_node;
29
30#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
31struct EdgeHash {
32 pub src_node: NodeID,
33 pub dst_node: NodeID,
34 pub src_port: PortIdx,
35 pub dst_port: PortIdx,
36}
37
38pub(crate) struct AudioGraph {
40 nodes: Arena<NodeEntry>,
41 edges: Arena<Edge>,
42 existing_edges: HashMap<EdgeHash, EdgeID>,
43
44 graph_in_id: NodeID,
45 graph_out_id: NodeID,
46 needs_compile: bool,
47
48 nodes_to_remove_from_schedule: Vec<NodeID>,
49 active_nodes_to_remove: HashMap<NodeID, NodeEntry>,
50 nodes_to_call_update_method: Vec<NodeID>,
51
52 prev_node_arena_capacity: usize,
53}
54
55impl AudioGraph {
56 pub fn new(config: &FirewheelConfig) -> Self {
57 let mut nodes = Arena::with_capacity(config.initial_node_capacity as usize);
58
59 let graph_in_config = DummyNodeConfig {
60 channel_config: ChannelConfig {
61 num_inputs: ChannelCount::ZERO,
62 num_outputs: config.num_graph_inputs,
63 },
64 };
65 let graph_out_config = DummyNodeConfig {
66 channel_config: ChannelConfig {
67 num_inputs: config.num_graph_outputs,
68 num_outputs: ChannelCount::ZERO,
69 },
70 };
71
72 let graph_in_id = NodeID(
73 nodes.insert(NodeEntry::new(
74 AudioNodeInfo::new()
75 .debug_name("graph_in")
76 .channel_config(graph_in_config.channel_config)
77 .into(),
78 Box::new(Constructor::new(DummyNode, Some(graph_in_config))),
79 )),
80 );
81 nodes[graph_in_id.0].id = graph_in_id;
82
83 let graph_out_id = NodeID(
84 nodes.insert(NodeEntry::new(
85 AudioNodeInfo::new()
86 .debug_name("graph_out")
87 .channel_config(graph_out_config.channel_config)
88 .into(),
89 Box::new(Constructor::new(DummyNode, Some(graph_out_config))),
90 )),
91 );
92 nodes[graph_out_id.0].id = graph_out_id;
93
94 Self {
95 nodes,
96 edges: Arena::with_capacity(config.initial_edge_capacity as usize),
97 existing_edges: HashMap::with_capacity(config.initial_edge_capacity as usize),
98 graph_in_id,
99 graph_out_id,
100 needs_compile: true,
101 nodes_to_remove_from_schedule: Vec::with_capacity(
102 config.initial_node_capacity as usize,
103 ),
104 active_nodes_to_remove: HashMap::with_capacity(config.initial_node_capacity as usize),
105 nodes_to_call_update_method: Vec::new(),
106 prev_node_arena_capacity: 0,
107 }
108 }
109
110 pub fn graph_in_node(&self) -> NodeID {
112 self.graph_in_id
113 }
114
115 pub fn graph_out_node(&self) -> NodeID {
117 self.graph_out_id
118 }
119
120 pub fn add_node<T: AudioNode + 'static>(
122 &mut self,
123 node: T,
124 config: Option<T::Configuration>,
125 ) -> NodeID {
126 let constructor = Constructor::new(node, config);
127 let info: AudioNodeInfoInner = constructor.info().into();
128 let call_update_method = info.call_update_method;
129
130 let new_id = NodeID(
131 self.nodes
132 .insert(NodeEntry::new(info, Box::new(constructor))),
133 );
134 self.nodes[new_id.0].id = new_id;
135
136 if call_update_method {
137 self.nodes_to_call_update_method.push(new_id);
138 }
139
140 self.needs_compile = true;
141
142 new_id
143 }
144
145 pub fn add_dyn_node<T: DynAudioNode + 'static>(&mut self, node: T) -> NodeID {
147 let info: AudioNodeInfoInner = node.info().into();
148 let call_update_method = info.call_update_method;
149
150 let new_id = NodeID(self.nodes.insert(NodeEntry::new(info, Box::new(node))));
151 self.nodes[new_id.0].id = new_id;
152
153 if call_update_method {
154 self.nodes_to_call_update_method.push(new_id);
155 }
156
157 self.needs_compile = true;
158
159 new_id
160 }
161
162 pub fn remove_node(
173 &mut self,
174 node_id: NodeID,
175 ) -> Result<SmallVec<[EdgeID; 4]>, RemoveNodeError> {
176 if node_id == self.graph_in_id {
177 return Err(RemoveNodeError::CannotRemoveGraphInNode);
178 }
179 if node_id == self.graph_out_id {
180 return Err(RemoveNodeError::CannotRemoveGraphOutNode);
181 }
182
183 let mut removed_edges = SmallVec::new();
184
185 let Some(node_entry) = self.nodes.remove(node_id.0) else {
186 return Ok(removed_edges);
187 };
188
189 for port_idx in 0..node_entry.info.channel_config.num_inputs.get() {
190 removed_edges.append(&mut self.remove_edges_with_input_port(node_id, port_idx));
191 }
192 for port_idx in 0..node_entry.info.channel_config.num_outputs.get() {
193 removed_edges.append(&mut self.remove_edges_with_output_port(node_id, port_idx));
194 }
195
196 self.nodes_to_remove_from_schedule.push(node_id);
197 self.active_nodes_to_remove.insert(node_id, node_entry);
198
199 self.needs_compile = true;
200
201 Ok(removed_edges)
202 }
203
204 pub fn node_info(&self, id: NodeID) -> Option<&NodeEntry> {
206 self.nodes.get(id.0)
207 }
208
209 pub fn node_state<T: 'static>(&self, id: NodeID) -> Option<&T> {
211 self.node_state_dyn(id).and_then(|s| s.downcast_ref())
212 }
213
214 pub fn node_state_dyn(&self, id: NodeID) -> Option<&dyn Any> {
216 self.nodes
217 .get(id.0)
218 .and_then(|node_entry| node_entry.info.custom_state.as_ref().map(|s| s.as_ref()))
219 }
220
221 pub fn node_state_mut<T: 'static>(&mut self, id: NodeID) -> Option<&mut T> {
223 self.node_state_dyn_mut(id).and_then(|s| s.downcast_mut())
224 }
225
226 pub fn node_state_dyn_mut(&mut self, id: NodeID) -> Option<&mut dyn Any> {
228 self.nodes
229 .get_mut(id.0)
230 .and_then(|node_entry| node_entry.info.custom_state.as_mut().map(|s| s.as_mut()))
231 }
232
233 pub fn nodes<'a>(&'a self) -> impl Iterator<Item = &'a NodeEntry> {
235 self.nodes.iter().map(|(_, n)| n)
236 }
237
238 pub fn edges<'a>(&'a self) -> impl Iterator<Item = &'a Edge> {
240 self.edges.iter().map(|(_, e)| e)
241 }
242
243 pub fn set_graph_channel_config(
247 &mut self,
248 channel_config: ChannelConfig,
249 ) -> SmallVec<[EdgeID; 4]> {
250 let mut removed_edges = SmallVec::new();
251
252 let graph_in_node = self.nodes.get_mut(self.graph_in_id.0).unwrap();
253 if channel_config.num_inputs != graph_in_node.info.channel_config.num_outputs {
254 let old_num_inputs = graph_in_node.info.channel_config.num_outputs;
255 graph_in_node.info.channel_config.num_outputs = channel_config.num_inputs;
256
257 if channel_config.num_inputs < old_num_inputs {
258 for port_idx in channel_config.num_inputs.get()..old_num_inputs.get() {
259 removed_edges.append(
260 &mut self.remove_edges_with_output_port(self.graph_in_id, port_idx),
261 );
262 }
263 }
264
265 self.needs_compile = true;
266 }
267
268 let graph_out_node = self.nodes.get_mut(self.graph_in_id.0).unwrap();
269
270 if channel_config.num_outputs != graph_out_node.info.channel_config.num_inputs {
271 let old_num_outputs = graph_out_node.info.channel_config.num_inputs;
272 graph_out_node.info.channel_config.num_inputs = channel_config.num_outputs;
273
274 if channel_config.num_outputs < old_num_outputs {
275 for port_idx in channel_config.num_outputs.get()..old_num_outputs.get() {
276 removed_edges.append(
277 &mut self.remove_edges_with_input_port(self.graph_out_id, port_idx),
278 );
279 }
280 }
281
282 self.needs_compile = true;
283 }
284
285 removed_edges
286 }
287
288 pub fn connect(
306 &mut self,
307 src_node: NodeID,
308 dst_node: NodeID,
309 ports_src_dst: &[(PortIdx, PortIdx)],
310 check_for_cycles: bool,
311 ) -> Result<SmallVec<[EdgeID; 4]>, AddEdgeError> {
312 let src_node_entry = self
313 .nodes
314 .get(src_node.0)
315 .ok_or(AddEdgeError::SrcNodeNotFound(src_node))?;
316 let dst_node_entry = self
317 .nodes
318 .get(dst_node.0)
319 .ok_or(AddEdgeError::DstNodeNotFound(dst_node))?;
320
321 if src_node.0 == dst_node.0 {
322 return Err(AddEdgeError::CycleDetected);
323 }
324
325 for (src_port, dst_port) in ports_src_dst.iter().copied() {
326 if src_port >= src_node_entry.info.channel_config.num_outputs.get() {
327 return Err(AddEdgeError::OutPortOutOfRange {
328 node: src_node,
329 port_idx: src_port,
330 num_out_ports: src_node_entry.info.channel_config.num_outputs,
331 });
332 }
333 if dst_port >= dst_node_entry.info.channel_config.num_inputs.get() {
334 return Err(AddEdgeError::InPortOutOfRange {
335 node: dst_node,
336 port_idx: dst_port,
337 num_in_ports: dst_node_entry.info.channel_config.num_inputs,
338 });
339 }
340 }
341
342 let mut edge_ids = SmallVec::new();
343
344 for (src_port, dst_port) in ports_src_dst.iter().copied() {
345 if let Some(id) = self.existing_edges.get(&EdgeHash {
346 src_node,
347 src_port,
348 dst_node,
349 dst_port,
350 }) {
351 edge_ids.push(*id);
353 continue;
354 }
355
356 let new_edge_id = EdgeID(self.edges.insert(Edge {
357 id: EdgeID(thunderdome::Index::DANGLING),
358 src_node,
359 src_port,
360 dst_node,
361 dst_port,
362 }));
363 self.edges[new_edge_id.0].id = new_edge_id;
364 self.existing_edges.insert(
365 EdgeHash {
366 src_node,
367 src_port,
368 dst_node,
369 dst_port,
370 },
371 new_edge_id,
372 );
373
374 edge_ids.push(new_edge_id);
375 }
376
377 if check_for_cycles {
378 if self.cycle_detected() {
379 self.disconnect(src_node, dst_node, ports_src_dst);
380
381 return Err(AddEdgeError::CycleDetected);
382 }
383 }
384
385 self.needs_compile = true;
386
387 Ok(edge_ids)
388 }
389
390 pub fn disconnect(
401 &mut self,
402 src_node: NodeID,
403 dst_node: NodeID,
404 ports_src_dst: &[(PortIdx, PortIdx)],
405 ) -> bool {
406 let mut any_removed = false;
407
408 for (src_port, dst_port) in ports_src_dst.iter().copied() {
409 if let Some(edge_id) = self.existing_edges.remove(&EdgeHash {
410 src_node,
411 src_port: src_port.into(),
412 dst_node,
413 dst_port: dst_port.into(),
414 }) {
415 self.disconnect_by_edge_id(edge_id);
416 any_removed = true;
417 }
418 }
419
420 any_removed
421 }
422
423 pub fn disconnect_all_between(
428 &mut self,
429 src_node: NodeID,
430 dst_node: NodeID,
431 ) -> SmallVec<[EdgeID; 4]> {
432 let mut removed_edges = SmallVec::new();
433
434 if !self.nodes.contains(src_node.0) || !self.nodes.contains(dst_node.0) {
435 return removed_edges;
436 };
437
438 for (edge_id, edge) in self.edges.iter() {
439 if edge.src_node == src_node && edge.dst_node == dst_node {
440 removed_edges.push(EdgeID(edge_id));
441 }
442 }
443
444 for &edge_id in removed_edges.iter() {
445 self.disconnect_by_edge_id(edge_id);
446 }
447
448 removed_edges
449 }
450
451 pub fn disconnect_by_edge_id(&mut self, edge_id: EdgeID) -> bool {
455 if let Some(edge) = self.edges.remove(edge_id.0) {
456 self.existing_edges.remove(&EdgeHash {
457 src_node: edge.src_node,
458 src_port: edge.src_port,
459 dst_node: edge.dst_node,
460 dst_port: edge.dst_port,
461 });
462
463 self.needs_compile = true;
464
465 true
466 } else {
467 false
468 }
469 }
470
471 pub fn edge(&self, edge_id: EdgeID) -> Option<&Edge> {
473 self.edges.get(edge_id.0)
474 }
475
476 fn remove_edges_with_input_port(
477 &mut self,
478 node_id: NodeID,
479 port_idx: PortIdx,
480 ) -> SmallVec<[EdgeID; 4]> {
481 let mut edges_to_remove = SmallVec::new();
482
483 for (edge_id, edge) in self.edges.iter() {
485 if edge.dst_node == node_id && edge.dst_port == port_idx {
486 edges_to_remove.push(EdgeID(edge_id));
487 }
488 }
489
490 for edge_id in edges_to_remove.iter() {
491 self.disconnect_by_edge_id(*edge_id);
492 }
493
494 edges_to_remove
495 }
496
497 fn remove_edges_with_output_port(
498 &mut self,
499 node_id: NodeID,
500 port_idx: PortIdx,
501 ) -> SmallVec<[EdgeID; 4]> {
502 let mut edges_to_remove = SmallVec::new();
503
504 for (edge_id, edge) in self.edges.iter() {
506 if edge.src_node == node_id && edge.src_port == port_idx {
507 edges_to_remove.push(EdgeID(edge_id));
508 }
509 }
510
511 for edge_id in edges_to_remove.iter() {
512 self.disconnect_by_edge_id(*edge_id);
513 }
514
515 edges_to_remove
516 }
517
518 pub fn cycle_detected(&mut self) -> bool {
519 compiler::cycle_detected(
520 &mut self.nodes,
521 &mut self.edges,
522 self.graph_in_id,
523 self.graph_out_id,
524 )
525 }
526
527 pub(crate) fn needs_compile(&self) -> bool {
528 self.needs_compile
529 }
530
531 pub(crate) fn on_schedule_send_failed(&mut self, failed_schedule: Box<ScheduleHeapData>) {
532 self.needs_compile = true;
533
534 for node in failed_schedule.new_node_processors.iter() {
535 if let Some(node_entry) = &mut self.nodes.get_mut(node.id.0) {
536 node_entry.processor_constructed = false;
537 }
538 }
539 }
540
541 pub(crate) fn deactivate(&mut self) {
542 self.needs_compile = true;
543 }
544
545 pub(crate) fn compile(
546 &mut self,
547 stream_info: &StreamInfo,
548 ) -> Result<Box<ScheduleHeapData>, CompileGraphError> {
549 let schedule = self.compile_internal(stream_info.max_block_frames.get() as usize)?;
550
551 let mut new_node_processors = Vec::new();
552 for (_, entry) in self.nodes.iter_mut() {
553 if !entry.processor_constructed {
554 entry.processor_constructed = true;
555
556 let cx = ConstructProcessorContext::new(
557 entry.id,
558 stream_info,
559 &mut entry.info.custom_state,
560 );
561
562 new_node_processors.push(NodeHeapData {
563 id: entry.id,
564 processor: entry.dyn_node.construct_processor(cx),
565 is_pre_process: entry.info.channel_config.is_empty(),
566 });
567 }
568 }
569
570 let mut nodes_to_remove = Vec::new();
571 core::mem::swap(
572 &mut self.nodes_to_remove_from_schedule,
573 &mut nodes_to_remove,
574 );
575
576 let new_arena = if self.nodes.capacity() > self.prev_node_arena_capacity {
577 Some(Arena::with_capacity(self.nodes.capacity()))
578 } else {
579 None
580 };
581 self.prev_node_arena_capacity = self.nodes.capacity();
582
583 let schedule_data = Box::new(ScheduleHeapData::new(
584 schedule,
585 nodes_to_remove,
586 new_node_processors,
587 new_arena,
588 ));
589
590 self.needs_compile = false;
591
592 log::debug!("compiled new audio graph: {:?}", &schedule_data);
593
594 Ok(schedule_data)
595 }
596
597 fn compile_internal(
598 &mut self,
599 max_block_frames: usize,
600 ) -> Result<CompiledSchedule, CompileGraphError> {
601 assert!(max_block_frames > 0);
602
603 compiler::compile(
604 &mut self.nodes,
605 &mut self.edges,
606 self.graph_in_id,
607 self.graph_out_id,
608 max_block_frames,
609 )
610 }
611
612 pub(crate) fn update(
613 &mut self,
614 stream_info: Option<&StreamInfo>,
615 event_queue: &mut Vec<NodeEvent>,
616 ) {
617 let mut cull_list = false;
618 for node_id in self.nodes_to_call_update_method.iter() {
619 if let Some(node_entry) = self.nodes.get_mut(node_id.0) {
620 node_entry.dyn_node.update(UpdateContext::new(
621 *node_id,
622 stream_info,
623 &mut node_entry.info.custom_state,
624 event_queue,
625 ));
626 } else {
627 cull_list = true;
628 }
629 }
630
631 if cull_list {
632 self.nodes_to_call_update_method
633 .retain(|node_id| self.nodes.contains(node_id.0));
634 }
635 }
636}