1#![no_std]
4#![warn(missing_docs)]
5
6extern crate alloc;
7
8mod graph;
9mod message;
10mod signal;
11mod signal_id;
12mod tracker;
13
14pub mod traits;
16
17pub type Input = signal_id::SignalId;
19
20pub type Output = signal_id::SignalId;
22
23pub use message::Message;
24pub use signal::Signal;
25pub use tracker::{BasicTracker, DynamicTracker};
26
27use alloc::boxed::Box;
28use alloc::vec;
29use alloc::vec::Vec;
30use graph::topological_sort;
31use rustc_hash::FxHashMap;
32use signal_id::SignalId;
33use traits::{Source, Tracker};
34
35pub struct Screech<MessageData = ()> {
37 pub sample_rate: usize,
39 outs: Vec<SignalId>,
40 id: usize,
41 sorted_cache: Option<Vec<usize>>,
42 tracker: Box<dyn Tracker<MessageData>>,
43}
44
45unsafe impl<T> Send for Screech<T> {}
46
47#[derive(Debug, PartialEq)]
49pub enum ScreechError {
50 CyclicDependencies,
54 MissingOutput,
56 MissingInput,
58}
59
60impl<MessageData: 'static> Screech<MessageData> {
61 pub fn new(buffer_size: usize, sample_rate: usize) -> Self {
63 Self::with_tracker(Box::new(DynamicTracker::new(buffer_size)), sample_rate)
64 }
65
66 pub fn with_tracker(mut tracker: Box<dyn Tracker<MessageData>>, sample_rate: usize) -> Self {
75 Screech {
76 id: tracker.create_source_id(),
77 outs: vec![],
78 sample_rate,
79 sorted_cache: None,
80 tracker,
81 }
82 }
83
84 pub fn invalidate_cache(&mut self) {
86 self.sorted_cache = None;
87 }
88
89 pub fn create_source_id(&mut self) -> usize {
91 self.tracker.create_source_id()
92 }
93
94 pub fn create_main_out(&mut self, signal_id: &'static str) {
96 let out = SignalId::new(self.id, signal_id);
97 self.tracker.init_output(&out);
98 self.tracker.init_input(&out);
99 self.outs.push(out);
100 }
101
102 pub fn get_main_out(&self, signal_id: &'static str) -> Option<&Signal> {
104 self.outs
105 .iter()
106 .find(|s| s.get_signal_id() == signal_id)
107 .and_then(|out| self.tracker.get_output(&out))
108 }
109
110 pub fn init_input(&mut self, source_id: &usize, signal_id: &'static str) -> Input {
112 let input = Input::new(*source_id, signal_id);
113 self.tracker.init_input(&input);
114 self.invalidate_cache();
115 input
116 }
117
118 pub fn init_output(&mut self, source_id: &usize, signal_id: &'static str) -> Output {
120 let output = Output::new(*source_id, signal_id);
121 self.tracker.init_output(&output);
122 self.invalidate_cache();
123 output
124 }
125
126 pub fn connect_signal(&mut self, output: &Output, input: &Input) {
128 self.tracker.connect_signal(output, input);
129 self.invalidate_cache();
130 }
131
132 pub fn disconnect_signal(&mut self, output: &Output, input: &Input) {
134 self.tracker.clear_connection(output, input);
135 self.invalidate_cache();
136 }
137
138 pub fn connect_signal_to_main_out(&mut self, output: &Output, signal_id: &'static str) {
140 if let Some(input) = self.outs.iter().find(|s| s.get_signal_id() == signal_id) {
141 self.tracker.connect_signal(output, input);
142 self.invalidate_cache();
143 }
144 }
145
146 pub fn disconnect_signal_from_main_out(&mut self, output: &Output, signal_id: &'static str) {
148 if let Some(input) = self.outs.iter().find(|s| s.get_signal_id() == signal_id) {
149 self.tracker.clear_connection(output, input);
150 self.invalidate_cache();
151 }
152 }
153
154 pub fn sample(
157 &mut self,
158 unmapped_sources: &mut [&mut dyn Source<MessageData>],
159 ) -> Result<(), ScreechError> {
160 if let None = self.sorted_cache {
162 let mut graph = FxHashMap::<usize, Vec<usize>>::default();
164
165 for source in unmapped_sources.iter() {
166 let id = source.get_source_id();
167 let sources = self.tracker.get_sources(id);
168 graph.insert(*id, sources);
169 }
170
171 let sorted = topological_sort(graph);
172 self.sorted_cache = Some(sorted);
173 }
174
175 let sample_rate = self.sample_rate;
176
177 for source in unmapped_sources.iter_mut() {
178 for key in self.sorted_cache.as_ref().unwrap().iter() {
179 if key == source.get_source_id() {
180 source.sample(self.tracker.as_mut(), sample_rate);
181 }
182 }
183 }
184
185 self.tracker.clear_messages();
187
188 for out in &self.outs {
190 let length = *self.tracker.get_buffer_size();
191 let mut samples = vec![0.0; length];
192
193 let inputs = self
194 .tracker
195 .get_input(&out)
196 .ok_or(ScreechError::MissingInput)?;
197
198 for input in inputs {
199 if let Some(input) = self.tracker.get_output(&input) {
200 for i in 0..length {
201 samples[i] += input.samples[i];
202 }
203 }
204 }
205
206 let output_signal = self
207 .tracker
208 .get_mut_output(&out)
209 .ok_or(ScreechError::MissingOutput)?;
210
211 for i in 0..length {
212 output_signal.samples[i] = samples[i];
213 }
214
215 }
217
218 Ok(())
219 }
220}