use crate::error::{Error, Result, TransportError};
use crate::shared::http_constants::{
ACCEPT, ACCEPT_STREAMABLE, APPLICATION_JSON, CONTENT_TYPE, LAST_EVENT_ID, MCP_PROTOCOL_VERSION,
MCP_SESSION_ID, TEXT_EVENT_STREAM,
};
use crate::shared::sse_parser::SseParser;
use crate::shared::{Transport, TransportMessage};
use async_trait::async_trait;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::{Method, Request, Response as HyperResponse, StatusCode};
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use parking_lot::RwLock;
use std::fmt::Debug;
use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
use tokio::sync::mpsc;
use url::Url;
#[derive(Debug, Clone, Default)]
pub struct SendOptions {
pub related_request_id: Option<String>,
pub resumption_token: Option<String>,
}
#[derive(Clone)]
pub struct StreamableHttpTransportConfig {
pub url: Url,
pub extra_headers: Vec<(String, String)>,
pub auth_provider: Option<Arc<dyn AuthProvider>>,
pub session_id: Option<String>,
pub enable_json_response: bool,
pub on_resumption_token: Option<Arc<dyn Fn(String) + Send + Sync>>,
pub http_middleware_chain: Option<Arc<crate::client::http_middleware::HttpMiddlewareChain>>,
}
impl Debug for StreamableHttpTransportConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamableHttpTransportConfig")
.field("url", &self.url)
.field("extra_headers", &self.extra_headers)
.field("auth_provider", &self.auth_provider.is_some())
.field("session_id", &self.session_id)
.field("enable_json_response", &self.enable_json_response)
.field("on_resumption_token", &self.on_resumption_token.is_some())
.field(
"http_middleware_chain",
&self.http_middleware_chain.is_some(),
)
.finish()
}
}
pub struct StreamableHttpTransportConfigBuilder {
url: Url,
extra_headers: Vec<(String, String)>,
auth_provider: Option<Arc<dyn AuthProvider>>,
session_id: Option<String>,
enable_json_response: bool,
on_resumption_token: Option<Arc<dyn Fn(String) + Send + Sync>>,
http_middleware_chain: Option<Arc<crate::client::http_middleware::HttpMiddlewareChain>>,
}
impl Debug for StreamableHttpTransportConfigBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamableHttpTransportConfigBuilder")
.field("url", &self.url)
.field("extra_headers", &self.extra_headers)
.field("auth_provider", &self.auth_provider.is_some())
.field("session_id", &self.session_id)
.field("enable_json_response", &self.enable_json_response)
.field("on_resumption_token", &self.on_resumption_token.is_some())
.field(
"http_middleware_chain",
&self.http_middleware_chain.is_some(),
)
.finish()
}
}
impl StreamableHttpTransportConfigBuilder {
pub fn new(url: Url) -> Self {
Self {
url,
extra_headers: Vec::new(),
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
}
}
pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.extra_headers.push((name.into(), value.into()));
self
}
pub fn with_auth_provider(mut self, provider: Arc<dyn AuthProvider>) -> Self {
self.auth_provider = Some(provider);
self
}
pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
self.session_id = Some(session_id.into());
self
}
pub fn enable_json_response(mut self) -> Self {
self.enable_json_response = true;
self
}
pub fn on_resumption_token(mut self, callback: Arc<dyn Fn(String) + Send + Sync>) -> Self {
self.on_resumption_token = Some(callback);
self
}
pub fn with_http_middleware(
mut self,
chain: Arc<crate::client::http_middleware::HttpMiddlewareChain>,
) -> Self {
self.http_middleware_chain = Some(chain);
self
}
pub fn build(self) -> StreamableHttpTransportConfig {
StreamableHttpTransportConfig {
url: self.url,
extra_headers: self.extra_headers,
auth_provider: self.auth_provider,
session_id: self.session_id,
enable_json_response: self.enable_json_response,
on_resumption_token: self.on_resumption_token,
http_middleware_chain: self.http_middleware_chain,
}
}
}
#[derive(Clone)]
pub struct StreamableHttpTransport {
config: Arc<RwLock<StreamableHttpTransportConfig>>,
client: Client<
hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
Full<Bytes>,
>,
receiver: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<TransportMessage>>>,
sender: mpsc::UnboundedSender<TransportMessage>,
protocol_version: Arc<RwLock<Option<String>>>,
abort_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
last_event_id: Arc<RwLock<Option<String>>>,
}
impl Debug for StreamableHttpTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamableHttpTransport")
.field("config", &self.config)
.field("protocol_version", &self.protocol_version)
.field("last_event_id", &self.last_event_id)
.finish()
}
}
impl StreamableHttpTransport {
pub fn new(config: StreamableHttpTransportConfig) -> Self {
Self::new_internal(config, false)
}
pub fn new_with_http2(config: StreamableHttpTransportConfig) -> Self {
Self::new_internal(config, true)
}
fn new_internal(config: StreamableHttpTransportConfig, enable_http2: bool) -> Self {
let _ = rustls::crypto::ring::default_provider().install_default();
let https = if enable_http2 {
tracing::debug!("Creating HTTPS connector with HTTP/1.1 and HTTP/2 support");
hyper_rustls::HttpsConnectorBuilder::new()
.with_native_roots()
.expect("Failed to load native root certificates")
.https_or_http()
.enable_http1()
.enable_http2()
.build()
} else {
tracing::debug!("Creating HTTPS connector with HTTP/1.1 only");
hyper_rustls::HttpsConnectorBuilder::new()
.with_native_roots()
.expect("Failed to load native root certificates")
.https_or_http()
.enable_http1()
.build()
};
let client = Client::builder(TokioExecutor::new())
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(10)
.build(https);
let (sender, receiver) = mpsc::unbounded_channel();
Self {
config: Arc::new(RwLock::new(config)),
client,
receiver: Arc::new(tokio::sync::Mutex::new(receiver)),
sender,
protocol_version: Arc::new(RwLock::new(None)),
abort_handle: Arc::new(RwLock::new(None)),
last_event_id: Arc::new(RwLock::new(None)),
}
}
pub fn session_id(&self) -> Option<String> {
self.config.read().session_id.clone()
}
pub fn set_session_id(&self, session_id: Option<String>) {
self.config.write().session_id = session_id;
}
pub fn protocol_version(&self) -> Option<String> {
self.protocol_version.read().clone()
}
pub fn set_protocol_version(&self, version: Option<String>) {
*self.protocol_version.write() = version;
}
pub fn last_event_id(&self) -> Option<String> {
self.last_event_id.read().clone()
}
pub async fn start_sse(&self, resumption_token: Option<String>) -> Result<()> {
let handle = self.abort_handle.write().take();
if let Some(handle) = handle {
handle.abort();
}
let url = self.config.read().url.clone();
let mut request = self
.build_request_with_middleware(
Method::GET,
url.as_str(),
vec![], )
.await?;
request.headers_mut().insert(
ACCEPT,
TEXT_EVENT_STREAM.parse().map_err(|e| {
Error::Transport(TransportError::InvalidMessage(format!(
"Invalid header: {}",
e
)))
})?,
);
if let Some(token) = &resumption_token {
request.headers_mut().insert(
LAST_EVENT_ID,
token.parse().map_err(|e| {
Error::Transport(TransportError::InvalidMessage(format!(
"Invalid header: {}",
e
)))
})?,
);
}
let response = self
.client
.request(request)
.await
.map_err(|e| Error::Transport(TransportError::Request(e.to_string())))?;
if response.status() == StatusCode::METHOD_NOT_ALLOWED {
return Ok(());
}
if !response.status().is_success() {
return Err(Error::Transport(TransportError::Request(format!(
"SSE request failed with status: {}",
response.status()
))));
}
self.process_response_headers(&response);
let body_bytes = response
.collect()
.await
.map_err(|e| Error::Transport(TransportError::Request(e.to_string())))?
.to_bytes();
let modified_body = if self.config.read().http_middleware_chain.is_some() {
let temp_response = HyperResponse::builder()
.status(200)
.body(Full::new(Bytes::new()))
.unwrap();
self.apply_response_middleware("GET", url.as_str(), &temp_response, body_bytes.to_vec())
.await?
} else {
body_bytes.to_vec()
};
let sender = self.sender.clone();
let on_resumption = self.config.read().on_resumption_token.clone();
let last_event_id = self.last_event_id.clone();
let handle = tokio::spawn(async move {
let mut sse_parser = SseParser::new();
let body = String::from_utf8_lossy(&modified_body);
let events = sse_parser.feed(&body);
for event in events {
if let Some(id) = &event.id {
*last_event_id.write() = Some(id.clone());
if let Some(callback) = &on_resumption {
callback(id.clone());
}
}
if event.event.as_deref() == Some("message") || event.event.is_none() {
if let Ok(msg) =
crate::shared::StdioTransport::parse_message(event.data.as_bytes())
{
let _ = sender.send(msg);
}
}
}
});
*self.abort_handle.write() = Some(handle);
Ok(())
}
async fn build_request_with_middleware(
&self,
method: Method,
url: &str,
body: Vec<u8>,
) -> Result<Request<Full<Bytes>>> {
use crate::client::http_middleware::{HttpMiddlewareContext, HttpRequest};
let (extra_headers, auth_provider, session_id, middleware_chain) = {
let config = self.config.read();
(
config.extra_headers.clone(),
config.auth_provider.clone(),
config.session_id.clone(),
config.http_middleware_chain.clone(),
)
};
let mut request_builder = Request::builder().method(method.clone()).uri(url);
for (key, value) in &extra_headers {
request_builder = request_builder.header(key.as_str(), value.as_str());
}
let has_auth = if let Some(auth_provider) = auth_provider {
let token = auth_provider.get_access_token().await?;
request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
true
} else {
false
};
if let Some(session_id) = &session_id {
request_builder = request_builder.header(MCP_SESSION_ID, session_id.as_str());
}
if let Some(protocol_version) = self.protocol_version.read().as_ref() {
request_builder =
request_builder.header(MCP_PROTOCOL_VERSION, protocol_version.as_str());
}
let temp_req = request_builder
.body(Full::new(Bytes::from(body.clone())))
.map_err(|e| Error::Transport(TransportError::InvalidMessage(e.to_string())))?;
let headers = temp_req.headers();
if let Some(chain) = middleware_chain {
let mut http_req = HttpRequest::new(method.as_str().to_string(), url.to_string(), body);
for (key, value) in headers {
if let Ok(value_str) = value.to_str() {
http_req.add_header(key.as_str(), value_str);
}
}
let context = HttpMiddlewareContext::new(url.to_string(), method.as_str().to_string());
if has_auth {
context.set_metadata("auth_already_set".to_string(), "true".to_string());
}
if let Err(e) = chain.process_request(&mut http_req, &context).await {
chain.handle_transport_error(&e, &context).await;
return Err(e);
}
let mut final_builder = Request::builder().method(method).uri(url);
for (key, value) in &http_req.headers {
final_builder = final_builder.header(key, value);
}
final_builder
.body(Full::new(Bytes::from(http_req.body)))
.map_err(|e| Error::Transport(TransportError::InvalidMessage(e.to_string())))
} else {
Ok(temp_req)
}
}
#[allow(clippy::future_not_send)]
async fn apply_response_middleware(
&self,
method: &str,
url: &str,
response: &HyperResponse<impl hyper::body::Body>,
body: Vec<u8>,
) -> Result<Vec<u8>> {
use crate::client::http_middleware::{HttpMiddlewareContext, HttpResponse};
let middleware_chain = self.config.read().http_middleware_chain.clone();
if let Some(chain) = middleware_chain {
let header_map = response.headers().clone();
let mut http_resp =
HttpResponse::with_headers(response.status().as_u16(), header_map, body);
let context = HttpMiddlewareContext::new(url.to_string(), method.to_string());
if let Err(e) = chain.process_response(&mut http_resp, &context).await {
chain.handle_transport_error(&e, &context).await;
return Err(e);
}
Ok(http_resp.body)
} else {
Ok(body)
}
}
fn process_response_headers(&self, response: &HyperResponse<impl hyper::body::Body>) {
if let Some(session_id) = response.headers().get(MCP_SESSION_ID) {
if let Ok(session_id_str) = session_id.to_str() {
self.config.write().session_id = Some(session_id_str.to_string());
}
}
if let Some(protocol_version) = response.headers().get(MCP_PROTOCOL_VERSION) {
if let Ok(protocol_version_str) = protocol_version.to_str() {
*self.protocol_version.write() = Some(protocol_version_str.to_string());
}
}
}
pub async fn send_with_options(
&mut self,
message: TransportMessage,
options: SendOptions,
) -> Result<()> {
if let Some(token) = options.resumption_token {
self.start_sse(Some(token)).await?;
return Ok(());
}
let body_bytes = crate::shared::StdioTransport::serialize_message(&message)?;
let url = self.config.read().url.clone();
let mut request = self
.build_request_with_middleware(Method::POST, url.as_str(), body_bytes)
.await?;
request.headers_mut().insert(
CONTENT_TYPE,
APPLICATION_JSON.parse().map_err(|e| {
Error::Transport(TransportError::InvalidMessage(format!(
"Invalid header: {}",
e
)))
})?,
);
request.headers_mut().insert(
ACCEPT,
ACCEPT_STREAMABLE.parse().map_err(|e| {
Error::Transport(TransportError::InvalidMessage(format!(
"Invalid header: {}",
e
)))
})?,
);
let response = self
.client
.request(request)
.await
.map_err(|e| Error::Transport(TransportError::Request(e.to_string())))?;
self.process_response_headers(&response);
if !response.status().is_success() {
if response.status() == StatusCode::ACCEPTED {
if matches!(message, TransportMessage::Notification { .. }) {
let _ = self.start_sse(None).await;
}
return Ok(());
}
return Err(Error::Transport(TransportError::Request(format!(
"Request failed with status: {}",
response.status()
))));
}
let status_code = response.status();
let content_type = response
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let content_length = response
.headers()
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<usize>().ok());
let body_bytes = response
.collect()
.await
.map_err(|e| Error::Transport(TransportError::Request(e.to_string())))?
.to_bytes();
tracing::debug!(
status = %status_code,
content_type = %content_type,
content_length = ?content_length,
body_len = body_bytes.len(),
"HTTP response received"
);
let modified_body = if self.config.read().http_middleware_chain.is_some() {
let temp_response = HyperResponse::builder()
.status(status_code)
.body(Full::new(Bytes::new()))
.unwrap();
self.apply_response_middleware(
"POST",
url.as_str(),
&temp_response,
body_bytes.to_vec(),
)
.await?
} else {
body_bytes.to_vec()
};
if status_code == StatusCode::OK && (content_length == Some(0) || content_type.is_empty()) {
if modified_body.is_empty() {
return Ok(());
}
if content_type.is_empty() {
return Err(Error::Transport(TransportError::Request(
"Response has body but no Content-Type header".to_string(),
)));
}
if let Ok(batch) = serde_json::from_slice::<Vec<serde_json::Value>>(&modified_body) {
for json_msg in batch {
let json_str = serde_json::to_string(&json_msg).map_err(|e| {
Error::Transport(TransportError::Deserialization(e.to_string()))
})?;
let msg = crate::shared::StdioTransport::parse_message(json_str.as_bytes())?;
self.sender
.send(msg)
.map_err(|e| Error::Transport(TransportError::Send(e.to_string())))?;
}
} else {
let msg_parsed = crate::shared::StdioTransport::parse_message(&modified_body)?;
self.sender
.send(msg_parsed)
.map_err(|e| Error::Transport(TransportError::Send(e.to_string())))?;
}
return Ok(());
}
if content_type.contains(APPLICATION_JSON) {
if modified_body.is_empty() {
if status_code == StatusCode::ACCEPTED {
tracing::debug!(
status = %status_code,
"Notification acknowledged with 202 Accepted"
);
return Ok(());
}
tracing::warn!(
status = %status_code,
content_type = %content_type,
"Server returned empty body with application/json content type"
);
return Err(Error::Transport(TransportError::Request(
"Server returned empty response body with Content-Type: application/json. \
This may indicate a server error or network issue."
.to_string(),
)));
}
if let Ok(batch) = serde_json::from_slice::<Vec<serde_json::Value>>(&modified_body) {
for json_msg in batch {
let json_str = serde_json::to_string(&json_msg).map_err(|e| {
Error::Transport(TransportError::Deserialization(e.to_string()))
})?;
let msg = crate::shared::StdioTransport::parse_message(json_str.as_bytes())?;
self.sender
.send(msg)
.map_err(|e| Error::Transport(TransportError::Send(e.to_string())))?;
}
} else {
let msg_parsed = crate::shared::StdioTransport::parse_message(&modified_body)?;
self.sender
.send(msg_parsed)
.map_err(|e| Error::Transport(TransportError::Send(e.to_string())))?;
}
} else if content_type.contains(TEXT_EVENT_STREAM) {
let sender = self.sender.clone();
let on_resumption = self.config.read().on_resumption_token.clone();
let last_event_id = self.last_event_id.clone();
tokio::spawn(async move {
let mut sse_parser = SseParser::new();
let body = String::from_utf8_lossy(&modified_body);
let events = sse_parser.feed(&body);
for event in events {
if let Some(id) = &event.id {
*last_event_id.write() = Some(id.clone());
if let Some(callback) = &on_resumption {
callback(id.clone());
}
}
if event.event.as_deref() == Some("message") || event.event.is_none() {
if let Ok(msg) =
crate::shared::StdioTransport::parse_message(event.data.as_bytes())
{
let _ = sender.send(msg);
}
}
}
});
} else if status_code == StatusCode::ACCEPTED {
return Ok(());
} else {
return Err(Error::Transport(TransportError::Request(format!(
"Unsupported content type: {}",
content_type
))));
}
Ok(())
}
}
#[async_trait]
impl Transport for StreamableHttpTransport {
async fn send(&mut self, message: TransportMessage) -> Result<()> {
self.send_with_options(message, SendOptions::default())
.await
}
async fn receive(&mut self) -> Result<TransportMessage> {
let mut receiver = self.receiver.lock().await;
receiver
.recv()
.await
.ok_or_else(|| Error::Transport(TransportError::ConnectionClosed))
}
async fn close(&mut self) -> Result<()> {
let handle = self.abort_handle.write().take();
if let Some(handle) = handle {
handle.abort();
}
if let Some(_session_id) = self.session_id() {
let url = self.config.read().url.clone();
let request = self
.build_request_with_middleware(Method::DELETE, url.as_str(), vec![])
.await?;
let response = self.client.request(request).await;
if let Ok(resp) = response {
if !resp.status().is_success() && resp.status() != StatusCode::METHOD_NOT_ALLOWED {
tracing::warn!("Failed to terminate session: {}", resp.status());
}
}
self.config.write().session_id = None;
}
Ok(())
}
fn is_connected(&self) -> bool {
true
}
}
#[async_trait]
pub trait AuthProvider: Send + Sync + Debug {
async fn get_access_token(&self) -> Result<String>;
}