use std::collections::HashMap;
use std::time::Instant;
use tracing::{error, info, warn};
use crate::error::Result;
use super::super::PipelineOptions;
use super::super::stages::IndexStage;
use super::context::{IndexContext, IndexInput, PipelineResult, StageResult};
use super::policy::FailurePolicy;
struct StageEntry {
stage: Box<dyn IndexStage>,
priority: i32,
depends_on: Vec<String>,
}
impl std::fmt::Debug for StageEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StageEntry")
.field("name", &self.stage.name())
.field("priority", &self.priority)
.field("depends_on", &self.depends_on)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ExecutionGroup {
pub stage_indices: Vec<usize>,
pub parallel: bool,
}
pub struct PipelineOrchestrator {
stages: Vec<StageEntry>,
}
impl Default for PipelineOrchestrator {
fn default() -> Self {
Self::new()
}
}
impl PipelineOrchestrator {
pub fn new() -> Self {
Self { stages: Vec::new() }
}
pub fn stage<S>(mut self, stage: S) -> Self
where
S: IndexStage + 'static,
{
let deps = stage.depends_on();
self.stages.push(StageEntry {
stage: Box::new(stage),
priority: 100,
depends_on: deps.into_iter().map(|s| s.to_string()).collect(),
});
self
}
pub fn stage_with_priority<S>(mut self, stage: S, priority: i32) -> Self
where
S: IndexStage + 'static,
{
let deps = stage.depends_on();
self.stages.push(StageEntry {
stage: Box::new(stage),
priority,
depends_on: deps.into_iter().map(|s| s.to_string()).collect(),
});
self
}
pub fn stage_with_deps<S>(
mut self,
stage: S,
priority: i32,
explicit_depends_on: &[&str],
) -> Self
where
S: IndexStage + 'static,
{
let trait_deps = stage.depends_on();
let mut all_deps: Vec<String> = trait_deps.into_iter().map(|s| s.to_string()).collect();
for dep in explicit_depends_on {
if !all_deps.iter().any(|d| d == dep) {
all_deps.push(dep.to_string());
}
}
self.stages.push(StageEntry {
stage: Box::new(stage),
priority,
depends_on: all_deps,
});
self
}
pub fn remove_stage(mut self, name: &str) -> Self {
self.stages.retain(|entry| entry.stage.name() != name);
self
}
pub fn has_stage(&self, name: &str) -> bool {
self.stages.iter().any(|entry| entry.stage.name() == name)
}
pub fn stage_count(&self) -> usize {
self.stages.len()
}
fn resolve_order(&self) -> Result<Vec<usize>> {
let name_to_idx: HashMap<&str, usize> = self
.stages
.iter()
.enumerate()
.map(|(i, entry)| (entry.stage.name(), i))
.collect();
for entry in &self.stages {
for dep in &entry.depends_on {
if !name_to_idx.contains_key(dep.as_str()) {
return Err(crate::error::Error::Config(format!(
"Stage '{}' depends on non-existent stage '{}'",
entry.stage.name(),
dep
)));
}
}
}
let n = self.stages.len();
let mut in_degree: Vec<usize> = vec![0; n];
let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
for (i, entry) in self.stages.iter().enumerate() {
for dep in &entry.depends_on {
if let Some(&dep_idx) = name_to_idx.get(dep.as_str()) {
adjacency.entry(dep_idx).or_default().push(i);
in_degree[i] += 1;
}
}
}
let mut ready: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
ready.sort_by_key(|&i| (self.stages[i].priority, i));
let mut result: Vec<usize> = Vec::new();
while let Some(idx) = ready.first().cloned() {
ready.remove(0);
result.push(idx);
if let Some(neighbors) = adjacency.get(&idx) {
for &neighbor in neighbors {
in_degree[neighbor] -= 1;
if in_degree[neighbor] == 0 {
let entry = &self.stages[neighbor];
let pos = ready
.binary_search_by_key(&(entry.priority, neighbor), |&i| {
(self.stages[i].priority, i)
})
.unwrap_or_else(|e| e);
ready.insert(pos, neighbor);
}
}
}
}
if result.len() != n {
let remaining: Vec<&str> = (0..n)
.filter(|i| !result.contains(i))
.map(|i| self.stages[i].stage.name())
.collect();
return Err(crate::error::Error::Config(format!(
"Circular dependency detected involving stages: {:?}",
remaining
)));
}
Ok(result)
}
fn compute_execution_groups(&self, order: &[usize]) -> Vec<ExecutionGroup> {
if order.is_empty() {
return Vec::new();
}
let name_to_idx: HashMap<&str, usize> = self
.stages
.iter()
.enumerate()
.map(|(i, entry)| (entry.stage.name(), i))
.collect();
let mut levels: HashMap<usize, usize> = HashMap::new();
for &idx in order {
let entry = &self.stages[idx];
let level = if entry.depends_on.is_empty() {
0
} else {
entry
.depends_on
.iter()
.filter_map(|dep| {
name_to_idx
.get(dep.as_str())
.and_then(|&dep_idx| levels.get(&dep_idx))
})
.max()
.map(|&l| l + 1)
.unwrap_or(0)
};
levels.insert(idx, level);
}
let mut level_groups: HashMap<usize, Vec<usize>> = HashMap::new();
for &idx in order {
let level = levels[&idx];
level_groups.entry(level).or_default().push(idx);
}
let max_level = *levels.values().max().unwrap_or(&0);
(0..=max_level)
.filter_map(|level| {
level_groups.get(&level).map(|indices| ExecutionGroup {
stage_indices: indices.clone(),
parallel: indices.len() > 1,
})
})
.collect()
}
async fn execute_stage_with_policy(
stage: &mut Box<dyn IndexStage>,
ctx: &mut IndexContext,
) -> Result<StageResult> {
let policy = stage.failure_policy();
let stage_name = stage.name().to_string();
match policy {
FailurePolicy::Fail => {
stage.execute(ctx).await
}
FailurePolicy::Skip => {
match stage.execute(ctx).await {
Ok(result) => Ok(result),
Err(e) => {
warn!("Stage {} failed, skipping: {}", stage_name, e);
Ok(StageResult::failure(&stage_name, &e.to_string()))
}
}
}
FailurePolicy::Retry(config) => {
let mut attempts = 0;
loop {
attempts += 1;
match stage.execute(ctx).await {
Ok(result) => {
if attempts > 1 {
info!("Stage {} succeeded on attempt {}", stage_name, attempts);
}
return Ok(result);
}
Err(e) => {
if attempts >= config.max_attempts {
warn!(
"Stage {} failed after {} attempts: {}",
stage_name, attempts, e
);
return Err(e);
}
let delay = config.delay_for_attempt(attempts - 1);
warn!(
"Stage {} failed on attempt {}, retrying in {:?}: {}",
stage_name, attempts, delay, e
);
tokio::time::sleep(delay).await;
}
}
}
}
}
}
fn handle_stage_result(
result: Result<StageResult>,
stage_name: &str,
policy: &FailurePolicy,
ctx: &mut IndexContext,
) -> Result<()> {
match result {
Ok(result) => {
ctx.stage_results.insert(stage_name.to_string(), result);
Ok(())
}
Err(e) => {
if policy.allows_continuation() {
warn!(
"Stage {} failed but policy allows continuation: {}",
stage_name, e
);
ctx.stage_results.insert(
stage_name.to_string(),
StageResult::failure(stage_name, &e.to_string()),
);
Ok(())
} else {
error!("Stage {} failed, stopping pipeline: {}", stage_name, e);
Err(e)
}
}
}
}
pub async fn execute(
&mut self,
input: IndexInput,
options: PipelineOptions,
) -> Result<PipelineResult> {
let total_start = Instant::now();
info!(
"Starting orchestrated pipeline with {} stages",
self.stages.len()
);
let order = self.resolve_order()?;
let stage_names: Vec<&str> = order.iter().map(|&i| self.stages[i].stage.name()).collect();
info!("Execution order: {:?}", stage_names);
let groups = self.compute_execution_groups(&order);
info!(
"Execution groups: {} ({} parallelizable)",
groups.len(),
groups.iter().filter(|g| g.parallel).count()
);
let mut opts = options;
let existing_tree = opts.existing_tree.take();
let mut ctx = IndexContext::new(input, opts);
if let Some(tree) = existing_tree {
ctx = ctx.with_existing_tree(tree);
}
for (group_idx, group) in groups.iter().enumerate() {
if group.parallel {
info!(
"Executing parallel group {} with {} stages: {:?}",
group_idx,
group.stage_indices.len(),
group
.stage_indices
.iter()
.map(|&i| self.stages[i].stage.name())
.collect::<Vec<_>>()
);
}
if group.parallel && group.stage_indices.len() == 2 {
let idx_a = group.stage_indices[0];
let idx_b = group.stage_indices[1];
let (writer_idx, reader_idx) = {
let ap_a = self.stages[idx_a].stage.access_pattern();
let ap_b = self.stages[idx_b].stage.access_pattern();
if ap_b.writes_tree && !ap_a.writes_tree {
(idx_b, idx_a) } else {
(idx_a, idx_b) }
};
let tree_snapshot = ctx.tree.clone();
let options_snapshot = ctx.options.clone();
let existing_tree_snapshot = ctx.existing_tree.clone();
let mut stage_writer =
std::mem::replace(&mut self.stages[writer_idx].stage, Box::new(NopStage));
let mut stage_reader =
std::mem::replace(&mut self.stages[reader_idx].stage, Box::new(NopStage));
let writer_name = stage_writer.name().to_string();
let reader_name = stage_reader.name().to_string();
let writer_policy = stage_writer.failure_policy();
let reader_policy = stage_reader.failure_policy();
info!("Parallel: executing {} ∥ {}", writer_name, reader_name);
let mut reader_ctx = IndexContext::new(IndexInput::content(""), options_snapshot);
reader_ctx.tree = tree_snapshot;
reader_ctx.existing_tree = existing_tree_snapshot;
reader_ctx.doc_id = ctx.doc_id.clone();
reader_ctx.name = ctx.name.clone();
reader_ctx.format = ctx.format;
reader_ctx.source_path = ctx.source_path.clone();
let (writer_result, reader_result) = tokio::join!(
Self::execute_stage_with_policy(&mut stage_writer, &mut ctx),
Self::execute_stage_with_policy(&mut stage_reader, &mut reader_ctx),
);
self.stages[writer_idx].stage = stage_writer;
self.stages[reader_idx].stage = stage_reader;
Self::handle_stage_result(writer_result, &writer_name, &writer_policy, &mut ctx)?;
Self::handle_stage_result(reader_result, &reader_name, &reader_policy, &mut ctx)?;
let reader_ap = self.stages[reader_idx].stage.access_pattern();
if reader_ap.writes_reasoning_index {
ctx.reasoning_index = reader_ctx.reasoning_index;
}
if reader_ap.writes_description {
ctx.description = reader_ctx.description;
}
ctx.metrics.llm_calls += reader_ctx.metrics.llm_calls;
ctx.metrics.summaries_generated += reader_ctx.metrics.summaries_generated;
ctx.metrics.total_tokens_generated += reader_ctx.metrics.total_tokens_generated;
ctx.metrics.nodes_processed += reader_ctx.metrics.nodes_processed;
if reader_ctx.metrics.reasoning_index_time_ms > 0 {
ctx.metrics.record_reasoning_index(
reader_ctx.metrics.reasoning_index_time_ms,
reader_ctx.metrics.topics_indexed,
reader_ctx.metrics.keywords_indexed,
);
}
if reader_ctx.metrics.optimize_time_ms > 0 {
ctx.metrics
.record_optimize(reader_ctx.metrics.optimize_time_ms);
}
ctx.metrics.nodes_merged += reader_ctx.metrics.nodes_merged;
ctx.metrics.nodes_skipped += reader_ctx.metrics.nodes_skipped;
} else {
for &idx in &group.stage_indices {
let entry = &mut self.stages[idx];
let stage_name = entry.stage.name().to_string();
let policy = entry.stage.failure_policy();
info!(
"Executing stage: {} (priority {})",
stage_name, entry.priority
);
match Self::execute_stage_with_policy(&mut entry.stage, &mut ctx).await {
Ok(result) => {
ctx.stage_results.insert(stage_name.clone(), result);
}
Err(e) => {
if policy.allows_continuation() {
warn!(
"Stage {} failed but policy allows continuation: {}",
stage_name, e
);
ctx.stage_results.insert(
stage_name.clone(),
StageResult::failure(&stage_name, &e.to_string()),
);
} else {
error!("Stage {} failed, stopping pipeline: {}", stage_name, e);
return Err(e);
}
}
}
}
}
}
let total_duration = total_start.elapsed().as_millis() as u64;
info!(
"Orchestrated pipeline completed in {}ms for document {}",
total_duration, ctx.name
);
Ok(ctx.finalize())
}
pub fn stage_names(&self) -> Result<Vec<&str>> {
let order = self.resolve_order()?;
Ok(order.iter().map(|&i| self.stages[i].stage.name()).collect())
}
pub fn get_execution_groups(&self) -> Result<Vec<ExecutionGroup>> {
let order = self.resolve_order()?;
Ok(self.compute_execution_groups(&order))
}
}
struct NopStage;
#[async_trait::async_trait]
impl IndexStage for NopStage {
fn name(&self) -> &'static str {
"_nop"
}
async fn execute(&mut self, _ctx: &mut IndexContext) -> Result<StageResult> {
Ok(StageResult::success("_nop"))
}
}
pub struct CustomStageBuilder {
name: String,
priority: i32,
depends_on: Vec<String>,
optional: bool,
}
impl CustomStageBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
priority: 100,
depends_on: Vec::new(),
optional: false,
}
}
pub fn priority(mut self, priority: i32) -> Self {
self.priority = priority;
self
}
pub fn depends_on(mut self, stage: impl Into<String>) -> Self {
self.depends_on.push(stage.into());
self
}
pub fn optional(mut self) -> Self {
self.optional = true;
self
}
pub fn name(&self) -> &str {
&self.name
}
pub fn get_priority(&self) -> i32 {
self.priority
}
pub fn get_deps(&self) -> &[String] {
&self.depends_on
}
pub fn is_optional(&self) -> bool {
self.optional
}
}
#[cfg(test)]
mod tests {
use super::super::context::StageResult;
use super::*;
#[test]
fn test_orchestrator_creation() {
let orchestrator = PipelineOrchestrator::new();
assert_eq!(orchestrator.stage_count(), 0);
}
#[test]
fn test_add_stages() {
let orchestrator = PipelineOrchestrator::new()
.stage_with_priority(MockStage::new("a"), 10)
.stage_with_priority(MockStage::new("b"), 20)
.stage_with_priority(MockStage::new("c"), 5);
assert_eq!(orchestrator.stage_count(), 3);
let names = orchestrator.stage_names().unwrap();
assert_eq!(names, vec!["c", "a", "b"]); }
#[test]
fn test_dependency_resolution() {
let orchestrator = PipelineOrchestrator::new()
.stage_with_priority(MockStage::new("a"), 10)
.stage_with_deps(MockStage::new("b"), 5, &["a"]) .stage_with_deps(MockStage::new("c"), 1, &["b"]);
let names = orchestrator.stage_names().unwrap();
assert_eq!(names, vec!["a", "b", "c"]);
}
#[test]
fn test_missing_dependency() {
let orchestrator =
PipelineOrchestrator::new().stage_with_deps(MockStage::new("a"), 10, &["nonexistent"]);
let result = orchestrator.stage_names();
assert!(result.is_err());
}
#[test]
fn test_remove_stage() {
let orchestrator = PipelineOrchestrator::new()
.stage(MockStage::new("a"))
.stage(MockStage::new("b"))
.remove_stage("a");
assert_eq!(orchestrator.stage_count(), 1);
assert!(!orchestrator.has_stage("a"));
assert!(orchestrator.has_stage("b"));
}
#[test]
fn test_custom_stage_builder() {
let builder = CustomStageBuilder::new("my_stage")
.priority(50)
.depends_on("parse")
.optional();
assert_eq!(builder.name(), "my_stage");
assert_eq!(builder.get_priority(), 50);
assert_eq!(builder.get_deps(), &["parse".to_string()]);
assert!(builder.is_optional());
}
struct MockStage {
name: String,
}
impl MockStage {
fn new(name: &str) -> Self {
Self {
name: name.to_string(),
}
}
}
#[async_trait::async_trait]
impl IndexStage for MockStage {
fn name(&self) -> &str {
&self.name
}
async fn execute(&mut self, _ctx: &mut IndexContext) -> Result<StageResult> {
Ok(StageResult::success(&self.name))
}
}
}