use async_trait::async_trait;
use std::time::Duration;
use crate::{
ExecutionMetrics, ExecutionResult, FlowPattern, RecoveryPolicy, RestartPolicy, RunError,
RunId, RunStatus,
};
use super::{build_capability_output, execute_worker, PatternContext, PatternExecutor};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecoveryPhase {
Retry,
Replan,
Decompose,
Exhausted,
}
impl RecoveryPhase {
pub fn name(&self) -> &'static str {
match self {
RecoveryPhase::Retry => "retry",
RecoveryPhase::Replan => "replan",
RecoveryPhase::Decompose => "decompose",
RecoveryPhase::Exhausted => "exhausted",
}
}
}
pub struct SupervisorExecutor {
max_restarts: u32,
restart_delay: Duration,
}
impl SupervisorExecutor {
pub fn new() -> Self {
SupervisorExecutor {
max_restarts: 3,
restart_delay: Duration::from_millis(100),
}
}
pub fn with_max_restarts(mut self, max: u32) -> Self {
self.max_restarts = max;
self
}
pub fn with_restart_delay(mut self, delay: Duration) -> Self {
self.restart_delay = delay;
self
}
#[allow(clippy::too_many_arguments)]
async fn execute_with_recovery(
&self,
ctx: &PatternContext,
runtime: &dyn crate::RuntimeAdapter,
cancel: &crate::CancellationToken,
worker_name: &str,
worker: &crate::Worker,
recovery_policy: &RecoveryPolicy,
final_scope: &mut crate::template::Scope,
) -> Result<(ExecutionResult, u32), RunError> {
let mut metrics = ExecutionMetrics::default();
let mut artifacts = vec![];
let mut current_phase = RecoveryPhase::Retry;
let mut retry_count = 0;
loop {
if cancel.is_cancelled().await {
return Ok((
ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Cancelled,
artifacts,
error: Some(RunError::Cancelled {
reason: "Execution cancelled".into(),
}),
metrics,
output: None,
},
retry_count,
));
}
let result = match current_phase {
RecoveryPhase::Retry => {
let exec_result =
execute_worker(worker, runtime, &ctx.runtime_ctx, &ctx.scope, cancel)
.await?;
retry_count += 1;
exec_result
}
RecoveryPhase::Replan => {
self.execute_replan(ctx, runtime, cancel, recovery_policy, final_scope)
.await?
}
RecoveryPhase::Decompose => {
self.execute_decompose(ctx, runtime, cancel, recovery_policy, final_scope)
.await?
}
RecoveryPhase::Exhausted => {
return Ok((
ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Failed,
artifacts,
error: Some(RunError::RuntimeError {
message: format!(
"All recovery phases exhausted for worker '{}'",
worker_name
),
}),
metrics,
output: None,
},
retry_count,
));
}
};
match result.status {
RunStatus::Completed => {
artifacts.extend(result.artifacts);
metrics.wall_time_ms += result.metrics.wall_time_ms;
metrics.retries += retry_count;
if let Some(output) = &result.output {
final_scope.add_step_output(worker_name.to_string(), output.clone());
}
return Ok((
ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Completed,
artifacts,
error: None,
metrics,
output: result.output,
},
retry_count,
));
}
RunStatus::Failed => {
current_phase = self.next_phase(current_phase, retry_count, recovery_policy);
if current_phase == RecoveryPhase::Exhausted {
metrics.retries += retry_count;
return Ok((
ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Failed,
artifacts,
error: Some(RunError::RuntimeError {
message: format!(
"All recovery phases exhausted for worker '{}'",
worker_name
),
}),
metrics,
output: None,
},
retry_count,
));
}
tokio::time::sleep(self.restart_delay).await;
}
RunStatus::Cancelled => {
return Ok((
ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Cancelled,
artifacts,
error: Some(RunError::Cancelled {
reason: "Execution cancelled".into(),
}),
metrics,
output: None,
},
retry_count,
));
}
_ => {}
}
}
}
fn next_phase(
&self,
current: RecoveryPhase,
retry_count: u32,
policy: &RecoveryPolicy,
) -> RecoveryPhase {
match current {
RecoveryPhase::Retry => {
if retry_count < policy.retry_attempts {
RecoveryPhase::Retry
} else if policy.replan_expr.is_some() || policy.replan_fallback.is_some() {
RecoveryPhase::Replan
} else if policy.decompose_swarm.is_some() {
RecoveryPhase::Decompose
} else {
RecoveryPhase::Exhausted
}
}
RecoveryPhase::Replan => {
if policy.decompose_swarm.is_some() {
RecoveryPhase::Decompose
} else {
RecoveryPhase::Exhausted
}
}
RecoveryPhase::Decompose | RecoveryPhase::Exhausted => RecoveryPhase::Exhausted,
}
}
async fn execute_replan(
&self,
ctx: &PatternContext,
runtime: &dyn crate::RuntimeAdapter,
cancel: &crate::CancellationToken,
policy: &RecoveryPolicy,
scope: &mut crate::template::Scope,
) -> Result<ExecutionResult, RunError> {
use crate::template::{ExpressionResolver, HandlebarsResolver};
let resolver = HandlebarsResolver::new();
let alt_worker_name = if let Some(expr) = &policy.replan_expr {
let resolved = resolver.resolve(expr, scope).map_err(|e| RunError::PatternError {
pattern: "supervisor".into(),
step: "replan_expr".into(),
message: format!("Failed to resolve replan_expr '{}': {}", expr, e),
})?;
if ctx
.swarm
.workers
.iter()
.any(|w| w.name == resolved.as_str())
{
Some(resolved)
} else {
policy.replan_fallback.clone()
}
} else {
policy.replan_fallback.clone()
};
if let Some(worker_name) = alt_worker_name {
let worker = ctx.get_worker(&worker_name).ok_or_else(|| RunError::PatternError {
pattern: "supervisor".into(),
step: "replan".into(),
message: format!("Replan worker '{}' not found", worker_name),
})?;
execute_worker(worker, runtime, &ctx.runtime_ctx, scope, cancel).await
} else {
Err(RunError::PatternError {
pattern: "supervisor".into(),
step: "replan".into(),
message: "No replan worker available".into(),
})
}
}
async fn execute_decompose(
&self,
ctx: &PatternContext,
runtime: &dyn crate::RuntimeAdapter,
cancel: &crate::CancellationToken,
policy: &RecoveryPolicy,
scope: &mut crate::template::Scope,
) -> Result<ExecutionResult, RunError> {
use crate::template::{ExpressionResolver, HandlebarsResolver};
let swarm_path = policy.decompose_swarm.as_ref().ok_or_else(|| RunError::PatternError {
pattern: "supervisor".into(),
step: "decompose".into(),
message: "No decompose swarm specified".into(),
})?;
let current_depth = ctx
.state
.custom
.get("nesting_depth")
.and_then(|d| d.parse::<u32>().ok())
.unwrap_or(0);
if current_depth >= crate::MAX_NESTING_DEPTH {
return Err(RunError::PatternError {
pattern: "supervisor".into(),
step: "decompose".into(),
message: format!(
"Nesting depth {} exceeds maximum {}",
current_depth,
crate::MAX_NESTING_DEPTH
),
});
}
let base_dir = ctx
.swarm
.file_path
.as_ref()
.and_then(|p| p.parent())
.unwrap_or(std::path::Path::new("."));
let delegate_path = base_dir.join(swarm_path);
let sub_swarm = crate::SwarmFile::from_yaml_file(&delegate_path)?;
let resolver = HandlebarsResolver::new();
let mut nested_input = serde_json::Map::new();
for (key, value) in &policy.decompose_input {
let resolved = resolver.resolve_value(value, scope).map_err(|e| RunError::PatternError {
pattern: "supervisor".into(),
step: "decompose_input".into(),
message: format!("Failed to resolve input '{}': {}", key, e),
})?;
nested_input.insert(key.clone(), resolved);
}
let nested_scope = crate::template::Scope::with_input(serde_json::Value::Object(nested_input));
let mut nested_scope = nested_scope;
nested_scope.steps = scope.steps.clone();
nested_scope.env = scope.env.clone();
let delegate_ctx = crate::ExecutionContext::new(
format!("decompose-{}", ctx.runtime_ctx.id),
runtime.kind(),
);
let nested_ctx = PatternContext::new(sub_swarm, delegate_ctx);
let mut nested_ctx = nested_ctx;
nested_ctx.scope = nested_scope;
nested_ctx
.state
.custom
.insert("nesting_depth".to_string(), (current_depth + 1).to_string());
let executor = super::get_executor(&nested_ctx.swarm.flow);
let result = executor.execute(&nested_ctx, runtime, cancel).await?;
if !policy.decompose_output.is_empty() {
if let Some(ref output) = &result.output {
let mut output_scope = scope.clone();
output_scope.add_step_output("decompose".to_string(), output.clone());
let mut mapped_output = serde_json::Map::new();
for (key, value) in &policy.decompose_output {
let resolved = resolver.resolve_value(value, &output_scope).map_err(|e| RunError::PatternError {
pattern: "supervisor".into(),
step: "decompose_output".into(),
message: format!("Failed to resolve output '{}': {}", key, e),
})?;
mapped_output.insert(key.clone(), resolved);
}
scope.add_step_output("decompose".to_string(), serde_json::Value::Object(mapped_output));
}
} else if let Some(output) = &result.output {
scope.add_step_output("decompose".to_string(), output.clone());
}
Ok(result)
}
}
impl Default for SupervisorExecutor {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PatternExecutor for SupervisorExecutor {
fn name(&self) -> &'static str {
"supervisor"
}
async fn execute(
&self,
ctx: &PatternContext,
runtime: &dyn crate::RuntimeAdapter,
cancel: &crate::CancellationToken,
) -> Result<ExecutionResult, RunError> {
let (workers, restart_policy, recovery_policy) = match &ctx.swarm.flow {
FlowPattern::Supervisor {
workers,
restart_policy,
recovery_policy,
} => (workers.clone(), *restart_policy, recovery_policy.clone()),
_ => {
return Err(RunError::PatternError {
pattern: "supervisor".into(),
step: "flow".into(),
message: "SupervisorExecutor requires Supervisor pattern in flow".into(),
})
}
};
if workers.is_empty() {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Completed,
artifacts: vec![],
error: None,
metrics: ExecutionMetrics::default(),
output: None,
});
}
if cancel.is_cancelled().await {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Cancelled,
artifacts: vec![],
error: Some(RunError::Cancelled {
reason: "Execution cancelled".into(),
}),
metrics: ExecutionMetrics::default(),
output: None,
});
}
let mut artifacts = vec![];
let mut metrics = ExecutionMetrics::default();
let mut final_scope = ctx.scope.clone();
for worker_name in workers.iter() {
let worker = ctx
.get_worker(worker_name)
.ok_or_else(|| RunError::PatternError {
pattern: "supervisor".into(),
step: worker_name.clone(),
message: format!("Worker '{}' not found in swarm", worker_name),
})?;
if let Some(policy) = &recovery_policy {
let (result, retries) = self
.execute_with_recovery(
ctx,
runtime,
cancel,
worker_name,
worker,
policy,
&mut final_scope,
)
.await?;
metrics.retries += retries;
metrics.wall_time_ms += result.metrics.wall_time_ms;
artifacts.extend(result.artifacts);
if result.status != RunStatus::Completed {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: result.status,
artifacts,
error: result.error,
metrics,
output: None,
});
}
} else {
let mut attempt = 0;
loop {
if cancel.is_cancelled().await {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Cancelled,
artifacts,
error: Some(RunError::Cancelled {
reason: "Execution cancelled".into(),
}),
metrics,
output: None,
});
}
let result =
execute_worker(worker, runtime, &ctx.runtime_ctx, &ctx.scope, cancel)
.await?;
match result.status {
RunStatus::Completed => {
artifacts.extend(result.artifacts);
metrics.wall_time_ms += result.metrics.wall_time_ms;
metrics.retries += attempt;
if let Some(output) = &result.output {
final_scope.add_step_output(worker_name.to_string(), output.clone());
}
break; }
RunStatus::Failed => {
match restart_policy {
RestartPolicy::Never => {
metrics.retries += attempt;
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Failed,
artifacts,
error: result.error,
metrics,
output: None,
});
}
RestartPolicy::OnFailure => {
attempt += 1;
if attempt >= self.max_restarts {
metrics.retries += attempt;
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Failed,
artifacts,
error: Some(RunError::RuntimeError {
message: format!(
"Max restarts ({}) exceeded",
self.max_restarts
),
}),
metrics,
output: None,
});
}
tokio::time::sleep(self.restart_delay).await;
}
RestartPolicy::Always => {
attempt += 1;
if attempt >= self.max_restarts {
metrics.retries += attempt;
break;
}
tokio::time::sleep(self.restart_delay).await;
}
}
}
RunStatus::Cancelled => {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Cancelled,
artifacts,
error: Some(RunError::Cancelled {
reason: "Execution cancelled".into(),
}),
metrics,
output: None,
});
}
_ => {}
}
}
}
}
let result = ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Completed,
artifacts,
error: None,
metrics,
output: None,
};
Ok(build_capability_output(result, &ctx.swarm, &final_scope))
}
async fn on_failure(
&self,
ctx: &mut PatternContext,
_runtime: &dyn crate::RuntimeAdapter,
failed_worker: &str,
_error: &RunError,
) -> Result<bool, RunError> {
ctx.state.failed.push(failed_worker.to_string());
let (restart_policy, recovery_policy) = match &ctx.swarm.flow {
FlowPattern::Supervisor {
restart_policy,
recovery_policy,
..
} => (*restart_policy, recovery_policy.clone()),
_ => (RestartPolicy::Never, None),
};
Ok(recovery_policy.is_some() || restart_policy != RestartPolicy::Never)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::template::Scope;
use crate::{CancellationToken, ExecutionContext, FlowPattern, RuntimeKind, SwarmFile};
use serde_json::json;
#[test]
fn test_supervisor_executor_name() {
let executor = SupervisorExecutor::new();
assert_eq!(executor.name(), "supervisor");
}
#[tokio::test]
async fn test_supervisor_executor_wrong_pattern() {
let executor = SupervisorExecutor::new();
let swarm = SwarmFile::new("test", FlowPattern::Sequence { steps: vec![] });
let ctx = PatternContext::new(swarm, ExecutionContext::new("ctx", RuntimeKind::Local));
let cancel = CancellationToken::new();
let result = executor
.execute(&ctx, &crate::LocalRuntime::new(), &cancel)
.await;
assert!(result.is_err());
}
#[test]
fn test_supervisor_scope_output_write() {
let worker_output = json!({ "status": "ok", "data": { "count": 7 } });
let mut scope = Scope::with_input(json!({}));
scope.add_step_output("my_worker".to_string(), worker_output.clone());
let data = scope.to_json();
assert_eq!(data["steps"]["my_worker"]["output"]["status"], json!("ok"));
assert_eq!(
data["steps"]["my_worker"]["output"]["data"]["count"],
json!(7)
);
}
#[test]
fn test_nested_field_access_template() {
use crate::template::{resolve_worker_input, Scope};
use std::collections::HashMap;
let mut scope = Scope::with_input(json!({}));
scope.add_step_output(
"worker_a".to_string(),
json!({
"data": { "count": 5, "labels": ["x", "y"] }
}),
);
let mut input: HashMap<String, serde_json::Value> = HashMap::new();
input.insert(
"cnt".to_string(),
json!("{{steps.worker_a.output.data.count}}"),
);
let resolved = resolve_worker_input(&input, &scope).unwrap();
assert_eq!(resolved["cnt"], json!(5));
}
#[test]
fn test_recovery_policy_creation() {
let policy = RecoveryPolicy::new();
assert_eq!(policy.retry_attempts, 3);
assert!(policy.replan_expr.is_none());
assert!(policy.replan_fallback.is_none());
assert!(policy.decompose_swarm.is_none());
}
#[test]
fn test_recovery_policy_with_retry_attempts() {
let policy = RecoveryPolicy::new().with_retry_attempts(5);
assert_eq!(policy.retry_attempts, 5);
}
#[test]
fn test_recovery_policy_with_replan() {
let policy = RecoveryPolicy::new()
.with_replan("{{input.alternative}}")
.with_replan_fallback("default_worker");
assert_eq!(policy.replan_expr, Some("{{input.alternative}}".to_string()));
assert_eq!(policy.replan_fallback, Some("default_worker".to_string()));
}
#[test]
fn test_recovery_policy_with_decompose() {
let policy = RecoveryPolicy::new()
.with_decompose("sub-swarm.yaml");
assert_eq!(policy.decompose_swarm, Some(std::path::PathBuf::from("sub-swarm.yaml")));
}
#[test]
fn test_recovery_phase_next_phase() {
let executor = SupervisorExecutor::new();
let policy = RecoveryPolicy::new().with_retry_attempts(3);
assert_eq!(executor.next_phase(RecoveryPhase::Retry, 0, &policy), RecoveryPhase::Retry);
assert_eq!(executor.next_phase(RecoveryPhase::Retry, 3, &policy), RecoveryPhase::Exhausted);
let policy_with_replan = RecoveryPolicy::new()
.with_retry_attempts(2)
.with_replan_fallback("alt");
assert_eq!(executor.next_phase(RecoveryPhase::Retry, 2, &policy_with_replan), RecoveryPhase::Replan);
let policy_with_decompose = RecoveryPolicy::new()
.with_retry_attempts(2)
.with_decompose("sub.yaml");
assert_eq!(executor.next_phase(RecoveryPhase::Retry, 2, &policy_with_decompose), RecoveryPhase::Decompose);
assert_eq!(executor.next_phase(RecoveryPhase::Replan, 3, &policy_with_decompose), RecoveryPhase::Decompose);
assert_eq!(executor.next_phase(RecoveryPhase::Replan, 3, &policy_with_replan), RecoveryPhase::Exhausted);
}
#[test]
fn test_recovery_phase_name() {
assert_eq!(RecoveryPhase::Retry.name(), "retry");
assert_eq!(RecoveryPhase::Replan.name(), "replan");
assert_eq!(RecoveryPhase::Decompose.name(), "decompose");
assert_eq!(RecoveryPhase::Exhausted.name(), "exhausted");
}
#[test]
fn test_supervisor_flow_with_recovery_policy() {
let policy = RecoveryPolicy::new()
.with_retry_attempts(5)
.with_replan_fallback("backup");
let swarm = SwarmFile::new(
"test",
FlowPattern::Supervisor {
workers: vec!["main".into()],
restart_policy: RestartPolicy::OnFailure,
recovery_policy: Some(policy),
},
);
if let FlowPattern::Supervisor {
recovery_policy: Some(rp),
..
} = &swarm.flow
{
assert_eq!(rp.retry_attempts, 5);
assert_eq!(rp.replan_fallback, Some("backup".to_string()));
} else {
panic!("Expected Supervisor pattern with recovery_policy");
}
}
}