oxify_model/
visualization.rs

1//! Workflow Visualization Export
2//!
3//! This module provides functionality to export workflows to various
4//! diagram formats for visualization and documentation purposes.
5//!
6//! Supported formats:
7//! - **Mermaid**: Popular markdown-based diagramming (flowchart syntax)
8//! - **Graphviz DOT**: Industry standard graph visualization language
9//! - **PlantUML**: UML and diagram generation tool
10//!
11//! # Example
12//!
13//! ```rust
14//! use oxify_model::{Workflow, WorkflowBuilder, LlmConfig, visualization::WorkflowVisualizer};
15//!
16//! let llm_config = LlmConfig {
17//!     provider: "openai".to_string(),
18//!     model: "gpt-4".to_string(),
19//!     system_prompt: None,
20//!     prompt_template: "{{input}}".to_string(),
21//!     temperature: Some(0.7),
22//!     max_tokens: Some(100),
23//!     tools: vec![],
24//!     images: vec![],
25//!     extra_params: serde_json::json!({}),
26//! };
27//!
28//! let workflow = WorkflowBuilder::new("example")
29//!     .description("Example workflow")
30//!     .start("Start")
31//!     .llm("Generate text", llm_config)
32//!     .end("End")
33//!     .build();
34//!
35//! let visualizer = WorkflowVisualizer::new(&workflow);
36//! let mermaid = visualizer.to_mermaid();
37//! println!("{}", mermaid);
38//! ```
39
40use crate::{Edge, Node, NodeKind, Workflow};
41use serde::{Deserialize, Serialize};
42use std::collections::{HashMap, HashSet};
43
44/// Visualization format options
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46pub enum VisualizationFormat {
47    /// Mermaid flowchart format
48    Mermaid,
49    /// Graphviz DOT format
50    Graphviz,
51    /// PlantUML activity diagram format
52    PlantUML,
53}
54
55/// Visual styling options for workflow diagrams
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct VisualizationStyle {
58    /// Show node IDs in addition to names
59    pub show_node_ids: bool,
60
61    /// Show edge labels (condition expressions)
62    pub show_edge_labels: bool,
63
64    /// Use colors to differentiate node types
65    pub use_colors: bool,
66
67    /// Include node descriptions as tooltips/notes
68    pub include_descriptions: bool,
69
70    /// Diagram orientation (TB, LR, BT, RL)
71    pub orientation: DiagramOrientation,
72
73    /// Group nodes by type
74    pub group_by_type: bool,
75}
76
77impl Default for VisualizationStyle {
78    fn default() -> Self {
79        Self {
80            show_node_ids: false,
81            show_edge_labels: true,
82            use_colors: true,
83            include_descriptions: false,
84            orientation: DiagramOrientation::TopBottom,
85            group_by_type: false,
86        }
87    }
88}
89
90/// Diagram layout orientation
91#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
92pub enum DiagramOrientation {
93    /// Top to bottom
94    TopBottom,
95    /// Left to right
96    LeftRight,
97    /// Bottom to top
98    BottomTop,
99    /// Right to left
100    RightLeft,
101}
102
103impl DiagramOrientation {
104    /// Convert to Mermaid orientation code
105    fn to_mermaid(self) -> &'static str {
106        match self {
107            DiagramOrientation::TopBottom => "TB",
108            DiagramOrientation::LeftRight => "LR",
109            DiagramOrientation::BottomTop => "BT",
110            DiagramOrientation::RightLeft => "RL",
111        }
112    }
113
114    /// Convert to Graphviz rankdir
115    fn to_graphviz(self) -> &'static str {
116        match self {
117            DiagramOrientation::TopBottom => "TB",
118            DiagramOrientation::LeftRight => "LR",
119            DiagramOrientation::BottomTop => "BT",
120            DiagramOrientation::RightLeft => "RL",
121        }
122    }
123}
124
125/// Workflow visualizer for generating diagrams
126pub struct WorkflowVisualizer<'a> {
127    workflow: &'a Workflow,
128    style: VisualizationStyle,
129}
130
131impl<'a> WorkflowVisualizer<'a> {
132    /// Create a new visualizer for a workflow
133    pub fn new(workflow: &'a Workflow) -> Self {
134        Self {
135            workflow,
136            style: VisualizationStyle::default(),
137        }
138    }
139
140    /// Create a visualizer with custom styling
141    pub fn with_style(workflow: &'a Workflow, style: VisualizationStyle) -> Self {
142        Self { workflow, style }
143    }
144
145    /// Export to Mermaid flowchart format
146    pub fn to_mermaid(&self) -> String {
147        let mut output = String::new();
148
149        // Header
150        output.push_str(&format!(
151            "flowchart {}\n",
152            self.style.orientation.to_mermaid()
153        ));
154
155        // Add title if present
156        if let Some(desc) = &self.workflow.metadata.description {
157            output.push_str("    %%{ init: {'theme':'base', 'themeVariables': { 'primaryColor':'#ff9900'}}}%%\n");
158            output.push_str(&format!("    %% {}\n", desc));
159        }
160
161        // Add nodes
162        for node in &self.workflow.nodes {
163            let node_def = self.mermaid_node_definition(node);
164            output.push_str(&format!("    {}\n", node_def));
165        }
166
167        output.push('\n');
168
169        // Add edges
170        for edge in &self.workflow.edges {
171            let edge_def = self.mermaid_edge_definition(edge);
172            output.push_str(&format!("    {}\n", edge_def));
173        }
174
175        // Add styling if enabled
176        if self.style.use_colors {
177            output.push('\n');
178            output.push_str(&self.mermaid_styling());
179        }
180
181        output
182    }
183
184    /// Generate Mermaid node definition
185    fn mermaid_node_definition(&self, node: &Node) -> String {
186        let node_id = self.sanitize_id(&node.id.to_string());
187        let label = self.node_label(node);
188
189        // Choose shape based on node type
190        let (open, close) = match node.kind {
191            NodeKind::Start => ("[", "]"),
192            NodeKind::End => ("[", "]"),
193            NodeKind::IfElse(_) => ("{", "}"),
194            NodeKind::Switch(_) => ("{", "}"),
195            NodeKind::Parallel(_) => ("[[", "]]"),
196            NodeKind::Loop(_) => ("{{", "}}"),
197            _ => ("(", ")"),
198        };
199
200        format!("{}{}\"{}\"{}", node_id, open, label, close)
201    }
202
203    /// Generate Mermaid edge definition
204    fn mermaid_edge_definition(&self, edge: &Edge) -> String {
205        let from_id = self.sanitize_id(&edge.from.to_string());
206        let to_id = self.sanitize_id(&edge.to.to_string());
207
208        if self.style.show_edge_labels {
209            if let Some(label) = &edge.label {
210                return format!("{} -->|\"{}\"| {}", from_id, label, to_id);
211            }
212        }
213
214        format!("{} --> {}", from_id, to_id)
215    }
216
217    /// Generate Mermaid styling classes
218    fn mermaid_styling(&self) -> String {
219        let mut styling = String::new();
220
221        // Define style classes for different node types
222        styling.push_str("    classDef startEnd fill:#90EE90,stroke:#228B22,stroke-width:2px\n");
223        styling.push_str("    classDef llm fill:#87CEEB,stroke:#4682B4,stroke-width:2px\n");
224        styling.push_str("    classDef code fill:#FFB6C1,stroke:#C71585,stroke-width:2px\n");
225        styling.push_str("    classDef decision fill:#FFD700,stroke:#FF8C00,stroke-width:2px\n");
226        styling.push_str("    classDef loop fill:#DDA0DD,stroke:#8B008B,stroke-width:2px\n");
227        styling.push_str("    classDef parallel fill:#F0E68C,stroke:#BDB76B,stroke-width:2px\n");
228
229        // Apply classes to nodes
230        for node in &self.workflow.nodes {
231            let node_id = self.sanitize_id(&node.id.to_string());
232            let class_name = match node.kind {
233                NodeKind::Start | NodeKind::End => "startEnd",
234                NodeKind::LLM(_) => "llm",
235                NodeKind::Code(_) => "code",
236                NodeKind::IfElse(_) | NodeKind::Switch(_) => "decision",
237                NodeKind::Loop(_) => "loop",
238                NodeKind::Parallel(_) => "parallel",
239                _ => continue,
240            };
241            styling.push_str(&format!("    class {} {}\n", node_id, class_name));
242        }
243
244        styling
245    }
246
247    /// Export to Graphviz DOT format
248    pub fn to_graphviz(&self) -> String {
249        let mut output = String::new();
250
251        // Header
252        output.push_str("digraph workflow {\n");
253        output.push_str(&format!(
254            "    rankdir={};\n",
255            self.style.orientation.to_graphviz()
256        ));
257        output.push_str("    node [shape=box, style=\"rounded,filled\"];\n");
258        output.push_str("    edge [fontsize=10];\n\n");
259
260        // Add workflow metadata as graph label
261        if let Some(desc) = &self.workflow.metadata.description {
262            output.push_str("    labelloc=\"t\";\n");
263            output.push_str(&format!(
264                "    label=\"{}\";\n\n",
265                self.escape_graphviz(desc)
266            ));
267        }
268
269        // Add nodes
270        for node in &self.workflow.nodes {
271            let node_def = self.graphviz_node_definition(node);
272            output.push_str(&format!("    {};\n", node_def));
273        }
274
275        output.push('\n');
276
277        // Add edges
278        for edge in &self.workflow.edges {
279            let edge_def = self.graphviz_edge_definition(edge);
280            output.push_str(&format!("    {};\n", edge_def));
281        }
282
283        output.push_str("}\n");
284        output
285    }
286
287    /// Generate Graphviz node definition
288    fn graphviz_node_definition(&self, node: &Node) -> String {
289        let node_id = self.sanitize_id(&node.id.to_string());
290        let label = self.escape_graphviz(&self.node_label(node));
291
292        let (shape, color) = match node.kind {
293            NodeKind::Start => ("ellipse", "#90EE90"),
294            NodeKind::End => ("ellipse", "#FFB6C1"),
295            NodeKind::LLM(_) => ("box", "#87CEEB"),
296            NodeKind::Code(_) => ("box", "#FFB6C1"),
297            NodeKind::IfElse(_) | NodeKind::Switch(_) => ("diamond", "#FFD700"),
298            NodeKind::Loop(_) => ("hexagon", "#DDA0DD"),
299            NodeKind::Parallel(_) => ("parallelogram", "#F0E68C"),
300            _ => ("box", "#E0E0E0"),
301        };
302
303        if self.style.use_colors {
304            format!(
305                "{} [label=\"{}\", shape={}, fillcolor=\"{}\"]",
306                node_id, label, shape, color
307            )
308        } else {
309            format!("{} [label=\"{}\", shape={}]", node_id, label, shape)
310        }
311    }
312
313    /// Generate Graphviz edge definition
314    fn graphviz_edge_definition(&self, edge: &Edge) -> String {
315        let from_id = self.sanitize_id(&edge.from.to_string());
316        let to_id = self.sanitize_id(&edge.to.to_string());
317
318        if self.style.show_edge_labels {
319            if let Some(label) = &edge.label {
320                let escaped_label = self.escape_graphviz(label);
321                return format!("{} -> {} [label=\"{}\"]", from_id, to_id, escaped_label);
322            }
323        }
324
325        format!("{} -> {}", from_id, to_id)
326    }
327
328    /// Export to PlantUML activity diagram format
329    pub fn to_plantuml(&self) -> String {
330        let mut output = String::new();
331
332        // Header
333        output.push_str("@startuml\n");
334
335        if let Some(desc) = &self.workflow.metadata.description {
336            output.push_str(&format!("title {}\n", desc));
337        }
338
339        output.push_str("start\n\n");
340
341        // Build execution order using topological sort
342        let execution_order = self.topological_sort();
343
344        // Track visited nodes to handle branching
345        let mut visited = HashSet::new();
346
347        for node_id in execution_order {
348            if visited.contains(&node_id) {
349                continue;
350            }
351            visited.insert(node_id);
352
353            if let Some(node) = self.workflow.nodes.iter().find(|n| n.id == node_id) {
354                let node_def = self.plantuml_node_definition(node);
355                output.push_str(&format!("{}\n", node_def));
356            }
357        }
358
359        output.push_str("\nstop\n");
360        output.push_str("@enduml\n");
361        output
362    }
363
364    /// Generate PlantUML node definition
365    fn plantuml_node_definition(&self, node: &Node) -> String {
366        let label = self.node_label(node);
367
368        match node.kind {
369            NodeKind::Start => "start".to_string(),
370            NodeKind::End => "stop".to_string(),
371            NodeKind::IfElse(_) => format!("if ({}) then (yes)\n  :proceed;\nelse (no)\n  :alternative;\nendif", label),
372            NodeKind::Switch(_) => format!("switch ({})\ncase (option 1)\n  :handle option 1;\ncase (option 2)\n  :handle option 2;\nendswitch", label),
373            NodeKind::Loop(_) => format!("while ({})\n  :process;\nendwhile", label),
374            _ => format!(":{};", label),
375        }
376    }
377
378    /// Perform topological sort on workflow nodes
379    fn topological_sort(&self) -> Vec<uuid::Uuid> {
380        let mut result = Vec::new();
381        let mut visited = HashSet::new();
382        let mut temp_mark = HashSet::new();
383
384        // Build adjacency list
385        let mut adj: HashMap<uuid::Uuid, Vec<uuid::Uuid>> = HashMap::new();
386        for edge in &self.workflow.edges {
387            adj.entry(edge.from).or_default().push(edge.to);
388        }
389
390        // Find start nodes
391        let start_nodes: Vec<_> = self
392            .workflow
393            .nodes
394            .iter()
395            .filter(|n| matches!(n.kind, NodeKind::Start))
396            .map(|n| n.id)
397            .collect();
398
399        fn visit(
400            node: uuid::Uuid,
401            adj: &HashMap<uuid::Uuid, Vec<uuid::Uuid>>,
402            visited: &mut HashSet<uuid::Uuid>,
403            temp_mark: &mut HashSet<uuid::Uuid>,
404            result: &mut Vec<uuid::Uuid>,
405        ) {
406            if visited.contains(&node) {
407                return;
408            }
409
410            if temp_mark.contains(&node) {
411                // Cycle detected, skip
412                return;
413            }
414
415            temp_mark.insert(node);
416
417            if let Some(neighbors) = adj.get(&node) {
418                for &neighbor in neighbors {
419                    visit(neighbor, adj, visited, temp_mark, result);
420                }
421            }
422
423            temp_mark.remove(&node);
424            visited.insert(node);
425            result.push(node);
426        }
427
428        for start in start_nodes {
429            visit(start, &adj, &mut visited, &mut temp_mark, &mut result);
430        }
431
432        result.reverse();
433        result
434    }
435
436    /// Generate node label with optional ID
437    fn node_label(&self, node: &Node) -> String {
438        if self.style.show_node_ids {
439            format!("{}\n({})", node.name, &node.id.to_string()[..8])
440        } else {
441            node.name.clone()
442        }
443    }
444
445    /// Sanitize ID for use in diagram formats
446    fn sanitize_id(&self, id: &str) -> String {
447        id.replace('-', "_").chars().take(8).collect::<String>()
448    }
449
450    /// Escape special characters for Graphviz
451    fn escape_graphviz(&self, s: &str) -> String {
452        s.replace('"', "\\\"").replace('\n', "\\n")
453    }
454
455    /// Export to specified format
456    pub fn export(&self, format: VisualizationFormat) -> String {
457        match format {
458            VisualizationFormat::Mermaid => self.to_mermaid(),
459            VisualizationFormat::Graphviz => self.to_graphviz(),
460            VisualizationFormat::PlantUML => self.to_plantuml(),
461        }
462    }
463}
464
465/// Helper function to generate Mermaid diagram from workflow
466pub fn workflow_to_mermaid(workflow: &Workflow) -> String {
467    WorkflowVisualizer::new(workflow).to_mermaid()
468}
469
470/// Helper function to generate Graphviz DOT from workflow
471pub fn workflow_to_graphviz(workflow: &Workflow) -> String {
472    WorkflowVisualizer::new(workflow).to_graphviz()
473}
474
475/// Helper function to generate PlantUML from workflow
476pub fn workflow_to_plantuml(workflow: &Workflow) -> String {
477    WorkflowVisualizer::new(workflow).to_plantuml()
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use crate::{LlmConfig, ScriptConfig, WorkflowBuilder};
484
485    fn create_llm_config() -> LlmConfig {
486        LlmConfig {
487            provider: "openai".to_string(),
488            model: "gpt-4".to_string(),
489            system_prompt: None,
490            prompt_template: "test".to_string(),
491            temperature: Some(0.7),
492            max_tokens: Some(100),
493            tools: vec![],
494            images: vec![],
495            extra_params: serde_json::json!({}),
496        }
497    }
498
499    fn create_script_config() -> ScriptConfig {
500        ScriptConfig {
501            runtime: "rust".to_string(),
502            code: "fn main() {}".to_string(),
503            inputs: vec![],
504            output: "result".to_string(),
505        }
506    }
507
508    #[test]
509    fn test_mermaid_export() {
510        let workflow = WorkflowBuilder::new("test")
511            .description("Test workflow")
512            .start("Start")
513            .llm("Generate", create_llm_config())
514            .end("End")
515            .build();
516
517        let mermaid = workflow_to_mermaid(&workflow);
518        assert!(mermaid.contains("flowchart TB"));
519        assert!(mermaid.contains("Generate"));
520    }
521
522    #[test]
523    fn test_graphviz_export() {
524        let workflow = WorkflowBuilder::new("test")
525            .start("Start")
526            .llm("Process", create_llm_config())
527            .end("End")
528            .build();
529
530        let dot = workflow_to_graphviz(&workflow);
531        assert!(dot.contains("digraph workflow"));
532        assert!(dot.contains("Process"));
533        assert!(dot.contains("->"));
534    }
535
536    #[test]
537    fn test_plantuml_export() {
538        let workflow = WorkflowBuilder::new("test")
539            .start("Start")
540            .llm("Action", create_llm_config())
541            .end("End")
542            .build();
543
544        let plantuml = workflow_to_plantuml(&workflow);
545        assert!(plantuml.contains("@startuml"));
546        assert!(plantuml.contains("@enduml"));
547        assert!(plantuml.contains("Action"));
548    }
549
550    #[test]
551    fn test_visualization_with_custom_style() {
552        let workflow = WorkflowBuilder::new("test")
553            .start("Start")
554            .llm("Task", create_llm_config())
555            .end("End")
556            .build();
557
558        let style = VisualizationStyle {
559            show_node_ids: true,
560            show_edge_labels: true,
561            use_colors: false,
562            include_descriptions: false,
563            orientation: DiagramOrientation::LeftRight,
564            group_by_type: false,
565        };
566
567        let visualizer = WorkflowVisualizer::with_style(&workflow, style);
568        let mermaid = visualizer.to_mermaid();
569        assert!(mermaid.contains("flowchart LR"));
570    }
571
572    #[test]
573    fn test_mermaid_with_colors() {
574        let workflow = WorkflowBuilder::new("test")
575            .start("Start")
576            .llm("LLM", create_llm_config())
577            .end("End")
578            .build();
579
580        let visualizer = WorkflowVisualizer::new(&workflow);
581        let mermaid = visualizer.to_mermaid();
582        assert!(mermaid.contains("classDef"));
583        assert!(mermaid.contains("class"));
584    }
585
586    #[test]
587    fn test_export_all_formats() {
588        let workflow = WorkflowBuilder::new("test")
589            .start("Start")
590            .llm("Process", create_llm_config())
591            .end("End")
592            .build();
593
594        let visualizer = WorkflowVisualizer::new(&workflow);
595
596        let mermaid = visualizer.export(VisualizationFormat::Mermaid);
597        assert!(mermaid.contains("flowchart"));
598
599        let graphviz = visualizer.export(VisualizationFormat::Graphviz);
600        assert!(graphviz.contains("digraph"));
601
602        let plantuml = visualizer.export(VisualizationFormat::PlantUML);
603        assert!(plantuml.contains("@startuml"));
604    }
605
606    #[test]
607    fn test_diagram_orientations() {
608        assert_eq!(DiagramOrientation::TopBottom.to_mermaid(), "TB");
609        assert_eq!(DiagramOrientation::LeftRight.to_mermaid(), "LR");
610        assert_eq!(DiagramOrientation::BottomTop.to_mermaid(), "BT");
611        assert_eq!(DiagramOrientation::RightLeft.to_mermaid(), "RL");
612    }
613
614    #[test]
615    fn test_node_shapes_in_mermaid() {
616        let workflow = WorkflowBuilder::new("test")
617            .start("Start")
618            .llm("LLM", create_llm_config())
619            .end("End")
620            .build();
621
622        let mermaid = workflow_to_mermaid(&workflow);
623        // Start/End nodes use brackets []
624        assert!(mermaid.contains('[') && mermaid.contains(']'));
625    }
626
627    #[test]
628    fn test_edge_labels() {
629        let mut workflow = WorkflowBuilder::new("test")
630            .start("Start")
631            .llm("Process", create_llm_config())
632            .end("End")
633            .build();
634
635        // Add edge label
636        if let Some(edge) = workflow.edges.get_mut(0) {
637            edge.label = Some("success".to_string());
638        }
639
640        let mermaid = workflow_to_mermaid(&workflow);
641        assert!(mermaid.contains("success"));
642    }
643
644    #[test]
645    fn test_graphviz_colors() {
646        let workflow = WorkflowBuilder::new("test")
647            .start("Start")
648            .llm("LLM", create_llm_config())
649            .code("Code", create_script_config())
650            .end("End")
651            .build();
652
653        let dot = workflow_to_graphviz(&workflow);
654        assert!(dot.contains("fillcolor"));
655        assert!(dot.contains("#87CEEB")); // LLM color
656    }
657}