1use crate::{
47 CostEstimate, CostEstimator, ExecutionState, Node, NodeId, NodeKind, TimeEstimate,
48 TimePredictor, Workflow,
49};
50use serde::{Deserialize, Serialize};
51use serde_json::Value;
52use std::collections::{HashMap, HashSet};
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SimulationResult {
57 pub success: bool,
59
60 pub final_state: ExecutionState,
62
63 pub trace: ExecutionTrace,
65
66 pub final_context: HashMap<String, Value>,
68
69 pub cost_estimate: Option<CostEstimate>,
71
72 pub time_estimate: Option<TimeEstimate>,
74
75 pub coverage: CoverageInfo,
77
78 pub errors: Vec<SimulationError>,
80
81 pub warnings: Vec<String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct ExecutionTrace {
88 pub executed_nodes: Vec<NodeId>,
90
91 pub node_details: HashMap<NodeId, NodeExecutionDetail>,
93
94 pub total_time_ms: u64,
96
97 pub node_count: usize,
99
100 pub branches_taken: HashMap<NodeId, Vec<String>>,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct NodeExecutionDetail {
107 pub node_id: NodeId,
109
110 pub node_name: String,
112
113 pub node_type: String,
115
116 pub execution_time_ms: u64,
118
119 pub input_context: HashMap<String, Value>,
121
122 pub output: Value,
124
125 pub mocked: bool,
127
128 pub retry_count: u32,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct CoverageInfo {
135 pub total_nodes: usize,
137
138 pub executed_nodes: usize,
140
141 pub coverage_percent: f64,
143
144 pub unexecuted_nodes: Vec<NodeId>,
146
147 pub branches_taken: HashMap<NodeId, Vec<String>>,
149
150 pub branches_not_taken: HashMap<NodeId, Vec<String>>,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct SimulationError {
157 pub node_id: NodeId,
159
160 pub message: String,
162
163 pub error_type: String,
165
166 pub expected: bool,
168}
169
170pub struct WorkflowSimulator {
172 mock_responses: HashMap<String, Value>,
174
175 simulate_latencies: bool,
177
178 estimate_costs: bool,
180
181 estimate_times: bool,
183
184 max_steps: usize,
186
187 seed: Option<u64>,
189}
190
191impl WorkflowSimulator {
192 pub fn new() -> Self {
194 Self {
195 mock_responses: HashMap::new(),
196 simulate_latencies: true,
197 estimate_costs: true,
198 estimate_times: true,
199 max_steps: 10000,
200 seed: None,
201 }
202 }
203
204 pub fn with_mock_responses(mut self, responses: Vec<(String, Value)>) -> Self {
206 self.mock_responses = responses.into_iter().collect();
207 self
208 }
209
210 pub fn simulate_latencies(mut self, enabled: bool) -> Self {
212 self.simulate_latencies = enabled;
213 self
214 }
215
216 pub fn estimate_costs(mut self, enabled: bool) -> Self {
218 self.estimate_costs = enabled;
219 self
220 }
221
222 pub fn estimate_times(mut self, enabled: bool) -> Self {
224 self.estimate_times = enabled;
225 self
226 }
227
228 pub fn max_steps(mut self, steps: usize) -> Self {
230 self.max_steps = steps;
231 self
232 }
233
234 pub fn with_seed(mut self, seed: u64) -> Self {
236 self.seed = Some(seed);
237 self
238 }
239
240 pub fn simulate(
242 &self,
243 workflow: &Workflow,
244 initial_context: HashMap<String, Value>,
245 ) -> Result<SimulationResult, String> {
246 let mut context = SimulationContext::new(workflow, initial_context, self.max_steps);
247
248 let start_node = workflow
250 .nodes
251 .iter()
252 .find(|n| matches!(n.kind, NodeKind::Start))
253 .ok_or("No start node found")?;
254
255 self.execute_node(&mut context, workflow, &start_node.id)?;
257
258 let coverage = self.calculate_coverage(workflow, &context);
260
261 let cost_estimate = if self.estimate_costs {
262 Some(CostEstimator::estimate(workflow))
263 } else {
264 None
265 };
266
267 let time_estimate = if self.estimate_times {
268 let predictor = TimePredictor::new();
269 Some(predictor.predict(workflow))
270 } else {
271 None
272 };
273
274 Ok(SimulationResult {
275 success: context.errors.is_empty(),
276 final_state: if context.errors.is_empty() {
277 ExecutionState::Completed
278 } else {
279 let error_msg = context
280 .errors
281 .iter()
282 .map(|e| e.message.as_str())
283 .collect::<Vec<_>>()
284 .join("; ");
285 ExecutionState::Failed(error_msg)
286 },
287 trace: context.build_trace(),
288 final_context: context.variables,
289 cost_estimate,
290 time_estimate,
291 coverage,
292 errors: context.errors,
293 warnings: context.warnings,
294 })
295 }
296
297 fn execute_node(
299 &self,
300 context: &mut SimulationContext,
301 workflow: &Workflow,
302 node_id: &NodeId,
303 ) -> Result<(), String> {
304 if context.step_count >= self.max_steps {
305 return Err("Maximum simulation steps exceeded".to_string());
306 }
307
308 context.step_count += 1;
309
310 let node = workflow
311 .nodes
312 .iter()
313 .find(|n| &n.id == node_id)
314 .ok_or("Node not found")?;
315
316 if context.executed_nodes.contains(node_id) {
318 return Ok(());
319 }
320
321 context.executed_nodes.insert(*node_id);
322
323 let output = self.simulate_node_execution(context, node)?;
325
326 context.record_execution(node, output.clone(), false);
328
329 match &node.kind {
331 NodeKind::Start => {
332 self.execute_next_nodes(context, workflow, node_id)?;
334 }
335 NodeKind::End => {
336 context.completed = true;
338 }
339 NodeKind::IfElse(condition_cfg) => {
340 let branch_taken = self.evaluate_condition(&condition_cfg.expression, context);
342 let branch_name = if branch_taken { "true" } else { "false" };
343 context
344 .branches_taken
345 .entry(*node_id)
346 .or_default()
347 .push(branch_name.to_string());
348
349 let next_node = if branch_taken {
351 &condition_cfg.true_branch
352 } else {
353 &condition_cfg.false_branch
354 };
355 self.execute_node(context, workflow, next_node)?;
356 }
357 NodeKind::Switch(switch_cfg) => {
358 let value = self.evaluate_expression(&switch_cfg.switch_on, context);
360 let matched_value = match &value {
361 Value::String(s) => s.clone(),
362 _ => "unknown".to_string(),
363 };
364
365 context
366 .branches_taken
367 .entry(*node_id)
368 .or_default()
369 .push(matched_value.clone());
370
371 self.execute_next_nodes(context, workflow, node_id)?;
373 }
374 NodeKind::Loop(_loop_cfg) => {
375 context.warnings.push(format!(
377 "Loop node '{}' simulated with single iteration",
378 node.name
379 ));
380 self.execute_next_nodes(context, workflow, node_id)?;
381 }
382 _ => {
383 self.execute_next_nodes(context, workflow, node_id)?;
385 }
386 }
387
388 Ok(())
389 }
390
391 fn simulate_node_execution(
393 &self,
394 _context: &SimulationContext,
395 node: &Node,
396 ) -> Result<Value, String> {
397 if let Some(mock) = self.mock_responses.get(&node.name) {
399 return Ok(mock.clone());
400 }
401
402 let output = match &node.kind {
404 NodeKind::Start => Value::Null,
405 NodeKind::End => Value::Null,
406 NodeKind::LLM(_) => Value::String("Simulated LLM response".to_string()),
407 NodeKind::Code(_) => Value::String("Simulated code execution".to_string()),
408 NodeKind::Retriever(_) => Value::Array(vec![
409 Value::String("Simulated document 1".to_string()),
410 Value::String("Simulated document 2".to_string()),
411 ]),
412 NodeKind::Tool(_) => Value::String("Simulated tool result".to_string()),
413 _ => Value::Null,
414 };
415
416 Ok(output)
417 }
418
419 fn execute_next_nodes(
421 &self,
422 context: &mut SimulationContext,
423 workflow: &Workflow,
424 current_node_id: &NodeId,
425 ) -> Result<(), String> {
426 let next_edges: Vec<_> = workflow
427 .edges
428 .iter()
429 .filter(|e| &e.from == current_node_id)
430 .collect();
431
432 for edge in next_edges {
433 self.execute_node(context, workflow, &edge.to)?;
434 }
435
436 Ok(())
437 }
438
439 fn evaluate_condition(&self, _condition: &str, _context: &SimulationContext) -> bool {
441 true
444 }
445
446 fn evaluate_expression(&self, _expression: &str, _context: &SimulationContext) -> Value {
448 Value::String("simulated".to_string())
451 }
452
453 #[allow(dead_code)]
455 fn matches_case(&self, _value: &Value, _match_value: &str) -> bool {
456 true
458 }
459
460 fn calculate_coverage(&self, workflow: &Workflow, context: &SimulationContext) -> CoverageInfo {
462 let total_nodes = workflow.nodes.len();
463 let executed_nodes = context.executed_nodes.len();
464 let coverage_percent = if total_nodes > 0 {
465 (executed_nodes as f64 / total_nodes as f64) * 100.0
466 } else {
467 0.0
468 };
469
470 let unexecuted_nodes: Vec<NodeId> = workflow
471 .nodes
472 .iter()
473 .filter(|n| !context.executed_nodes.contains(&n.id))
474 .map(|n| n.id)
475 .collect();
476
477 CoverageInfo {
478 total_nodes,
479 executed_nodes,
480 coverage_percent,
481 unexecuted_nodes,
482 branches_taken: context.branches_taken.clone(),
483 branches_not_taken: HashMap::new(),
484 }
485 }
486}
487
488impl Default for WorkflowSimulator {
489 fn default() -> Self {
490 Self::new()
491 }
492}
493
494#[allow(dead_code)]
496struct SimulationContext {
497 variables: HashMap<String, Value>,
499
500 executed_nodes: HashSet<NodeId>,
502
503 execution_details: Vec<NodeExecutionDetail>,
505
506 branches_taken: HashMap<NodeId, Vec<String>>,
508
509 errors: Vec<SimulationError>,
511
512 warnings: Vec<String>,
514
515 step_count: usize,
517
518 max_steps: usize,
520
521 completed: bool,
523
524 total_time_ms: u64,
526}
527
528impl SimulationContext {
529 fn new(
530 _workflow: &Workflow,
531 initial_context: HashMap<String, Value>,
532 max_steps: usize,
533 ) -> Self {
534 Self {
535 variables: initial_context,
536 executed_nodes: HashSet::new(),
537 execution_details: Vec::new(),
538 branches_taken: HashMap::new(),
539 errors: Vec::new(),
540 warnings: Vec::new(),
541 step_count: 0,
542 max_steps,
543 completed: false,
544 total_time_ms: 0,
545 }
546 }
547
548 fn record_execution(&mut self, node: &Node, output: Value, mocked: bool) {
549 let execution_time_ms = self.estimate_node_time(node);
550 self.total_time_ms += execution_time_ms;
551
552 self.execution_details.push(NodeExecutionDetail {
553 node_id: node.id,
554 node_name: node.name.clone(),
555 node_type: format!("{:?}", node.kind),
556 execution_time_ms,
557 input_context: self.variables.clone(),
558 output,
559 mocked,
560 retry_count: 0,
561 });
562 }
563
564 fn estimate_node_time(&self, node: &Node) -> u64 {
565 match &node.kind {
567 NodeKind::Start | NodeKind::End => 0,
568 NodeKind::LLM(_) => 1000,
569 NodeKind::Code(_) => 100,
570 NodeKind::Retriever(_) => 500,
571 NodeKind::Tool(_) => 200,
572 _ => 50,
573 }
574 }
575
576 fn build_trace(&self) -> ExecutionTrace {
577 let executed_nodes: Vec<NodeId> =
578 self.execution_details.iter().map(|d| d.node_id).collect();
579
580 let node_details: HashMap<NodeId, NodeExecutionDetail> = self
581 .execution_details
582 .iter()
583 .map(|d| (d.node_id, d.clone()))
584 .collect();
585
586 ExecutionTrace {
587 executed_nodes,
588 node_details,
589 total_time_ms: self.total_time_ms,
590 node_count: self.execution_details.len(),
591 branches_taken: self.branches_taken.clone(),
592 }
593 }
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599 use crate::{LlmConfig, WorkflowBuilder};
600
601 #[test]
602 fn test_simulate_simple_workflow() {
603 let workflow = WorkflowBuilder::new("Test")
604 .start("start")
605 .end("end")
606 .build();
607
608 let context = HashMap::new();
609 let simulator = WorkflowSimulator::new();
610 let result = simulator.simulate(&workflow, context);
611
612 assert!(result.is_ok());
613 let sim_result = result.unwrap();
614 assert!(sim_result.success);
615 assert_eq!(sim_result.coverage.executed_nodes, 2);
616 }
617
618 #[test]
619 fn test_simulate_with_llm_node() {
620 let config = LlmConfig {
621 provider: "openai".to_string(),
622 model: "gpt-4".to_string(),
623 system_prompt: None,
624 prompt_template: "Generate: {{input}}".to_string(),
625 temperature: Some(0.7),
626 max_tokens: None,
627 tools: vec![],
628 images: vec![],
629 extra_params: serde_json::Value::Null,
630 };
631
632 let workflow = WorkflowBuilder::new("Test")
633 .start("start")
634 .llm("gen", config)
635 .end("end")
636 .build();
637
638 let mut context = HashMap::new();
639 context.insert("input".to_string(), Value::String("test".to_string()));
640
641 let simulator = WorkflowSimulator::new();
642 let result = simulator.simulate(&workflow, context);
643
644 assert!(result.is_ok());
645 let sim_result = result.unwrap();
646 assert!(sim_result.success);
647 assert_eq!(sim_result.coverage.executed_nodes, 3);
648 }
649
650 #[test]
651 fn test_simulate_with_mock_response() {
652 let config = LlmConfig {
653 provider: "openai".to_string(),
654 model: "gpt-4".to_string(),
655 system_prompt: None,
656 prompt_template: "Generate: {{input}}".to_string(),
657 temperature: Some(0.7),
658 max_tokens: None,
659 tools: vec![],
660 images: vec![],
661 extra_params: serde_json::Value::Null,
662 };
663
664 let workflow = WorkflowBuilder::new("Test")
665 .start("start")
666 .llm("gen", config)
667 .end("end")
668 .build();
669
670 let context = HashMap::new();
671 let mock_response = Value::String("Mocked LLM response".to_string());
672
673 let simulator = WorkflowSimulator::new()
674 .with_mock_responses(vec![("gen".to_string(), mock_response.clone())]);
675
676 let result = simulator.simulate(&workflow, context);
677
678 assert!(result.is_ok());
679 let sim_result = result.unwrap();
680 assert!(sim_result.success);
681
682 let gen_detail = sim_result
684 .trace
685 .node_details
686 .values()
687 .find(|d| d.node_name == "gen")
688 .unwrap();
689 assert_eq!(gen_detail.output, mock_response);
690 }
691
692 #[test]
693 fn test_coverage_calculation() {
694 let config = LlmConfig {
695 provider: "openai".to_string(),
696 model: "gpt-4".to_string(),
697 system_prompt: None,
698 prompt_template: "test".to_string(),
699 temperature: Some(0.7),
700 max_tokens: None,
701 tools: vec![],
702 images: vec![],
703 extra_params: serde_json::Value::Null,
704 };
705
706 let workflow = WorkflowBuilder::new("Test")
707 .start("start")
708 .llm("gen", config)
709 .end("end")
710 .build();
711
712 let context = HashMap::new();
713 let simulator = WorkflowSimulator::new();
714 let result = simulator.simulate(&workflow, context).unwrap();
715
716 assert_eq!(result.coverage.total_nodes, 3);
717 assert_eq!(result.coverage.executed_nodes, 3);
718 assert_eq!(result.coverage.coverage_percent, 100.0);
719 assert!(result.coverage.unexecuted_nodes.is_empty());
720 }
721
722 #[test]
723 fn test_execution_trace() {
724 let config = LlmConfig {
725 provider: "openai".to_string(),
726 model: "gpt-4".to_string(),
727 system_prompt: None,
728 prompt_template: "test".to_string(),
729 temperature: Some(0.7),
730 max_tokens: None,
731 tools: vec![],
732 images: vec![],
733 extra_params: serde_json::Value::Null,
734 };
735
736 let workflow = WorkflowBuilder::new("Test")
737 .start("start")
738 .llm("gen", config)
739 .end("end")
740 .build();
741
742 let context = HashMap::new();
743 let simulator = WorkflowSimulator::new();
744 let result = simulator.simulate(&workflow, context).unwrap();
745
746 assert_eq!(result.trace.node_count, 3);
747 assert_eq!(result.trace.executed_nodes.len(), 3);
748 assert!(result.trace.total_time_ms > 0);
749 }
750
751 #[test]
752 fn test_cost_estimation() {
753 let config = LlmConfig {
754 provider: "openai".to_string(),
755 model: "gpt-4".to_string(),
756 system_prompt: None,
757 prompt_template: "test".to_string(),
758 temperature: Some(0.7),
759 max_tokens: None,
760 tools: vec![],
761 images: vec![],
762 extra_params: serde_json::Value::Null,
763 };
764
765 let workflow = WorkflowBuilder::new("Test")
766 .start("start")
767 .llm("gen", config)
768 .end("end")
769 .build();
770
771 let context = HashMap::new();
772 let simulator = WorkflowSimulator::new().estimate_costs(true);
773 let result = simulator.simulate(&workflow, context).unwrap();
774
775 assert!(result.cost_estimate.is_some());
776 }
777
778 #[test]
779 fn test_time_estimation() {
780 let config = LlmConfig {
781 provider: "openai".to_string(),
782 model: "gpt-4".to_string(),
783 system_prompt: None,
784 prompt_template: "test".to_string(),
785 temperature: Some(0.7),
786 max_tokens: None,
787 tools: vec![],
788 images: vec![],
789 extra_params: serde_json::Value::Null,
790 };
791
792 let workflow = WorkflowBuilder::new("Test")
793 .start("start")
794 .llm("gen", config)
795 .end("end")
796 .build();
797
798 let context = HashMap::new();
799 let simulator = WorkflowSimulator::new().estimate_times(true);
800 let result = simulator.simulate(&workflow, context).unwrap();
801
802 assert!(result.time_estimate.is_some());
803 }
804
805 #[test]
806 fn test_max_steps_limit() {
807 let config = LlmConfig {
808 provider: "openai".to_string(),
809 model: "gpt-4".to_string(),
810 system_prompt: None,
811 prompt_template: "test".to_string(),
812 temperature: Some(0.7),
813 max_tokens: None,
814 tools: vec![],
815 images: vec![],
816 extra_params: serde_json::Value::Null,
817 };
818
819 let workflow = WorkflowBuilder::new("Test")
820 .start("start")
821 .llm("gen", config)
822 .end("end")
823 .build();
824
825 let context = HashMap::new();
826 let simulator = WorkflowSimulator::new().max_steps(1);
827 let result = simulator.simulate(&workflow, context);
828
829 assert!(result.is_err());
831 }
832
833 #[test]
834 fn test_simulator_builder_pattern() {
835 let simulator = WorkflowSimulator::new()
836 .simulate_latencies(false)
837 .estimate_costs(true)
838 .estimate_times(true)
839 .max_steps(5000)
840 .with_seed(42);
841
842 assert!(!simulator.simulate_latencies);
843 assert!(simulator.estimate_costs);
844 assert!(simulator.estimate_times);
845 assert_eq!(simulator.max_steps, 5000);
846 assert_eq!(simulator.seed, Some(42));
847 }
848}