use axum::{Router, middleware};
use rmcp::transport::sse_server::{SseServer, SseServerConfig};
use tokio_util::sync::CancellationToken;
use crate::{error::AppResult, handlers::McpServerHandler};
use oauth_provider_rs::http_integration::middleware::simple_auth_middleware;
#[derive(Clone)]
pub struct SseHandler<M: McpServerHandler> {
mcp_server: M,
}
#[derive(Debug, Clone)]
pub struct SseHandlerConfig {
pub sse_path: String,
pub message_path: String,
pub keep_alive_seconds: u64,
pub require_auth: bool,
}
impl Default for SseHandlerConfig {
fn default() -> Self {
Self {
sse_path: "/mcp/sse".to_string(),
message_path: "/mcp/message".to_string(),
keep_alive_seconds: 15,
require_auth: true,
}
}
}
impl SseHandlerConfig {
pub fn default_config() -> SseHandlerConfig {
SseHandlerConfig::default()
}
pub fn config_with_paths(
sse_path: impl Into<String>,
message_path: impl Into<String>,
) -> SseHandlerConfig {
SseHandlerConfig {
sse_path: sse_path.into(),
message_path: message_path.into(),
..Default::default()
}
}
pub fn config_without_auth() -> SseHandlerConfig {
SseHandlerConfig {
require_auth: false,
..Default::default()
}
}
}
impl<M: McpServerHandler> SseHandler<M> {
pub fn new(mcp_server: M) -> Self {
Self { mcp_server }
}
pub fn router(&self, config: SseHandlerConfig) -> AppResult<Router> {
let sse_config = SseServerConfig {
bind: "0.0.0.0:0".parse().unwrap(), sse_path: config.sse_path,
post_path: config.message_path,
ct: CancellationToken::new(),
sse_keep_alive: Some(std::time::Duration::from_secs(config.keep_alive_seconds)),
};
let (sse_server, sse_router) = SseServer::new(sse_config);
let mcp_server = self.mcp_server.clone();
let _service_token = sse_server.with_service(move || mcp_server.clone());
let router = if config.require_auth {
sse_router.layer(middleware::from_fn(simple_auth_middleware))
} else {
sse_router
};
Ok(router)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handlers::McpServer;
use oauth_provider_rs::{GitHubOAuthConfig, GitHubOAuthProvider, OAuthProvider};
#[test]
fn test_sse_handler_creation() {
let github_config = GitHubOAuthConfig {
client_id: "test_client_id".to_string(),
client_secret: "test_client_secret".to_string(),
redirect_uri: "http://localhost:8080/oauth/callback".to_string(),
scope: "read:user".to_string(),
provider_name: "github".to_string(),
};
let oauth_provider = OAuthProvider::new(GitHubOAuthProvider::new_github(github_config));
let mcp_server = McpServer::new();
let _sse_handler = SseHandler::new(mcp_server);
}
#[test]
fn test_sse_config_defaults() {
let config = SseHandlerConfig::default();
assert_eq!(config.sse_path, "/mcp/sse");
assert_eq!(config.message_path, "/mcp/message");
assert_eq!(config.keep_alive_seconds, 15);
assert!(config.require_auth);
}
#[test]
fn test_sse_config_custom_paths() {
let config = SseHandlerConfig::config_with_paths("/custom/sse", "/custom/message");
assert_eq!(config.sse_path, "/custom/sse");
assert_eq!(config.message_path, "/custom/message");
assert!(config.require_auth);
}
#[test]
fn test_sse_config_without_auth() {
let config = SseHandlerConfig::config_without_auth();
assert!(!config.require_auth);
assert_eq!(config.sse_path, "/mcp/sse");
}
#[tokio::test]
async fn test_sse_router_creation() {
let github_config = GitHubOAuthConfig {
client_id: "test_client_id".to_string(),
client_secret: "test_client_secret".to_string(),
redirect_uri: "http://localhost:8080/oauth/callback".to_string(),
scope: "read:user".to_string(),
provider_name: "github".to_string(),
};
let oauth_provider = OAuthProvider::new(GitHubOAuthProvider::new_github(github_config));
let mcp_server = McpServer::new();
let sse_handler = SseHandler::new(mcp_server);
let config = SseHandlerConfig::default();
let router_result = sse_handler.router(config);
assert!(router_result.is_ok());
}
}