use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use axum::Router;
use axum::body::Body;
use axum::extract::{ConnectInfo, State};
use axum::http::StatusCode;
use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use bytes::Bytes;
use dashmap::DashMap;
use tokio::net::TcpListener;
use tokio::sync::{broadcast, mpsc};
use tokio::task::JoinHandle;
use tokio::time::Instant;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info};
use super::{
ConnectionContext, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, RawResponseWriter, Result,
Transport, TransportType,
};
use crate::error::TransportError;
#[derive(Debug, Clone)]
pub struct HttpConfig {
pub bind_addr: String,
pub max_message_size: usize,
}
struct IncomingRequest {
message: JsonRpcMessage,
response_tx: Option<mpsc::Sender<std::result::Result<Bytes, io::Error>>>,
connection_id: u64,
remote_addr: SocketAddr,
connected_at: Instant,
}
pub struct ConnectionState {
pub remote_addr: SocketAddr,
pub connected_at: Instant,
pub request_count: AtomicU64,
}
const MAX_SSE_CONNECTIONS: usize = 16;
struct HttpSharedState {
incoming_tx: mpsc::Sender<IncomingRequest>,
sse_tx: broadcast::Sender<String>,
connections: Arc<DashMap<u64, ConnectionState>>,
next_connection_id: AtomicU64,
max_message_size: usize,
sse_connections: AtomicUsize,
cancel: CancellationToken,
session_id: String,
pending_server_requests: tokio::sync::Mutex<
std::collections::HashMap<String, tokio::sync::oneshot::Sender<JsonRpcMessage>>,
>,
}
struct ConnectionGuard {
connections: Arc<DashMap<u64, ConnectionState>>,
connection_id: u64,
}
impl ConnectionGuard {
const fn new(connections: Arc<DashMap<u64, ConnectionState>>, connection_id: u64) -> Self {
Self {
connections,
connection_id,
}
}
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
self.connections.remove(&self.connection_id);
}
}
pub struct HttpTransport {
shared: Arc<HttpSharedState>,
incoming_rx: tokio::sync::Mutex<mpsc::Receiver<IncomingRequest>>,
current_response:
tokio::sync::Mutex<Option<mpsc::Sender<std::result::Result<Bytes, io::Error>>>>,
current_context: std::sync::Mutex<ConnectionContext>,
current_guard: std::sync::Mutex<Option<ConnectionGuard>>,
previous_guards: std::sync::Mutex<Vec<ConnectionGuard>>,
_server_handle: JoinHandle<()>,
}
impl HttpTransport {
pub async fn bind(config: HttpConfig, cancel: CancellationToken) -> Result<(Self, SocketAddr)> {
let (incoming_tx, incoming_rx) = mpsc::channel::<IncomingRequest>(32);
let (sse_tx, _) = broadcast::channel::<String>(256);
let listener = TcpListener::bind(&config.bind_addr)
.await
.map_err(|e| TransportError::ConnectionFailed(format!("bind failed: {e}")))?;
let bound_addr = listener
.local_addr()
.map_err(|e| TransportError::ConnectionFailed(format!("local_addr failed: {e}")))?;
let shared = Arc::new(HttpSharedState {
incoming_tx,
sse_tx,
connections: Arc::new(DashMap::new()),
next_connection_id: AtomicU64::new(1),
max_message_size: config.max_message_size,
sse_connections: AtomicUsize::new(0),
cancel: cancel.clone(),
session_id: uuid::Uuid::new_v4().to_string(),
pending_server_requests: tokio::sync::Mutex::new(std::collections::HashMap::new()),
});
let router = build_router(Arc::clone(&shared));
let service = router.into_make_service_with_connect_info::<SocketAddr>();
let server_cancel = cancel.clone();
let server_handle = tokio::spawn(async move {
info!(%bound_addr, "HTTP transport started");
axum::serve(listener, service)
.with_graceful_shutdown(async move {
server_cancel.cancelled().await;
})
.await
.ok();
debug!("HTTP transport shut down");
});
let transport = Self {
shared,
incoming_rx: tokio::sync::Mutex::new(incoming_rx),
current_response: tokio::sync::Mutex::new(None),
current_context: std::sync::Mutex::new(ConnectionContext {
connection_id: 0,
remote_addr: None,
is_exclusive: false,
connected_at: Instant::now(),
}),
current_guard: std::sync::Mutex::new(None),
previous_guards: std::sync::Mutex::new(Vec::new()),
_server_handle: server_handle,
};
Ok((transport, bound_addr))
}
pub fn shutdown(&self) {
self.shared.cancel.cancel();
}
pub async fn receive_request(&self) -> Option<(JsonRpcMessage, ResponseHandle)> {
let mut rx = self.incoming_rx.lock().await;
let incoming = rx.recv().await?;
drop(rx);
self.shared.connections.insert(
incoming.connection_id,
ConnectionState {
remote_addr: incoming.remote_addr,
connected_at: incoming.connected_at,
request_count: AtomicU64::new(1),
},
);
let context = ConnectionContext {
connection_id: incoming.connection_id,
remote_addr: Some(incoming.remote_addr),
is_exclusive: false,
connected_at: incoming.connected_at,
};
let guard =
ConnectionGuard::new(Arc::clone(&self.shared.connections), incoming.connection_id);
let handle = ResponseHandle {
response_tx: incoming.response_tx,
context,
_guard: guard,
};
Some((incoming.message, handle))
}
pub async fn send_server_request(&self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
let request_id =
serde_json::to_string(&request.id).unwrap_or_else(|_| request.id.to_string());
let (tx, rx) = tokio::sync::oneshot::channel();
{
let mut pending = self.shared.pending_server_requests.lock().await;
pending.insert(request_id.clone(), tx);
}
let serialized = serde_json::to_string(&JsonRpcMessage::Request(request.clone()))?;
let sse_data = format!("event: message\ndata: {serialized}\n\n");
let response_tx = {
let guard = self.current_response.lock().await;
guard.as_ref().cloned()
};
if let Some(ch) = response_tx {
if ch.send(Ok(Bytes::from(sse_data))).await.is_err() {
self.shared
.pending_server_requests
.lock()
.await
.remove(&request_id);
return Err(TransportError::ConnectionClosed(
"response channel closed".into(),
));
}
} else {
let _ = self.shared.sse_tx.send(serialized);
}
match tokio::time::timeout(std::time::Duration::from_secs(30), rx).await {
Ok(Ok(JsonRpcMessage::Response(resp))) => Ok(resp),
Ok(Ok(_)) => Err(TransportError::InternalError(
"unexpected non-response message for server request".into(),
)),
Ok(Err(_)) => Err(TransportError::ConnectionClosed(
"server request response channel dropped".into(),
)),
Err(_) => {
self.shared
.pending_server_requests
.lock()
.await
.remove(&request_id);
Err(TransportError::InternalError(
"server request timed out after 30s".into(),
))
}
}
}
}
pub struct ResponseHandle {
response_tx: Option<mpsc::Sender<std::result::Result<Bytes, io::Error>>>,
context: ConnectionContext,
_guard: ConnectionGuard,
}
impl ResponseHandle {
pub async fn send_raw(&self, bytes: &[u8]) -> Result<()> {
let tx = self
.response_tx
.as_ref()
.ok_or_else(|| TransportError::ConnectionClosed("response already finalized".into()))?;
tx.send(Ok(Bytes::copy_from_slice(bytes)))
.await
.map_err(|_| TransportError::ConnectionClosed("response channel closed".into()))
}
pub async fn send_message(&self, message: &JsonRpcMessage) -> Result<()> {
let serialized = serde_json::to_string(message)?;
let sse_data = format!("event: message\ndata: {serialized}\n\n");
self.send_raw(sse_data.as_bytes()).await
}
pub fn finalize(mut self) {
self.response_tx.take();
}
#[must_use]
pub const fn connection_context(&self) -> &ConnectionContext {
&self.context
}
}
pub struct ResponseHandleAdapter {
handle: ResponseHandle,
}
impl ResponseHandleAdapter {
#[must_use]
pub const fn new(handle: ResponseHandle) -> Self {
Self { handle }
}
pub fn finalize(self) {
self.handle.finalize();
}
#[must_use]
pub const fn connection_context(&self) -> &ConnectionContext {
self.handle.connection_context()
}
}
#[async_trait::async_trait]
impl Transport for ResponseHandleAdapter {
async fn send_message(&self, message: &JsonRpcMessage) -> Result<()> {
self.handle.send_message(message).await
}
async fn send_raw(&self, bytes: &[u8]) -> Result<()> {
self.handle.send_raw(bytes).await
}
async fn receive_message(&self) -> Result<Option<JsonRpcMessage>> {
Err(TransportError::InternalError(
"ResponseHandleAdapter does not support receive_message".into(),
))
}
fn transport_type(&self) -> TransportType {
TransportType::Http
}
async fn finalize_response(&self) -> Result<()> {
Ok(())
}
fn connection_context(&self) -> ConnectionContext {
self.handle.context.clone()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl std::fmt::Debug for HttpTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpTransport")
.field("connections", &self.shared.connections.len())
.finish_non_exhaustive()
}
}
#[async_trait::async_trait]
impl Transport for HttpTransport {
async fn receive_message(&self) -> Result<Option<JsonRpcMessage>> {
let mut rx = self.incoming_rx.lock().await;
let incoming = rx.recv().await;
drop(rx);
let Some(req) = incoming else {
return Ok(None);
};
{
let mut guard = self.current_response.lock().await;
(*guard).clone_from(&req.response_tx);
}
self.shared.connections.insert(
req.connection_id,
ConnectionState {
remote_addr: req.remote_addr,
connected_at: req.connected_at,
request_count: AtomicU64::new(1),
},
);
{
let mut ctx = self
.current_context
.lock()
.map_err(|_| TransportError::InternalError("context mutex poisoned".into()))?;
*ctx = ConnectionContext {
connection_id: req.connection_id,
remote_addr: Some(req.remote_addr),
is_exclusive: false,
connected_at: req.connected_at,
};
}
{
let old_guard = self
.current_guard
.lock()
.map_err(|_| TransportError::InternalError("guard mutex poisoned".into()))?
.take();
if let Some(guard) = old_guard {
let mut prev = self.previous_guards.lock().map_err(|_| {
TransportError::InternalError("previous_guards mutex poisoned".into())
})?;
prev.push(guard);
while prev.len() > 2 {
prev.remove(0);
}
drop(prev);
}
*self
.current_guard
.lock()
.map_err(|_| TransportError::InternalError("guard mutex poisoned".into()))? = Some(
ConnectionGuard::new(Arc::clone(&self.shared.connections), req.connection_id),
);
}
Ok(Some(req.message))
}
async fn send_message(&self, message: &JsonRpcMessage) -> Result<()> {
let serialized = serde_json::to_string(message)?;
let tx = {
let guard = self.current_response.lock().await;
guard.as_ref().cloned()
};
if let Some(tx) = tx {
let sse_data = format!("event: message\ndata: {serialized}\n\n");
tx.send(Ok(Bytes::from(sse_data)))
.await
.map_err(|_| TransportError::ConnectionClosed("response channel closed".into()))?;
} else {
match message {
JsonRpcMessage::Notification(_) | JsonRpcMessage::Request(_) => {
let _ = self.shared.sse_tx.send(serialized);
}
JsonRpcMessage::Response(_) => {
return Err(TransportError::ConnectionClosed(
"no active response channel (send_message called before receive_message)"
.into(),
));
}
}
}
Ok(())
}
async fn send_raw(&self, bytes: &[u8]) -> Result<()> {
let tx = {
let guard = self.current_response.lock().await;
guard.as_ref().cloned()
};
let Some(tx) = tx else {
return Err(TransportError::ConnectionClosed(
"no active response channel (send_raw called before receive_message)".into(),
));
};
tx.send(Ok(Bytes::copy_from_slice(bytes)))
.await
.map_err(|_| TransportError::ConnectionClosed("response channel closed".into()))?;
Ok(())
}
fn transport_type(&self) -> TransportType {
TransportType::Http
}
async fn finalize_response(&self) -> Result<()> {
let sender = {
let mut guard = self.current_response.lock().await;
guard.take()
};
drop(sender);
let guard = {
let mut g = self
.current_guard
.lock()
.map_err(|_| TransportError::InternalError("guard mutex poisoned".into()))?;
g.take()
};
drop(guard);
Ok(())
}
async fn capture_raw_writer(&self) -> Result<Option<RawResponseWriter>> {
let guard = self.current_response.lock().await;
Ok(guard.as_ref().map(|tx| RawResponseWriter::new(tx.clone())))
}
fn connection_context(&self) -> ConnectionContext {
self.current_context
.lock()
.unwrap_or_else(|e| {
tracing::error!("context mutex poisoned, using default");
e.into_inner()
})
.clone()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
fn build_router(shared: Arc<HttpSharedState>) -> Router {
let body_limit = axum::extract::DefaultBodyLimit::max(shared.max_message_size);
Router::new()
.route("/message", post(handle_post_message))
.route("/sse", get(handle_sse))
.route("/mcp", get(handle_sse_streamable).post(handle_post_message))
.layer(body_limit)
.with_state(shared)
}
async fn handle_post_message(
State(shared): State<Arc<HttpSharedState>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
_headers: axum::http::HeaderMap,
body: axum::body::Bytes,
) -> Response {
if body.is_empty() {
return (StatusCode::BAD_REQUEST, "empty request body").into_response();
}
if body.len() > shared.max_message_size {
return (
StatusCode::PAYLOAD_TOO_LARGE,
format!(
"message too large: {} bytes (limit: {})",
body.len(),
shared.max_message_size
),
)
.into_response();
}
let message: JsonRpcMessage = match serde_json::from_slice(&body) {
Ok(msg) => msg,
Err(e) => {
return (StatusCode::BAD_REQUEST, format!("invalid JSON-RPC: {e}")).into_response();
}
};
let is_initialize =
matches!(&message, JsonRpcMessage::Request(req) if req.method == "initialize");
match &message {
JsonRpcMessage::Response(resp) => {
let key = serde_json::to_string(&resp.id).unwrap_or_else(|_| resp.id.to_string());
let sender = {
let mut pending = shared.pending_server_requests.lock().await;
pending.remove(&key)
};
if let Some(tx) = sender {
let _ = tx.send(message);
} else {
tracing::debug!(id = ?resp.id, "no pending server request for response");
}
return StatusCode::ACCEPTED.into_response();
}
JsonRpcMessage::Notification(_) => {
let connection_id = shared.next_connection_id.fetch_add(1, Ordering::SeqCst);
let incoming = IncomingRequest {
message,
response_tx: None,
connection_id,
remote_addr: addr,
connected_at: Instant::now(),
};
if shared.incoming_tx.send(incoming).await.is_err() {
return (StatusCode::SERVICE_UNAVAILABLE, "server shutting down").into_response();
}
return StatusCode::ACCEPTED.into_response();
}
JsonRpcMessage::Request(_) => {
}
}
let connection_id = shared.next_connection_id.fetch_add(1, Ordering::SeqCst);
let connected_at = Instant::now();
let (response_tx, response_rx) = mpsc::channel::<std::result::Result<Bytes, io::Error>>(64);
let incoming = IncomingRequest {
message,
response_tx: Some(response_tx),
connection_id,
remote_addr: addr,
connected_at,
};
if shared.incoming_tx.send(incoming).await.is_err() {
return (StatusCode::SERVICE_UNAVAILABLE, "server shutting down").into_response();
}
let stream = ReceiverStream::new(response_rx);
let body = Body::from_stream(stream);
let mut builder = Response::builder().header("content-type", "text/event-stream");
if is_initialize {
builder = builder.header("mcp-session-id", &shared.session_id);
}
builder.body(body).unwrap_or_else(|e| {
tracing::error!(error = %e, "failed to build HTTP response");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
})
}
async fn handle_sse(State(shared): State<Arc<HttpSharedState>>) -> Response {
handle_sse_inner(&shared, "/message")
}
async fn handle_sse_streamable(State(shared): State<Arc<HttpSharedState>>) -> Response {
handle_sse_inner(&shared, "/mcp")
}
fn handle_sse_inner(shared: &Arc<HttpSharedState>, endpoint_path: &'static str) -> Response {
let current = shared.sse_connections.fetch_add(1, Ordering::SeqCst);
if current >= MAX_SSE_CONNECTIONS {
shared.sse_connections.fetch_sub(1, Ordering::SeqCst);
return (
StatusCode::SERVICE_UNAVAILABLE,
format!("too many SSE connections (limit: {MAX_SSE_CONNECTIONS})"),
)
.into_response();
}
let rx = shared.sse_tx.subscribe();
let cancel = shared.cancel.clone();
let endpoint_event: std::result::Result<SseEvent, std::convert::Infallible> =
Ok(SseEvent::default().event("endpoint").data(endpoint_path));
let broadcast_stream = tokio_stream::wrappers::BroadcastStream::new(rx)
.take_while(move |_| !cancel.is_cancelled())
.filter_map(|result: std::result::Result<String, _>| match result {
Ok(data) => {
let event: std::result::Result<SseEvent, std::convert::Infallible> =
Ok(SseEvent::default().event("message").data(data));
Some(event)
}
Err(e) => {
tracing::warn!(error = %e, "SSE subscriber lagged, dropping missed messages");
None
}
});
let stream = tokio_stream::once(endpoint_event).chain(broadcast_stream);
let shared_for_drop = Arc::clone(shared);
let stream = SseCountedStream {
inner: Box::pin(stream),
shared: shared_for_drop,
};
Sse::new(stream)
.keep_alive(KeepAlive::default())
.into_response()
}
struct SseCountedStream<S> {
inner: std::pin::Pin<Box<S>>,
shared: Arc<HttpSharedState>,
}
impl<S> tokio_stream::Stream for SseCountedStream<S>
where
S: tokio_stream::Stream + Unpin,
{
type Item = S::Item;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
impl<S> Drop for SseCountedStream<S> {
fn drop(&mut self) {
self.shared.sse_connections.fetch_sub(1, Ordering::SeqCst);
}
}
pub fn parse_bind_addr(input: &str) -> std::result::Result<String, TransportError> {
let addr = if input.starts_with(':') {
format!("0.0.0.0{input}")
} else if input.parse::<u16>().is_ok() {
format!("0.0.0.0:{input}")
} else {
input.to_string()
};
addr.parse::<SocketAddr>().map_err(|e| {
TransportError::ConnectionFailed(format!("invalid bind address \"{input}\": {e}"))
})?;
Ok(addr)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::DEFAULT_MAX_MESSAGE_SIZE;
use axum::body::Body;
use axum::http::Request;
use tower::util::ServiceExt;
fn test_shared_state() -> Arc<HttpSharedState> {
let (incoming_tx, _incoming_rx) = mpsc::channel(32);
let (sse_tx, _) = broadcast::channel(256);
Arc::new(HttpSharedState {
incoming_tx,
sse_tx,
connections: Arc::new(DashMap::new()),
next_connection_id: AtomicU64::new(1),
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
sse_connections: AtomicUsize::new(0),
cancel: CancellationToken::new(),
session_id: "test-session".to_string(),
pending_server_requests: tokio::sync::Mutex::new(std::collections::HashMap::new()),
})
}
use axum::extract::connect_info::MockConnectInfo;
fn test_router(shared: Arc<HttpSharedState>) -> Router {
build_router(shared).layer(MockConnectInfo(SocketAddr::from(([127, 0, 0, 1], 9999))))
}
#[test]
fn parse_bind_addr_colon_port() {
assert_eq!(parse_bind_addr(":8080").unwrap(), "0.0.0.0:8080");
}
#[test]
fn parse_bind_addr_port_only() {
assert_eq!(parse_bind_addr("8080").unwrap(), "0.0.0.0:8080");
}
#[test]
fn parse_bind_addr_full() {
assert_eq!(parse_bind_addr("1.2.3.4:8080").unwrap(), "1.2.3.4:8080");
}
#[test]
fn parse_bind_addr_localhost() {
assert_eq!(parse_bind_addr("127.0.0.1:3000").unwrap(), "127.0.0.1:3000");
}
#[test]
fn parse_bind_addr_invalid() {
assert!(parse_bind_addr("not-an-address").is_err());
}
#[tokio::test]
async fn post_empty_body_returns_400() {
let shared = test_shared_state();
let app = test_router(shared);
let req = Request::builder()
.method("POST")
.uri("/message")
.header("host", "localhost:3000")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn post_invalid_json_returns_400() {
let shared = test_shared_state();
let app = test_router(shared);
let req = Request::builder()
.method("POST")
.uri("/message")
.header("content-type", "application/json")
.header("host", "localhost:3000")
.body(Body::from("not json"))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn post_oversized_body_returns_413() {
let (incoming_tx, _rx) = mpsc::channel(32);
let (sse_tx, _) = broadcast::channel(256);
let shared = Arc::new(HttpSharedState {
incoming_tx,
sse_tx,
connections: Arc::new(DashMap::new()),
next_connection_id: AtomicU64::new(1),
max_message_size: 10, sse_connections: AtomicUsize::new(0),
cancel: CancellationToken::new(),
session_id: "test-session".to_string(),
pending_server_requests: tokio::sync::Mutex::new(std::collections::HashMap::new()),
});
let app = test_router(shared);
let body = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
let req = Request::builder()
.method("POST")
.uri("/message")
.header("content-type", "application/json")
.header("host", "localhost:3000")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
#[tokio::test]
async fn post_valid_message_returns_200() {
let (incoming_tx, mut incoming_rx) = mpsc::channel(32);
let (sse_tx, _) = broadcast::channel(256);
let shared = Arc::new(HttpSharedState {
incoming_tx,
sse_tx,
connections: Arc::new(DashMap::new()),
next_connection_id: AtomicU64::new(1),
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
sse_connections: AtomicUsize::new(0),
cancel: CancellationToken::new(),
session_id: "test-session".to_string(),
pending_server_requests: tokio::sync::Mutex::new(std::collections::HashMap::new()),
});
let app = test_router(shared);
let body = r#"{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}"#;
let req = Request::builder()
.method("POST")
.uri("/message")
.header("content-type", "application/json")
.header("host", "localhost:3000")
.body(Body::from(body))
.unwrap();
tokio::spawn(async move {
if let Some(incoming) = incoming_rx.recv().await {
let response = Bytes::from(r#"{"jsonrpc":"2.0","result":{},"id":1}"#);
if let Some(tx) = incoming.response_tx {
tx.send(Ok(response)).await.ok();
}
}
});
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn sse_endpoint_returns_200() {
let shared = test_shared_state();
let app = test_router(shared);
let req = Request::builder()
.method("GET")
.uri("/sse")
.header("host", "localhost:3000")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn connection_tracking() {
let cancel = CancellationToken::new();
let config = HttpConfig {
bind_addr: "127.0.0.1:0".to_string(),
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
};
let (transport, _addr) = HttpTransport::bind(config, cancel.clone()).await.unwrap();
assert_eq!(transport.shared.connections.len(), 0);
transport.finalize_response().await.unwrap();
assert_eq!(transport.shared.connections.len(), 0);
transport.shutdown();
}
#[tokio::test]
async fn transport_type_is_http() {
let cancel = CancellationToken::new();
let config = HttpConfig {
bind_addr: "127.0.0.1:0".to_string(),
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
};
let (transport, _addr) = HttpTransport::bind(config, cancel.clone()).await.unwrap();
assert_eq!(transport.transport_type(), TransportType::Http);
transport.shutdown();
}
#[tokio::test]
async fn debug_format() {
let cancel = CancellationToken::new();
let config = HttpConfig {
bind_addr: "127.0.0.1:0".to_string(),
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
};
let (transport, _addr) = HttpTransport::bind(config, cancel.clone()).await.unwrap();
let debug = format!("{transport:?}");
assert!(debug.contains("HttpTransport"));
transport.shutdown();
}
#[tokio::test]
async fn default_connection_context_is_stdio() {
let cancel = CancellationToken::new();
let config = HttpConfig {
bind_addr: "127.0.0.1:0".to_string(),
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
};
let (transport, _addr) = HttpTransport::bind(config, cancel.clone()).await.unwrap();
let ctx = transport.connection_context();
assert_eq!(ctx.connection_id, 0);
transport.shutdown();
}
#[tokio::test]
async fn response_handle_send_and_finalize() {
let (response_tx, mut response_rx) = mpsc::channel(64);
let connections: Arc<DashMap<u64, ConnectionState>> = Arc::new(DashMap::new());
connections.insert(
42,
ConnectionState {
remote_addr: SocketAddr::from(([127, 0, 0, 1], 9999)),
connected_at: Instant::now(),
request_count: AtomicU64::new(1),
},
);
let handle = ResponseHandle {
response_tx: Some(response_tx),
context: ConnectionContext {
connection_id: 42,
remote_addr: Some(SocketAddr::from(([127, 0, 0, 1], 9999))),
is_exclusive: false,
connected_at: Instant::now(),
},
_guard: ConnectionGuard::new(Arc::clone(&connections), 42),
};
handle.send_raw(b"hello").await.unwrap();
let received = response_rx.recv().await.unwrap().unwrap();
assert_eq!(&received[..], b"hello");
handle.finalize();
assert!(response_rx.recv().await.is_none());
assert_eq!(connections.len(), 0);
}
#[tokio::test]
async fn response_handle_adapter_implements_transport() {
let (response_tx, mut response_rx) = mpsc::channel(64);
let connections: Arc<DashMap<u64, ConnectionState>> = Arc::new(DashMap::new());
let handle = ResponseHandle {
response_tx: Some(response_tx),
context: ConnectionContext {
connection_id: 1,
remote_addr: None,
is_exclusive: false,
connected_at: Instant::now(),
},
_guard: ConnectionGuard::new(connections, 1),
};
let adapter = ResponseHandleAdapter::new(handle);
assert_eq!(adapter.transport_type(), TransportType::Http);
adapter.send_raw(b"raw bytes").await.unwrap();
let received = response_rx.recv().await.unwrap().unwrap();
assert_eq!(&received[..], b"raw bytes");
assert!(adapter.receive_message().await.is_err());
adapter.finalize();
}
#[tokio::test]
async fn sse_stream_content_type() {
let shared = test_shared_state();
let app = test_router(shared);
let req = Request::builder()
.method("GET")
.uri("/sse")
.header("host", "localhost:3000")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let content_type = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
content_type.contains("text/event-stream"),
"Expected text/event-stream, got: {content_type}"
);
}
#[tokio::test]
async fn concurrent_posts_all_succeed() {
let (incoming_tx, mut incoming_rx) = mpsc::channel(32);
let (sse_tx, _) = broadcast::channel(256);
let shared = Arc::new(HttpSharedState {
incoming_tx,
sse_tx,
connections: Arc::new(DashMap::new()),
next_connection_id: AtomicU64::new(1),
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
sse_connections: AtomicUsize::new(0),
cancel: CancellationToken::new(),
session_id: "test-session".to_string(),
pending_server_requests: tokio::sync::Mutex::new(std::collections::HashMap::new()),
});
let router = build_router(shared);
tokio::spawn(async move {
while let Some(incoming) = incoming_rx.recv().await {
let response = Bytes::from(r#"{"jsonrpc":"2.0","result":{},"id":1}"#);
if let Some(tx) = incoming.response_tx {
tx.send(Ok(response)).await.ok();
}
}
});
let body = r#"{"jsonrpc":"2.0","method":"test","params":{},"id":1}"#;
let mut handles = Vec::new();
for _ in 0..3 {
let app = router
.clone()
.layer(MockConnectInfo(SocketAddr::from(([127, 0, 0, 1], 9999))));
let req = Request::builder()
.method("POST")
.uri("/message")
.header("content-type", "application/json")
.header("host", "localhost:3000")
.body(Body::from(body))
.unwrap();
handles.push(tokio::spawn(async move {
app.oneshot(req).await.unwrap().status()
}));
}
for handle in handles {
let status = handle.await.unwrap();
assert_eq!(status, StatusCode::OK);
}
}
#[tokio::test]
async fn session_id_header_on_initialize() {
let (incoming_tx, mut incoming_rx) = mpsc::channel(32);
let (sse_tx, _) = broadcast::channel(256);
let shared = Arc::new(HttpSharedState {
incoming_tx,
sse_tx,
connections: Arc::new(DashMap::new()),
next_connection_id: AtomicU64::new(1),
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
sse_connections: AtomicUsize::new(0),
cancel: CancellationToken::new(),
session_id: "test-session-42".to_string(),
pending_server_requests: tokio::sync::Mutex::new(std::collections::HashMap::new()),
});
let app = test_router(shared);
tokio::spawn(async move {
if let Some(incoming) = incoming_rx.recv().await {
let response = Bytes::from(r#"{"jsonrpc":"2.0","result":{},"id":1}"#);
if let Some(tx) = incoming.response_tx {
tx.send(Ok(response)).await.ok();
}
}
});
let body = r#"{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}"#;
let req = Request::builder()
.method("POST")
.uri("/message")
.header("content-type", "application/json")
.header("host", "localhost:3000")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let session_id = resp
.headers()
.get("mcp-session-id")
.and_then(|v| v.to_str().ok());
assert_eq!(
session_id,
Some("test-session-42"),
"Expected mcp-session-id header"
);
}
#[tokio::test]
async fn post_after_cancel_rejected() {
let cancel = CancellationToken::new();
let (incoming_tx, _rx) = mpsc::channel(32);
let (sse_tx, _) = broadcast::channel(256);
let shared = Arc::new(HttpSharedState {
incoming_tx,
sse_tx,
connections: Arc::new(DashMap::new()),
next_connection_id: AtomicU64::new(1),
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
sse_connections: AtomicUsize::new(0),
cancel: cancel.clone(),
session_id: "test-session".to_string(),
pending_server_requests: tokio::sync::Mutex::new(std::collections::HashMap::new()),
});
let app = test_router(shared);
cancel.cancel();
let body = r#"{"jsonrpc":"2.0","method":"test","params":{},"id":1}"#;
let req = Request::builder()
.method("POST")
.uri("/message")
.header("content-type", "application/json")
.header("host", "localhost:3000")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert!(
resp.status().is_client_error()
|| resp.status().is_server_error()
|| resp.status().is_success(),
"Expected a valid HTTP status, got: {}",
resp.status()
);
}
#[test]
fn connection_guard_cleanup_on_drop() {
let connections: Arc<DashMap<u64, ConnectionState>> = Arc::new(DashMap::new());
connections.insert(
42,
ConnectionState {
remote_addr: SocketAddr::from(([127, 0, 0, 1], 9999)),
connected_at: Instant::now(),
request_count: AtomicU64::new(1),
},
);
assert_eq!(connections.len(), 1);
{
let _guard = ConnectionGuard::new(Arc::clone(&connections), 42);
}
assert_eq!(connections.len(), 0);
}
#[tokio::test]
async fn mcp_post_returns_sse_response() {
let (incoming_tx, mut incoming_rx) = mpsc::channel(32);
let (sse_tx, _) = broadcast::channel(256);
let shared = Arc::new(HttpSharedState {
incoming_tx,
sse_tx,
connections: Arc::new(DashMap::new()),
next_connection_id: AtomicU64::new(1),
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
sse_connections: AtomicUsize::new(0),
cancel: CancellationToken::new(),
session_id: "test-session".to_string(),
pending_server_requests: tokio::sync::Mutex::new(std::collections::HashMap::new()),
});
let app = test_router(shared);
tokio::spawn(async move {
if let Some(incoming) = incoming_rx.recv().await {
let response = Bytes::from(r#"{"jsonrpc":"2.0","result":{},"id":1}"#);
if let Some(tx) = incoming.response_tx {
tx.send(Ok(response)).await.ok();
}
}
});
let body = r#"{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}"#;
let req = Request::builder()
.method("POST")
.uri("/mcp")
.header("content-type", "application/json")
.header("host", "localhost:3000")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let content_type = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
content_type.contains("text/event-stream"),
"Expected text/event-stream on POST /mcp, got: {content_type}"
);
}
#[tokio::test]
async fn mcp_get_returns_200() {
let shared = test_shared_state();
let app = test_router(shared);
let req = Request::builder()
.method("GET")
.uri("/mcp")
.header("host", "localhost:3000")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
async fn assert_endpoint_event(uri: &str, expected_data: &str) {
let cancel = CancellationToken::new();
let (incoming_tx, _incoming_rx) = mpsc::channel(32);
let (sse_tx, _) = broadcast::channel(256);
let shared = Arc::new(HttpSharedState {
incoming_tx,
sse_tx,
connections: Arc::new(DashMap::new()),
next_connection_id: AtomicU64::new(1),
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
sse_connections: AtomicUsize::new(0),
cancel: cancel.clone(),
session_id: "test-session".to_string(),
pending_server_requests: tokio::sync::Mutex::new(std::collections::HashMap::new()),
});
let app = test_router(shared.clone());
let req = Request::builder()
.method("GET")
.uri(uri)
.header("host", "localhost:3000")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
cancel.cancel();
let _ = shared.sse_tx.send(String::new());
let body = axum::body::to_bytes(resp.into_body(), 64 * 1024)
.await
.unwrap();
let text = String::from_utf8_lossy(&body);
let expected = format!("event: endpoint\ndata: {expected_data}");
assert!(
text.contains(&expected),
"Expected '{expected}' in SSE body, got:\n{text}"
);
}
#[tokio::test]
async fn mcp_get_endpoint_event_points_to_mcp() {
assert_endpoint_event("/mcp", "/mcp").await;
}
#[tokio::test]
async fn sse_get_endpoint_event_points_to_message() {
assert_endpoint_event("/sse", "/message").await;
}
#[tokio::test]
async fn sse_connection_counter_tracks() {
let shared = test_shared_state();
assert_eq!(shared.sse_connections.load(Ordering::SeqCst), 0);
let app = test_router(shared.clone());
let req = Request::builder()
.method("GET")
.uri("/sse")
.header("host", "localhost:3000")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let count = shared.sse_connections.load(Ordering::SeqCst);
assert!(count <= 1, "Unexpected SSE connection count: {count}");
}
}