use agent_client_protocol::{BoxFuture, Channel, ConnectTo, jsonrpcmsg::Message, role::mcp};
use axum::{
Router,
extract::State,
http::StatusCode,
response::{IntoResponse, Response, Sse},
routing::post,
};
use futures::{SinkExt, StreamExt as _, channel::mpsc, future::Either, stream::Stream};
use futures_concurrency::future::FutureExt as _;
use futures_concurrency::stream::StreamExt as _;
use rustc_hash::FxHashMap;
use std::{
collections::{HashMap, VecDeque},
pin::pin,
sync::Arc,
};
use tokio::net::TcpListener;
use super::{BridgeConnection, BridgeMessage, actor::BridgeConnectionActor};
pub async fn run_http_listener(
tcp_listener: TcpListener,
acp_id: String,
mut bridge_tx: mpsc::Sender<BridgeMessage>,
) -> Result<(), agent_client_protocol::Error> {
let (to_mcp_client_tx, to_mcp_client_rx) = mpsc::channel(128);
bridge_tx
.send(BridgeMessage::ConnectionReceived {
acp_id,
actor: BridgeConnectionActor::new(
HttpMcpBridge::new(tcp_listener),
bridge_tx.clone(),
to_mcp_client_rx,
),
connection: BridgeConnection::new(to_mcp_client_tx),
})
.await
.map_err(|_| agent_client_protocol::Error::internal_error())?;
Ok(())
}
struct HttpMcpBridge {
listener: tokio::net::TcpListener,
}
impl HttpMcpBridge {
fn new(listener: tokio::net::TcpListener) -> Self {
Self { listener }
}
}
impl ConnectTo<mcp::Client> for HttpMcpBridge {
async fn connect_to(
self,
client: impl ConnectTo<mcp::Server>,
) -> Result<(), agent_client_protocol::Error> {
let (channel, serve_self) = self.into_channel_and_future();
match futures::future::select(pin!(client.connect_to(channel)), serve_self).await {
Either::Left((result, _)) | Either::Right((result, _)) => result,
}
}
fn into_channel_and_future(
self,
) -> (
Channel,
BoxFuture<'static, Result<(), agent_client_protocol::Error>>,
)
where
Self: Sized,
{
let (channel_a, channel_b) = Channel::duplex();
(channel_a, Box::pin(run(self.listener, channel_b)))
}
}
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
struct HttpError(#[from] agent_client_protocol::Error);
impl From<axum::Error> for HttpError {
fn from(error: axum::Error) -> Self {
HttpError(agent_client_protocol::util::internal_error(error))
}
}
impl IntoResponse for HttpError {
fn into_response(self) -> Response {
let message = format!("Error: {}", self.0);
(StatusCode::INTERNAL_SERVER_ERROR, message).into_response()
}
}
async fn run(listener: TcpListener, channel: Channel) -> Result<(), agent_client_protocol::Error> {
let (registration_tx, registration_rx) = mpsc::unbounded();
let state = BridgeState { registration_tx };
async {
let app = Router::new()
.route("/", post(handle_post).get(handle_get))
.with_state(Arc::new(state));
axum::serve(listener, app)
.await
.map_err(agent_client_protocol::util::internal_error)
}
.race(RunningServer::new().run(channel, registration_rx))
.await
}
struct BridgeState {
registration_tx: mpsc::UnboundedSender<HttpMessage>,
}
#[derive(Debug)]
#[allow(dead_code)]
enum HttpMessage {
Request {
http_request_id: uuid::Uuid,
request: agent_client_protocol::jsonrpcmsg::Request,
response_tx: mpsc::UnboundedSender<agent_client_protocol::jsonrpcmsg::Message>,
},
Notification {
http_request_id: uuid::Uuid,
request: agent_client_protocol::jsonrpcmsg::Request,
},
Response {
http_request_id: uuid::Uuid,
response: agent_client_protocol::jsonrpcmsg::Response,
},
Get {
http_request_id: uuid::Uuid,
response_tx: mpsc::UnboundedSender<agent_client_protocol::jsonrpcmsg::Message>,
},
}
#[derive(Eq, PartialEq, PartialOrd, Ord, Hash, Debug, Clone)]
enum JsonRpcId {
String(String),
Number(u64),
Null,
}
impl From<agent_client_protocol::jsonrpcmsg::Id> for JsonRpcId {
fn from(id: agent_client_protocol::jsonrpcmsg::Id) -> Self {
match id {
agent_client_protocol::jsonrpcmsg::Id::String(s) => JsonRpcId::String(s),
agent_client_protocol::jsonrpcmsg::Id::Number(n) => JsonRpcId::Number(n),
agent_client_protocol::jsonrpcmsg::Id::Null => JsonRpcId::Null,
}
}
}
struct RunningServer {
waiting_sessions: FxHashMap<JsonRpcId, RegisteredSession>,
general_sessions: Vec<RegisteredSession>,
message_deque: VecDeque<agent_client_protocol::jsonrpcmsg::Message>,
}
impl RunningServer {
fn new() -> Self {
RunningServer {
waiting_sessions: HashMap::default(),
general_sessions: Vec::default(),
message_deque: VecDeque::with_capacity(32),
}
}
async fn run(
mut self,
mut channel: Channel,
http_rx: mpsc::UnboundedReceiver<HttpMessage>,
) -> Result<(), agent_client_protocol::Error> {
#[derive(Debug)]
enum MultiplexMessage {
FromHttpToChannel(HttpMessage),
FromChannelToHttp(
Result<agent_client_protocol::jsonrpcmsg::Message, agent_client_protocol::Error>,
),
}
let mut merged_stream = http_rx
.map(MultiplexMessage::FromHttpToChannel)
.merge(channel.rx.map(MultiplexMessage::FromChannelToHttp));
while let Some(message) = merged_stream.next().await {
tracing::trace!(?message, "received message");
match message {
MultiplexMessage::FromHttpToChannel(http_message) => {
self.handle_http_message(http_message, &mut channel.tx)?;
}
MultiplexMessage::FromChannelToHttp(message) => {
let message = message.unwrap_or_else(|err| {
agent_client_protocol::jsonrpcmsg::Message::Response(
agent_client_protocol::jsonrpcmsg::Response::error(
agent_client_protocol::util::into_jsonrpc_error(err),
None,
),
)
});
self.message_deque.push_back(message);
}
}
self.drain_jsonrpc_messages();
}
Ok(())
}
fn handle_http_message(
&mut self,
message: HttpMessage,
channel_tx: &mut mpsc::UnboundedSender<
Result<agent_client_protocol::jsonrpcmsg::Message, agent_client_protocol::Error>,
>,
) -> Result<(), agent_client_protocol::Error> {
match message {
HttpMessage::Request {
http_request_id,
request,
response_tx,
} => {
tracing::debug!(%http_request_id, ?request, "handling request");
let request_id = request.id.clone().map(JsonRpcId::from);
channel_tx
.unbounded_send(Ok(Message::Request(request)))
.map_err(agent_client_protocol::util::internal_error)?;
let session = RegisteredSession::new(response_tx);
if let Some(id) = request_id {
self.waiting_sessions.insert(id, session);
} else {
self.general_sessions.push(session);
}
}
HttpMessage::Notification {
http_request_id: _,
request,
} => {
channel_tx
.unbounded_send(Ok(Message::Request(request)))
.map_err(agent_client_protocol::util::internal_error)?;
}
HttpMessage::Response {
http_request_id: _,
response,
} => {
channel_tx
.unbounded_send(Ok(Message::Response(response)))
.map_err(agent_client_protocol::util::internal_error)?;
}
HttpMessage::Get {
http_request_id: _,
response_tx,
} => {
self.general_sessions
.push(RegisteredSession::new(response_tx));
}
}
self.purge_closed_sessions();
Ok(())
}
fn drain_jsonrpc_messages(&mut self) {
while let Some(message) = self.message_deque.pop_front() {
if let Some(message) = self.try_dispatch_jsonrpc_message(message) {
self.message_deque.push_front(message);
break;
}
}
}
fn try_dispatch_jsonrpc_message(
&mut self,
mut message: agent_client_protocol::jsonrpcmsg::Message,
) -> Option<agent_client_protocol::jsonrpcmsg::Message> {
let message_id = match &message {
Message::Response(response) => response.id.as_ref().map(|v| v.clone().into()),
Message::Request(_) => None,
};
if let Some(ref message_id) = message_id
&& let Some(session) = self.waiting_sessions.remove(message_id)
{
match session.outgoing_tx.unbounded_send(message) {
Ok(()) => return None,
Err(m) => {
assert!(m.is_disconnected());
message = m.into_inner();
}
}
}
self.purge_closed_sessions();
let all_sessions = self
.general_sessions
.iter_mut()
.chain(self.waiting_sessions.values_mut());
for session in all_sessions {
match session.outgoing_tx.unbounded_send(message) {
Ok(()) => return None,
Err(m) => {
assert!(m.is_disconnected());
message = m.into_inner();
}
}
}
Some(message)
}
fn purge_closed_sessions(&mut self) {
self.general_sessions
.retain(|session| !session.outgoing_tx.is_closed());
self.waiting_sessions
.retain(|_, session| !session.outgoing_tx.is_closed());
}
}
struct RegisteredSession {
#[allow(dead_code)]
id: uuid::Uuid,
outgoing_tx: mpsc::UnboundedSender<agent_client_protocol::jsonrpcmsg::Message>,
}
impl RegisteredSession {
fn new(outgoing_tx: mpsc::UnboundedSender<agent_client_protocol::jsonrpcmsg::Message>) -> Self {
Self {
id: uuid::Uuid::new_v4(),
outgoing_tx,
}
}
}
async fn handle_post(
State(state): State<Arc<BridgeState>>,
body: String,
) -> Result<Response, HttpError> {
let http_request_id = uuid::Uuid::new_v4();
let message: agent_client_protocol::jsonrpcmsg::Message =
serde_json::from_str(&body).map_err(agent_client_protocol::util::parse_error)?;
match message {
Message::Request(request) if request.id.is_some() => {
let (tx, mut rx) = mpsc::unbounded();
state
.registration_tx
.unbounded_send(HttpMessage::Request {
http_request_id,
request,
response_tx: tx,
})
.map_err(agent_client_protocol::util::internal_error)?;
let stream = async_stream::stream! {
while let Some(message) = rx.next().await {
match axum::response::sse::Event::default().json_data(message) {
Ok(v) => yield Ok(v),
Err(e) => yield Err(HttpError::from(e)),
}
}
};
Ok(Sse::new(stream).into_response())
}
Message::Request(request) => {
state
.registration_tx
.unbounded_send(HttpMessage::Notification {
http_request_id,
request,
})
.map_err(agent_client_protocol::util::internal_error)?;
Ok(StatusCode::ACCEPTED.into_response())
}
Message::Response(response) => {
state
.registration_tx
.unbounded_send(HttpMessage::Response {
http_request_id,
response,
})
.map_err(agent_client_protocol::util::internal_error)?;
Ok(StatusCode::ACCEPTED.into_response())
}
}
}
async fn handle_get(
State(state): State<Arc<BridgeState>>,
) -> Result<Sse<impl Stream<Item = Result<axum::response::sse::Event, HttpError>>>, HttpError> {
let http_request_id = uuid::Uuid::new_v4();
let (tx, mut rx) = mpsc::unbounded();
state
.registration_tx
.unbounded_send(HttpMessage::Get {
http_request_id,
response_tx: tx,
})
.map_err(agent_client_protocol::util::internal_error)?;
let stream = async_stream::stream! {
while let Some(message) = rx.next().await {
match axum::response::sse::Event::default().json_data(message) {
Ok(v) => yield Ok(v),
Err(e) => yield Err(HttpError::from(e)),
}
}
};
Ok(Sse::new(stream))
}