use std::process::Stdio;
use std::sync::atomic::{AtomicI64, AtomicUsize, Ordering};
use std::sync::{Arc, Weak};
use scc::HashMap as SccHashMap;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::sync::{broadcast, mpsc, oneshot};
use agent_os_sidecar::protocol::{
self, EventPayload, NativeFrameCodec, NativePayloadCodec, OwnershipScope, ProtocolFrame,
RequestFrame, RequestPayload, ResponsePayload, SidecarRequestFrame, SidecarRequestPayload,
SidecarResponseFrame, SidecarResponsePayload, DEFAULT_MAX_FRAME_BYTES,
};
use crate::error::ClientError;
const EVENT_CHANNEL_CAPACITY: usize = 4096;
const SIDECAR_BIN_ENV: &str = "AGENT_OS_SIDECAR_BIN";
pub(crate) type SidecarCallback = Arc<
dyn Fn(
SidecarRequestPayload,
OwnershipScope,
) -> futures::future::BoxFuture<'static, Result<SidecarResponsePayload, ClientError>>
+ Send
+ Sync,
>;
pub struct SidecarTransport {
pub(crate) child: parking_lot::Mutex<Option<Child>>,
pub(crate) pending: SccHashMap<protocol::RequestId, oneshot::Sender<ResponsePayload>>,
pub(crate) request_counter: AtomicI64,
pub(crate) sidecar_request_counter: AtomicI64,
pub(crate) max_frame_bytes: AtomicUsize,
pub(crate) event_tx: broadcast::Sender<(OwnershipScope, EventPayload)>,
pub(crate) callbacks: SccHashMap<&'static str, SidecarCallback>,
pub(crate) writer_tx: mpsc::UnboundedSender<Vec<u8>>,
}
impl SidecarTransport {
pub(crate) async fn spawn() -> Result<Arc<Self>, ClientError> {
let bin = std::env::var(SIDECAR_BIN_ENV).unwrap_or_else(|_| "agent-os-sidecar".to_string());
let mut child = Command::new(&bin)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.kill_on_drop(true)
.spawn()
.map_err(|error| {
ClientError::Sidecar(format!("failed to spawn sidecar '{bin}': {error}"))
})?;
let stdin = child
.stdin
.take()
.ok_or_else(|| ClientError::Sidecar("sidecar stdin was not piped".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| ClientError::Sidecar("sidecar stdout was not piped".to_string()))?;
let (writer_tx, writer_rx) = mpsc::unbounded_channel();
let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
let transport = Arc::new(Self {
child: parking_lot::Mutex::new(Some(child)),
pending: SccHashMap::new(),
request_counter: AtomicI64::new(1),
sidecar_request_counter: AtomicI64::new(-1),
max_frame_bytes: AtomicUsize::new(DEFAULT_MAX_FRAME_BYTES),
event_tx,
callbacks: SccHashMap::new(),
writer_tx,
});
tokio::spawn(run_writer(stdin, writer_rx));
tokio::spawn(run_reader(Arc::downgrade(&transport), stdout));
Ok(transport)
}
pub(crate) fn next_request_id(&self) -> protocol::RequestId {
self.request_counter.fetch_add(1, Ordering::SeqCst)
}
pub(crate) fn next_sidecar_request_id(&self) -> protocol::RequestId {
self.sidecar_request_counter.fetch_sub(1, Ordering::SeqCst)
}
pub(crate) async fn request(
&self,
ownership: OwnershipScope,
payload: RequestPayload,
) -> Result<ResponsePayload, ClientError> {
let request_id = self.next_request_id();
let frame = ProtocolFrame::Request(RequestFrame::new(request_id, ownership, payload));
let bytes = self.encode_frame(&frame)?;
let (tx, rx) = oneshot::channel();
let _ = self.pending.insert(request_id, tx);
if self.writer_tx.send(bytes).is_err() {
self.pending.remove(&request_id);
return Err(ClientError::Sidecar("sidecar transport closed".to_string()));
}
rx.await
.map_err(|_| ClientError::Sidecar("sidecar transport disconnected".to_string()))
}
pub(crate) fn subscribe_events(&self) -> broadcast::Receiver<(OwnershipScope, EventPayload)> {
self.event_tx.subscribe()
}
pub(crate) fn register_callback(&self, key: &'static str, callback: SidecarCallback) {
let _ = self.callbacks.insert(key, callback);
}
fn encode_frame(&self, frame: &ProtocolFrame) -> Result<Vec<u8>, ClientError> {
let codec = NativeFrameCodec::with_payload_codec(
self.max_frame_bytes.load(Ordering::Relaxed),
NativePayloadCodec::Bare,
);
Ok(codec.encode(frame)?)
}
async fn handle_frame(&self, frame: ProtocolFrame) {
match frame {
ProtocolFrame::Response(response) => {
match self.pending.remove(&response.request_id) {
Some((_, tx)) => {
let _ = tx.send(response.payload);
}
None => {
tracing::warn!(request_id = response.request_id, "response for unknown request id")
}
}
}
ProtocolFrame::Event(event) => {
let _ = self.event_tx.send((event.ownership, event.payload));
}
ProtocolFrame::SidecarRequest(request) => self.dispatch_sidecar_request(request).await,
ProtocolFrame::SidecarResponse(_) | ProtocolFrame::Request(_) => {
tracing::warn!("unexpected inbound frame on host transport")
}
}
}
async fn dispatch_sidecar_request(&self, frame: SidecarRequestFrame) {
let key = sidecar_request_key(&frame.payload);
let callback = self.callbacks.read(&key, |_, value| value.clone());
match callback {
Some(callback) => match callback(frame.payload, frame.ownership.clone()).await {
Ok(payload) => {
let response = ProtocolFrame::SidecarResponse(SidecarResponseFrame::new(
frame.request_id,
frame.ownership,
payload,
));
if let Ok(bytes) = self.encode_frame(&response) {
let _ = self.writer_tx.send(bytes);
}
}
Err(error) => tracing::warn!(?error, key, "sidecar callback failed"),
},
None => tracing::warn!(key, "no callback registered for sidecar request"),
}
}
fn fail_all_pending(&self) {
self.pending.clear();
}
}
fn sidecar_request_key(payload: &SidecarRequestPayload) -> &'static str {
match payload {
SidecarRequestPayload::ToolInvocation(_) => "tool_invocation",
SidecarRequestPayload::PermissionRequest(_) => "permission_request",
SidecarRequestPayload::AcpRequest(_) => "acp_request",
SidecarRequestPayload::JsBridgeCall(_) => "js_bridge_call",
}
}
async fn run_writer(mut stdin: ChildStdin, mut writer_rx: mpsc::UnboundedReceiver<Vec<u8>>) {
while let Some(bytes) = writer_rx.recv().await {
if stdin.write_all(&bytes).await.is_err() {
break;
}
if stdin.flush().await.is_err() {
break;
}
}
}
async fn run_reader(transport: Weak<SidecarTransport>, mut stdout: ChildStdout) {
loop {
let mut length_buf = [0u8; 4];
if stdout.read_exact(&mut length_buf).await.is_err() {
break;
}
let length = u32::from_be_bytes(length_buf) as usize;
let mut frame_bytes = vec![0u8; 4 + length];
frame_bytes[..4].copy_from_slice(&length_buf);
if stdout.read_exact(&mut frame_bytes[4..]).await.is_err() {
break;
}
let Some(transport) = transport.upgrade() else {
break;
};
let codec = NativeFrameCodec::with_payload_codec(
transport.max_frame_bytes.load(Ordering::Relaxed),
NativePayloadCodec::Bare,
);
match codec.decode(&frame_bytes) {
Ok(frame) => transport.handle_frame(frame).await,
Err(error) => tracing::warn!(?error, "failed to decode sidecar frame"),
}
}
if let Some(transport) = transport.upgrade() {
transport.fail_all_pending();
}
}