use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::acl::ACL;
use crate::approval::ApprovalHandler;
use crate::config::Config;
use crate::context::Context;
use crate::errors::{ErrorCode, ModuleError};
use crate::middleware::manager::MiddlewareManager;
use crate::module::Module;
use crate::registry::registry::Registry;
use crate::utils::helpers::match_pattern;
#[async_trait]
pub trait Step: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn removable(&self) -> bool;
fn replaceable(&self) -> bool;
fn match_modules(&self) -> Option<&[String]> {
None
}
fn ignore_errors(&self) -> bool {
false
}
fn pure(&self) -> bool {
false
}
fn timeout_ms(&self) -> u64 {
0
}
fn requires(&self) -> &[&str] {
&[]
}
fn provides(&self) -> &[&str] {
&[]
}
async fn execute(&self, ctx: &mut PipelineContext) -> Result<StepResult, ModuleError>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepResult {
pub action: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub skip_to: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub explanation: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub confidence: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub alternatives: Option<Vec<String>>,
}
impl Default for StepResult {
fn default() -> Self {
Self {
action: "continue".into(),
skip_to: None,
explanation: None,
confidence: None,
alternatives: None,
}
}
}
impl StepResult {
#[must_use]
pub fn continue_step() -> Self {
Self::default()
}
#[must_use]
pub fn abort(explanation: &str) -> Self {
Self {
action: "abort".into(),
explanation: Some(explanation.to_string()),
..Default::default()
}
}
#[must_use]
pub fn skip_to(target: &str) -> Self {
Self {
action: "skip_to".into(),
skip_to: Some(target.to_string()),
..Default::default()
}
}
}
pub struct PipelineContext {
pub module_id: String,
pub inputs: serde_json::Value,
pub context: Context<serde_json::Value>,
pub module: Option<Arc<dyn Module>>,
pub validated_inputs: Option<serde_json::Value>,
pub output: Option<serde_json::Value>,
pub validated_output: Option<serde_json::Value>,
pub dry_run: bool,
pub version_hint: Option<String>,
pub executed_middlewares: Vec<usize>,
pub registry: Option<Arc<Registry>>,
pub config: Option<Arc<Config>>,
pub acl: Option<Arc<ACL>>,
pub approval_handler: Option<Arc<dyn ApprovalHandler>>,
pub middleware_manager: Option<Arc<MiddlewareManager>>,
pub strategy_name: String,
pub trace: PipelineTrace,
}
impl PipelineContext {
pub fn new(
module_id: impl Into<String>,
inputs: serde_json::Value,
context: Context<serde_json::Value>,
strategy_name: impl Into<String>,
) -> Self {
let module_id = module_id.into();
let strategy_name = strategy_name.into();
Self {
trace: PipelineTrace::new(module_id.clone(), strategy_name.clone()),
module_id,
inputs,
context,
module: None,
validated_inputs: None,
output: None,
validated_output: None,
dry_run: false,
version_hint: None,
executed_middlewares: vec![],
registry: None,
config: None,
acl: None,
approval_handler: None,
middleware_manager: None,
strategy_name,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepTrace {
pub name: String,
pub duration_ms: f64,
pub result: StepResult,
pub skipped: bool,
pub decision_point: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub skip_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineTrace {
pub module_id: String,
pub strategy_name: String,
pub steps: Vec<StepTrace>,
pub total_duration_ms: f64,
pub success: bool,
}
impl PipelineTrace {
#[must_use]
pub fn new(module_id: String, strategy_name: String) -> Self {
Self {
module_id,
strategy_name,
steps: Vec::new(),
total_duration_ms: 0.0,
success: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StrategyInfo {
pub name: String,
pub step_count: usize,
pub step_names: Vec<String>,
pub description: String,
}
impl std::fmt::Display for StrategyInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}-step pipeline: {}",
self.step_count,
self.step_names.join(" \u{2192} ")
)
}
}
pub struct ExecutionStrategy {
name: String,
steps: Vec<Box<dyn Step>>,
}
impl std::fmt::Debug for ExecutionStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExecutionStrategy")
.field("name", &self.name)
.field("step_names", &self.step_names())
.field("step_count", &self.steps.len())
.finish()
}
}
struct PlaceholderStep;
#[async_trait]
impl Step for PlaceholderStep {
fn name(&self) -> &'static str {
"__placeholder__"
}
fn description(&self) -> &'static str {
""
}
fn removable(&self) -> bool {
true
}
fn replaceable(&self) -> bool {
true
}
async fn execute(&self, _ctx: &mut PipelineContext) -> Result<StepResult, ModuleError> {
Ok(StepResult::continue_step())
}
}
impl ExecutionStrategy {
pub fn new(name: impl Into<String>, steps: Vec<Box<dyn Step>>) -> Result<Self, ModuleError> {
let name = name.into();
let mut seen = std::collections::HashSet::new();
for step in &steps {
if !seen.insert(step.name().to_string()) {
return Err(ModuleError::new(
ErrorCode::GeneralInvalidInput,
format!(
"Duplicate step name '{}' in strategy '{}'",
step.name(),
name,
),
));
}
}
let strategy = Self { name, steps };
strategy.validate_dependencies();
Ok(strategy)
}
fn validate_dependencies(&self) {
let mut provided = std::collections::HashSet::new();
for step in &self.steps {
for req in step.requires() {
if !provided.contains(*req) {
tracing::warn!(
step = step.name(),
requires = *req,
"Step requires '{}', but no preceding step provides it. \
This may cause runtime errors.",
req,
);
}
}
for p in step.provides() {
provided.insert(*p);
}
}
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
pub fn set_name(&mut self, name: impl Into<String>) {
self.name = name.into();
}
#[must_use]
pub fn step_names(&self) -> Vec<String> {
self.steps.iter().map(|s| s.name().to_string()).collect()
}
#[must_use]
pub fn steps(&self) -> &[Box<dyn Step>] {
&self.steps
}
pub fn insert_after(&mut self, anchor: &str, step: Box<dyn Step>) -> Result<(), ModuleError> {
self.validate_no_duplicate(step.name())?;
let idx = self.find_step_index(anchor)?;
self.steps.insert(idx + 1, step);
self.validate_dependencies();
Ok(())
}
pub fn insert_before(&mut self, anchor: &str, step: Box<dyn Step>) -> Result<(), ModuleError> {
self.validate_no_duplicate(step.name())?;
let idx = self.find_step_index(anchor)?;
self.steps.insert(idx, step);
self.validate_dependencies();
Ok(())
}
pub fn remove(&mut self, step_name: &str) -> Result<(), ModuleError> {
let idx = self.find_step_index(step_name)?;
if !self.steps[idx].removable() {
return Err(ModuleError::new(
ErrorCode::GeneralInvalidInput,
format!("Step '{step_name}' is not removable"),
));
}
self.steps.remove(idx);
Ok(())
}
pub fn replace(&mut self, step_name: &str, new_step: Box<dyn Step>) -> Result<(), ModuleError> {
let idx = self.find_step_index(step_name)?;
if !self.steps[idx].replaceable() {
return Err(ModuleError::new(
ErrorCode::GeneralInvalidInput,
format!("Step '{step_name}' is not replaceable"),
));
}
self.steps[idx] = new_step;
Ok(())
}
pub fn replace_with<F>(&mut self, step_name: &str, wrapper: F) -> Result<(), ModuleError>
where
F: FnOnce(Box<dyn Step>) -> Box<dyn Step>,
{
let idx = self.find_step_index(step_name)?;
let old = std::mem::replace(&mut self.steps[idx], Box::new(PlaceholderStep));
self.steps[idx] = wrapper(old);
Ok(())
}
#[must_use]
pub fn info(&self) -> StrategyInfo {
let step_names = self.step_names();
let description = self
.steps
.iter()
.map(|s| format!("{}: {}", s.name(), s.description()))
.collect::<Vec<_>>()
.join("; ");
StrategyInfo {
name: self.name.clone(),
step_count: self.steps.len(),
step_names,
description,
}
}
fn find_step_index(&self, step_name: &str) -> Result<usize, ModuleError> {
self.steps
.iter()
.position(|s| s.name() == step_name)
.ok_or_else(|| {
ModuleError::new(
ErrorCode::GeneralInvalidInput,
format!("Step '{}' not found in strategy '{}'", step_name, self.name),
)
})
}
fn validate_no_duplicate(&self, name: &str) -> Result<(), ModuleError> {
if self.steps.iter().any(|s| s.name() == name) {
return Err(ModuleError::new(
ErrorCode::GeneralInvalidInput,
format!(
"Step name '{}' already exists in strategy '{}'",
name, self.name,
),
));
}
Ok(())
}
}
pub struct PipelineEngine;
impl PipelineEngine {
pub async fn run(
strategy: &ExecutionStrategy,
ctx: &mut PipelineContext,
) -> Result<(Option<serde_json::Value>, PipelineTrace), ModuleError> {
Self::run_inner(strategy, ctx, None).await
}
pub async fn run_until(
strategy: &ExecutionStrategy,
ctx: &mut PipelineContext,
stop_before_step: &str,
) -> Result<(Option<serde_json::Value>, PipelineTrace), ModuleError> {
Self::run_inner(strategy, ctx, Some(stop_before_step)).await
}
#[allow(clippy::too_many_lines)] async fn run_inner(
strategy: &ExecutionStrategy,
ctx: &mut PipelineContext,
stop_before_step: Option<&str>,
) -> Result<(Option<serde_json::Value>, PipelineTrace), ModuleError> {
let pipeline_start = std::time::Instant::now();
let steps = strategy.steps();
let mut idx: usize = 0;
while idx < steps.len() {
let step = &steps[idx];
if let Some(stop_name) = stop_before_step {
if step.name() == stop_name {
break;
}
}
let step_match_modules = step.match_modules();
let step_ignore_errors = step.ignore_errors();
let step_pure = step.pure();
let step_timeout_ms = step.timeout_ms();
if let Some(patterns) = step_match_modules {
let matched = patterns
.iter()
.any(|pattern| match_pattern(pattern, &ctx.module_id));
if !matched {
ctx.trace.steps.push(StepTrace {
name: step.name().to_string(),
duration_ms: 0.0,
result: StepResult::continue_step(),
skipped: true,
decision_point: false,
skip_reason: Some("no_match".to_string()),
});
idx += 1;
continue;
}
}
if ctx.dry_run && !step_pure {
ctx.trace.steps.push(StepTrace {
name: step.name().to_string(),
duration_ms: 0.0,
result: StepResult::continue_step(),
skipped: true,
decision_point: false,
skip_reason: Some("dry_run".to_string()),
});
idx += 1;
continue;
}
let step_start = std::time::Instant::now();
let exec_result = if step_timeout_ms > 0 {
match tokio::time::timeout(
std::time::Duration::from_millis(step_timeout_ms),
step.execute(ctx),
)
.await
{
Ok(r) => r,
Err(_elapsed) => Err(ModuleError::new(
ErrorCode::ModuleTimeout,
format!(
"Step '{}' timed out after {}ms",
step.name(),
step_timeout_ms
),
)),
}
} else {
step.execute(ctx).await
};
let duration_ms = step_start.elapsed().as_secs_f64() * 1000.0;
let result = match exec_result {
Ok(r) => r,
Err(err) => {
if step_ignore_errors {
tracing::warn!(
step = step.name(),
error = %err,
"Step failed (ignored)"
);
ctx.trace.steps.push(StepTrace {
name: step.name().to_string(),
duration_ms,
result: StepResult {
action: "continue".into(),
explanation: Some(err.to_string()),
..Default::default()
},
skipped: false,
decision_point: false,
skip_reason: Some("error_ignored".to_string()),
});
idx += 1;
continue;
}
ctx.trace.steps.push(StepTrace {
name: step.name().to_string(),
duration_ms,
result: StepResult::abort(&err.to_string()),
skipped: false,
decision_point: false,
skip_reason: None,
});
ctx.trace.total_duration_ms = pipeline_start.elapsed().as_secs_f64() * 1000.0;
return Err(err);
}
};
let action = result.action.clone();
let skip_target = result.skip_to.clone();
ctx.trace.steps.push(StepTrace {
name: step.name().to_string(),
duration_ms,
result,
skipped: false,
decision_point: false,
skip_reason: None,
});
match action.as_str() {
"continue" => {
idx += 1;
}
"abort" => {
ctx.trace.total_duration_ms = pipeline_start.elapsed().as_secs_f64() * 1000.0;
ctx.trace.success = false;
return Ok((ctx.output.clone(), ctx.trace.clone()));
}
"skip_to" => {
let target = skip_target.as_deref().unwrap_or("");
let found = steps
.iter()
.enumerate()
.position(|(i, s)| i > idx && s.name() == target);
match found {
Some(target_idx) => {
for step in steps.iter().take(target_idx).skip(idx + 1) {
ctx.trace.steps.push(StepTrace {
name: step.name().to_string(),
duration_ms: 0.0,
result: StepResult::continue_step(),
skipped: true,
decision_point: false,
skip_reason: None,
});
}
idx = target_idx;
}
None => {
return Err(ModuleError::new(
ErrorCode::GeneralInvalidInput,
format!(
"skip_to target '{}' not found after step '{}'",
target,
step.name(),
),
));
}
}
}
other => {
return Err(ModuleError::new(
ErrorCode::GeneralInvalidInput,
format!("Unknown step action: '{other}'"),
));
}
}
}
ctx.trace.total_duration_ms = pipeline_start.elapsed().as_secs_f64() * 1000.0;
ctx.trace.success = true;
Ok((ctx.output.clone(), ctx.trace.clone()))
}
}
#[cfg(test)]
mod tests {
use super::*;
struct FakeStep {
name: String,
description: String,
removable: bool,
replaceable: bool,
}
impl FakeStep {
fn new(name: &str, removable: bool, replaceable: bool) -> Self {
Self {
name: name.to_string(),
description: format!("Fake step: {name}"),
removable,
replaceable,
}
}
fn boxed(name: &str, removable: bool, replaceable: bool) -> Box<dyn Step> {
Box::new(Self::new(name, removable, replaceable))
}
}
#[async_trait]
impl Step for FakeStep {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn removable(&self) -> bool {
self.removable
}
fn replaceable(&self) -> bool {
self.replaceable
}
async fn execute(&self, _ctx: &mut PipelineContext) -> Result<StepResult, ModuleError> {
Ok(StepResult::continue_step())
}
}
#[test]
fn test_step_result_continue() {
let r = StepResult::continue_step();
assert_eq!(r.action, "continue");
assert!(r.skip_to.is_none());
assert!(r.explanation.is_none());
}
#[test]
fn test_step_result_abort() {
let r = StepResult::abort("bad input");
assert_eq!(r.action, "abort");
assert_eq!(r.explanation.as_deref(), Some("bad input"));
}
#[test]
fn test_step_result_skip_to() {
let r = StepResult::skip_to("execute");
assert_eq!(r.action, "skip_to");
assert_eq!(r.skip_to.as_deref(), Some("execute"));
}
#[test]
fn test_step_result_serde_round_trip() {
let r = StepResult {
action: "abort".into(),
explanation: Some("denied".into()),
confidence: Some(0.95),
alternatives: Some(vec!["retry".into()]),
..Default::default()
};
let json = serde_json::to_string(&r).expect("serialize");
let r2: StepResult = serde_json::from_str(&json).expect("deserialize");
assert_eq!(r2.action, "abort");
assert_eq!(r2.confidence, Some(0.95));
}
#[test]
fn test_strategy_new_rejects_duplicate_names() {
let steps: Vec<Box<dyn Step>> = vec![
FakeStep::boxed("a", true, true),
FakeStep::boxed("a", true, true),
];
let result = ExecutionStrategy::new("test", steps);
assert!(result.is_err());
}
#[test]
fn test_strategy_step_names() {
let strategy = ExecutionStrategy::new(
"default",
vec![
FakeStep::boxed("one", true, true),
FakeStep::boxed("two", true, true),
FakeStep::boxed("three", true, true),
],
)
.expect("create strategy");
assert_eq!(strategy.step_names(), vec!["one", "two", "three"]);
}
#[test]
fn test_strategy_insert_after() {
let mut strategy = ExecutionStrategy::new(
"s",
vec![
FakeStep::boxed("a", true, true),
FakeStep::boxed("c", true, true),
],
)
.unwrap();
strategy
.insert_after("a", FakeStep::boxed("b", true, true))
.unwrap();
assert_eq!(strategy.step_names(), vec!["a", "b", "c"]);
}
#[test]
fn test_strategy_insert_before() {
let mut strategy = ExecutionStrategy::new(
"s",
vec![
FakeStep::boxed("a", true, true),
FakeStep::boxed("c", true, true),
],
)
.unwrap();
strategy
.insert_before("c", FakeStep::boxed("b", true, true))
.unwrap();
assert_eq!(strategy.step_names(), vec!["a", "b", "c"]);
}
#[test]
fn test_strategy_insert_rejects_duplicate() {
let mut strategy =
ExecutionStrategy::new("s", vec![FakeStep::boxed("a", true, true)]).unwrap();
let result = strategy.insert_after("a", FakeStep::boxed("a", true, true));
assert!(result.is_err());
}
#[test]
fn test_strategy_insert_rejects_unknown_anchor() {
let mut strategy =
ExecutionStrategy::new("s", vec![FakeStep::boxed("a", true, true)]).unwrap();
let result = strategy.insert_after("missing", FakeStep::boxed("b", true, true));
assert!(result.is_err());
}
#[test]
fn test_strategy_remove() {
let mut strategy = ExecutionStrategy::new(
"s",
vec![
FakeStep::boxed("a", true, true),
FakeStep::boxed("b", true, true),
],
)
.unwrap();
strategy.remove("a").unwrap();
assert_eq!(strategy.step_names(), vec!["b"]);
}
#[test]
fn test_strategy_remove_non_removable() {
let mut strategy =
ExecutionStrategy::new("s", vec![FakeStep::boxed("core", false, false)]).unwrap();
let result = strategy.remove("core");
assert!(result.is_err());
}
#[test]
fn test_strategy_replace() {
let mut strategy =
ExecutionStrategy::new("s", vec![FakeStep::boxed("a", true, true)]).unwrap();
strategy
.replace("a", FakeStep::boxed("a", true, true))
.unwrap();
assert_eq!(strategy.step_names(), vec!["a"]);
}
#[test]
fn test_strategy_replace_non_replaceable() {
let mut strategy =
ExecutionStrategy::new("s", vec![FakeStep::boxed("a", true, false)]).unwrap();
let result = strategy.replace("a", FakeStep::boxed("a", true, true));
assert!(result.is_err());
}
#[test]
fn test_strategy_info() {
let strategy = ExecutionStrategy::new(
"default",
vec![
FakeStep::boxed("one", true, true),
FakeStep::boxed("two", false, true),
],
)
.unwrap();
let info = strategy.info();
assert_eq!(info.name, "default");
assert_eq!(info.step_count, 2);
assert_eq!(info.step_names, vec!["one", "two"]);
assert!(info.description.contains("one"));
assert!(info.description.contains("two"));
}
#[test]
fn test_pipeline_trace_new() {
let trace = PipelineTrace::new("my_module".into(), "default".into());
assert_eq!(trace.module_id, "my_module");
assert_eq!(trace.strategy_name, "default");
assert!(trace.steps.is_empty());
assert!(!trace.success);
}
#[test]
fn test_pipeline_context_new() {
let ctx_inner = Context::<serde_json::Value>::anonymous();
let pctx = PipelineContext::new(
"test_module",
serde_json::json!({"key": "value"}),
ctx_inner,
"default",
);
assert_eq!(pctx.module_id, "test_module");
assert_eq!(pctx.strategy_name, "default");
assert!(pctx.module.is_none());
assert!(pctx.validated_inputs.is_none());
assert!(pctx.output.is_none());
assert!(pctx.validated_output.is_none());
assert_eq!(pctx.trace.module_id, "test_module");
}
struct CountingStep {
name: String,
invocations: Arc<std::sync::atomic::AtomicUsize>,
match_modules: Option<Vec<String>>,
}
#[async_trait]
impl Step for CountingStep {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.name
}
fn removable(&self) -> bool {
true
}
fn replaceable(&self) -> bool {
true
}
fn match_modules(&self) -> Option<&[String]> {
self.match_modules.as_deref()
}
async fn execute(&self, _ctx: &mut PipelineContext) -> Result<StepResult, ModuleError> {
self.invocations
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(StepResult::continue_step())
}
}
#[tokio::test]
async fn run_until_stops_before_named_step() {
use std::sync::atomic::{AtomicUsize, Ordering};
let pre_count = Arc::new(AtomicUsize::new(0));
let execute_count = Arc::new(AtomicUsize::new(0));
let post_count = Arc::new(AtomicUsize::new(0));
let steps: Vec<Box<dyn Step>> = vec![
Box::new(CountingStep {
name: "pre".into(),
invocations: Arc::clone(&pre_count),
match_modules: None,
}),
Box::new(CountingStep {
name: "execute".into(),
invocations: Arc::clone(&execute_count),
match_modules: None,
}),
Box::new(CountingStep {
name: "post".into(),
invocations: Arc::clone(&post_count),
match_modules: None,
}),
];
let strategy = ExecutionStrategy::new("test", steps).unwrap();
let mut pctx = PipelineContext::new(
"mod.x",
serde_json::json!({}),
Context::<serde_json::Value>::anonymous(),
"test",
);
let (_, trace) = PipelineEngine::run_until(&strategy, &mut pctx, "execute")
.await
.unwrap();
assert_eq!(pre_count.load(Ordering::SeqCst), 1, "'pre' runs once");
assert_eq!(
execute_count.load(Ordering::SeqCst),
0,
"'execute' must NOT run — run_until stops before it"
);
assert_eq!(post_count.load(Ordering::SeqCst), 0, "'post' must not run");
assert!(trace.success);
}
#[tokio::test]
async fn run_until_applies_match_modules_filtering() {
use std::sync::atomic::{AtomicUsize, Ordering};
let filtered_count = Arc::new(AtomicUsize::new(0));
let steps: Vec<Box<dyn Step>> = vec![
Box::new(CountingStep {
name: "filtered_step".into(),
invocations: Arc::clone(&filtered_count),
match_modules: Some(vec!["api.*".into()]),
}),
Box::new(CountingStep {
name: "execute".into(),
invocations: Arc::new(AtomicUsize::new(0)),
match_modules: None,
}),
];
let strategy = ExecutionStrategy::new("test", steps).unwrap();
let mut pctx = PipelineContext::new(
"other.mod",
serde_json::json!({}),
Context::<serde_json::Value>::anonymous(),
"test",
);
PipelineEngine::run_until(&strategy, &mut pctx, "execute")
.await
.unwrap();
assert_eq!(
filtered_count.load(Ordering::SeqCst),
0,
"step with match_modules=['api.*'] must be skipped when module_id='other.mod' \
— confirms streaming and non-streaming paths share dispatch semantics"
);
}
}