use crate::hyper_servers::error::TransportServerResult;
use crate::mcp_http::{McpAppState, McpHttpHandler};
use axum::{extract::State, response::IntoResponse, routing::get, Extension, Router};
use http::{HeaderMap, Method, Uri};
use std::sync::Arc;
#[derive(Clone)]
pub struct SseMessageEndpoint(pub String);
pub fn routes(sse_endpoint: &str, sse_message_endpoint: &str) -> Router<Arc<McpAppState>> {
let sse_message_endpoint = SseMessageEndpoint(sse_message_endpoint.to_string());
Router::new().route(
sse_endpoint,
get(handle_sse).layer(Extension(sse_message_endpoint)),
)
}
pub async fn handle_sse(
headers: HeaderMap,
uri: Uri,
Extension(sse_message_endpoint): Extension<SseMessageEndpoint>,
Extension(http_handler): Extension<Arc<McpHttpHandler>>,
State(state): State<Arc<McpAppState>>,
) -> TransportServerResult<impl IntoResponse> {
let SseMessageEndpoint(sse_message_endpoint) = sse_message_endpoint;
let request = McpHttpHandler::create_request(Method::GET, uri, headers, None);
let generic_response = http_handler
.handle_sse_connection(request, state.clone(), Some(&sse_message_endpoint))
.await?;
let (parts, body) = generic_response.into_parts();
let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body));
Ok(resp)
}