1mod compiler;
2
3use core::any::Any;
4use core::fmt::Debug;
5use core::hash::Hash;
6
7#[cfg(not(feature = "std"))]
8use bevy_platform::prelude::{Box, Vec};
9
10use bevy_platform::collections::HashMap;
11use firewheel_core::channel_config::{ChannelConfig, ChannelCount};
12use firewheel_core::event::NodeEvent;
13use firewheel_core::node::{ConstructProcessorContext, UpdateContext};
14use firewheel_core::StreamInfo;
15use smallvec::SmallVec;
16use thunderdome::Arena;
17
18use crate::error::{AddEdgeError, CompileGraphError, RemoveNodeError};
19use crate::FirewheelConfig;
20use firewheel_core::node::{
21 dummy::{DummyNode, DummyNodeConfig},
22 AudioNode, AudioNodeInfo, AudioNodeInfoInner, Constructor, DynAudioNode, NodeID,
23};
24
25pub(crate) use self::compiler::{CompiledSchedule, NodeHeapData, ScheduleHeapData};
26
27pub use self::compiler::{Edge, EdgeID, NodeEntry, PortIdx};
28
29#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
30struct EdgeHash {
31 pub src_node: NodeID,
32 pub dst_node: NodeID,
33 pub src_port: PortIdx,
34 pub dst_port: PortIdx,
35}
36
37pub(crate) struct AudioGraph {
39 nodes: Arena<NodeEntry>,
40 edges: Arena<Edge>,
41 existing_edges: HashMap<EdgeHash, EdgeID>,
42
43 graph_in_id: NodeID,
44 graph_out_id: NodeID,
45 needs_compile: bool,
46
47 nodes_to_remove_from_schedule: Vec<NodeID>,
48 active_nodes_to_remove: HashMap<NodeID, NodeEntry>,
49 nodes_to_call_update_method: Vec<NodeID>,
50
51 prev_node_arena_capacity: usize,
52}
53
54impl AudioGraph {
55 pub fn new(config: &FirewheelConfig) -> Self {
56 let mut nodes = Arena::with_capacity(config.initial_node_capacity as usize);
57
58 let graph_in_config = DummyNodeConfig {
59 channel_config: ChannelConfig {
60 num_inputs: ChannelCount::ZERO,
61 num_outputs: config.num_graph_inputs,
62 },
63 };
64 let graph_out_config = DummyNodeConfig {
65 channel_config: ChannelConfig {
66 num_inputs: config.num_graph_outputs,
67 num_outputs: ChannelCount::ZERO,
68 },
69 };
70
71 let graph_in_id = NodeID(
72 nodes.insert(NodeEntry::new(
73 AudioNodeInfo::new()
74 .debug_name("graph_in")
75 .channel_config(graph_in_config.channel_config)
76 .into(),
77 Box::new(Constructor::new(DummyNode, Some(graph_in_config))),
78 )),
79 );
80 nodes[graph_in_id.0].id = graph_in_id;
81
82 let graph_out_id = NodeID(
83 nodes.insert(NodeEntry::new(
84 AudioNodeInfo::new()
85 .debug_name("graph_out")
86 .channel_config(graph_out_config.channel_config)
87 .into(),
88 Box::new(Constructor::new(DummyNode, Some(graph_out_config))),
89 )),
90 );
91 nodes[graph_out_id.0].id = graph_out_id;
92
93 Self {
94 nodes,
95 edges: Arena::with_capacity(config.initial_edge_capacity as usize),
96 existing_edges: HashMap::with_capacity(config.initial_edge_capacity as usize),
97 graph_in_id,
98 graph_out_id,
99 needs_compile: true,
100 nodes_to_remove_from_schedule: Vec::with_capacity(
101 config.initial_node_capacity as usize,
102 ),
103 active_nodes_to_remove: HashMap::with_capacity(config.initial_node_capacity as usize),
104 nodes_to_call_update_method: Vec::new(),
105 prev_node_arena_capacity: 0,
106 }
107 }
108
109 pub fn graph_in_node(&self) -> NodeID {
111 self.graph_in_id
112 }
113
114 pub fn graph_out_node(&self) -> NodeID {
116 self.graph_out_id
117 }
118
119 pub fn add_node<T: AudioNode + 'static>(
121 &mut self,
122 node: T,
123 config: Option<T::Configuration>,
124 ) -> NodeID {
125 let constructor = Constructor::new(node, config);
126 let info: AudioNodeInfoInner = constructor.info().into();
127 let call_update_method = info.call_update_method;
128
129 let new_id = NodeID(
130 self.nodes
131 .insert(NodeEntry::new(info, Box::new(constructor))),
132 );
133 self.nodes[new_id.0].id = new_id;
134
135 if call_update_method {
136 self.nodes_to_call_update_method.push(new_id);
137 }
138
139 self.needs_compile = true;
140
141 new_id
142 }
143
144 pub fn add_dyn_node<T: DynAudioNode + 'static>(&mut self, node: T) -> NodeID {
146 let info: AudioNodeInfoInner = node.info().into();
147 let call_update_method = info.call_update_method;
148
149 let new_id = NodeID(self.nodes.insert(NodeEntry::new(info, Box::new(node))));
150 self.nodes[new_id.0].id = new_id;
151
152 if call_update_method {
153 self.nodes_to_call_update_method.push(new_id);
154 }
155
156 self.needs_compile = true;
157
158 new_id
159 }
160
161 pub fn remove_node(
172 &mut self,
173 node_id: NodeID,
174 ) -> Result<SmallVec<[EdgeID; 4]>, RemoveNodeError> {
175 if node_id == self.graph_in_id {
176 return Err(RemoveNodeError::CannotRemoveGraphInNode);
177 }
178 if node_id == self.graph_out_id {
179 return Err(RemoveNodeError::CannotRemoveGraphOutNode);
180 }
181
182 let mut removed_edges = SmallVec::new();
183
184 let Some(node_entry) = self.nodes.remove(node_id.0) else {
185 return Ok(removed_edges);
186 };
187
188 for port_idx in 0..node_entry.info.channel_config.num_inputs.get() {
189 removed_edges.append(&mut self.remove_edges_with_input_port(node_id, port_idx));
190 }
191 for port_idx in 0..node_entry.info.channel_config.num_outputs.get() {
192 removed_edges.append(&mut self.remove_edges_with_output_port(node_id, port_idx));
193 }
194
195 self.nodes_to_remove_from_schedule.push(node_id);
196 self.active_nodes_to_remove.insert(node_id, node_entry);
197
198 self.needs_compile = true;
199
200 Ok(removed_edges)
201 }
202
203 pub fn node_info(&self, id: NodeID) -> Option<&NodeEntry> {
205 self.nodes.get(id.0)
206 }
207
208 pub fn node_state<T: 'static>(&self, id: NodeID) -> Option<&T> {
210 self.node_state_dyn(id).and_then(|s| s.downcast_ref())
211 }
212
213 pub fn node_state_dyn(&self, id: NodeID) -> Option<&dyn Any> {
215 self.nodes
216 .get(id.0)
217 .and_then(|node_entry| node_entry.info.custom_state.as_ref().map(|s| s.as_ref()))
218 }
219
220 pub fn node_state_mut<T: 'static>(&mut self, id: NodeID) -> Option<&mut T> {
222 self.node_state_dyn_mut(id).and_then(|s| s.downcast_mut())
223 }
224
225 pub fn node_state_dyn_mut(&mut self, id: NodeID) -> Option<&mut dyn Any> {
227 self.nodes
228 .get_mut(id.0)
229 .and_then(|node_entry| node_entry.info.custom_state.as_mut().map(|s| s.as_mut()))
230 }
231
232 pub fn nodes<'a>(&'a self) -> impl Iterator<Item = &'a NodeEntry> {
234 self.nodes.iter().map(|(_, n)| n)
235 }
236
237 pub fn edges<'a>(&'a self) -> impl Iterator<Item = &'a Edge> {
239 self.edges.iter().map(|(_, e)| e)
240 }
241
242 pub fn set_graph_channel_config(
246 &mut self,
247 channel_config: ChannelConfig,
248 ) -> SmallVec<[EdgeID; 4]> {
249 let mut removed_edges = SmallVec::new();
250
251 let graph_in_node = self.nodes.get_mut(self.graph_in_id.0).unwrap();
252 if channel_config.num_inputs != graph_in_node.info.channel_config.num_outputs {
253 let old_num_inputs = graph_in_node.info.channel_config.num_outputs;
254 graph_in_node.info.channel_config.num_outputs = channel_config.num_inputs;
255
256 if channel_config.num_inputs < old_num_inputs {
257 for port_idx in channel_config.num_inputs.get()..old_num_inputs.get() {
258 removed_edges.append(
259 &mut self.remove_edges_with_output_port(self.graph_in_id, port_idx),
260 );
261 }
262 }
263
264 self.needs_compile = true;
265 }
266
267 let graph_out_node = self.nodes.get_mut(self.graph_in_id.0).unwrap();
268
269 if channel_config.num_outputs != graph_out_node.info.channel_config.num_inputs {
270 let old_num_outputs = graph_out_node.info.channel_config.num_inputs;
271 graph_out_node.info.channel_config.num_inputs = channel_config.num_outputs;
272
273 if channel_config.num_outputs < old_num_outputs {
274 for port_idx in channel_config.num_outputs.get()..old_num_outputs.get() {
275 removed_edges.append(
276 &mut self.remove_edges_with_input_port(self.graph_out_id, port_idx),
277 );
278 }
279 }
280
281 self.needs_compile = true;
282 }
283
284 removed_edges
285 }
286
287 pub fn connect(
305 &mut self,
306 src_node: NodeID,
307 dst_node: NodeID,
308 ports_src_dst: &[(PortIdx, PortIdx)],
309 check_for_cycles: bool,
310 ) -> Result<SmallVec<[EdgeID; 4]>, AddEdgeError> {
311 let src_node_entry = self
312 .nodes
313 .get(src_node.0)
314 .ok_or(AddEdgeError::SrcNodeNotFound(src_node))?;
315 let dst_node_entry = self
316 .nodes
317 .get(dst_node.0)
318 .ok_or(AddEdgeError::DstNodeNotFound(dst_node))?;
319
320 if src_node.0 == dst_node.0 {
321 return Err(AddEdgeError::CycleDetected);
322 }
323
324 for (src_port, dst_port) in ports_src_dst.iter().copied() {
325 if src_port >= src_node_entry.info.channel_config.num_outputs.get() {
326 return Err(AddEdgeError::OutPortOutOfRange {
327 node: src_node,
328 port_idx: src_port,
329 num_out_ports: src_node_entry.info.channel_config.num_outputs,
330 });
331 }
332 if dst_port >= dst_node_entry.info.channel_config.num_inputs.get() {
333 return Err(AddEdgeError::InPortOutOfRange {
334 node: dst_node,
335 port_idx: dst_port,
336 num_in_ports: dst_node_entry.info.channel_config.num_inputs,
337 });
338 }
339 }
340
341 let mut edge_ids = SmallVec::new();
342
343 for (src_port, dst_port) in ports_src_dst.iter().copied() {
344 if let Some(id) = self.existing_edges.get(&EdgeHash {
345 src_node,
346 src_port,
347 dst_node,
348 dst_port,
349 }) {
350 edge_ids.push(*id);
352 continue;
353 }
354
355 let new_edge_id = EdgeID(self.edges.insert(Edge {
356 id: EdgeID(thunderdome::Index::DANGLING),
357 src_node,
358 src_port,
359 dst_node,
360 dst_port,
361 }));
362 self.edges[new_edge_id.0].id = new_edge_id;
363 self.existing_edges.insert(
364 EdgeHash {
365 src_node,
366 src_port,
367 dst_node,
368 dst_port,
369 },
370 new_edge_id,
371 );
372
373 edge_ids.push(new_edge_id);
374 }
375
376 if check_for_cycles {
377 if self.cycle_detected() {
378 self.disconnect(src_node, dst_node, ports_src_dst);
379
380 return Err(AddEdgeError::CycleDetected);
381 }
382 }
383
384 self.needs_compile = true;
385
386 Ok(edge_ids)
387 }
388
389 pub fn disconnect(
400 &mut self,
401 src_node: NodeID,
402 dst_node: NodeID,
403 ports_src_dst: &[(PortIdx, PortIdx)],
404 ) -> bool {
405 let mut any_removed = false;
406
407 for (src_port, dst_port) in ports_src_dst.iter().copied() {
408 if let Some(edge_id) = self.existing_edges.remove(&EdgeHash {
409 src_node,
410 src_port: src_port.into(),
411 dst_node,
412 dst_port: dst_port.into(),
413 }) {
414 self.disconnect_by_edge_id(edge_id);
415 any_removed = true;
416 }
417 }
418
419 any_removed
420 }
421
422 pub fn disconnect_all_between(
427 &mut self,
428 src_node: NodeID,
429 dst_node: NodeID,
430 ) -> SmallVec<[EdgeID; 4]> {
431 let mut removed_edges = SmallVec::new();
432
433 if !self.nodes.contains(src_node.0) || !self.nodes.contains(dst_node.0) {
434 return removed_edges;
435 };
436
437 for (edge_id, edge) in self.edges.iter() {
438 if edge.src_node == src_node && edge.dst_node == dst_node {
439 removed_edges.push(EdgeID(edge_id));
440 }
441 }
442
443 for &edge_id in removed_edges.iter() {
444 self.disconnect_by_edge_id(edge_id);
445 }
446
447 removed_edges
448 }
449
450 pub fn disconnect_by_edge_id(&mut self, edge_id: EdgeID) -> bool {
454 if let Some(edge) = self.edges.remove(edge_id.0) {
455 self.existing_edges.remove(&EdgeHash {
456 src_node: edge.src_node,
457 src_port: edge.src_port,
458 dst_node: edge.dst_node,
459 dst_port: edge.dst_port,
460 });
461
462 self.needs_compile = true;
463
464 true
465 } else {
466 false
467 }
468 }
469
470 pub fn edge(&self, edge_id: EdgeID) -> Option<&Edge> {
472 self.edges.get(edge_id.0)
473 }
474
475 fn remove_edges_with_input_port(
476 &mut self,
477 node_id: NodeID,
478 port_idx: PortIdx,
479 ) -> SmallVec<[EdgeID; 4]> {
480 let mut edges_to_remove = SmallVec::new();
481
482 for (edge_id, edge) in self.edges.iter() {
484 if edge.dst_node == node_id && edge.dst_port == port_idx {
485 edges_to_remove.push(EdgeID(edge_id));
486 }
487 }
488
489 for edge_id in edges_to_remove.iter() {
490 self.disconnect_by_edge_id(*edge_id);
491 }
492
493 edges_to_remove
494 }
495
496 fn remove_edges_with_output_port(
497 &mut self,
498 node_id: NodeID,
499 port_idx: PortIdx,
500 ) -> SmallVec<[EdgeID; 4]> {
501 let mut edges_to_remove = SmallVec::new();
502
503 for (edge_id, edge) in self.edges.iter() {
505 if edge.src_node == node_id && edge.src_port == port_idx {
506 edges_to_remove.push(EdgeID(edge_id));
507 }
508 }
509
510 for edge_id in edges_to_remove.iter() {
511 self.disconnect_by_edge_id(*edge_id);
512 }
513
514 edges_to_remove
515 }
516
517 pub fn cycle_detected(&mut self) -> bool {
518 compiler::cycle_detected(
519 &mut self.nodes,
520 &mut self.edges,
521 self.graph_in_id,
522 self.graph_out_id,
523 )
524 }
525
526 pub(crate) fn needs_compile(&self) -> bool {
527 self.needs_compile
528 }
529
530 pub(crate) fn on_schedule_send_failed(&mut self, failed_schedule: Box<ScheduleHeapData>) {
531 self.needs_compile = true;
532
533 for node in failed_schedule.new_node_processors.iter() {
534 if let Some(node_entry) = &mut self.nodes.get_mut(node.id.0) {
535 node_entry.processor_constructed = false;
536 }
537 }
538 }
539
540 pub(crate) fn deactivate(&mut self) {
541 self.needs_compile = true;
542 }
543
544 pub(crate) fn compile(
545 &mut self,
546 stream_info: &StreamInfo,
547 ) -> Result<Box<ScheduleHeapData>, CompileGraphError> {
548 let schedule = self.compile_internal(stream_info.max_block_frames.get() as usize)?;
549
550 let mut new_node_processors = Vec::new();
551 for (_, entry) in self.nodes.iter_mut() {
552 if !entry.processor_constructed {
553 entry.processor_constructed = true;
554
555 let cx = ConstructProcessorContext::new(
556 entry.id,
557 stream_info,
558 &mut entry.info.custom_state,
559 );
560
561 new_node_processors.push(NodeHeapData {
562 id: entry.id,
563 processor: entry.dyn_node.construct_processor(cx),
564 });
565 }
566 }
567
568 let mut nodes_to_remove = Vec::new();
569 core::mem::swap(
570 &mut self.nodes_to_remove_from_schedule,
571 &mut nodes_to_remove,
572 );
573
574 let new_arena = if self.nodes.capacity() > self.prev_node_arena_capacity {
575 Some(Arena::with_capacity(self.nodes.capacity()))
576 } else {
577 None
578 };
579 self.prev_node_arena_capacity = self.nodes.capacity();
580
581 let schedule_data = Box::new(ScheduleHeapData::new(
582 schedule,
583 nodes_to_remove,
584 new_node_processors,
585 new_arena,
586 ));
587
588 self.needs_compile = false;
589
590 log::debug!("compiled new audio graph: {:?}", &schedule_data);
591
592 Ok(schedule_data)
593 }
594
595 fn compile_internal(
596 &mut self,
597 max_block_frames: usize,
598 ) -> Result<CompiledSchedule, CompileGraphError> {
599 assert!(max_block_frames > 0);
600
601 compiler::compile(
602 &mut self.nodes,
603 &mut self.edges,
604 self.graph_in_id,
605 self.graph_out_id,
606 max_block_frames,
607 )
608 }
609
610 pub(crate) fn update(
611 &mut self,
612 stream_info: Option<&StreamInfo>,
613 event_queue: &mut Vec<NodeEvent>,
614 ) {
615 let mut cull_list = false;
616 for node_id in self.nodes_to_call_update_method.iter() {
617 if let Some(node_entry) = self.nodes.get_mut(node_id.0) {
618 node_entry.dyn_node.update(UpdateContext::new(
619 *node_id,
620 stream_info,
621 &mut node_entry.info.custom_state,
622 event_queue,
623 ));
624 } else {
625 cull_list = true;
626 }
627 }
628
629 if cull_list {
630 self.nodes_to_call_update_method
631 .retain(|node_id| self.nodes.contains(node_id.0));
632 }
633 }
634}