1use super::engine::{RhaiScriptEngine, ScriptContext, ScriptEngineConfig, ScriptResult};
10use anyhow::{Result, anyhow};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
22pub enum ScriptNodeType {
23 Task,
25 Condition,
27 Transform,
29 Validator,
31 Aggregator,
33 LoopCondition,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ScriptNodeConfig {
40 pub id: String,
42 pub name: String,
44 pub node_type: ScriptNodeType,
46 pub script_source: Option<String>,
48 pub script_path: Option<String>,
50 pub entry_function: Option<String>,
52 pub enable_cache: bool,
54 pub timeout_ms: u64,
56 pub max_retries: u32,
58 pub metadata: HashMap<String, String>,
60}
61
62impl Default for ScriptNodeConfig {
63 fn default() -> Self {
64 Self {
65 id: String::new(),
66 name: String::new(),
67 node_type: ScriptNodeType::Task,
68 script_source: None,
69 script_path: None,
70 entry_function: None,
71 enable_cache: true,
72 timeout_ms: 30000,
73 max_retries: 0,
74 metadata: HashMap::new(),
75 }
76 }
77}
78
79impl ScriptNodeConfig {
80 pub fn new(id: &str, name: &str, node_type: ScriptNodeType) -> Self {
81 Self {
82 id: id.to_string(),
83 name: name.to_string(),
84 node_type,
85 ..Default::default()
86 }
87 }
88
89 pub fn with_source(mut self, source: &str) -> Self {
90 self.script_source = Some(source.to_string());
91 self
92 }
93
94 pub fn with_path(mut self, path: &str) -> Self {
95 self.script_path = Some(path.to_string());
96 self
97 }
98
99 pub fn with_entry(mut self, function: &str) -> Self {
100 self.entry_function = Some(function.to_string());
101 self
102 }
103
104 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
105 self.timeout_ms = timeout_ms;
106 self
107 }
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ScriptNodeResult {
117 pub node_id: String,
119 pub success: bool,
121 pub output: serde_json::Value,
123 pub error: Option<String>,
125 pub execution_time_ms: u64,
127 pub retry_count: u32,
129 pub logs: Vec<String>,
131}
132
133pub struct ScriptWorkflowNode {
135 config: ScriptNodeConfig,
137 engine: Arc<RhaiScriptEngine>,
139 cached_script_id: Option<String>,
141}
142
143impl ScriptWorkflowNode {
144 pub async fn new(config: ScriptNodeConfig, engine: Arc<RhaiScriptEngine>) -> Result<Self> {
146 let mut node = Self {
147 config,
148 engine,
149 cached_script_id: None,
150 };
151
152 if node.config.enable_cache {
154 node.compile_script().await?;
155 }
156
157 Ok(node)
158 }
159
160 async fn compile_script(&mut self) -> Result<()> {
162 let source = self.get_script_source().await?;
163 let script_id = format!("node_{}", self.config.id);
164
165 self.engine
166 .compile_and_cache(&script_id, &self.config.name, &source)
167 .await?;
168
169 self.cached_script_id = Some(script_id);
170 Ok(())
171 }
172
173 async fn get_script_source(&self) -> Result<String> {
175 if let Some(ref source) = self.config.script_source {
176 Ok(source.clone())
177 } else if let Some(ref path) = self.config.script_path {
178 tokio::fs::read_to_string(path)
179 .await
180 .map_err(|e| anyhow!("Failed to read script file: {}", e))
181 } else {
182 Err(anyhow!("No script source or path specified"))
183 }
184 }
185
186 pub async fn execute(&self, input: serde_json::Value) -> Result<ScriptNodeResult> {
188 let start_time = std::time::Instant::now();
189 let mut last_error = None;
190 let mut retry_count = 0;
191
192 let mut context = ScriptContext::new()
194 .with_node(&self.config.id)
195 .with_variable("input", input.clone())?;
196
197 for (k, v) in &self.config.metadata {
199 context.metadata.insert(k.clone(), v.clone());
200 }
201
202 while retry_count <= self.config.max_retries {
204 let result = self.execute_once(&context).await;
205
206 match result {
207 Ok(script_result) if script_result.success => {
208 return Ok(ScriptNodeResult {
209 node_id: self.config.id.clone(),
210 success: true,
211 output: script_result.value,
212 error: None,
213 execution_time_ms: start_time.elapsed().as_millis() as u64,
214 retry_count,
215 logs: script_result.logs,
216 });
217 }
218 Ok(script_result) => {
219 last_error = script_result.error;
220 }
221 Err(e) => {
222 last_error = Some(e.to_string());
223 }
224 }
225
226 if retry_count < self.config.max_retries {
227 let delay = std::time::Duration::from_millis(100 * 2u64.pow(retry_count));
229 tokio::time::sleep(delay).await;
230 }
231 retry_count += 1;
232 }
233
234 Ok(ScriptNodeResult {
235 node_id: self.config.id.clone(),
236 success: false,
237 output: serde_json::Value::Null,
238 error: last_error,
239 execution_time_ms: start_time.elapsed().as_millis() as u64,
240 retry_count: retry_count.saturating_sub(1),
241 logs: Vec::new(),
242 })
243 }
244
245 async fn execute_once(&self, context: &ScriptContext) -> Result<ScriptResult> {
247 if let Some(ref script_id) = self.cached_script_id {
249 if let Some(ref entry) = self.config.entry_function {
251 let input = context
252 .get_variable::<serde_json::Value>("input")
253 .unwrap_or(serde_json::Value::Null);
254
255 let result: serde_json::Value = self
256 .engine
257 .call_function(script_id, entry, vec![input], context)
258 .await?;
259
260 Ok(ScriptResult::success(result, 0))
261 } else {
262 self.engine.execute_compiled(script_id, context).await
263 }
264 } else {
265 let source = self.get_script_source().await?;
266 self.engine.execute(&source, context).await
267 }
268 }
269
270 pub async fn execute_as_condition(&self, input: serde_json::Value) -> Result<bool> {
272 let result = self.execute(input).await?;
273
274 if !result.success {
275 return Err(anyhow!(
276 result
277 .error
278 .unwrap_or_else(|| "Condition execution failed".into())
279 ));
280 }
281
282 match &result.output {
284 serde_json::Value::Bool(b) => Ok(*b),
285 serde_json::Value::Number(n) => Ok(n.as_i64().unwrap_or(0) != 0),
286 serde_json::Value::String(s) => Ok(!s.is_empty() && s != "false" && s != "0"),
287 serde_json::Value::Array(arr) => Ok(!arr.is_empty()),
288 serde_json::Value::Object(obj) => Ok(!obj.is_empty()),
289 serde_json::Value::Null => Ok(false),
290 }
291 }
292
293 pub fn config(&self) -> &ScriptNodeConfig {
295 &self.config
296 }
297
298 pub fn id(&self) -> &str {
300 &self.config.id
301 }
302
303 pub fn name(&self) -> &str {
305 &self.config.name
306 }
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct ScriptWorkflowDefinition {
316 pub id: String,
318 pub name: String,
320 pub description: String,
322 pub nodes: Vec<ScriptNodeConfig>,
324 pub edges: Vec<(String, String, Option<String>)>,
326 pub start_node: String,
328 pub end_nodes: Vec<String>,
330 pub global_variables: HashMap<String, serde_json::Value>,
332}
333
334impl ScriptWorkflowDefinition {
335 pub fn new(id: &str, name: &str) -> Self {
336 Self {
337 id: id.to_string(),
338 name: name.to_string(),
339 description: String::new(),
340 nodes: Vec::new(),
341 edges: Vec::new(),
342 start_node: String::new(),
343 end_nodes: Vec::new(),
344 global_variables: HashMap::new(),
345 }
346 }
347
348 pub async fn from_yaml(path: &str) -> Result<Self> {
350 let content = tokio::fs::read_to_string(path).await?;
351 serde_yaml::from_str(&content).map_err(|e| anyhow!("Failed to parse YAML: {}", e))
352 }
353
354 pub async fn from_json(path: &str) -> Result<Self> {
356 let content = tokio::fs::read_to_string(path).await?;
357 serde_json::from_str(&content).map_err(|e| anyhow!("Failed to parse JSON: {}", e))
358 }
359
360 pub fn add_node(&mut self, config: ScriptNodeConfig) -> &mut Self {
362 self.nodes.push(config);
363 self
364 }
365
366 pub fn add_edge(&mut self, from: &str, to: &str) -> &mut Self {
368 self.edges.push((from.to_string(), to.to_string(), None));
369 self
370 }
371
372 pub fn add_conditional_edge(&mut self, from: &str, to: &str, condition: &str) -> &mut Self {
374 self.edges.push((
375 from.to_string(),
376 to.to_string(),
377 Some(condition.to_string()),
378 ));
379 self
380 }
381
382 pub fn set_start(&mut self, node_id: &str) -> &mut Self {
384 self.start_node = node_id.to_string();
385 self
386 }
387
388 pub fn add_end(&mut self, node_id: &str) -> &mut Self {
390 self.end_nodes.push(node_id.to_string());
391 self
392 }
393
394 pub fn validate(&self) -> Result<Vec<String>> {
396 let mut errors = Vec::new();
397
398 if self.id.is_empty() {
399 errors.push("Workflow ID is required".to_string());
400 }
401
402 if self.start_node.is_empty() {
403 errors.push("Start node is not specified".to_string());
404 }
405
406 if self.end_nodes.is_empty() {
407 errors.push("At least one end node is required".to_string());
408 }
409
410 let node_ids: std::collections::HashSet<_> = self.nodes.iter().map(|n| &n.id).collect();
412
413 if !node_ids.contains(&self.start_node) {
414 errors.push(format!("Start node '{}' not found", self.start_node));
415 }
416
417 for end_node in &self.end_nodes {
418 if !node_ids.contains(end_node) {
419 errors.push(format!("End node '{}' not found", end_node));
420 }
421 }
422
423 for (from, to, _) in &self.edges {
424 if !node_ids.contains(from) {
425 errors.push(format!("Edge source node '{}' not found", from));
426 }
427 if !node_ids.contains(to) {
428 errors.push(format!("Edge target node '{}' not found", to));
429 }
430 }
431
432 Ok(errors)
433 }
434}
435
436pub struct ScriptWorkflowExecutor {
442 #[allow(dead_code)]
444 engine: Arc<RhaiScriptEngine>,
445 nodes: HashMap<String, ScriptWorkflowNode>,
447 definition: ScriptWorkflowDefinition,
449 state: Arc<RwLock<WorkflowExecutionState>>,
451}
452
453#[derive(Debug, Clone, Default)]
455pub struct WorkflowExecutionState {
456 pub current_node: Option<String>,
458 pub node_outputs: HashMap<String, serde_json::Value>,
460 pub variables: HashMap<String, serde_json::Value>,
462 pub execution_history: Vec<String>,
464 pub completed: bool,
466 pub final_result: Option<serde_json::Value>,
468 pub error: Option<String>,
470}
471
472impl ScriptWorkflowExecutor {
473 pub async fn new(
475 definition: ScriptWorkflowDefinition,
476 engine_config: ScriptEngineConfig,
477 ) -> Result<Self> {
478 let engine = Arc::new(RhaiScriptEngine::new(engine_config)?);
479 let mut nodes = HashMap::new();
480
481 for node_config in &definition.nodes {
483 let node = ScriptWorkflowNode::new(node_config.clone(), engine.clone()).await?;
484 nodes.insert(node_config.id.clone(), node);
485 }
486
487 let mut state = WorkflowExecutionState::default();
489 state.variables = definition.global_variables.clone();
490
491 Ok(Self {
492 engine,
493 nodes,
494 definition,
495 state: Arc::new(RwLock::new(state)),
496 })
497 }
498
499 pub async fn execute(&self, input: serde_json::Value) -> Result<serde_json::Value> {
501 let mut state = self.state.write().await;
502 state.current_node = Some(self.definition.start_node.clone());
503 state.variables.insert("input".to_string(), input.clone());
504
505 let mut current_value = input;
506
507 while let Some(ref node_id) = state.current_node.clone() {
508 let node = self
510 .nodes
511 .get(node_id)
512 .ok_or_else(|| anyhow!("Node not found: {}", node_id))?;
513
514 if self.definition.end_nodes.contains(node_id) {
516 let result = node.execute(current_value.clone()).await?;
518
519 if !result.success {
520 state.error = result.error;
521 return Err(anyhow!("Node {} execution failed", node_id));
522 }
523
524 state
526 .node_outputs
527 .insert(node_id.clone(), result.output.clone());
528
529 state.completed = true;
530 state.final_result = Some(result.output.clone());
531 break;
532 }
533
534 state.execution_history.push(node_id.clone());
536
537 let result = node.execute(current_value.clone()).await?;
539
540 if !result.success {
541 let error = result.error.clone(); state.error = error.clone();
543 let error_detail = error.unwrap_or_else(|| "unknown error".to_string());
544 return Err(anyhow!(
545 "Node {} execution failed: {}",
546 node_id,
547 error_detail
548 ));
549 }
550
551 state
553 .node_outputs
554 .insert(node_id.clone(), result.output.clone());
555 current_value = result.output;
556
557 let next_node = self.determine_next_node(node_id, ¤t_value).await?;
559 state.current_node = next_node;
560 }
561
562 Ok(state
563 .final_result
564 .clone()
565 .unwrap_or(serde_json::Value::Null))
566 }
567
568 async fn determine_next_node(
570 &self,
571 current_node_id: &str,
572 output: &serde_json::Value,
573 ) -> Result<Option<String>> {
574 let candidate_edges: Vec<_> = self
576 .definition
577 .edges
578 .iter()
579 .filter(|(from, _, _)| from == current_node_id)
580 .collect();
581
582 if candidate_edges.is_empty() {
583 return Ok(None);
584 }
585
586 if candidate_edges.len() == 1 && candidate_edges[0].2.is_none() {
588 return Ok(Some(candidate_edges[0].1.clone()));
589 }
590
591 for (_, to, condition) in &candidate_edges {
593 if let Some(cond) = condition {
594 let condition_value = {
597 if cond.contains("==") {
599 let parts: Vec<_> = cond
600 .split("==")
601 .map(|s| s.trim().replace("\"", ""))
602 .collect();
603 if parts.len() == 2 {
604 let field = parts[0].clone();
605 let value = parts[1].clone();
606
607 match output {
609 serde_json::Value::Object(obj) => {
610 if let Some(serde_json::Value::String(v)) = obj.get(&field) {
611 *v == value
612 } else if let Some(serde_json::Value::Number(n)) =
613 obj.get(&field)
614 {
615 n.to_string() == value
616 } else {
617 false
618 }
619 }
620 _ => false,
621 }
622 } else {
623 match output {
625 serde_json::Value::String(s) => s == cond,
626 serde_json::Value::Bool(b) => {
627 (*b && cond == "true") || (!*b && cond == "false")
628 }
629 _ => false,
630 }
631 }
632 } else {
633 match output {
635 serde_json::Value::String(s) => s == cond,
636 serde_json::Value::Bool(b) => {
637 (*b && cond == "true") || (!*b && cond == "false")
638 }
639 _ => false,
640 }
641 }
642 };
643
644 if condition_value {
645 return Ok(Some(to.clone()));
646 }
647 }
648 }
649
650 for (_, to, condition) in &candidate_edges {
652 if condition.is_none() {
653 return Ok(Some(to.clone()));
654 }
655 }
656
657 Ok(None)
658 }
659
660 pub async fn state(&self) -> WorkflowExecutionState {
662 self.state.read().await.clone()
663 }
664
665 pub async fn reset(&self) {
667 let mut state = self.state.write().await;
668 *state = WorkflowExecutionState::default();
669 state.variables = self.definition.global_variables.clone();
670 }
671}
672
673pub fn task_script(id: &str, name: &str, script: &str) -> ScriptNodeConfig {
679 ScriptNodeConfig::new(id, name, ScriptNodeType::Task).with_source(script)
680}
681
682pub fn condition_script(id: &str, name: &str, script: &str) -> ScriptNodeConfig {
684 ScriptNodeConfig::new(id, name, ScriptNodeType::Condition).with_source(script)
685}
686
687pub fn transform_script(id: &str, name: &str, script: &str) -> ScriptNodeConfig {
689 ScriptNodeConfig::new(id, name, ScriptNodeType::Transform).with_source(script)
690}
691
692pub fn validator_script(id: &str, name: &str, script: &str) -> ScriptNodeConfig {
694 ScriptNodeConfig::new(id, name, ScriptNodeType::Validator).with_source(script)
695}
696
697#[cfg(test)]
702mod tests {
703 use super::*;
704
705 #[tokio::test]
706 async fn test_script_node_execution() {
707 let engine = Arc::new(RhaiScriptEngine::new(ScriptEngineConfig::default()).unwrap());
708
709 let config = task_script(
710 "double_node",
711 "Double Value",
712 r#"
713 let result = input * 2;
714 result
715 "#,
716 );
717
718 let node = ScriptWorkflowNode::new(config, engine).await.unwrap();
719 let result = node.execute(serde_json::json!(21)).await.unwrap();
720
721 assert!(result.success);
722 assert_eq!(result.output, serde_json::json!(42));
723 }
724
725 #[tokio::test]
726 async fn test_condition_node() {
727 let engine = Arc::new(RhaiScriptEngine::new(ScriptEngineConfig::default()).unwrap());
728
729 let config = condition_script("check_positive", "Check Positive", "input > 0");
730
731 let node = ScriptWorkflowNode::new(config, engine).await.unwrap();
732
733 assert!(
734 node.execute_as_condition(serde_json::json!(10))
735 .await
736 .unwrap()
737 );
738 assert!(
739 !node
740 .execute_as_condition(serde_json::json!(-5))
741 .await
742 .unwrap()
743 );
744 }
745
746 #[tokio::test]
747 async fn test_workflow_definition() {
748 let mut workflow = ScriptWorkflowDefinition::new("test_wf", "Test Workflow");
749
750 workflow
751 .add_node(task_script("start", "Start", "input"))
752 .add_node(task_script("process", "Process", "input * 2"))
753 .add_node(task_script("end", "End", "input"))
754 .add_edge("start", "process")
755 .add_edge("process", "end")
756 .set_start("start")
757 .add_end("end");
758
759 let errors = workflow.validate().unwrap();
760 assert!(errors.is_empty(), "Validation errors: {:?}", errors);
761 }
762
763 #[tokio::test]
764 async fn test_simple_workflow_execution() {
765 let mut workflow = ScriptWorkflowDefinition::new("calc_wf", "Calculator Workflow");
766
767 workflow
768 .add_node(task_script("double", "Double", "input * 2"))
769 .add_node(task_script("add_ten", "Add Ten", "input + 10"))
770 .add_node(task_script("done", "Done", "input"))
771 .add_edge("double", "add_ten")
772 .add_edge("add_ten", "done")
773 .set_start("double")
774 .add_end("done");
775
776 let executor = ScriptWorkflowExecutor::new(workflow, ScriptEngineConfig::default())
777 .await
778 .unwrap();
779
780 let result = executor.execute(serde_json::json!(5)).await.unwrap();
781 assert_eq!(result, serde_json::json!(20));
783 }
784
785 #[tokio::test]
786 async fn test_conditional_workflow() {
787 let mut workflow = ScriptWorkflowDefinition::new("cond_wf", "Conditional Workflow");
788
789 workflow
790 .add_node(condition_script(
791 "check",
792 "Check Value",
793 r#"if input > 10 { "high" } else { "low" }"#,
794 ))
795 .add_node(task_script("high_path", "High Path", r#""HIGH: " + input"#))
796 .add_node(task_script("low_path", "Low Path", r#""LOW: " + input"#))
797 .add_node(task_script("end", "End", "input"))
798 .add_conditional_edge("check", "high_path", "high")
799 .add_conditional_edge("check", "low_path", "low")
800 .add_edge("high_path", "end")
801 .add_edge("low_path", "end")
802 .set_start("check")
803 .add_end("end");
804
805 let executor = ScriptWorkflowExecutor::new(workflow, ScriptEngineConfig::default())
806 .await
807 .unwrap();
808
809 let result = executor.execute(serde_json::json!(20)).await.unwrap();
810 assert!(result.as_str().unwrap().starts_with("HIGH:"));
811
812 executor.reset().await;
813
814 let result = executor.execute(serde_json::json!(5)).await.unwrap();
815 assert!(result.as_str().unwrap().starts_with("LOW:"));
816 }
817}