directed/
stage.rs

1use crate::DynFields;
2use crate::{
3    InjectionError,
4    node::{AnyNode, Node},
5};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Copy, PartialEq, Hash)]
9pub enum RefType {
10    Owned,
11    Borrowed,
12    BorrowedMut,
13}
14
15/// Type reflection for graph I/O
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub struct StageShape {
18    pub stage_name: &'static str,
19    pub inputs: &'static [&'static str],
20    pub outputs: &'static [&'static str],
21}
22
23/// Defines all the information about how a stage is handled.
24#[cfg_attr(feature = "tokio", async_trait::async_trait)]
25pub trait Stage: Clone + 'static {
26    /// Used for reflection
27    const SHAPE: StageShape;
28    /// Internal state only, no special rules apply to this. This is stored
29    /// as a tuple of all state parameters in order.
30    /// TODO: Should be possible to relax Send+Sync bounds in sync contexts
31    type State: Send + Sync;
32    /// The input of this stage
33    /// TODO: Should be possible to relax Send+Sync bounds in sync contexts
34    type Input: Send + Sync + Default + DynFields;
35    /// The output of this stage
36    /// TODO: Should be possible to relax Send+Sync bounds in sync contexts
37    type Output: Send + Sync + Default + DynFields;
38
39    /// Evaluate the stage with the given input and state
40    fn evaluate(
41        &self,
42        state: &mut Self::State,
43        inputs: &mut Self::Input,
44        cache: &mut HashMap<u64, Vec<crate::Cached<Self>>>,
45    ) -> Result<Self::Output, InjectionError>;
46    /// async version of evaluate
47    #[cfg(feature = "tokio")]
48    async fn evaluate_async(
49        &self,
50        state: &mut Self::State,
51        inputs: &mut Self::Input,
52        cache: &mut HashMap<u64, Vec<crate::Cached<Self>>>,
53    ) -> Result<Self::Output, InjectionError>;
54
55    fn reeval_rule(&self) -> ReevaluationRule {
56        ReevaluationRule::Move
57    }
58
59    /// Stage-level connection processing logic. See [Node::flow_data] for more
60    /// information.  
61    fn inject_input(
62        &self,
63        node: &mut Node<Self>,
64        parent: &mut Box<dyn AnyNode>,
65        output: Option<&'static str>,
66        input: Option<&'static str>,
67    ) -> Result<(), InjectionError>;
68}
69
70#[repr(u8)]
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
72pub enum ReevaluationRule {
73    /// Always move outputs, reevaluate every time. If the receiving node takes
74    /// a reference, it will be pased in, then dropped after that node
75    /// evaluates.
76    Move,
77    /// If all inputs are previous inputs, don't evaluate and just return a
78    /// clone of the cached output.
79    CacheLast,
80    /// If all inputs are equal to ANY previous input combination, don't
81    /// evaluate and just return a clone of the cached output associated with
82    /// that exact set of inputs.
83    CacheAll,
84}
85
86/// Empty version of a stage. Does nothing
87#[cfg_attr(feature = "tokio", async_trait::async_trait)]
88impl Stage for () {
89    const SHAPE: StageShape = StageShape {
90        stage_name: "()",
91        inputs: &[],
92        outputs: &[],
93    };
94    type State = ();
95    type Input = ();
96    type Output = ();
97    fn evaluate(
98        &self,
99        _: &mut Self::State,
100        _: &mut Self::Input,
101        _: &mut HashMap<u64, Vec<crate::Cached<Self>>>,
102    ) -> Result<Self::Output, InjectionError> {
103        Ok(())
104    }
105    #[cfg(feature = "tokio")]
106    async fn evaluate_async(
107        &self,
108        _: &mut Self::State,
109        _: &mut Self::Input,
110        _: &mut HashMap<u64, Vec<crate::Cached<Self>>>,
111    ) -> Result<Self::Output, InjectionError> {
112        Ok(())
113    }
114    fn inject_input(
115        &self,
116        _: &mut Node<Self>,
117        _: &mut Box<dyn AnyNode>,
118        _: Option<&'static str>,
119        _: Option<&'static str>,
120    ) -> Result<(), InjectionError> {
121        Ok(())
122    }
123}
124
125/// A stage that just returns a value. Currently this is somewhat naive and will always clone the value.
126#[derive(Clone, Copy)]
127pub struct ValueStage<T: Send + Sync + Clone + 'static>(std::marker::PhantomData<T>);
128impl<T: Send + Sync + Clone + 'static> ValueStage<T> {
129    pub fn new() -> Self {
130        Self(std::marker::PhantomData)
131    }
132}
133
134/// Wrapper used by [`ValueStage`]
135#[derive(Clone)]
136pub struct ValueWrapper<T: Send + Sync + Clone + 'static>(pub Option<T>);
137impl<T: Send + Sync + Clone + 'static> Default for ValueWrapper<T> {
138    fn default() -> Self {
139        Self(None)
140    }
141}
142
143impl<T: Send + Sync + Clone + 'static> DynFields for ValueWrapper<T> {
144    fn field<'a>(&'a self, _: Option<&'static str>) -> Option<&'a (dyn std::any::Any + 'static)> {
145        self.0.as_ref().map(|t| t as &dyn std::any::Any)
146    }
147
148    fn field_mut<'a>(
149        &'a mut self,
150        _: Option<&'static str>,
151    ) -> Option<&'a mut (dyn std::any::Any + 'static)> {
152        self.0.as_mut().map(|t| t as &mut dyn std::any::Any)
153    }
154
155    fn take_field(&mut self, _: Option<&'static str>) -> Option<Box<dyn std::any::Any>> {
156        self.0.take().map(|t| Box::new(t) as Box<dyn std::any::Any>)
157    }
158
159    fn replace(&mut self, other: Box<dyn std::any::Any>) -> Box<dyn DynFields> {
160        if let Ok(other) = other.downcast() {
161            Box::new(std::mem::replace(self, *other))
162        } else {
163            panic!("Attempted to replace value with wrong type")
164        }
165    }
166
167    fn clear(&mut self) {
168        self.0 = None;
169    }
170}
171
172#[cfg_attr(feature = "tokio", async_trait::async_trait)]
173impl<T: Send + Sync + Clone + 'static> Stage for ValueStage<T> {
174    const SHAPE: StageShape = StageShape {
175        stage_name: "_",
176        inputs: &[],
177        outputs: &["_"],
178    };
179    type State = ValueWrapper<T>;
180    type Input = ();
181    type Output = ValueWrapper<T>;
182    fn evaluate(
183        &self,
184        state: &mut Self::State,
185        _: &mut Self::Input,
186        _: &mut HashMap<u64, Vec<crate::Cached<Self>>>,
187    ) -> Result<Self::Output, InjectionError> {
188        Ok(state.clone())
189    }
190    #[cfg(feature = "tokio")]
191    async fn evaluate_async(
192        &self,
193        state: &mut Self::State,
194        _: &mut Self::Input,
195        _: &mut HashMap<u64, Vec<crate::Cached<Self>>>,
196    ) -> Result<Self::Output, InjectionError> {
197        Ok(state.clone())
198    }
199    fn inject_input(
200        &self,
201        _: &mut Node<Self>,
202        _: &mut Box<dyn AnyNode>,
203        _: Option<&'static str>,
204        _: Option<&'static str>,
205    ) -> Result<(), InjectionError> {
206        Ok(())
207    }
208}