use std::cell::RefCell;
use std::rc::Rc;
use std::sync::Arc;
use std::time::Instant;
use acp::Client as _;
use agent_client_protocol as acp;
use tokio::sync::mpsc;
use bitrouter_core::agents::event::{
AgentEvent, PermissionOutcome, PermissionResponse, StopReason,
};
use bitrouter_core::agents::provider::AgentProvider;
use bitrouter_core::auth::claims::BitrouterClaims;
use bitrouter_core::auth::token as jwt_token;
use bitrouter_core::observe::{
AgentObserveCallback, AgentRequestContext, AgentTurnFailureEvent, AgentTurnSuccessEvent,
CallerContext,
};
use super::provider::AcpAgentProvider;
pub struct ProxyConfig {
pub agent_name: String,
pub pre_auth_token: Option<String>,
pub operator_caip10: Option<String>,
pub observer: Option<Arc<dyn AgentObserveCallback>>,
}
type ConsumerConn = Rc<RefCell<Option<Rc<acp::AgentSideConnection>>>>;
pub struct ProxyAgent {
provider: Arc<AcpAgentProvider>,
consumer_conn: ConsumerConn,
config: ProxyConfig,
upstream_session_id: RefCell<Option<String>>,
authenticated: RefCell<bool>,
caller_context: RefCell<CallerContext>,
}
impl ProxyAgent {
pub fn new(
provider: Arc<AcpAgentProvider>,
consumer_conn: ConsumerConn,
config: ProxyConfig,
) -> Self {
let authenticated = config.pre_auth_token.is_some();
let caller_context = if let Some(ref token) = config.pre_auth_token {
claims_to_caller_context(jwt_token::verify(token).ok().as_ref())
} else {
CallerContext::default()
};
Self {
provider,
consumer_conn,
config,
upstream_session_id: RefCell::new(None),
authenticated: RefCell::new(authenticated),
caller_context: RefCell::new(caller_context),
}
}
async fn notify_consumer(
&self,
session_id: acp::SessionId,
update: acp::SessionUpdate,
) -> acp::Result<()> {
let conn = self.consumer_conn.borrow().clone();
let Some(conn) = conn else {
return Err(acp::Error::internal_error());
};
conn.session_notification(acp::SessionNotification::new(session_id, update))
.await
}
async fn forward_permission_to_consumer(
&self,
req: acp::RequestPermissionRequest,
) -> acp::Result<acp::RequestPermissionResponse> {
let conn = self.consumer_conn.borrow().clone();
let Some(conn) = conn else {
return Err(acp::Error::internal_error());
};
conn.request_permission(req).await
}
async fn relay_events(
&self,
consumer_session_id: acp::SessionId,
mut rx: mpsc::Receiver<AgentEvent>,
) -> acp::Result<acp::PromptResponse> {
while let Some(event) = rx.recv().await {
match event {
AgentEvent::TurnDone { stop_reason } => {
return Ok(acp::PromptResponse::new(to_acp_stop_reason(stop_reason)));
}
AgentEvent::Error { message } => {
return Err(acp::Error::new(
i32::from(acp::ErrorCode::InternalError),
message,
));
}
AgentEvent::Disconnected => {
return Err(acp::Error::new(
i32::from(acp::ErrorCode::InternalError),
"upstream agent disconnected",
));
}
AgentEvent::MessageChunk { text } => {
let update = acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
acp::ContentBlock::Text(acp::TextContent::new(text)),
));
self.notify_consumer(consumer_session_id.clone(), update)
.await?;
}
AgentEvent::ThoughtChunk { text } => {
let update = acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
acp::ContentBlock::Text(acp::TextContent::new(text)),
));
self.notify_consumer(consumer_session_id.clone(), update)
.await?;
}
AgentEvent::NonTextContent { description } => {
let update = acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
acp::ContentBlock::Text(acp::TextContent::new(description)),
));
self.notify_consumer(consumer_session_id.clone(), update)
.await?;
}
AgentEvent::ToolCall {
tool_call_id,
title,
status,
} => {
let update = acp::SessionUpdate::ToolCall(
acp::ToolCall::new(tool_call_id, title)
.status(convert_tool_call_status_to_acp(status)),
);
self.notify_consumer(consumer_session_id.clone(), update)
.await?;
}
AgentEvent::ToolCallUpdate {
tool_call_id,
title,
status,
} => {
let mut fields = acp::ToolCallUpdateFields::new();
if let Some(title) = title {
fields = fields.title(title);
}
if let Some(status) = status {
fields = fields.status(convert_tool_call_status_to_acp(status));
}
let tc_update = acp::ToolCallUpdate::new(tool_call_id, fields);
let update = acp::SessionUpdate::ToolCallUpdate(tc_update);
self.notify_consumer(consumer_session_id.clone(), update)
.await?;
}
AgentEvent::PermissionRequest { id, request } => {
let options: Vec<acp::PermissionOption> = request
.options
.iter()
.map(|opt| {
acp::PermissionOption::new(
opt.id.clone(),
opt.title.clone(),
acp::PermissionOptionKind::AllowOnce,
)
})
.collect();
let fields = acp::ToolCallUpdateFields::new()
.title(request.title.clone())
.status(acp::ToolCallStatus::Pending);
let tool_call_update = acp::ToolCallUpdate::new("permission", fields);
let acp_req = acp::RequestPermissionRequest::new(
consumer_session_id.clone(),
tool_call_update,
options,
);
let acp_resp = self.forward_permission_to_consumer(acp_req).await?;
let response = match acp_resp.outcome {
acp::RequestPermissionOutcome::Selected(sel) => PermissionResponse {
outcome: PermissionOutcome::Allowed {
selected_option: sel.option_id.to_string(),
},
},
_ => PermissionResponse {
outcome: PermissionOutcome::Denied,
},
};
let session_id = self.upstream_session_id.borrow().clone();
if let Some(ref session_id) = session_id {
let _ = self
.provider
.respond_permission(session_id, id, response)
.await;
}
}
AgentEvent::HistoryReplayDone => {
}
}
}
Err(acp::Error::new(
i32::from(acp::ErrorCode::InternalError),
"upstream event stream ended unexpectedly",
))
}
}
#[async_trait::async_trait(?Send)]
impl acp::Agent for ProxyAgent {
async fn initialize(
&self,
_args: acp::InitializeRequest,
) -> acp::Result<acp::InitializeResponse> {
let mut auth_methods = Vec::new();
if !*self.authenticated.borrow() {
auth_methods.push(acp::AuthMethod::Agent(acp::AuthMethodAgent::new(
"bitrouter-jwt",
"BitRouter JWT",
)));
}
Ok(acp::InitializeResponse::new(acp::ProtocolVersion::V1)
.agent_capabilities(acp::AgentCapabilities::default())
.auth_methods(auth_methods)
.agent_info(
acp::Implementation::new("bitrouter", env!("CARGO_PKG_VERSION"))
.title("BitRouter Agent Proxy"),
))
}
async fn authenticate(
&self,
args: acp::AuthenticateRequest,
) -> acp::Result<acp::AuthenticateResponse> {
if *self.authenticated.borrow() {
return Ok(acp::AuthenticateResponse::new());
}
let token = args
.meta
.as_ref()
.and_then(|m| m.get("token"))
.and_then(|v| v.as_str())
.ok_or_else(|| {
acp::Error::new(
i32::from(acp::ErrorCode::InvalidParams),
"missing token in _meta.token",
)
})?;
let claims = jwt_token::verify(token).map_err(|e| {
acp::Error::new(
i32::from(acp::ErrorCode::InvalidParams),
format!("invalid JWT: {e}"),
)
})?;
jwt_token::check_expiration(&claims).map_err(|_| {
acp::Error::new(i32::from(acp::ErrorCode::InvalidParams), "JWT expired")
})?;
if let Some(ref expected) = self.config.operator_caip10
&& claims.iss != *expected
{
return Err(acp::Error::new(
i32::from(acp::ErrorCode::InvalidParams),
"JWT issuer does not match operator wallet",
));
}
*self.caller_context.borrow_mut() = claims_to_caller_context(Some(&claims));
*self.authenticated.borrow_mut() = true;
Ok(acp::AuthenticateResponse::new())
}
async fn new_session(
&self,
args: acp::NewSessionRequest,
) -> acp::Result<acp::NewSessionResponse> {
if !*self.authenticated.borrow() {
return Err(acp::Error::new(
i32::from(acp::ErrorCode::InvalidRequest),
"not authenticated",
));
}
let session_info = self.provider.connect(&args.cwd).await.map_err(|e| {
acp::Error::new(
i32::from(acp::ErrorCode::InternalError),
format!("upstream connect failed: {e}"),
)
})?;
let upstream_id = session_info.session_id.clone();
*self.upstream_session_id.borrow_mut() = Some(upstream_id);
Ok(acp::NewSessionResponse::new(session_info.session_id))
}
async fn prompt(&self, args: acp::PromptRequest) -> acp::Result<acp::PromptResponse> {
let upstream_session_id = self.upstream_session_id.borrow().clone().ok_or_else(|| {
acp::Error::new(
i32::from(acp::ErrorCode::InvalidRequest),
"no session — call new_session first",
)
})?;
let text = extract_prompt_text(&args.prompt);
let turn_start = Instant::now();
let rx = self
.provider
.submit(&upstream_session_id, text)
.await
.map_err(|e| {
acp::Error::new(
i32::from(acp::ErrorCode::InternalError),
format!("upstream submit failed: {e}"),
)
})?;
let result = self.relay_events(args.session_id, rx).await;
let latency_ms = turn_start.elapsed().as_millis() as u64;
if let Some(ref observer) = self.config.observer {
let ctx = AgentRequestContext {
agent_name: self.config.agent_name.clone(),
protocol: "acp".to_owned(),
session_id: self.upstream_session_id.borrow().clone(),
caller: self.caller_context.borrow().clone(),
latency_ms,
};
match &result {
Ok(_) => {
observer
.on_agent_turn_success(AgentTurnSuccessEvent { ctx })
.await;
}
Err(e) => {
observer
.on_agent_turn_failure(AgentTurnFailureEvent {
ctx,
error: e.message.clone(),
})
.await;
}
}
}
result
}
async fn cancel(&self, _args: acp::CancelNotification) -> acp::Result<()> {
Ok(())
}
}
fn extract_prompt_text(blocks: &[acp::ContentBlock]) -> String {
let mut parts = Vec::new();
for block in blocks {
match block {
acp::ContentBlock::Text(tc) => parts.push(tc.text.clone()),
acp::ContentBlock::Image(_) => parts.push("[image]".to_owned()),
acp::ContentBlock::Audio(_) => parts.push("[audio]".to_owned()),
acp::ContentBlock::ResourceLink(rl) => {
parts.push(format!("[{}]({})", rl.name, rl.uri));
}
acp::ContentBlock::Resource(_) => parts.push("[resource]".to_owned()),
_ => parts.push("[unknown content]".to_owned()),
}
}
parts.join("\n")
}
fn to_acp_stop_reason(reason: StopReason) -> acp::StopReason {
match reason {
StopReason::EndTurn => acp::StopReason::EndTurn,
StopReason::MaxTokens => acp::StopReason::MaxTokens,
StopReason::StopSequence | StopReason::ToolUse | StopReason::Other(_) => {
acp::StopReason::EndTurn
}
}
}
fn convert_tool_call_status_to_acp(
status: bitrouter_core::agents::event::ToolCallStatus,
) -> acp::ToolCallStatus {
match status {
bitrouter_core::agents::event::ToolCallStatus::Pending => acp::ToolCallStatus::Pending,
bitrouter_core::agents::event::ToolCallStatus::InProgress => {
acp::ToolCallStatus::InProgress
}
bitrouter_core::agents::event::ToolCallStatus::Completed => acp::ToolCallStatus::Completed,
bitrouter_core::agents::event::ToolCallStatus::Failed => acp::ToolCallStatus::Failed,
}
}
fn claims_to_caller_context(claims: Option<&BitrouterClaims>) -> CallerContext {
let Some(claims) = claims else {
return CallerContext::default();
};
CallerContext {
account_id: Some(claims.iss.clone()),
key_id: claims.id.clone(),
models: claims.mdl.clone(),
budget: claims.bgt,
budget_scope: claims.bsc,
issued_at: claims.iat,
key: claims.key.clone(),
chain: None,
policy_id: claims.pol.clone(),
}
}
pub fn run_stdio_proxy(provider: Arc<AcpAgentProvider>, config: ProxyConfig) -> Result<(), String> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| format!("failed to create runtime: {e}"))?;
let local = tokio::task::LocalSet::new();
rt.block_on(local.run_until(run_proxy_local(provider, config)))
}
async fn run_proxy_local(
provider: Arc<AcpAgentProvider>,
config: ProxyConfig,
) -> Result<(), String> {
let stdin = tokio::io::stdin();
let stdout = tokio::io::stdout();
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
let stdin = stdin.compat();
let stdout = stdout.compat_write();
let consumer_conn: ConsumerConn = Rc::new(RefCell::new(None));
let proxy = ProxyAgent::new(provider, consumer_conn.clone(), config);
let (conn, io_future) = acp::AgentSideConnection::new(proxy, stdout, stdin, |fut| {
tokio::task::spawn_local(fut);
});
*consumer_conn.borrow_mut() = Some(Rc::new(conn));
let result = io_future.await;
consumer_conn.borrow_mut().take();
result.map_err(|e| format!("proxy I/O error: {e}"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_text_from_content_blocks() {
let blocks = vec![
acp::ContentBlock::Text(acp::TextContent::new("Hello")),
acp::ContentBlock::Text(acp::TextContent::new("World")),
];
assert_eq!(extract_prompt_text(&blocks), "Hello\nWorld");
}
#[test]
fn extract_text_empty_blocks() {
let blocks: Vec<acp::ContentBlock> = Vec::new();
assert_eq!(extract_prompt_text(&blocks), "");
}
#[test]
fn stop_reason_round_trip() {
assert!(matches!(
to_acp_stop_reason(StopReason::EndTurn),
acp::StopReason::EndTurn
));
assert!(matches!(
to_acp_stop_reason(StopReason::MaxTokens),
acp::StopReason::MaxTokens
));
assert!(matches!(
to_acp_stop_reason(StopReason::Other("custom".into())),
acp::StopReason::EndTurn
));
}
#[test]
fn claims_to_context_with_none() {
let ctx = claims_to_caller_context(None);
assert!(ctx.account_id.is_none());
assert!(ctx.key_id.is_none());
}
}