use std::sync::{Arc, Mutex};
use std::time::Duration;
use agent_client_protocol::{
ActiveSession, Agent, ConnectionTo, SessionMessage,
schema::{
CancelNotification, ContentBlock, ContentChunk, InitializeRequest, ProtocolVersion,
SessionId, SessionNotification, SessionUpdate,
},
};
use futures::StreamExt;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use crate::client::{AcpClientError, HandshakeStep, RunOutcome, SubagentConfig};
type ReadySlot = Arc<Mutex<Option<oneshot::Sender<Result<SessionId, AcpClientError>>>>>;
pub(crate) enum SubagentCommand {
Prompt {
text: String,
reply: oneshot::Sender<Result<(), AcpClientError>>,
},
ReadUpdate {
reply: oneshot::Sender<Result<SessionMessage, AcpClientError>>,
},
ReadToString {
reply: oneshot::Sender<Result<RunOutcome, AcpClientError>>,
},
Cancel {
reply: oneshot::Sender<Result<(), AcpClientError>>,
},
Close { ack: oneshot::Sender<()> },
}
pub(crate) async fn run_driver(
cx: ConnectionTo<Agent>,
cmd_rx: futures::channel::mpsc::UnboundedReceiver<SubagentCommand>,
ready_slot: ReadySlot,
cfg: SubagentConfig,
mut child: tokio::process::Child,
stderr_task: JoinHandle<()>,
) -> Result<(), agent_client_protocol::Error> {
let cx_clone = cx.clone();
let init_result = tokio::time::timeout(
Duration::from_secs(cfg.handshake_timeout_secs),
cx.send_request(InitializeRequest::new(ProtocolVersion::V1))
.block_task(),
)
.await;
match init_result {
Ok(Ok(_)) => {}
Ok(Err(e)) => {
fire_handshake_err(&ready_slot, HandshakeStep::Initialize, e);
cleanup_child(&mut child, stderr_task).await;
return Ok(());
}
Err(_) => {
fire_handshake_err(
&ready_slot,
HandshakeStep::Initialize,
agent_client_protocol::Error::internal_error().data("initialize timed out"),
);
cleanup_child(&mut child, stderr_task).await;
return Ok(());
}
}
let session_cwd = cfg.effective_session_cwd();
let ready_for_body = ready_slot.clone();
let cmd_rx = Arc::new(Mutex::new(Some(cmd_rx)));
let cmd_rx_for_body = cmd_rx.clone();
let run_result = cx_clone
.build_session(session_cwd)
.block_task()
.run_until(async move |mut session: ActiveSession<'_, Agent>| {
if let Some(tx) = ready_for_body
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take()
{
let _ = tx.send(Ok(session.session_id().clone()));
}
let Some(mut cmd_rx) = cmd_rx_for_body
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take()
else {
return Ok(());
};
while let Some(cmd) = cmd_rx.next().await {
match cmd {
SubagentCommand::Prompt { text, reply } => {
let r = session.send_prompt(text).map_err(AcpClientError::Sdk);
let _ = reply.send(r);
}
SubagentCommand::ReadUpdate { reply } => {
let r = read_one_with_preemption(&mut session, &mut cmd_rx).await;
let _ = reply.send(r);
}
SubagentCommand::ReadToString { reply } => {
let r = drain_until_stop(&mut session, &mut cmd_rx).await;
let _ = reply.send(r);
}
SubagentCommand::Cancel { reply } => {
let r = session
.connection()
.send_notification(CancelNotification::new(
session.session_id().clone(),
))
.map_err(AcpClientError::Sdk);
let _ = reply.send(r);
}
SubagentCommand::Close { ack } => {
let _ = ack.send(());
break;
}
}
}
Ok(())
})
.await;
if let (Err(e), Some(tx)) = (
run_result.as_ref(),
ready_slot
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take(),
) {
let _ = tx.send(Err(AcpClientError::Handshake {
step: HandshakeStep::NewSession,
source: e.clone(),
}));
}
cleanup_child(&mut child, stderr_task).await;
run_result
}
pub(crate) async fn read_one_with_preemption(
session: &mut ActiveSession<'_, Agent>,
cmd_rx: &mut futures::channel::mpsc::UnboundedReceiver<SubagentCommand>,
) -> Result<SessionMessage, AcpClientError> {
loop {
tokio::select! {
biased;
maybe = cmd_rx.next() => match maybe {
Some(SubagentCommand::Cancel { reply }) => {
let r = session
.connection()
.send_notification(CancelNotification::new(session.session_id().clone()))
.map_err(AcpClientError::Sdk);
let _ = reply.send(r);
return session.read_update().await.map_err(AcpClientError::Sdk);
}
Some(SubagentCommand::Close { ack }) => {
let _ = ack.send(());
return Err(AcpClientError::Closed);
}
Some(other) => {
send_busy(other);
}
None => return Err(AcpClientError::Closed),
},
update = session.read_update() => {
return update.map_err(AcpClientError::Sdk);
}
}
}
}
pub(crate) async fn drain_until_stop(
session: &mut ActiveSession<'_, Agent>,
cmd_rx: &mut futures::channel::mpsc::UnboundedReceiver<SubagentCommand>,
) -> Result<RunOutcome, AcpClientError> {
use agent_client_protocol::util::MatchDispatch;
let mut text = String::new();
loop {
let update = {
loop {
tokio::select! {
biased;
maybe = cmd_rx.next() => match maybe {
Some(SubagentCommand::Cancel { reply }) => {
let r = session
.connection()
.send_notification(CancelNotification::new(session.session_id().clone()))
.map_err(AcpClientError::Sdk);
let _ = reply.send(r);
let upd = session.read_update().await.map_err(AcpClientError::Sdk)?;
break upd;
}
Some(SubagentCommand::Close { ack }) => {
let _ = ack.send(());
return Err(AcpClientError::Closed);
}
Some(other) => {
send_busy(other);
}
None => return Err(AcpClientError::Closed),
},
upd = session.read_update() => {
break upd.map_err(AcpClientError::Sdk)?;
}
}
}
};
match update {
SessionMessage::SessionMessage(dispatch) => {
MatchDispatch::new(dispatch)
.if_notification(async |notif: SessionNotification| {
match notif.update {
SessionUpdate::AgentMessageChunk(ContentChunk {
content: ContentBlock::Text(t),
..
}) => {
text.push_str(&t.text);
}
SessionUpdate::AgentThoughtChunk(_) => {
tracing::trace!(target: "acp.client.drain", "thought_chunk ignored");
}
SessionUpdate::ToolCall(ref tc) => {
tracing::trace!(
target: "acp.client.drain",
tool_call_id = ?tc.tool_call_id,
"tool_call ignored"
);
}
SessionUpdate::Plan(_) => {
tracing::trace!(target: "acp.client.drain", "plan ignored");
}
_ => {
tracing::debug!(target: "acp.client.drain", "unknown SessionUpdate variant ignored");
}
}
Ok(())
})
.await
.otherwise_ignore()
.map_err(AcpClientError::Sdk)?;
}
SessionMessage::StopReason(reason) => {
return Ok(RunOutcome {
text,
stop_reason: reason,
});
}
_ => {
tracing::debug!(target: "acp.client.drain", "unknown SessionMessage variant ignored");
}
}
}
}
fn fire_handshake_err(slot: &ReadySlot, step: HandshakeStep, source: agent_client_protocol::Error) {
if let Some(tx) = slot
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take()
{
let _ = tx.send(Err(AcpClientError::Handshake { step, source }));
}
}
async fn cleanup_child(child: &mut tokio::process::Child, stderr_task: JoinHandle<()>) {
let _ = child.start_kill();
let _ = child.wait().await;
stderr_task.abort();
}
fn send_busy(cmd: SubagentCommand) {
match cmd {
SubagentCommand::Prompt { reply, .. } => {
let _ = reply.send(Err(AcpClientError::DriverBusy));
}
SubagentCommand::ReadUpdate { reply } => {
let _ = reply.send(Err(AcpClientError::DriverBusy));
}
SubagentCommand::ReadToString { reply } => {
let _ = reply.send(Err(AcpClientError::DriverBusy));
}
SubagentCommand::Cancel { .. } => {
unreachable!("Cancel must be handled by the biased select arm, not send_busy");
}
SubagentCommand::Close { ack } => {
let _ = ack.send(());
}
}
}