use std::collections::HashMap;
use std::time::Instant;
use tracing::{error, info, warn};
use crate::error::Result;
use super::super::PipelineOptions;
use super::super::stages::IndexStage;
use super::context::{IndexContext, IndexInput, IndexResult, StageResult};
use super::policy::FailurePolicy;
struct StageEntry {
stage: Box<dyn IndexStage>,
priority: i32,
depends_on: Vec<String>,
}
impl std::fmt::Debug for StageEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StageEntry")
.field("name", &self.stage.name())
.field("priority", &self.priority)
.field("depends_on", &self.depends_on)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ExecutionGroup {
pub stage_indices: Vec<usize>,
pub parallel: bool,
}
pub struct PipelineOrchestrator {
stages: Vec<StageEntry>,
}
impl Default for PipelineOrchestrator {
fn default() -> Self {
Self::new()
}
}
impl PipelineOrchestrator {
pub fn new() -> Self {
Self { stages: Vec::new() }
}
pub fn stage<S>(mut self, stage: S) -> Self
where
S: IndexStage + 'static,
{
let deps = stage.depends_on();
self.stages.push(StageEntry {
stage: Box::new(stage),
priority: 100,
depends_on: deps.into_iter().map(|s| s.to_string()).collect(),
});
self
}
pub fn stage_with_priority<S>(mut self, stage: S, priority: i32) -> Self
where
S: IndexStage + 'static,
{
let deps = stage.depends_on();
self.stages.push(StageEntry {
stage: Box::new(stage),
priority,
depends_on: deps.into_iter().map(|s| s.to_string()).collect(),
});
self
}
pub fn stage_with_deps<S>(
mut self,
stage: S,
priority: i32,
explicit_depends_on: &[&str],
) -> Self
where
S: IndexStage + 'static,
{
let trait_deps = stage.depends_on();
let mut all_deps: Vec<String> = trait_deps.into_iter().map(|s| s.to_string()).collect();
for dep in explicit_depends_on {
if !all_deps.iter().any(|d| d == dep) {
all_deps.push(dep.to_string());
}
}
self.stages.push(StageEntry {
stage: Box::new(stage),
priority,
depends_on: all_deps,
});
self
}
pub fn remove_stage(mut self, name: &str) -> Self {
self.stages.retain(|entry| entry.stage.name() != name);
self
}
pub fn has_stage(&self, name: &str) -> bool {
self.stages.iter().any(|entry| entry.stage.name() == name)
}
pub fn stage_count(&self) -> usize {
self.stages.len()
}
fn resolve_order(&self) -> Result<Vec<usize>> {
let name_to_idx: HashMap<&str, usize> = self
.stages
.iter()
.enumerate()
.map(|(i, entry)| (entry.stage.name(), i))
.collect();
for entry in &self.stages {
for dep in &entry.depends_on {
if !name_to_idx.contains_key(dep.as_str()) {
return Err(crate::error::Error::Config(format!(
"Stage '{}' depends on non-existent stage '{}'",
entry.stage.name(),
dep
)));
}
}
}
let n = self.stages.len();
let mut in_degree: Vec<usize> = vec![0; n];
let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
for (i, entry) in self.stages.iter().enumerate() {
for dep in &entry.depends_on {
if let Some(&dep_idx) = name_to_idx.get(dep.as_str()) {
adjacency.entry(dep_idx).or_default().push(i);
in_degree[i] += 1;
}
}
}
let mut ready: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
ready.sort_by_key(|&i| (self.stages[i].priority, i));
let mut result: Vec<usize> = Vec::new();
while let Some(idx) = ready.first().cloned() {
ready.remove(0);
result.push(idx);
if let Some(neighbors) = adjacency.get(&idx) {
for &neighbor in neighbors {
in_degree[neighbor] -= 1;
if in_degree[neighbor] == 0 {
let entry = &self.stages[neighbor];
let pos = ready
.binary_search_by_key(&(entry.priority, neighbor), |&i| {
(self.stages[i].priority, i)
})
.unwrap_or_else(|e| e);
ready.insert(pos, neighbor);
}
}
}
}
if result.len() != n {
let remaining: Vec<&str> = result
.iter()
.filter(|&&i| !result.contains(&i))
.map(|&i| self.stages[i].stage.name())
.collect();
return Err(crate::error::Error::Config(format!(
"Circular dependency detected involving stages: {:?}",
remaining
)));
}
Ok(result)
}
fn compute_execution_groups(&self, order: &[usize]) -> Vec<ExecutionGroup> {
if order.is_empty() {
return Vec::new();
}
let name_to_idx: HashMap<&str, usize> = self
.stages
.iter()
.enumerate()
.map(|(i, entry)| (entry.stage.name(), i))
.collect();
let mut levels: HashMap<usize, usize> = HashMap::new();
for &idx in order {
let entry = &self.stages[idx];
let level = if entry.depends_on.is_empty() {
0
} else {
entry
.depends_on
.iter()
.filter_map(|dep| {
name_to_idx
.get(dep.as_str())
.and_then(|&dep_idx| levels.get(&dep_idx))
})
.max()
.map(|&l| l + 1)
.unwrap_or(0)
};
levels.insert(idx, level);
}
let mut level_groups: HashMap<usize, Vec<usize>> = HashMap::new();
for &idx in order {
let level = levels[&idx];
level_groups.entry(level).or_default().push(idx);
}
let max_level = *levels.values().max().unwrap_or(&0);
(0..=max_level)
.filter_map(|level| {
level_groups.get(&level).map(|indices| ExecutionGroup {
stage_indices: indices.clone(),
parallel: indices.len() > 1,
})
})
.collect()
}
async fn execute_stage_with_policy(
stage: &mut Box<dyn IndexStage>,
ctx: &mut IndexContext,
) -> Result<StageResult> {
let policy = stage.failure_policy();
let stage_name = stage.name().to_string();
match policy {
FailurePolicy::Fail => {
stage.execute(ctx).await
}
FailurePolicy::Skip => {
match stage.execute(ctx).await {
Ok(result) => Ok(result),
Err(e) => {
warn!("Stage {} failed, skipping: {}", stage_name, e);
Ok(StageResult::failure(&stage_name, &e.to_string()))
}
}
}
FailurePolicy::Retry(config) => {
let mut attempts = 0;
loop {
attempts += 1;
match stage.execute(ctx).await {
Ok(result) => {
if attempts > 1 {
info!("Stage {} succeeded on attempt {}", stage_name, attempts);
}
return Ok(result);
}
Err(e) => {
if attempts >= config.max_attempts {
warn!(
"Stage {} failed after {} attempts: {}",
stage_name, attempts, e
);
return Err(e);
}
let delay = config.delay_for_attempt(attempts - 1);
warn!(
"Stage {} failed on attempt {}, retrying in {:?}: {}",
stage_name, attempts, delay, e
);
tokio::time::sleep(delay).await;
}
}
}
}
}
}
pub async fn execute(
&mut self,
input: IndexInput,
options: PipelineOptions,
) -> Result<IndexResult> {
let total_start = Instant::now();
info!(
"Starting orchestrated pipeline with {} stages",
self.stages.len()
);
let order = self.resolve_order()?;
let stage_names: Vec<&str> = order.iter().map(|&i| self.stages[i].stage.name()).collect();
info!("Execution order: {:?}", stage_names);
let groups = self.compute_execution_groups(&order);
info!(
"Execution groups: {} ({} parallelizable)",
groups.len(),
groups.iter().filter(|g| g.parallel).count()
);
let mut ctx = IndexContext::new(input, options);
for (group_idx, group) in groups.iter().enumerate() {
if group.parallel {
info!(
"Executing parallel group {} with {} stages: {:?}",
group_idx,
group.stage_indices.len(),
group
.stage_indices
.iter()
.map(|&i| self.stages[i].stage.name())
.collect::<Vec<_>>()
);
}
for &idx in &group.stage_indices {
let entry = &mut self.stages[idx];
let stage_name = entry.stage.name().to_string();
let policy = entry.stage.failure_policy();
info!(
"Executing stage: {} (priority {})",
stage_name, entry.priority
);
match Self::execute_stage_with_policy(&mut entry.stage, &mut ctx).await {
Ok(result) => {
ctx.stage_results.insert(stage_name.clone(), result);
}
Err(e) => {
if policy.allows_continuation() {
warn!(
"Stage {} failed but policy allows continuation: {}",
stage_name, e
);
ctx.stage_results.insert(
stage_name.clone(),
StageResult::failure(&stage_name, &e.to_string()),
);
} else {
error!("Stage {} failed, stopping pipeline: {}", stage_name, e);
return Err(e);
}
}
}
}
}
let total_duration = total_start.elapsed().as_millis() as u64;
info!(
"Orchestrated pipeline completed in {}ms for document {}",
total_duration, ctx.name
);
Ok(ctx.finalize())
}
pub fn stage_names(&self) -> Result<Vec<&str>> {
let order = self.resolve_order()?;
Ok(order.iter().map(|&i| self.stages[i].stage.name()).collect())
}
pub fn get_execution_groups(&self) -> Result<Vec<ExecutionGroup>> {
let order = self.resolve_order()?;
Ok(self.compute_execution_groups(&order))
}
}
pub struct CustomStageBuilder {
name: String,
priority: i32,
depends_on: Vec<String>,
optional: bool,
}
impl CustomStageBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
priority: 100,
depends_on: Vec::new(),
optional: false,
}
}
pub fn priority(mut self, priority: i32) -> Self {
self.priority = priority;
self
}
pub fn depends_on(mut self, stage: impl Into<String>) -> Self {
self.depends_on.push(stage.into());
self
}
pub fn optional(mut self) -> Self {
self.optional = true;
self
}
pub fn name(&self) -> &str {
&self.name
}
pub fn get_priority(&self) -> i32 {
self.priority
}
pub fn get_deps(&self) -> &[String] {
&self.depends_on
}
pub fn is_optional(&self) -> bool {
self.optional
}
}
#[cfg(test)]
mod tests {
use super::super::context::StageResult;
use super::*;
#[test]
fn test_orchestrator_creation() {
let orchestrator = PipelineOrchestrator::new();
assert_eq!(orchestrator.stage_count(), 0);
}
#[test]
fn test_add_stages() {
let orchestrator = PipelineOrchestrator::new()
.stage_with_priority(MockStage::new("a"), 10)
.stage_with_priority(MockStage::new("b"), 20)
.stage_with_priority(MockStage::new("c"), 5);
assert_eq!(orchestrator.stage_count(), 3);
let names = orchestrator.stage_names().unwrap();
assert_eq!(names, vec!["c", "a", "b"]); }
#[test]
fn test_dependency_resolution() {
let orchestrator = PipelineOrchestrator::new()
.stage_with_priority(MockStage::new("a"), 10)
.stage_with_deps(MockStage::new("b"), 5, &["a"]) .stage_with_deps(MockStage::new("c"), 1, &["b"]);
let names = orchestrator.stage_names().unwrap();
assert_eq!(names, vec!["a", "b", "c"]);
}
#[test]
fn test_missing_dependency() {
let orchestrator =
PipelineOrchestrator::new().stage_with_deps(MockStage::new("a"), 10, &["nonexistent"]);
let result = orchestrator.stage_names();
assert!(result.is_err());
}
#[test]
fn test_remove_stage() {
let orchestrator = PipelineOrchestrator::new()
.stage(MockStage::new("a"))
.stage(MockStage::new("b"))
.remove_stage("a");
assert_eq!(orchestrator.stage_count(), 1);
assert!(!orchestrator.has_stage("a"));
assert!(orchestrator.has_stage("b"));
}
#[test]
fn test_custom_stage_builder() {
let builder = CustomStageBuilder::new("my_stage")
.priority(50)
.depends_on("parse")
.optional();
assert_eq!(builder.name(), "my_stage");
assert_eq!(builder.get_priority(), 50);
assert_eq!(builder.get_deps(), &["parse".to_string()]);
assert!(builder.is_optional());
}
struct MockStage {
name: String,
}
impl MockStage {
fn new(name: &str) -> Self {
Self {
name: name.to_string(),
}
}
}
#[async_trait::async_trait]
impl IndexStage for MockStage {
fn name(&self) -> &str {
&self.name
}
async fn execute(&mut self, _ctx: &mut IndexContext) -> Result<StageResult> {
Ok(StageResult::success(&self.name))
}
}
}