use self::mcp_session::McpSession;
use crate::{
error::{Error, ErrorCode},
transport::http::{ClientRuntimeContext, MCP_SESSION_ID, get_mcp_session_id},
types::Message,
};
use futures_util::{StreamExt, TryStreamExt};
use reqwest::{
RequestBuilder,
header::{ACCEPT, CACHE_CONTROL, CONTENT_TYPE, HeaderName},
};
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "client-tls")]
use tls_config::ClientTlsConfig;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
pub(super) mod mcp_session;
#[cfg(feature = "client-tls")]
pub(crate) mod tls_config;
const LAST_EVENT_ID: HeaderName = HeaderName::from_static("last-event-id");
const SSE_RECONNECT_DELAY: Duration = Duration::from_secs(3);
pub(super) async fn connect(rt: ClientRuntimeContext, token: CancellationToken) {
let session = Arc::new(McpSession::new(rt.url, token));
let access_token: Option<Arc<[u8]>> = rt.access_token.map(|t| t.into());
tokio::join!(
handle_connection(
session.clone(),
rt.rx,
rt.tx.clone(),
access_token.clone(),
#[cfg(feature = "client-tls")]
rt.tls_config.clone()
),
start_sse_connection(
session.clone(),
rt.tx.clone(),
access_token.clone(),
#[cfg(feature = "client-tls")]
rt.tls_config.clone()
)
);
}
async fn handle_connection(
session: Arc<McpSession>,
mut sender_rx: mpsc::Receiver<Message>,
recv_tx: mpsc::Sender<Result<Message, Error>>,
access_token: Option<Arc<[u8]>>,
#[cfg(feature = "client-tls")] tls_config: Option<ClientTlsConfig>,
) {
#[cfg(not(feature = "client-tls"))]
let client = match create_client() {
Ok(client) => client,
Err(_err) => {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "HTTP client error: {_err:#}");
return;
}
};
#[cfg(feature = "client-tls")]
let client = match create_client(tls_config) {
Ok(client) => client,
Err(_err) => {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "HTTP client error: {_err:#}");
return;
}
};
let token = session.cancellation_token();
loop {
tokio::select! {
biased;
_ = token.cancelled() => return,
req = sender_rx.recv() => {
let Some(req) = req else {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "Unexpected messaging error");
break;
};
let mut resp = client
.post(session.url().as_str().as_ref())
.json(&req)
.header(CONTENT_TYPE, "application/json")
.header(ACCEPT, "application/json, text/event-stream");
if let Some(session_id) = session.session_id() {
resp = resp.header(MCP_SESSION_ID, session_id.to_string())
}
if let Some(access_token) = &access_token {
resp = resp.bearer_auth(String::from_utf8_lossy(access_token))
}
crate::spawn_fair!(send_request(
session.clone(),
resp,
req,
recv_tx.clone()
));
}
}
}
}
async fn send_request(
session: Arc<McpSession>,
resp: RequestBuilder,
req: Message,
resp_tx: mpsc::Sender<Result<Message, Error>>,
) {
let resp = match resp.send().await {
Ok(resp) => resp,
Err(_err) => {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "Failed to send HTTP request: {}", _err);
return;
}
};
if let Message::Notification(_) = &req {
return;
}
if let Message::Batch(ref batch) = req
&& !batch.has_requests()
{
return;
}
if !session.has_session_id()
&& let Some(session_id) = get_mcp_session_id(resp.headers())
{
session.set_session_id(session_id);
}
if let Message::Request(r) = req
&& r.method == crate::commands::INIT
{
let token = session.cancellation_token();
session.notify_session_initialized();
tokio::select! {
biased;
_ = token.cancelled() => return,
_ = session.sse_ready() => {},
}
}
let resp = resp.json::<Message>().await;
if let Err(_err) = resp_tx.send(resp.map_err(Error::from)).await {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "Failed to send response: {}", _err);
}
}
async fn start_sse_connection(
session: Arc<McpSession>,
resp_tx: mpsc::Sender<Result<Message, Error>>,
access_token: Option<Arc<[u8]>>,
#[cfg(feature = "client-tls")] tls_config: Option<ClientTlsConfig>,
) {
let token = session.cancellation_token();
tokio::select! {
biased;
_ = token.cancelled() => (),
_ = session.initialized() => {
tokio::spawn(handle_sse_connection(
session.clone(),
resp_tx,
access_token,
#[cfg(feature = "client-tls")]
tls_config
));
}
}
}
async fn handle_sse_connection(
session: Arc<McpSession>,
resp_tx: mpsc::Sender<Result<Message, Error>>,
access_token: Option<Arc<[u8]>>,
#[cfg(feature = "client-tls")] tls_config: Option<ClientTlsConfig>,
) {
#[cfg(not(feature = "client-tls"))]
let client = match create_client() {
Ok(client) => client,
Err(_err) => {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "SSE client error: {_err:#}");
return;
}
};
#[cfg(feature = "client-tls")]
let client = match create_client(tls_config) {
Ok(client) => client,
Err(_err) => {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "SSE client error: {_err:#}");
return;
}
};
let token = session.cancellation_token();
loop {
let mut req = client
.get(session.url().as_str().as_ref())
.header(ACCEPT, "application/json, text/event-stream")
.header(CACHE_CONTROL, "no-cache");
if let Some(ref access_token) = access_token {
req = req.bearer_auth(String::from_utf8_lossy(access_token));
}
if let Some(session_id) = session.session_id() {
req = req.header(MCP_SESSION_ID, session_id.to_string());
}
if let Some(last_id) = session.last_event_id() {
req = req.header(LAST_EVENT_ID, last_id);
}
let resp = match req.send().await {
Ok(resp) => resp,
Err(_err) => {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "Failed to send SSE request: {}", _err);
session.cancellation_token().cancel();
return;
}
};
if !resp.status().is_success() {
#[cfg(feature = "tracing")]
tracing::error!(
logger = "neva",
"SSE request failed with status: {}",
resp.status()
);
session.cancellation_token().cancel();
return;
}
let mut stream = sse_stream::SseStream::from_byte_stream(resp.bytes_stream())
.fuse()
.map_ok(|event| handle_event(event, &session, &resp_tx))
.map_err(handle_error);
session.notify_sse_initialized();
loop {
tokio::select! {
biased;
_ = token.cancelled() => return,
fut = stream.next() => {
let Some(Ok(fut)) = fut else {
#[cfg(feature = "tracing")]
tracing::info!(logger = "neva", "SSE stream ended, reconnecting");
break;
};
fut.await;
}
}
}
tokio::select! {
biased;
_ = token.cancelled() => return,
_ = tokio::time::sleep(SSE_RECONNECT_DELAY) => {}
}
}
}
async fn handle_event(
event: sse_stream::Sse,
session: &Arc<McpSession>,
resp_tx: &mpsc::Sender<Result<Message, Error>>,
) {
let id = event.id.clone();
let delivered = if event.is_message() {
handle_msg(event, resp_tx).await
} else {
#[cfg(feature = "tracing")]
tracing::debug!(logger = "neva", event = ?event);
true
};
if delivered && let Some(id) = id {
session.set_last_event_id(id);
}
}
#[inline]
fn handle_error(_err: sse_stream::Error) {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "SSE Error: {}", _err);
}
async fn handle_msg(
event: sse_stream::Sse,
resp_tx: &mpsc::Sender<Result<Message, Error>>,
) -> bool {
let Some(data) = event.data else {
return false;
};
let msg = serde_json::from_str::<Message>(&data);
let parsed_ok = msg.is_ok();
if let Err(_err) = resp_tx.send(msg.map_err(Error::from)).await {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "Failed to send server request: {}", _err);
return false;
}
parsed_ok
}
#[inline]
#[cfg(not(feature = "client-tls"))]
fn create_client() -> Result<reqwest::Client, Error> {
reqwest::Client::builder().build().map_err(Error::from)
}
#[inline]
#[cfg(feature = "client-tls")]
fn create_client(mut tls_config: Option<ClientTlsConfig>) -> Result<reqwest::Client, Error> {
let mut builder = reqwest::ClientBuilder::new();
if let Some(ca_cert) = tls_config.as_mut().and_then(|tls| tls.ca.take()) {
builder = builder.add_root_certificate(ca_cert);
}
if let Some(identity) = tls_config.as_mut().and_then(|tls| tls.identity.take()) {
builder = builder.identity(identity);
}
if tls_config.is_some_and(|tls| !tls.certs_verification) {
builder = builder.danger_accept_invalid_certs(true);
}
builder.build().map_err(Error::from)
}
impl From<reqwest::Error> for Error {
#[inline]
fn from(err: reqwest::Error) -> Self {
Error::new(ErrorCode::ParseError, err.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::http::ServiceUrl;
fn make_session() -> Arc<McpSession> {
Arc::new(McpSession::new(
ServiceUrl::default(),
CancellationToken::new(),
))
}
const VALID_MSG: &str = r#"{"jsonrpc":"2.0","method":"ping"}"#;
#[tokio::test]
async fn it_advances_last_event_id_on_successful_delivery() {
let session = make_session();
let (tx, mut rx) = mpsc::channel(1);
let event = sse_stream::Sse::default().id("evt-1").data(VALID_MSG);
handle_event(event, &session, &tx).await;
assert_eq!(session.last_event_id(), Some("evt-1".to_string()));
assert!(rx.try_recv().is_ok(), "message should have been delivered");
}
#[tokio::test]
async fn it_does_not_advance_last_event_id_on_parse_failure() {
let session = make_session();
let (tx, _rx) = mpsc::channel(1);
let event = sse_stream::Sse::default()
.id("evt-bad")
.data("not { valid json");
handle_event(event, &session, &tx).await;
assert!(session.last_event_id().is_none());
}
#[tokio::test]
async fn it_does_not_advance_last_event_id_when_channel_closed() {
let session = make_session();
let (tx, rx) = mpsc::channel(1);
drop(rx);
let event = sse_stream::Sse::default().id("evt-dropped").data(VALID_MSG);
handle_event(event, &session, &tx).await;
assert!(session.last_event_id().is_none());
}
#[tokio::test]
async fn it_advances_last_event_id_for_non_message_event() {
let session = make_session();
let (tx, _rx) = mpsc::channel(1);
let event = sse_stream::Sse::default()
.id("evt-keepalive")
.event("keepalive");
handle_event(event, &session, &tx).await;
assert_eq!(session.last_event_id(), Some("evt-keepalive".to_string()));
}
#[tokio::test]
async fn it_does_not_advance_last_event_id_when_data_is_absent() {
let session = make_session();
let (tx, _rx) = mpsc::channel(1);
let event = sse_stream::Sse::default().id("evt-empty");
handle_event(event, &session, &tx).await;
assert!(session.last_event_id().is_none());
}
}