use std::{borrow::Cow, collections::HashMap, sync::Arc};
use futures::{StreamExt, stream::BoxStream};
use http::{HeaderName, HeaderValue, header::WWW_AUTHENTICATE};
use reqwest::header::ACCEPT;
use sse_stream::{Sse, SseStream};
use crate::{
model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
transport::{
common::http_header::{
EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION,
HEADER_SESSION_ID, JSON_MIME_TYPE,
},
streamable_http_client::*,
},
};
impl From<reqwest::Error> for StreamableHttpError<reqwest::Error> {
fn from(e: reqwest::Error) -> Self {
StreamableHttpError::Client(e)
}
}
const RESERVED_HEADERS: &[&str] = &[
"accept",
HEADER_SESSION_ID,
HEADER_MCP_PROTOCOL_VERSION,
HEADER_LAST_EVENT_ID,
];
fn apply_custom_headers(
mut builder: reqwest::RequestBuilder,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<reqwest::RequestBuilder, StreamableHttpError<reqwest::Error>> {
for (name, value) in custom_headers {
if RESERVED_HEADERS
.iter()
.any(|&r| name.as_str().eq_ignore_ascii_case(r))
{
if name
.as_str()
.eq_ignore_ascii_case(HEADER_MCP_PROTOCOL_VERSION)
{
builder = builder.header(name, value);
continue;
}
return Err(StreamableHttpError::ReservedHeaderConflict(
name.to_string(),
));
}
builder = builder.header(name, value);
}
Ok(builder)
}
impl StreamableHttpClient for reqwest::Client {
type Error = reqwest::Error;
async fn get_stream(
&self,
uri: Arc<str>,
session_id: Arc<str>,
last_event_id: Option<String>,
auth_token: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<BoxStream<'static, Result<Sse, SseError>>, StreamableHttpError<Self::Error>> {
let mut request_builder = self
.get(uri.as_ref())
.header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "))
.header(HEADER_SESSION_ID, session_id.as_ref());
if let Some(last_event_id) = last_event_id {
request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id);
}
if let Some(auth_header) = auth_token {
request_builder = request_builder.bearer_auth(auth_header);
}
request_builder = apply_custom_headers(request_builder, custom_headers)?;
let response = request_builder.send().await?;
if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
return Err(StreamableHttpError::ServerDoesNotSupportSse);
}
let response = response.error_for_status()?;
match response.headers().get(reqwest::header::CONTENT_TYPE) {
Some(ct) => {
if !ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes())
&& !ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes())
{
return Err(StreamableHttpError::UnexpectedContentType(Some(
String::from_utf8_lossy(ct.as_bytes()).to_string(),
)));
}
}
None => {
return Err(StreamableHttpError::UnexpectedContentType(None));
}
}
let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed();
Ok(event_stream)
}
async fn delete_session(
&self,
uri: Arc<str>,
session: Arc<str>,
auth_token: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<(), StreamableHttpError<Self::Error>> {
let mut request_builder = self.delete(uri.as_ref());
if let Some(auth_header) = auth_token {
request_builder = request_builder.bearer_auth(auth_header);
}
request_builder = request_builder.header(HEADER_SESSION_ID, session.as_ref());
request_builder = apply_custom_headers(request_builder, custom_headers)?;
let response = request_builder.send().await?;
if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
tracing::debug!("this server doesn't support deleting session");
return Ok(());
}
let _response = response.error_for_status()?;
Ok(())
}
async fn post_message(
&self,
uri: Arc<str>,
message: ClientJsonRpcMessage,
session_id: Option<Arc<str>>,
auth_token: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>> {
let mut request = self
.post(uri.as_ref())
.header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "));
if let Some(auth_header) = auth_token {
request = request.bearer_auth(auth_header);
}
request = apply_custom_headers(request, custom_headers)?;
if let Some(session_id) = session_id {
request = request.header(HEADER_SESSION_ID, session_id.as_ref());
}
let response = request.json(&message).send().await?;
if response.status() == reqwest::StatusCode::UNAUTHORIZED {
if let Some(header) = response.headers().get(WWW_AUTHENTICATE) {
let header = header
.to_str()
.map_err(|_| {
StreamableHttpError::UnexpectedServerResponse(Cow::from(
"invalid www-authenticate header value",
))
})?
.to_string();
return Err(StreamableHttpError::AuthRequired(AuthRequiredError {
www_authenticate_header: header,
}));
}
}
if response.status() == reqwest::StatusCode::FORBIDDEN {
if let Some(header) = response.headers().get(WWW_AUTHENTICATE) {
let header_str = header.to_str().map_err(|_| {
StreamableHttpError::UnexpectedServerResponse(Cow::from(
"invalid www-authenticate header value",
))
})?;
let scope = extract_scope_from_header(header_str);
return Err(StreamableHttpError::InsufficientScope(
InsufficientScopeError {
www_authenticate_header: header_str.to_string(),
required_scope: scope,
},
));
}
}
let status = response.status();
if matches!(
status,
reqwest::StatusCode::ACCEPTED | reqwest::StatusCode::NO_CONTENT
) {
return Ok(StreamableHttpPostResponse::Accepted);
}
if !status.is_success() {
let body = response
.text()
.await
.unwrap_or_else(|_| "<failed to read response body>".to_owned());
return Err(StreamableHttpError::UnexpectedServerResponse(Cow::Owned(
format!("HTTP {status}: {body}"),
)));
}
let content_type = response.headers().get(reqwest::header::CONTENT_TYPE);
let session_id = response.headers().get(HEADER_SESSION_ID);
let session_id = session_id
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
match content_type {
Some(ct) if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) => {
let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed();
Ok(StreamableHttpPostResponse::Sse(event_stream, session_id))
}
Some(ct) if ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) => {
match response.json::<ServerJsonRpcMessage>().await {
Ok(message) => Ok(StreamableHttpPostResponse::Json(message, session_id)),
Err(e) => {
tracing::warn!(
"could not parse JSON response as ServerJsonRpcMessage, treating as accepted: {e}"
);
Ok(StreamableHttpPostResponse::Accepted)
}
}
}
_ => {
tracing::error!("unexpected content type: {:?}", content_type);
Err(StreamableHttpError::UnexpectedContentType(
content_type.map(|ct| String::from_utf8_lossy(ct.as_bytes()).to_string()),
))
}
}
}
}
impl StreamableHttpClientTransport<reqwest::Client> {
pub fn from_uri(uri: impl Into<Arc<str>>) -> Self {
StreamableHttpClientTransport::with_client(
reqwest::Client::default(),
StreamableHttpClientTransportConfig {
uri: uri.into(),
auth_header: None,
..Default::default()
},
)
}
pub fn from_config(config: StreamableHttpClientTransportConfig) -> Self {
StreamableHttpClientTransport::with_client(reqwest::Client::default(), config)
}
}
fn extract_scope_from_header(header: &str) -> Option<String> {
let header_lowercase = header.to_ascii_lowercase();
let scope_key = "scope=";
if let Some(pos) = header_lowercase.find(scope_key) {
let start = pos + scope_key.len();
let value_slice = &header[start..];
if let Some(stripped) = value_slice.strip_prefix('"') {
if let Some(end_quote) = stripped.find('"') {
return Some(stripped[..end_quote].to_string());
}
} else {
let end = value_slice
.find(|c: char| c == ',' || c == ';' || c.is_whitespace())
.unwrap_or(value_slice.len());
if end > 0 {
return Some(value_slice[..end].to_string());
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::extract_scope_from_header;
use crate::transport::streamable_http_client::InsufficientScopeError;
#[test]
fn extract_scope_quoted() {
let header = r#"Bearer error="insufficient_scope", scope="files:read files:write""#;
assert_eq!(
extract_scope_from_header(header),
Some("files:read files:write".to_string())
);
}
#[test]
fn extract_scope_unquoted() {
let header = r#"Bearer scope=read:data, error="insufficient_scope""#;
assert_eq!(
extract_scope_from_header(header),
Some("read:data".to_string())
);
}
#[test]
fn extract_scope_missing() {
let header = r#"Bearer error="invalid_token""#;
assert_eq!(extract_scope_from_header(header), None);
}
#[test]
fn extract_scope_empty_header() {
assert_eq!(extract_scope_from_header("Bearer"), None);
}
#[test]
fn insufficient_scope_error_can_upgrade() {
let with_scope = InsufficientScopeError {
www_authenticate_header: "Bearer scope=\"admin\"".to_string(),
required_scope: Some("admin".to_string()),
};
assert!(with_scope.can_upgrade());
assert_eq!(with_scope.get_required_scope(), Some("admin"));
let without_scope = InsufficientScopeError {
www_authenticate_header: "Bearer error=\"insufficient_scope\"".to_string(),
required_scope: None,
};
assert!(!without_scope.can_upgrade());
assert_eq!(without_scope.get_required_scope(), None);
}
}