use std::collections::HashMap;
use std::sync::{Arc, Mutex, PoisonError};
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use crate::agent::{Agent, AgentOptions};
use crate::error::AgentError;
use crate::handle::AgentStatus;
use crate::task_core::{TaskCore, resolve_status};
use crate::types::{AgentMessage, AgentResult, ContentBlock, LlmMessage, UserMessage};
use crate::util::now_timestamp;
type OptionsFactoryArc = Arc<dyn Fn() -> AgentOptions + Send + Sync>;
pub struct AgentRequest {
pub messages: Vec<AgentMessage>,
pub reply: oneshot::Sender<Result<AgentResult, AgentError>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SupervisorAction {
Restart,
Stop,
Escalate,
}
pub trait SupervisorPolicy: Send + Sync {
fn on_agent_error(&self, name: &str, error: &AgentError) -> SupervisorAction;
}
#[derive(Debug, Clone)]
pub struct DefaultSupervisor {
max_restarts: u32,
}
impl DefaultSupervisor {
#[must_use]
pub const fn new(max_restarts: u32) -> Self {
Self { max_restarts }
}
#[must_use]
pub const fn max_restarts(&self) -> u32 {
self.max_restarts
}
}
impl Default for DefaultSupervisor {
fn default() -> Self {
Self { max_restarts: 3 }
}
}
impl SupervisorPolicy for DefaultSupervisor {
fn on_agent_error(&self, _name: &str, error: &AgentError) -> SupervisorAction {
if error.is_retryable() {
SupervisorAction::Restart
} else {
SupervisorAction::Stop
}
}
}
struct AgentEntry {
options_factory: OptionsFactoryArc,
parent: Option<String>,
children: Vec<String>,
max_restarts: u32,
}
pub struct OrchestratedHandle {
name: String,
request_tx: mpsc::Sender<AgentRequest>,
core: TaskCore,
}
impl OrchestratedHandle {
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
pub async fn send_message(&self, text: impl Into<String>) -> Result<AgentResult, AgentError> {
let msg = AgentMessage::Llm(LlmMessage::User(UserMessage {
content: vec![ContentBlock::Text { text: text.into() }],
timestamp: now_timestamp(),
cache_hint: None,
}));
self.send_messages(vec![msg]).await
}
pub async fn send_messages(
&self,
messages: Vec<AgentMessage>,
) -> Result<AgentResult, AgentError> {
let (reply_tx, reply_rx) = oneshot::channel();
let request = AgentRequest {
messages,
reply: reply_tx,
};
self.request_tx.send(request).await.map_err(|_| {
AgentError::plugin(
"orchestrator",
std::io::Error::other("agent channel closed"),
)
})?;
reply_rx.await.map_err(|_| {
AgentError::plugin("orchestrator", std::io::Error::other("agent reply dropped"))
})?
}
pub async fn await_result(self) -> Result<AgentResult, AgentError> {
drop(self.request_tx);
self.core.result().await
}
pub fn cancel(&self) {
self.core.cancel();
}
pub fn status(&self) -> AgentStatus {
self.core.status()
}
pub fn is_done(&self) -> bool {
self.core.is_done()
}
}
impl std::fmt::Debug for OrchestratedHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OrchestratedHandle")
.field("name", &self.name)
.field("status", &self.status())
.finish_non_exhaustive()
}
}
pub struct AgentOrchestrator {
entries: HashMap<String, AgentEntry>,
supervisor: Option<Arc<dyn SupervisorPolicy>>,
channel_buffer: usize,
default_max_restarts: u32,
}
impl AgentOrchestrator {
#[must_use]
pub fn new() -> Self {
Self {
entries: HashMap::new(),
supervisor: None,
channel_buffer: 32,
default_max_restarts: 3,
}
}
#[must_use]
pub fn with_supervisor(mut self, policy: impl SupervisorPolicy + 'static) -> Self {
self.supervisor = Some(Arc::new(policy));
self
}
#[must_use]
pub const fn with_channel_buffer(mut self, size: usize) -> Self {
self.channel_buffer = size;
self
}
#[must_use]
pub const fn with_max_restarts(mut self, max: u32) -> Self {
self.default_max_restarts = max;
self
}
pub fn add_agent(
&mut self,
name: impl Into<String>,
options_factory: impl Fn() -> AgentOptions + Send + Sync + 'static,
) {
let name = name.into();
assert!(
!self.entries.contains_key(&name),
"agent '{name}' already registered"
);
self.entries.insert(
name,
AgentEntry {
options_factory: Arc::new(options_factory),
parent: None,
children: Vec::new(),
max_restarts: self.default_max_restarts,
},
);
}
pub fn add_child(
&mut self,
name: impl Into<String>,
parent: impl Into<String>,
options_factory: impl Fn() -> AgentOptions + Send + Sync + 'static,
) {
let name = name.into();
let parent = parent.into();
assert!(
self.entries.contains_key(&parent),
"parent agent '{parent}' not registered"
);
assert!(
!self.entries.contains_key(&name),
"agent '{name}' already registered"
);
self.entries
.get_mut(&parent)
.expect("parent checked above")
.children
.push(name.clone());
self.entries.insert(
name,
AgentEntry {
options_factory: Arc::new(options_factory),
parent: Some(parent),
children: Vec::new(),
max_restarts: self.default_max_restarts,
},
);
}
#[must_use]
pub fn parent_of(&self, name: &str) -> Option<&str> {
self.entries.get(name).and_then(|e| e.parent.as_deref())
}
#[must_use]
pub fn children_of(&self, name: &str) -> Option<&[String]> {
self.entries.get(name).map(|e| e.children.as_slice())
}
#[must_use]
pub fn names(&self) -> Vec<&str> {
self.entries.keys().map(String::as_str).collect()
}
#[must_use]
pub fn contains(&self, name: &str) -> bool {
self.entries.contains_key(name)
}
pub fn spawn(&self, name: &str) -> Result<OrchestratedHandle, AgentError> {
let entry = self.entries.get(name).ok_or_else(|| {
AgentError::plugin(
"orchestrator",
std::io::Error::other(format!("agent not registered: {name}")),
)
})?;
let factory = Arc::clone(&entry.options_factory);
let max_restarts = entry.max_restarts;
let agent_name = name.to_owned();
let supervisor = self.supervisor.clone();
let (request_tx, request_rx) = mpsc::channel::<AgentRequest>(self.channel_buffer);
let cancellation_token = CancellationToken::new();
let status = Arc::new(Mutex::new(AgentStatus::Running));
let status_clone = Arc::clone(&status);
let token_clone = cancellation_token.clone();
let join_handle = tokio::spawn(run_agent_loop(
agent_name,
factory,
request_rx,
token_clone,
status_clone,
supervisor,
max_restarts,
));
Ok(OrchestratedHandle {
name: name.to_owned(),
request_tx,
core: TaskCore::new(join_handle, cancellation_token, status),
})
}
}
async fn run_agent_loop(
agent_name: String,
factory: OptionsFactoryArc,
mut request_rx: mpsc::Receiver<AgentRequest>,
cancellation_token: CancellationToken,
status: Arc<Mutex<AgentStatus>>,
supervisor: Option<Arc<dyn SupervisorPolicy>>,
max_restarts: u32,
) -> Result<AgentResult, AgentError> {
let mut agent = Agent::new(factory());
let mut restarts: u32 = 0;
let final_result = loop {
tokio::select! {
biased;
() = cancellation_token.cancelled() => {
agent.abort();
break Err(AgentError::Aborted);
}
maybe_req = request_rx.recv() => {
if let Some(req) = maybe_req {
let result = tokio::select! {
biased;
() = cancellation_token.cancelled() => {
agent.abort();
let _ = req.reply.send(Err(AgentError::Aborted));
break Err(AgentError::Aborted);
}
r = agent.prompt_async(req.messages) => r,
};
match result {
Ok(r) => {
let _ = req.reply.send(Ok(r));
restarts = 0;
}
Err(err) => {
let action = supervisor
.as_ref()
.map_or(SupervisorAction::Escalate, |s| {
s.on_agent_error(&agent_name, &err)
});
match action {
SupervisorAction::Restart if restarts < max_restarts => {
warn!(
agent = %agent_name,
restart = restarts + 1,
max = max_restarts,
"supervisor restarting agent"
);
restarts += 1;
let _ = req.reply.send(Err(err));
agent = Agent::new(factory());
}
SupervisorAction::Escalate => {
let _ = req.reply.send(Err(err));
}
_ => {
let _ = req.reply.send(Err(err));
break Err(AgentError::plugin(
"orchestrator",
std::io::Error::other(format!(
"agent '{agent_name}' stopped by supervisor"
)),
));
}
}
}
}
} else {
info!(agent = %agent_name, "request channel closed, shutting down");
break Ok(AgentResult {
messages: Vec::new(),
stop_reason: crate::types::StopReason::Stop,
usage: crate::types::Usage::default(),
cost: crate::types::Cost::default(),
error: None,
transfer_signal: None,
});
}
}
}
};
*status.lock().unwrap_or_else(PoisonError::into_inner) = resolve_status(&final_result);
final_result
}
impl Default for AgentOrchestrator {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for AgentOrchestrator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AgentOrchestrator")
.field("agents", &self.entries.keys().collect::<Vec<_>>())
.field(
"supervisor",
&if self.supervisor.is_some() {
"Some"
} else {
"None"
},
)
.field("channel_buffer", &self.channel_buffer)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use std::panic::AssertUnwindSafe;
use super::*;
#[test]
fn add_agent_and_names() {
let mut orch = AgentOrchestrator::new();
orch.add_agent("alpha", || panic!("not called"));
orch.add_agent("beta", || panic!("not called"));
let mut names = orch.names();
names.sort_unstable();
assert_eq!(names, vec!["alpha", "beta"]);
}
#[test]
fn contains_registered() {
let mut orch = AgentOrchestrator::new();
orch.add_agent("a", || panic!("not called"));
assert!(orch.contains("a"));
assert!(!orch.contains("b"));
}
#[test]
fn parent_child_hierarchy() {
let mut orch = AgentOrchestrator::new();
orch.add_agent("parent", || panic!("not called"));
orch.add_child("child1", "parent", || panic!("not called"));
orch.add_child("child2", "parent", || panic!("not called"));
assert_eq!(orch.parent_of("child1"), Some("parent"));
assert_eq!(orch.parent_of("child2"), Some("parent"));
assert_eq!(orch.parent_of("parent"), None);
let children = orch.children_of("parent").unwrap();
assert_eq!(children, &["child1", "child2"]);
assert!(orch.children_of("child1").unwrap().is_empty());
}
#[test]
#[should_panic(expected = "parent agent 'missing' not registered")]
fn add_child_missing_parent_panics() {
let mut orch = AgentOrchestrator::new();
orch.add_child("child", "missing", || panic!("not called"));
}
#[test]
#[should_panic(expected = "agent 'alpha' already registered")]
fn add_agent_duplicate_name_panics() {
let mut orch = AgentOrchestrator::new();
orch.add_agent("alpha", || panic!("not called"));
orch.add_agent("alpha", || panic!("not called"));
}
#[test]
fn duplicate_child_registration_preserves_existing_hierarchy() {
let mut orch = AgentOrchestrator::new();
orch.add_agent("parent1", || panic!("not called"));
orch.add_agent("parent2", || panic!("not called"));
orch.add_child("child", "parent1", || panic!("not called"));
let duplicate = std::panic::catch_unwind(AssertUnwindSafe(|| {
orch.add_child("child", "parent2", || panic!("not called"));
}));
assert!(duplicate.is_err());
assert_eq!(orch.parent_of("child"), Some("parent1"));
assert_eq!(orch.children_of("parent1").unwrap(), &["child"]);
assert!(orch.children_of("parent2").unwrap().is_empty());
}
#[test]
fn duplicate_top_level_registration_preserves_child_link() {
let mut orch = AgentOrchestrator::new();
orch.add_agent("parent", || panic!("not called"));
orch.add_child("child", "parent", || panic!("not called"));
let duplicate = std::panic::catch_unwind(AssertUnwindSafe(|| {
orch.add_agent("child", || panic!("not called"));
}));
assert!(duplicate.is_err());
assert_eq!(orch.parent_of("child"), Some("parent"));
assert_eq!(orch.children_of("parent").unwrap(), &["child"]);
}
#[test]
fn spawn_unregistered_agent_errors() {
let orch = AgentOrchestrator::new();
let result = orch.spawn("nonexistent");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(format!("{err}").contains("orchestrator"));
}
#[test]
fn default_supervisor_retryable_restarts() {
let supervisor = DefaultSupervisor::default();
assert_eq!(supervisor.max_restarts(), 3);
let retryable = AgentError::ModelThrottled;
assert_eq!(
supervisor.on_agent_error("test", &retryable),
SupervisorAction::Restart
);
let non_retryable = AgentError::Aborted;
assert_eq!(
supervisor.on_agent_error("test", &non_retryable),
SupervisorAction::Stop
);
}
#[test]
fn supervisor_action_variants() {
assert_eq!(format!("{:?}", SupervisorAction::Restart), "Restart");
assert_eq!(format!("{:?}", SupervisorAction::Stop), "Stop");
assert_eq!(format!("{:?}", SupervisorAction::Escalate), "Escalate");
}
#[test]
fn orchestrator_debug_format() {
let orch = AgentOrchestrator::new();
let debug = format!("{orch:?}");
assert!(debug.contains("AgentOrchestrator"));
assert!(debug.contains("channel_buffer"));
}
#[test]
fn with_supervisor_sets_policy() {
let orch = AgentOrchestrator::new().with_supervisor(DefaultSupervisor::default());
assert!(orch.supervisor.is_some());
}
#[test]
fn with_channel_buffer_sets_size() {
let orch = AgentOrchestrator::new().with_channel_buffer(64);
assert_eq!(orch.channel_buffer, 64);
}
#[test]
fn with_max_restarts_sets_default() {
let mut orch = AgentOrchestrator::new().with_max_restarts(5);
orch.add_agent("a", || panic!("not called"));
assert_eq!(orch.entries["a"].max_restarts, 5);
}
#[test]
fn default_impl() {
let orch = AgentOrchestrator::default();
assert!(orch.entries.is_empty());
assert!(orch.supervisor.is_none());
}
#[test]
fn custom_supervisor_policy() {
struct AlwaysEscalate;
impl SupervisorPolicy for AlwaysEscalate {
fn on_agent_error(&self, _name: &str, _error: &AgentError) -> SupervisorAction {
SupervisorAction::Escalate
}
}
let supervisor = AlwaysEscalate;
assert_eq!(
supervisor.on_agent_error("x", &AgentError::ModelThrottled),
SupervisorAction::Escalate
);
}
#[test]
fn grandchild_hierarchy() {
let mut orch = AgentOrchestrator::new();
orch.add_agent("root", || panic!("not called"));
orch.add_child("mid", "root", || panic!("not called"));
orch.add_child("leaf", "mid", || panic!("not called"));
assert_eq!(orch.parent_of("leaf"), Some("mid"));
assert_eq!(orch.parent_of("mid"), Some("root"));
assert_eq!(orch.parent_of("root"), None);
assert_eq!(orch.children_of("root").unwrap(), &["mid"]);
assert_eq!(orch.children_of("mid").unwrap(), &["leaf"]);
assert!(orch.children_of("leaf").unwrap().is_empty());
}
}