use async_trait::async_trait;
use crate::{ExecutionMetrics, ExecutionResult, FlowPattern, RunError, RunId, RunStatus};
use super::{build_capability_output, execute_worker, PatternContext, PatternExecutor};
pub struct EscalationExecutor;
impl EscalationExecutor {
pub fn new() -> Self {
EscalationExecutor
}
}
impl Default for EscalationExecutor {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PatternExecutor for EscalationExecutor {
fn name(&self) -> &'static str {
"escalation"
}
async fn execute(
&self,
ctx: &PatternContext,
runtime: &dyn crate::RuntimeAdapter,
cancel: &crate::CancellationToken,
) -> Result<ExecutionResult, RunError> {
let (primary, chain) = match &ctx.swarm.flow {
FlowPattern::Escalation { primary, chain } => (primary.clone(), chain.clone()),
_ => {
return Err(RunError::PatternError {
pattern: "escalation".into(),
step: "flow".into(),
message: "EscalationExecutor requires Escalation pattern in flow".into(),
})
}
};
let escalation_sequence: Vec<&str> = std::iter::once(primary.as_str())
.chain(chain.iter().map(|s| s.as_str()))
.collect();
if cancel.is_cancelled().await {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Cancelled,
artifacts: vec![],
error: Some(RunError::Cancelled {
reason: "Execution cancelled".into(),
}),
metrics: ExecutionMetrics::default(),
output: None,
});
}
let mut last_error: Option<RunError> = None;
let mut metrics = ExecutionMetrics::default();
for (level, worker_name) in escalation_sequence.iter().enumerate() {
if cancel.is_cancelled().await {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Cancelled,
artifacts: vec![],
error: Some(RunError::Cancelled {
reason: "Execution cancelled".into(),
}),
metrics,
output: None,
});
}
let worker = ctx
.get_worker(worker_name)
.ok_or_else(|| RunError::PatternError {
pattern: "escalation".into(),
step: worker_name.to_string(),
message: format!(
"Worker '{}' not found in swarm at escalation level {}",
worker_name, level
),
})?;
let result =
execute_worker(worker, runtime, &ctx.runtime_ctx, &ctx.scope, cancel).await?;
match result.status {
RunStatus::Completed => {
let mut final_scope = ctx.scope.clone();
if let Some(output) = &result.output {
final_scope.add_step_output(worker_name.to_string(), output.clone());
}
let exec_result = ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Completed,
artifacts: result.artifacts,
error: None,
metrics: ExecutionMetrics {
wall_time_ms: metrics.wall_time_ms + result.metrics.wall_time_ms,
cpu_time_ms: result.metrics.cpu_time_ms,
peak_memory_bytes: result.metrics.peak_memory_bytes,
retries: level as u32,
selection_trace: None,
},
output: None,
};
return Ok(build_capability_output(
exec_result,
&ctx.swarm,
&final_scope,
));
}
RunStatus::Failed => {
last_error = result.error.clone();
metrics.wall_time_ms += result.metrics.wall_time_ms;
metrics.retries += 1;
}
RunStatus::Cancelled => {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Cancelled,
artifacts: vec![],
error: Some(RunError::Cancelled {
reason: "Execution cancelled".into(),
}),
metrics,
output: None,
});
}
_ => {}
}
}
Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Failed,
artifacts: vec![],
error: last_error.or_else(|| {
Some(RunError::RuntimeError {
message: "All escalation levels failed".into(),
})
}),
metrics,
output: None,
})
}
async fn on_failure(
&self,
ctx: &mut PatternContext,
_runtime: &dyn crate::RuntimeAdapter,
failed_worker: &str,
_error: &RunError,
) -> Result<bool, RunError> {
ctx.state.failed.push(failed_worker.to_string());
let chain = match &ctx.swarm.flow {
FlowPattern::Escalation { chain, .. } => chain,
_ => return Ok(false),
};
Ok(ctx.state.failed.len() < chain.len() + 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{CancellationToken, ExecutionContext, FlowPattern, RuntimeKind, SwarmFile};
#[test]
fn test_escalation_executor_name() {
let executor = EscalationExecutor::new();
assert_eq!(executor.name(), "escalation");
}
#[tokio::test]
async fn test_escalation_executor_wrong_pattern() {
let executor = EscalationExecutor::new();
let swarm = SwarmFile::new("test", FlowPattern::Sequence { steps: vec![] });
let ctx = PatternContext::new(swarm, ExecutionContext::new("ctx", RuntimeKind::Local));
let cancel = CancellationToken::new();
let result = executor
.execute(&ctx, &crate::LocalRuntime::new(), &cancel)
.await;
assert!(result.is_err());
}
}