use std::sync::{Arc, Mutex};
use std::time::Duration;
use serde_json::{json, Value};
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use crate::core::error::AgentError;
use crate::core::types::{AgentEvent, CliTool};
use crate::rpc::id::IdGen;
use crate::rpc::message::{RpcNotification, RpcRequest};
use crate::rpc::pending::PendingRequests;
use super::host::{AcpHostAdapter, AcpHostHandler, FilesystemAcpHandler};
use super::protocol::{
extract_token_usage, AgentCapabilities, ClientCapabilities, ClientInfo, ContentBlock,
FsCapabilities, InitializeParams, McpServerConfig, SessionCancelParams, SessionLoadParams,
SessionLoadResult, SessionNewParams, SessionPromptParams,
};
use super::reader::acp_reader_loop;
use super::spawn::AcpProcess;
#[derive(Debug, thiserror::Error)]
pub enum AcpError {
#[error("Process spawn failed: {source}")]
Spawn {
#[source]
source: std::io::Error,
},
#[error("Stdin write failed: {source}")]
Write {
#[source]
source: std::io::Error,
},
#[error("JSON error: {source}")]
Json {
#[source]
source: serde_json::Error,
},
#[error("Handshake timed out (step={step})")]
HandshakeTimeout { step: &'static str },
#[error("Handshake failed: {message}")]
HandshakeFailed { message: String },
#[error("Agent returned RPC error: {0}")]
Agent(#[from] crate::rpc::message::RpcError),
#[error("Request timed out (method={method})")]
Timeout { method: String },
#[error("Session not initialized — call session_new() first")]
NoSession,
#[error("Session closed while awaiting response")]
SessionClosed,
}
pub struct AcpSessionOptions {
pub host_handler: Option<Box<dyn AcpHostHandler>>,
pub channel_capacity: usize,
pub handshake_timeout: Duration,
pub prompt_timeout: Duration,
pub mcp_servers: Vec<McpServerConfig>,
}
impl Default for AcpSessionOptions {
fn default() -> Self {
Self {
host_handler: None,
channel_capacity: 256,
handshake_timeout: Duration::from_secs(30),
prompt_timeout: Duration::from_secs(120),
mcp_servers: vec![],
}
}
}
pub struct AcpSession {
local_session_id: String,
acp_session_id: Arc<tokio::sync::Mutex<Option<String>>>,
tool: CliTool,
tx: broadcast::Sender<AgentEvent>,
process: Arc<Mutex<AcpProcess>>,
pending: PendingRequests,
id_gen: Arc<IdGen>,
reader_task: JoinHandle<()>,
prompt_timeout: Duration,
agent_caps: AgentCapabilities,
}
impl AcpSession {
pub async fn spawn(
tool: CliTool,
working_dir: &std::path::Path,
options: AcpSessionOptions,
) -> Result<Self, AcpError> {
let mcp_servers = options.mcp_servers;
let options = AcpSessionOptions { mcp_servers: vec![], ..options };
let proc = AcpProcess::spawn(tool, working_dir, &[])
.map_err(|e| AcpError::Spawn { source: e })?;
let local_session_id = generate_session_id();
let (tx, _) = broadcast::channel::<AgentEvent>(options.channel_capacity);
let _ = tx.send(AgentEvent::Started {
session_id: local_session_id.clone(),
});
let process = Arc::new(Mutex::new(proc));
let handler: Arc<dyn crate::rpc::handler::HostHandler> = match options.host_handler {
Some(h) => Arc::new(AcpHostAdapter(Arc::from(h))),
None => Arc::new(AcpHostAdapter(Arc::new(FilesystemAcpHandler { allowed_roots: None }))),
};
let pending = PendingRequests::new();
let id_gen = Arc::new(IdGen::new());
let reader_process = Arc::clone(&process);
let reader_tx = tx.clone();
let reader_pending = pending.clone();
let reader_task = tokio::task::spawn_blocking(move || {
acp_reader_loop(reader_process, reader_tx, reader_pending, handler);
});
let acp_session_id = Arc::new(tokio::sync::Mutex::new(None::<String>));
let mut session = Self {
local_session_id: local_session_id.clone(),
acp_session_id: Arc::clone(&acp_session_id),
tool,
tx: tx.clone(),
process,
pending,
id_gen,
reader_task,
prompt_timeout: options.prompt_timeout,
agent_caps: AgentCapabilities::default(),
};
let init_params = InitializeParams {
protocol_version: 1,
client_capabilities: ClientCapabilities {
fs: FsCapabilities { read_text_file: true, write_text_file: true },
terminal: true,
},
client_info: ClientInfo {
name: "gate4agent",
title: Some("Gate4Agent"),
version: env!("CARGO_PKG_VERSION"),
},
};
let caps: AgentCapabilities = session
.rpc_call_typed("initialize", json!(init_params), options.handshake_timeout, true)
.await
.map_err(|e| match e {
AcpError::Timeout { .. } => AcpError::HandshakeTimeout { step: "initialize" },
AcpError::Agent(rpc_err) => AcpError::HandshakeFailed {
message: rpc_err.to_string(),
},
other => other,
})?;
session.agent_caps = caps;
let new_params = SessionNewParams {
cwd: working_dir.to_str().unwrap_or(".").to_string(),
mcp_servers,
};
let new_result = session
.rpc_call("session/new", Some(json!(new_params)), options.handshake_timeout)
.await
.map_err(|e| match e {
AcpError::Timeout { .. } => AcpError::HandshakeTimeout { step: "session/new" },
AcpError::Agent(rpc_err) => AcpError::HandshakeFailed {
message: rpc_err.to_string(),
},
other => other,
})?;
let acp_sid = new_result
.get("sessionId")
.and_then(|v| v.as_str())
.unwrap_or(&local_session_id)
.to_owned();
{
let mut guard = acp_session_id.lock().await;
*guard = Some(acp_sid.clone());
}
let _ = tx.send(AgentEvent::SessionStart {
session_id: acp_sid,
model: "".to_string(),
tools: vec![],
});
Ok(session)
}
pub async fn prompt(&self, text: &str) -> Result<(), AcpError> {
let session_id = {
let guard = self.acp_session_id.lock().await;
guard.clone().ok_or(AcpError::NoSession)?
};
let params = SessionPromptParams {
session_id,
prompt: vec![ContentBlock::Text { text: text.to_owned() }],
};
let result = self
.rpc_call("session/prompt", Some(json!(params)), self.prompt_timeout)
.await?;
let stop_reason = result
.get("stopReason")
.and_then(|v| v.as_str())
.unwrap_or("end_turn")
.to_owned();
let (input_tokens, output_tokens) = extract_token_usage(&result);
let _ = self.tx.send(AgentEvent::TurnComplete { input_tokens, output_tokens });
let _ = self.tx.send(AgentEvent::SessionEnd {
result: stop_reason,
cost_usd: None,
is_error: false,
});
Ok(())
}
pub async fn cancel(&self) -> Result<(), AcpError> {
let session_id = {
let guard = self.acp_session_id.lock().await;
guard.clone().ok_or(AcpError::NoSession)?
};
let params = SessionCancelParams { session_id };
self.notify("session/cancel", Some(json!(params))).await
}
pub fn subscribe(&self) -> broadcast::Receiver<AgentEvent> {
self.tx.subscribe()
}
pub fn session_id(&self) -> &str {
&self.local_session_id
}
pub fn tool(&self) -> CliTool {
self.tool
}
pub async fn acp_session_id(&self) -> Option<String> {
self.acp_session_id.lock().await.clone()
}
pub async fn kill(&self) -> Result<(), AgentError> {
self.reader_task.abort();
let process = Arc::clone(&self.process);
tokio::task::spawn_blocking(move || {
let mut guard = process
.lock()
.map_err(|_| AgentError::Pty("acp process mutex poisoned".into()))?;
guard.kill().map_err(|e| AgentError::Spawn { source: e })
})
.await
.map_err(|_| AgentError::Pty("spawn_blocking panicked".into()))?
}
pub fn supports_load_session(&self) -> bool {
self.agent_caps.agent_capabilities.load_session
}
pub async fn load_session(&self, prior_session_id: &str) -> Result<(), AcpError> {
if !self.supports_load_session() {
return Err(AcpError::HandshakeFailed {
message: "agent does not support loadSession".to_string(),
});
}
let params = SessionLoadParams { session_id: prior_session_id.to_owned() };
let result: SessionLoadResult = self
.rpc_call_typed("session/load", json!(params), self.prompt_timeout, false)
.await?;
let new_sid = if result.session_id.is_empty() {
prior_session_id.to_owned()
} else {
result.session_id
};
{
let mut guard = self.acp_session_id.lock().await;
*guard = Some(new_sid);
}
Ok(())
}
async fn rpc_call_typed<T: serde::de::DeserializeOwned>(
&self,
method: &str,
params: Value,
timeout: Duration,
_id_zero: bool,
) -> Result<T, AcpError> {
let raw = self.rpc_call(method, Some(params), timeout).await?;
serde_json::from_value(raw).map_err(|e| AcpError::Json { source: e })
}
async fn rpc_call(
&self,
method: &str,
params: Option<Value>,
timeout: Duration,
) -> Result<Value, AcpError> {
let id = self.id_gen.next();
let rx = self.pending.register(id.clone());
let request = RpcRequest::new(id.clone(), method, params);
let line = serde_json::to_string(&request).map_err(|e| AcpError::Json { source: e })?;
self.write_line(line).await?;
tokio::time::timeout(timeout, rx)
.await
.map_err(|_| AcpError::Timeout {
method: method.to_owned(),
})?
.map_err(|_| AcpError::SessionClosed)?
.map_err(AcpError::Agent)
}
async fn notify(&self, method: &str, params: Option<Value>) -> Result<(), AcpError> {
let notif = RpcNotification {
jsonrpc: "2.0".into(),
method: method.into(),
params,
};
let line = serde_json::to_string(¬if).map_err(|e| AcpError::Json { source: e })?;
self.write_line(line).await
}
async fn write_line(&self, line: String) -> Result<(), AcpError> {
let process = Arc::clone(&self.process);
tokio::task::spawn_blocking(move || {
let mut guard = process.lock().map_err(|_| AcpError::Write {
source: std::io::Error::new(std::io::ErrorKind::Other, "mutex poisoned"),
})?;
guard
.write_line(&line)
.map_err(|e| AcpError::Write { source: e })
})
.await
.map_err(|_| AcpError::Write {
source: std::io::Error::new(std::io::ErrorKind::Other, "spawn_blocking panicked"),
})?
}
}
fn generate_session_id() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let t = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
format!("acp-{:x}", t)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn acp_error_display_messages() {
let e = AcpError::HandshakeTimeout { step: "initialize" };
assert!(e.to_string().contains("initialize"));
let e = AcpError::Timeout { method: "session/prompt".into() };
assert!(e.to_string().contains("session/prompt"));
let e = AcpError::NoSession;
assert!(!e.to_string().is_empty());
let e = AcpError::SessionClosed;
assert!(!e.to_string().is_empty());
}
#[test]
fn acp_session_options_default_compiles() {
let opts = AcpSessionOptions::default();
assert_eq!(opts.channel_capacity, 256);
assert_eq!(opts.handshake_timeout, Duration::from_secs(30));
assert_eq!(opts.prompt_timeout, Duration::from_secs(120));
assert!(opts.host_handler.is_none());
assert!(opts.mcp_servers.is_empty(), "default mcp_servers must be empty");
}
}