use crate::error::{Error, Result, TransportError};
use crate::shared::http_constants::{
ACCEPT, ACCEPT_STREAMABLE, APPLICATION_JSON, CONTENT_TYPE, LAST_EVENT_ID, MCP_PROTOCOL_VERSION,
MCP_SESSION_ID, TEXT_EVENT_STREAM,
};
use crate::shared::sse_parser::SseParser;
use crate::shared::{Transport, TransportMessage};
use async_trait::async_trait;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::{Method, Request, Response as HyperResponse, StatusCode};
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use parking_lot::RwLock;
use std::fmt::Debug;
use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
use tokio::sync::mpsc;
use url::Url;
#[derive(Debug, Clone, Default)]
pub struct SendOptions {
pub related_request_id: Option<String>,
pub resumption_token: Option<String>,
}
#[derive(Clone)]
pub struct StreamableHttpTransportConfig {
pub url: Url,
pub extra_headers: Vec<(String, String)>,
pub auth_provider: Option<Arc<dyn AuthProvider>>,
pub session_id: Option<String>,
pub enable_json_response: bool,
pub on_resumption_token: Option<Arc<dyn Fn(String) + Send + Sync>>,
pub http_middleware_chain: Option<Arc<crate::client::http_middleware::HttpMiddlewareChain>>,
}
impl Debug for StreamableHttpTransportConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamableHttpTransportConfig")
.field("url", &self.url)
.field("extra_headers", &self.extra_headers)
.field("auth_provider", &self.auth_provider.is_some())
.field("session_id", &self.session_id)
.field("enable_json_response", &self.enable_json_response)
.field("on_resumption_token", &self.on_resumption_token.is_some())
.field(
"http_middleware_chain",
&self.http_middleware_chain.is_some(),
)
.finish()
}
}
pub struct StreamableHttpTransportConfigBuilder {
url: Url,
extra_headers: Vec<(String, String)>,
auth_provider: Option<Arc<dyn AuthProvider>>,
session_id: Option<String>,
enable_json_response: bool,
on_resumption_token: Option<Arc<dyn Fn(String) + Send + Sync>>,
http_middleware_chain: Option<Arc<crate::client::http_middleware::HttpMiddlewareChain>>,
}
impl Debug for StreamableHttpTransportConfigBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamableHttpTransportConfigBuilder")
.field("url", &self.url)
.field("extra_headers", &self.extra_headers)
.field("auth_provider", &self.auth_provider.is_some())
.field("session_id", &self.session_id)
.field("enable_json_response", &self.enable_json_response)
.field("on_resumption_token", &self.on_resumption_token.is_some())
.field(
"http_middleware_chain",
&self.http_middleware_chain.is_some(),
)
.finish()
}
}
impl StreamableHttpTransportConfigBuilder {
pub fn new(url: Url) -> Self {
Self {
url,
extra_headers: Vec::new(),
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
}
}
pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.extra_headers.push((name.into(), value.into()));
self
}
pub fn with_auth_provider(mut self, provider: Arc<dyn AuthProvider>) -> Self {
self.auth_provider = Some(provider);
self
}
pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
self.session_id = Some(session_id.into());
self
}
pub fn enable_json_response(mut self) -> Self {
self.enable_json_response = true;
self
}
pub fn on_resumption_token(mut self, callback: Arc<dyn Fn(String) + Send + Sync>) -> Self {
self.on_resumption_token = Some(callback);
self
}
pub fn with_http_middleware(
mut self,
chain: Arc<crate::client::http_middleware::HttpMiddlewareChain>,
) -> Self {
self.http_middleware_chain = Some(chain);
self
}
pub fn build(self) -> StreamableHttpTransportConfig {
StreamableHttpTransportConfig {
url: self.url,
extra_headers: self.extra_headers,
auth_provider: self.auth_provider,
session_id: self.session_id,
enable_json_response: self.enable_json_response,
on_resumption_token: self.on_resumption_token,
http_middleware_chain: self.http_middleware_chain,
}
}
}
#[derive(Clone)]
pub struct StreamableHttpTransport {
config: Arc<RwLock<StreamableHttpTransportConfig>>,
client: Client<
hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
Full<Bytes>,
>,
receiver: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<TransportMessage>>>,
sender: mpsc::UnboundedSender<TransportMessage>,
protocol_version: Arc<RwLock<Option<String>>>,
abort_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
last_event_id: Arc<RwLock<Option<String>>>,
}
impl Debug for StreamableHttpTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamableHttpTransport")
.field("config", &self.config)
.field("protocol_version", &self.protocol_version)
.field("last_event_id", &self.last_event_id)
.finish()
}
}
impl StreamableHttpTransport {
pub fn new(config: StreamableHttpTransportConfig) -> Self {
Self::new_internal(config, false)
}
pub fn new_with_http2(config: StreamableHttpTransportConfig) -> Self {
Self::new_internal(config, true)
}
fn new_internal(config: StreamableHttpTransportConfig, enable_http2: bool) -> Self {
let _ = rustls::crypto::ring::default_provider().install_default();
let https = if enable_http2 {
tracing::debug!("Creating HTTPS connector with HTTP/1.1 and HTTP/2 support");
hyper_rustls::HttpsConnectorBuilder::new()
.with_native_roots()
.expect("Failed to load native root certificates")
.https_or_http()
.enable_http1()
.enable_http2()
.build()
} else {
tracing::debug!("Creating HTTPS connector with HTTP/1.1 only");
hyper_rustls::HttpsConnectorBuilder::new()
.with_native_roots()
.expect("Failed to load native root certificates")
.https_or_http()
.enable_http1()
.build()
};
let client = Client::builder(TokioExecutor::new())
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(10)
.build(https);
let (sender, receiver) = mpsc::unbounded_channel();
Self {
config: Arc::new(RwLock::new(config)),
client,
receiver: Arc::new(tokio::sync::Mutex::new(receiver)),
sender,
protocol_version: Arc::new(RwLock::new(None)),
abort_handle: Arc::new(RwLock::new(None)),
last_event_id: Arc::new(RwLock::new(None)),
}
}
pub fn session_id(&self) -> Option<String> {
self.config.read().session_id.clone()
}
pub fn set_session_id(&self, session_id: Option<String>) {
self.config.write().session_id = session_id;
}
pub fn protocol_version(&self) -> Option<String> {
self.protocol_version.read().clone()
}
pub fn set_protocol_version(&self, version: Option<String>) {
*self.protocol_version.write() = version;
}
pub fn last_event_id(&self) -> Option<String> {
self.last_event_id.read().clone()
}
pub async fn start_sse(&self, resumption_token: Option<String>) -> Result<()> {
let handle = self.abort_handle.write().take();
if let Some(handle) = handle {
handle.abort();
}
let url = self.config.read().url.clone();
let mut request = self
.build_request_with_middleware(
Method::GET,
url.as_str(),
vec![], )
.await?;
request.headers_mut().insert(
ACCEPT,
TEXT_EVENT_STREAM.parse().map_err(|e| {
Error::Transport(TransportError::InvalidMessage(format!(
"Invalid header: {}",
e
)))
})?,
);
if let Some(token) = &resumption_token {
request.headers_mut().insert(
LAST_EVENT_ID,
token.parse().map_err(|e| {
Error::Transport(TransportError::InvalidMessage(format!(
"Invalid header: {}",
e
)))
})?,
);
}
let response = self
.client
.request(request)
.await
.map_err(|e| Error::Transport(TransportError::Request(e.to_string())))?;
if response.status() == StatusCode::METHOD_NOT_ALLOWED {
return Ok(());
}
if !response.status().is_success() {
return Err(Error::Transport(TransportError::Request(format!(
"SSE request failed with status: {}",
response.status()
))));
}
self.process_response_headers(&response);
let body_bytes = response
.collect()
.await
.map_err(|e| Error::Transport(TransportError::Request(e.to_string())))?
.to_bytes();
let modified_body = if self.config.read().http_middleware_chain.is_some() {
let temp_response = HyperResponse::builder()
.status(200)
.body(Full::new(Bytes::new()))
.unwrap();
self.apply_response_middleware("GET", url.as_str(), &temp_response, body_bytes.to_vec())
.await?
} else {
body_bytes.to_vec()
};
let sender = self.sender.clone();
let on_resumption = self.config.read().on_resumption_token.clone();
let last_event_id = self.last_event_id.clone();
let handle = tokio::spawn(async move {
let mut sse_parser = SseParser::new();
let body = String::from_utf8_lossy(&modified_body);
let events = sse_parser.feed(&body);
for event in events {
if let Some(id) = &event.id {
*last_event_id.write() = Some(id.clone());
if let Some(callback) = &on_resumption {
callback(id.clone());
}
}
if event.event.as_deref() == Some("message") || event.event.is_none() {
if let Ok(msg) =
crate::shared::StdioTransport::parse_message(event.data.as_bytes())
{
let _ = sender.send(msg);
}
}
}
});
*self.abort_handle.write() = Some(handle);
Ok(())
}
async fn build_request_with_middleware(
&self,
method: Method,
url: &str,
body: Vec<u8>,
) -> Result<Request<Full<Bytes>>> {
use crate::client::http_middleware::{HttpMiddlewareContext, HttpRequest};
let (extra_headers, auth_provider, session_id, middleware_chain) = {
let config = self.config.read();
(
config.extra_headers.clone(),
config.auth_provider.clone(),
config.session_id.clone(),
config.http_middleware_chain.clone(),
)
};
let mut request_builder = Request::builder().method(method.clone()).uri(url);
for (key, value) in &extra_headers {
request_builder = request_builder.header(key.as_str(), value.as_str());
}
let has_auth = if let Some(auth_provider) = auth_provider {
let token = auth_provider.get_access_token().await?;
request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
true
} else {
false
};
if let Some(session_id) = &session_id {
request_builder = request_builder.header(MCP_SESSION_ID, session_id.as_str());
}
if let Some(protocol_version) = self.protocol_version.read().as_ref() {
request_builder =
request_builder.header(MCP_PROTOCOL_VERSION, protocol_version.as_str());
}
let temp_req = request_builder
.body(Full::new(Bytes::from(body.clone())))
.map_err(|e| Error::Transport(TransportError::InvalidMessage(e.to_string())))?;
let headers = temp_req.headers();
if let Some(chain) = middleware_chain {
let mut http_req = HttpRequest::new(method.as_str().to_string(), url.to_string(), body);
for (key, value) in headers {
if let Ok(value_str) = value.to_str() {
http_req.add_header(key.as_str(), value_str);
}
}
let context = HttpMiddlewareContext::new(url.to_string(), method.as_str().to_string());
if has_auth {
context.set_metadata("auth_already_set".to_string(), "true".to_string());
}
if let Err(e) = chain.process_request(&mut http_req, &context).await {
chain.handle_transport_error(&e, &context).await;
return Err(e);
}
let mut final_builder = Request::builder().method(method).uri(url);
for (key, value) in &http_req.headers {
final_builder = final_builder.header(key, value);
}
final_builder
.body(Full::new(Bytes::from(http_req.body)))
.map_err(|e| Error::Transport(TransportError::InvalidMessage(e.to_string())))
} else {
Ok(temp_req)
}
}
#[allow(clippy::future_not_send)]
async fn apply_response_middleware(
&self,
method: &str,
url: &str,
response: &HyperResponse<impl hyper::body::Body>,
body: Vec<u8>,
) -> Result<Vec<u8>> {
use crate::client::http_middleware::{HttpMiddlewareContext, HttpResponse};
let middleware_chain = self.config.read().http_middleware_chain.clone();
if let Some(chain) = middleware_chain {
let header_map = response.headers().clone();
let mut http_resp =
HttpResponse::with_headers(response.status().as_u16(), header_map, body);
let context = HttpMiddlewareContext::new(url.to_string(), method.to_string());
if let Err(e) = chain.process_response(&mut http_resp, &context).await {
chain.handle_transport_error(&e, &context).await;
return Err(e);
}
Ok(http_resp.body)
} else {
Ok(body)
}
}
fn process_response_headers(&self, response: &HyperResponse<impl hyper::body::Body>) {
if let Some(session_id) = response.headers().get(MCP_SESSION_ID) {
if let Ok(session_id_str) = session_id.to_str() {
self.config.write().session_id = Some(session_id_str.to_string());
}
}
if let Some(protocol_version) = response.headers().get(MCP_PROTOCOL_VERSION) {
if let Ok(protocol_version_str) = protocol_version.to_str() {
*self.protocol_version.write() = Some(protocol_version_str.to_string());
}
}
}
pub async fn send_with_options(
&mut self,
message: TransportMessage,
options: SendOptions,
) -> Result<()> {
if let Some(token) = options.resumption_token {
self.start_sse(Some(token)).await?;
return Ok(());
}
let body_bytes = crate::shared::StdioTransport::serialize_message(&message)?;
let body_bytes_snapshot = body_bytes.clone();
let url = self.config.read().url.clone();
let mut request = self
.build_request_with_middleware(Method::POST, url.as_str(), body_bytes)
.await?;
request.headers_mut().insert(
CONTENT_TYPE,
APPLICATION_JSON.parse().map_err(|e| {
Error::Transport(TransportError::InvalidMessage(format!(
"Invalid header: {}",
e
)))
})?,
);
request.headers_mut().insert(
ACCEPT,
ACCEPT_STREAMABLE.parse().map_err(|e| {
Error::Transport(TransportError::InvalidMessage(format!(
"Invalid header: {}",
e
)))
})?,
);
let response = self
.client
.request(request)
.await
.map_err(|e| Error::Transport(TransportError::Request(e.to_string())))?;
let retried_unauthorized = response.status() == StatusCode::UNAUTHORIZED;
let response = if retried_unauthorized {
let auth_provider = self.config.read().auth_provider.clone();
if let Some(provider) = auth_provider {
provider.on_unauthorized().await?;
let mut retry_request = self
.build_request_with_middleware(Method::POST, url.as_str(), body_bytes_snapshot)
.await?;
retry_request.headers_mut().insert(
CONTENT_TYPE,
APPLICATION_JSON.parse().map_err(|e| {
Error::Transport(TransportError::InvalidMessage(format!(
"Invalid header: {}",
e
)))
})?,
);
retry_request.headers_mut().insert(
ACCEPT,
ACCEPT_STREAMABLE.parse().map_err(|e| {
Error::Transport(TransportError::InvalidMessage(format!(
"Invalid header: {}",
e
)))
})?,
);
self.client
.request(retry_request)
.await
.map_err(|e| Error::Transport(TransportError::Request(e.to_string())))?
} else {
response
}
} else {
response
};
self.process_response_headers(&response);
if !response.status().is_success() {
if response.status() == StatusCode::ACCEPTED {
if matches!(message, TransportMessage::Notification { .. }) {
let _ = self.start_sse(None).await;
}
return Ok(());
}
return Err(Error::Transport(TransportError::Request(format!(
"Request failed with status: {}",
response.status()
))));
}
let status_code = response.status();
let content_type = response
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let content_length = response
.headers()
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<usize>().ok());
let body_bytes = response
.collect()
.await
.map_err(|e| Error::Transport(TransportError::Request(e.to_string())))?
.to_bytes();
tracing::debug!(
status = %status_code,
content_type = %content_type,
content_length = ?content_length,
body_len = body_bytes.len(),
"HTTP response received"
);
let modified_body = if self.config.read().http_middleware_chain.is_some() {
let temp_response = HyperResponse::builder()
.status(status_code)
.body(Full::new(Bytes::new()))
.unwrap();
self.apply_response_middleware(
"POST",
url.as_str(),
&temp_response,
body_bytes.to_vec(),
)
.await?
} else {
body_bytes.to_vec()
};
if status_code == StatusCode::OK && (content_length == Some(0) || content_type.is_empty()) {
if modified_body.is_empty() {
return Ok(());
}
if content_type.is_empty() {
return Err(Error::Transport(TransportError::Request(
"Response has body but no Content-Type header".to_string(),
)));
}
if let Ok(batch) = serde_json::from_slice::<Vec<serde_json::Value>>(&modified_body) {
for json_msg in batch {
let json_str = serde_json::to_string(&json_msg).map_err(|e| {
Error::Transport(TransportError::Deserialization(e.to_string()))
})?;
let msg = crate::shared::StdioTransport::parse_message(json_str.as_bytes())?;
self.sender
.send(msg)
.map_err(|e| Error::Transport(TransportError::Send(e.to_string())))?;
}
} else {
let msg_parsed = crate::shared::StdioTransport::parse_message(&modified_body)?;
self.sender
.send(msg_parsed)
.map_err(|e| Error::Transport(TransportError::Send(e.to_string())))?;
}
return Ok(());
}
if content_type.contains(APPLICATION_JSON) {
if modified_body.is_empty() {
if status_code == StatusCode::ACCEPTED {
tracing::debug!(
status = %status_code,
"Notification acknowledged with 202 Accepted"
);
return Ok(());
}
tracing::warn!(
status = %status_code,
content_type = %content_type,
"Server returned empty body with application/json content type"
);
return Err(Error::Transport(TransportError::Request(
"Server returned empty response body with Content-Type: application/json. \
This may indicate a server error or network issue."
.to_string(),
)));
}
if let Ok(batch) = serde_json::from_slice::<Vec<serde_json::Value>>(&modified_body) {
for json_msg in batch {
let json_str = serde_json::to_string(&json_msg).map_err(|e| {
Error::Transport(TransportError::Deserialization(e.to_string()))
})?;
let msg = crate::shared::StdioTransport::parse_message(json_str.as_bytes())?;
self.sender
.send(msg)
.map_err(|e| Error::Transport(TransportError::Send(e.to_string())))?;
}
} else {
let msg_parsed = crate::shared::StdioTransport::parse_message(&modified_body)?;
self.sender
.send(msg_parsed)
.map_err(|e| Error::Transport(TransportError::Send(e.to_string())))?;
}
} else if content_type.contains(TEXT_EVENT_STREAM) {
let sender = self.sender.clone();
let on_resumption = self.config.read().on_resumption_token.clone();
let last_event_id = self.last_event_id.clone();
tokio::spawn(async move {
let mut sse_parser = SseParser::new();
let body = String::from_utf8_lossy(&modified_body);
let events = sse_parser.feed(&body);
for event in events {
if let Some(id) = &event.id {
*last_event_id.write() = Some(id.clone());
if let Some(callback) = &on_resumption {
callback(id.clone());
}
}
if event.event.as_deref() == Some("message") || event.event.is_none() {
if let Ok(msg) =
crate::shared::StdioTransport::parse_message(event.data.as_bytes())
{
let _ = sender.send(msg);
}
}
}
});
} else if status_code == StatusCode::ACCEPTED {
return Ok(());
} else {
return Err(Error::Transport(TransportError::Request(format!(
"Unsupported content type: {}",
content_type
))));
}
Ok(())
}
}
#[async_trait]
impl Transport for StreamableHttpTransport {
async fn send(&mut self, message: TransportMessage) -> Result<()> {
self.send_with_options(message, SendOptions::default())
.await
}
async fn receive(&mut self) -> Result<TransportMessage> {
let mut receiver = self.receiver.lock().await;
receiver
.recv()
.await
.ok_or_else(|| Error::Transport(TransportError::ConnectionClosed))
}
async fn close(&mut self) -> Result<()> {
let handle = self.abort_handle.write().take();
if let Some(handle) = handle {
handle.abort();
}
if let Some(_session_id) = self.session_id() {
let url = self.config.read().url.clone();
let request = self
.build_request_with_middleware(Method::DELETE, url.as_str(), vec![])
.await?;
let response = self.client.request(request).await;
if let Ok(resp) = response {
if !resp.status().is_success() && resp.status() != StatusCode::METHOD_NOT_ALLOWED {
tracing::warn!("Failed to terminate session: {}", resp.status());
}
}
self.config.write().session_id = None;
}
Ok(())
}
fn is_connected(&self) -> bool {
true
}
}
#[async_trait]
pub trait AuthProvider: Send + Sync + Debug {
async fn get_access_token(&self) -> Result<String>;
async fn on_unauthorized(&self) -> Result<()> {
Ok(())
}
}
#[cfg(all(test, not(target_arch = "wasm32"), feature = "streamable-http"))]
mod tests {
use super::*;
use crate::shared::TransportMessage;
use mockito::Server as MockServer;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex as StdMutex;
use url::Url;
#[derive(Debug)]
struct CountingProvider {
token: String,
get_count: AtomicUsize,
unauthorized_count: AtomicUsize,
call_order: Option<StdMutex<Vec<&'static str>>>,
}
impl CountingProvider {
fn new(token: impl Into<String>) -> Self {
Self {
token: token.into(),
get_count: AtomicUsize::new(0),
unauthorized_count: AtomicUsize::new(0),
call_order: None,
}
}
fn with_order_tracking(token: impl Into<String>) -> Self {
Self {
token: token.into(),
get_count: AtomicUsize::new(0),
unauthorized_count: AtomicUsize::new(0),
call_order: Some(StdMutex::new(Vec::new())),
}
}
}
#[async_trait]
impl AuthProvider for CountingProvider {
async fn get_access_token(&self) -> Result<String> {
self.get_count.fetch_add(1, Ordering::SeqCst);
if let Some(order) = &self.call_order {
order.lock().unwrap().push("get_access_token");
}
Ok(self.token.clone())
}
async fn on_unauthorized(&self) -> Result<()> {
self.unauthorized_count.fetch_add(1, Ordering::SeqCst);
if let Some(order) = &self.call_order {
order.lock().unwrap().push("on_unauthorized");
}
Ok(())
}
}
fn make_transport(
url: Url,
provider: Option<Arc<dyn AuthProvider>>,
) -> StreamableHttpTransport {
let mut builder = StreamableHttpTransportConfigBuilder::new(url);
if let Some(p) = provider {
builder = builder.with_auth_provider(p);
}
let config = builder.build();
StreamableHttpTransport::new(config)
}
fn ping_message() -> TransportMessage {
use crate::types::{ClientNotification, Notification};
TransportMessage::Notification(Notification::Client(ClientNotification::Initialized))
}
fn list_tools_message() -> TransportMessage {
use crate::types::{ClientRequest, ListToolsRequest, Request, RequestId};
TransportMessage::Request {
id: RequestId::from(42i64),
request: Request::Client(Box::new(ClientRequest::ListTools(ListToolsRequest {
cursor: None,
}))),
}
}
#[tokio::test]
async fn test_on_unauthorized_default_noop_compiles_and_succeeds() {
#[derive(Debug)]
struct MinimalProvider;
#[async_trait]
impl AuthProvider for MinimalProvider {
async fn get_access_token(&self) -> Result<String> {
Ok("token".to_string())
}
}
let p = MinimalProvider;
let result = p.on_unauthorized().await;
assert!(
result.is_ok(),
"default on_unauthorized should return Ok(())"
);
}
#[tokio::test]
async fn test_max_one_retry_on_401() {
let mut server = MockServer::new_async().await;
let _m = server
.mock("POST", "/")
.with_status(401)
.with_header("content-type", "application/json")
.with_body(r#"{"error":"unauthorized"}"#)
.expect(2) .create_async()
.await;
let url = Url::parse(&server.url()).unwrap();
let provider = Arc::new(CountingProvider::new("initial-token"));
let mut transport = make_transport(url, Some(provider.clone() as Arc<dyn AuthProvider>));
let _ = transport
.send_with_options(ping_message(), SendOptions::default())
.await;
assert_eq!(
provider.unauthorized_count.load(Ordering::SeqCst),
1,
"on_unauthorized should be called exactly once"
);
assert_eq!(
provider.get_count.load(Ordering::SeqCst),
2,
"get_access_token should be called twice (once per attempt)"
);
}
#[tokio::test]
async fn test_on_unauthorized_not_called_for_non_401() {
let mut server = MockServer::new_async().await;
let _m200 = server
.mock("POST", "/")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"{"jsonrpc":"2.0","id":1,"result":{}}"#)
.create_async()
.await;
let url = Url::parse(&server.url()).unwrap();
let provider = Arc::new(CountingProvider::new("token"));
let mut transport =
make_transport(url.clone(), Some(provider.clone() as Arc<dyn AuthProvider>));
let _ = transport
.send_with_options(ping_message(), SendOptions::default())
.await;
assert_eq!(
provider.unauthorized_count.load(Ordering::SeqCst),
0,
"on_unauthorized must NOT be called on 200"
);
let mut server2 = MockServer::new_async().await;
let _m500 = server2
.mock("POST", "/")
.with_status(500)
.with_header("content-type", "application/json")
.with_body(r#"{"error":"server error"}"#)
.create_async()
.await;
let url2 = Url::parse(&server2.url()).unwrap();
let provider2 = Arc::new(CountingProvider::new("token"));
let mut transport2 = make_transport(url2, Some(provider2.clone() as Arc<dyn AuthProvider>));
let _ = transport2
.send_with_options(ping_message(), SendOptions::default())
.await;
assert_eq!(
provider2.unauthorized_count.load(Ordering::SeqCst),
0,
"on_unauthorized must NOT be called on 500"
);
}
#[tokio::test]
async fn test_retry_body_and_headers_are_byte_identical() {
use hyper::service::service_fn;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder as ServerBuilder;
use std::sync::Mutex as StdMutex;
use tokio::net::TcpListener;
#[derive(Debug, Default)]
struct Captured {
requests: Vec<(String, Vec<u8>, String)>,
}
#[derive(Debug)]
struct DualTokenProvider {
call_count: AtomicUsize,
}
#[async_trait]
impl AuthProvider for DualTokenProvider {
async fn get_access_token(&self) -> Result<String> {
let n = self.call_count.fetch_add(1, Ordering::SeqCst);
if n == 0 {
Ok("token-attempt-1".to_string())
} else {
Ok("token-attempt-2".to_string())
}
}
}
let captured = Arc::new(StdMutex::new(Captured::default()));
let captured_clone = captured.clone();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let cap = captured_clone.clone();
tokio::spawn(async move {
let mut attempt = 0u8;
loop {
let (stream, _) = listener.accept().await.unwrap();
let cap = cap.clone();
let io = hyper_util::rt::TokioIo::new(stream);
tokio::spawn(async move {
let _ = ServerBuilder::new(TokioExecutor::new())
.serve_connection(
io,
service_fn(move |req: Request<hyper::body::Incoming>| {
let cap = cap.clone();
async move {
let method = req.method().to_string();
let auth = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let body_bytes = req
.collect()
.await
.map(|b| b.to_bytes().to_vec())
.unwrap_or_default();
cap.lock()
.unwrap()
.requests
.push((method, body_bytes, auth));
let status = {
let len = cap.lock().unwrap().requests.len();
if len == 1 {
401u16
} else {
200u16
}
};
Ok::<_, hyper::Error>(
HyperResponse::builder()
.status(status)
.header("content-type", "application/json")
.body(Full::new(Bytes::from(if status == 200 {
r#"{"jsonrpc":"2.0","id":1,"result":{}}"#
} else {
r#"{"error":"unauthorized"}"#
})))
.unwrap(),
)
}
}),
)
.await;
});
attempt += 1;
if attempt >= 2 {
break;
}
}
});
let provider = Arc::new(DualTokenProvider {
call_count: AtomicUsize::new(0),
});
let url = Url::parse(&format!("http://127.0.0.1:{}", addr.port())).unwrap();
let mut transport = make_transport(url, Some(provider as Arc<dyn AuthProvider>));
let _ = transport
.send_with_options(list_tools_message(), SendOptions::default())
.await;
let cap = captured.lock().unwrap();
assert_eq!(
cap.requests.len(),
2,
"expected exactly 2 requests (original + retry)"
);
let (method1, body1, auth1) = &cap.requests[0];
let (method2, body2, auth2) = &cap.requests[1];
assert_eq!(
method1, method2,
"method must be byte-identical across retry"
);
assert_eq!(body1, body2, "body must be byte-identical across retry");
assert_ne!(
auth1, auth2,
"Authorization header should differ (new token)"
);
assert!(auth1.contains("token-attempt-1"), "first auth: {}", auth1);
assert!(auth2.contains("token-attempt-2"), "retry auth: {}", auth2);
}
#[tokio::test]
async fn test_on_unauthorized_called_before_get_access_token_on_retry() {
let mut server = MockServer::new_async().await;
let _m = server
.mock("POST", "/")
.with_status(401)
.with_header("content-type", "application/json")
.with_body(r#"{"error":"unauthorized"}"#)
.expect(2)
.create_async()
.await;
let url = Url::parse(&server.url()).unwrap();
let provider = Arc::new(CountingProvider::with_order_tracking("token"));
let mut transport = make_transport(url, Some(provider.clone() as Arc<dyn AuthProvider>));
let _ = transport
.send_with_options(ping_message(), SendOptions::default())
.await;
let order = provider
.call_order
.as_ref()
.unwrap()
.lock()
.unwrap()
.clone();
assert!(
order.len() >= 3,
"expected at least 3 calls, got {:?}",
order
);
let unauth_pos = order
.iter()
.position(|&s| s == "on_unauthorized")
.expect("on_unauthorized must appear in call order");
let retry_get_pos = order
.iter()
.skip(unauth_pos + 1)
.position(|&s| s == "get_access_token");
assert!(
retry_get_pos.is_some(),
"get_access_token must be called AFTER on_unauthorized; order = {:?}",
order
);
}
}