1use crate::{Edge, Node, NodeKind, Workflow};
41use serde::{Deserialize, Serialize};
42use std::collections::{HashMap, HashSet};
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46pub enum VisualizationFormat {
47 Mermaid,
49 Graphviz,
51 PlantUML,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct VisualizationStyle {
58 pub show_node_ids: bool,
60
61 pub show_edge_labels: bool,
63
64 pub use_colors: bool,
66
67 pub include_descriptions: bool,
69
70 pub orientation: DiagramOrientation,
72
73 pub group_by_type: bool,
75}
76
77impl Default for VisualizationStyle {
78 fn default() -> Self {
79 Self {
80 show_node_ids: false,
81 show_edge_labels: true,
82 use_colors: true,
83 include_descriptions: false,
84 orientation: DiagramOrientation::TopBottom,
85 group_by_type: false,
86 }
87 }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
92pub enum DiagramOrientation {
93 TopBottom,
95 LeftRight,
97 BottomTop,
99 RightLeft,
101}
102
103impl DiagramOrientation {
104 fn to_mermaid(self) -> &'static str {
106 match self {
107 DiagramOrientation::TopBottom => "TB",
108 DiagramOrientation::LeftRight => "LR",
109 DiagramOrientation::BottomTop => "BT",
110 DiagramOrientation::RightLeft => "RL",
111 }
112 }
113
114 fn to_graphviz(self) -> &'static str {
116 match self {
117 DiagramOrientation::TopBottom => "TB",
118 DiagramOrientation::LeftRight => "LR",
119 DiagramOrientation::BottomTop => "BT",
120 DiagramOrientation::RightLeft => "RL",
121 }
122 }
123}
124
125pub struct WorkflowVisualizer<'a> {
127 workflow: &'a Workflow,
128 style: VisualizationStyle,
129}
130
131impl<'a> WorkflowVisualizer<'a> {
132 pub fn new(workflow: &'a Workflow) -> Self {
134 Self {
135 workflow,
136 style: VisualizationStyle::default(),
137 }
138 }
139
140 pub fn with_style(workflow: &'a Workflow, style: VisualizationStyle) -> Self {
142 Self { workflow, style }
143 }
144
145 pub fn to_mermaid(&self) -> String {
147 let mut output = String::new();
148
149 output.push_str(&format!(
151 "flowchart {}\n",
152 self.style.orientation.to_mermaid()
153 ));
154
155 if let Some(desc) = &self.workflow.metadata.description {
157 output.push_str(" %%{ init: {'theme':'base', 'themeVariables': { 'primaryColor':'#ff9900'}}}%%\n");
158 output.push_str(&format!(" %% {}\n", desc));
159 }
160
161 for node in &self.workflow.nodes {
163 let node_def = self.mermaid_node_definition(node);
164 output.push_str(&format!(" {}\n", node_def));
165 }
166
167 output.push('\n');
168
169 for edge in &self.workflow.edges {
171 let edge_def = self.mermaid_edge_definition(edge);
172 output.push_str(&format!(" {}\n", edge_def));
173 }
174
175 if self.style.use_colors {
177 output.push('\n');
178 output.push_str(&self.mermaid_styling());
179 }
180
181 output
182 }
183
184 fn mermaid_node_definition(&self, node: &Node) -> String {
186 let node_id = self.sanitize_id(&node.id.to_string());
187 let label = self.node_label(node);
188
189 let (open, close) = match node.kind {
191 NodeKind::Start => ("[", "]"),
192 NodeKind::End => ("[", "]"),
193 NodeKind::IfElse(_) => ("{", "}"),
194 NodeKind::Switch(_) => ("{", "}"),
195 NodeKind::Parallel(_) => ("[[", "]]"),
196 NodeKind::Loop(_) => ("{{", "}}"),
197 _ => ("(", ")"),
198 };
199
200 format!("{}{}\"{}\"{}", node_id, open, label, close)
201 }
202
203 fn mermaid_edge_definition(&self, edge: &Edge) -> String {
205 let from_id = self.sanitize_id(&edge.from.to_string());
206 let to_id = self.sanitize_id(&edge.to.to_string());
207
208 if self.style.show_edge_labels {
209 if let Some(label) = &edge.label {
210 return format!("{} -->|\"{}\"| {}", from_id, label, to_id);
211 }
212 }
213
214 format!("{} --> {}", from_id, to_id)
215 }
216
217 fn mermaid_styling(&self) -> String {
219 let mut styling = String::new();
220
221 styling.push_str(" classDef startEnd fill:#90EE90,stroke:#228B22,stroke-width:2px\n");
223 styling.push_str(" classDef llm fill:#87CEEB,stroke:#4682B4,stroke-width:2px\n");
224 styling.push_str(" classDef code fill:#FFB6C1,stroke:#C71585,stroke-width:2px\n");
225 styling.push_str(" classDef decision fill:#FFD700,stroke:#FF8C00,stroke-width:2px\n");
226 styling.push_str(" classDef loop fill:#DDA0DD,stroke:#8B008B,stroke-width:2px\n");
227 styling.push_str(" classDef parallel fill:#F0E68C,stroke:#BDB76B,stroke-width:2px\n");
228
229 for node in &self.workflow.nodes {
231 let node_id = self.sanitize_id(&node.id.to_string());
232 let class_name = match node.kind {
233 NodeKind::Start | NodeKind::End => "startEnd",
234 NodeKind::LLM(_) => "llm",
235 NodeKind::Code(_) => "code",
236 NodeKind::IfElse(_) | NodeKind::Switch(_) => "decision",
237 NodeKind::Loop(_) => "loop",
238 NodeKind::Parallel(_) => "parallel",
239 _ => continue,
240 };
241 styling.push_str(&format!(" class {} {}\n", node_id, class_name));
242 }
243
244 styling
245 }
246
247 pub fn to_graphviz(&self) -> String {
249 let mut output = String::new();
250
251 output.push_str("digraph workflow {\n");
253 output.push_str(&format!(
254 " rankdir={};\n",
255 self.style.orientation.to_graphviz()
256 ));
257 output.push_str(" node [shape=box, style=\"rounded,filled\"];\n");
258 output.push_str(" edge [fontsize=10];\n\n");
259
260 if let Some(desc) = &self.workflow.metadata.description {
262 output.push_str(" labelloc=\"t\";\n");
263 output.push_str(&format!(
264 " label=\"{}\";\n\n",
265 self.escape_graphviz(desc)
266 ));
267 }
268
269 for node in &self.workflow.nodes {
271 let node_def = self.graphviz_node_definition(node);
272 output.push_str(&format!(" {};\n", node_def));
273 }
274
275 output.push('\n');
276
277 for edge in &self.workflow.edges {
279 let edge_def = self.graphviz_edge_definition(edge);
280 output.push_str(&format!(" {};\n", edge_def));
281 }
282
283 output.push_str("}\n");
284 output
285 }
286
287 fn graphviz_node_definition(&self, node: &Node) -> String {
289 let node_id = self.sanitize_id(&node.id.to_string());
290 let label = self.escape_graphviz(&self.node_label(node));
291
292 let (shape, color) = match node.kind {
293 NodeKind::Start => ("ellipse", "#90EE90"),
294 NodeKind::End => ("ellipse", "#FFB6C1"),
295 NodeKind::LLM(_) => ("box", "#87CEEB"),
296 NodeKind::Code(_) => ("box", "#FFB6C1"),
297 NodeKind::IfElse(_) | NodeKind::Switch(_) => ("diamond", "#FFD700"),
298 NodeKind::Loop(_) => ("hexagon", "#DDA0DD"),
299 NodeKind::Parallel(_) => ("parallelogram", "#F0E68C"),
300 _ => ("box", "#E0E0E0"),
301 };
302
303 if self.style.use_colors {
304 format!(
305 "{} [label=\"{}\", shape={}, fillcolor=\"{}\"]",
306 node_id, label, shape, color
307 )
308 } else {
309 format!("{} [label=\"{}\", shape={}]", node_id, label, shape)
310 }
311 }
312
313 fn graphviz_edge_definition(&self, edge: &Edge) -> String {
315 let from_id = self.sanitize_id(&edge.from.to_string());
316 let to_id = self.sanitize_id(&edge.to.to_string());
317
318 if self.style.show_edge_labels {
319 if let Some(label) = &edge.label {
320 let escaped_label = self.escape_graphviz(label);
321 return format!("{} -> {} [label=\"{}\"]", from_id, to_id, escaped_label);
322 }
323 }
324
325 format!("{} -> {}", from_id, to_id)
326 }
327
328 pub fn to_plantuml(&self) -> String {
330 let mut output = String::new();
331
332 output.push_str("@startuml\n");
334
335 if let Some(desc) = &self.workflow.metadata.description {
336 output.push_str(&format!("title {}\n", desc));
337 }
338
339 output.push_str("start\n\n");
340
341 let execution_order = self.topological_sort();
343
344 let mut visited = HashSet::new();
346
347 for node_id in execution_order {
348 if visited.contains(&node_id) {
349 continue;
350 }
351 visited.insert(node_id);
352
353 if let Some(node) = self.workflow.nodes.iter().find(|n| n.id == node_id) {
354 let node_def = self.plantuml_node_definition(node);
355 output.push_str(&format!("{}\n", node_def));
356 }
357 }
358
359 output.push_str("\nstop\n");
360 output.push_str("@enduml\n");
361 output
362 }
363
364 fn plantuml_node_definition(&self, node: &Node) -> String {
366 let label = self.node_label(node);
367
368 match node.kind {
369 NodeKind::Start => "start".to_string(),
370 NodeKind::End => "stop".to_string(),
371 NodeKind::IfElse(_) => format!("if ({}) then (yes)\n :proceed;\nelse (no)\n :alternative;\nendif", label),
372 NodeKind::Switch(_) => format!("switch ({})\ncase (option 1)\n :handle option 1;\ncase (option 2)\n :handle option 2;\nendswitch", label),
373 NodeKind::Loop(_) => format!("while ({})\n :process;\nendwhile", label),
374 _ => format!(":{};", label),
375 }
376 }
377
378 fn topological_sort(&self) -> Vec<uuid::Uuid> {
380 let mut result = Vec::new();
381 let mut visited = HashSet::new();
382 let mut temp_mark = HashSet::new();
383
384 let mut adj: HashMap<uuid::Uuid, Vec<uuid::Uuid>> = HashMap::new();
386 for edge in &self.workflow.edges {
387 adj.entry(edge.from).or_default().push(edge.to);
388 }
389
390 let start_nodes: Vec<_> = self
392 .workflow
393 .nodes
394 .iter()
395 .filter(|n| matches!(n.kind, NodeKind::Start))
396 .map(|n| n.id)
397 .collect();
398
399 fn visit(
400 node: uuid::Uuid,
401 adj: &HashMap<uuid::Uuid, Vec<uuid::Uuid>>,
402 visited: &mut HashSet<uuid::Uuid>,
403 temp_mark: &mut HashSet<uuid::Uuid>,
404 result: &mut Vec<uuid::Uuid>,
405 ) {
406 if visited.contains(&node) {
407 return;
408 }
409
410 if temp_mark.contains(&node) {
411 return;
413 }
414
415 temp_mark.insert(node);
416
417 if let Some(neighbors) = adj.get(&node) {
418 for &neighbor in neighbors {
419 visit(neighbor, adj, visited, temp_mark, result);
420 }
421 }
422
423 temp_mark.remove(&node);
424 visited.insert(node);
425 result.push(node);
426 }
427
428 for start in start_nodes {
429 visit(start, &adj, &mut visited, &mut temp_mark, &mut result);
430 }
431
432 result.reverse();
433 result
434 }
435
436 fn node_label(&self, node: &Node) -> String {
438 if self.style.show_node_ids {
439 format!("{}\n({})", node.name, &node.id.to_string()[..8])
440 } else {
441 node.name.clone()
442 }
443 }
444
445 fn sanitize_id(&self, id: &str) -> String {
447 id.replace('-', "_").chars().take(8).collect::<String>()
448 }
449
450 fn escape_graphviz(&self, s: &str) -> String {
452 s.replace('"', "\\\"").replace('\n', "\\n")
453 }
454
455 pub fn export(&self, format: VisualizationFormat) -> String {
457 match format {
458 VisualizationFormat::Mermaid => self.to_mermaid(),
459 VisualizationFormat::Graphviz => self.to_graphviz(),
460 VisualizationFormat::PlantUML => self.to_plantuml(),
461 }
462 }
463}
464
465pub fn workflow_to_mermaid(workflow: &Workflow) -> String {
467 WorkflowVisualizer::new(workflow).to_mermaid()
468}
469
470pub fn workflow_to_graphviz(workflow: &Workflow) -> String {
472 WorkflowVisualizer::new(workflow).to_graphviz()
473}
474
475pub fn workflow_to_plantuml(workflow: &Workflow) -> String {
477 WorkflowVisualizer::new(workflow).to_plantuml()
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use crate::{LlmConfig, ScriptConfig, WorkflowBuilder};
484
485 fn create_llm_config() -> LlmConfig {
486 LlmConfig {
487 provider: "openai".to_string(),
488 model: "gpt-4".to_string(),
489 system_prompt: None,
490 prompt_template: "test".to_string(),
491 temperature: Some(0.7),
492 max_tokens: Some(100),
493 tools: vec![],
494 images: vec![],
495 extra_params: serde_json::json!({}),
496 }
497 }
498
499 fn create_script_config() -> ScriptConfig {
500 ScriptConfig {
501 runtime: "rust".to_string(),
502 code: "fn main() {}".to_string(),
503 inputs: vec![],
504 output: "result".to_string(),
505 }
506 }
507
508 #[test]
509 fn test_mermaid_export() {
510 let workflow = WorkflowBuilder::new("test")
511 .description("Test workflow")
512 .start("Start")
513 .llm("Generate", create_llm_config())
514 .end("End")
515 .build();
516
517 let mermaid = workflow_to_mermaid(&workflow);
518 assert!(mermaid.contains("flowchart TB"));
519 assert!(mermaid.contains("Generate"));
520 }
521
522 #[test]
523 fn test_graphviz_export() {
524 let workflow = WorkflowBuilder::new("test")
525 .start("Start")
526 .llm("Process", create_llm_config())
527 .end("End")
528 .build();
529
530 let dot = workflow_to_graphviz(&workflow);
531 assert!(dot.contains("digraph workflow"));
532 assert!(dot.contains("Process"));
533 assert!(dot.contains("->"));
534 }
535
536 #[test]
537 fn test_plantuml_export() {
538 let workflow = WorkflowBuilder::new("test")
539 .start("Start")
540 .llm("Action", create_llm_config())
541 .end("End")
542 .build();
543
544 let plantuml = workflow_to_plantuml(&workflow);
545 assert!(plantuml.contains("@startuml"));
546 assert!(plantuml.contains("@enduml"));
547 assert!(plantuml.contains("Action"));
548 }
549
550 #[test]
551 fn test_visualization_with_custom_style() {
552 let workflow = WorkflowBuilder::new("test")
553 .start("Start")
554 .llm("Task", create_llm_config())
555 .end("End")
556 .build();
557
558 let style = VisualizationStyle {
559 show_node_ids: true,
560 show_edge_labels: true,
561 use_colors: false,
562 include_descriptions: false,
563 orientation: DiagramOrientation::LeftRight,
564 group_by_type: false,
565 };
566
567 let visualizer = WorkflowVisualizer::with_style(&workflow, style);
568 let mermaid = visualizer.to_mermaid();
569 assert!(mermaid.contains("flowchart LR"));
570 }
571
572 #[test]
573 fn test_mermaid_with_colors() {
574 let workflow = WorkflowBuilder::new("test")
575 .start("Start")
576 .llm("LLM", create_llm_config())
577 .end("End")
578 .build();
579
580 let visualizer = WorkflowVisualizer::new(&workflow);
581 let mermaid = visualizer.to_mermaid();
582 assert!(mermaid.contains("classDef"));
583 assert!(mermaid.contains("class"));
584 }
585
586 #[test]
587 fn test_export_all_formats() {
588 let workflow = WorkflowBuilder::new("test")
589 .start("Start")
590 .llm("Process", create_llm_config())
591 .end("End")
592 .build();
593
594 let visualizer = WorkflowVisualizer::new(&workflow);
595
596 let mermaid = visualizer.export(VisualizationFormat::Mermaid);
597 assert!(mermaid.contains("flowchart"));
598
599 let graphviz = visualizer.export(VisualizationFormat::Graphviz);
600 assert!(graphviz.contains("digraph"));
601
602 let plantuml = visualizer.export(VisualizationFormat::PlantUML);
603 assert!(plantuml.contains("@startuml"));
604 }
605
606 #[test]
607 fn test_diagram_orientations() {
608 assert_eq!(DiagramOrientation::TopBottom.to_mermaid(), "TB");
609 assert_eq!(DiagramOrientation::LeftRight.to_mermaid(), "LR");
610 assert_eq!(DiagramOrientation::BottomTop.to_mermaid(), "BT");
611 assert_eq!(DiagramOrientation::RightLeft.to_mermaid(), "RL");
612 }
613
614 #[test]
615 fn test_node_shapes_in_mermaid() {
616 let workflow = WorkflowBuilder::new("test")
617 .start("Start")
618 .llm("LLM", create_llm_config())
619 .end("End")
620 .build();
621
622 let mermaid = workflow_to_mermaid(&workflow);
623 assert!(mermaid.contains('[') && mermaid.contains(']'));
625 }
626
627 #[test]
628 fn test_edge_labels() {
629 let mut workflow = WorkflowBuilder::new("test")
630 .start("Start")
631 .llm("Process", create_llm_config())
632 .end("End")
633 .build();
634
635 if let Some(edge) = workflow.edges.get_mut(0) {
637 edge.label = Some("success".to_string());
638 }
639
640 let mermaid = workflow_to_mermaid(&workflow);
641 assert!(mermaid.contains("success"));
642 }
643
644 #[test]
645 fn test_graphviz_colors() {
646 let workflow = WorkflowBuilder::new("test")
647 .start("Start")
648 .llm("LLM", create_llm_config())
649 .code("Code", create_script_config())
650 .end("End")
651 .build();
652
653 let dot = workflow_to_graphviz(&workflow);
654 assert!(dot.contains("fillcolor"));
655 assert!(dot.contains("#87CEEB")); }
657}