euphony_graph/
lib.rs

1// #![no_std]
2
3extern crate alloc;
4
5use alloc::{
6    boxed::Box,
7    collections::{BTreeMap, BTreeSet, VecDeque},
8    vec,
9    vec::Vec,
10};
11use core::{cell::UnsafeCell, fmt, ops};
12use slotmap::SlotMap;
13
14slotmap::new_key_type! { struct Key; }
15
16#[derive(Clone, Debug, PartialEq, Eq)]
17pub enum Error<Parameter> {
18    MissingNode(u64),
19    InvalidParameter(u64, Parameter),
20    CycleDetected,
21}
22
23type NodeMap<C> = SlotMap<Key, Node<C>>;
24
25pub trait Config: 'static {
26    type Output: 'static + Send;
27    type Parameter: 'static + Send;
28    type Value: 'static + Send;
29    type Context: 'static + Send + Sync;
30}
31
32pub trait Processor<C: Config>: 'static + Send {
33    fn set(
34        &mut self,
35        parameter: C::Parameter,
36        key: Input<C::Value>,
37    ) -> Result<Input<C::Value>, C::Parameter>;
38
39    fn remove(&mut self, key: NodeKey);
40
41    fn output(&self) -> &C::Output;
42
43    fn output_mut(&mut self) -> &mut C::Output;
44
45    fn process(&mut self, inputs: Inputs<C>, context: &C::Context);
46}
47
48#[derive(Debug)]
49pub struct Graph<C: Config> {
50    nodes: NodeMap<C>,
51    ids: BTreeMap<u64, Key>,
52    levels: Vec<BTreeSet<Key>>,
53    dirty: BTreeMap<Key, DirtyState>,
54    stack: VecDeque<Key>,
55}
56
57impl<C: Config> Default for Graph<C> {
58    #[inline]
59    fn default() -> Self {
60        Self {
61            nodes: Default::default(),
62            ids: Default::default(),
63            levels: vec![Default::default()],
64            dirty: Default::default(),
65            stack: Default::default(),
66        }
67    }
68}
69
70#[derive(Clone, Copy, Debug)]
71enum DirtyState {
72    Initial,
73    Pending,
74    Done(u16),
75}
76
77impl Default for DirtyState {
78    #[inline]
79    fn default() -> Self {
80        Self::Initial
81    }
82}
83
84impl<C: Config> Graph<C> {
85    #[inline]
86    pub fn process(&mut self, context: &C::Context) {
87        debug_assert!(
88            self.dirty.is_empty(),
89            "need to call `update` before `process`"
90        );
91
92        for level in &self.levels {
93            let nodes = &self.nodes;
94
95            #[cfg(any(test, feature = "rayon"))]
96            {
97                use rayon::prelude::*;
98                level.par_iter().for_each(|key| {
99                    nodes[*key].render(nodes, context);
100                });
101            }
102
103            #[cfg(not(any(test, feature = "rayon")))]
104            {
105                level.iter().for_each(|key| {
106                    nodes[*key].render(nodes, context);
107                });
108            }
109        }
110    }
111
112    #[inline]
113    pub fn insert(&mut self, id: u64, processor: Box<dyn Processor<C>>) {
114        let node = Node::new(id, processor);
115        let key = self.nodes.insert(node);
116        self.ids.insert(id, key);
117        self.levels[0].insert(key);
118
119        self.ensure_consistency();
120    }
121
122    #[inline]
123    pub fn set(
124        &mut self,
125        target: u64,
126        param: C::Parameter,
127        value: C::Value,
128    ) -> Result<(), Error<C::Parameter>> {
129        let idx = *self.ids.get(&target).ok_or(Error::MissingNode(target))?;
130
131        let node = unsafe { self.nodes.get_unchecked_mut(idx) };
132
133        let prev = node
134            .set(param, Input::Value(value))
135            .map_err(|param| Error::InvalidParameter(target, param))?;
136
137        if let Input::Node(prev) = prev {
138            // if we went from a node input to a constant, we need to recalc
139            self.dirty.insert(idx, Default::default());
140            node.parents.remove(prev.0);
141
142            // tell the parent we are no longer a child
143            let prev = unsafe { self.nodes.get_unchecked_mut(prev.0) };
144            prev.children.remove(idx);
145        }
146
147        self.ensure_consistency();
148
149        Ok(())
150    }
151
152    #[inline]
153    pub fn connect(
154        &mut self,
155        target: u64,
156        param: C::Parameter,
157        source: u64,
158    ) -> Result<(), Error<C::Parameter>> {
159        if target == source {
160            return Err(Error::CycleDetected);
161        }
162
163        let idx = *self.ids.get(&target).ok_or(Error::MissingNode(target))?;
164
165        let source_key = *self.ids.get(&source).ok_or(Error::MissingNode(source))?;
166        let source = unsafe { self.nodes.get_unchecked_mut(source_key) };
167        source.children.insert(idx);
168        let source_level = source.level;
169
170        let node = unsafe { self.nodes.get_unchecked_mut(idx) };
171        let prev = node
172            .set(param, Input::Node(NodeKey(source_key)))
173            .map_err(|param| Error::InvalidParameter(target, param))?;
174        node.parents.insert(source_key);
175
176        if let Input::Node(prev) = prev {
177            node.parents.remove(prev.0);
178
179            let prev = unsafe { self.nodes.get_unchecked_mut(prev.0) };
180            prev.children.remove(idx);
181            let prev_level = prev.level;
182
183            // the node is only dirty if the levels have changed
184            if source_level != prev_level {
185                self.dirty.insert(idx, Default::default());
186            }
187        } else {
188            // going from a constant to a node will require recalc
189            self.dirty.insert(idx, Default::default());
190        }
191
192        self.ensure_consistency();
193
194        Ok(())
195    }
196
197    #[inline]
198    pub fn remove(&mut self, id: u64) -> Result<Box<dyn Processor<C>>, Error<C::Parameter>> {
199        let key = self.ids.remove(&id).ok_or(Error::MissingNode(id))?;
200        let node = self.nodes.remove(key).unwrap();
201
202        // the node is no longer part of the levels
203        self.levels[node.level as usize].remove(&key);
204        self.dirty.remove(&key);
205
206        // notify our children that we're finished
207        for child_key in node.children.iter() {
208            let child = unsafe { self.nodes.get_unchecked_mut(child_key) };
209            child.clear_parent(key);
210
211            // if the child's level matches the node's, it needs to be recalculated
212            if child.level == node.level + 1 {
213                self.dirty.insert(child_key, Default::default());
214            }
215        }
216
217        // notify our parents that we're finished
218        for parent_key in node.parents.iter() {
219            let parent = unsafe { self.nodes.get_unchecked_mut(parent_key) };
220            parent.children.clear(key);
221        }
222
223        self.ensure_consistency();
224
225        Ok(node.processor.into_inner())
226    }
227
228    #[inline]
229    pub fn get(&self, id: u64) -> Result<&C::Output, Error<C::Parameter>> {
230        let key = self.ids.get(&id).ok_or(Error::MissingNode(id))?;
231        let node = unsafe { self.nodes.get_unchecked(*key) };
232        let output = node.output();
233        Ok(output)
234    }
235
236    #[inline]
237    pub fn get_mut(&mut self, id: u64) -> Result<&mut C::Output, Error<C::Parameter>> {
238        let key = self.ids.get(&id).ok_or(Error::MissingNode(id))?;
239        let node = unsafe { self.nodes.get_unchecked_mut(*key) };
240        let output = node.output_mut();
241        Ok(output)
242    }
243
244    #[inline]
245    pub fn update(&mut self) -> Result<(), Error<C::Parameter>> {
246        if self.dirty.is_empty() {
247            return Ok(());
248        }
249
250        // queue up all of the updates
251        self.stack.extend(self.dirty.keys().copied());
252
253        while let Some(key) = self.stack.pop_front() {
254            let node = unsafe { self.nodes.get_unchecked(key) };
255            let mut was_repushed = false;
256            let mut new_level = 0u16;
257
258            for parent in node.parents.iter() {
259                if let Some(parent_state) = self.dirty.get(&parent).copied() {
260                    match parent_state {
261                        DirtyState::Initial => {
262                            if !core::mem::replace(&mut was_repushed, true) {
263                                self.dirty.insert(key, DirtyState::Pending);
264                                self.stack.push_front(key);
265                            }
266
267                            self.stack.push_front(parent);
268                        }
269                        DirtyState::Pending => {
270                            return Err(Error::CycleDetected);
271                        }
272                        DirtyState::Done(parent_level) => {
273                            new_level = new_level.max(parent_level + 1);
274                        }
275                    }
276                } else if !was_repushed {
277                    let parent = unsafe { self.nodes.get_unchecked(parent) };
278                    new_level = new_level.max(parent.level + 1);
279                }
280            }
281
282            if was_repushed {
283                continue;
284            }
285
286            if let Some(DirtyState::Done(prev_level)) =
287                self.dirty.insert(key, DirtyState::Done(new_level))
288            {
289                if prev_level != new_level {
290                    return Err(Error::CycleDetected);
291                }
292
293                continue;
294            }
295
296            if node.level == new_level {
297                continue;
298            }
299
300            self.levels[node.level as usize].remove(&key);
301
302            // the children need to be updated now
303            for child in node.children.iter() {
304                self.stack.push_back(child);
305            }
306
307            let node = unsafe { self.nodes.get_unchecked_mut(key) };
308            node.level = new_level;
309
310            let new_level = new_level as usize;
311            if self.levels.len() <= new_level {
312                self.levels.resize_with(new_level + 1, Default::default);
313            }
314            self.levels[new_level].insert(key);
315        }
316
317        self.dirty.clear();
318
319        self.ensure_consistency();
320
321        Ok(())
322    }
323
324    #[inline(always)]
325    #[cfg(not(debug_assertions))]
326    fn ensure_consistency(&self) {}
327
328    #[inline]
329    #[cfg(debug_assertions)]
330    fn ensure_consistency(&self) {
331        // ensure the ids aren't referencing a freed node
332        for (id, key) in self.ids.iter() {
333            let node = self.nodes.get(*key).unwrap();
334            assert_eq!(*id, node.id);
335        }
336
337        // ensure the nodes match the expected id
338        for (key, node) in self.nodes.iter() {
339            let actual = *self.ids.get(&node.id).unwrap();
340            assert_eq!(actual, key);
341        }
342
343        // ensure the levels don't have freed nodes
344        for level in &self.levels {
345            for key in level {
346                assert!(self.nodes.contains_key(*key));
347            }
348        }
349
350        for key in self.nodes.keys() {
351            let node = &self.nodes[key];
352
353            for child_key in node.children.iter() {
354                let child = &self.nodes[child_key];
355                assert!(child.parents.0.contains_key(&key));
356            }
357
358            for parent_key in node.parents.iter() {
359                let parent = &self.nodes[parent_key];
360                assert!(parent.children.0.contains_key(&key));
361            }
362
363            assert!(self.levels[node.level as usize].contains(&key));
364
365            // the following checks require the node to be clean
366            if self.dirty.contains_key(&key) {
367                continue;
368            }
369
370            let mut expected = 0;
371
372            for parent in node.parents.iter() {
373                let parent = self.nodes[parent].level;
374                expected = expected.max(parent + 1);
375            }
376
377            assert_eq!(node.level, expected, "level mismatch");
378        }
379    }
380}
381
382#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
383pub struct NodeKey(Key);
384
385#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
386pub enum Input<Value> {
387    Value(Value),
388    Node(NodeKey),
389}
390
391pub struct Inputs<'a, C: Config> {
392    nodes: &'a NodeMap<C>,
393    #[cfg(debug_assertions)]
394    parents: &'a Relationship,
395}
396
397impl<'a, C: Config> ops::Index<NodeKey> for Inputs<'a, C> {
398    type Output = C::Output;
399
400    #[inline]
401    fn index(&self, key: NodeKey) -> &Self::Output {
402        debug_assert!(self.nodes.contains_key(key.0));
403
404        #[cfg(debug_assertions)]
405        {
406            assert!(
407                self.parents.0.contains_key(&key.0),
408                "node should only access its configured parents"
409            );
410        }
411
412        unsafe { self.nodes.get_unchecked(key.0).output() }
413    }
414}
415
416struct Node<C: Config> {
417    #[cfg(debug_assertions)]
418    id: u64,
419    processor: UnsafeCell<Box<dyn Processor<C>>>,
420    level: u16,
421    parents: Relationship,
422    children: Relationship,
423}
424
425impl<C: Config> fmt::Debug for Node<C> {
426    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
427        let mut s = f.debug_struct("Node");
428
429        #[cfg(debug_assertions)]
430        s.field("id", &self.id);
431
432        s.field("level", &self.level)
433            .field("parents", &self.parents)
434            .field("children", &self.children)
435            .finish()
436    }
437}
438
439/// Safety: Mutual exclusion is ensured by level organization
440unsafe impl<C: Config> Sync for Node<C> {}
441
442impl<C: Config> Node<C> {
443    #[inline]
444    fn new(id: u64, processor: Box<dyn Processor<C>>) -> Self {
445        let _ = id;
446        Self {
447            #[cfg(debug_assertions)]
448            id,
449            processor: UnsafeCell::new(processor),
450            level: 0,
451            parents: Default::default(),
452            children: Default::default(),
453        }
454    }
455
456    #[inline]
457    fn set(
458        &mut self,
459        param: C::Parameter,
460        value: Input<C::Value>,
461    ) -> Result<Input<C::Value>, C::Parameter> {
462        let processor = unsafe { &mut *self.processor.get() };
463        processor.set(param, value)
464    }
465
466    #[inline]
467    fn clear_parent(&mut self, key: Key) {
468        let processor = unsafe { &mut *self.processor.get() };
469        processor.remove(NodeKey(key));
470        self.parents.clear(key);
471    }
472
473    #[inline]
474    fn render(&self, nodes: &NodeMap<C>, context: &C::Context) {
475        let inputs = Inputs {
476            nodes,
477            #[cfg(debug_assertions)]
478            parents: &self.parents,
479        };
480
481        let processor = unsafe { &mut *self.processor.get() };
482
483        processor.process(inputs, context);
484    }
485
486    #[inline]
487    fn output(&self) -> &C::Output {
488        let processor = unsafe { &*self.processor.get() };
489        processor.output()
490    }
491
492    #[inline]
493    fn output_mut(&mut self) -> &mut C::Output {
494        let processor = unsafe { &mut *self.processor.get() };
495        processor.output_mut()
496    }
497}
498
499#[derive(Clone, Debug, Default)]
500struct Relationship(BTreeMap<Key, u16>);
501
502impl Relationship {
503    #[inline]
504    pub fn insert(&mut self, key: Key) {
505        *self.0.entry(key).or_default() += 1;
506    }
507
508    #[inline]
509    pub fn remove(&mut self, key: Key) {
510        let new_count = self.0.remove(&key).unwrap_or(1) - 1;
511        if new_count != 0 {
512            self.0.insert(key, new_count);
513        }
514    }
515
516    #[inline]
517    pub fn clear(&mut self, key: Key) {
518        self.0.remove(&key);
519    }
520
521    #[inline]
522    pub fn iter(&self) -> impl Iterator<Item = Key> + '_ {
523        self.0.keys().copied()
524    }
525}
526
527#[cfg(test)]
528mod tests;