use agent_client_protocol::{BoxFuture, Channel, ConnectTo, jsonrpcmsg::Message, role::mcp};
use axum::{
Router,
extract::State,
http::StatusCode,
response::{IntoResponse, Response, Sse},
routing::post,
};
use futures::{SinkExt, StreamExt as _, channel::mpsc, future::Either, stream::Stream};
use futures_concurrency::future::FutureExt as _;
use futures_concurrency::stream::StreamExt as _;
use rustc_hash::FxHashMap;
use std::{
collections::{HashMap, VecDeque},
pin::pin,
sync::Arc,
};
use tokio::net::TcpListener;
use crate::conductor::{
ConductorMessage,
mcp_bridge::{McpBridgeConnection, McpBridgeConnectionActor},
};
pub async fn run_http_listener(
tcp_listener: TcpListener,
acp_url: String,
mut conductor_tx: mpsc::Sender<ConductorMessage>,
) -> Result<(), agent_client_protocol::Error> {
let (to_mcp_client_tx, to_mcp_client_rx) = mpsc::channel(128);
conductor_tx
.send(ConductorMessage::McpConnectionReceived {
acp_url,
actor: McpBridgeConnectionActor::new(
HttpMcpBridge::new(tcp_listener),
conductor_tx.clone(),
to_mcp_client_rx,
),
connection: McpBridgeConnection::new(to_mcp_client_tx),
})
.await
.map_err(|_| agent_client_protocol::Error::internal_error())?;
Ok(())
}
struct HttpMcpBridge {
listener: tokio::net::TcpListener,
}
impl HttpMcpBridge {
fn new(listener: tokio::net::TcpListener) -> Self {
Self { listener }
}
}
impl ConnectTo<mcp::Client> for HttpMcpBridge {
async fn connect_to(
self,
client: impl ConnectTo<mcp::Server>,
) -> Result<(), agent_client_protocol::Error> {
let (channel, serve_self) = self.into_channel_and_future();
match futures::future::select(pin!(client.connect_to(channel)), serve_self).await {
Either::Left((result, _)) | Either::Right((result, _)) => result,
}
}
fn into_channel_and_future(
self,
) -> (
Channel,
BoxFuture<'static, Result<(), agent_client_protocol::Error>>,
)
where
Self: Sized,
{
let (channel_a, channel_b) = Channel::duplex();
(channel_a, Box::pin(run(self.listener, channel_b)))
}
}
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
struct HttpError(#[from] agent_client_protocol::Error);
impl From<axum::Error> for HttpError {
fn from(error: axum::Error) -> Self {
HttpError(agent_client_protocol::util::internal_error(error))
}
}
impl IntoResponse for HttpError {
fn into_response(self) -> Response {
let message = format!("Error: {}", self.0);
(StatusCode::INTERNAL_SERVER_ERROR, message).into_response()
}
}
async fn run(listener: TcpListener, channel: Channel) -> Result<(), agent_client_protocol::Error> {
let (registration_tx, registration_rx) = mpsc::unbounded();
let state = BridgeState { registration_tx };
async {
let app = Router::new()
.route("/", post(handle_post).get(handle_get))
.with_state(Arc::new(state));
axum::serve(listener, app)
.await
.map_err(agent_client_protocol::util::internal_error)
}
.race(RunningServer::new().run(channel, registration_rx))
.await
}
struct BridgeState {
registration_tx: mpsc::UnboundedSender<HttpMessage>,
}
#[derive(Debug)]
enum HttpMessage {
Request {
http_request_id: uuid::Uuid,
request: agent_client_protocol::jsonrpcmsg::Request,
response_tx: mpsc::UnboundedSender<agent_client_protocol::jsonrpcmsg::Message>,
},
Notification {
http_request_id: uuid::Uuid,
request: agent_client_protocol::jsonrpcmsg::Request,
},
Response {
http_request_id: uuid::Uuid,
response: agent_client_protocol::jsonrpcmsg::Response,
},
Get {
http_request_id: uuid::Uuid,
response_tx: mpsc::UnboundedSender<agent_client_protocol::jsonrpcmsg::Message>,
},
}
#[derive(Eq, PartialEq, PartialOrd, Ord, Hash, Debug, Clone)]
enum JsonRpcId {
String(String),
Number(u64),
Null,
}
impl From<agent_client_protocol::jsonrpcmsg::Id> for JsonRpcId {
fn from(id: agent_client_protocol::jsonrpcmsg::Id) -> Self {
match id {
agent_client_protocol::jsonrpcmsg::Id::String(s) => JsonRpcId::String(s),
agent_client_protocol::jsonrpcmsg::Id::Number(n) => JsonRpcId::Number(n),
agent_client_protocol::jsonrpcmsg::Id::Null => JsonRpcId::Null,
}
}
}
struct RunningServer {
waiting_sessions: FxHashMap<JsonRpcId, RegisteredSession>,
general_sessions: Vec<RegisteredSession>,
message_deque: VecDeque<agent_client_protocol::jsonrpcmsg::Message>,
}
impl RunningServer {
fn new() -> Self {
RunningServer {
waiting_sessions: HashMap::default(),
general_sessions: Vec::default(),
message_deque: VecDeque::with_capacity(32),
}
}
async fn run(
mut self,
mut channel: Channel,
http_rx: mpsc::UnboundedReceiver<HttpMessage>,
) -> Result<(), agent_client_protocol::Error> {
#[derive(Debug)]
enum MultiplexMessage {
FromHttpToChannel(HttpMessage),
FromChannelToHttp(
Result<agent_client_protocol::jsonrpcmsg::Message, agent_client_protocol::Error>,
),
}
let mut merged_stream = http_rx
.map(MultiplexMessage::FromHttpToChannel)
.merge(channel.rx.map(MultiplexMessage::FromChannelToHttp));
while let Some(message) = merged_stream.next().await {
tracing::trace!(?message, "received message");
match message {
MultiplexMessage::FromHttpToChannel(http_message) => {
self.handle_http_message(http_message, &mut channel.tx)?;
}
MultiplexMessage::FromChannelToHttp(message) => {
let message = message.unwrap_or_else(|err| {
agent_client_protocol::jsonrpcmsg::Message::Response(
agent_client_protocol::jsonrpcmsg::Response::error(
agent_client_protocol::util::into_jsonrpc_error(err),
None,
),
)
});
tracing::debug!(
queue_len = self.message_deque.len() + 1,
?message,
"enqueuing outgoing message"
);
self.message_deque.push_back(message);
}
}
self.drain_jsonrpc_messages();
}
tracing::trace!("http connection terminating");
Ok(())
}
fn handle_http_message(
&mut self,
message: HttpMessage,
channel_tx: &mut mpsc::UnboundedSender<
Result<agent_client_protocol::jsonrpcmsg::Message, agent_client_protocol::Error>,
>,
) -> Result<(), agent_client_protocol::Error> {
match message {
HttpMessage::Request {
http_request_id,
request,
response_tx,
} => {
tracing::debug!(%http_request_id, ?request, "handling request");
let request_id = request.id.clone().map(JsonRpcId::from);
channel_tx
.unbounded_send(Ok(Message::Request(request)))
.map_err(agent_client_protocol::util::internal_error)?;
let session = RegisteredSession::new(response_tx);
if let Some(id) = request_id {
tracing::debug!(%http_request_id, session_id = %session.id, ?id, "registering waiting session");
self.waiting_sessions.insert(id, session);
} else {
tracing::debug!(%http_request_id, session_id = %session.id, "registering general session (request without id)");
self.general_sessions.push(session);
}
}
HttpMessage::Notification {
http_request_id,
request,
} => {
tracing::debug!(%http_request_id, ?request, "handling notification");
channel_tx
.unbounded_send(Ok(Message::Request(request)))
.map_err(agent_client_protocol::util::internal_error)?;
}
HttpMessage::Response {
http_request_id,
response,
} => {
tracing::debug!(%http_request_id, ?response, "handling response");
channel_tx
.unbounded_send(Ok(Message::Response(response)))
.map_err(agent_client_protocol::util::internal_error)?;
}
HttpMessage::Get {
http_request_id,
response_tx,
} => {
let session = RegisteredSession::new(response_tx);
tracing::debug!(
%http_request_id,
session_id = %session.id,
queued_messages = self.message_deque.len(),
"handling GET (opening SSE stream)"
);
self.general_sessions.push(session);
}
}
self.purge_closed_sessions();
Ok(())
}
fn drain_jsonrpc_messages(&mut self) {
if !self.message_deque.is_empty() {
tracing::debug!(
queue_len = self.message_deque.len(),
general_sessions = self.general_sessions.len(),
waiting_sessions = self.waiting_sessions.len(),
"draining message queue"
);
}
while let Some(message) = self.message_deque.pop_front() {
match self.try_dispatch_jsonrpc_message(message) {
None => {
tracing::debug!(
remaining = self.message_deque.len(),
"message dispatched successfully"
);
}
Some(message) => {
tracing::debug!(
remaining = self.message_deque.len() + 1,
"no available session, re-enqueuing message"
);
self.message_deque.push_front(message);
break;
}
}
}
}
fn try_dispatch_jsonrpc_message(
&mut self,
mut message: agent_client_protocol::jsonrpcmsg::Message,
) -> Option<agent_client_protocol::jsonrpcmsg::Message> {
let message_id = match &message {
Message::Response(response) => response.id.as_ref().map(|v| v.clone().into()),
Message::Request(_) => None,
};
tracing::debug!(?message_id, "attempting to dispatch JSON-RPC message");
if let Some(ref message_id) = message_id
&& let Some(session) = self.waiting_sessions.remove(message_id)
{
tracing::debug!(session_id = %session.id, "found waiting session, attempting send");
match session.outgoing_tx.unbounded_send(message) {
Ok(()) => {
tracing::debug!(session_id = %session.id, "sent to waiting session");
return None;
}
Err(m) => {
tracing::debug!(session_id = %session.id, "waiting session disconnected");
assert!(m.is_disconnected());
message = m.into_inner();
}
}
}
self.purge_closed_sessions();
tracing::debug!(
general_sessions = self.general_sessions.len(),
waiting_sessions = self.waiting_sessions.len(),
"trying to find any active session"
);
let all_sessions = self
.general_sessions
.iter_mut()
.chain(self.waiting_sessions.values_mut());
for session in all_sessions {
tracing::trace!(session_id = %session.id, "trying session");
match session.outgoing_tx.unbounded_send(message) {
Ok(()) => {
tracing::debug!(session_id = %session.id, "sent to session");
return None;
}
Err(m) => {
tracing::debug!(session_id = %session.id, "session disconnected, trying next");
assert!(m.is_disconnected());
message = m.into_inner();
}
}
}
Some(message)
}
fn purge_closed_sessions(&mut self) {
self.general_sessions
.retain(|session| !session.outgoing_tx.is_closed());
self.waiting_sessions
.retain(|_, session| !session.outgoing_tx.is_closed());
}
}
struct RegisteredSession {
id: uuid::Uuid,
outgoing_tx: mpsc::UnboundedSender<agent_client_protocol::jsonrpcmsg::Message>,
}
impl RegisteredSession {
fn new(outgoing_tx: mpsc::UnboundedSender<agent_client_protocol::jsonrpcmsg::Message>) -> Self {
Self {
id: uuid::Uuid::new_v4(),
outgoing_tx,
}
}
}
async fn handle_post(
State(state): State<Arc<BridgeState>>,
body: String,
) -> Result<Response, HttpError> {
let http_request_id = uuid::Uuid::new_v4();
let message: agent_client_protocol::jsonrpcmsg::Message =
serde_json::from_str(&body).map_err(agent_client_protocol::util::parse_error)?;
match message {
Message::Request(request) if request.id.is_some() => {
tracing::debug!(%http_request_id, method = %request.method, "POST request received");
let (tx, mut rx) = mpsc::unbounded();
state
.registration_tx
.unbounded_send(HttpMessage::Request {
http_request_id,
request,
response_tx: tx,
})
.map_err(agent_client_protocol::util::internal_error)?;
let stream = async_stream::stream! {
while let Some(message) = rx.next().await {
tracing::debug!(%http_request_id, "sending SSE event");
match axum::response::sse::Event::default().json_data(message) {
Ok(v) => yield Ok(v),
Err(e) => yield Err(HttpError::from(e)),
}
}
tracing::debug!(%http_request_id, "SSE stream completed");
};
Ok(Sse::new(stream).into_response())
}
Message::Request(request) => {
tracing::debug!(%http_request_id, method = %request.method, "POST notification received");
state
.registration_tx
.unbounded_send(HttpMessage::Notification {
http_request_id,
request,
})
.map_err(agent_client_protocol::util::internal_error)?;
Ok(StatusCode::ACCEPTED.into_response())
}
Message::Response(response) => {
tracing::debug!(%http_request_id, "POST response received");
state
.registration_tx
.unbounded_send(HttpMessage::Response {
http_request_id,
response,
})
.map_err(agent_client_protocol::util::internal_error)?;
Ok(StatusCode::ACCEPTED.into_response())
}
}
}
async fn handle_get(
State(state): State<Arc<BridgeState>>,
) -> Result<Sse<impl Stream<Item = Result<axum::response::sse::Event, HttpError>>>, HttpError> {
let http_request_id = uuid::Uuid::new_v4();
tracing::debug!(%http_request_id, "GET request received");
let (tx, mut rx) = mpsc::unbounded();
state
.registration_tx
.unbounded_send(HttpMessage::Get {
http_request_id,
response_tx: tx,
})
.map_err(agent_client_protocol::util::internal_error)?;
let stream = async_stream::stream! {
while let Some(message) = rx.next().await {
tracing::debug!(%http_request_id, "sending SSE event");
match axum::response::sse::Event::default().json_data(message) {
Ok(v) => yield Ok(v),
Err(e) => yield Err(HttpError::from(e)),
}
}
tracing::debug!(%http_request_id, "SSE stream completed");
};
Ok(Sse::new(stream))
}