lemon_graph/execution/
step.rs

1use petgraph::{graph::NodeIndex, visit::EdgeRef, Direction};
2use thiserror::Error;
3
4use crate::{nodes::NodeError, Graph, GraphEdge, GraphNode};
5
6pub struct ExecutionStep(pub NodeIndex);
7
8#[derive(Debug, Error)]
9pub enum ExecutionStepError {
10    #[error("No weight")]
11    NoWeight,
12    #[error("Invalid weight")]
13    InvalidWeight,
14    #[error(transparent)]
15    NodeError(#[from] NodeError),
16}
17
18impl ExecutionStep {
19    pub async fn execute<'a>(
20        &self,
21        graph: &'a mut Graph,
22    ) -> Result<impl Iterator<Item = ExecutionStep> + 'a, ExecutionStepError> {
23        // Read inputs
24        let inputs = graph
25            .edges_directed(self.0, Direction::Incoming)
26            .filter_map(|edge| match edge.weight() {
27                GraphEdge::DataMap(data_idx) => Some((*data_idx, edge.source())),
28                _ => None,
29            })
30            .collect::<Vec<_>>();
31
32        let mut inputs = inputs
33            .into_iter()
34            .map(|(data_idx, source_idx)| -> Result<_, ExecutionStepError> {
35                // Update source from incoming DataFlow edges.
36                let mut new_value = None;
37
38                for edge in graph.edges_directed(source_idx, Direction::Incoming) {
39                    if let GraphEdge::DataFlow = edge.weight() {
40                        let source = edge.source();
41
42                        let source_weight = graph
43                            .node_weight(source)
44                            .ok_or(ExecutionStepError::NoWeight)?;
45
46                        let value = match source_weight {
47                            GraphNode::Store(value) => value,
48                            _ => return Err(ExecutionStepError::InvalidWeight),
49                        };
50
51                        new_value = Some(value.clone());
52                    }
53                }
54
55                if let Some(value) = new_value {
56                    graph[source_idx] = GraphNode::Store(value.clone());
57                    return Ok((data_idx, value));
58                }
59
60                let source_weight = graph
61                    .node_weight(source_idx)
62                    .ok_or(ExecutionStepError::NoWeight)?;
63
64                match source_weight {
65                    GraphNode::Store(value) => Ok((data_idx, value.clone())),
66                    _ => Err(ExecutionStepError::InvalidWeight),
67                }
68            })
69            .collect::<Result<Vec<_>, _>>()?;
70
71        inputs.sort_by_key(|(idx, _)| *idx);
72
73        let inputs = inputs.into_iter().map(|(_, value)| value).collect();
74
75        // Execute node
76        let node = graph
77            .node_weight(self.0)
78            .ok_or(ExecutionStepError::NoWeight)?;
79
80        let res = match node {
81            GraphNode::AsyncNode(node) => node.run(inputs).await?,
82            GraphNode::SyncNode(node) => node.run(inputs)?,
83            _ => return Err(ExecutionStepError::InvalidWeight),
84        };
85
86        // Write outputs
87        let outputs = graph
88            .edges_directed(self.0, Direction::Outgoing)
89            .filter_map(|edge| match edge.weight() {
90                GraphEdge::DataMap(data_idx) => Some((edge.target(), *data_idx)),
91                _ => None,
92            })
93            .collect::<Vec<_>>();
94
95        for (i, value) in res.into_iter().enumerate() {
96            let (store_idx, _) = match outputs.iter().find(|(_, idx)| *idx == i) {
97                Some(output) => output,
98                None => continue,
99            };
100
101            graph[*store_idx] = GraphNode::Store(value);
102        }
103
104        // Get next steps
105        Ok(graph
106            .edges_directed(self.0, Direction::Outgoing)
107            .filter_map(|edge| match edge.weight() {
108                GraphEdge::ExecutionFlow => Some(ExecutionStep(edge.target())),
109                _ => None,
110            }))
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use crate::{
117        nodes::{AsyncNode, SyncNode},
118        Value,
119    };
120
121    use super::*;
122
123    struct TestSync;
124
125    impl SyncNode for TestSync {
126        fn run(&self, inputs: Vec<Value>) -> Result<Vec<Value>, NodeError> {
127            Ok(inputs)
128        }
129    }
130
131    struct TestAsync;
132
133    impl AsyncNode for TestAsync {
134        fn run(
135            &self,
136            inputs: Vec<Value>,
137        ) -> Box<dyn std::future::Future<Output = Result<Vec<Value>, NodeError>> + Unpin> {
138            Box::new(Box::pin(async move { Ok(inputs) }))
139        }
140    }
141
142    #[tokio::test]
143    async fn test_sync_execution() {
144        let mut graph = Graph::new();
145
146        let input = graph.add_node(GraphNode::Store(Value::String("Hello, world!".to_string())));
147        let node = graph.add_node(GraphNode::SyncNode(Box::new(TestSync)));
148        let output = graph.add_node(GraphNode::Store(Value::String(Default::default())));
149        graph.add_edge(input, node, GraphEdge::DataMap(0));
150        graph.add_edge(node, output, GraphEdge::DataMap(0));
151
152        let step = ExecutionStep(node);
153        let next_steps = step.execute(&mut graph).await.unwrap().collect::<Vec<_>>();
154        assert!(next_steps.is_empty());
155
156        let output_value = graph.node_weight(output).unwrap();
157        let output_value = match output_value {
158            GraphNode::Store(value) => value,
159            _ => panic!(),
160        };
161        assert_eq!(output_value, &Value::String("Hello, world!".to_string()));
162    }
163
164    #[tokio::test]
165    async fn test_async_execution() {
166        let mut graph = Graph::new();
167
168        let input = graph.add_node(GraphNode::Store(Value::String("Hello, world!".to_string())));
169        let node = graph.add_node(GraphNode::AsyncNode(Box::new(TestAsync)));
170        let output = graph.add_node(GraphNode::Store(Value::String(Default::default())));
171        graph.add_edge(input, node, GraphEdge::DataMap(0));
172        graph.add_edge(node, output, GraphEdge::DataMap(0));
173
174        let step = ExecutionStep(node);
175        let next_steps = step.execute(&mut graph).await.unwrap().collect::<Vec<_>>();
176        assert!(next_steps.is_empty());
177
178        let output_value = graph.node_weight(output).unwrap();
179        let output_value = match output_value {
180            GraphNode::Store(value) => value,
181            _ => panic!(),
182        };
183        assert_eq!(output_value, &Value::String("Hello, world!".to_string()));
184    }
185}