use super::{InboundHandler, McpTransport};
use crate::mcp::auth::{Credential, McpCredentialProvider};
use crate::mcp::types::RawJsonRpcMessage;
use anyhow::{Context, Result};
use futures::StreamExt;
use reqwest::Client;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, oneshot};
pub const DEFAULT_TIMEOUT_MS: u64 = 30_000;
pub struct StreamableHttpTransport {
endpoint: String,
server_name: String,
client: Client,
session_id: Mutex<Option<String>>,
inbound_handler: parking_lot::Mutex<Option<InboundHandler>>,
#[allow(dead_code)]
pending: Mutex<HashMap<u64, oneshot::Sender<RawJsonRpcMessage>>>,
credential_provider: Option<Arc<dyn McpCredentialProvider>>,
timeout: Duration,
closed: Mutex<bool>,
}
impl std::fmt::Debug for StreamableHttpTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamableHttpTransport")
.field("endpoint", &self.endpoint)
.field("server_name", &self.server_name)
.field("connected", &self.is_connected())
.finish()
}
}
impl StreamableHttpTransport {
pub fn new(
server_name: &str,
endpoint: &str,
credential_provider: Option<Arc<dyn McpCredentialProvider>>,
timeout_ms: u64,
) -> Result<Self> {
let client = Client::builder()
.user_agent(concat!("oxi-mcp/", env!("CARGO_PKG_VERSION")))
.build()
.context("Failed to build reqwest client for MCP Streamable HTTP")?;
Ok(Self {
endpoint: endpoint.to_string(),
server_name: server_name.to_string(),
client,
session_id: Mutex::new(None),
inbound_handler: parking_lot::Mutex::new(None),
pending: Mutex::new(HashMap::new()),
credential_provider,
timeout: if timeout_ms == 0 {
Duration::from_secs(60 * 60 * 24 * 365)
} else {
Duration::from_millis(timeout_ms)
},
closed: Mutex::new(false),
})
}
async fn build_headers(&self, credential: Option<&Credential>) -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Accept",
"application/json, text/event-stream"
.parse()
.expect("static Accept header is valid"),
);
if let Some(sid) = self.session_id.lock().await.as_deref() {
if let Ok(v) = sid.parse() {
headers.insert("Mcp-Session-Id", v);
}
}
if let Some(cred) = credential {
if let Ok(v) = format!("Bearer {}", cred.access_token).parse() {
headers.insert(reqwest::header::AUTHORIZATION, v);
}
}
headers
}
async fn capture_session_id(&self, resp: &reqwest::Response) {
if let Some(v) = resp.headers().get("Mcp-Session-Id") {
if let Ok(s) = v.to_str() {
*self.session_id.lock().await = Some(s.to_string());
}
}
}
async fn post_once(
&self,
json: &str,
credential: Option<&Credential>,
) -> Result<reqwest::Response> {
let headers = self.build_headers(credential).await;
self.client
.post(&self.endpoint)
.headers(headers)
.header("Content-Type", "application/json")
.body(json.to_string())
.send()
.await
.context("MCP Streamable HTTP POST failed")
}
fn dispatch_inbound(&self, msg: RawJsonRpcMessage) {
if let Some(h) = self.inbound_handler.lock().as_mut() {
h(msg);
}
}
}
#[async_trait::async_trait]
impl McpTransport for StreamableHttpTransport {
async fn request(&mut self, id: u64, json: &str) -> Result<RawJsonRpcMessage> {
if *self.closed.lock().await {
anyhow::bail!("MCP HTTP transport closed");
}
let credential = match self.credential_provider.as_ref() {
Some(p) => p.access_token(&self.server_name, &self.endpoint).await,
None => None,
};
let resp = self.post_once(json, credential.as_ref()).await?;
let status = resp.status();
if status.is_success() {
self.capture_session_id(&resp).await;
}
if (status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN)
&& self.credential_provider.is_some()
{
drop(resp);
let provider = self
.credential_provider
.as_ref()
.expect("checked is_some above");
let refreshed = provider.refresh(&self.server_name, &self.endpoint).await;
let credential2 = match refreshed {
Some(_) => {
provider
.access_token(&self.server_name, &self.endpoint)
.await
}
None => credential,
};
let resp2 = self.post_once(json, credential2.as_ref()).await?;
let status2 = resp2.status();
if status2.is_success() {
self.capture_session_id(&resp2).await;
}
if !status2.is_success() {
anyhow::bail!(
"MCP HTTP request failed after credential refresh: {} {}",
status2.as_u16(),
status2.canonical_reason().unwrap_or("")
);
}
return self.handle_response(resp2, id).await;
}
if !status.is_success() {
anyhow::bail!(
"MCP HTTP error {}: {}",
status.as_u16(),
status.canonical_reason().unwrap_or("")
);
}
self.handle_response(resp, id).await
}
async fn notify(&mut self, json: &str) -> Result<()> {
if *self.closed.lock().await {
anyhow::bail!("MCP HTTP transport closed");
}
let credential = match self.credential_provider.as_ref() {
Some(p) => p.access_token(&self.server_name, &self.endpoint).await,
None => None,
};
let resp = self.post_once(json, credential.as_ref()).await?;
let status = resp.status();
if status.is_success() {
self.capture_session_id(&resp).await;
}
if !status.is_success() {
anyhow::bail!(
"MCP HTTP notify failed: {} {}",
status.as_u16(),
status.canonical_reason().unwrap_or("")
);
}
Ok(())
}
fn set_inbound_handler(&mut self, handler: InboundHandler) {
*self.inbound_handler.lock() = Some(handler);
}
async fn close(&mut self) -> Result<()> {
*self.closed.lock().await = true;
let sid = self.session_id.lock().await.clone();
if let Some(sid) = sid {
if let Ok(v) = sid.parse::<reqwest::header::HeaderValue>() {
let _ = self
.client
.delete(&self.endpoint)
.header("Mcp-Session-Id", v)
.send()
.await;
}
}
Ok(())
}
fn is_connected(&self) -> bool {
!*self.closed.blocking_lock()
}
}
impl StreamableHttpTransport {
async fn handle_response(&self, resp: reqwest::Response, id: u64) -> Result<RawJsonRpcMessage> {
let ct = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
if ct.starts_with("text/event-stream") {
let mut stream = resp.bytes_stream();
let mut buffer: Vec<u8> = Vec::new();
let timeout = self.timeout;
let resolved = tokio::time::timeout(timeout, async {
while let Some(chunk) = stream.next().await {
let chunk = chunk.context("MCP HTTP SSE chunk read failed")?;
buffer.extend_from_slice(&chunk);
while let Some((event, consumed)) = parse_sse_event(&buffer) {
let rest = buffer.split_off(consumed);
buffer = rest;
let Some(data) = event else { continue };
let msg: RawJsonRpcMessage = match serde_json::from_slice(&data) {
Ok(m) => m,
Err(_) => continue,
};
if msg.id == Some(id) {
return Ok(msg);
}
self.dispatch_inbound(msg);
}
}
Err::<RawJsonRpcMessage, _>(anyhow::anyhow!(
"MCP HTTP SSE response ended without matching id"
))
})
.await
.map_err(|_| {
anyhow::anyhow!("MCP HTTP request timed out after {:?}", self.timeout)
})??;
return Ok(resolved);
}
let body = resp
.bytes()
.await
.context("Failed to read MCP HTTP response body")?;
let msg: RawJsonRpcMessage =
serde_json::from_slice(&body).context("Failed to parse MCP HTTP response JSON")?;
if msg.id != Some(id) {
self.dispatch_inbound(msg);
anyhow::bail!("MCP HTTP returned response with non-matching id");
}
Ok(msg)
}
}
fn parse_sse_event(buffer: &[u8]) -> Option<(Option<Vec<u8>>, usize)> {
let delim = find_sse_delim(buffer)?;
let event_bytes = &buffer[..delim];
let mut data: Option<Vec<u8>> = None;
for line in event_bytes.split(|b| *b == b'\n') {
let line = line.strip_suffix(b"\r").unwrap_or(line);
if line.starts_with(b"data:") {
let rest = if line.len() > 4 && line[4] == b' ' {
&line[5..]
} else if line.len() > 4 {
&line[4..]
} else {
&[][..]
};
let entry = data.get_or_insert_with(Vec::new);
if !entry.is_empty() {
entry.push(b'\n');
}
entry.extend_from_slice(rest);
}
}
Some((data, delim + 2))
}
fn find_sse_delim(buffer: &[u8]) -> Option<usize> {
for i in 0..buffer.len().saturating_sub(1) {
if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
return Some(i);
}
}
None
}