use super::engine::{RhaiScriptEngine, ScriptContext, ScriptEngineConfig, ScriptResult};
use anyhow::{Result, anyhow};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ScriptNodeType {
Task,
Condition,
Transform,
Validator,
Aggregator,
LoopCondition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScriptNodeConfig {
pub id: String,
pub name: String,
pub node_type: ScriptNodeType,
pub script_source: Option<String>,
pub script_path: Option<String>,
pub entry_function: Option<String>,
pub enable_cache: bool,
pub timeout_ms: u64,
pub max_retries: u32,
pub metadata: HashMap<String, String>,
}
impl Default for ScriptNodeConfig {
fn default() -> Self {
Self {
id: String::new(),
name: String::new(),
node_type: ScriptNodeType::Task,
script_source: None,
script_path: None,
entry_function: None,
enable_cache: true,
timeout_ms: 30000,
max_retries: 0,
metadata: HashMap::new(),
}
}
}
impl ScriptNodeConfig {
pub fn new(id: &str, name: &str, node_type: ScriptNodeType) -> Self {
Self {
id: id.to_string(),
name: name.to_string(),
node_type,
..Default::default()
}
}
pub fn with_source(mut self, source: &str) -> Self {
self.script_source = Some(source.to_string());
self
}
pub fn with_path(mut self, path: &str) -> Self {
self.script_path = Some(path.to_string());
self
}
pub fn with_entry(mut self, function: &str) -> Self {
self.entry_function = Some(function.to_string());
self
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = timeout_ms;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScriptNodeResult {
pub node_id: String,
pub success: bool,
pub output: serde_json::Value,
pub error: Option<String>,
pub execution_time_ms: u64,
pub retry_count: u32,
pub logs: Vec<String>,
}
pub struct ScriptWorkflowNode {
config: ScriptNodeConfig,
engine: Arc<RhaiScriptEngine>,
cached_script_id: Option<String>,
}
impl ScriptWorkflowNode {
pub async fn new(config: ScriptNodeConfig, engine: Arc<RhaiScriptEngine>) -> Result<Self> {
let mut node = Self {
config,
engine,
cached_script_id: None,
};
if node.config.enable_cache {
node.compile_script().await?;
}
Ok(node)
}
async fn compile_script(&mut self) -> Result<()> {
let source = self.get_script_source().await?;
let script_id = format!("node_{}", self.config.id);
self.engine
.compile_and_cache(&script_id, &self.config.name, &source)
.await?;
self.cached_script_id = Some(script_id);
Ok(())
}
async fn get_script_source(&self) -> Result<String> {
if let Some(ref source) = self.config.script_source {
Ok(source.clone())
} else if let Some(ref path) = self.config.script_path {
tokio::fs::read_to_string(path)
.await
.map_err(|e| anyhow!("Failed to read script file: {}", e))
} else {
Err(anyhow!("No script source or path specified"))
}
}
pub async fn execute(&self, input: serde_json::Value) -> Result<ScriptNodeResult> {
let start_time = std::time::Instant::now();
let mut last_error = None;
let mut retry_count = 0;
let mut context = ScriptContext::new()
.with_node(&self.config.id)
.with_variable("input", input.clone())?;
for (k, v) in &self.config.metadata {
context.metadata.insert(k.clone(), v.clone());
}
while retry_count <= self.config.max_retries {
let result = self.execute_once(&context).await;
match result {
Ok(script_result) if script_result.success => {
return Ok(ScriptNodeResult {
node_id: self.config.id.clone(),
success: true,
output: script_result.value,
error: None,
execution_time_ms: start_time.elapsed().as_millis() as u64,
retry_count,
logs: script_result.logs,
});
}
Ok(script_result) => {
last_error = script_result.error;
}
Err(e) => {
last_error = Some(e.to_string());
}
}
if retry_count < self.config.max_retries {
let delay = std::time::Duration::from_millis(100 * 2u64.pow(retry_count));
tokio::time::sleep(delay).await;
}
retry_count += 1;
}
Ok(ScriptNodeResult {
node_id: self.config.id.clone(),
success: false,
output: serde_json::Value::Null,
error: last_error,
execution_time_ms: start_time.elapsed().as_millis() as u64,
retry_count: retry_count.saturating_sub(1),
logs: Vec::new(),
})
}
async fn execute_once(&self, context: &ScriptContext) -> Result<ScriptResult> {
if let Some(ref script_id) = self.cached_script_id {
if let Some(ref entry) = self.config.entry_function {
let input = context
.get_variable::<serde_json::Value>("input")
.unwrap_or(serde_json::Value::Null);
let result: serde_json::Value = self
.engine
.call_function(script_id, entry, vec![input], context)
.await?;
Ok(ScriptResult::success(result, 0))
} else {
self.engine.execute_compiled(script_id, context).await
}
} else {
let source = self.get_script_source().await?;
self.engine.execute(&source, context).await
}
}
pub async fn execute_as_condition(&self, input: serde_json::Value) -> Result<bool> {
let result = self.execute(input).await?;
if !result.success {
return Err(anyhow!(
result
.error
.unwrap_or_else(|| "Condition execution failed".into())
));
}
match &result.output {
serde_json::Value::Bool(b) => Ok(*b),
serde_json::Value::Number(n) => Ok(n.as_i64().unwrap_or(0) != 0),
serde_json::Value::String(s) => Ok(!s.is_empty() && s != "false" && s != "0"),
serde_json::Value::Array(arr) => Ok(!arr.is_empty()),
serde_json::Value::Object(obj) => Ok(!obj.is_empty()),
serde_json::Value::Null => Ok(false),
}
}
pub fn config(&self) -> &ScriptNodeConfig {
&self.config
}
pub fn id(&self) -> &str {
&self.config.id
}
pub fn name(&self) -> &str {
&self.config.name
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScriptWorkflowDefinition {
pub id: String,
pub name: String,
pub description: String,
pub nodes: Vec<ScriptNodeConfig>,
pub edges: Vec<(String, String, Option<String>)>,
pub start_node: String,
pub end_nodes: Vec<String>,
pub global_variables: HashMap<String, serde_json::Value>,
}
impl ScriptWorkflowDefinition {
pub fn new(id: &str, name: &str) -> Self {
Self {
id: id.to_string(),
name: name.to_string(),
description: String::new(),
nodes: Vec::new(),
edges: Vec::new(),
start_node: String::new(),
end_nodes: Vec::new(),
global_variables: HashMap::new(),
}
}
pub async fn from_yaml(path: &str) -> Result<Self> {
let content = tokio::fs::read_to_string(path).await?;
serde_yaml::from_str(&content).map_err(|e| anyhow!("Failed to parse YAML: {}", e))
}
pub async fn from_json(path: &str) -> Result<Self> {
let content = tokio::fs::read_to_string(path).await?;
serde_json::from_str(&content).map_err(|e| anyhow!("Failed to parse JSON: {}", e))
}
pub fn add_node(&mut self, config: ScriptNodeConfig) -> &mut Self {
self.nodes.push(config);
self
}
pub fn add_edge(&mut self, from: &str, to: &str) -> &mut Self {
self.edges.push((from.to_string(), to.to_string(), None));
self
}
pub fn add_conditional_edge(&mut self, from: &str, to: &str, condition: &str) -> &mut Self {
self.edges.push((
from.to_string(),
to.to_string(),
Some(condition.to_string()),
));
self
}
pub fn set_start(&mut self, node_id: &str) -> &mut Self {
self.start_node = node_id.to_string();
self
}
pub fn add_end(&mut self, node_id: &str) -> &mut Self {
self.end_nodes.push(node_id.to_string());
self
}
pub fn validate(&self) -> Result<Vec<String>> {
let mut errors = Vec::new();
if self.id.is_empty() {
errors.push("Workflow ID is required".to_string());
}
if self.start_node.is_empty() {
errors.push("Start node is not specified".to_string());
}
if self.end_nodes.is_empty() {
errors.push("At least one end node is required".to_string());
}
let node_ids: std::collections::HashSet<_> = self.nodes.iter().map(|n| &n.id).collect();
if !node_ids.contains(&self.start_node) {
errors.push(format!("Start node '{}' not found", self.start_node));
}
for end_node in &self.end_nodes {
if !node_ids.contains(end_node) {
errors.push(format!("End node '{}' not found", end_node));
}
}
for (from, to, _) in &self.edges {
if !node_ids.contains(from) {
errors.push(format!("Edge source node '{}' not found", from));
}
if !node_ids.contains(to) {
errors.push(format!("Edge target node '{}' not found", to));
}
}
Ok(errors)
}
}
pub struct ScriptWorkflowExecutor {
#[allow(dead_code)]
engine: Arc<RhaiScriptEngine>,
nodes: HashMap<String, ScriptWorkflowNode>,
definition: ScriptWorkflowDefinition,
state: Arc<RwLock<WorkflowExecutionState>>,
}
#[derive(Debug, Clone, Default)]
pub struct WorkflowExecutionState {
pub current_node: Option<String>,
pub node_outputs: HashMap<String, serde_json::Value>,
pub variables: HashMap<String, serde_json::Value>,
pub execution_history: Vec<String>,
pub completed: bool,
pub final_result: Option<serde_json::Value>,
pub error: Option<String>,
}
impl ScriptWorkflowExecutor {
pub async fn new(
definition: ScriptWorkflowDefinition,
engine_config: ScriptEngineConfig,
) -> Result<Self> {
let engine = Arc::new(RhaiScriptEngine::new(engine_config)?);
let mut nodes = HashMap::new();
for node_config in &definition.nodes {
let node = ScriptWorkflowNode::new(node_config.clone(), engine.clone()).await?;
nodes.insert(node_config.id.clone(), node);
}
let mut state = WorkflowExecutionState::default();
state.variables = definition.global_variables.clone();
Ok(Self {
engine,
nodes,
definition,
state: Arc::new(RwLock::new(state)),
})
}
pub async fn execute(&self, input: serde_json::Value) -> Result<serde_json::Value> {
let mut state = self.state.write().await;
state.current_node = Some(self.definition.start_node.clone());
state.variables.insert("input".to_string(), input.clone());
let mut current_value = input;
while let Some(ref node_id) = state.current_node.clone() {
let node = self
.nodes
.get(node_id)
.ok_or_else(|| anyhow!("Node not found: {}", node_id))?;
if self.definition.end_nodes.contains(node_id) {
let result = node.execute(current_value.clone()).await?;
if !result.success {
state.error = result.error;
return Err(anyhow!("Node {} execution failed", node_id));
}
state
.node_outputs
.insert(node_id.clone(), result.output.clone());
state.completed = true;
state.final_result = Some(result.output.clone());
break;
}
state.execution_history.push(node_id.clone());
let result = node.execute(current_value.clone()).await?;
if !result.success {
let error = result.error.clone(); state.error = error.clone();
let error_detail = error.unwrap_or_else(|| "unknown error".to_string());
return Err(anyhow!(
"Node {} execution failed: {}",
node_id,
error_detail
));
}
state
.node_outputs
.insert(node_id.clone(), result.output.clone());
current_value = result.output;
let next_node = self.determine_next_node(node_id, ¤t_value).await?;
state.current_node = next_node;
}
Ok(state
.final_result
.clone()
.unwrap_or(serde_json::Value::Null))
}
async fn determine_next_node(
&self,
current_node_id: &str,
output: &serde_json::Value,
) -> Result<Option<String>> {
let candidate_edges: Vec<_> = self
.definition
.edges
.iter()
.filter(|(from, _, _)| from == current_node_id)
.collect();
if candidate_edges.is_empty() {
return Ok(None);
}
if candidate_edges.len() == 1 && candidate_edges[0].2.is_none() {
return Ok(Some(candidate_edges[0].1.clone()));
}
for (_, to, condition) in &candidate_edges {
if let Some(cond) = condition {
let condition_value = {
if cond.contains("==") {
let parts: Vec<_> = cond
.split("==")
.map(|s| s.trim().replace("\"", ""))
.collect();
if parts.len() == 2 {
let field = parts[0].clone();
let value = parts[1].clone();
match output {
serde_json::Value::Object(obj) => {
if let Some(serde_json::Value::String(v)) = obj.get(&field) {
*v == value
} else if let Some(serde_json::Value::Number(n)) =
obj.get(&field)
{
n.to_string() == value
} else {
false
}
}
_ => false,
}
} else {
match output {
serde_json::Value::String(s) => s == cond,
serde_json::Value::Bool(b) => {
(*b && cond == "true") || (!*b && cond == "false")
}
_ => false,
}
}
} else {
match output {
serde_json::Value::String(s) => s == cond,
serde_json::Value::Bool(b) => {
(*b && cond == "true") || (!*b && cond == "false")
}
_ => false,
}
}
};
if condition_value {
return Ok(Some(to.clone()));
}
}
}
for (_, to, condition) in &candidate_edges {
if condition.is_none() {
return Ok(Some(to.clone()));
}
}
Ok(None)
}
pub async fn state(&self) -> WorkflowExecutionState {
self.state.read().await.clone()
}
pub async fn reset(&self) {
let mut state = self.state.write().await;
*state = WorkflowExecutionState::default();
state.variables = self.definition.global_variables.clone();
}
}
pub fn task_script(id: &str, name: &str, script: &str) -> ScriptNodeConfig {
ScriptNodeConfig::new(id, name, ScriptNodeType::Task).with_source(script)
}
pub fn condition_script(id: &str, name: &str, script: &str) -> ScriptNodeConfig {
ScriptNodeConfig::new(id, name, ScriptNodeType::Condition).with_source(script)
}
pub fn transform_script(id: &str, name: &str, script: &str) -> ScriptNodeConfig {
ScriptNodeConfig::new(id, name, ScriptNodeType::Transform).with_source(script)
}
pub fn validator_script(id: &str, name: &str, script: &str) -> ScriptNodeConfig {
ScriptNodeConfig::new(id, name, ScriptNodeType::Validator).with_source(script)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_script_node_execution() {
let engine = Arc::new(RhaiScriptEngine::new(ScriptEngineConfig::default()).unwrap());
let config = task_script(
"double_node",
"Double Value",
r#"
let result = input * 2;
result
"#,
);
let node = ScriptWorkflowNode::new(config, engine).await.unwrap();
let result = node.execute(serde_json::json!(21)).await.unwrap();
assert!(result.success);
assert_eq!(result.output, serde_json::json!(42));
}
#[tokio::test]
async fn test_condition_node() {
let engine = Arc::new(RhaiScriptEngine::new(ScriptEngineConfig::default()).unwrap());
let config = condition_script("check_positive", "Check Positive", "input > 0");
let node = ScriptWorkflowNode::new(config, engine).await.unwrap();
assert!(
node.execute_as_condition(serde_json::json!(10))
.await
.unwrap()
);
assert!(
!node
.execute_as_condition(serde_json::json!(-5))
.await
.unwrap()
);
}
#[tokio::test]
async fn test_workflow_definition() {
let mut workflow = ScriptWorkflowDefinition::new("test_wf", "Test Workflow");
workflow
.add_node(task_script("start", "Start", "input"))
.add_node(task_script("process", "Process", "input * 2"))
.add_node(task_script("end", "End", "input"))
.add_edge("start", "process")
.add_edge("process", "end")
.set_start("start")
.add_end("end");
let errors = workflow.validate().unwrap();
assert!(errors.is_empty(), "Validation errors: {:?}", errors);
}
#[tokio::test]
async fn test_simple_workflow_execution() {
let mut workflow = ScriptWorkflowDefinition::new("calc_wf", "Calculator Workflow");
workflow
.add_node(task_script("double", "Double", "input * 2"))
.add_node(task_script("add_ten", "Add Ten", "input + 10"))
.add_node(task_script("done", "Done", "input"))
.add_edge("double", "add_ten")
.add_edge("add_ten", "done")
.set_start("double")
.add_end("done");
let executor = ScriptWorkflowExecutor::new(workflow, ScriptEngineConfig::default())
.await
.unwrap();
let result = executor.execute(serde_json::json!(5)).await.unwrap();
assert_eq!(result, serde_json::json!(20));
}
#[tokio::test]
async fn test_conditional_workflow() {
let mut workflow = ScriptWorkflowDefinition::new("cond_wf", "Conditional Workflow");
workflow
.add_node(condition_script(
"check",
"Check Value",
r#"if input > 10 { "high" } else { "low" }"#,
))
.add_node(task_script("high_path", "High Path", r#""HIGH: " + input"#))
.add_node(task_script("low_path", "Low Path", r#""LOW: " + input"#))
.add_node(task_script("end", "End", "input"))
.add_conditional_edge("check", "high_path", "high")
.add_conditional_edge("check", "low_path", "low")
.add_edge("high_path", "end")
.add_edge("low_path", "end")
.set_start("check")
.add_end("end");
let executor = ScriptWorkflowExecutor::new(workflow, ScriptEngineConfig::default())
.await
.unwrap();
let result = executor.execute(serde_json::json!(20)).await.unwrap();
assert!(result.as_str().unwrap().starts_with("HIGH:"));
executor.reset().await;
let result = executor.execute(serde_json::json!(5)).await.unwrap();
assert!(result.as_str().unwrap().starts_with("LOW:"));
}
}