remote_mcp_kernel/handlers/
sse_handler.rs1use axum::{Router, middleware};
8use rmcp::transport::sse_server::{SseServer, SseServerConfig};
9use tokio_util::sync::CancellationToken;
10
11use crate::{error::AppResult, handlers::McpServerHandler};
12use oauth_provider_rs::http_integration::middleware::simple_auth_middleware;
13
14#[derive(Clone)]
20pub struct SseHandler<M: McpServerHandler> {
21 mcp_server: M,
23}
24
25#[derive(Debug, Clone)]
27pub struct SseHandlerConfig {
28 pub sse_path: String,
29 pub message_path: String,
30 pub keep_alive_seconds: u64,
31 pub require_auth: bool,
32}
33
34impl Default for SseHandlerConfig {
35 fn default() -> Self {
36 Self {
37 sse_path: "/mcp/sse".to_string(),
38 message_path: "/mcp/message".to_string(),
39 keep_alive_seconds: 15,
40 require_auth: true,
41 }
42 }
43}
44
45impl SseHandlerConfig {
46 pub fn default_config() -> SseHandlerConfig {
48 SseHandlerConfig::default()
49 }
50
51 pub fn config_with_paths(
53 sse_path: impl Into<String>,
54 message_path: impl Into<String>,
55 ) -> SseHandlerConfig {
56 SseHandlerConfig {
57 sse_path: sse_path.into(),
58 message_path: message_path.into(),
59 ..Default::default()
60 }
61 }
62
63 pub fn config_without_auth() -> SseHandlerConfig {
65 SseHandlerConfig {
66 require_auth: false,
67 ..Default::default()
68 }
69 }
70}
71
72impl<M: McpServerHandler> SseHandler<M> {
73 pub fn new(mcp_server: M) -> Self {
75 Self { mcp_server }
76 }
77
78 pub fn router(&self, config: SseHandlerConfig) -> AppResult<Router> {
83 let sse_config = SseServerConfig {
84 bind: "0.0.0.0:0".parse().unwrap(), sse_path: config.sse_path,
86 post_path: config.message_path,
87 ct: CancellationToken::new(),
88 sse_keep_alive: Some(std::time::Duration::from_secs(config.keep_alive_seconds)),
89 };
90
91 let (sse_server, sse_router) = SseServer::new(sse_config);
92
93 let mcp_server = self.mcp_server.clone();
95 let _service_token = sse_server.with_service(move || mcp_server.clone());
96
97 let router = if config.require_auth {
99 sse_router.layer(middleware::from_fn(simple_auth_middleware))
100 } else {
101 sse_router
102 };
103
104 Ok(router)
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::handlers::McpServer;
112 use oauth_provider_rs::{GitHubOAuthConfig, GitHubOAuthProvider, OAuthProvider};
113
114 #[test]
115 fn test_sse_handler_creation() {
116 let github_config = GitHubOAuthConfig {
117 client_id: "test_client_id".to_string(),
118 client_secret: "test_client_secret".to_string(),
119 redirect_uri: "http://localhost:8080/oauth/callback".to_string(),
120 scope: "read:user".to_string(),
121 provider_name: "github".to_string(),
122 };
123
124 let oauth_provider = OAuthProvider::new(GitHubOAuthProvider::new_github(github_config));
125 let mcp_server = McpServer::new();
126 let _sse_handler = SseHandler::new(mcp_server);
127
128 }
130
131 #[test]
132 fn test_sse_config_defaults() {
133 let config = SseHandlerConfig::default();
134
135 assert_eq!(config.sse_path, "/mcp/sse");
136 assert_eq!(config.message_path, "/mcp/message");
137 assert_eq!(config.keep_alive_seconds, 15);
138 assert!(config.require_auth);
139 }
140
141 #[test]
142 fn test_sse_config_custom_paths() {
143 let config = SseHandlerConfig::config_with_paths("/custom/sse", "/custom/message");
144
145 assert_eq!(config.sse_path, "/custom/sse");
146 assert_eq!(config.message_path, "/custom/message");
147 assert!(config.require_auth);
148 }
149
150 #[test]
151 fn test_sse_config_without_auth() {
152 let config = SseHandlerConfig::config_without_auth();
153
154 assert!(!config.require_auth);
155 assert_eq!(config.sse_path, "/mcp/sse");
156 }
157
158 #[tokio::test]
159 async fn test_sse_router_creation() {
160 let github_config = GitHubOAuthConfig {
161 client_id: "test_client_id".to_string(),
162 client_secret: "test_client_secret".to_string(),
163 redirect_uri: "http://localhost:8080/oauth/callback".to_string(),
164 scope: "read:user".to_string(),
165 provider_name: "github".to_string(),
166 };
167
168 let oauth_provider = OAuthProvider::new(GitHubOAuthProvider::new_github(github_config));
169 let mcp_server = McpServer::new();
170 let sse_handler = SseHandler::new(mcp_server);
171 let config = SseHandlerConfig::default();
172
173 let router_result = sse_handler.router(config);
174 assert!(router_result.is_ok());
175 }
176}