use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::{oneshot, Mutex};
use tokio::task::JoinSet;
use tracing::{debug, error, info, warn};
use crate::bus::{InboundMessage, MessageBus, OutboundMessage};
use crate::config::{AcpChannelConfig, AcpHttpConfig};
use crate::error::{Result, ZeptoError};
use super::acp_protocol::{
AgentCapabilities, AgentInfo, ContentBlock, InitializeResult, JsonRpcRequest, SessionInfo,
SessionListParams, SessionListResult, SessionNewResult, SessionPromptResult,
SessionUpdateParams, SessionUpdatePayload,
};
use super::{BaseChannelConfig, Channel};
pub const ACP_HTTP_CHANNEL_NAME: &str = "acp_http";
const ACP_HTTP_SENDER_ID: &str = "acp_client";
const MAX_REQUEST_BYTES: usize = 118_784; const MAX_PROMPT_BYTES: usize = 102_400;
const MAX_ACP_SESSIONS: usize = 1_000;
const MAX_CONCURRENT_CONNECTIONS: usize = 128;
const PROMPT_TIMEOUT_SECS: u64 = 300;
fn cors_line(open_cors: bool) -> &'static str {
if open_cors {
"Access-Control-Allow-Origin: *\r\n"
} else {
""
}
}
fn build_cors_preflight(open_cors: bool) -> String {
format!(
"HTTP/1.1 204 No Content\r\n{}Access-Control-Allow-Methods: POST, OPTIONS\r\nAccess-Control-Allow-Headers: Content-Type, Authorization\r\nContent-Length: 0\r\n\r\n",
cors_line(open_cors),
)
}
fn build_204_notification(open_cors: bool) -> String {
format!(
"HTTP/1.1 204 No Content\r\n{}Content-Length: 0\r\n\r\n",
cors_line(open_cors),
)
}
fn build_http_error(status_line: &str, body: &str, open_cors: bool) -> String {
format!(
"HTTP/1.1 {}\r\nContent-Type: application/json\r\n{}Content-Length: {}\r\nConnection: close\r\n\r\n{}",
status_line,
cors_line(open_cors),
body.len(),
body
)
}
struct PendingPrompt {
cancelled: bool,
}
type PromptMap = Arc<Mutex<HashMap<String, oneshot::Sender<(String, bool)>>>>;
struct SessionEntry {
cwd: String,
last_active: std::time::Instant,
}
struct AcpHttpState {
sessions: HashMap<String, SessionEntry>,
pending: HashMap<String, PendingPrompt>,
}
impl AcpHttpState {
fn new() -> Self {
Self {
sessions: HashMap::new(),
pending: HashMap::new(),
}
}
fn sweep_expired(&mut self, ttl: Duration) {
let now = std::time::Instant::now();
let expired: Vec<String> = self
.sessions
.iter()
.filter(|(_, entry)| now.duration_since(entry.last_active) >= ttl)
.map(|(id, _)| id.clone())
.collect();
for id in &expired {
self.sessions.remove(id);
self.pending.remove(id);
}
}
}
struct ParsedRequest {
method: String,
path: String,
headers: Vec<(String, String)>,
body: String,
}
pub struct AcpHttpChannel {
config: AcpChannelConfig,
http_config: AcpHttpConfig,
base_config: BaseChannelConfig,
bus: Arc<MessageBus>,
running: Arc<AtomicBool>,
state: Arc<Mutex<AcpHttpState>>,
pending_http: PromptMap,
accept_handle: Option<tokio::task::JoinHandle<()>>,
conn_tasks: Arc<Mutex<JoinSet<()>>>,
}
impl AcpHttpChannel {
pub fn new(
config: AcpChannelConfig,
http_config: AcpHttpConfig,
base_config: BaseChannelConfig,
bus: Arc<MessageBus>,
) -> Self {
Self {
config,
http_config,
base_config,
bus,
running: Arc::new(AtomicBool::new(false)),
state: Arc::new(Mutex::new(AcpHttpState::new())),
pending_http: Arc::new(Mutex::new(HashMap::new())),
accept_handle: None,
conn_tasks: Arc::new(Mutex::new(JoinSet::new())),
}
}
fn find_header_end(data: &[u8]) -> Option<usize> {
data.windows(4).position(|w| w == b"\r\n\r\n")
}
fn parse_request(raw: &[u8]) -> Option<ParsedRequest> {
let s = std::str::from_utf8(raw).ok()?;
let pos = s.find("\r\n\r\n")?;
let header_section = &s[..pos];
let mut lines = header_section.lines();
let request_line = lines.next()?;
let mut parts = request_line.split_whitespace();
let method = parts.next()?.to_uppercase();
let path = parts.next()?.to_string();
let mut headers = Vec::new();
for line in lines {
if let Some(colon) = line.find(':') {
headers.push((
line[..colon].trim().to_string(),
line[colon + 1..].trim().to_string(),
));
}
}
let content_len = Self::content_length(&headers);
let body_start = pos + 4;
let body_end = (body_start + content_len).min(s.len());
let body = s[body_start..body_end].to_string();
Some(ParsedRequest {
method,
path,
headers,
body,
})
}
fn content_length(headers: &[(String, String)]) -> usize {
headers
.iter()
.find(|(n, _)| n.eq_ignore_ascii_case("content-length"))
.and_then(|(_, v)| v.trim().parse::<usize>().ok())
.unwrap_or(0)
}
fn validate_auth(headers: &[(String, String)], token: &Option<String>) -> bool {
let required = match token {
Some(t) => t,
None => return true,
};
let expected = format!("Bearer {}", required);
headers.iter().any(|(n, v)| {
n.eq_ignore_ascii_case("authorization") && constant_time_eq(v.trim(), &expected)
})
}
fn json_rpc_result(id: Option<serde_json::Value>, result: serde_json::Value) -> String {
serde_json::to_string(&serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": result
}))
.unwrap_or_else(|_| {
r#"{"jsonrpc":"2.0","error":{"code":-32603,"message":"serialize error"}}"#.to_string()
})
}
fn json_rpc_error(id: Option<serde_json::Value>, code: i64, message: &str) -> String {
serde_json::to_string(&serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"error": { "code": code, "message": message }
}))
.unwrap_or_else(|_| {
r#"{"jsonrpc":"2.0","error":{"code":-32603,"message":"serialize error"}}"#.to_string()
})
}
fn http_200(body: &str, open_cors: bool) -> String {
format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n{}Content-Length: {}\r\nConnection: close\r\n\r\n{}",
cors_line(open_cors),
body.len(),
body
)
}
fn http_400(body: &str, open_cors: bool) -> String {
format!(
"HTTP/1.1 400 Bad Request\r\nContent-Type: application/json\r\n{}Content-Length: {}\r\nConnection: close\r\n\r\n{}",
cors_line(open_cors),
body.len(),
body
)
}
fn sse_event(data: &str) -> String {
format!("data: {}\n\n", data)
}
async fn do_initialize(
_config: &AcpChannelConfig,
id: Option<serde_json::Value>,
params: Option<serde_json::Value>,
) -> String {
if let Some(init_params) = params
.and_then(|p| serde_json::from_value::<super::acp_protocol::InitializeParams>(p).ok())
{
if let Some(ref ci) = init_params.client_info {
info!(
client_name = ?ci.name,
client_version = ?ci.version,
"ACP-HTTP: client initialized"
);
}
}
let result = InitializeResult {
protocol_version: serde_json::json!("1"),
agent_capabilities: AgentCapabilities {
load_session: Some(false),
prompt_capabilities: Some(serde_json::json!({
"image": false, "audio": false, "embeddedContext": false
})),
mcp_capabilities: Some(serde_json::json!({ "http": false, "sse": false })),
session_capabilities: Some({
let mut m = HashMap::new();
m.insert("list".to_string(), serde_json::json!({}));
m
}),
},
agent_info: Some(AgentInfo {
name: "zeptoclaw".to_string(),
title: Some("ZeptoClaw".to_string()),
version: env!("CARGO_PKG_VERSION").to_string(),
}),
auth_methods: vec![],
};
match serde_json::to_value(result) {
Ok(v) => Self::json_rpc_result(id, v),
Err(e) => Self::json_rpc_error(id, -32603, &format!("serialize error: {}", e)),
}
}
async fn do_session_new(
state: &Arc<Mutex<AcpHttpState>>,
base_config: &BaseChannelConfig,
config: &AcpChannelConfig,
id: Option<serde_json::Value>,
params: Option<serde_json::Value>,
) -> String {
if !base_config.is_allowed(ACP_HTTP_SENDER_ID) {
return Self::json_rpc_error(id, -32000, "Unauthorized");
}
let cwd = params
.and_then(|p| serde_json::from_value::<super::acp_protocol::SessionNewParams>(p).ok())
.and_then(|p| p.cwd)
.filter(|c| !c.is_empty());
let cwd = match cwd {
Some(c) => c,
None => return Self::json_rpc_error(id, -32602, "session/new: cwd is required"),
};
if !cwd.starts_with('/') {
return Self::json_rpc_error(id, -32602, "session/new: cwd must be an absolute path");
}
if cwd.len() > 4096 {
return Self::json_rpc_error(id, -32602, "session/new: cwd exceeds 4096 bytes");
}
if cwd_contains_traversal(&cwd) {
return Self::json_rpc_error(
id,
-32602,
"session/new: cwd must not contain '..' path segments",
);
}
let session_id = format!("acph_{}", super::acp_protocol::new_id());
{
let mut st = state.lock().await;
if let Some(ttl) = config.session_ttl_secs {
st.sweep_expired(Duration::from_secs(ttl));
}
if st.sessions.len() >= MAX_ACP_SESSIONS {
return Self::json_rpc_error(
id,
-32000,
&format!("too many sessions (limit: {})", MAX_ACP_SESSIONS),
);
}
st.sessions.insert(
session_id.clone(),
SessionEntry {
cwd,
last_active: std::time::Instant::now(),
},
);
}
let result = SessionNewResult { session_id };
match serde_json::to_value(result) {
Ok(v) => Self::json_rpc_result(id, v),
Err(e) => Self::json_rpc_error(id, -32603, &format!("serialize error: {}", e)),
}
}
async fn do_session_cancel(
state: &Arc<Mutex<AcpHttpState>>,
base_config: &BaseChannelConfig,
id: Option<serde_json::Value>,
params: Option<serde_json::Value>,
) -> String {
if !base_config.is_allowed(ACP_HTTP_SENDER_ID) {
return Self::json_rpc_error(id, -32000, "Unauthorized");
}
if let Some(p) = params.and_then(|p| {
serde_json::from_value::<super::acp_protocol::SessionCancelParams>(p).ok()
}) {
let mut st = state.lock().await;
if let Some(pending) = st.pending.get_mut(&p.session_id) {
pending.cancelled = true;
debug!(session_id = %p.session_id, "ACP-HTTP: marked prompt as cancelled");
}
}
Self::json_rpc_result(id, serde_json::json!({}))
}
async fn do_session_list(
state: &Arc<Mutex<AcpHttpState>>,
base_config: &BaseChannelConfig,
id: Option<serde_json::Value>,
params: Option<serde_json::Value>,
) -> String {
if !base_config.is_allowed(ACP_HTTP_SENDER_ID) {
return Self::json_rpc_error(id, -32000, "Unauthorized");
}
let st = state.lock().await;
let list_params: Option<SessionListParams> = match params {
None => None,
Some(p) => match serde_json::from_value::<SessionListParams>(p) {
Ok(lp) => Some(lp),
Err(e) => {
return Self::json_rpc_error(id, -32602, &format!("Invalid params: {}", e));
}
},
};
let cwd_filter = list_params.and_then(|p| p.cwd);
let sessions: Vec<SessionInfo> = st
.sessions
.iter()
.filter(|(_, entry)| {
if let Some(ref filter) = cwd_filter {
entry.cwd.as_str() == filter.as_str()
} else {
true
}
})
.map(|(sid, entry)| SessionInfo {
session_id: sid.clone(),
cwd: entry.cwd.clone(),
title: None,
updated_at: None,
meta: Some(serde_json::json!({ "pending": st.pending.contains_key(sid) })),
})
.collect();
let result = SessionListResult {
sessions,
next_cursor: None,
};
match serde_json::to_value(result) {
Ok(v) => Self::json_rpc_result(id, v),
Err(e) => Self::json_rpc_error(id, -32603, &format!("serialize error: {}", e)),
}
}
async fn register_prompt(
state: &Arc<Mutex<AcpHttpState>>,
pending_http: &PromptMap,
base_config: &BaseChannelConfig,
bus: &Arc<MessageBus>,
id: Option<serde_json::Value>,
params: Option<serde_json::Value>,
) -> Result<std::result::Result<(String, oneshot::Receiver<(String, bool)>), String>> {
if !base_config.is_allowed(ACP_HTTP_SENDER_ID) {
return Ok(Err(Self::json_rpc_error(id, -32000, "Unauthorized")));
}
let params: super::acp_protocol::SessionPromptParams =
match params.and_then(|p| serde_json::from_value(p).ok()) {
Some(p) => p,
None => {
return Ok(Err(Self::json_rpc_error(
id,
-32602,
"session/prompt: missing or invalid params",
)));
}
};
let session_id = params.session_id.clone();
let content = super::acp_protocol::prompt_blocks_to_text(¶ms.prompt);
if content.is_empty() {
return Ok(Err(Self::json_rpc_error(
id,
-32602,
"session/prompt: prompt content is empty",
)));
}
if content.len() > MAX_PROMPT_BYTES {
return Ok(Err(Self::json_rpc_error(
id,
-32602,
&format!(
"session/prompt: content too large ({} bytes, limit {})",
content.len(),
MAX_PROMPT_BYTES
),
)));
}
{
let mut st = state.lock().await;
match st.sessions.get_mut(&session_id) {
None => {
return Ok(Err(Self::json_rpc_error(
id,
-32000,
&format!("ACP: unknown session {}", session_id),
)));
}
Some(entry) => {
entry.last_active = std::time::Instant::now();
}
}
if st.pending.contains_key(&session_id) {
return Ok(Err(Self::json_rpc_error(
id,
-32602,
"session/prompt: a prompt is already in flight for this session",
)));
}
st.pending
.insert(session_id.clone(), PendingPrompt { cancelled: false });
}
let (tx, rx) = oneshot::channel::<(String, bool)>();
{
pending_http.lock().await.insert(session_id.clone(), tx);
}
let inbound = InboundMessage::new(
ACP_HTTP_CHANNEL_NAME,
ACP_HTTP_SENDER_ID,
&session_id,
&content,
);
if let Err(e) = bus.publish_inbound(inbound).await {
state.lock().await.pending.remove(&session_id);
pending_http.lock().await.remove(&session_id);
return Err(ZeptoError::Channel(format!(
"ACP-HTTP: failed to publish inbound: {}",
e
)));
}
debug!(session_id = %session_id, "ACP-HTTP: published session/prompt to bus");
Ok(Ok((session_id, rx)))
}
async fn stream_prompt_response(
stream: &mut tokio::net::TcpStream,
session_id: &str,
id: Option<serde_json::Value>,
rx: oneshot::Receiver<(String, bool)>,
state: &Arc<Mutex<AcpHttpState>>,
pending_http: &PromptMap,
open_cors: bool,
) {
let sse_headers = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nCache-Control: no-cache\r\nConnection: keep-alive\r\n{}X-Accel-Buffering: no\r\n\r\n",
cors_line(open_cors),
);
if stream.write_all(sse_headers.as_bytes()).await.is_err() {
state.lock().await.pending.remove(session_id);
pending_http.lock().await.remove(session_id);
return;
}
let _ = stream.flush().await;
let (content, cancelled) =
match tokio::time::timeout(Duration::from_secs(PROMPT_TIMEOUT_SECS), rx).await {
Ok(Ok(payload)) => payload,
Ok(Err(_)) => {
state.lock().await.pending.remove(session_id);
pending_http.lock().await.remove(session_id);
let ev =
Self::sse_event(&Self::json_rpc_error(id, -32603, "agent session closed"));
let _ = stream.write_all(ev.as_bytes()).await;
let _ = stream.flush().await;
return;
}
Err(_) => {
state.lock().await.pending.remove(session_id);
pending_http.lock().await.remove(session_id);
let ev = Self::sse_event(&Self::json_rpc_error(
id,
-32603,
"session/prompt timed out",
));
let _ = stream.write_all(ev.as_bytes()).await;
let _ = stream.flush().await;
return;
}
};
let update = SessionUpdateParams {
session_id: session_id.to_string(),
update: SessionUpdatePayload {
session_update: "agent_message".to_string(),
content: Some(ContentBlock::text(&content)),
tool_call_id: None,
title: None,
kind: None,
status: None,
},
};
if let Ok(update_json) = serde_json::to_string(&serde_json::json!({
"jsonrpc": "2.0",
"method": "session/update",
"params": serde_json::to_value(&update).unwrap_or(serde_json::Value::Null)
})) {
let ev = Self::sse_event(&update_json);
if stream.write_all(ev.as_bytes()).await.is_err() {
return;
}
}
let stop_reason = if cancelled { "cancelled" } else { "end_turn" };
let prompt_result = SessionPromptResult {
stop_reason: stop_reason.to_string(),
};
if let Ok(result_val) = serde_json::to_value(&prompt_result) {
let body = Self::json_rpc_result(id, result_val);
let ev = Self::sse_event(&body);
let _ = stream.write_all(ev.as_bytes()).await;
}
let _ = stream.flush().await;
}
#[allow(clippy::too_many_arguments)]
async fn handle_connection(
mut stream: tokio::net::TcpStream,
config: AcpChannelConfig,
http_config: AcpHttpConfig,
base_config: BaseChannelConfig,
bus: Arc<MessageBus>,
state: Arc<Mutex<AcpHttpState>>,
pending_http: PromptMap,
) {
let open_cors = http_config.auth_token.is_none();
let mut buf = vec![0u8; MAX_REQUEST_BYTES];
let mut total = 0usize;
let read_result = tokio::time::timeout(Duration::from_secs(30), async {
loop {
if total >= buf.len() {
return Err("payload too large");
}
match tokio::time::timeout(Duration::from_secs(10), stream.read(&mut buf[total..]))
.await
{
Ok(Ok(0)) => break,
Ok(Ok(n)) => {
total += n;
if let Some(hend) = Self::find_header_end(&buf[..total]) {
if let Some(req) = Self::parse_request(&buf[..total]) {
let body_received = total - hend - 4;
if body_received >= Self::content_length(&req.headers) {
break;
}
} else {
break;
}
}
}
Ok(Err(e)) => {
debug!("ACP-HTTP: read error: {}", e);
return Err("read error");
}
Err(_) => {
debug!("ACP-HTTP: read timeout");
return Err("read timeout");
}
}
}
Ok(())
})
.await;
match read_result {
Ok(Ok(())) => {}
Ok(Err("payload too large")) => {
let resp = build_http_error(
"413 Payload Too Large",
r#"{"error":"payload too large"}"#,
open_cors,
);
let _ = stream.write_all(resp.as_bytes()).await;
return;
}
Ok(Err(_)) => return,
Err(_) => {
debug!("ACP-HTTP: total request deadline exceeded (slow-loris?)");
return;
}
}
if total == 0 {
return;
}
let req = match Self::parse_request(&buf[..total]) {
Some(r) => r,
None => {
let _ = stream
.write_all(Self::http_400("{}", open_cors).as_bytes())
.await;
return;
}
};
if req.method == "OPTIONS" {
let resp = build_cors_preflight(open_cors);
let _ = stream.write_all(resp.as_bytes()).await;
return;
}
if req.path != "/acp" && req.path != "/" {
let resp = build_http_error("404 Not Found", r#"{"error":"not found"}"#, open_cors);
let _ = stream.write_all(resp.as_bytes()).await;
return;
}
if req.method != "POST" {
let resp = build_http_error(
"405 Method Not Allowed",
r#"{"error":"method not allowed"}"#,
open_cors,
);
let _ = stream.write_all(resp.as_bytes()).await;
return;
}
if !Self::validate_auth(&req.headers, &http_config.auth_token) {
let resp =
build_http_error("401 Unauthorized", r#"{"error":"unauthorized"}"#, open_cors);
let _ = stream.write_all(resp.as_bytes()).await;
return;
}
let rpc: JsonRpcRequest = match serde_json::from_str(&req.body) {
Ok(r) => r,
Err(e) => {
let body = Self::json_rpc_error(None, -32700, &format!("parse error: {}", e));
let resp = Self::http_400(&body, open_cors);
let _ = stream.write_all(resp.as_bytes()).await;
return;
}
};
if rpc.jsonrpc != "2.0" {
let body =
Self::json_rpc_error(rpc.id, -32600, "Invalid Request: jsonrpc must be \"2.0\"");
let resp = Self::http_200(&body, open_cors);
let _ = stream.write_all(resp.as_bytes()).await;
return;
}
let id = rpc.id.clone();
let params = rpc.params.clone();
if id.is_none() {
if rpc.method.as_str() == "session/cancel" {
if let Some(p) = params.and_then(|p| {
serde_json::from_value::<super::acp_protocol::SessionCancelParams>(p).ok()
}) {
let mut st = state.lock().await;
if let Some(pending) = st.pending.get_mut(&p.session_id) {
pending.cancelled = true;
debug!(session_id = %p.session_id, "ACP-HTTP: marked prompt as cancelled (notification)");
}
}
}
let resp = build_204_notification(open_cors);
let _ = stream.write_all(resp.as_bytes()).await;
return;
}
match rpc.method.as_str() {
"initialize" => {
let body = Self::do_initialize(&config, id, params).await;
let resp = Self::http_200(&body, open_cors);
let _ = stream.write_all(resp.as_bytes()).await;
}
"session/new" => {
let body = Self::do_session_new(&state, &base_config, &config, id, params).await;
let resp = Self::http_200(&body, open_cors);
let _ = stream.write_all(resp.as_bytes()).await;
}
"session/cancel" => {
let body = Self::do_session_cancel(&state, &base_config, id, params).await;
let resp = Self::http_200(&body, open_cors);
let _ = stream.write_all(resp.as_bytes()).await;
}
"session/list" => {
let body = Self::do_session_list(&state, &base_config, id, params).await;
let resp = Self::http_200(&body, open_cors);
let _ = stream.write_all(resp.as_bytes()).await;
}
"session/prompt" => {
match Self::register_prompt(
&state,
&pending_http,
&base_config,
&bus,
id.clone(),
params,
)
.await
{
Err(e) => {
error!("ACP-HTTP: session/prompt internal error: {}", e);
let body =
Self::json_rpc_error(id, -32603, &format!("internal error: {}", e));
let resp = Self::http_200(&body, open_cors);
let _ = stream.write_all(resp.as_bytes()).await;
}
Ok(Err(err_body)) => {
let resp = Self::http_200(&err_body, open_cors);
let _ = stream.write_all(resp.as_bytes()).await;
}
Ok(Ok((session_id, rx))) => {
Self::stream_prompt_response(
&mut stream,
&session_id,
id,
rx,
&state,
&pending_http,
open_cors,
)
.await;
}
}
}
_ => {
let body =
Self::json_rpc_error(id, -32601, &format!("method not found: {}", rpc.method));
let resp = Self::http_200(&body, open_cors);
let _ = stream.write_all(resp.as_bytes()).await;
}
}
}
#[allow(clippy::too_many_arguments)]
async fn run_accept_loop(
listener: TcpListener,
config: AcpChannelConfig,
http_config: AcpHttpConfig,
base_config: BaseChannelConfig,
bus: Arc<MessageBus>,
state: Arc<Mutex<AcpHttpState>>,
pending_http: PromptMap,
running: Arc<AtomicBool>,
conn_tasks: Arc<Mutex<JoinSet<()>>>,
) {
info!(
"ACP-HTTP: listening on {}:{}",
http_config.bind, http_config.port
);
let conn_sem = Arc::new(tokio::sync::Semaphore::new(MAX_CONCURRENT_CONNECTIONS));
while running.load(Ordering::SeqCst) {
match tokio::time::timeout(Duration::from_secs(1), listener.accept()).await {
Ok(Ok((stream, addr))) => {
debug!("ACP-HTTP: accepted connection from {}", addr);
let permit = match Arc::clone(&conn_sem).try_acquire_owned() {
Ok(p) => p,
Err(_) => {
debug!(
"ACP-HTTP: connection limit ({}) reached, dropping {}",
MAX_CONCURRENT_CONNECTIONS, addr
);
continue;
}
};
let config = config.clone();
let http_config = http_config.clone();
let base_config = base_config.clone();
let bus = Arc::clone(&bus);
let state = Arc::clone(&state);
let pending_http = Arc::clone(&pending_http);
{
let mut tasks = conn_tasks.lock().await;
while tasks.try_join_next().is_some() {}
tasks.spawn(async move {
let _permit = permit; Self::handle_connection(
stream,
config,
http_config,
base_config,
bus,
state,
pending_http,
)
.await;
});
}
}
Ok(Err(e)) => {
error!("ACP-HTTP: accept error: {}", e);
}
Err(_) => {
}
}
}
running.store(false, Ordering::SeqCst);
info!("ACP-HTTP: accept loop exited");
}
}
#[async_trait]
impl Channel for AcpHttpChannel {
fn name(&self) -> &str {
ACP_HTTP_CHANNEL_NAME
}
async fn start(&mut self) -> Result<()> {
if self.running.swap(true, Ordering::SeqCst) {
info!("ACP-HTTP channel already running");
return Ok(());
}
let addr = format!("{}:{}", self.http_config.bind, self.http_config.port);
let listener = TcpListener::bind(&addr).await.map_err(|e| {
self.running.store(false, Ordering::SeqCst);
ZeptoError::Channel(format!("ACP-HTTP: failed to bind {}: {}", addr, e))
})?;
let config = self.config.clone();
let http_config = self.http_config.clone();
let base_config = self.base_config.clone();
let bus = Arc::clone(&self.bus);
let state = Arc::clone(&self.state);
let pending_http = Arc::clone(&self.pending_http);
let running = Arc::clone(&self.running);
let conn_tasks = Arc::clone(&self.conn_tasks);
let handle = tokio::spawn(async move {
Self::run_accept_loop(
listener,
config,
http_config,
base_config,
bus,
state,
pending_http,
running,
conn_tasks,
)
.await;
});
self.accept_handle = Some(handle);
if self.http_config.auth_token.is_none() {
warn!(
"ACP-HTTP channel started without an auth_token on {}:{}. \
Combined with wildcard CORS, any webpage can reach this endpoint \
(DNS rebinding risk). Set acp.http.auth_token in your config.",
self.http_config.bind, self.http_config.port
);
} else {
info!(
"ACP-HTTP channel started on {}:{}",
self.http_config.bind, self.http_config.port
);
}
Ok(())
}
async fn stop(&mut self) -> Result<()> {
self.running.store(false, Ordering::SeqCst);
self.pending_http.lock().await.clear();
self.state.lock().await.pending.clear();
if let Some(handle) = self.accept_handle.take() {
handle.abort();
let _ = handle.await;
}
{
let mut tasks = self.conn_tasks.lock().await;
tasks.abort_all();
while tasks.join_next().await.is_some() {}
}
Ok(())
}
async fn send(&self, msg: OutboundMessage) -> Result<()> {
if msg.channel != ACP_HTTP_CHANNEL_NAME {
return Ok(());
}
let session_id = msg.chat_id.clone();
let (session_exists, cancelled) = {
let mut st = self.state.lock().await;
let exists = st.sessions.contains_key(&session_id);
let cancelled = st
.pending
.remove(&session_id)
.map(|p| p.cancelled)
.unwrap_or(false);
(exists, cancelled)
};
if !session_exists {
debug!(
session_id = %session_id,
"ACP-HTTP: outbound for unknown session, skipping"
);
return Ok(());
}
let sender = self.pending_http.lock().await.remove(&session_id);
if let Some(tx) = sender {
let _ = tx.send((msg.content, cancelled));
} else {
debug!(
session_id = %session_id,
"ACP-HTTP: proactive message with no waiting handler"
);
}
Ok(())
}
fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
fn is_allowed(&self, user_id: &str) -> bool {
self.base_config.is_allowed(user_id)
}
}
pub(super) fn cwd_contains_traversal(cwd: &str) -> bool {
std::path::Path::new(cwd)
.components()
.any(|c| matches!(c, std::path::Component::ParentDir))
}
fn constant_time_eq(a: &str, b: &str) -> bool {
let a = a.as_bytes();
let b = b.as_bytes();
let max_len = a.len().max(b.len());
let mut result = (a.len() ^ b.len()) as u8;
for i in 0..max_len {
result |= a.get(i).copied().unwrap_or(0) ^ b.get(i).copied().unwrap_or(0);
}
result == 0
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bus::MessageBus;
use crate::config::{AcpChannelConfig, AcpHttpConfig};
fn make_channel() -> AcpHttpChannel {
let bus = Arc::new(MessageBus::new());
AcpHttpChannel::new(
AcpChannelConfig::default(),
AcpHttpConfig::default(),
BaseChannelConfig::new(ACP_HTTP_CHANNEL_NAME),
bus,
)
}
#[test]
fn test_channel_name() {
let ch = make_channel();
assert_eq!(ch.name(), ACP_HTTP_CHANNEL_NAME);
}
#[test]
fn test_is_not_running_initially() {
assert!(!make_channel().is_running());
}
#[test]
fn test_prompt_blocks_to_text_text_only() {
use super::super::acp_protocol::PromptContentBlock;
let blocks = vec![
PromptContentBlock::Text {
text: "Hello".to_string(),
},
PromptContentBlock::Text {
text: "World".to_string(),
},
];
assert_eq!(
crate::channels::acp_protocol::prompt_blocks_to_text(&blocks),
"Hello\nWorld"
);
}
#[test]
fn test_prompt_blocks_to_text_skips_non_text() {
use super::super::acp_protocol::PromptContentBlock;
let blocks = vec![
PromptContentBlock::Text {
text: "only this".to_string(),
},
PromptContentBlock::Other,
];
assert_eq!(
crate::channels::acp_protocol::prompt_blocks_to_text(&blocks),
"only this"
);
}
#[tokio::test]
async fn test_send_ignores_wrong_channel() {
let ch = make_channel();
let session_id = "acph_test".to_string();
{
let mut st = ch.state.lock().await;
st.sessions.insert(
session_id.clone(),
SessionEntry {
cwd: "/test".to_string(),
last_active: std::time::Instant::now(),
},
);
st.pending
.insert(session_id.clone(), PendingPrompt { cancelled: false });
}
let msg = OutboundMessage {
channel: "telegram".to_string(),
chat_id: session_id.clone(),
content: "hello".to_string(),
reply_to: None,
metadata: Default::default(),
};
assert!(ch.send(msg).await.is_ok());
let st = ch.state.lock().await;
assert!(
st.pending.contains_key(&session_id),
"wrong-channel send must not consume the pending entry"
);
}
#[tokio::test]
async fn test_send_skips_unknown_session() {
let ch = make_channel();
let msg = OutboundMessage {
channel: ACP_HTTP_CHANNEL_NAME.to_string(),
chat_id: "acph_ghost".to_string(),
content: "hello".to_string(),
reply_to: None,
metadata: Default::default(),
};
assert!(ch.send(msg).await.is_ok());
assert!(ch.state.lock().await.sessions.is_empty());
}
#[tokio::test]
async fn test_send_delivers_via_oneshot() {
let ch = make_channel();
let session_id = "acph_deliver".to_string();
let (tx, rx) = oneshot::channel::<(String, bool)>();
{
let mut st = ch.state.lock().await;
st.sessions.insert(
session_id.clone(),
SessionEntry {
cwd: "/test".to_string(),
last_active: std::time::Instant::now(),
},
);
st.pending
.insert(session_id.clone(), PendingPrompt { cancelled: false });
ch.pending_http.lock().await.insert(session_id.clone(), tx);
}
let msg = OutboundMessage {
channel: ACP_HTTP_CHANNEL_NAME.to_string(),
chat_id: session_id.clone(),
content: "agent reply".to_string(),
reply_to: None,
metadata: Default::default(),
};
assert!(ch.send(msg).await.is_ok());
let (content, cancelled) = rx.await.expect("must receive payload");
assert_eq!(content, "agent reply");
assert!(!cancelled);
assert!(!ch.state.lock().await.pending.contains_key(&session_id));
}
#[tokio::test]
async fn test_send_marks_cancelled() {
let ch = make_channel();
let session_id = "acph_cancel".to_string();
let (tx, rx) = oneshot::channel::<(String, bool)>();
{
let mut st = ch.state.lock().await;
st.sessions.insert(
session_id.clone(),
SessionEntry {
cwd: "/test".to_string(),
last_active: std::time::Instant::now(),
},
);
st.pending
.insert(session_id.clone(), PendingPrompt { cancelled: true });
ch.pending_http.lock().await.insert(session_id.clone(), tx);
}
let msg = OutboundMessage {
channel: ACP_HTTP_CHANNEL_NAME.to_string(),
chat_id: session_id.clone(),
content: "reply after cancel".to_string(),
reply_to: None,
metadata: Default::default(),
};
assert!(ch.send(msg).await.is_ok());
let (_content, cancelled) = rx.await.expect("must receive payload");
assert!(
cancelled,
"cancelled flag must be forwarded to HTTP handler"
);
}
#[tokio::test]
async fn test_deny_by_default_blocks_session_new() {
let bus = Arc::new(MessageBus::new());
let base = BaseChannelConfig {
name: ACP_HTTP_CHANNEL_NAME.to_string(),
allowlist: vec![],
deny_by_default: true,
};
let ch = AcpHttpChannel::new(
AcpChannelConfig {
deny_by_default: true,
..AcpChannelConfig::default()
},
AcpHttpConfig::default(),
base,
bus,
);
let result = AcpHttpChannel::do_session_new(
&ch.state,
&ch.base_config,
&ch.config,
Some(serde_json::json!(1)),
Some(serde_json::json!({ "cwd": "/workspace" })),
)
.await;
assert!(
result.contains("Unauthorized"),
"deny_by_default must block session/new"
);
assert!(ch.state.lock().await.sessions.is_empty());
}
#[test]
fn test_constant_time_eq_same() {
assert!(constant_time_eq("hello", "hello"));
}
#[test]
fn test_constant_time_eq_different() {
assert!(!constant_time_eq("hello", "world"));
assert!(!constant_time_eq("hello", "hello!"));
}
#[test]
fn test_sse_event_format() {
let ev = AcpHttpChannel::sse_event(r#"{"foo":"bar"}"#);
assert_eq!(ev, "data: {\"foo\":\"bar\"}\n\n");
}
#[test]
fn test_http_200_content_length() {
let body = r#"{"result":"ok"}"#;
let resp = AcpHttpChannel::http_200(body, true);
assert!(resp.contains(&format!("Content-Length: {}", body.len())));
assert!(resp.ends_with(body));
}
#[tokio::test]
async fn test_session_list_no_sessions() {
let ch = make_channel();
let body = AcpHttpChannel::do_session_list(
&ch.state,
&ch.base_config,
Some(serde_json::json!(1)),
None,
)
.await;
let v: serde_json::Value = serde_json::from_str(&body).unwrap();
assert_eq!(v["result"]["sessions"], serde_json::json!([]));
}
#[tokio::test]
async fn test_session_list_empty() {
let ch = make_channel();
let body = AcpHttpChannel::do_session_list(
&ch.state,
&ch.base_config,
Some(serde_json::json!(1)),
None,
)
.await;
let v: serde_json::Value = serde_json::from_str(&body).unwrap();
assert_eq!(v["result"]["sessions"], serde_json::json!([]));
}
#[tokio::test]
async fn test_session_list_shows_sessions_with_pending_flag() {
let ch = make_channel();
let sid_a = "acph_list_a".to_string();
let sid_b = "acph_list_b".to_string();
{
let mut st = ch.state.lock().await;
st.sessions.insert(
sid_a.clone(),
SessionEntry {
cwd: "/test".to_string(),
last_active: std::time::Instant::now(),
},
);
st.sessions.insert(
sid_b.clone(),
SessionEntry {
cwd: "/test".to_string(),
last_active: std::time::Instant::now(),
},
);
st.pending
.insert(sid_a.clone(), PendingPrompt { cancelled: false });
}
let body = AcpHttpChannel::do_session_list(
&ch.state,
&ch.base_config,
Some(serde_json::json!(1)),
None,
)
.await;
let v: serde_json::Value = serde_json::from_str(&body).unwrap();
let sessions = v["result"]["sessions"].as_array().unwrap();
assert_eq!(sessions.len(), 2);
let find = |sid: &str| sessions.iter().find(|s| s["sessionId"] == sid).cloned();
assert_eq!(find(&sid_a).unwrap()["_meta"]["pending"], true);
assert_eq!(find(&sid_b).unwrap()["_meta"]["pending"], false);
}
#[tokio::test]
async fn test_session_cancel_notification_returns_204_no_body() {
let ch = make_channel();
let session_id = "acph_notif_cancel".to_string();
{
let mut st = ch.state.lock().await;
st.sessions.insert(
session_id.clone(),
SessionEntry {
cwd: "/test".to_string(),
last_active: std::time::Instant::now(),
},
);
st.pending
.insert(session_id.clone(), PendingPrompt { cancelled: false });
}
{
let params = serde_json::json!({ "sessionId": session_id });
if let Ok(p) = serde_json::from_value::<
crate::channels::acp_protocol::SessionCancelParams,
>(params.clone())
{
let mut st = ch.state.lock().await;
if let Some(pending) = st.pending.get_mut(&p.session_id) {
pending.cancelled = true;
}
}
}
assert!(
ch.state
.lock()
.await
.pending
.get(&session_id)
.unwrap()
.cancelled,
"cancel notification must set the cancelled flag"
);
}
#[tokio::test]
async fn test_session_cancel_request_returns_json_result() {
let ch = make_channel();
let session_id = "acph_req_cancel".to_string();
{
let mut st = ch.state.lock().await;
st.sessions.insert(
session_id.clone(),
SessionEntry {
cwd: "/test".to_string(),
last_active: std::time::Instant::now(),
},
);
st.pending
.insert(session_id.clone(), PendingPrompt { cancelled: false });
}
let body = AcpHttpChannel::do_session_cancel(
&ch.state,
&ch.base_config,
Some(serde_json::json!(42)),
Some(serde_json::json!({ "sessionId": session_id })),
)
.await;
let v: serde_json::Value = serde_json::from_str(&body).unwrap();
assert_eq!(v["id"], 42, "response must echo the request id");
assert!(
v["result"].is_object(),
"result must be present for a request"
);
assert!(v.get("error").is_none(), "no error field expected");
}
#[test]
fn test_http_204_notification_has_no_body() {
let resp = build_204_notification(true);
assert!(resp.contains("204 No Content"), "must be a 204 status");
assert!(
resp.contains("Content-Length: 0"),
"Content-Length must be 0"
);
assert!(
resp.ends_with("\r\n\r\n"),
"response must end with the blank line and no trailing body"
);
}
#[tokio::test]
async fn test_initialize_returns_spec_fields() {
let ch = make_channel();
let body =
AcpHttpChannel::do_initialize(&ch.config, Some(serde_json::json!(1)), None).await;
let v: serde_json::Value = serde_json::from_str(&body).unwrap();
assert!(
v["result"]["protocolVersion"].is_string(),
"protocolVersion must be a string"
);
assert!(
v["result"]["agentCapabilities"].is_object(),
"agentCapabilities must be present"
);
assert!(
v["result"].get("clientId").is_none(),
"clientId must not be in the response"
);
}
#[tokio::test]
async fn test_initialize_to_session_new_round_trip() {
let ch = make_channel();
AcpHttpChannel::do_initialize(&ch.config, Some(serde_json::json!(1)), None).await;
let new_body = AcpHttpChannel::do_session_new(
&ch.state,
&ch.base_config,
&ch.config,
Some(serde_json::json!(2)),
Some(serde_json::json!({ "cwd": "/workspace" })),
)
.await;
let new_v: serde_json::Value = serde_json::from_str(&new_body).unwrap();
assert!(
new_v.get("error").is_none(),
"session/new must succeed: {new_body}"
);
assert!(
new_v["result"]["sessionId"].as_str().is_some(),
"session/new must return a sessionId: {new_body}"
);
}
#[tokio::test]
async fn test_initialize_to_session_new_to_session_list_round_trip() {
let ch = make_channel();
AcpHttpChannel::do_initialize(&ch.config, Some(serde_json::json!(1)), None).await;
let new_body = AcpHttpChannel::do_session_new(
&ch.state,
&ch.base_config,
&ch.config,
Some(serde_json::json!(2)),
Some(serde_json::json!({ "cwd": "/workspace" })),
)
.await;
let session_id = serde_json::from_str::<serde_json::Value>(&new_body).unwrap()["result"]
["sessionId"]
.as_str()
.expect("session/new must succeed")
.to_string();
let list_body = AcpHttpChannel::do_session_list(
&ch.state,
&ch.base_config,
Some(serde_json::json!(3)),
None,
)
.await;
let list_v: serde_json::Value = serde_json::from_str(&list_body).unwrap();
assert!(
list_v.get("error").is_none(),
"session/list must succeed: {list_body}"
);
let sessions = list_v["result"]["sessions"].as_array().unwrap();
assert_eq!(sessions.len(), 1, "exactly one session must be listed");
assert_eq!(
sessions[0]["sessionId"].as_str().unwrap(),
session_id,
"listed sessionId must match the one returned by session/new"
);
}
#[tokio::test]
async fn test_session_new_rejects_missing_cwd() {
let ch = make_channel();
let body = AcpHttpChannel::do_session_new(
&ch.state,
&ch.base_config,
&ch.config,
Some(serde_json::json!(1)),
None,
)
.await;
let v: serde_json::Value = serde_json::from_str(&body).unwrap();
assert_eq!(
v["error"]["code"], -32602,
"missing cwd must give -32602: {body}"
);
assert!(ch.state.lock().await.sessions.is_empty());
}
#[tokio::test]
async fn test_session_new_rejects_relative_cwd() {
let ch = make_channel();
let body = AcpHttpChannel::do_session_new(
&ch.state,
&ch.base_config,
&ch.config,
Some(serde_json::json!(1)),
Some(serde_json::json!({ "cwd": "relative/path" })),
)
.await;
let v: serde_json::Value = serde_json::from_str(&body).unwrap();
assert_eq!(
v["error"]["code"], -32602,
"relative cwd must give -32602: {body}"
);
assert!(ch.state.lock().await.sessions.is_empty());
}
#[tokio::test]
async fn test_session_new_stores_absolute_cwd() {
let ch = make_channel();
let body = AcpHttpChannel::do_session_new(
&ch.state,
&ch.base_config,
&ch.config,
Some(serde_json::json!(1)),
Some(serde_json::json!({ "cwd": "/home/user/project" })),
)
.await;
let v: serde_json::Value = serde_json::from_str(&body).unwrap();
assert!(
v.get("error").is_none(),
"absolute cwd must succeed: {body}"
);
let sid = v["result"]["sessionId"].as_str().unwrap().to_string();
let st = ch.state.lock().await;
assert_eq!(
st.sessions.get(&sid).map(|e| e.cwd.as_str()),
Some("/home/user/project")
);
}
#[tokio::test]
async fn test_register_prompt_blocked_without_client_id() {
let ch = make_channel();
let result = AcpHttpChannel::register_prompt(
&ch.state,
&ch.pending_http,
&ch.base_config,
&ch.bus,
Some(serde_json::json!(2)),
Some(serde_json::json!({
"sessionId": "acph_nonexistent",
"prompt": [{ "type": "text", "text": "hello" }]
})),
)
.await
.expect("register_prompt must not return an Err variant here");
assert!(
result.is_err(),
"register_prompt must reject calls with an unknown session"
);
let err_body = result.unwrap_err();
assert!(
err_body.contains("-32600") || err_body.contains("-32000"),
"rejection must be an RPC error: {err_body}"
);
}
#[test]
fn test_http_200_no_cors_when_auth() {
let body = r#"{"result":"ok"}"#;
let resp = AcpHttpChannel::http_200(body, false);
assert!(
!resp.contains("Access-Control-Allow-Origin"),
"CORS header must be absent when open_cors=false"
);
assert!(resp.contains(&format!("Content-Length: {}", body.len())));
}
#[test]
fn test_cors_preflight_open() {
let resp = build_cors_preflight(true);
assert!(resp.contains("Access-Control-Allow-Origin: *"));
assert!(resp.contains("Access-Control-Allow-Methods"));
assert!(resp.contains("204 No Content"));
}
#[test]
fn test_cors_preflight_restricted() {
let resp = build_cors_preflight(false);
assert!(
!resp.contains("Access-Control-Allow-Origin"),
"CORS origin header must be absent when restricted"
);
assert!(resp.contains("Access-Control-Allow-Methods"));
assert!(resp.contains("204 No Content"));
}
#[test]
fn test_build_http_error_cors() {
let open = build_http_error("400 Bad Request", r#"{"error":"bad"}"#, true);
assert!(open.contains("Access-Control-Allow-Origin: *"));
let restricted = build_http_error("400 Bad Request", r#"{"error":"bad"}"#, false);
assert!(
!restricted.contains("Access-Control-Allow-Origin"),
"CORS header must be absent when open_cors=false"
);
}
#[test]
fn test_cwd_contains_traversal() {
assert!(cwd_contains_traversal("/home/user/../etc/passwd"));
assert!(cwd_contains_traversal("/home/user/.."));
assert!(cwd_contains_traversal("/../root"));
assert!(!cwd_contains_traversal("/home/user/project"));
assert!(!cwd_contains_traversal("/home/user/..hidden"));
assert!(!cwd_contains_traversal("/home/user/a..b"));
}
#[tokio::test]
async fn test_session_new_rejects_traversal_cwd() {
let ch = make_channel();
let body = AcpHttpChannel::do_session_new(
&ch.state,
&ch.base_config,
&ch.config,
Some(serde_json::json!(1)),
Some(serde_json::json!({ "cwd": "/home/user/../etc" })),
)
.await;
let v: serde_json::Value = serde_json::from_str(&body).unwrap();
assert_eq!(
v["error"]["code"], -32602,
"traversal cwd must give -32602: {body}"
);
assert!(
v["error"]["message"].as_str().unwrap().contains("'..'"),
"error message must mention '..': {body}"
);
assert!(ch.state.lock().await.sessions.is_empty());
}
}