1use crate::{AnyNode, Graph, Registry, registry::NodeReflection};
3use std::fmt::{self, Display, Formatter, Write};
4
5#[derive(thiserror::Error, Debug)]
7pub struct ErrorWithTrace<T: std::error::Error> {
8 #[source]
9 pub error: T,
10 pub graph_trace: Option<GraphTrace>,
11}
12
13#[derive(thiserror::Error, Debug)]
14pub enum InjectionError {
15 #[error("Output '{0:?}' not found")]
16 OutputNotFound(Option<&'static str>),
17 #[error("Output '{0:?}' type mismatch")]
18 OutputTypeMismatch(Option<&'static str>),
19 #[error("Input '{0:?}' not found")]
20 InputNotFound(Option<&'static str>),
21 #[error("Input '{0:?}' type mismatch")]
22 InputTypeMismatch(Option<&'static str>),
23}
24
25#[derive(thiserror::Error, Debug)]
26pub enum NodeExecutionError {
27 #[error(transparent)]
28 NodesNotFoundInRegistry(#[from] NodesNotFoundError),
29 #[error(transparent)]
30 NodeNotFoundInGraph(#[from] NodeIndexNotFoundInGraphError),
31 #[error(transparent)]
32 EdgeNotFoundInGraph(#[from] EdgeNotFoundInGraphError),
33 #[error(transparent)]
34 InputInjection(#[from] InjectionError),
35 #[cfg(feature = "tokio")]
36 #[error(transparent)]
37 JoinError(#[from] tokio::task::JoinError),
38}
39
40#[derive(thiserror::Error, Debug)]
41pub enum RegistryError {
42 #[error(transparent)]
43 NodesNotFoundInRegistry(#[from] NodesNotFoundError),
44 #[error(transparent)]
45 NodeTypeMismatch(#[from] NodeTypeMismatchError),
46}
47
48#[derive(thiserror::Error, Debug)]
49pub enum EdgeCreationError {
50 #[error(transparent)]
51 NodesNotFound(#[from] NodesNotFoundInGraphError),
52 #[error(transparent)]
53 CycleError(daggy::WouldCycle<crate::EdgeInfo>),
54}
55
56#[derive(thiserror::Error, Debug)]
57#[error("Invalid node type: (id:{got:?}). Expected: (id:{expected:?})")]
58pub struct NodeTypeMismatchError {
59 pub got: std::any::TypeId,
60 pub expected: std::any::TypeId,
61}
62
63#[derive(thiserror::Error, Debug)]
64#[error("Nodes with id `{0:?}` not found")]
65pub struct NodesNotFoundError(Vec<NodeReflection>);
66
67impl From<&[NodeReflection]> for NodesNotFoundError {
68 fn from(value: &[NodeReflection]) -> Self {
69 Self(Vec::from(value))
70 }
71}
72
73#[derive(thiserror::Error, Debug)]
74#[error("Nodes `{0:?}` not found in graph")]
75pub struct NodesNotFoundInGraphError(Vec<NodeReflection>);
76
77impl From<&[NodeReflection]> for NodesNotFoundInGraphError {
78 fn from(value: &[NodeReflection]) -> Self {
79 Self(Vec::from(value))
80 }
81}
82
83#[derive(thiserror::Error, Debug)]
84#[error("Node with index `{0:?}` not found in graph")]
85pub struct NodeIndexNotFoundInGraphError(daggy::NodeIndex);
86
87impl From<daggy::NodeIndex> for NodeIndexNotFoundInGraphError {
88 fn from(value: daggy::NodeIndex) -> Self {
89 Self(value)
90 }
91}
92
93#[derive(thiserror::Error, Debug)]
94#[error("Edge with index `{0:?}` not found in graph")]
95pub struct EdgeNotFoundInGraphError(daggy::EdgeIndex);
96
97impl From<daggy::EdgeIndex> for EdgeNotFoundInGraphError {
98 fn from(value: daggy::EdgeIndex) -> Self {
99 Self(value)
100 }
101}
102
103impl<T: std::error::Error> Display for ErrorWithTrace<T> {
104 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
105 writeln!(f, "{}", self.error)?;
106 if let Some(graph_trace) = &self.graph_trace {
107 writeln!(f, "{}", graph_trace.create_mermaid_graph())?;
108 }
109 Ok(())
110 }
111}
112
113impl<T: std::error::Error> From<T> for ErrorWithTrace<T> {
114 fn from(error: T) -> Self {
115 Self {
116 error,
117 graph_trace: None,
118 }
119 }
120}
121
122impl<T: std::error::Error> ErrorWithTrace<T> {
123 pub fn with_trace(self, trace: GraphTrace) -> Self {
124 Self {
125 error: self.error,
126 graph_trace: Some(trace),
127 }
128 }
129}
130
131#[derive(Clone)]
133pub struct GraphTrace {
134 pub nodes: Vec<NodeInfo>,
136 pub connections: Vec<ConnectionInfo>,
138}
139
140impl std::fmt::Debug for GraphTrace {
141 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
142 writeln!(f, "{}", self.create_mermaid_graph())
143 }
144}
145
146#[derive(Debug, Clone)]
148pub struct NodeInfo {
149 pub id: NodeReflection,
151 pub name: &'static str,
153 pub inputs: &'static [&'static str],
155 pub outputs: &'static [&'static str],
157 pub highlighted: bool,
159}
160
161#[derive(Debug, Clone, PartialEq, Eq)]
163pub struct ConnectionInfo {
164 pub source_id: NodeReflection,
166 pub source_output: Option<&'static str>,
168 pub target_id: NodeReflection,
170 pub target_input: Option<&'static str>,
172 pub highlighted: bool,
174}
175
176impl Registry {
178 pub fn get_node_by_id(&self, id: NodeReflection) -> Option<&Box<dyn AnyNode>> {
180 self.0.get(id.id).map(|node| node.as_ref()).flatten()
181 }
182}
183
184impl Graph {
185 pub fn generate_trace(&self, registry: &Registry) -> GraphTrace {
187 let mut nodes = Vec::new();
188 let mut connections = Vec::new();
189
190 for id in self.node_indices.iter().filter_map(|(id, _)| Some(*id)) {
192 if let Some(node) = registry.get_node_by_id(id) {
193 let stage_shape = node.stage_shape();
194 let node_info = NodeInfo {
195 id,
196 name: stage_shape.stage_name,
197 inputs: stage_shape.inputs,
198 outputs: stage_shape.outputs,
199 highlighted: false,
200 };
201 nodes.push(node_info);
202 }
203 }
204
205 for edge in self.dag.raw_edges() {
207 let source_idx = edge.source();
208 let target_idx = edge.target();
209
210 let source_id = self
212 .node_indices
213 .iter()
214 .find(|(_, idx)| **idx == source_idx)
215 .map(|(id, _)| Some(*id))
216 .flatten();
217
218 let target_id = self
219 .node_indices
220 .iter()
221 .find(|(_, idx)| **idx == target_idx)
222 .map(|(id, _)| Some(*id))
223 .flatten();
224
225 if let (Some(source_id), Some(target_id)) = (source_id, target_id) {
226 let source_output = edge.weight.source_output;
227 let target_input = edge.weight.target_input;
228 let connection_info = ConnectionInfo {
229 source_id,
230 source_output,
231 target_id,
232 target_input,
233 highlighted: false,
234 };
235 connections.push(connection_info);
236 }
237 }
238
239 GraphTrace { nodes, connections }
240 }
241}
242
243impl GraphTrace {
244 pub fn highlight_node(&mut self, node: NodeReflection) {
246 if let Some(node) = self.nodes.iter_mut().find(|n| n.id == node) {
247 node.highlighted = true;
248 }
249 }
250
251 pub fn highlight_connection(
253 &mut self,
254 source_node: NodeReflection,
255 source_output: Option<&'static str>,
256 target_node: NodeReflection,
257 target_input: Option<&'static str>,
258 ) {
259 if let Some(conn) = self.connections.iter_mut().find(|conn| {
260 conn.source_id == source_node
261 && conn.source_output == source_output
262 && conn.target_id == target_node
263 && conn.target_input == target_input
264 }) {
265 conn.highlighted = true;
266 }
267 }
268
269 pub fn create_mermaid_graph(&self) -> String {
271 const EMPHASIS_STYLE: &str = "stroke:yellow,stroke-width:3;";
272 const SANITIZER: &str = " |-|.|:|/|\\";
273 let mut result = String::new();
274
275 writeln!(&mut result, "```mermaid").unwrap();
280 writeln!(&mut result, "flowchart TB").unwrap();
281
282 for node in &self.nodes {
284 write!(&mut result, " subgraph Node_{}_", node.id.id).unwrap();
286 write!(&mut result, "[\"Node {} ({})\"]", node.id.id, node.name).unwrap();
287 writeln!(&mut result, "").unwrap();
288
289 for input in node.inputs.iter() {
291 let field_name = input;
292 writeln!(
294 &mut result,
295 " {}_in_{}[/\"{}\"\\]",
296 node.id.id,
297 field_name.replace(SANITIZER, "_"),
298 field_name
299 )
300 .unwrap();
301 }
302
303 for output in node.outputs.iter() {
305 let field_name = output;
306 write!(
307 &mut result,
308 " {}_out_{}[\\\"",
309 node.id.id,
310 field_name.replace(SANITIZER, "_")
311 )
312 .unwrap();
313 write!(&mut result, "{}", field_name).unwrap();
314 writeln!(&mut result, "\"/]").unwrap();
316 }
317
318 writeln!(&mut result, " end").unwrap();
319 if node.highlighted {
320 writeln!(
321 &mut result,
322 " style Node_{}_ {EMPHASIS_STYLE}",
323 node.id.id
324 )
325 .unwrap();
326 }
327 }
328
329 for (i, conn) in self.connections.iter().enumerate() {
331 let source_name = conn.source_output.unwrap_or("_");
332 let target_name = conn.target_input.unwrap_or("_");
333
334 write!(
335 &mut result,
336 " {}_out_{} ",
337 conn.source_id.id,
338 source_name.replace(SANITIZER, "_")
339 )
340 .unwrap();
341 write!(&mut result, "--> ").unwrap();
342 writeln!(
343 &mut result,
344 "{}_in_{}",
345 conn.target_id.id,
346 target_name.replace(SANITIZER, "_")
347 )
348 .unwrap();
349
350 if conn.highlighted {
351 writeln!(&mut result, " linkStyle {i} {EMPHASIS_STYLE}").unwrap();
352 }
353 }
354
355 writeln!(&mut result, "```").unwrap();
357
358 result
359 }
360}