use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::RwLock;
use super::protocol::{JsonRpcRequest, JsonRpcResponse, RequestId};
use super::transport::McpTransport;
const SESSION_ID_HEADER: &str = "Mcp-Session-Id";
const PROTOCOL_VERSION_HEADER: &str = "MCP-Protocol-Version";
#[derive(Clone, Debug)]
pub struct HttpReply {
pub content_type: String,
pub body: String,
pub session_id: Option<String>,
}
impl HttpReply {
#[must_use]
pub fn json(body: impl Into<String>) -> Self {
Self {
content_type: "application/json".to_string(),
body: body.into(),
session_id: None,
}
}
#[must_use]
pub fn event_stream(body: impl Into<String>) -> Self {
Self {
content_type: "text/event-stream".to_string(),
body: body.into(),
session_id: None,
}
}
#[must_use]
pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
self.session_id = Some(session_id.into());
self
}
}
#[derive(Clone, Debug)]
pub struct HttpRequest {
pub body: String,
pub authorization: Option<String>,
pub session_id: Option<String>,
pub protocol_version: Option<String>,
pub extra_headers: Vec<(String, String)>,
}
#[async_trait]
pub trait HttpPoster: Send + Sync {
async fn post(&self, request: HttpRequest) -> Result<HttpReply>;
}
#[derive(Clone, Debug, Default)]
pub enum McpAuth {
#[default]
None,
Bearer(String),
}
impl McpAuth {
#[must_use]
fn header_value(&self) -> Option<String> {
match self {
Self::None => None,
Self::Bearer(token) => Some(format!("Bearer {token}")),
}
}
}
pub struct StreamableHttpTransport {
poster: Arc<dyn HttpPoster>,
auth: McpAuth,
extra_headers: Vec<(String, String)>,
next_id: AtomicU64,
session_id: RwLock<Option<String>>,
protocol_version: RwLock<Option<String>>,
}
impl StreamableHttpTransport {
pub fn new(endpoint: impl Into<String>, auth: McpAuth) -> Result<Arc<Self>> {
let poster = ReqwestPoster::new(endpoint)?;
Ok(Self::with_poster(Arc::new(poster), auth))
}
#[must_use]
pub fn with_poster(poster: Arc<dyn HttpPoster>, auth: McpAuth) -> Arc<Self> {
Arc::new(Self {
poster,
auth,
extra_headers: Vec::new(),
next_id: AtomicU64::new(1),
session_id: RwLock::new(None),
protocol_version: RwLock::new(None),
})
}
#[must_use]
pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.extra_headers.push((name.into(), value.into()));
self
}
fn next_request_id(&self) -> u64 {
self.next_id.fetch_add(1, Ordering::SeqCst)
}
async fn build_http_request(&self, body: String) -> HttpRequest {
HttpRequest {
body,
authorization: self.auth.header_value(),
session_id: self.session_id.read().await.clone(),
protocol_version: self.protocol_version.read().await.clone(),
extra_headers: self.extra_headers.clone(),
}
}
async fn capture_session_id(&self, reply: &HttpReply) {
if let Some(ref sid) = reply.session_id {
let mut guard = self.session_id.write().await;
if guard.as_deref() != Some(sid.as_str()) {
*guard = Some(sid.clone());
}
}
}
}
fn parse_reply(reply: &HttpReply, id: &RequestId) -> Result<JsonRpcResponse> {
if reply.content_type.contains("text/event-stream") {
parse_sse_response(&reply.body, id)
} else {
serde_json::from_str::<JsonRpcResponse>(reply.body.trim())
.context("failed to parse JSON MCP response body")
}
}
fn parse_sse_response(body: &str, id: &RequestId) -> Result<JsonRpcResponse> {
let mut data_buf = String::new();
let mut last_parsed: Option<JsonRpcResponse> = None;
let flush =
|data: &mut String, last: &mut Option<JsonRpcResponse>| -> Option<JsonRpcResponse> {
if data.is_empty() {
return None;
}
let raw = std::mem::take(data);
if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(raw.trim()) {
if &resp.id == id {
return Some(resp);
}
*last = Some(resp);
}
None
};
for line in body.lines() {
let line = line.trim_end_matches('\r');
if line.is_empty() {
if let Some(resp) = flush(&mut data_buf, &mut last_parsed) {
return Ok(resp);
}
continue;
}
if let Some(rest) = line.strip_prefix("data:") {
let rest = rest.strip_prefix(' ').unwrap_or(rest);
if !data_buf.is_empty() {
data_buf.push('\n');
}
data_buf.push_str(rest);
}
}
if let Some(resp) = flush(&mut data_buf, &mut last_parsed) {
return Ok(resp);
}
last_parsed.context("SSE stream contained no JSON-RPC response matching the request id")
}
#[async_trait]
impl McpTransport for StreamableHttpTransport {
async fn send(&self, mut request: JsonRpcRequest) -> Result<JsonRpcResponse> {
let id = self.next_request_id();
request.id = RequestId::Number(id);
let request_id = request.id.clone();
let body = serde_json::to_string(&request).context("failed to serialize MCP request")?;
let http_request = self.build_http_request(body).await;
let reply = self.poster.post(http_request).await?;
self.capture_session_id(&reply).await;
let response = parse_reply(&reply, &request_id)?;
if let Some(ref error) = response.error {
bail!("JSON-RPC error {}: {}", error.code, error.message);
}
Ok(response)
}
async fn send_notification(&self, mut request: JsonRpcRequest) -> Result<()> {
let id = self.next_request_id();
request.id = RequestId::Number(id);
let body = serde_json::to_string(&request).context("failed to serialize MCP request")?;
let http_request = self.build_http_request(body).await;
let reply = self.poster.post(http_request).await?;
self.capture_session_id(&reply).await;
Ok(())
}
async fn set_protocol_version(&self, version: &str) {
let mut guard = self.protocol_version.write().await;
*guard = Some(version.to_string());
}
async fn close(&self) -> Result<()> {
Ok(())
}
}
pub struct ReqwestPoster {
client: reqwest::Client,
endpoint: String,
}
impl ReqwestPoster {
pub fn new(endpoint: impl Into<String>) -> Result<Self> {
let client = reqwest::Client::builder()
.build()
.context("failed to build MCP HTTP client")?;
Ok(Self {
client,
endpoint: endpoint.into(),
})
}
#[must_use]
pub fn with_client(client: reqwest::Client, endpoint: impl Into<String>) -> Self {
Self {
client,
endpoint: endpoint.into(),
}
}
}
#[async_trait]
impl HttpPoster for ReqwestPoster {
async fn post(&self, request: HttpRequest) -> Result<HttpReply> {
let mut builder = self
.client
.post(&self.endpoint)
.header(
reqwest::header::ACCEPT,
"application/json, text/event-stream",
)
.header(reqwest::header::CONTENT_TYPE, "application/json")
.body(request.body);
if let Some(auth) = request.authorization {
builder = builder.header(reqwest::header::AUTHORIZATION, auth);
}
if let Some(sid) = request.session_id {
builder = builder.header(SESSION_ID_HEADER, sid);
}
if let Some(version) = request.protocol_version {
builder = builder.header(PROTOCOL_VERSION_HEADER, version);
}
for (name, value) in request.extra_headers {
builder = builder.header(name, value);
}
let response = builder
.send()
.await
.context("MCP HTTP request failed to send")?;
let status = response.status();
let session_id = response
.headers()
.get(SESSION_ID_HEADER)
.and_then(|v| v.to_str().ok())
.map(ToString::to_string);
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map_or_else(
|| "application/json".to_string(),
|s| s.split(';').next().unwrap_or(s).trim().to_lowercase(),
);
let body = response
.text()
.await
.context("failed to read MCP HTTP response body")?;
if !status.is_success() {
bail!("MCP HTTP request returned status {status}: {body}");
}
Ok(HttpReply {
content_type,
body,
session_id,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ok_response(id: u64, result: &serde_json::Value) -> String {
serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": result,
})
.to_string()
}
#[test]
fn parse_json_body() {
let reply = HttpReply::json(ok_response(1, &serde_json::json!({"ok": true})));
let resp = parse_reply(&reply, &RequestId::Number(1)).expect("parse");
assert!(!resp.is_error());
assert!(resp.result().is_some());
}
#[test]
fn parse_sse_single_event() {
let body = format!(
"event: message\ndata: {}\n\n",
ok_response(2, &serde_json::json!({}))
);
let reply = HttpReply::event_stream(body);
let resp = parse_reply(&reply, &RequestId::Number(2)).expect("parse");
assert_eq!(resp.id, RequestId::Number(2));
}
#[test]
fn parse_sse_skips_non_matching_then_matches() {
let body = format!(
"data: {}\n\ndata: {}\n\n",
ok_response(99, &serde_json::json!({"unrelated": true})),
ok_response(3, &serde_json::json!({"answer": 42})),
);
let reply = HttpReply::event_stream(body);
let resp = parse_reply(&reply, &RequestId::Number(3)).expect("parse");
assert_eq!(resp.id, RequestId::Number(3));
}
#[test]
fn parse_sse_multiline_data() {
let body = "data: {\"jsonrpc\":\"2.0\",\ndata: \"id\":4,\ndata: \"result\":{}}\n\n";
let reply = HttpReply::event_stream(body.to_string());
let resp = parse_reply(&reply, &RequestId::Number(4)).expect("parse");
assert_eq!(resp.id, RequestId::Number(4));
}
#[test]
fn bearer_auth_header_value() {
assert_eq!(McpAuth::None.header_value(), None);
assert_eq!(
McpAuth::Bearer("tok".to_string()).header_value().as_deref(),
Some("Bearer tok"),
);
}
}