pub mod config;
pub mod error;
pub(crate) mod driver;
pub(crate) mod transport;
pub use config::{AcpSubagentsConfig, SubagentConfig, SubagentPresetConfig};
pub use error::{AcpClientError, HandshakeStep};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use agent_client_protocol::{
Agent, Client, SessionMessage, on_receive_notification, on_receive_request,
schema::{
RequestPermissionOutcome, RequestPermissionRequest, RequestPermissionResponse,
SelectedPermissionOutcome, SessionId, SessionNotification, StopReason,
},
};
use futures::channel::mpsc;
use tokio::sync::oneshot;
use tracing::Instrument;
use driver::SubagentCommand;
#[derive(Debug, Clone)]
pub struct RunOutcome {
pub text: String,
pub stop_reason: StopReason,
}
pub struct SubagentHandle {
cmd_tx: mpsc::UnboundedSender<SubagentCommand>,
join_handle: tokio::task::JoinHandle<()>,
session_id: SessionId,
closed: bool,
prompt_timeout: Duration,
}
impl SubagentHandle {
#[must_use]
pub fn session_id(&self) -> &SessionId {
&self.session_id
}
#[cfg(test)]
pub(crate) fn new_for_test(
cmd_tx: mpsc::UnboundedSender<SubagentCommand>,
join_handle: tokio::task::JoinHandle<()>,
session_id: SessionId,
) -> Self {
Self {
cmd_tx,
join_handle,
session_id,
closed: false,
prompt_timeout: Duration::from_secs(30),
}
}
pub async fn send_prompt(&mut self, text: impl Into<String>) -> Result<(), AcpClientError> {
if self.closed {
return Err(AcpClientError::Closed);
}
let span = tracing::info_span!("acp.client.prompt");
async {
let (tx, rx) = oneshot::channel();
self.cmd_tx
.unbounded_send(SubagentCommand::Prompt {
text: text.into(),
reply: tx,
})
.map_err(|_| AcpClientError::DriverDied)?;
rx.await.map_err(|_| AcpClientError::DriverDied)?
}
.instrument(span)
.await
}
pub async fn read_update(&mut self) -> Result<SessionMessage, AcpClientError> {
if self.closed {
return Err(AcpClientError::Closed);
}
let span = tracing::info_span!("acp.client.read_update");
async {
let (tx, rx) = oneshot::channel();
self.cmd_tx
.unbounded_send(SubagentCommand::ReadUpdate { reply: tx })
.map_err(|_| AcpClientError::DriverDied)?;
rx.await.map_err(|_| AcpClientError::DriverDied)?
}
.instrument(span)
.await
}
pub async fn read_to_string(&mut self) -> Result<RunOutcome, AcpClientError> {
if self.closed {
return Err(AcpClientError::Closed);
}
let timeout = self.prompt_timeout;
let span = tracing::info_span!("acp.client.read_to_string");
async {
let (tx, rx) = oneshot::channel();
self.cmd_tx
.unbounded_send(SubagentCommand::ReadToString { reply: tx })
.map_err(|_| AcpClientError::DriverDied)?;
tokio::time::timeout(timeout, rx)
.await
.map_err(|_| AcpClientError::Timeout)?
.map_err(|_| AcpClientError::DriverDied)?
}
.instrument(span)
.await
}
pub async fn send_cancel(&mut self) -> Result<(), AcpClientError> {
if self.closed {
return Err(AcpClientError::Closed);
}
let span = tracing::info_span!("acp.client.cancel");
async {
let (tx, rx) = oneshot::channel();
self.cmd_tx
.unbounded_send(SubagentCommand::Cancel { reply: tx })
.map_err(|_| AcpClientError::DriverDied)?;
rx.await.map_err(|_| AcpClientError::DriverDied)?
}
.instrument(span)
.await
}
pub async fn close(&mut self) -> Result<(), AcpClientError> {
if self.closed {
return Err(AcpClientError::Closed);
}
self.closed = true;
let span = tracing::info_span!("acp.client.close");
async {
let (tx, rx) = oneshot::channel();
let _ = self
.cmd_tx
.unbounded_send(SubagentCommand::Close { ack: tx });
let _ = tokio::time::timeout(Duration::from_secs(5), rx).await;
self.join_handle.abort();
Ok(())
}
.instrument(span)
.await
}
}
impl Drop for SubagentHandle {
fn drop(&mut self) {
self.join_handle.abort();
}
}
pub async fn spawn_subagent(cfg: SubagentConfig) -> Result<SubagentHandle, AcpClientError> {
let span = tracing::info_span!("acp.client.connect");
spawn_subagent_inner(cfg).instrument(span).await
}
async fn spawn_subagent_inner(cfg: SubagentConfig) -> Result<SubagentHandle, AcpClientError> {
let spawned = transport::spawn_child(&cfg)?;
let (cmd_tx, cmd_rx) = mpsc::unbounded::<SubagentCommand>();
let (ready_tx, ready_rx) = oneshot::channel::<Result<SessionId, AcpClientError>>();
let ready_slot = Arc::new(Mutex::new(Some(ready_tx)));
let transport = transport::make_byte_streams(spawned.stdin, spawned.stdout);
let auto_approve = cfg.auto_approve_permissions;
let handshake_timeout = Duration::from_secs(cfg.handshake_timeout_secs);
let prompt_timeout = Duration::from_secs(cfg.prompt_timeout_secs);
let ready_slot_clone = ready_slot.clone();
let cfg_clone = cfg.clone();
let child = spawned.child;
let stderr_task = transport::spawn_stderr_drain(spawned.stderr, "pending".to_owned());
let join_handle =
tokio::spawn(async move {
let result = Client
.builder()
.on_receive_notification(
async move |_notif: SessionNotification, _cx| Ok(()),
on_receive_notification!(),
)
.on_receive_request(
async move |req: RequestPermissionRequest,
responder: agent_client_protocol::Responder<RequestPermissionResponse>,
_cx: agent_client_protocol::ConnectionTo<Agent>| {
let outcome = if auto_approve {
if let Some(opt) = req.options.first() {
RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new(
opt.option_id.clone(),
))
} else {
RequestPermissionOutcome::Cancelled
}
} else {
RequestPermissionOutcome::Cancelled
};
let _ = responder.respond(RequestPermissionResponse::new(outcome));
Ok(())
},
on_receive_request!(),
)
.connect_with(transport, move |cx: agent_client_protocol::ConnectionTo<Agent>| {
let ready_slot = ready_slot_clone;
let cfg = cfg_clone;
async move {
driver::run_driver(cx, cmd_rx, ready_slot, cfg, child, stderr_task).await
}
})
.await;
if let Err(e) = result {
tracing::debug!(error = %e, "acp.client.connect: transport closed");
}
});
let session_id = match tokio::time::timeout(handshake_timeout, ready_rx).await {
Ok(Ok(Ok(id))) => id,
Ok(Ok(Err(e))) => {
join_handle.abort();
return Err(e);
}
Ok(Err(_)) => {
join_handle.abort();
return Err(AcpClientError::DriverDied);
}
Err(_) => {
join_handle.abort();
return Err(AcpClientError::Timeout);
}
};
Ok(SubagentHandle {
cmd_tx,
join_handle,
session_id,
closed: false,
prompt_timeout,
})
}
pub async fn run_session(
cfg: SubagentConfig,
prompt: impl Into<String>,
) -> Result<RunOutcome, AcpClientError> {
let span = tracing::info_span!("acp.client.session.run");
run_session_inner(cfg, prompt.into()).instrument(span).await
}
async fn run_session_inner(
cfg: SubagentConfig,
prompt: String,
) -> Result<RunOutcome, AcpClientError> {
let session_timeout = Duration::from_secs(cfg.session_timeout_secs);
let mut handle = spawn_subagent(cfg).await?;
let result = tokio::time::timeout(session_timeout, async {
handle.send_prompt(prompt).await?;
handle.read_to_string().await
})
.await
.map_err(|_| AcpClientError::Timeout)?;
let _ = handle.close().await;
result
}
#[cfg(test)]
mod tests {
include!("tests.rs");
}