remote_mcp_kernel/handlers/
sse_handler.rs1use axum::{Router, middleware};
8use rmcp::transport::sse_server::{SseServer, SseServerConfig};
9use tokio_util::sync::CancellationToken;
10
11use crate::{error::AppResult, handlers::McpServerHandler};
12use oauth_provider_rs::http_integration::middleware::simple_auth_middleware;
13
14#[derive(Clone)]
20pub struct SseHandler<M: McpServerHandler> {
21    mcp_server: M,
23}
24
25#[derive(Debug, Clone)]
27pub struct SseHandlerConfig {
28    pub sse_path: String,
29    pub message_path: String,
30    pub keep_alive_seconds: u64,
31    pub require_auth: bool,
32}
33
34impl Default for SseHandlerConfig {
35    fn default() -> Self {
36        Self {
37            sse_path: "/mcp/sse".to_string(),
38            message_path: "/mcp/message".to_string(),
39            keep_alive_seconds: 15,
40            require_auth: true,
41        }
42    }
43}
44
45impl SseHandlerConfig {
46    pub fn default_config() -> SseHandlerConfig {
48        SseHandlerConfig::default()
49    }
50
51    pub fn config_with_paths(
53        sse_path: impl Into<String>,
54        message_path: impl Into<String>,
55    ) -> SseHandlerConfig {
56        SseHandlerConfig {
57            sse_path: sse_path.into(),
58            message_path: message_path.into(),
59            ..Default::default()
60        }
61    }
62
63    pub fn config_without_auth() -> SseHandlerConfig {
65        SseHandlerConfig {
66            require_auth: false,
67            ..Default::default()
68        }
69    }
70}
71
72impl<M: McpServerHandler> SseHandler<M> {
73    pub fn new(mcp_server: M) -> Self {
75        Self { mcp_server }
76    }
77
78    pub fn router(&self, config: SseHandlerConfig) -> AppResult<Router> {
83        let sse_config = SseServerConfig {
84            bind: "0.0.0.0:0".parse().unwrap(), sse_path: config.sse_path,
86            post_path: config.message_path,
87            ct: CancellationToken::new(),
88            sse_keep_alive: Some(std::time::Duration::from_secs(config.keep_alive_seconds)),
89        };
90
91        let (sse_server, sse_router) = SseServer::new(sse_config);
92
93        let mcp_server = self.mcp_server.clone();
95        let _service_token = sse_server.with_service(move || mcp_server.clone());
96
97        let router = if config.require_auth {
99            sse_router.layer(middleware::from_fn(simple_auth_middleware))
100        } else {
101            sse_router
102        };
103
104        Ok(router)
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::handlers::McpServer;
112    use oauth_provider_rs::{
113        GitHubOAuthProvider, OAuthProvider, provider_trait::OAuthProviderConfig,
114    };
115
116    #[test]
117    fn test_sse_handler_creation() {
118        let github_config = OAuthProviderConfig::with_oauth_config(
119            "test_client_id".to_string(),
120            "test_client_secret".to_string(),
121            "http://localhost:8080/oauth/callback".to_string(),
122            "read:user".to_string(),
123            "github".to_string(),
124        );
125
126        let oauth_provider = OAuthProvider::new(
127            GitHubOAuthProvider::new_github(github_config),
128            oauth_provider_rs::http_integration::config::OAuthProviderConfig::default(),
129        );
130        let mcp_server = McpServer::new();
131        let _sse_handler = SseHandler::new(mcp_server);
132
133        }
135
136    #[test]
137    fn test_sse_config_defaults() {
138        let config = SseHandlerConfig::default();
139
140        assert_eq!(config.sse_path, "/mcp/sse");
141        assert_eq!(config.message_path, "/mcp/message");
142        assert_eq!(config.keep_alive_seconds, 15);
143        assert!(config.require_auth);
144    }
145
146    #[test]
147    fn test_sse_config_custom_paths() {
148        let config = SseHandlerConfig::config_with_paths("/custom/sse", "/custom/message");
149
150        assert_eq!(config.sse_path, "/custom/sse");
151        assert_eq!(config.message_path, "/custom/message");
152        assert!(config.require_auth);
153    }
154
155    #[test]
156    fn test_sse_config_without_auth() {
157        let config = SseHandlerConfig::config_without_auth();
158
159        assert!(!config.require_auth);
160        assert_eq!(config.sse_path, "/mcp/sse");
161    }
162
163    #[tokio::test]
164    async fn test_sse_router_creation() {
165        let github_config = OAuthProviderConfig::with_oauth_config(
166            "test_client_id".to_string(),
167            "test_client_secret".to_string(),
168            "http://localhost:8080/oauth/callback".to_string(),
169            "read:user".to_string(),
170            "github".to_string(),
171        );
172
173        let oauth_provider = OAuthProvider::new(
174            GitHubOAuthProvider::new_github(github_config),
175            oauth_provider_rs::http_integration::config::OAuthProviderConfig::default(),
176        );
177        let mcp_server = McpServer::new();
178        let sse_handler = SseHandler::new(mcp_server);
179        let config = SseHandlerConfig::default();
180
181        let router_result = sse_handler.router(config);
182        assert!(router_result.is_ok());
183    }
184}