directed/
error.rs

1//! Errors and the graph trace system
2use crate::{AnyNode, Graph, Registry, registry::NodeReflection};
3use std::fmt::{self, Display, Formatter, Write};
4
5/// Wrapper error type, wraps errors from this crate and stores a graph information with them.
6#[derive(thiserror::Error, Debug)]
7pub struct ErrorWithTrace<T: std::error::Error> {
8    #[source]
9    pub error: T,
10    pub graph_trace: Option<GraphTrace>,
11}
12
13#[derive(thiserror::Error, Debug)]
14pub enum InjectionError {
15    #[error("Output '{0:?}' not found")]
16    OutputNotFound(Option<&'static str>),
17    #[error("Output '{0:?}' type mismatch")]
18    OutputTypeMismatch(Option<&'static str>),
19    #[error("Input '{0:?}' not found")]
20    InputNotFound(Option<&'static str>),
21    #[error("Input '{0:?}' type mismatch")]
22    InputTypeMismatch(Option<&'static str>),
23}
24
25#[derive(thiserror::Error, Debug)]
26pub enum NodeExecutionError {
27    #[error(transparent)]
28    NodesNotFoundInRegistry(#[from] NodesNotFoundError),
29    #[error(transparent)]
30    NodeNotFoundInGraph(#[from] NodeIndexNotFoundInGraphError),
31    #[error(transparent)]
32    EdgeNotFoundInGraph(#[from] EdgeNotFoundInGraphError),
33    #[error(transparent)]
34    InputInjection(#[from] InjectionError),
35    #[cfg(feature = "tokio")]
36    #[error(transparent)]
37    JoinError(#[from] tokio::task::JoinError),
38}
39
40#[derive(thiserror::Error, Debug)]
41pub enum RegistryError {
42    #[error(transparent)]
43    NodesNotFoundInRegistry(#[from] NodesNotFoundError),
44    #[error(transparent)]
45    NodeTypeMismatch(#[from] NodeTypeMismatchError),
46}
47
48#[derive(thiserror::Error, Debug)]
49pub enum EdgeCreationError {
50    #[error(transparent)]
51    NodesNotFound(#[from] NodesNotFoundInGraphError),
52    #[error(transparent)]
53    CycleError(daggy::WouldCycle<crate::EdgeInfo>),
54}
55
56#[derive(thiserror::Error, Debug)]
57#[error("Invalid node type: (id:{got:?}). Expected: (id:{expected:?})")]
58pub struct NodeTypeMismatchError {
59    pub got: std::any::TypeId,
60    pub expected: std::any::TypeId,
61}
62
63#[derive(thiserror::Error, Debug)]
64#[error("Nodes with id `{0:?}` not found")]
65pub struct NodesNotFoundError(Vec<NodeReflection>);
66
67impl From<&[NodeReflection]> for NodesNotFoundError {
68    fn from(value: &[NodeReflection]) -> Self {
69        Self(Vec::from(value))
70    }
71}
72
73#[derive(thiserror::Error, Debug)]
74#[error("Nodes `{0:?}` not found in graph")]
75pub struct NodesNotFoundInGraphError(Vec<NodeReflection>);
76
77impl From<&[NodeReflection]> for NodesNotFoundInGraphError {
78    fn from(value: &[NodeReflection]) -> Self {
79        Self(Vec::from(value))
80    }
81}
82
83#[derive(thiserror::Error, Debug)]
84#[error("Node with index `{0:?}` not found in graph")]
85pub struct NodeIndexNotFoundInGraphError(daggy::NodeIndex);
86
87impl From<daggy::NodeIndex> for NodeIndexNotFoundInGraphError {
88    fn from(value: daggy::NodeIndex) -> Self {
89        Self(value)
90    }
91}
92
93#[derive(thiserror::Error, Debug)]
94#[error("Edge with index `{0:?}` not found in graph")]
95pub struct EdgeNotFoundInGraphError(daggy::EdgeIndex);
96
97impl From<daggy::EdgeIndex> for EdgeNotFoundInGraphError {
98    fn from(value: daggy::EdgeIndex) -> Self {
99        Self(value)
100    }
101}
102
103impl<T: std::error::Error> Display for ErrorWithTrace<T> {
104    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
105        writeln!(f, "{}", self.error)?;
106        if let Some(graph_trace) = &self.graph_trace {
107            writeln!(f, "{}", graph_trace.create_mermaid_graph())?;
108        }
109        Ok(())
110    }
111}
112
113impl<T: std::error::Error> From<T> for ErrorWithTrace<T> {
114    fn from(error: T) -> Self {
115        Self {
116            error,
117            graph_trace: None,
118        }
119    }
120}
121
122impl<T: std::error::Error> ErrorWithTrace<T> {
123    pub fn with_trace(self, trace: GraphTrace) -> Self {
124        Self {
125            error: self.error,
126            graph_trace: Some(trace),
127        }
128    }
129}
130
131/// A trace of a graph, containing information about nodes and connections.
132#[derive(Clone)]
133pub struct GraphTrace {
134    /// Information about each node in the graph.
135    pub nodes: Vec<NodeInfo>,
136    /// Information about each connection in the graph.
137    pub connections: Vec<ConnectionInfo>,
138}
139
140impl std::fmt::Debug for GraphTrace {
141    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
142        writeln!(f, "{}", self.create_mermaid_graph())
143    }
144}
145
146/// Information about a node in the graph.
147#[derive(Debug, Clone)]
148pub struct NodeInfo {
149    /// The unique ID of the node.
150    pub id: NodeReflection,
151    /// The name of the node.
152    pub name: &'static str,
153    /// The input fields of the node.
154    pub inputs: &'static [&'static str],
155    /// The output fields of the node.
156    pub outputs: &'static [&'static str],
157    /// Used for debugging purposes
158    pub highlighted: bool,
159}
160
161/// Information about a connection in the graph.
162#[derive(Debug, Clone, PartialEq, Eq)]
163pub struct ConnectionInfo {
164    /// The ID of the source node.
165    pub source_id: NodeReflection,
166    /// The output label of the source node.
167    pub source_output: Option<&'static str>,
168    /// The ID of the target node.
169    pub target_id: NodeReflection,
170    /// The input label of the target node.
171    pub target_input: Option<&'static str>,
172    /// Used for debugging purposes
173    pub highlighted: bool,
174}
175
176// Extension to Registry to allow access to nodes by ID
177impl Registry {
178    /// Gets a node by its ID.
179    pub fn get_node_by_id(&self, id: NodeReflection) -> Option<&Box<dyn AnyNode>> {
180        self.0.get(id.id).map(|node| node.as_ref()).flatten()
181    }
182}
183
184impl Graph {
185    /// Generates a trace of the graph.
186    pub fn generate_trace(&self, registry: &Registry) -> GraphTrace {
187        let mut nodes = Vec::new();
188        let mut connections = Vec::new();
189
190        // Add node information
191        for id in self.node_indices.iter().filter_map(|(id, _)| Some(*id)) {
192            if let Some(node) = registry.get_node_by_id(id) {
193                let stage_shape = node.stage_shape();
194                let node_info = NodeInfo {
195                    id,
196                    name: stage_shape.stage_name,
197                    inputs: stage_shape.inputs,
198                    outputs: stage_shape.outputs,
199                    highlighted: false,
200                };
201                nodes.push(node_info);
202            }
203        }
204
205        // Add connection information
206        for edge in self.dag.raw_edges() {
207            let source_idx = edge.source();
208            let target_idx = edge.target();
209
210            // Find the node IDs corresponding to the indices
211            let source_id = self
212                .node_indices
213                .iter()
214                .find(|(_, idx)| **idx == source_idx)
215                .map(|(id, _)| Some(*id))
216                .flatten();
217
218            let target_id = self
219                .node_indices
220                .iter()
221                .find(|(_, idx)| **idx == target_idx)
222                .map(|(id, _)| Some(*id))
223                .flatten();
224
225            if let (Some(source_id), Some(target_id)) = (source_id, target_id) {
226                let source_output = edge.weight.source_output;
227                let target_input = edge.weight.target_input;
228                let connection_info = ConnectionInfo {
229                    source_id,
230                    source_output,
231                    target_id,
232                    target_input,
233                    highlighted: false,
234                };
235                connections.push(connection_info);
236            }
237        }
238
239        GraphTrace { nodes, connections }
240    }
241}
242
243impl GraphTrace {
244    /// Emphasizes a node in the trace
245    pub fn highlight_node(&mut self, node: NodeReflection) {
246        if let Some(node) = self.nodes.iter_mut().find(|n| n.id == node) {
247            node.highlighted = true;
248        }
249    }
250
251    /// Emphasizes a connection in the trace
252    pub fn highlight_connection(
253        &mut self,
254        source_node: NodeReflection,
255        source_output: Option<&'static str>,
256        target_node: NodeReflection,
257        target_input: Option<&'static str>,
258    ) {
259        if let Some(conn) = self.connections.iter_mut().find(|conn| {
260            conn.source_id == source_node
261                && conn.source_output == source_output
262                && conn.target_id == target_node
263                && conn.target_input == target_input
264        }) {
265            conn.highlighted = true;
266        }
267    }
268
269    /// Creates a mermaid graph representing the graph.
270    pub fn create_mermaid_graph(&self) -> String {
271        const EMPHASIS_STYLE: &str = "stroke:yellow,stroke-width:3;";
272        const SANITIZER: &str = " |-|.|:|/|\\";
273        let mut result = String::new();
274
275        // Note the unwraps in this function are fine. If they were to actually
276        // panic there are deeper problems going on.
277
278        // Start the Mermaid flowchart definition
279        writeln!(&mut result, "```mermaid").unwrap();
280        writeln!(&mut result, "flowchart TB").unwrap();
281
282        // Create subgraphs for each node with its inputs and outputs
283        for node in &self.nodes {
284            // Create a subgraph for the node
285            write!(&mut result, "    subgraph Node_{}_", node.id.id).unwrap();
286            write!(&mut result, "[\"Node {} ({})\"]", node.id.id, node.name).unwrap();
287            writeln!(&mut result, "").unwrap();
288
289            // Define a node for each input port
290            for input in node.inputs.iter() {
291                let field_name = input;
292                // TODO: Would really help to have type information here
293                writeln!(
294                    &mut result,
295                    "        {}_in_{}[/\"{}\"\\]",
296                    node.id.id,
297                    field_name.replace(SANITIZER, "_"),
298                    field_name
299                )
300                .unwrap();
301            }
302
303            // Define a node for each output port, unless this is a plain node.
304            for output in node.outputs.iter() {
305                let field_name = output;
306                write!(
307                    &mut result,
308                    "        {}_out_{}[\\\"",
309                    node.id.id,
310                    field_name.replace(SANITIZER, "_")
311                )
312                .unwrap();
313                write!(&mut result, "{}", field_name).unwrap();
314                // TODO: Would really help to have type information here
315                writeln!(&mut result, "\"/]").unwrap();
316            }
317
318            writeln!(&mut result, "    end").unwrap();
319            if node.highlighted {
320                writeln!(
321                    &mut result,
322                    "    style Node_{}_ {EMPHASIS_STYLE}",
323                    node.id.id
324                )
325                .unwrap();
326            }
327        }
328
329        // Create the connections between nodes
330        for (i, conn) in self.connections.iter().enumerate() {
331            let source_name = conn.source_output.unwrap_or("_");
332            let target_name = conn.target_input.unwrap_or("_");
333
334            write!(
335                &mut result,
336                "    {}_out_{} ",
337                conn.source_id.id,
338                source_name.replace(SANITIZER, "_")
339            )
340            .unwrap();
341            write!(&mut result, "--> ").unwrap();
342            writeln!(
343                &mut result,
344                "{}_in_{}",
345                conn.target_id.id,
346                target_name.replace(SANITIZER, "_")
347            )
348            .unwrap();
349
350            if conn.highlighted {
351                writeln!(&mut result, "    linkStyle {i} {EMPHASIS_STYLE}").unwrap();
352            }
353        }
354
355        // End the Mermaid diagram
356        writeln!(&mut result, "```").unwrap();
357
358        result
359    }
360}