use std::collections::HashMap;
use std::process::Stdio;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::Arc;
use anyhow::{bail, Context, Result};
use serde_json::{json, Value};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, Command};
use tokio::sync::{mpsc, oneshot, Mutex};
use crate::config::Config;
pub(crate) struct Agent {
stdin: Mutex<ChildStdin>,
next_id: AtomicI64,
pending: Arc<Mutex<HashMap<i64, oneshot::Sender<Value>>>>,
updates_rx: Mutex<Option<mpsc::UnboundedReceiver<Value>>>,
child: Mutex<Child>
}
impl Agent {
pub(crate) fn take_updates(&self) -> mpsc::UnboundedReceiver<Value> {
self.updates_rx
.try_lock()
.expect("updates_rx already locked")
.take()
.expect("updates_rx already taken")
}
pub(crate) async fn request(&self, method: &str, params: Value) -> Result<Value> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel();
self.pending.lock().await.insert(id, tx);
let line = format!("{}\n", json!({ "jsonrpc": "2.0", "id": id, "method": method, "params": params }));
{
let mut stdin = self.stdin.lock().await;
stdin.write_all(line.as_bytes()).await?;
stdin.flush().await?;
}
let resp = rx.await.context("Agent closed before replying")?;
if let Some(err) = resp.get("error") {
bail!("Agent error: {err}");
}
Ok(resp.get("result").cloned().unwrap_or(Value::Null))
}
pub(crate) async fn respond(&self, id: Value, result: Value) -> Result<()> {
let line = format!("{}\n", json!({ "jsonrpc": "2.0", "id": id, "result": result }));
let mut stdin = self.stdin.lock().await;
stdin.write_all(line.as_bytes()).await?;
stdin.flush().await?;
Ok(())
}
pub(crate) async fn notify(&self, method: &str, params: Value) -> Result<()> {
let line = format!("{}\n", json!({ "jsonrpc": "2.0", "method": method, "params": params }));
let mut stdin = self.stdin.lock().await;
stdin.write_all(line.as_bytes()).await?;
stdin.flush().await?;
Ok(())
}
pub(crate) async fn shutdown(&self, session_id: Option<&str>) {
if let Some(sid) = session_id {
let _ = self.notify("session/cancel", json!({ "sessionId": sid })).await;
}
{
let mut stdin = self.stdin.lock().await;
let _ = stdin.shutdown().await;
}
let _ = tokio::time::timeout(std::time::Duration::from_millis(500), async {
let mut child = self.child.lock().await;
let _ = child.wait().await;
})
.await;
}
}
pub(crate) async fn spawn_agent(cfg: &Config) -> Result<Agent> {
let mut cmd = Command::new(&cfg.agent_cmd);
cmd.args(&cfg.agent_args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true);
let mut child = cmd
.spawn()
.with_context(|| format!("Failed to spawn `{}`", cfg.agent_cmd))?;
let stdin = child.stdin.take().expect("stdin");
let stdout = child.stdout.take().expect("stdout");
let stderr = child.stderr.take().expect("stderr");
tokio::spawn(async move {
let mut lines = BufReader::new(stderr).lines();
while let Ok(Some(line)) = lines.next_line().await {
eprintln!("[agent] {line}");
}
});
let pending: Arc<Mutex<HashMap<i64, oneshot::Sender<Value>>>> = Arc::new(Mutex::new(HashMap::new()));
let (updates_tx, updates_rx) = mpsc::unbounded_channel();
let debug_acp = std::env::var_os("MEZAME_DEBUG_ACP").is_some();
let pending_reader = pending.clone();
tokio::spawn(async move {
let mut lines = BufReader::new(stdout).lines();
while let Ok(Some(line)) = lines.next_line().await {
if debug_acp {
eprintln!("[acp<-] {line}");
}
let msg: Value = match serde_json::from_str(&line) {
Ok(v) => v,
Err(_) => continue };
let is_response = msg.get("result").is_some() || msg.get("error").is_some();
if is_response {
if let Some(id) = msg.get("id").and_then(Value::as_i64) {
if let Some(tx) = pending_reader.lock().await.remove(&id) {
let _ = tx.send(msg);
continue;
}
}
}
let _ = updates_tx.send(msg);
}
});
Ok(Agent {
stdin: Mutex::new(stdin),
next_id: AtomicI64::new(1),
pending,
updates_rx: Mutex::new(Some(updates_rx)),
child: Mutex::new(child)
})
}