mod persistence;
mod planner;
mod router;
mod tick;
use std::collections::{HashMap, VecDeque};
use std::path::PathBuf;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use super::cascade::{CascadeConfig, CascadeDetector};
use super::dag;
use super::error::OrchestrationError;
use super::graph::{GraphStatus, TaskGraph, TaskId, TaskStatus};
use super::router::AgentRouter;
use super::topology::{TopologyAnalysis, TopologyClassifier};
pub(super) use super::verifier::inject_tasks as verifier_inject_tasks;
use zeph_config::OrchestrationConfig;
use zeph_sanitizer::{ContentIsolationConfig, ContentSanitizer};
#[derive(Debug)]
pub enum SchedulerAction {
Spawn {
task_id: TaskId,
agent_def_name: String,
prompt: String,
},
Cancel {
agent_handle_id: String,
},
RunInline {
task_id: TaskId,
prompt: String,
},
Done {
status: GraphStatus,
},
VerifyPredicate {
task_id: TaskId,
predicate: super::verify_predicate::VerifyPredicate,
output: String,
},
Verify {
task_id: TaskId,
output: String,
},
}
#[derive(Debug)]
pub struct TaskEvent {
pub task_id: TaskId,
pub agent_handle_id: String,
pub outcome: TaskOutcome,
}
#[derive(Debug)]
pub enum TaskOutcome {
Completed {
output: String,
artifacts: Vec<PathBuf>,
},
Failed {
error: String,
},
}
pub(super) struct RunningTask {
pub(super) agent_handle_id: String,
pub(super) agent_def_name: String,
pub(super) started_at: Instant,
}
#[allow(clippy::struct_excessive_bools)] pub struct DagScheduler {
pub(super) graph: TaskGraph,
pub(super) max_parallel: usize,
pub(super) config_max_parallel: usize,
pub(super) running: HashMap<TaskId, RunningTask>,
pub(super) event_rx: mpsc::Receiver<TaskEvent>,
pub(super) event_tx: mpsc::Sender<TaskEvent>,
pub(super) task_timeout: Duration,
pub(super) router: Box<dyn AgentRouter>,
pub(super) available_agents: Vec<zeph_subagent::SubAgentDef>,
pub(super) dependency_context_budget: usize,
pub(super) buffered_events: VecDeque<TaskEvent>,
pub(super) sanitizer: ContentSanitizer,
pub(super) deferral_backoff: Duration,
pub(super) consecutive_spawn_failures: u32,
pub(super) topology: TopologyAnalysis,
pub(super) topology_dirty: bool,
pub(super) current_level: usize,
pub(super) verify_completeness: bool,
pub(super) verify_provider: String,
pub(super) task_replan_counts: HashMap<TaskId, u32>,
pub(super) global_replan_count: u32,
pub(super) max_replans: u32,
pub(super) completeness_threshold_value: f32,
pub(super) cascade_detector: Option<CascadeDetector>,
pub(super) tree_optimized_dispatch: bool,
pub(super) cascade_routing: bool,
pub(super) lineage_chains: HashMap<TaskId, super::lineage::ErrorLineage>,
pub(super) cascade_chain_threshold: usize,
pub(super) cascade_failure_rate_abort_threshold: f32,
pub(super) lineage_ttl_secs: u64,
pub(super) verify_predicate_enabled: bool,
pub(super) predicate_provider: String,
pub(super) max_predicate_replans: u32,
pub(super) predicate_replans_used: u32,
pub(super) predicate_reasons: HashMap<TaskId, String>,
}
impl std::fmt::Debug for DagScheduler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DagScheduler")
.field("graph_id", &self.graph.id)
.field("graph_status", &self.graph.status)
.field("running_count", &self.running.len())
.field("max_parallel", &self.max_parallel)
.field("task_timeout_secs", &self.task_timeout.as_secs())
.field("topology", &self.topology.topology)
.field("strategy", &self.topology.strategy)
.field("current_level", &self.current_level)
.field("global_replan_count", &self.global_replan_count)
.field("cascade_routing", &self.cascade_routing)
.field("tree_optimized_dispatch", &self.tree_optimized_dispatch)
.finish_non_exhaustive()
}
}
impl DagScheduler {
pub fn new(
mut graph: TaskGraph,
config: &OrchestrationConfig,
router: Box<dyn AgentRouter>,
available_agents: Vec<zeph_subagent::SubAgentDef>,
) -> Result<Self, OrchestrationError> {
if graph.status != GraphStatus::Created {
return Err(OrchestrationError::InvalidGraph(format!(
"graph must be in Created status, got {}",
graph.status
)));
}
dag::validate(&graph.tasks, config.max_tasks as usize)?;
graph.status = GraphStatus::Running;
for task in &mut graph.tasks {
if task.depends_on.is_empty() && task.status == TaskStatus::Pending {
task.status = TaskStatus::Ready;
}
}
let (event_tx, event_rx) = mpsc::channel(64);
let task_timeout = if config.task_timeout_secs > 0 {
Duration::from_secs(config.task_timeout_secs)
} else {
Duration::from_mins(10)
};
let topology = TopologyClassifier::analyze(&graph, config);
let max_parallel = topology.max_parallel;
let config_max_parallel = config.max_parallel as usize;
if config.topology_selection {
tracing::debug!(
topology = ?topology.topology,
strategy = ?topology.strategy,
max_parallel,
"topology-aware concurrency limit applied"
);
}
if config.cascade_routing && !config.topology_selection {
tracing::warn!(
"cascade_routing = true requires topology_selection = true; \
cascade routing is disabled (topology_selection is off)"
);
}
let cascade_detector = if config.cascade_routing && config.topology_selection {
Some(CascadeDetector::new(CascadeConfig {
failure_threshold: config.cascade_failure_threshold,
}))
} else {
None
};
Ok(Self {
graph,
max_parallel,
config_max_parallel,
running: HashMap::new(),
event_rx,
event_tx,
task_timeout,
router,
available_agents,
dependency_context_budget: config.dependency_context_budget,
buffered_events: VecDeque::new(),
sanitizer: ContentSanitizer::new(&ContentIsolationConfig::default()),
deferral_backoff: Duration::from_millis(config.deferral_backoff_ms),
consecutive_spawn_failures: 0,
topology,
topology_dirty: false,
current_level: 0,
verify_completeness: config.verify_completeness,
verify_provider: config.verify_provider.as_str().trim().to_owned(),
task_replan_counts: HashMap::new(),
global_replan_count: 0,
max_replans: config.max_replans,
completeness_threshold_value: config.completeness_threshold,
cascade_detector,
tree_optimized_dispatch: config.tree_optimized_dispatch,
cascade_routing: config.cascade_routing && config.topology_selection,
lineage_chains: HashMap::new(),
cascade_chain_threshold: config.cascade_chain_threshold,
cascade_failure_rate_abort_threshold: config.cascade_failure_rate_abort_threshold,
lineage_ttl_secs: config.lineage_ttl_secs,
verify_predicate_enabled: config.verify_predicate_enabled,
predicate_provider: config.predicate_provider.as_str().trim().to_owned(),
max_predicate_replans: config.max_predicate_replans,
predicate_replans_used: 0,
predicate_reasons: HashMap::new(),
})
}
pub fn resume_from(
mut graph: TaskGraph,
config: &OrchestrationConfig,
router: Box<dyn AgentRouter>,
available_agents: Vec<zeph_subagent::SubAgentDef>,
) -> Result<Self, OrchestrationError> {
if graph.status == GraphStatus::Completed || graph.status == GraphStatus::Canceled {
return Err(OrchestrationError::InvalidGraph(format!(
"cannot resume a {} graph; only Paused, Failed, or Running graphs are resumable",
graph.status
)));
}
graph.status = GraphStatus::Running;
let running: HashMap<TaskId, RunningTask> = graph
.tasks
.iter()
.filter(|t| t.status == TaskStatus::Running)
.filter_map(|t| {
let handle_id = t.assigned_agent.clone()?;
let def_name = t.agent_hint.clone().unwrap_or_default();
Some((
t.id,
RunningTask {
agent_handle_id: handle_id,
agent_def_name: def_name,
started_at: Instant::now(),
},
))
})
.collect();
let (event_tx, event_rx) = mpsc::channel(64);
let task_timeout = if config.task_timeout_secs > 0 {
Duration::from_secs(config.task_timeout_secs)
} else {
Duration::from_mins(10)
};
let topology = TopologyClassifier::analyze(&graph, config);
let max_parallel = topology.max_parallel;
let config_max_parallel = config.max_parallel as usize;
let cascade_detector = if config.cascade_routing && config.topology_selection {
Some(CascadeDetector::new(CascadeConfig {
failure_threshold: config.cascade_failure_threshold,
}))
} else {
None
};
Ok(Self {
graph,
max_parallel,
config_max_parallel,
running,
event_rx,
event_tx,
task_timeout,
router,
available_agents,
dependency_context_budget: config.dependency_context_budget,
buffered_events: VecDeque::new(),
sanitizer: ContentSanitizer::new(&ContentIsolationConfig::default()),
deferral_backoff: Duration::from_millis(config.deferral_backoff_ms),
consecutive_spawn_failures: 0,
topology,
topology_dirty: false,
current_level: 0,
verify_completeness: config.verify_completeness,
verify_provider: config.verify_provider.as_str().trim().to_owned(),
task_replan_counts: HashMap::new(),
global_replan_count: 0,
max_replans: config.max_replans,
completeness_threshold_value: config.completeness_threshold,
cascade_detector,
tree_optimized_dispatch: config.tree_optimized_dispatch,
cascade_routing: config.cascade_routing && config.topology_selection,
lineage_chains: HashMap::new(),
cascade_chain_threshold: config.cascade_chain_threshold,
cascade_failure_rate_abort_threshold: config.cascade_failure_rate_abort_threshold,
lineage_ttl_secs: config.lineage_ttl_secs,
verify_predicate_enabled: config.verify_predicate_enabled,
predicate_provider: config.predicate_provider.as_str().trim().to_owned(),
max_predicate_replans: config.max_predicate_replans,
predicate_replans_used: 0,
predicate_reasons: HashMap::new(),
})
}
pub fn validate_verify_config(
&self,
provider_names: &[&str],
) -> Result<(), OrchestrationError> {
if !self.verify_completeness {
return Ok(());
}
let name = self.verify_provider.as_str();
if name.is_empty() || provider_names.is_empty() {
return Ok(());
}
if !provider_names.contains(&name) {
return Err(OrchestrationError::InvalidConfig(format!(
"verify_provider \"{}\" not found in [[llm.providers]]; available: [{}]",
name,
provider_names.join(", ")
)));
}
Ok(())
}
#[must_use]
pub fn event_sender(&self) -> mpsc::Sender<TaskEvent> {
self.event_tx.clone()
}
#[must_use]
pub fn graph(&self) -> &TaskGraph {
&self.graph
}
#[must_use]
pub fn into_graph(&self) -> TaskGraph {
self.graph.clone()
}
#[must_use]
pub fn topology(&self) -> &TopologyAnalysis {
&self.topology
}
#[must_use]
pub fn completeness_threshold(&self) -> f32 {
self.completeness_threshold_value
}
#[must_use]
pub fn verify_provider_name(&self) -> &str {
&self.verify_provider
}
#[must_use]
pub fn predicate_provider_name(&self) -> &str {
&self.predicate_provider
}
#[must_use]
pub fn verify_predicate_enabled(&self) -> bool {
self.verify_predicate_enabled
}
#[must_use]
pub fn max_replans_remaining(&self) -> u32 {
self.max_replans.saturating_sub(self.global_replan_count)
}
pub fn record_whole_plan_replan(&mut self) {
self.global_replan_count = self.global_replan_count.saturating_add(1);
}
#[must_use]
pub fn has_running_tasks(&self) -> bool {
!self.running.is_empty()
}
}
impl Drop for DagScheduler {
fn drop(&mut self) {
if !self.running.is_empty() {
tracing::warn!(
running_tasks = self.running.len(),
"DagScheduler dropped with running tasks; agents may continue until their \
CancellationToken fires or they complete naturally"
);
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::default_trait_access)]
use super::*;
use crate::graph::{GraphStatus, TaskGraph, TaskNode, TaskStatus};
pub(super) fn make_node(id: u32, deps: &[u32]) -> TaskNode {
let mut n = TaskNode::new(
id,
format!("task-{id}"),
format!("description for task {id}"),
);
n.depends_on = deps.iter().map(|&d| TaskId(d)).collect();
n
}
pub(super) fn graph_from_nodes(nodes: Vec<TaskNode>) -> TaskGraph {
let mut g = TaskGraph::new("test goal");
g.tasks = nodes;
g
}
pub(super) fn make_def(name: &str) -> zeph_subagent::SubAgentDef {
use zeph_subagent::{SkillFilter, SubAgentPermissions, SubagentHooks, ToolPolicy};
zeph_subagent::SubAgentDef {
name: name.to_string(),
description: format!("{name} agent"),
model: None,
tools: ToolPolicy::InheritAll,
disallowed_tools: vec![],
permissions: SubAgentPermissions::default(),
skills: SkillFilter::default(),
system_prompt: String::new(),
hooks: SubagentHooks::default(),
memory: None,
source: None,
file_path: None,
}
}
pub(super) fn make_config() -> zeph_config::OrchestrationConfig {
zeph_config::OrchestrationConfig {
enabled: true,
max_tasks: 20,
max_parallel: 4,
default_failure_strategy: "abort".to_string(),
default_max_retries: 3,
task_timeout_secs: 300,
planner_provider: Default::default(),
planner_max_tokens: 4096,
dependency_context_budget: 16384,
confirm_before_execute: true,
aggregator_max_tokens: 4096,
deferral_backoff_ms: 250,
plan_cache: zeph_config::PlanCacheConfig::default(),
topology_selection: false,
verify_provider: Default::default(),
verify_max_tokens: 1024,
max_replans: 2,
verify_completeness: false,
completeness_threshold: 0.7,
tool_provider: Default::default(),
cascade_routing: false,
cascade_failure_threshold: 0.5,
tree_optimized_dispatch: false,
adaptorch: Default::default(),
cascade_chain_threshold: 3,
cascade_failure_rate_abort_threshold: 0.0,
lineage_ttl_secs: 300,
verify_predicate_enabled: false,
predicate_provider: Default::default(),
max_predicate_replans: 2,
predicate_timeout_secs: 30,
persistence_enabled: true,
}
}
pub(super) struct FirstRouter;
impl AgentRouter for FirstRouter {
fn route(
&self,
_task: &TaskNode,
available: &[zeph_subagent::SubAgentDef],
) -> Option<String> {
available.first().map(|d| d.name.clone())
}
}
pub(super) struct NoneRouter;
impl AgentRouter for NoneRouter {
fn route(
&self,
_task: &TaskNode,
_available: &[zeph_subagent::SubAgentDef],
) -> Option<String> {
None
}
}
pub(super) fn make_scheduler_with_router(
graph: TaskGraph,
router: Box<dyn AgentRouter>,
) -> DagScheduler {
let config = make_config();
let defs = vec![make_def("worker")];
DagScheduler::new(graph, &config, router, defs).unwrap()
}
pub(super) fn make_scheduler(graph: TaskGraph) -> DagScheduler {
let config = make_config();
let defs = vec![make_def("worker")];
DagScheduler::new(graph, &config, Box::new(FirstRouter), defs).unwrap()
}
#[test]
fn test_new_validates_graph_status() {
let mut graph = graph_from_nodes(vec![make_node(0, &[])]);
graph.status = GraphStatus::Running; let config = make_config();
let result = DagScheduler::new(graph, &config, Box::new(FirstRouter), vec![]);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, OrchestrationError::InvalidGraph(_)));
}
#[test]
fn test_new_marks_roots_ready() {
let graph = graph_from_nodes(vec![
make_node(0, &[]),
make_node(1, &[]),
make_node(2, &[0, 1]),
]);
let scheduler = make_scheduler(graph);
assert_eq!(scheduler.graph().tasks[0].status, TaskStatus::Ready);
assert_eq!(scheduler.graph().tasks[1].status, TaskStatus::Ready);
assert_eq!(scheduler.graph().tasks[2].status, TaskStatus::Pending);
assert_eq!(scheduler.graph().status, GraphStatus::Running);
}
#[test]
fn test_new_validates_empty_graph() {
let graph = graph_from_nodes(vec![]);
let config = make_config();
let result = DagScheduler::new(graph, &config, Box::new(FirstRouter), vec![]);
assert!(result.is_err());
}
}