screech/
lib.rs

1//! A collection of helpers for handling audio data in real time
2
3#![no_std]
4#![warn(missing_docs)]
5
6extern crate alloc;
7
8mod graph;
9mod message;
10mod signal;
11mod signal_id;
12mod tracker;
13
14/// common traits used throughout the library
15pub mod traits;
16
17/// Identifier used for keeping track of signals
18pub type Input = signal_id::SignalId;
19
20/// Identifier used for keeping track of signals
21pub type Output = signal_id::SignalId;
22
23pub use message::Message;
24pub use signal::Signal;
25pub use tracker::{BasicTracker, DynamicTracker};
26
27use alloc::boxed::Box;
28use alloc::vec;
29use alloc::vec::Vec;
30use graph::topological_sort;
31use rustc_hash::FxHashMap;
32use signal_id::SignalId;
33use traits::{Source, Tracker};
34
35/// Main helper struct to render and manage relations between [`crate::traits::Source`] types.
36pub struct Screech<MessageData = ()> {
37    /// sample rate field used for sampling
38    pub sample_rate: usize,
39    outs: Vec<SignalId>,
40    id: usize,
41    sorted_cache: Option<Vec<usize>>,
42    tracker: Box<dyn Tracker<MessageData>>,
43}
44
45unsafe impl<T> Send for Screech<T> {}
46
47/// Error type for failure to execute [`Screech::sample`]
48#[derive(Debug, PartialEq)]
49pub enum ScreechError {
50    /// Dependency graph contains a cyclic dependency.
51    ///
52    /// for example, track A -> track B -> track A
53    CyclicDependencies,
54    /// Output buffer is missing but assigned to an input
55    MissingOutput,
56    /// no Input found for mains set with [`Screech::create_main_out`]
57    MissingInput,
58}
59
60impl<MessageData: 'static> Screech<MessageData> {
61    /// Create new Screech instance with a default tracker
62    pub fn new(buffer_size: usize, sample_rate: usize) -> Self {
63        Self::with_tracker(Box::new(DynamicTracker::new(buffer_size)), sample_rate)
64    }
65
66    /// Create a new Screech instance with a supplied tracker
67    ///
68    /// ```
69    /// use screech::{Screech, BasicTracker};
70    ///
71    /// let tracker = BasicTracker::<256>::new(8);
72    /// let screech = Screech::with_tracker(Box::new(tracker), 48_000);
73    /// ```
74    pub fn with_tracker(mut tracker: Box<dyn Tracker<MessageData>>, sample_rate: usize) -> Self {
75        Screech {
76            id: tracker.create_source_id(),
77            outs: vec![],
78            sample_rate,
79            sorted_cache: None,
80            tracker,
81        }
82    }
83
84    /// invalidate connections cache
85    pub fn invalidate_cache(&mut self) {
86        self.sorted_cache = None;
87    }
88
89    /// create new unique identifier
90    pub fn create_source_id(&mut self) -> usize {
91        self.tracker.create_source_id()
92    }
93
94    /// create new main output based on `&'static str` identifier
95    pub fn create_main_out(&mut self, signal_id: &'static str) {
96        let out = SignalId::new(self.id, signal_id);
97        self.tracker.init_output(&out);
98        self.tracker.init_input(&out);
99        self.outs.push(out);
100    }
101
102    /// return output [`Signal`] based on `&'static str` identifier
103    pub fn get_main_out(&self, signal_id: &'static str) -> Option<&Signal> {
104        self.outs
105            .iter()
106            .find(|s| s.get_signal_id() == signal_id)
107            .and_then(|out| self.tracker.get_output(&out))
108    }
109
110    /// create and initialize a new input
111    pub fn init_input(&mut self, source_id: &usize, signal_id: &'static str) -> Input {
112        let input = Input::new(*source_id, signal_id);
113        self.tracker.init_input(&input);
114        self.invalidate_cache();
115        input
116    }
117
118    /// create and initialize a new output
119    pub fn init_output(&mut self, source_id: &usize, signal_id: &'static str) -> Output {
120        let output = Output::new(*source_id, signal_id);
121        self.tracker.init_output(&output);
122        self.invalidate_cache();
123        output
124    }
125
126    /// connect an [`Output`] to an [`Input`]
127    pub fn connect_signal(&mut self, output: &Output, input: &Input) {
128        self.tracker.connect_signal(output, input);
129        self.invalidate_cache();
130    }
131
132    /// disconnect an [`Output`] to an [`Input`]
133    pub fn disconnect_signal(&mut self, output: &Output, input: &Input) {
134        self.tracker.clear_connection(output, input);
135        self.invalidate_cache();
136    }
137
138    /// connect an [`Output`] to a main output buffer
139    pub fn connect_signal_to_main_out(&mut self, output: &Output, signal_id: &'static str) {
140        if let Some(input) = self.outs.iter().find(|s| s.get_signal_id() == signal_id) {
141            self.tracker.connect_signal(output, input);
142            self.invalidate_cache();
143        }
144    }
145
146    /// disconnect an [`Output`] from a main output buffer
147    pub fn disconnect_signal_from_main_out(&mut self, output: &Output, signal_id: &'static str) {
148        if let Some(input) = self.outs.iter().find(|s| s.get_signal_id() == signal_id) {
149            self.tracker.clear_connection(output, input);
150            self.invalidate_cache();
151        }
152    }
153
154    /// Sample multiple sources based on their dependencies into [`Signal`]s stored in a
155    /// [`traits::Tracker`]
156    pub fn sample(
157        &mut self,
158        unmapped_sources: &mut [&mut dyn Source<MessageData>],
159    ) -> Result<(), ScreechError> {
160        // update dependency graph if needed
161        if let None = self.sorted_cache {
162            // @TODO: move this allocation outside of the `sample()` method
163            let mut graph = FxHashMap::<usize, Vec<usize>>::default();
164
165            for source in unmapped_sources.iter() {
166                let id = source.get_source_id();
167                let sources = self.tracker.get_sources(id);
168                graph.insert(*id, sources);
169            }
170
171            let sorted = topological_sort(graph);
172            self.sorted_cache = Some(sorted);
173        }
174
175        let sample_rate = self.sample_rate;
176
177        for source in unmapped_sources.iter_mut() {
178            for key in self.sorted_cache.as_ref().unwrap().iter() {
179                if key == source.get_source_id() {
180                    source.sample(self.tracker.as_mut(), sample_rate);
181                }
182            }
183        }
184
185        // clear message queue
186        self.tracker.clear_messages();
187
188        // generate output
189        for out in &self.outs {
190            let length = *self.tracker.get_buffer_size();
191            let mut samples = vec![0.0; length];
192
193            let inputs = self
194                .tracker
195                .get_input(&out)
196                .ok_or(ScreechError::MissingInput)?;
197
198            for input in inputs {
199                if let Some(input) = self.tracker.get_output(&input) {
200                    for i in 0..length {
201                        samples[i] += input.samples[i];
202                    }
203                }
204            }
205
206            let output_signal = self
207                .tracker
208                .get_mut_output(&out)
209                .ok_or(ScreechError::MissingOutput)?;
210
211            for i in 0..length {
212                output_signal.samples[i] = samples[i];
213            }
214
215            // samples.clear();
216        }
217
218        Ok(())
219    }
220}