use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use crate::decompose::{DecompositionError, TaskDecomposer};
use crate::state::{StepState, TaskState};
#[derive(Debug, Error)]
pub enum OrchestrateError {
#[error("Decomposition failed: {0}")]
Decomposition(#[from] DecompositionError),
#[error("Graph error: {0}")]
Graph(#[from] crate::graph::GraphError),
#[error("Sandbox error: {0}")]
Sandbox(String),
#[error("Confirmation error: {0}")]
Confirmation(String),
#[error("Budget exceeded: {0}")]
BudgetExceeded(String),
#[error("Audit error: {0}")]
Audit(String),
#[error("Task not found: {0}")]
TaskNotFound(String),
#[error("Task cancelled")]
Cancelled,
}
pub struct TaskOrchestrator {
pub(crate) decomposer: Arc<dyn TaskDecomposer>,
pub(crate) audit: Option<Arc<dyn audit::AuditTrail>>,
pub(crate) confirm: Option<Arc<dyn confirm::ConfirmationEngine>>,
pub(crate) budget: Option<Arc<dyn budget::CostBudget>>,
pub(crate) sandbox: Option<Arc<dyn sandbox::SandboxExecutor>>,
pub(crate) agents: Option<Arc<delegate::AgentRegistry>>,
pub(crate) llm: Option<Arc<dyn cortex::LlmProvider>>,
pub(crate) dispatcher: Option<Arc<channel::ChannelDispatcher>>,
pub(crate) episodic: Option<Arc<hippocampus::EpisodicStore>>,
pub(crate) delegation_policy: delegate::EscalationPolicy,
pub(crate) available_tools: Vec<String>,
pub(crate) tasks: RwLock<HashMap<String, TaskState>>,
pub(crate) observer: Option<Arc<dyn observe::Observer>>,
pub(crate) state_pool: Option<storage::SqlitePool>,
pub(crate) cancel_tokens: RwLock<HashMap<String, CancellationToken>>,
}
pub(crate) const MAX_REPLAN_ATTEMPTS: u32 = 2;
impl TaskOrchestrator {
pub fn new(decomposer: Arc<dyn TaskDecomposer>) -> Self {
Self {
decomposer,
audit: None,
confirm: None,
budget: None,
sandbox: None,
agents: None,
llm: None,
dispatcher: None,
episodic: None,
delegation_policy: delegate::EscalationPolicy::default(),
available_tools: Vec::new(),
tasks: RwLock::new(HashMap::new()),
observer: None,
state_pool: None,
cancel_tokens: RwLock::new(HashMap::new()),
}
}
pub(crate) async fn cancel_token_for(&self, task_id: &str) -> CancellationToken {
self.cancel_tokens
.read()
.await
.get(task_id)
.cloned()
.unwrap_or_else(CancellationToken::new)
}
pub(crate) async fn mark_step_cancelled(&self, task_id: &str, step_id: &str) {
let mut tasks = self.tasks.write().await;
if let Some(task) = tasks.get_mut(task_id) {
task.set_step_state(step_id, StepState::Cancelled);
}
}
pub fn with_available_tools(mut self, tools: Vec<String>) -> Self {
self.available_tools = tools;
self
}
pub fn with_audit(mut self, audit: Arc<dyn audit::AuditTrail>) -> Self {
self.audit = Some(audit);
self
}
pub fn with_confirmation(mut self, confirm: Arc<dyn confirm::ConfirmationEngine>) -> Self {
self.confirm = Some(confirm);
self
}
pub fn with_budget(mut self, budget: Arc<dyn budget::CostBudget>) -> Self {
self.budget = Some(budget);
self
}
pub fn with_sandbox(mut self, sandbox: Arc<dyn sandbox::SandboxExecutor>) -> Self {
self.sandbox = Some(sandbox);
self
}
pub fn with_agents(mut self, agents: Arc<delegate::AgentRegistry>) -> Self {
self.agents = Some(agents);
self
}
pub fn with_llm(mut self, llm: Arc<dyn cortex::LlmProvider>) -> Self {
self.llm = Some(llm);
self
}
pub fn with_channel_dispatcher(mut self, dispatcher: Arc<channel::ChannelDispatcher>) -> Self {
self.dispatcher = Some(dispatcher);
self
}
pub fn with_episodic(mut self, store: Arc<hippocampus::EpisodicStore>) -> Self {
self.episodic = Some(store);
self
}
pub fn with_observer(mut self, observer: Arc<dyn observe::Observer>) -> Self {
self.observer = Some(observer);
self
}
pub fn with_state_pool(mut self, pool: storage::SqlitePool) -> Self {
self.state_pool = Some(pool);
self
}
pub fn with_delegation_policy(mut self, policy: delegate::EscalationPolicy) -> Self {
self.delegation_policy = policy;
self
}
}
#[cfg(test)]
mod tests;