use crate::error::Result;
use crate::orchestrator::{
agent::SubAgentEventStream, ControlSignal, OrchestratorEvent, SubAgentActivity, SubAgentConfig,
SubAgentState,
};
use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::sync::RwLock;
#[derive(Clone)]
pub struct SubAgentHandle {
pub id: String,
pub(crate) config: SubAgentConfig,
pub(crate) created_at: u64,
control_tx: tokio::sync::mpsc::Sender<ControlSignal>,
subagent_event_tx: broadcast::Sender<OrchestratorEvent>,
event_history: Arc<RwLock<std::collections::VecDeque<OrchestratorEvent>>>,
state: Arc<RwLock<SubAgentState>>,
pub(crate) activity: Arc<RwLock<SubAgentActivity>>,
#[allow(dead_code)]
task_handle: Arc<tokio::task::JoinHandle<Result<String>>>,
}
pub(crate) struct SubAgentHandleParts {
pub id: String,
pub config: SubAgentConfig,
pub control_tx: tokio::sync::mpsc::Sender<ControlSignal>,
pub subagent_event_tx: tokio::sync::broadcast::Sender<OrchestratorEvent>,
pub event_history: Arc<RwLock<std::collections::VecDeque<OrchestratorEvent>>>,
pub state: Arc<RwLock<SubAgentState>>,
pub activity: Arc<RwLock<SubAgentActivity>>,
pub task_handle: tokio::task::JoinHandle<Result<String>>,
}
impl SubAgentHandle {
pub(crate) fn new(parts: SubAgentHandleParts) -> Self {
Self {
id: parts.id,
config: parts.config,
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
control_tx: parts.control_tx,
subagent_event_tx: parts.subagent_event_tx,
event_history: parts.event_history,
state: parts.state,
activity: parts.activity,
task_handle: Arc::new(parts.task_handle),
}
}
pub fn state(&self) -> SubAgentState {
self.state
.try_read()
.map(|guard| guard.clone())
.unwrap_or(SubAgentState::Initializing)
}
pub async fn state_async(&self) -> SubAgentState {
self.state.read().await.clone()
}
pub async fn activity(&self) -> SubAgentActivity {
self.activity.read().await.clone()
}
pub fn config(&self) -> &SubAgentConfig {
&self.config
}
pub fn created_at(&self) -> u64 {
self.created_at
}
pub async fn send_control(&self, signal: ControlSignal) -> Result<()> {
self.control_tx
.send(signal)
.await
.map_err(|_| anyhow::anyhow!("Failed to send control signal: channel closed"))?;
Ok(())
}
pub async fn pause(&self) -> Result<()> {
self.send_control(ControlSignal::Pause).await
}
pub async fn resume(&self) -> Result<()> {
self.send_control(ControlSignal::Resume).await
}
pub async fn cancel(&self) -> Result<()> {
self.send_control(ControlSignal::Cancel).await
}
pub async fn adjust_params(
&self,
max_steps: Option<usize>,
timeout_ms: Option<u64>,
) -> Result<()> {
self.send_control(ControlSignal::AdjustParams {
max_steps,
timeout_ms,
})
.await
}
pub async fn inject_prompt(&self, prompt: impl Into<String>) -> Result<()> {
self.send_control(ControlSignal::InjectPrompt {
prompt: prompt.into(),
})
.await
}
pub async fn wait(&self) -> Result<String> {
loop {
let state = self.state_async().await;
if state.is_terminal() {
match state {
SubAgentState::Completed { output, .. } => return Ok(output),
SubAgentState::Cancelled => {
return Err(anyhow::anyhow!("SubAgent was cancelled").into())
}
SubAgentState::Error { message } => {
return Err(anyhow::anyhow!("SubAgent error: {}", message).into())
}
_ => unreachable!(),
}
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
}
pub fn is_done(&self) -> bool {
self.state().is_terminal()
}
pub fn is_running(&self) -> bool {
self.state().is_running()
}
pub fn is_paused(&self) -> bool {
self.state().is_paused()
}
pub fn events(&self) -> SubAgentEventStream {
let rx = self.subagent_event_tx.subscribe();
let history = self
.event_history
.try_read()
.map(|events| events.clone())
.unwrap_or_default();
SubAgentEventStream {
history,
rx,
filter_id: self.id.clone(),
}
}
}
impl std::fmt::Debug for SubAgentHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SubAgentHandle")
.field("id", &self.id)
.field("state", &self.state())
.finish()
}
}