use crate::engine::error::{DataflowError, ErrorInfo, Result};
use crate::engine::executor::InternalExecutor;
use crate::engine::message::{AuditTrail, Change, Message};
use crate::engine::task_executor::TaskExecutor;
use crate::engine::workflow::Workflow;
use chrono::Utc;
use log::{debug, error, info, warn};
use serde_json::json;
use std::sync::Arc;
pub struct WorkflowExecutor {
task_executor: Arc<TaskExecutor>,
internal_executor: Arc<InternalExecutor>,
}
impl WorkflowExecutor {
pub fn new(task_executor: Arc<TaskExecutor>, internal_executor: Arc<InternalExecutor>) -> Self {
Self {
task_executor,
internal_executor,
}
}
pub async fn execute(&self, workflow: &Workflow, message: &mut Message) -> Result<bool> {
let context_arc = message.get_context_arc();
let should_execute = self
.internal_executor
.evaluate_condition(workflow.condition_index, context_arc)?;
if !should_execute {
debug!("Skipping workflow {} - condition not met", workflow.id);
return Ok(false);
}
match self.execute_tasks(workflow, message).await {
Ok(_) => {
info!("Successfully completed workflow: {}", workflow.id);
Ok(true)
}
Err(e) if workflow.continue_on_error => {
warn!(
"Workflow {} encountered error but continuing: {:?}",
workflow.id, e
);
message.errors.push(
ErrorInfo::builder(
"WORKFLOW_ERROR",
format!("Workflow {} error: {}", workflow.id, e),
)
.workflow_id(&workflow.id)
.build(),
);
Ok(true)
}
Err(e) => {
error!("Workflow {} failed: {:?}", workflow.id, e);
Err(e)
}
}
}
async fn execute_tasks(&self, workflow: &Workflow, message: &mut Message) -> Result<()> {
for task in &workflow.tasks {
let context_arc = message.get_context_arc();
let should_execute = self
.internal_executor
.evaluate_condition(task.condition_index, context_arc)?;
if !should_execute {
debug!("Skipping task {} - condition not met", task.id);
continue;
}
let result = self.task_executor.execute(task, message).await;
self.handle_task_result(
result,
&workflow.id,
&task.id,
task.continue_on_error,
message,
)?;
}
Ok(())
}
fn handle_task_result(
&self,
result: Result<(usize, Vec<Change>)>,
workflow_id: &str,
task_id: &str,
continue_on_error: bool,
message: &mut Message,
) -> Result<()> {
match result {
Ok((status, changes)) => {
message.audit_trail.push(AuditTrail {
timestamp: Utc::now(),
workflow_id: Arc::from(workflow_id),
task_id: Arc::from(task_id),
status,
changes,
});
if let Some(metadata) = message.context["metadata"].as_object_mut() {
if let Some(progress) = metadata.get_mut("progress") {
if let Some(progress_obj) = progress.as_object_mut() {
progress_obj.insert("workflow_id".to_string(), json!(workflow_id));
progress_obj.insert("task_id".to_string(), json!(task_id));
progress_obj.insert("status_code".to_string(), json!(status));
}
} else {
metadata.insert(
"progress".to_string(),
json!({
"workflow_id": workflow_id,
"task_id": task_id,
"status_code": status
}),
);
}
}
message.invalidate_context_cache();
if (400..500).contains(&status) {
warn!("Task {} returned client error status: {}", task_id, status);
} else if status >= 500 {
error!("Task {} returned server error status: {}", task_id, status);
if !continue_on_error {
return Err(DataflowError::Task(format!(
"Task {} failed with status {}",
task_id, status
)));
}
}
Ok(())
}
Err(e) => {
error!("Task {} failed: {:?}", task_id, e);
message.audit_trail.push(AuditTrail {
timestamp: Utc::now(),
workflow_id: Arc::from(workflow_id),
task_id: Arc::from(task_id),
status: 500,
changes: vec![],
});
message.errors.push(
ErrorInfo::builder("TASK_ERROR", format!("Task {} error: {}", task_id, e))
.workflow_id(workflow_id)
.task_id(task_id)
.build(),
);
if !continue_on_error { Err(e) } else { Ok(()) }
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::compiler::LogicCompiler;
use serde_json::json;
use std::collections::HashMap;
#[tokio::test]
async fn test_workflow_executor_skip_condition() {
let workflow_json = r#"{
"id": "test_workflow",
"name": "Test Workflow",
"condition": false,
"tasks": [{
"id": "dummy_task",
"name": "Dummy Task",
"function": {
"name": "map",
"input": {"mappings": []}
}
}]
}"#;
let mut compiler = LogicCompiler::new();
let mut workflow = Workflow::from_json(workflow_json).unwrap();
let workflows = compiler.compile_workflows(vec![workflow.clone()]);
if let Some(compiled_workflow) = workflows.get("test_workflow") {
workflow = compiled_workflow.clone();
}
let (datalogic, logic_cache) = compiler.into_parts();
let internal_executor = Arc::new(InternalExecutor::new(datalogic.clone(), logic_cache));
let task_executor = Arc::new(TaskExecutor::new(
Arc::new(HashMap::new()),
internal_executor.clone(),
datalogic,
));
let workflow_executor = WorkflowExecutor::new(task_executor, internal_executor);
let mut message = Message::from_value(&json!({}));
let executed = workflow_executor
.execute(&workflow, &mut message)
.await
.unwrap();
assert!(!executed);
assert_eq!(message.audit_trail.len(), 0);
}
#[tokio::test]
async fn test_workflow_executor_execute_success() {
let workflow_json = r#"{
"id": "test_workflow",
"name": "Test Workflow",
"condition": true,
"tasks": [{
"id": "dummy_task",
"name": "Dummy Task",
"function": {
"name": "map",
"input": {"mappings": []}
}
}]
}"#;
let mut compiler = LogicCompiler::new();
let mut workflow = Workflow::from_json(workflow_json).unwrap();
let workflows = compiler.compile_workflows(vec![workflow.clone()]);
if let Some(compiled_workflow) = workflows.get("test_workflow") {
workflow = compiled_workflow.clone();
}
let (datalogic, logic_cache) = compiler.into_parts();
let internal_executor = Arc::new(InternalExecutor::new(datalogic.clone(), logic_cache));
let task_executor = Arc::new(TaskExecutor::new(
Arc::new(HashMap::new()),
internal_executor.clone(),
datalogic,
));
let workflow_executor = WorkflowExecutor::new(task_executor, internal_executor);
let mut message = Message::from_value(&json!({}));
let executed = workflow_executor
.execute(&workflow, &mut message)
.await
.unwrap();
assert!(executed);
}
}