use anyhow::Result;
use flowbuilder_context::SharedContext;
use flowbuilder_core::{
ActionSpec, ExecutionNode, ExecutionPhase, ExecutionPlan, Executor,
ExecutorStatus, PhaseExecutionMode, RetryStrategy,
};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Semaphore;
pub struct EnhancedTaskExecutor {
config: ExecutorConfig,
status: ExecutorStatus,
semaphore: Arc<Semaphore>,
stats: ExecutionStats,
}
#[derive(Debug, Clone)]
pub struct ExecutorConfig {
pub max_concurrent_tasks: usize,
pub default_timeout: u64,
pub enable_performance_monitoring: bool,
pub enable_detailed_logging: bool,
}
impl Default for ExecutorConfig {
fn default() -> Self {
Self {
max_concurrent_tasks: 10,
default_timeout: 30000, enable_performance_monitoring: true,
enable_detailed_logging: false,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ExecutionStats {
pub total_tasks: usize,
pub successful_tasks: usize,
pub failed_tasks: usize,
pub skipped_tasks: usize,
pub total_execution_time: Duration,
pub average_execution_time: Duration,
}
impl Default for EnhancedTaskExecutor {
fn default() -> Self {
Self::new()
}
}
impl EnhancedTaskExecutor {
pub fn new() -> Self {
let config = ExecutorConfig::default();
let semaphore = Arc::new(Semaphore::new(config.max_concurrent_tasks));
Self {
config,
status: ExecutorStatus::Idle,
semaphore,
stats: ExecutionStats::default(),
}
}
pub fn with_config(config: ExecutorConfig) -> Self {
let semaphore = Arc::new(Semaphore::new(config.max_concurrent_tasks));
Self {
config,
status: ExecutorStatus::Idle,
semaphore,
stats: ExecutionStats::default(),
}
}
pub async fn execute_plan(
&mut self,
plan: ExecutionPlan,
context: SharedContext,
) -> Result<ExecutionResult> {
self.status = ExecutorStatus::Running;
let start_time = Instant::now();
if self.config.enable_detailed_logging {
println!("开始执行计划: {}", plan.metadata.workflow_name);
println!("总阶段数: {}", plan.phases.len());
println!("总节点数: {}", plan.metadata.total_nodes);
}
let mut result = ExecutionResult {
plan_id: plan.metadata.plan_id.clone(),
start_time,
end_time: None,
phase_results: Vec::new(),
total_duration: Duration::default(),
success: true,
error_message: None,
};
self.setup_context(&plan, context.clone()).await?;
for (index, phase) in plan.phases.iter().enumerate() {
if self.config.enable_detailed_logging {
println!(
"执行阶段 {}: {} ({:?})",
index + 1,
phase.name,
phase.execution_mode
);
}
let phase_start = Instant::now();
let phase_result =
match self.execute_phase(phase, context.clone()).await {
Ok(r) => r,
Err(e) => {
result.success = false;
result.error_message = Some(e.to_string());
PhaseResult {
phase_id: phase.id.clone(),
phase_name: phase.name.clone(),
start_time: phase_start,
end_time: Some(Instant::now()),
duration: phase_start.elapsed(),
success: false,
error_message: Some(e.to_string()),
node_results: Vec::new(),
}
}
};
result.phase_results.push(phase_result);
if !result.success {
break;
}
}
result.end_time = Some(Instant::now());
result.total_duration = start_time.elapsed();
self.update_stats(&result);
self.status = ExecutorStatus::Idle;
if self.config.enable_detailed_logging {
println!("执行计划完成,总用时: {:?}", result.total_duration);
}
Ok(result)
}
async fn execute_phase(
&mut self,
phase: &ExecutionPhase,
context: SharedContext,
) -> Result<PhaseResult> {
let start_time = Instant::now();
let mut phase_result = PhaseResult {
phase_id: phase.id.clone(),
phase_name: phase.name.clone(),
start_time,
end_time: None,
duration: Duration::default(),
success: true,
error_message: None,
node_results: Vec::new(),
};
if let Some(condition) = &phase.condition {
if self.config.enable_detailed_logging {
println!(" 检查阶段条件: {condition}");
}
}
match phase.execution_mode {
PhaseExecutionMode::Sequential => {
for node in &phase.nodes {
let node_result =
self.execute_node(node, context.clone()).await?;
phase_result.node_results.push(node_result);
}
}
PhaseExecutionMode::Parallel => {
let mut handles = Vec::new();
for node in &phase.nodes {
let node_clone = node.clone();
let context_clone = context.clone();
let semaphore = self.semaphore.clone();
let config = self.config.clone();
let handle = tokio::spawn(async move {
let _permit = semaphore.acquire().await.unwrap();
Self::execute_node_static(
&node_clone,
context_clone,
&config,
)
.await
});
handles.push(handle);
}
for handle in handles {
match handle.await {
Ok(node_result) => match node_result {
Ok(result) => {
phase_result.node_results.push(result)
}
Err(e) => {
phase_result.success = false;
phase_result.error_message =
Some(e.to_string());
return Err(e);
}
},
Err(e) => {
phase_result.success = false;
phase_result.error_message = Some(e.to_string());
return Err(anyhow::anyhow!("任务执行失败: {}", e));
}
}
}
}
PhaseExecutionMode::Conditional { ref condition } => {
if self.config.enable_detailed_logging {
println!(" 检查条件: {condition}");
}
let condition_met = true;
if condition_met {
for node in &phase.nodes {
let node_result =
self.execute_node(node, context.clone()).await?;
phase_result.node_results.push(node_result);
}
} else if self.config.enable_detailed_logging {
println!(" 跳过阶段 {} (条件不满足)", phase.name);
}
}
}
phase_result.end_time = Some(Instant::now());
phase_result.duration = start_time.elapsed();
Ok(phase_result)
}
async fn execute_node(
&mut self,
node: &ExecutionNode,
context: SharedContext,
) -> Result<NodeResult> {
Self::execute_node_static(node, context, &self.config).await
}
async fn execute_node_static(
node: &ExecutionNode,
context: SharedContext,
config: &ExecutorConfig,
) -> Result<NodeResult> {
let start_time = Instant::now();
let mut result = NodeResult {
node_id: node.id.clone(),
node_name: node.name.clone(),
start_time,
end_time: None,
duration: Duration::default(),
success: true,
error_message: None,
retry_count: 0,
};
if config.enable_detailed_logging {
println!(" 执行节点: {} - {}", node.id, node.name);
}
if let Some(condition) = &node.condition {
if config.enable_detailed_logging {
println!(" 检查节点条件: {condition}");
}
let condition_met = true;
if !condition_met {
if config.enable_detailed_logging {
println!(" 跳过节点 {} (条件不满足)", node.name);
}
result.end_time = Some(Instant::now());
result.duration = start_time.elapsed();
return Ok(result);
}
}
let max_retries = node
.retry_config
.as_ref()
.map(|c| c.max_retries)
.unwrap_or(0);
let mut retries = 0;
loop {
let execute_result =
Self::execute_node_action(node, context.clone(), config).await;
match execute_result {
Ok(()) => {
result.success = true;
break;
}
Err(e) => {
if retries < max_retries {
retries += 1;
result.retry_count = retries;
if config.enable_detailed_logging {
println!(
" 重试节点 {} ({}/{})",
node.name, retries, max_retries
);
}
if let Some(retry_config) = &node.retry_config {
let delay = match retry_config.strategy {
RetryStrategy::Fixed => retry_config.delay,
RetryStrategy::Exponential { multiplier } => {
(retry_config.delay as f64
* multiplier.powi(retries as i32))
as u64
}
RetryStrategy::Linear { increment } => {
retry_config.delay
+ (increment * retries as u64)
}
};
tokio::time::sleep(Duration::from_millis(delay))
.await;
}
continue;
} else {
result.success = false;
result.error_message = Some(e.to_string());
break;
}
}
}
}
result.end_time = Some(Instant::now());
result.duration = start_time.elapsed();
Ok(result)
}
async fn execute_node_action(
node: &ExecutionNode,
context: SharedContext,
config: &ExecutorConfig,
) -> Result<()> {
let action_spec = &node.action_spec;
let timeout_duration = node
.timeout_config
.as_ref()
.map(|c| Duration::from_millis(c.duration))
.unwrap_or_else(|| Duration::from_millis(config.default_timeout));
let action_future = Self::execute_action_by_type(action_spec, context);
match tokio::time::timeout(timeout_duration, action_future).await {
Ok(result) => result,
Err(_) => {
if config.enable_detailed_logging {
println!(" 节点 {} 执行超时", node.name);
}
Err(anyhow::anyhow!("节点 {} 执行超时", node.name))
}
}
}
async fn execute_action_by_type(
action_spec: &ActionSpec,
context: SharedContext,
) -> Result<()> {
match action_spec.action_type.as_str() {
"builtin" => {
tokio::time::sleep(Duration::from_millis(100)).await;
println!(" 执行内置动作");
}
"cmd" => {
tokio::time::sleep(Duration::from_millis(200)).await;
println!(" 执行命令动作");
}
"http" => {
tokio::time::sleep(Duration::from_millis(300)).await;
println!(" 执行HTTP动作");
}
"wasm" => {
tokio::time::sleep(Duration::from_millis(150)).await;
println!(" 执行WASM动作");
}
_ => {
return Err(anyhow::anyhow!(
"不支持的动作类型: {}",
action_spec.action_type
));
}
}
for (key, value) in &action_spec.outputs {
let mut guard = context.lock().await;
guard.set_variable(key.clone(), format!("{value:?}"));
}
Ok(())
}
async fn setup_context(
&self,
plan: &ExecutionPlan,
context: SharedContext,
) -> Result<()> {
let mut guard = context.lock().await;
for (key, value) in &plan.env_vars {
guard.set_variable(format!("env.{key}"), format!("{value:?}"));
}
for (key, value) in &plan.flow_vars {
guard.set_variable(format!("flow.{key}"), format!("{value:?}"));
}
Ok(())
}
fn update_stats(&mut self, result: &ExecutionResult) {
self.stats.total_execution_time = result.total_duration;
for phase_result in &result.phase_results {
for node_result in &phase_result.node_results {
self.stats.total_tasks += 1;
if node_result.success {
self.stats.successful_tasks += 1;
} else {
self.stats.failed_tasks += 1;
}
}
}
if self.stats.total_tasks > 0 {
self.stats.average_execution_time = Duration::from_nanos(
self.stats.total_execution_time.as_nanos() as u64
/ self.stats.total_tasks as u64,
);
}
}
pub fn get_stats(&self) -> &ExecutionStats {
&self.stats
}
}
impl Executor for EnhancedTaskExecutor {
type Input = (ExecutionPlan, SharedContext);
type Output = ExecutionResult;
type Error = anyhow::Error;
async fn execute(
&mut self,
input: Self::Input,
) -> Result<Self::Output, Self::Error> {
let (plan, context) = input;
self.execute_plan(plan, context).await
}
fn status(&self) -> ExecutorStatus {
self.status.clone()
}
async fn stop(&mut self) -> Result<(), Self::Error> {
self.status = ExecutorStatus::Stopped;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ExecutionResult {
pub plan_id: String,
pub start_time: Instant,
pub end_time: Option<Instant>,
pub phase_results: Vec<PhaseResult>,
pub total_duration: Duration,
pub success: bool,
pub error_message: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PhaseResult {
pub phase_id: String,
pub phase_name: String,
pub start_time: Instant,
pub end_time: Option<Instant>,
pub duration: Duration,
pub success: bool,
pub error_message: Option<String>,
pub node_results: Vec<NodeResult>,
}
#[derive(Debug, Clone)]
pub struct NodeResult {
pub node_id: String,
pub node_name: String,
pub start_time: Instant,
pub end_time: Option<Instant>,
pub duration: Duration,
pub success: bool,
pub error_message: Option<String>,
pub retry_count: u32,
}
#[cfg(test)]
mod tests {
use super::*;
use flowbuilder_core::{ActionSpec, ExecutionNode};
use std::collections::HashMap;
#[tokio::test]
async fn test_executor_creation() {
let executor = EnhancedTaskExecutor::new();
assert_eq!(executor.status(), ExecutorStatus::Idle);
}
#[tokio::test]
async fn test_node_execution() {
let config = ExecutorConfig::default();
let context = Arc::new(tokio::sync::Mutex::new(
flowbuilder_context::FlowContext::default(),
));
let node = ExecutionNode::new(
"test_node".to_string(),
"Test Node".to_string(),
ActionSpec {
action_type: "builtin".to_string(),
parameters: HashMap::new(),
outputs: HashMap::new(),
},
);
let result =
EnhancedTaskExecutor::execute_node_static(&node, context, &config)
.await;
assert!(result.is_ok());
let node_result = result.unwrap();
assert!(node_result.success);
assert_eq!(node_result.node_id, "test_node");
}
}