lemon_graph/execution/
step.rs1use petgraph::{graph::NodeIndex, visit::EdgeRef, Direction};
2use thiserror::Error;
3
4use crate::{nodes::NodeError, Graph, GraphEdge, GraphNode};
5
6pub struct ExecutionStep(pub NodeIndex);
7
8#[derive(Debug, Error)]
9pub enum ExecutionStepError {
10 #[error("No weight")]
11 NoWeight,
12 #[error("Invalid weight")]
13 InvalidWeight,
14 #[error(transparent)]
15 NodeError(#[from] NodeError),
16}
17
18impl ExecutionStep {
19 pub async fn execute<'a>(
20 &self,
21 graph: &'a mut Graph,
22 ) -> Result<impl Iterator<Item = ExecutionStep> + 'a, ExecutionStepError> {
23 let inputs = graph
25 .edges_directed(self.0, Direction::Incoming)
26 .filter_map(|edge| match edge.weight() {
27 GraphEdge::DataMap(data_idx) => Some((*data_idx, edge.source())),
28 _ => None,
29 })
30 .collect::<Vec<_>>();
31
32 let mut inputs = inputs
33 .into_iter()
34 .map(|(data_idx, source_idx)| -> Result<_, ExecutionStepError> {
35 let mut new_value = None;
37
38 for edge in graph.edges_directed(source_idx, Direction::Incoming) {
39 if let GraphEdge::DataFlow = edge.weight() {
40 let source = edge.source();
41
42 let source_weight = graph
43 .node_weight(source)
44 .ok_or(ExecutionStepError::NoWeight)?;
45
46 let value = match source_weight {
47 GraphNode::Store(value) => value,
48 _ => return Err(ExecutionStepError::InvalidWeight),
49 };
50
51 new_value = Some(value.clone());
52 }
53 }
54
55 if let Some(value) = new_value {
56 graph[source_idx] = GraphNode::Store(value.clone());
57 return Ok((data_idx, value));
58 }
59
60 let source_weight = graph
61 .node_weight(source_idx)
62 .ok_or(ExecutionStepError::NoWeight)?;
63
64 match source_weight {
65 GraphNode::Store(value) => Ok((data_idx, value.clone())),
66 _ => Err(ExecutionStepError::InvalidWeight),
67 }
68 })
69 .collect::<Result<Vec<_>, _>>()?;
70
71 inputs.sort_by_key(|(idx, _)| *idx);
72
73 let inputs = inputs.into_iter().map(|(_, value)| value).collect();
74
75 let node = graph
77 .node_weight(self.0)
78 .ok_or(ExecutionStepError::NoWeight)?;
79
80 let res = match node {
81 GraphNode::AsyncNode(node) => node.run(inputs).await?,
82 GraphNode::SyncNode(node) => node.run(inputs)?,
83 _ => return Err(ExecutionStepError::InvalidWeight),
84 };
85
86 let outputs = graph
88 .edges_directed(self.0, Direction::Outgoing)
89 .filter_map(|edge| match edge.weight() {
90 GraphEdge::DataMap(data_idx) => Some((edge.target(), *data_idx)),
91 _ => None,
92 })
93 .collect::<Vec<_>>();
94
95 for (i, value) in res.into_iter().enumerate() {
96 let (store_idx, _) = match outputs.iter().find(|(_, idx)| *idx == i) {
97 Some(output) => output,
98 None => continue,
99 };
100
101 graph[*store_idx] = GraphNode::Store(value);
102 }
103
104 Ok(graph
106 .edges_directed(self.0, Direction::Outgoing)
107 .filter_map(|edge| match edge.weight() {
108 GraphEdge::ExecutionFlow => Some(ExecutionStep(edge.target())),
109 _ => None,
110 }))
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use crate::{
117 nodes::{AsyncNode, SyncNode},
118 Value,
119 };
120
121 use super::*;
122
123 struct TestSync;
124
125 impl SyncNode for TestSync {
126 fn run(&self, inputs: Vec<Value>) -> Result<Vec<Value>, NodeError> {
127 Ok(inputs)
128 }
129 }
130
131 struct TestAsync;
132
133 impl AsyncNode for TestAsync {
134 fn run(
135 &self,
136 inputs: Vec<Value>,
137 ) -> Box<dyn std::future::Future<Output = Result<Vec<Value>, NodeError>> + Unpin> {
138 Box::new(Box::pin(async move { Ok(inputs) }))
139 }
140 }
141
142 #[tokio::test]
143 async fn test_sync_execution() {
144 let mut graph = Graph::new();
145
146 let input = graph.add_node(GraphNode::Store(Value::String("Hello, world!".to_string())));
147 let node = graph.add_node(GraphNode::SyncNode(Box::new(TestSync)));
148 let output = graph.add_node(GraphNode::Store(Value::String(Default::default())));
149 graph.add_edge(input, node, GraphEdge::DataMap(0));
150 graph.add_edge(node, output, GraphEdge::DataMap(0));
151
152 let step = ExecutionStep(node);
153 let next_steps = step.execute(&mut graph).await.unwrap().collect::<Vec<_>>();
154 assert!(next_steps.is_empty());
155
156 let output_value = graph.node_weight(output).unwrap();
157 let output_value = match output_value {
158 GraphNode::Store(value) => value,
159 _ => panic!(),
160 };
161 assert_eq!(output_value, &Value::String("Hello, world!".to_string()));
162 }
163
164 #[tokio::test]
165 async fn test_async_execution() {
166 let mut graph = Graph::new();
167
168 let input = graph.add_node(GraphNode::Store(Value::String("Hello, world!".to_string())));
169 let node = graph.add_node(GraphNode::AsyncNode(Box::new(TestAsync)));
170 let output = graph.add_node(GraphNode::Store(Value::String(Default::default())));
171 graph.add_edge(input, node, GraphEdge::DataMap(0));
172 graph.add_edge(node, output, GraphEdge::DataMap(0));
173
174 let step = ExecutionStep(node);
175 let next_steps = step.execute(&mut graph).await.unwrap().collect::<Vec<_>>();
176 assert!(next_steps.is_empty());
177
178 let output_value = graph.node_weight(output).unwrap();
179 let output_value = match output_value {
180 GraphNode::Store(value) => value,
181 _ => panic!(),
182 };
183 assert_eq!(output_value, &Value::String("Hello, world!".to_string()));
184 }
185}