1extern crate alloc;
4
5use alloc::{
6 boxed::Box,
7 collections::{BTreeMap, BTreeSet, VecDeque},
8 vec,
9 vec::Vec,
10};
11use core::{cell::UnsafeCell, fmt, ops};
12use slotmap::SlotMap;
13
14slotmap::new_key_type! { struct Key; }
15
16#[derive(Clone, Debug, PartialEq, Eq)]
17pub enum Error<Parameter> {
18 MissingNode(u64),
19 InvalidParameter(u64, Parameter),
20 CycleDetected,
21}
22
23type NodeMap<C> = SlotMap<Key, Node<C>>;
24
25pub trait Config: 'static {
26 type Output: 'static + Send;
27 type Parameter: 'static + Send;
28 type Value: 'static + Send;
29 type Context: 'static + Send + Sync;
30}
31
32pub trait Processor<C: Config>: 'static + Send {
33 fn set(
34 &mut self,
35 parameter: C::Parameter,
36 key: Input<C::Value>,
37 ) -> Result<Input<C::Value>, C::Parameter>;
38
39 fn remove(&mut self, key: NodeKey);
40
41 fn output(&self) -> &C::Output;
42
43 fn output_mut(&mut self) -> &mut C::Output;
44
45 fn process(&mut self, inputs: Inputs<C>, context: &C::Context);
46}
47
48#[derive(Debug)]
49pub struct Graph<C: Config> {
50 nodes: NodeMap<C>,
51 ids: BTreeMap<u64, Key>,
52 levels: Vec<BTreeSet<Key>>,
53 dirty: BTreeMap<Key, DirtyState>,
54 stack: VecDeque<Key>,
55}
56
57impl<C: Config> Default for Graph<C> {
58 #[inline]
59 fn default() -> Self {
60 Self {
61 nodes: Default::default(),
62 ids: Default::default(),
63 levels: vec![Default::default()],
64 dirty: Default::default(),
65 stack: Default::default(),
66 }
67 }
68}
69
70#[derive(Clone, Copy, Debug)]
71enum DirtyState {
72 Initial,
73 Pending,
74 Done(u16),
75}
76
77impl Default for DirtyState {
78 #[inline]
79 fn default() -> Self {
80 Self::Initial
81 }
82}
83
84impl<C: Config> Graph<C> {
85 #[inline]
86 pub fn process(&mut self, context: &C::Context) {
87 debug_assert!(
88 self.dirty.is_empty(),
89 "need to call `update` before `process`"
90 );
91
92 for level in &self.levels {
93 let nodes = &self.nodes;
94
95 #[cfg(any(test, feature = "rayon"))]
96 {
97 use rayon::prelude::*;
98 level.par_iter().for_each(|key| {
99 nodes[*key].render(nodes, context);
100 });
101 }
102
103 #[cfg(not(any(test, feature = "rayon")))]
104 {
105 level.iter().for_each(|key| {
106 nodes[*key].render(nodes, context);
107 });
108 }
109 }
110 }
111
112 #[inline]
113 pub fn insert(&mut self, id: u64, processor: Box<dyn Processor<C>>) {
114 let node = Node::new(id, processor);
115 let key = self.nodes.insert(node);
116 self.ids.insert(id, key);
117 self.levels[0].insert(key);
118
119 self.ensure_consistency();
120 }
121
122 #[inline]
123 pub fn set(
124 &mut self,
125 target: u64,
126 param: C::Parameter,
127 value: C::Value,
128 ) -> Result<(), Error<C::Parameter>> {
129 let idx = *self.ids.get(&target).ok_or(Error::MissingNode(target))?;
130
131 let node = unsafe { self.nodes.get_unchecked_mut(idx) };
132
133 let prev = node
134 .set(param, Input::Value(value))
135 .map_err(|param| Error::InvalidParameter(target, param))?;
136
137 if let Input::Node(prev) = prev {
138 self.dirty.insert(idx, Default::default());
140 node.parents.remove(prev.0);
141
142 let prev = unsafe { self.nodes.get_unchecked_mut(prev.0) };
144 prev.children.remove(idx);
145 }
146
147 self.ensure_consistency();
148
149 Ok(())
150 }
151
152 #[inline]
153 pub fn connect(
154 &mut self,
155 target: u64,
156 param: C::Parameter,
157 source: u64,
158 ) -> Result<(), Error<C::Parameter>> {
159 if target == source {
160 return Err(Error::CycleDetected);
161 }
162
163 let idx = *self.ids.get(&target).ok_or(Error::MissingNode(target))?;
164
165 let source_key = *self.ids.get(&source).ok_or(Error::MissingNode(source))?;
166 let source = unsafe { self.nodes.get_unchecked_mut(source_key) };
167 source.children.insert(idx);
168 let source_level = source.level;
169
170 let node = unsafe { self.nodes.get_unchecked_mut(idx) };
171 let prev = node
172 .set(param, Input::Node(NodeKey(source_key)))
173 .map_err(|param| Error::InvalidParameter(target, param))?;
174 node.parents.insert(source_key);
175
176 if let Input::Node(prev) = prev {
177 node.parents.remove(prev.0);
178
179 let prev = unsafe { self.nodes.get_unchecked_mut(prev.0) };
180 prev.children.remove(idx);
181 let prev_level = prev.level;
182
183 if source_level != prev_level {
185 self.dirty.insert(idx, Default::default());
186 }
187 } else {
188 self.dirty.insert(idx, Default::default());
190 }
191
192 self.ensure_consistency();
193
194 Ok(())
195 }
196
197 #[inline]
198 pub fn remove(&mut self, id: u64) -> Result<Box<dyn Processor<C>>, Error<C::Parameter>> {
199 let key = self.ids.remove(&id).ok_or(Error::MissingNode(id))?;
200 let node = self.nodes.remove(key).unwrap();
201
202 self.levels[node.level as usize].remove(&key);
204 self.dirty.remove(&key);
205
206 for child_key in node.children.iter() {
208 let child = unsafe { self.nodes.get_unchecked_mut(child_key) };
209 child.clear_parent(key);
210
211 if child.level == node.level + 1 {
213 self.dirty.insert(child_key, Default::default());
214 }
215 }
216
217 for parent_key in node.parents.iter() {
219 let parent = unsafe { self.nodes.get_unchecked_mut(parent_key) };
220 parent.children.clear(key);
221 }
222
223 self.ensure_consistency();
224
225 Ok(node.processor.into_inner())
226 }
227
228 #[inline]
229 pub fn get(&self, id: u64) -> Result<&C::Output, Error<C::Parameter>> {
230 let key = self.ids.get(&id).ok_or(Error::MissingNode(id))?;
231 let node = unsafe { self.nodes.get_unchecked(*key) };
232 let output = node.output();
233 Ok(output)
234 }
235
236 #[inline]
237 pub fn get_mut(&mut self, id: u64) -> Result<&mut C::Output, Error<C::Parameter>> {
238 let key = self.ids.get(&id).ok_or(Error::MissingNode(id))?;
239 let node = unsafe { self.nodes.get_unchecked_mut(*key) };
240 let output = node.output_mut();
241 Ok(output)
242 }
243
244 #[inline]
245 pub fn update(&mut self) -> Result<(), Error<C::Parameter>> {
246 if self.dirty.is_empty() {
247 return Ok(());
248 }
249
250 self.stack.extend(self.dirty.keys().copied());
252
253 while let Some(key) = self.stack.pop_front() {
254 let node = unsafe { self.nodes.get_unchecked(key) };
255 let mut was_repushed = false;
256 let mut new_level = 0u16;
257
258 for parent in node.parents.iter() {
259 if let Some(parent_state) = self.dirty.get(&parent).copied() {
260 match parent_state {
261 DirtyState::Initial => {
262 if !core::mem::replace(&mut was_repushed, true) {
263 self.dirty.insert(key, DirtyState::Pending);
264 self.stack.push_front(key);
265 }
266
267 self.stack.push_front(parent);
268 }
269 DirtyState::Pending => {
270 return Err(Error::CycleDetected);
271 }
272 DirtyState::Done(parent_level) => {
273 new_level = new_level.max(parent_level + 1);
274 }
275 }
276 } else if !was_repushed {
277 let parent = unsafe { self.nodes.get_unchecked(parent) };
278 new_level = new_level.max(parent.level + 1);
279 }
280 }
281
282 if was_repushed {
283 continue;
284 }
285
286 if let Some(DirtyState::Done(prev_level)) =
287 self.dirty.insert(key, DirtyState::Done(new_level))
288 {
289 if prev_level != new_level {
290 return Err(Error::CycleDetected);
291 }
292
293 continue;
294 }
295
296 if node.level == new_level {
297 continue;
298 }
299
300 self.levels[node.level as usize].remove(&key);
301
302 for child in node.children.iter() {
304 self.stack.push_back(child);
305 }
306
307 let node = unsafe { self.nodes.get_unchecked_mut(key) };
308 node.level = new_level;
309
310 let new_level = new_level as usize;
311 if self.levels.len() <= new_level {
312 self.levels.resize_with(new_level + 1, Default::default);
313 }
314 self.levels[new_level].insert(key);
315 }
316
317 self.dirty.clear();
318
319 self.ensure_consistency();
320
321 Ok(())
322 }
323
324 #[inline(always)]
325 #[cfg(not(debug_assertions))]
326 fn ensure_consistency(&self) {}
327
328 #[inline]
329 #[cfg(debug_assertions)]
330 fn ensure_consistency(&self) {
331 for (id, key) in self.ids.iter() {
333 let node = self.nodes.get(*key).unwrap();
334 assert_eq!(*id, node.id);
335 }
336
337 for (key, node) in self.nodes.iter() {
339 let actual = *self.ids.get(&node.id).unwrap();
340 assert_eq!(actual, key);
341 }
342
343 for level in &self.levels {
345 for key in level {
346 assert!(self.nodes.contains_key(*key));
347 }
348 }
349
350 for key in self.nodes.keys() {
351 let node = &self.nodes[key];
352
353 for child_key in node.children.iter() {
354 let child = &self.nodes[child_key];
355 assert!(child.parents.0.contains_key(&key));
356 }
357
358 for parent_key in node.parents.iter() {
359 let parent = &self.nodes[parent_key];
360 assert!(parent.children.0.contains_key(&key));
361 }
362
363 assert!(self.levels[node.level as usize].contains(&key));
364
365 if self.dirty.contains_key(&key) {
367 continue;
368 }
369
370 let mut expected = 0;
371
372 for parent in node.parents.iter() {
373 let parent = self.nodes[parent].level;
374 expected = expected.max(parent + 1);
375 }
376
377 assert_eq!(node.level, expected, "level mismatch");
378 }
379 }
380}
381
382#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
383pub struct NodeKey(Key);
384
385#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
386pub enum Input<Value> {
387 Value(Value),
388 Node(NodeKey),
389}
390
391pub struct Inputs<'a, C: Config> {
392 nodes: &'a NodeMap<C>,
393 #[cfg(debug_assertions)]
394 parents: &'a Relationship,
395}
396
397impl<'a, C: Config> ops::Index<NodeKey> for Inputs<'a, C> {
398 type Output = C::Output;
399
400 #[inline]
401 fn index(&self, key: NodeKey) -> &Self::Output {
402 debug_assert!(self.nodes.contains_key(key.0));
403
404 #[cfg(debug_assertions)]
405 {
406 assert!(
407 self.parents.0.contains_key(&key.0),
408 "node should only access its configured parents"
409 );
410 }
411
412 unsafe { self.nodes.get_unchecked(key.0).output() }
413 }
414}
415
416struct Node<C: Config> {
417 #[cfg(debug_assertions)]
418 id: u64,
419 processor: UnsafeCell<Box<dyn Processor<C>>>,
420 level: u16,
421 parents: Relationship,
422 children: Relationship,
423}
424
425impl<C: Config> fmt::Debug for Node<C> {
426 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
427 let mut s = f.debug_struct("Node");
428
429 #[cfg(debug_assertions)]
430 s.field("id", &self.id);
431
432 s.field("level", &self.level)
433 .field("parents", &self.parents)
434 .field("children", &self.children)
435 .finish()
436 }
437}
438
439unsafe impl<C: Config> Sync for Node<C> {}
441
442impl<C: Config> Node<C> {
443 #[inline]
444 fn new(id: u64, processor: Box<dyn Processor<C>>) -> Self {
445 let _ = id;
446 Self {
447 #[cfg(debug_assertions)]
448 id,
449 processor: UnsafeCell::new(processor),
450 level: 0,
451 parents: Default::default(),
452 children: Default::default(),
453 }
454 }
455
456 #[inline]
457 fn set(
458 &mut self,
459 param: C::Parameter,
460 value: Input<C::Value>,
461 ) -> Result<Input<C::Value>, C::Parameter> {
462 let processor = unsafe { &mut *self.processor.get() };
463 processor.set(param, value)
464 }
465
466 #[inline]
467 fn clear_parent(&mut self, key: Key) {
468 let processor = unsafe { &mut *self.processor.get() };
469 processor.remove(NodeKey(key));
470 self.parents.clear(key);
471 }
472
473 #[inline]
474 fn render(&self, nodes: &NodeMap<C>, context: &C::Context) {
475 let inputs = Inputs {
476 nodes,
477 #[cfg(debug_assertions)]
478 parents: &self.parents,
479 };
480
481 let processor = unsafe { &mut *self.processor.get() };
482
483 processor.process(inputs, context);
484 }
485
486 #[inline]
487 fn output(&self) -> &C::Output {
488 let processor = unsafe { &*self.processor.get() };
489 processor.output()
490 }
491
492 #[inline]
493 fn output_mut(&mut self) -> &mut C::Output {
494 let processor = unsafe { &mut *self.processor.get() };
495 processor.output_mut()
496 }
497}
498
499#[derive(Clone, Debug, Default)]
500struct Relationship(BTreeMap<Key, u16>);
501
502impl Relationship {
503 #[inline]
504 pub fn insert(&mut self, key: Key) {
505 *self.0.entry(key).or_default() += 1;
506 }
507
508 #[inline]
509 pub fn remove(&mut self, key: Key) {
510 let new_count = self.0.remove(&key).unwrap_or(1) - 1;
511 if new_count != 0 {
512 self.0.insert(key, new_count);
513 }
514 }
515
516 #[inline]
517 pub fn clear(&mut self, key: Key) {
518 self.0.remove(&key);
519 }
520
521 #[inline]
522 pub fn iter(&self) -> impl Iterator<Item = Key> + '_ {
523 self.0.keys().copied()
524 }
525}
526
527#[cfg(test)]
528mod tests;