use bytes::Bytes;
use futures::StreamExt;
use reqwest::{Client as HttpClient, header};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, RwLock, mpsc};
use tracing::{debug, error, info, warn};
use turbomcp_protocol::MessageId;
use turbomcp_transport_traits::{
LimitsConfig, TlsConfig, TlsVersion, Transport, TransportCapabilities, TransportError,
TransportEventEmitter, TransportMessage, TransportMetrics, TransportResult, TransportState,
TransportType, validate_request_size, validate_response_size,
};
#[derive(Clone, Debug)]
pub enum RetryPolicy {
Fixed {
interval: Duration,
max_attempts: Option<u32>,
},
Exponential {
base: Duration,
max_delay: Duration,
max_attempts: Option<u32>,
},
Never,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self::Exponential {
base: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
max_attempts: Some(10),
}
}
}
impl RetryPolicy {
pub(crate) fn delay(&self, attempt: u32) -> Option<Duration> {
match self {
Self::Fixed {
interval,
max_attempts,
} => {
if let Some(max) = max_attempts
&& attempt >= *max
{
return None;
}
Some(*interval)
}
Self::Exponential {
base,
max_delay,
max_attempts,
} => {
if let Some(max) = max_attempts
&& attempt >= *max
{
return None;
}
let base_delay = base.as_millis() as u64 * 2u64.pow(attempt);
let max_delay_ms = max_delay.as_millis() as u64;
let capped = base_delay.min(max_delay_ms);
let jitter_range = capped / 4;
let jitter_offset = if jitter_range > 0 {
let hash = (attempt as u64)
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
hash % (jitter_range * 2)
} else {
0
};
let final_delay = capped
.saturating_sub(jitter_range)
.saturating_add(jitter_offset);
Some(Duration::from_millis(final_delay))
}
Self::Never => None,
}
}
}
#[derive(Clone, Debug)]
pub struct StreamableHttpClientConfig {
pub base_url: String,
pub endpoint_path: String,
pub timeout: Duration,
pub retry_policy: RetryPolicy,
pub auth_token: Option<String>,
pub headers: HashMap<String, String>,
pub user_agent: Option<String>,
pub protocol_version: String,
pub limits: LimitsConfig,
pub tls: TlsConfig,
}
impl Default for StreamableHttpClientConfig {
fn default() -> Self {
Self {
base_url: "http://localhost:8080".to_string(),
endpoint_path: "/mcp".to_string(),
timeout: Duration::from_secs(30),
retry_policy: RetryPolicy::default(),
auth_token: None,
headers: HashMap::new(),
user_agent: Some(format!("TurboMCP-Client/{}", env!("CARGO_PKG_VERSION"))),
protocol_version: "2025-11-25".to_string(),
limits: LimitsConfig::default(),
tls: TlsConfig::default(),
}
}
}
pub struct StreamableHttpClientTransport {
config: StreamableHttpClientConfig,
http_client: HttpClient,
state: Arc<RwLock<TransportState>>,
capabilities: TransportCapabilities,
metrics: Arc<RwLock<TransportMetrics>>,
_event_emitter: TransportEventEmitter,
message_endpoint: Arc<RwLock<Option<String>>>,
session_id: Arc<RwLock<Option<String>>>,
last_event_id: Arc<RwLock<Option<String>>>,
sse_receiver: Arc<Mutex<mpsc::Receiver<TransportMessage>>>,
sse_sender: mpsc::Sender<TransportMessage>,
response_receiver: Arc<Mutex<mpsc::Receiver<TransportMessage>>>,
response_sender: mpsc::Sender<TransportMessage>,
sse_task_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
}
impl std::fmt::Debug for StreamableHttpClientTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamableHttpClientTransport")
.field("base_url", &self.config.base_url)
.field("endpoint", &self.config.endpoint_path)
.finish()
}
}
impl StreamableHttpClientTransport {
pub fn new(config: StreamableHttpClientConfig) -> Self {
let (sse_tx, sse_rx) = mpsc::channel(1000);
let (response_tx, response_rx) = mpsc::channel(100);
let (event_emitter, _) = TransportEventEmitter::new();
if config.tls.is_insecure() {
warn!(
"Certificate validation is disabled. This is insecure and should only be used \
for testing or in secure mTLS mesh environments. \
See https://turbomcp.org/docs/security/tls#certificate-validation"
);
}
let mut client_builder = HttpClient::builder()
.use_rustls_tls()
.timeout(config.timeout);
if let Some(ref user_agent) = config.user_agent {
client_builder = client_builder.user_agent(user_agent);
}
client_builder = match config.tls.min_version {
TlsVersion::Tls13 => client_builder.min_tls_version(reqwest::tls::Version::TLS_1_3),
};
if !config.tls.validate_certificates {
const INSECURE_TLS_ENV_VAR: &str = "TURBOMCP_ALLOW_INSECURE_TLS";
if std::env::var(INSECURE_TLS_ENV_VAR).is_err() {
error!(
"SECURITY: Certificate validation disabled but {} not set. \
Overriding to validate_certificates=true for safety. \
Set {}=1 to allow insecure TLS.",
INSECURE_TLS_ENV_VAR, INSECURE_TLS_ENV_VAR
);
} else {
warn!(
"SECURITY WARNING: TLS certificate validation is DISABLED. \
This configuration is INSECURE and should ONLY be used: \
(1) In development/testing environments, or \
(2) In secure mTLS mesh where validation happens elsewhere. \
NEVER use in production connecting to untrusted servers."
);
client_builder = client_builder.danger_accept_invalid_certs(true);
}
}
if let Some(ca_certs) = &config.tls.custom_ca_certs {
let mut loaded = 0usize;
let total = ca_certs.len();
for cert_bytes in ca_certs {
if let Ok(cert) = reqwest::Certificate::from_pem(cert_bytes) {
client_builder = client_builder.add_root_certificate(cert);
loaded += 1;
} else if let Ok(cert) = reqwest::Certificate::from_der(cert_bytes) {
client_builder = client_builder.add_root_certificate(cert);
loaded += 1;
} else {
warn!(
"Failed to parse custom CA certificate ({}/{}), skipping",
loaded + 1,
total
);
}
}
if loaded == 0 && total > 0 {
error!("All {} custom CA certificates failed to parse", total);
}
if loaded > 0 {
info!("Loaded {}/{} custom CA certificates", loaded, total);
}
}
let http_client = client_builder.build().expect("Failed to build HTTP client");
Self {
config,
http_client,
state: Arc::new(RwLock::new(TransportState::Disconnected)),
capabilities: TransportCapabilities {
max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
supports_compression: false,
supports_streaming: true,
supports_bidirectional: true,
supports_multiplexing: false,
compression_algorithms: Vec::new(),
custom: HashMap::new(),
},
metrics: Arc::new(RwLock::new(TransportMetrics::default())),
_event_emitter: event_emitter,
message_endpoint: Arc::new(RwLock::new(None)),
session_id: Arc::new(RwLock::new(None)),
last_event_id: Arc::new(RwLock::new(None)),
sse_receiver: Arc::new(Mutex::new(sse_rx)),
sse_sender: sse_tx,
response_receiver: Arc::new(Mutex::new(response_rx)),
response_sender: response_tx,
sse_task_handle: Arc::new(Mutex::new(None)),
}
}
fn get_endpoint_url(&self) -> String {
format!("{}{}", self.config.base_url, self.config.endpoint_path)
}
async fn get_message_endpoint_url(&self) -> String {
let discovered = self.message_endpoint.read().await;
if let Some(endpoint) = discovered.as_ref() {
if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
endpoint.clone()
} else if endpoint.starts_with('/') {
format!("{}{}", self.config.base_url, endpoint)
} else {
format!("{}/{}", self.config.base_url, endpoint)
}
} else {
self.get_endpoint_url()
}
}
async fn build_headers(&self, accept: &str) -> header::HeaderMap {
let mut headers = header::HeaderMap::new();
if let Ok(accept_value) = header::HeaderValue::from_str(accept) {
headers.insert(header::ACCEPT, accept_value);
}
if let Ok(protocol_value) = header::HeaderValue::from_str(&self.config.protocol_version) {
headers.insert("MCP-Protocol-Version", protocol_value);
}
if let Some(session_id) = self.session_id.read().await.as_ref()
&& let Ok(session_value) = header::HeaderValue::from_str(session_id)
{
headers.insert("Mcp-Session-Id", session_value);
}
if let Some(last_event_id) = self.last_event_id.read().await.as_ref()
&& let Ok(event_value) = header::HeaderValue::from_str(last_event_id)
{
headers.insert("Last-Event-ID", event_value);
}
if let Some(token) = &self.config.auth_token
&& let Ok(auth_value) = header::HeaderValue::from_str(&format!("Bearer {}", token))
{
headers.insert(header::AUTHORIZATION, auth_value);
}
for (key, value) in &self.config.headers {
if let (Ok(k), Ok(v)) = (
header::HeaderName::from_bytes(key.as_bytes()),
header::HeaderValue::from_str(value),
) {
headers.insert(k, v);
}
}
headers
}
async fn start_sse_connection(&self) -> TransportResult<()> {
info!("Starting SSE connection to {}", self.get_endpoint_url());
let endpoint_url = self.get_endpoint_url();
let config = self.config.clone();
let http_client = self.http_client.clone();
let state = Arc::clone(&self.state);
let sse_sender = self.sse_sender.clone();
let session_id = Arc::clone(&self.session_id);
let last_event_id = Arc::clone(&self.last_event_id);
let message_endpoint = Arc::clone(&self.message_endpoint);
let task = tokio::spawn(async move {
Self::sse_connection_task(
endpoint_url,
config,
http_client,
state,
sse_sender,
session_id,
last_event_id,
message_endpoint,
)
.await;
});
*self.sse_task_handle.lock().await = Some(task);
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn sse_connection_task(
endpoint_url: String,
config: StreamableHttpClientConfig,
http_client: HttpClient,
state: Arc<RwLock<TransportState>>,
sse_sender: mpsc::Sender<TransportMessage>,
session_id: Arc<RwLock<Option<String>>>,
last_event_id: Arc<RwLock<Option<String>>>,
message_endpoint: Arc<RwLock<Option<String>>>,
) {
let mut attempt = 0u32;
loop {
if let Some(delay) = config.retry_policy.delay(attempt) {
if attempt > 0 {
warn!("Reconnecting in {:?} (attempt {})", delay, attempt + 1);
tokio::time::sleep(delay).await;
}
} else {
error!("Max retry attempts reached, giving up");
*state.write().await = TransportState::Disconnected;
break;
}
let mut headers = header::HeaderMap::new();
headers.insert(
header::ACCEPT,
header::HeaderValue::from_static("text/event-stream"),
);
if let Ok(protocol_value) = header::HeaderValue::from_str(&config.protocol_version) {
headers.insert("MCP-Protocol-Version", protocol_value);
}
if let Some(sid) = session_id.read().await.as_ref()
&& let Ok(session_value) = header::HeaderValue::from_str(sid)
{
headers.insert("Mcp-Session-Id", session_value);
}
if let Some(last_id) = last_event_id.read().await.as_ref()
&& let Ok(event_value) = header::HeaderValue::from_str(last_id)
{
headers.insert("Last-Event-ID", event_value);
}
if let Some(token) = &config.auth_token
&& let Ok(auth_value) = header::HeaderValue::from_str(&format!("Bearer {}", token))
{
headers.insert(header::AUTHORIZATION, auth_value);
}
match http_client.get(&endpoint_url).headers(headers).send().await {
Ok(response) => {
if !response.status().is_success() {
error!("SSE connection failed: {}", response.status());
attempt += 1;
continue;
}
if let Some(sid) = response
.headers()
.get("Mcp-Session-Id")
.and_then(|v| v.to_str().ok())
{
*session_id.write().await = Some(sid.to_string());
info!("Received session ID: {}", sid);
}
info!("SSE connection established");
*state.write().await = TransportState::Connected;
attempt = 0;
let mut stream = response.bytes_stream();
let mut buffer = String::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_str = String::from_utf8_lossy(&chunk);
buffer.push_str(&chunk_str);
while let Some(pos) = buffer.find("\n\n") {
let event_str = buffer[..pos].to_string();
buffer = buffer[pos + 2..].to_string();
if let Err(e) = Self::process_sse_event(
&event_str,
&sse_sender,
&last_event_id,
&message_endpoint,
)
.await
{
warn!("Failed to process SSE event: {}", e);
}
}
}
Err(e) => {
error!("Error reading SSE stream: {}", e);
break;
}
}
}
warn!("SSE stream ended");
*state.write().await = TransportState::Disconnected;
}
Err(e) => {
error!("Failed to connect: {}", e);
attempt += 1;
}
}
}
}
async fn process_sse_event(
event_str: &str,
sse_sender: &mpsc::Sender<TransportMessage>,
last_event_id: &Arc<RwLock<Option<String>>>,
message_endpoint: &Arc<RwLock<Option<String>>>,
) -> TransportResult<()> {
let lines: Vec<&str> = event_str.lines().collect();
let mut event_type: Option<String> = None;
let mut event_data: Vec<String> = Vec::new();
let mut event_id: Option<String> = None;
for line in lines {
if line.is_empty() {
continue;
}
if let Some(colon_pos) = line.find(':') {
let field = &line[..colon_pos];
let value = line[colon_pos + 1..].trim_start();
match field {
"event" => event_type = Some(value.to_string()),
"data" => event_data.push(value.to_string()),
"id" => event_id = Some(value.to_string()),
_ => {}
}
}
}
if let Some(id) = event_id {
*last_event_id.write().await = Some(id);
}
if event_data.is_empty() {
return Ok(());
}
let data_str = event_data.join("\n");
match event_type.as_deref() {
Some("endpoint") => {
let endpoint_uri = if data_str.trim().starts_with('{') {
let endpoint_json: serde_json::Value = serde_json::from_str(&data_str)
.map_err(|e| {
TransportError::SerializationFailed(format!(
"Invalid endpoint JSON: {}",
e
))
})?;
endpoint_json["uri"]
.as_str()
.ok_or_else(|| {
TransportError::SerializationFailed(
"Endpoint event missing 'uri' field".to_string(),
)
})?
.to_string()
} else {
data_str.clone()
};
info!("Discovered message endpoint: {}", endpoint_uri);
*message_endpoint.write().await = Some(endpoint_uri);
Ok(())
}
Some("message") | None => {
if data_str.trim().is_empty() {
debug!("Skipping empty SSE event");
return Ok(());
}
let json_value: serde_json::Value =
serde_json::from_str(&data_str).map_err(|e| {
TransportError::SerializationFailed(format!("Invalid JSON: {}", e))
})?;
let message = TransportMessage::new(
MessageId::from("sse-message".to_string()),
Bytes::from(
serde_json::to_vec(&json_value)
.map_err(|e| TransportError::SerializationFailed(e.to_string()))?,
),
);
sse_sender
.send(message)
.await
.map_err(|e| TransportError::ConnectionLost(e.to_string()))?;
debug!("Received SSE message");
Ok(())
}
Some(other) => {
debug!("Ignoring unknown event type: {}", other);
Ok(())
}
}
}
async fn process_post_sse_event(
event_str: &str,
response_sender: &mpsc::Sender<TransportMessage>,
last_event_id: &Arc<RwLock<Option<String>>>,
) -> TransportResult<()> {
let lines: Vec<&str> = event_str.lines().collect();
let mut event_data: Vec<String> = Vec::new();
let mut event_id: Option<String> = None;
for line in lines {
if line.is_empty() {
continue;
}
if let Some(colon_pos) = line.find(':') {
let field = &line[..colon_pos];
let value = line[colon_pos + 1..].trim_start();
match field {
"data" => event_data.push(value.to_string()),
"id" => event_id = Some(value.to_string()),
"event" => {
}
_ => {}
}
}
}
if let Some(id) = event_id {
*last_event_id.write().await = Some(id);
}
if event_data.is_empty() {
return Ok(());
}
let data_str = event_data.join("\n");
let json_value: serde_json::Value = serde_json::from_str(&data_str).map_err(|e| {
TransportError::SerializationFailed(format!("Invalid JSON in POST SSE: {}", e))
})?;
let message = TransportMessage::new(
MessageId::from("post-sse-response".to_string()),
Bytes::from(
serde_json::to_vec(&json_value)
.map_err(|e| TransportError::SerializationFailed(e.to_string()))?,
),
);
response_sender
.send(message.clone())
.await
.map_err(|e| TransportError::ConnectionLost(e.to_string()))?;
debug!(
"Queued message from POST SSE stream: {}",
String::from_utf8_lossy(&message.payload)
);
Ok(())
}
}
impl Transport for StreamableHttpClientTransport {
fn send(
&self,
message: TransportMessage,
) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async move {
debug!("Sending message via HTTP POST");
validate_request_size(message.payload.len(), &self.config.limits)?;
let url = self.get_message_endpoint_url().await;
let headers = self
.build_headers("application/json, text/event-stream")
.await;
let response = self
.http_client
.post(&url)
.headers(headers)
.header(header::CONTENT_TYPE, "application/json")
.body(message.payload.to_vec())
.send()
.await
.map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
if !response.status().is_success() {
return Err(TransportError::ConnectionFailed(format!(
"POST failed: {}",
response.status()
)));
}
if let Some(session_id) = response
.headers()
.get("Mcp-Session-Id")
.and_then(|v| v.to_str().ok())
{
*self.session_id.write().await = Some(session_id.to_string());
}
if response.status() == reqwest::StatusCode::ACCEPTED {
debug!("Received HTTP 202 Accepted (no response body expected)");
{
let mut metrics = self.metrics.write().await;
metrics.messages_sent += 1;
metrics.bytes_sent += message.payload.len() as u64;
}
return Ok(());
}
let content_type = response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if content_type.contains("application/json") {
debug!("Received JSON response from POST");
let response_bytes = response
.bytes()
.await
.map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
validate_response_size(response_bytes.len(), &self.config.limits)?;
let response_message = TransportMessage::new(
MessageId::from("http-response".to_string()),
response_bytes,
);
self.response_sender
.send(response_message)
.await
.map_err(|e| TransportError::ConnectionLost(e.to_string()))?;
debug!("JSON response queued successfully");
} else if content_type.contains("text/event-stream") {
debug!("Received SSE stream response from POST, processing events");
let response_sender = self.response_sender.clone();
let last_event_id = Arc::clone(&self.last_event_id);
let mut stream = response.bytes_stream();
let mut buffer = String::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_str = String::from_utf8_lossy(&chunk);
buffer.push_str(&chunk_str);
while let Some(pos) = buffer.find("\n\n") {
let event_str = buffer[..pos].to_string();
buffer = buffer[pos + 2..].to_string();
if let Err(e) = Self::process_post_sse_event(
&event_str,
&response_sender,
&last_event_id,
)
.await
{
warn!("Failed to process POST SSE event: {}", e);
}
}
}
Err(e) => {
warn!("Error reading POST SSE stream: {}", e);
break;
}
}
}
debug!("POST SSE stream processing completed");
}
{
let mut metrics = self.metrics.write().await;
metrics.messages_sent += 1;
metrics.bytes_sent += message.payload.len() as u64;
}
debug!("Message sent successfully");
Ok(())
})
}
fn receive(
&self,
) -> Pin<Box<dyn Future<Output = TransportResult<Option<TransportMessage>>> + Send + '_>> {
Box::pin(async move {
{
let mut response_receiver = self.response_receiver.lock().await;
match response_receiver.try_recv() {
Ok(message) => {
debug!("Received queued JSON response");
{
let mut metrics = self.metrics.write().await;
metrics.messages_received += 1;
metrics.bytes_received += message.payload.len() as u64;
}
return Ok(Some(message));
}
Err(mpsc::error::TryRecvError::Empty) => {
}
Err(mpsc::error::TryRecvError::Disconnected) => {
return Err(TransportError::ConnectionLost(
"Response channel disconnected".to_string(),
));
}
}
}
let mut sse_receiver = self.sse_receiver.lock().await;
match sse_receiver.try_recv() {
Ok(message) => {
debug!("Received SSE message");
{
let mut metrics = self.metrics.write().await;
metrics.messages_received += 1;
metrics.bytes_received += message.payload.len() as u64;
}
Ok(Some(message))
}
Err(mpsc::error::TryRecvError::Empty) => Ok(None),
Err(mpsc::error::TryRecvError::Disconnected) => Err(
TransportError::ConnectionLost("SSE channel disconnected".to_string()),
),
}
})
}
fn capabilities(&self) -> &TransportCapabilities {
&self.capabilities
}
fn state(&self) -> Pin<Box<dyn Future<Output = TransportState> + Send + '_>> {
Box::pin(async move { self.state.read().await.clone() })
}
fn transport_type(&self) -> TransportType {
TransportType::Http
}
fn metrics(&self) -> Pin<Box<dyn Future<Output = TransportMetrics> + Send + '_>> {
Box::pin(async move { self.metrics.read().await.clone() })
}
fn connect(&self) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async move {
info!("Connecting to {}", self.get_endpoint_url());
*self.state.write().await = TransportState::Connecting;
self.start_sse_connection().await?;
tokio::time::sleep(Duration::from_millis(500)).await;
*self.state.write().await = TransportState::Connected;
info!("Connected successfully");
Ok(())
})
}
fn disconnect(&self) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async move {
info!("Disconnecting");
*self.state.write().await = TransportState::Disconnecting;
if let Some(handle) = self.sse_task_handle.lock().await.take() {
handle.abort();
}
if let Some(session_id) = self.session_id.read().await.as_ref() {
let url = self.get_endpoint_url();
let mut headers = header::HeaderMap::new();
if let Ok(session_value) = header::HeaderValue::from_str(session_id) {
headers.insert("Mcp-Session-Id", session_value);
}
let _ = self.http_client.delete(&url).headers(headers).send().await;
}
*self.state.write().await = TransportState::Disconnected;
info!("Disconnected");
Ok(())
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_policy_fixed() {
let policy = RetryPolicy::Fixed {
interval: Duration::from_secs(5),
max_attempts: Some(3),
};
assert_eq!(policy.delay(0), Some(Duration::from_secs(5)));
assert_eq!(policy.delay(1), Some(Duration::from_secs(5)));
assert_eq!(policy.delay(2), Some(Duration::from_secs(5)));
assert_eq!(policy.delay(3), None);
}
#[test]
fn test_retry_policy_exponential() {
let policy = RetryPolicy::Exponential {
base: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
max_attempts: None,
};
let delay0 = policy.delay(0).unwrap();
assert!(delay0 >= Duration::from_millis(750) && delay0 <= Duration::from_millis(1250));
let delay1 = policy.delay(1).unwrap();
assert!(delay1 >= Duration::from_millis(1500) && delay1 <= Duration::from_millis(2500));
let delay2 = policy.delay(2).unwrap();
assert!(delay2 >= Duration::from_millis(3000) && delay2 <= Duration::from_millis(5000));
let delay3 = policy.delay(3).unwrap();
assert!(delay3 >= Duration::from_millis(6000) && delay3 <= Duration::from_millis(10000));
let delay10 = policy.delay(10).unwrap();
assert!(delay10 >= Duration::from_millis(45000) && delay10 <= Duration::from_millis(75000));
}
#[tokio::test]
async fn test_client_creation() {
let config = StreamableHttpClientConfig::default();
let client = StreamableHttpClientTransport::new(config);
assert_eq!(client.transport_type(), TransportType::Http);
assert!(client.capabilities().supports_streaming);
assert!(client.capabilities().supports_bidirectional);
}
#[tokio::test]
async fn test_endpoint_event_json_parsing() {
use std::sync::Arc;
use tokio::sync::RwLock;
let message_endpoint = Arc::new(RwLock::new(None::<String>));
let event_data = [r#"{"uri":"http://127.0.0.1:8080/mcp"}"#.to_string()];
let data_str = event_data.join("\n");
let endpoint_uri = if data_str.trim().starts_with('{') {
let endpoint_json: serde_json::Value =
serde_json::from_str(&data_str).expect("Failed to parse endpoint JSON");
endpoint_json["uri"]
.as_str()
.expect("Missing uri field")
.to_string()
} else {
data_str.clone()
};
*message_endpoint.write().await = Some(endpoint_uri.clone());
let stored = message_endpoint.read().await;
assert_eq!(stored.as_ref().unwrap(), "http://127.0.0.1:8080/mcp");
assert!(stored.as_ref().unwrap().starts_with("http://"));
assert!(stored.as_ref().unwrap().parse::<url::Url>().is_ok());
}
#[tokio::test]
async fn test_endpoint_event_plain_string_parsing() {
use std::sync::Arc;
use tokio::sync::RwLock;
let message_endpoint = Arc::new(RwLock::new(None::<String>));
let event_data = ["http://127.0.0.1:8080/mcp".to_string()];
let data_str = event_data.join("\n");
let endpoint_uri = if data_str.trim().starts_with('{') {
let endpoint_json: serde_json::Value =
serde_json::from_str(&data_str).expect("Failed to parse endpoint JSON");
endpoint_json["uri"]
.as_str()
.expect("Missing uri field")
.to_string()
} else {
data_str.clone()
};
*message_endpoint.write().await = Some(endpoint_uri.clone());
let stored = message_endpoint.read().await;
assert_eq!(stored.as_ref().unwrap(), "http://127.0.0.1:8080/mcp");
assert!(stored.as_ref().unwrap().starts_with("http://"));
}
}