use futures::stream::BoxStream;
use futures::StreamExt;
use http::header::WWW_AUTHENTICATE;
use reqwest::header::ACCEPT;
use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage};
use rmcp::transport::common::http_header::{
EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE,
};
use rmcp::transport::streamable_http_client::{
AuthRequiredError, StreamableHttpClient, StreamableHttpError, StreamableHttpPostResponse,
};
use sse_stream::{Error as SseError, Sse, SseStream};
use std::sync::Arc;
use systemprompt_models::RequestContext;
use systemprompt_traits::ContextPropagation;
#[derive(Clone, Debug)]
pub struct HttpClientWithContext {
client: reqwest::Client,
context: RequestContext,
}
impl HttpClientWithContext {
pub fn new(context: RequestContext) -> Self {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(600))
.connect_timeout(std::time::Duration::from_secs(30))
.tcp_keepalive(Some(std::time::Duration::from_secs(60)))
.build()
.unwrap_or_else(|_| reqwest::Client::default());
Self { client, context }
}
fn add_context_headers(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
let headers = self.context.to_headers();
let mut builder = builder;
for (key, value) in &headers {
builder = builder.header(key, value);
}
if !self.context.auth_token().as_str().is_empty() {
let auth_header = format!("Bearer {}", self.context.auth_token().as_str());
builder = builder.header("Authorization", &auth_header);
}
builder
}
}
impl StreamableHttpClient for HttpClientWithContext {
type Error = reqwest::Error;
async fn get_stream(
&self,
uri: Arc<str>,
session_id: Arc<str>,
last_event_id: Option<String>,
auth_token: Option<String>,
) -> Result<BoxStream<'static, Result<Sse, SseError>>, StreamableHttpError<Self::Error>> {
let mut request_builder = self
.client
.get(uri.as_ref())
.header(ACCEPT, EVENT_STREAM_MIME_TYPE)
.header(HEADER_SESSION_ID, session_id.as_ref());
request_builder = self.add_context_headers(request_builder);
if let Some(last_event_id) = last_event_id {
request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id);
}
if let Some(auth_header) = auth_token {
request_builder = request_builder.bearer_auth(auth_header);
}
let response = request_builder.send().await?;
if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
return Err(StreamableHttpError::ServerDoesNotSupportSse);
}
let response = response.error_for_status()?;
match response.headers().get(reqwest::header::CONTENT_TYPE) {
Some(ct) => {
if !ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) {
return Err(StreamableHttpError::UnexpectedContentType(Some(
String::from_utf8_lossy(ct.as_bytes()).to_string(),
)));
}
},
None => {
return Err(StreamableHttpError::UnexpectedContentType(None));
},
}
let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed();
Ok(event_stream)
}
async fn delete_session(
&self,
uri: Arc<str>,
session: Arc<str>,
auth_token: Option<String>,
) -> Result<(), StreamableHttpError<Self::Error>> {
let mut request_builder = self.client.delete(uri.as_ref());
request_builder = self.add_context_headers(request_builder);
if let Some(auth_header) = auth_token {
request_builder = request_builder.bearer_auth(auth_header);
}
let response = request_builder
.header(HEADER_SESSION_ID, session.as_ref())
.send()
.await?;
if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
return Ok(());
}
let _response = response.error_for_status()?;
Ok(())
}
async fn post_message(
&self,
uri: Arc<str>,
message: ClientJsonRpcMessage,
session_id: Option<Arc<str>>,
auth_token: Option<String>,
) -> Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>> {
let mut request = self
.client
.post(uri.as_ref())
.header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "));
request = self.add_context_headers(request);
if let Some(auth_header) = auth_token {
request = request.bearer_auth(auth_header);
}
if let Some(ref session_id) = session_id {
request = request.header(HEADER_SESSION_ID, session_id.as_ref());
}
let response = request.json(&message).send().await?;
if response.status() == reqwest::StatusCode::UNAUTHORIZED {
if let Some(header) = response.headers().get(WWW_AUTHENTICATE) {
let header = header
.to_str()
.map_err(|_| {
StreamableHttpError::UnexpectedServerResponse(std::borrow::Cow::from(
"invalid www-authenticate header value",
))
})?
.to_string();
return Err(StreamableHttpError::AuthRequired(AuthRequiredError {
www_authenticate_header: header,
}));
}
}
let response = response.error_for_status()?;
if response.status() == reqwest::StatusCode::ACCEPTED {
return Ok(StreamableHttpPostResponse::Accepted);
}
let content_type = response.headers().get(reqwest::header::CONTENT_TYPE);
let session_id = response.headers().get(HEADER_SESSION_ID);
let session_id = session_id
.and_then(|v| v.to_str().ok())
.map(ToString::to_string);
match content_type {
Some(ct) if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) => {
let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed();
Ok(StreamableHttpPostResponse::Sse(event_stream, session_id))
},
Some(ct) if ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) => {
let message: ServerJsonRpcMessage = response.json().await?;
Ok(StreamableHttpPostResponse::Json(message, session_id))
},
_ => Err(StreamableHttpError::UnexpectedContentType(
content_type.map(|ct| String::from_utf8_lossy(ct.as_bytes()).to_string()),
)),
}
}
}