use std::{collections::HashMap, process::Stdio, sync::Arc};
use anyhow::{Context, Result};
use async_trait::async_trait;
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
process::{Child, Command},
sync::{Mutex, broadcast, mpsc},
time::{Duration, timeout},
};
use crate::acp::{methods, notification::*, types::*};
pub const ACP_TIMEOUT: Duration = Duration::from_secs(60);
pub const LONG_TIMEOUT: Duration = Duration::from_secs(300);
#[async_trait]
pub trait AcpCallbackHandler: Send + Sync {
async fn handle_request_permission(
&self,
session_id: &SessionId,
tool_call_id: &str,
options: Vec<PermissionOption>,
) -> RequestPermissionOutcome;
async fn handle_read_text_file(&self, session_id: &SessionId, path: &str) -> Result<String>;
async fn handle_write_text_file(
&self,
session_id: &SessionId,
path: &str,
contents: &str,
) -> Result<()>;
async fn handle_terminal_create(
&self,
session_id: &SessionId,
command: Option<&str>,
args: Option<Vec<String>>,
) -> Result<String>;
async fn handle_terminal_output(
&self,
session_id: &SessionId,
terminal_id: &str,
) -> Result<TerminalOutputResponse>;
async fn handle_terminal_kill(&self, session_id: &SessionId, terminal_id: &str) -> Result<()>;
async fn handle_terminal_release(
&self,
session_id: &SessionId,
terminal_id: &str,
) -> Result<()>;
async fn handle_terminal_wait_for_exit(
&self,
session_id: &SessionId,
terminal_id: &str,
) -> Result<Option<i32>>;
}
pub struct DefaultAcpHandler;
#[async_trait]
impl AcpCallbackHandler for DefaultAcpHandler {
async fn handle_request_permission(
&self,
_session_id: &SessionId,
_tool_call_id: &str,
options: Vec<PermissionOption>,
) -> RequestPermissionOutcome {
for opt in options {
if matches!(
opt.kind,
PermissionOptionKind::AllowOnce | PermissionOptionKind::AllowAlways
) {
return RequestPermissionOutcome::Selected {
option_id: opt.option_id,
};
}
}
RequestPermissionOutcome::Cancelled
}
async fn handle_read_text_file(&self, _session_id: &SessionId, path: &str) -> Result<String> {
tokio::fs::read_to_string(path)
.await
.context("Failed to read file")
}
async fn handle_write_text_file(
&self,
_session_id: &SessionId,
path: &str,
contents: &str,
) -> Result<()> {
tokio::fs::write(path, contents)
.await
.context("Failed to write file")
}
async fn handle_terminal_create(
&self,
_session_id: &SessionId,
_command: Option<&str>,
_args: Option<Vec<String>>,
) -> Result<String> {
Err(anyhow::anyhow!(
"Terminal operations not implemented in DefaultAcpHandler. Implement custom AcpCallbackHandler to enable terminal support."
))
}
async fn handle_terminal_output(
&self,
_session_id: &SessionId,
_terminal_id: &str,
) -> Result<TerminalOutputResponse> {
Err(anyhow::anyhow!(
"Terminal operations not implemented in DefaultAcpHandler"
))
}
async fn handle_terminal_kill(
&self,
_session_id: &SessionId,
_terminal_id: &str,
) -> Result<()> {
Err(anyhow::anyhow!(
"Terminal operations not implemented in DefaultAcpHandler"
))
}
async fn handle_terminal_release(
&self,
_session_id: &SessionId,
_terminal_id: &str,
) -> Result<()> {
Err(anyhow::anyhow!(
"Terminal operations not implemented in DefaultAcpHandler"
))
}
async fn handle_terminal_wait_for_exit(
&self,
_session_id: &SessionId,
_terminal_id: &str,
) -> Result<Option<i32>> {
Err(anyhow::anyhow!(
"Terminal operations not implemented in DefaultAcpHandler"
))
}
}
pub struct DefaultAcpHandlerWithTerminal {
state: Arc<Mutex<AcpState>>,
}
impl DefaultAcpHandlerWithTerminal {
pub fn new() -> Self {
Self {
state: Arc::new(Mutex::new(AcpState {
next_id: 1,
session_id: None,
pending_requests: HashMap::new(),
handler: Arc::new(DefaultAcpHandler),
capabilities: None,
agent_info: None,
config_options: Vec::new(),
models: None,
terminals: HashMap::new(),
})),
}
}
}
#[async_trait]
impl AcpCallbackHandler for DefaultAcpHandlerWithTerminal {
async fn handle_request_permission(
&self,
_session_id: &SessionId,
_tool_call_id: &str,
options: Vec<PermissionOption>,
) -> RequestPermissionOutcome {
for opt in options {
if matches!(
opt.kind,
PermissionOptionKind::AllowOnce | PermissionOptionKind::AllowAlways
) {
return RequestPermissionOutcome::Selected {
option_id: opt.option_id,
};
}
}
RequestPermissionOutcome::Cancelled
}
async fn handle_read_text_file(&self, _session_id: &SessionId, path: &str) -> Result<String> {
tokio::fs::read_to_string(path)
.await
.context("Failed to read file")
}
async fn handle_write_text_file(
&self,
_session_id: &SessionId,
path: &str,
contents: &str,
) -> Result<()> {
tokio::fs::write(path, contents)
.await
.context("Failed to write file")
}
async fn handle_terminal_create(
&self,
_session_id: &SessionId,
command: Option<&str>,
_args: Option<Vec<String>>,
) -> Result<String> {
let shell = command.unwrap_or("powershell.exe");
let child = Command::new(shell)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.context("Failed to spawn terminal process")?;
let terminal_id = format!("terminal-{}", uuid::Uuid::new_v4());
let mut state = self.state.lock().await;
state.terminals.insert(terminal_id.clone(), child);
tracing::info!(terminal_id = %terminal_id, "terminal created");
Ok(terminal_id)
}
async fn handle_terminal_output(
&self,
_session_id: &SessionId,
terminal_id: &str,
) -> Result<TerminalOutputResponse> {
let mut state = self.state.lock().await;
let child = state
.terminals
.get_mut(terminal_id)
.ok_or_else(|| anyhow::anyhow!("terminal not found: {}", terminal_id))?;
let mut stdout_buf = String::new();
let mut stderr_buf = String::new();
if let Some(stdout) = child.stdout.as_mut() {
let mut reader = BufReader::new(stdout);
reader.read_line(&mut stdout_buf).await.ok();
}
if let Some(stderr) = child.stderr.as_mut() {
let mut reader = BufReader::new(stderr);
reader.read_line(&mut stderr_buf).await.ok();
}
let exit = match child.try_wait()? {
Some(status) => status.code(),
None => None,
};
tracing::debug!(terminal_id = %terminal_id, "terminal output read");
Ok(TerminalOutputResponse {
exit,
stdout: stdout_buf,
stderr: stderr_buf,
})
}
async fn handle_terminal_kill(&self, _session_id: &SessionId, terminal_id: &str) -> Result<()> {
let mut state = self.state.lock().await;
if let Some(mut child) = state.terminals.remove(terminal_id) {
child.kill().await.ok();
tracing::info!(terminal_id = %terminal_id, "terminal killed");
}
Ok(())
}
async fn handle_terminal_release(
&self,
_session_id: &SessionId,
terminal_id: &str,
) -> Result<()> {
let mut state = self.state.lock().await;
if let Some(mut child) = state.terminals.remove(terminal_id) {
child.wait().await.ok();
tracing::info!(terminal_id = %terminal_id, "terminal released");
}
Ok(())
}
async fn handle_terminal_wait_for_exit(
&self,
_session_id: &SessionId,
terminal_id: &str,
) -> Result<Option<i32>> {
let mut state = self.state.lock().await;
if let Some(child) = state.terminals.get_mut(terminal_id) {
let status = child.wait().await.context("Failed to wait for terminal")?;
let code = status.code();
tracing::info!(terminal_id = %terminal_id, exit_code = ?code, "terminal exited");
return Ok(code);
}
Ok(None)
}
}
#[derive(Debug, Clone)]
pub enum SessionEvent {
AgentMessageChunk { content: String },
AgentThoughtChunk { content: String },
ToolCallStarted {
tool_call_id: String,
title: Option<String>,
kind: ToolKind,
},
ToolCallInProgress { tool_call_id: String },
ToolCallCompleted {
tool_call_id: String,
result: Option<String>,
},
ToolCallFailed { tool_call_id: String, error: String },
ModeChanged { mode_id: String },
ConfigOptionUpdated { options: Vec<SessionConfigOption> },
SessionInfoUpdated {
title: Option<String>,
updated_at: Option<String>,
},
UsageUpdated { used: u32, size: u32 },
AvailableCommandsUpdated { commands: Vec<AvailableCommand> },
}
struct AcpState {
next_id: i64,
session_id: Option<SessionId>,
#[allow(dead_code)]
pending_requests: HashMap<i64, mpsc::Sender<Result<serde_json::Value>>>,
#[allow(dead_code)]
handler: Arc<dyn AcpCallbackHandler>,
capabilities: Option<AgentCapabilities>,
agent_info: Option<Implementation>,
config_options: Vec<SessionConfigOption>,
models: Option<SessionModels>,
terminals: HashMap<String, Child>,
}
#[derive(Debug)]
enum SubprocessCmd {
SendRequest {
request: String,
response_tx: mpsc::Sender<Result<serde_json::Value>>,
},
Shutdown,
}
#[derive(Debug)]
#[allow(dead_code)]
enum SubprocessEvent {
Response {
id: i64,
result: Result<serde_json::Value>,
},
SessionUpdate {
session_id: SessionId,
event: SessionEvent,
},
AgentRequest {
request: serde_json::Value,
},
}
#[derive(Clone)]
pub struct AcpClient {
cmd_tx: Arc<Mutex<Option<mpsc::Sender<SubprocessCmd>>>>,
state: Arc<Mutex<AcpState>>,
collected_content: Arc<Mutex<String>>,
event_tx: broadcast::Sender<SessionEvent>,
notification_manager: Arc<Mutex<NotificationManager>>,
}
impl AcpClient {
pub async fn spawn(command: &str, args: &[&str]) -> Result<Self> {
Self::spawn_with_handler(
command,
args,
Arc::new(DefaultAcpHandler),
Arc::new(Mutex::new(NotificationManager::new())),
)
.await
}
pub async fn spawn_with_handler(
command: &str,
args: &[&str],
handler: Arc<dyn AcpCallbackHandler>,
notification_manager: Arc<Mutex<NotificationManager>>,
) -> Result<Self> {
let command_owned = command.to_string();
let args_owned: Vec<String> = args.iter().map(|s| s.to_string()).collect();
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let cmd_tx_clone = cmd_tx.clone();
let collected = Arc::new(Mutex::new(String::new()));
let collected_clone = collected.clone();
let handler_clone = handler.clone();
let notification_manager_clone = notification_manager.clone();
let (event_tx, _) = broadcast::channel(256);
let event_tx_clone = event_tx.clone();
tokio::spawn(async move {
run_subprocess(
&command_owned,
&args_owned,
cmd_rx,
collected_clone,
handler_clone,
event_tx_clone,
notification_manager_clone,
)
.await;
});
tokio::time::sleep(Duration::from_millis(200)).await;
Ok(Self {
cmd_tx: Arc::new(Mutex::new(Some(cmd_tx_clone))),
state: Arc::new(Mutex::new(AcpState {
next_id: 1,
session_id: None,
pending_requests: HashMap::new(),
handler,
capabilities: None,
agent_info: None,
config_options: Vec::new(),
models: None,
terminals: HashMap::new(),
})),
collected_content: collected,
event_tx,
notification_manager,
})
}
pub fn subscribe_events(&self) -> broadcast::Receiver<SessionEvent> {
self.event_tx.subscribe()
}
pub fn add_notification_sink(&self, sink: Arc<dyn NotificationSink>) {
if let Ok(mut guard) = self.notification_manager.try_lock() {
guard.add_sink(sink);
}
}
pub async fn get_collected_content(&self) -> String {
self.collected_content.lock().await.clone()
}
pub async fn clear_collected_content(&self) {
self.collected_content.lock().await.clear();
}
pub async fn session_id(&self) -> Option<SessionId> {
self.state.lock().await.session_id.clone()
}
pub async fn capabilities(&self) -> Option<AgentCapabilities> {
self.state.lock().await.capabilities.clone()
}
pub async fn agent_info(&self) -> Option<Implementation> {
self.state.lock().await.agent_info.clone()
}
pub async fn config_options(&self) -> Vec<SessionConfigOption> {
self.state.lock().await.config_options.clone()
}
pub async fn models(&self) -> Option<SessionModels> {
self.state.lock().await.models.clone()
}
pub async fn supports_load_session(&self) -> bool {
self.state
.lock()
.await
.capabilities
.as_ref()
.and_then(|c| c.load_session)
.unwrap_or(false)
}
pub async fn supports_images(&self) -> bool {
self.state
.lock()
.await
.capabilities
.as_ref()
.and_then(|c| c.prompt_capabilities.as_ref())
.map(|p| p.image)
.unwrap_or(false)
}
pub async fn supports_audio(&self) -> bool {
self.state
.lock()
.await
.capabilities
.as_ref()
.and_then(|c| c.prompt_capabilities.as_ref())
.map(|p| p.audio)
.unwrap_or(false)
}
pub async fn supports_embedded_context(&self) -> bool {
self.state
.lock()
.await
.capabilities
.as_ref()
.and_then(|c| c.prompt_capabilities.as_ref())
.map(|p| p.embedded_context)
.unwrap_or(false)
}
pub async fn initialize(
&self,
client_name: &str,
client_version: &str,
) -> Result<InitializeResponse> {
let params = serde_json::json!({
"protocolVersion": PROTOCOL_VERSION,
"clientInfo": {"name": client_name, "version": client_version},
"clientCapabilities": {
"fs": { "readTextFile": true, "writeTextFile": true },
"terminal": true
}
});
let resp = self.rpc(methods::INITIALIZE, params).await?;
tracing::info!(response = ?resp, "ACP initialize response");
let result = resp
.get("result")
.cloned()
.unwrap_or(serde_json::Value::Null);
let init_resp: InitializeResponse =
serde_json::from_value(result).context("Failed to parse initialize response")?;
let mut state = self.state.lock().await;
state.capabilities = Some(init_resp.agent_capabilities.clone());
state.agent_info = Some(init_resp.agent_info.clone());
Ok(init_resp)
}
pub async fn create_session(
&self,
cwd: &str,
model: Option<&str>,
mcp_servers: Option<Vec<McpServerConfig>>,
) -> Result<NewSessionResponse> {
let mut params = serde_json::json!({
"cwd": cwd,
"mcpServers": mcp_servers.unwrap_or_default()
});
if let Some(m) = model {
params["modelId"] = serde_json::json!(m);
tracing::info!(model = %m, "Adding modelId to session/new request");
}
let resp = self.rpc(methods::SESSION_NEW, params).await?;
tracing::info!(response = ?resp, "ACP session/new response");
let result = resp
.get("result")
.cloned()
.unwrap_or(serde_json::Value::Null);
let session_resp: NewSessionResponse =
serde_json::from_value(result).context("Failed to parse session/new response")?;
let mut state = self.state.lock().await;
state.session_id = Some(session_resp.session_id.clone());
state.config_options = session_resp.config_options.clone().unwrap_or_default();
state.models = session_resp.models.clone();
if let Some(ref models) = session_resp.models {
tracing::info!("Available models: {:?}", models.available_models);
}
if let Some(ref opts) = session_resp.config_options {
tracing::info!("Config options available: {}", opts.len());
for opt in opts {
tracing::info!(
" - {}: {} (current: {})",
opt.id,
opt.name,
opt.current_value
);
}
}
Ok(session_resp)
}
pub async fn load_session(
&self,
session_id: &SessionId,
cwd: Option<&str>,
mcp_servers: Option<Vec<McpServerConfig>>,
) -> Result<LoadSessionResponse> {
let mut params = serde_json::json!({
"sessionId": session_id,
});
if let Some(cwd) = cwd {
params["cwd"] = serde_json::json!(cwd);
}
params["mcpServers"] = serde_json::json!(mcp_servers.unwrap_or_default());
let resp = self.rpc(methods::SESSION_LOAD, params).await?;
let result = resp
.get("result")
.cloned()
.unwrap_or(serde_json::Value::Null);
let load_resp: LoadSessionResponse =
serde_json::from_value(result).context("Failed to parse session/load response")?;
self.state.lock().await.session_id = Some(session_id.to_string());
Ok(load_resp)
}
pub async fn send_prompt(&self, prompt: &str) -> Result<PromptResponse> {
let session_id = self.session_id().await.context("No active session")?;
let params = serde_json::json!({
"sessionId": session_id,
"prompt": [{"type": "text", "text": prompt}]
});
let resp = self
.rpc_no_timeout(methods::SESSION_PROMPT, params)
.await
.map_err(|e| {
if e.to_string().contains("timeout")
|| e.to_string().contains("Subprocess task died")
|| e.to_string().contains("Channel closed")
{
anyhow::anyhow!(
"OpenCode 执行超时或进程异常退出。请检查 OpenCode 是否正常运行后重试。"
)
} else {
e
}
})?;
tracing::info!("=== send_prompt raw response ===");
tracing::info!(
"Full response: {}",
serde_json::to_string(&resp).unwrap_or_default()
);
let result = resp
.get("result")
.cloned()
.unwrap_or(serde_json::Value::Null);
tracing::info!("=== send_prompt result ===");
tracing::info!(
"Result: {}",
serde_json::to_string(&result).unwrap_or_default()
);
let prompt_resp: PromptResponse =
serde_json::from_value(result.clone()).context("Failed to parse prompt response")?;
tracing::info!("=== send_prompt parsed ===");
tracing::info!("stop_reason: {:?}", prompt_resp.stop_reason);
tracing::info!("usage: {:?}", prompt_resp.usage);
if let Some(ref r) = prompt_resp.result {
tracing::info!("content blocks: {}", r.content.len());
for (i, block) in r.content.iter().enumerate() {
match block {
crate::acp::types::ContentBlock::Text { text } => {
tracing::info!(" [{}] Text: {}", i, text);
}
crate::acp::types::ContentBlock::Image { .. } => {
tracing::info!(" [{}] Image", i);
}
crate::acp::types::ContentBlock::Resource { .. } => {
tracing::info!(" [{}] Resource", i);
}
crate::acp::types::ContentBlock::ResourceLink { .. } => {
tracing::info!(" [{}] ResourceLink", i);
}
}
}
if let Some(ref calls) = r.tool_calls {
tracing::info!("tool_calls: {} calls", calls.len());
for (i, call) in calls.iter().enumerate() {
tracing::info!(" [{}] tool_call: id={}, name={}", i, call.id, call.name);
}
}
}
Ok(prompt_resp)
}
pub async fn send_prompt_with_content(
&self,
prompt: Vec<ContentBlock>,
) -> Result<PromptResponse> {
let session_id = self.session_id().await.context("No active session")?;
let params = serde_json::json!({
"sessionId": session_id,
"prompt": prompt
});
let resp = self.rpc_no_timeout(methods::SESSION_PROMPT, params).await?;
let result = resp
.get("result")
.cloned()
.unwrap_or(serde_json::Value::Null);
serde_json::from_value(result).context("Failed to parse prompt response")
}
pub async fn cancel_session(&self) -> Result<()> {
let session_id = self.session_id().await.context("No active session")?;
let params = serde_json::json!({
"sessionId": session_id
});
self.send_notification(methods::SESSION_CANCEL, params)
.await?;
Ok(())
}
pub async fn list_sessions(&self, cwd: Option<&str>) -> Result<ListSessionsResponse> {
let params = if let Some(cwd) = cwd {
serde_json::json!({ "cwd": cwd })
} else {
serde_json::json!({})
};
let resp = self.rpc(methods::SESSION_LIST, params).await?;
let result = resp
.get("result")
.cloned()
.unwrap_or(serde_json::Value::Null);
serde_json::from_value(result).context("Failed to parse session/list response")
}
pub async fn set_mode(&self, mode_id: &str) -> Result<()> {
let session_id = self.session_id().await.context("No active session")?;
let params = serde_json::json!({
"sessionId": session_id,
"modeId": mode_id
});
let _resp = self.rpc(methods::SESSION_SET_MODE, params).await?;
Ok(())
}
pub async fn set_model(&self, model_id: &str) -> Result<Vec<SessionConfigOption>> {
let session_id = self.session_id().await.context("No active session")?;
let params = serde_json::json!({
"sessionId": session_id,
"configId": "model",
"value": model_id
});
tracing::info!(model_id = %model_id, "Calling session/set_config_option");
let resp = self.rpc(methods::SESSION_SET_CONFIG_OPTION, params).await?;
tracing::info!(response = ?resp, "set_model response");
let result = resp
.get("result")
.cloned()
.unwrap_or(serde_json::Value::Null);
let config_resp: SetSessionConfigOptionResponse =
serde_json::from_value(result).context("Failed to parse config option response")?;
let options = config_resp.config_options.unwrap_or_default();
self.state.lock().await.config_options = options.clone();
Ok(options)
}
pub async fn set_config_option(
&self,
config_id: &str,
value: &str,
) -> Result<Vec<SessionConfigOption>> {
let session_id = self.session_id().await.context("No active session")?;
let params = serde_json::json!({
"sessionId": session_id,
"configId": config_id,
"value": value
});
let resp = self.rpc(methods::SESSION_SET_CONFIG_OPTION, params).await?;
let result = resp
.get("result")
.cloned()
.unwrap_or(serde_json::Value::Null);
let config_resp: SetSessionConfigOptionResponse =
serde_json::from_value(result).context("Failed to parse config option response")?;
Ok(config_resp.config_options.unwrap_or_default())
}
pub async fn authenticate(
&self,
method_id: &str,
credentials: Option<serde_json::Value>,
) -> Result<()> {
let params = serde_json::json!({
"methodId": method_id,
"credentials": credentials
});
let _resp = self.rpc(methods::AUTHENTICATE, params).await?;
Ok(())
}
pub async fn shutdown(self) -> Result<()> {
let guard = self.cmd_tx.lock().await;
if let Some(tx) = guard.as_ref() {
let _ = tx.send(SubprocessCmd::Shutdown).await;
}
Ok(())
}
async fn rpc(&self, method: &str, params: serde_json::Value) -> Result<serde_json::Value> {
let id = {
let mut state = self.state.lock().await;
let id = state.next_id;
state.next_id += 1;
id
};
let guard = self.cmd_tx.lock().await;
let tx = guard.as_ref().context("Subprocess task died")?;
let (resp_tx, mut resp_rx) = mpsc::channel(1);
let request = serde_json::to_string(&serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params
}))?;
tracing::debug!(method, id, request = %request, "ACP sending request");
tx.send(SubprocessCmd::SendRequest {
request,
response_tx: resp_tx,
})
.await
.context("Failed to send request")?;
let resp = timeout(LONG_TIMEOUT, resp_rx.recv())
.await
.context("RPC timeout")?
.context("Channel closed")?;
tracing::debug!(method, id, response = ?resp, "ACP received response");
resp
}
async fn rpc_no_timeout(
&self,
method: &str,
params: serde_json::Value,
) -> Result<serde_json::Value> {
let id = {
let mut state = self.state.lock().await;
let id = state.next_id;
state.next_id += 1;
id
};
let guard = self.cmd_tx.lock().await;
let tx = guard.as_ref().context("Subprocess task died")?;
let (resp_tx, mut resp_rx) = mpsc::channel(1);
let request = serde_json::to_string(&serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params
}))?;
tracing::debug!(method, id, request = %request, "ACP sending request (no timeout)");
tx.send(SubprocessCmd::SendRequest {
request,
response_tx: resp_tx,
})
.await
.context("Failed to send request")?;
let resp = resp_rx
.recv()
.await
.context("Channel closed - subprocess died")?;
tracing::debug!(method, id, response = ?resp, "ACP received response");
resp
}
async fn send_notification(&self, method: &str, params: serde_json::Value) -> Result<()> {
let guard = self.cmd_tx.lock().await;
let tx = guard.as_ref().context("Subprocess task died")?;
let notification = serde_json::to_string(&serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": params
}))?;
let (resp_tx, _) = mpsc::channel(1);
tx.send(SubprocessCmd::SendRequest {
request: notification,
response_tx: resp_tx,
})
.await
.context("Failed to send notification")?;
Ok(())
}
}
async fn run_subprocess(
command: &str,
args: &[String],
mut cmd_rx: mpsc::Receiver<SubprocessCmd>,
collected_content: Arc<Mutex<String>>,
handler: Arc<dyn AcpCallbackHandler>,
event_tx: broadcast::Sender<SessionEvent>,
notification_manager: Arc<Mutex<NotificationManager>>,
) {
let mut child = match Command::new(command)
.args(args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true)
.spawn()
{
Ok(c) => c,
Err(e) => {
tracing::error!("Failed to spawn ACP subprocess: {}", e);
return;
}
};
let mut stdin = match child.stdin.take() {
Some(s) => s,
None => return,
};
let stdout = match child.stdout.take() {
Some(s) => s,
None => return,
};
let mut reader = BufReader::new(stdout);
tracing::info!("ACP subprocess started: {} {:?}", command, args);
loop {
tokio::select! {
cmd = cmd_rx.recv() => {
match cmd {
Some(SubprocessCmd::SendRequest { request, response_tx }) => {
collected_content.lock().await.clear();
let request_id = serde_json::from_str::<serde_json::Value>(&request)
.ok()
.and_then(|v| v.get("id").and_then(|i| i.as_i64()));
let mut combined = request.as_bytes().to_vec();
combined.push(b'\n');
if stdin.write_all(&combined).await.is_err() {
let _ = response_tx.send(Err(anyhow::anyhow!("Write error"))).await;
break;
}
if stdin.flush().await.is_err() {
let _ = response_tx.send(Err(anyhow::anyhow!("Flush error"))).await;
break;
}
tracing::debug!("ACP request sent: {}", request);
let mut line_buf = String::new();
loop {
line_buf.clear();
tokio::select! {
result = reader.read_line(&mut line_buf) => {
match result {
Ok(0) => {
let _ = response_tx.send(Err(anyhow::anyhow!("EOF"))).await;
break;
}
Ok(_) => {
let line = line_buf.trim();
if line.is_empty() {
continue;
}
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(line) {
let method_field = msg.get("method").and_then(|m| m.as_str());
if let Some(method) = method_field {
tracing::info!("ACP incoming method: {} | msg: {}", method, line);
} else if msg.get("id").is_some() {
tracing::debug!("ACP response: {}", line);
} else {
tracing::debug!("ACP message: {}", line);
}
if method_field == Some(methods::SESSION_UPDATE) {
handle_session_update(&msg, &collected_content, &event_tx, ¬ification_manager).await;
continue;
}
if let Some(method) = method_field {
tracing::info!("Handling agent request: {}", method);
if handle_agent_request(&mut stdin, &msg, method, &handler).await {
continue;
}
}
let resp_id = msg.get("id").and_then(|i| i.as_i64());
if resp_id == request_id {
let _ = response_tx.send(Ok(msg)).await;
break;
}
}
}
Err(e) => {
let _ = response_tx.send(Err(anyhow::anyhow!("{}", e))).await;
break;
}
}
}
_ = tokio::time::sleep(Duration::from_millis(50)) => {
}
}
}
}
Some(SubprocessCmd::Shutdown) => {
tracing::info!("ACP subprocess shutting down");
break;
}
None => {
break;
}
}
}
_ = tokio::time::sleep(Duration::from_millis(10)) => {}
}
}
}
async fn handle_session_update(
msg: &serde_json::Value,
collected_content: &Arc<Mutex<String>>,
event_tx: &broadcast::Sender<SessionEvent>,
notification_manager: &Arc<Mutex<NotificationManager>>,
) {
let params = msg.get("params");
let session_id = params
.and_then(|p| p.get("sessionId"))
.and_then(|s| s.as_str())
.map(String::from);
let update = params.and_then(|p| p.get("update"));
if let Some(update) = update {
let session_update = update.get("sessionUpdate").and_then(|s| s.as_str());
match session_update {
Some("plan") => {
tracing::info!("ACP plan received");
if let Some(entries) = update.get("entries").and_then(|e| e.as_array()) {
for entry in entries {
if let Some(content) = entry.get("content").and_then(|c| c.as_str()) {
tracing::info!("Plan entry: {}", content);
}
}
}
}
Some("user_message") | Some("user_message_chunk") => {
tracing::debug!("ACP user_message received");
}
Some("agent_message") => {
if let Some(content) = update.get("content") {
extract_text_content(content, collected_content).await;
}
}
Some("agent_message_chunk") => {
if let Some(content) = update.get("content") {
extract_text_content(content, collected_content).await;
if let Some(text) = content.get("text").and_then(|t| t.as_str()) {
let _ = event_tx.send(SessionEvent::AgentMessageChunk {
content: text.to_string(),
});
}
}
}
Some("agent_thought_chunk") => {
if let Some(content) = update.get("content") {
if let Some(text) = content.get("text").and_then(|t| t.as_str()) {
tracing::debug!("ACP thought: {}", text);
let _ = event_tx.send(SessionEvent::AgentThoughtChunk {
content: text.to_string(),
});
}
}
}
Some("tool_call") => {
let tool_call_id = update
.get("toolCallId")
.and_then(|t| t.as_str())
.unwrap_or("?")
.to_string();
let title = update
.get("title")
.and_then(|t| t.as_str())
.map(String::from);
let kind = parse_tool_kind(update.get("kind").and_then(|k| k.as_str()));
let status = update.get("status").and_then(|s| s.as_str());
tracing::info!(
"ACP tool_call: {} - {:?} ({:?})",
tool_call_id,
title,
status
);
match status {
Some("pending") => {
let _ = event_tx.send(SessionEvent::ToolCallStarted {
tool_call_id: tool_call_id.clone(),
title: title.clone(),
kind: kind.clone(),
});
let notif = Notification::new(
NotificationPriority::Medium,
"工具开始",
&format!("执行工具: {}", title.clone().unwrap_or_default()),
);
if let Ok(nm) = notification_manager.try_lock() {
nm.send(¬if.with_session_id(session_id.clone().unwrap_or_default()))
.await;
}
}
Some("in_progress") => {
let _ = event_tx.send(SessionEvent::ToolCallInProgress { tool_call_id });
}
Some("completed") => {
let result = update
.get("result")
.and_then(|r| r.get("text"))
.and_then(|t| t.as_str())
.map(String::from);
let _ = event_tx.send(SessionEvent::ToolCallCompleted {
tool_call_id: tool_call_id.clone(),
result: result.clone(),
});
let notif = Notification::new(
NotificationPriority::Medium,
"工具完成",
&format!("工具执行完成: {}", title.clone().unwrap_or_default()),
);
if let Ok(nm) = notification_manager.try_lock() {
nm.send(¬if.with_session_id(session_id.clone().unwrap_or_default()))
.await;
}
}
Some("failed") => {
let error = update
.get("error")
.and_then(|e| e.as_str())
.unwrap_or("Unknown error")
.to_string();
let _ = event_tx.send(SessionEvent::ToolCallFailed {
tool_call_id: tool_call_id.clone(),
error: error.clone(),
});
let notif = Notification::new(
NotificationPriority::High,
"工具失败",
&format!(
"工具执行失败: {} - {}",
title.clone().unwrap_or_default(),
error
),
)
.with_burn_after_read();
if let Ok(nm) = notification_manager.try_lock() {
nm.send(¬if.with_session_id(session_id.clone().unwrap_or_default()))
.await;
}
}
_ => {}
}
}
Some("mode_change") => {
if let Some(mode_id) = update.get("modeId").and_then(|m| m.as_str()) {
tracing::info!("ACP mode_change: {}", mode_id);
let _ = event_tx.send(SessionEvent::ModeChanged {
mode_id: mode_id.to_string(),
});
}
}
Some("config_option_update") => {
if let Some(options) = update.get("configOptions").and_then(|o| o.as_array()) {
let config_options: Vec<SessionConfigOption> = options
.iter()
.filter_map(|v| serde_json::from_value(v.clone()).ok())
.collect();
let _ = event_tx.send(SessionEvent::ConfigOptionUpdated {
options: config_options,
});
}
}
Some("session_info_update") => {
let title = update
.get("title")
.and_then(|t| t.as_str())
.map(String::from);
let updated_at = update
.get("updatedAt")
.and_then(|t| t.as_str())
.map(String::from);
let _ = event_tx.send(SessionEvent::SessionInfoUpdated {
title: title.clone(),
updated_at: updated_at.clone(),
});
let notif = Notification::new(
NotificationPriority::High,
"会话已创建",
&format!(
"Session ID: {}\n标题: {}",
session_id.clone().unwrap_or_default(),
title.clone().unwrap_or_default()
),
)
.with_burn_after_read();
if let Ok(nm) = notification_manager.try_lock() {
nm.send(¬if).await;
}
}
Some("usage_update") => {
let used = update.get("used").and_then(|u| u.as_u64()).unwrap_or(0) as u32;
let size = update.get("size").and_then(|s| s.as_u64()).unwrap_or(0) as u32;
let _ = event_tx.send(SessionEvent::UsageUpdated { used, size });
}
Some("available_commands_update") => {
if let Some(commands) = update.get("availableCommands").and_then(|c| c.as_array()) {
let cmds: Vec<AvailableCommand> = commands
.iter()
.filter_map(|v| serde_json::from_value(v.clone()).ok())
.collect();
let _ =
event_tx.send(SessionEvent::AvailableCommandsUpdated { commands: cmds });
}
}
_ => {
tracing::debug!("ACP session_update: {:?}", session_update);
}
}
}
}
async fn extract_text_content(content: &serde_json::Value, collected_content: &Arc<Mutex<String>>) {
if let Some(text) = content.get("text").and_then(|t| t.as_str()) {
collected_content.lock().await.push_str(text);
}
if let Some(arr) = content.as_array() {
for item in arr {
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
collected_content.lock().await.push_str(text);
}
}
}
}
fn parse_tool_kind(kind: Option<&str>) -> ToolKind {
match kind {
Some("read") => ToolKind::Read,
Some("edit") => ToolKind::Edit,
Some("delete") => ToolKind::Delete,
Some("move") => ToolKind::Move,
Some("search") => ToolKind::Search,
Some("execute") => ToolKind::Execute,
Some("think") => ToolKind::Think,
Some("fetch") => ToolKind::Fetch,
_ => ToolKind::Other,
}
}
async fn handle_agent_request(
stdin: &mut tokio::process::ChildStdin,
msg: &serde_json::Value,
method: &str,
handler: &Arc<dyn AcpCallbackHandler>,
) -> bool {
let request_id = msg.get("id").and_then(|i| i.as_i64());
let params = msg
.get("params")
.cloned()
.unwrap_or(serde_json::Value::Null);
let session_id = params
.get("sessionId")
.and_then(|s| s.as_str())
.unwrap_or("")
.to_string();
let result: Result<serde_json::Value> = match method {
methods::SESSION_REQUEST_PERMISSION => {
let tool_call_id = params
.get("toolCall")
.and_then(|t| t.get("id").and_then(|i| i.as_str()))
.unwrap_or("");
let options: Vec<PermissionOption> = params
.get("options")
.and_then(|o| o.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| serde_json::from_value(v.clone()).ok())
.collect()
})
.unwrap_or_default();
let outcome = handler
.handle_request_permission(&session_id, tool_call_id, options)
.await;
Ok(serde_json::to_value(RequestPermissionResponse { outcome })
.unwrap_or(serde_json::Value::Null))
}
methods::FS_READ_TEXT_FILE => {
let path = params.get("path").and_then(|p| p.as_str()).unwrap_or("");
match handler.handle_read_text_file(&session_id, path).await {
Ok(contents) => Ok(serde_json::to_value(ReadTextFileResponse { contents })
.unwrap_or(serde_json::Value::Null)),
Err(e) => Err(e),
}
}
methods::FS_WRITE_TEXT_FILE => {
let path = params.get("path").and_then(|p| p.as_str()).unwrap_or("");
let contents = params
.get("contents")
.and_then(|c| c.as_str())
.unwrap_or("");
match handler
.handle_write_text_file(&session_id, path, contents)
.await
{
Ok(_) => Ok(serde_json::to_value(WriteTextFileResponse {})
.unwrap_or(serde_json::Value::Null)),
Err(e) => Err(e),
}
}
methods::TERMINAL_CREATE => {
let command = params.get("command").and_then(|c| c.as_str());
let args = params.get("args").and_then(|a| a.as_array()).map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
});
match handler
.handle_terminal_create(&session_id, command, args)
.await
{
Ok(terminal_id) => Ok(serde_json::to_value(CreateTerminalResponse { terminal_id })
.unwrap_or(serde_json::Value::Null)),
Err(e) => Err(e),
}
}
methods::TERMINAL_OUTPUT => {
let terminal_id = params
.get("terminalId")
.and_then(|t| t.as_str())
.unwrap_or("");
match handler
.handle_terminal_output(&session_id, terminal_id)
.await
{
Ok(resp) => Ok(serde_json::to_value(resp).unwrap_or(serde_json::Value::Null)),
Err(e) => Err(e),
}
}
methods::TERMINAL_KILL => {
let terminal_id = params
.get("terminalId")
.and_then(|t| t.as_str())
.unwrap_or("");
match handler.handle_terminal_kill(&session_id, terminal_id).await {
Ok(_) => Ok(serde_json::to_value(KillTerminalResponse {})
.unwrap_or(serde_json::Value::Null)),
Err(e) => Err(e),
}
}
methods::TERMINAL_RELEASE => {
let terminal_id = params
.get("terminalId")
.and_then(|t| t.as_str())
.unwrap_or("");
match handler
.handle_terminal_release(&session_id, terminal_id)
.await
{
Ok(_) => Ok(serde_json::to_value(ReleaseTerminalResponse {})
.unwrap_or(serde_json::Value::Null)),
Err(e) => Err(e),
}
}
methods::TERMINAL_WAIT_FOR_EXIT => {
let terminal_id = params
.get("terminalId")
.and_then(|t| t.as_str())
.unwrap_or("");
match handler
.handle_terminal_wait_for_exit(&session_id, terminal_id)
.await
{
Ok(exit) => Ok(serde_json::to_value(WaitForTerminalExitResponse { exit })
.unwrap_or(serde_json::Value::Null)),
Err(e) => Err(e),
}
}
_ => return false, };
if let Some(id) = request_id {
let response = match result {
Ok(value) => serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": value
}),
Err(e) => serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"error": { "code": -32603, "message": e.to_string() }
}),
};
let response_str = serde_json::to_string(&response).unwrap_or_default();
let mut combined = response_str.as_bytes().to_vec();
combined.push(b'\n');
use tokio::io::AsyncWriteExt;
let _ = stdin.write_all(&combined).await;
let _ = stdin.flush().await;
tracing::debug!("ACP response sent: {}", response_str);
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_handler_permission() {
let handler = DefaultAcpHandler;
let options = vec![
PermissionOption {
option_id: "deny".to_string(),
kind: PermissionOptionKind::RejectOnce,
label: None,
},
PermissionOption {
option_id: "allow".to_string(),
kind: PermissionOptionKind::AllowOnce,
label: None,
},
];
assert!(matches!(options[1].kind, PermissionOptionKind::AllowOnce));
}
#[test]
fn test_session_event_variants() {
let event = SessionEvent::AgentMessageChunk {
content: "test".to_string(),
};
assert!(matches!(event, SessionEvent::AgentMessageChunk { .. }));
let event = SessionEvent::ToolCallStarted {
tool_call_id: "call_1".to_string(),
title: Some("Reading file".to_string()),
kind: ToolKind::Read,
};
assert!(matches!(event, SessionEvent::ToolCallStarted { .. }));
}
}