use futures::StreamExt;
use reqwest::header::{ACCEPT, CONTENT_TYPE, HeaderMap, HeaderValue};
use std::collections::HashMap;
use std::time::Duration;
use tracing::debug;
pub async fn is_sse(url: &str) -> bool {
is_sse_with_headers(url, None).await
}
pub async fn is_sse_with_headers(
url: &str,
custom_headers: Option<&HashMap<String, String>>,
) -> bool {
let client = match reqwest::Client::builder()
.connect_timeout(Duration::from_secs(5))
.build()
{
Ok(c) => c,
Err(_) => return false,
};
let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
if let Some(custom) = custom_headers {
for (key, value) in custom {
if let (Ok(name), Ok(val)) = (
reqwest::header::HeaderName::try_from(key.as_str()),
HeaderValue::from_str(value),
) {
headers.insert(name, val);
}
}
}
let trimmed = url.trim_end_matches('/');
let candidates: Vec<String> = if trimmed.ends_with("/sse") {
vec![url.to_string()]
} else {
vec![format!("{}/sse", trimmed), url.to_string()]
};
for probe_url in &candidates {
debug!("SSE probe: try {}", probe_url);
match tokio::time::timeout(
Duration::from_secs(5),
probe_sse_endpoint(&client, probe_url, &headers),
)
.await
{
Ok(true) => {
debug!(
"SSE probe: Confirm {} is MCP SSE protocol (discover endpoint event)",
probe_url
);
return true;
}
Ok(false) => {
debug!("SSE probe: {} is not MCP SSE protocol", probe_url);
}
Err(_) => {
debug!("SSE probe: {} timeout", probe_url);
}
}
}
false
}
async fn probe_sse_endpoint(client: &reqwest::Client, url: &str, headers: &HeaderMap) -> bool {
let response = match client.get(url).headers(headers.clone()).send().await {
Ok(r) => r,
Err(_) => return false,
};
if !response.status().is_success() {
return false;
}
let is_event_stream = response
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.is_some_and(|ct| ct.contains("text/event-stream"));
if !is_event_stream {
return false;
}
read_sse_for_endpoint_event(response).await
}
async fn read_sse_for_endpoint_event(response: reqwest::Response) -> bool {
let mut stream = response.bytes_stream();
let mut buffer = String::new();
const MAX_BYTES: usize = 4096;
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
if let Ok(text) = std::str::from_utf8(&bytes) {
buffer.push_str(text);
}
if buffer.contains("event: endpoint") || buffer.contains("event:endpoint") {
return true;
}
if buffer.len() > MAX_BYTES {
return false;
}
}
Err(_) => return false,
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_is_sse_nonexistent_server() {
let result = is_sse("http://localhost:99999/mcp").await;
assert!(!result);
}
#[tokio::test]
async fn test_is_sse_with_headers_no_panic() {
let mut headers = HashMap::new();
headers.insert("Authorization".to_string(), "Bearer test-token".to_string());
let result = is_sse_with_headers("http://localhost:99999/mcp", Some(&headers)).await;
assert!(!result);
}
#[tokio::test]
async fn test_candidate_urls_with_sse_suffix() {
let result = is_sse("http://localhost:99999/sse").await;
assert!(!result);
}
#[tokio::test]
async fn test_candidate_urls_with_trailing_slash() {
let result = is_sse("http://localhost:99999/mcp/").await;
assert!(!result);
}
}