use std::collections::HashMap;
use std::process::Stdio;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::Arc;
use anyhow::{bail, Context, Result};
use serde_json::{json, Value};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, Command};
use tokio::sync::{mpsc, oneshot, Mutex};
use crate::config::Config;
pub(crate) struct Agent {
stdin: Mutex<ChildStdin>,
next_id: AtomicI64,
pending: Arc<Mutex<HashMap<i64, oneshot::Sender<Value>>>>,
child: Mutex<Child>,
#[cfg(unix)]
pgid: i32,
}
impl Agent {
async fn write_message(&self, msg: Value) -> Result<()> {
let line = format!("{msg}\n");
let mut stdin = self.stdin.lock().await;
stdin.write_all(line.as_bytes()).await?;
stdin.flush().await?;
Ok(())
}
pub(crate) async fn request(&self, method: &str, params: Value) -> Result<Value> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel();
self.pending.lock().await.insert(id, tx);
self.write_message(json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
}))
.await?;
let resp = rx.await.context("Agent closed before replying")?;
if let Some(err) = resp.get("error") {
bail!("Agent error: {err}");
}
Ok(resp.get("result").cloned().unwrap_or(Value::Null))
}
pub(crate) async fn respond(&self, id: Value, result: Value) -> Result<()> {
self.write_message(json!({
"jsonrpc": "2.0",
"id": id,
"result": result,
}))
.await
}
pub(crate) async fn notify(&self, method: &str, params: Value) -> Result<()> {
self.write_message(json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
}))
.await
}
pub(crate) async fn shutdown(&self, session_id: Option<&str>) {
if let Some(sid) = session_id {
let _ = self
.notify("session/cancel", json!({ "sessionId": sid }))
.await;
}
{
let mut stdin = self.stdin.lock().await;
let _ = stdin.shutdown().await;
}
let _ = tokio::time::timeout(std::time::Duration::from_millis(500), async {
let mut child = self.child.lock().await;
let _ = child.wait().await;
})
.await;
self.kill_process_group();
}
#[cfg(unix)]
fn kill_process_group(&self) {
if self.pgid > 0 {
crate::unix::send_signal(-self.pgid, 9);
}
}
#[cfg(not(unix))]
fn kill_process_group(&self) {
}
}
#[cfg(unix)]
impl Drop for Agent {
fn drop(&mut self) {
if self.pgid > 0 {
crate::unix::send_signal(-self.pgid, 9);
}
}
}
pub(crate) async fn spawn_agent(cfg: &Config) -> Result<(Agent, mpsc::UnboundedReceiver<Value>)> {
let mut cmd = Command::new(&cfg.agent_cmd);
cmd.args(&cfg.agent_args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true);
#[cfg(unix)]
unsafe {
cmd.pre_exec(|| {
if crate::unix::new_session() == -1 {
return Err(std::io::Error::last_os_error());
}
Ok(())
});
}
let mut child = cmd
.spawn()
.with_context(|| format!("Failed to spawn `{}`", cfg.agent_cmd))?;
#[cfg(unix)]
let pgid = child.id().map(|id| id as i32).unwrap_or(0);
let stdin = child.stdin.take().expect("stdin");
let stdout = child.stdout.take().expect("stdout");
let stderr = child.stderr.take().expect("stderr");
tokio::spawn(async move {
let mut lines = BufReader::new(stderr).lines();
while let Ok(Some(line)) = lines.next_line().await {
eprintln!("[agent] {line}");
}
});
let pending: Arc<Mutex<HashMap<i64, oneshot::Sender<Value>>>> =
Arc::new(Mutex::new(HashMap::new()));
let (updates_tx, updates_rx) = mpsc::unbounded_channel();
let debug_acp = std::env::var_os("MEZAME_DEBUG_ACP").is_some();
let pending_reader = pending.clone();
tokio::spawn(async move {
let mut lines = BufReader::new(stdout).lines();
while let Ok(Some(line)) = lines.next_line().await {
if debug_acp {
eprintln!("[acp<-] {line}");
}
let msg: Value = match serde_json::from_str(&line) {
Ok(v) => v,
Err(_) => continue, };
let is_response = msg.get("result").is_some() || msg.get("error").is_some();
if is_response {
if let Some(id) = msg.get("id").and_then(Value::as_i64) {
if let Some(tx) = pending_reader.lock().await.remove(&id) {
let _ = tx.send(msg);
continue;
}
}
}
let _ = updates_tx.send(msg);
}
});
Ok((
Agent {
stdin: Mutex::new(stdin),
next_id: AtomicI64::new(1),
pending,
child: Mutex::new(child),
#[cfg(unix)]
pgid,
},
updates_rx,
))
}