remote_mcp_kernel/handlers/
sse_handler.rs1use axum::{Router, middleware};
8use rmcp::transport::sse_server::{SseServer, SseServerConfig};
9use tokio_util::sync::CancellationToken;
10
11use crate::{error::AppResult, handlers::McpServerHandler};
12use oauth_provider_rs::http_integration::middleware::simple_auth_middleware;
13
14#[derive(Clone)]
20pub struct SseHandler<M: McpServerHandler> {
21 mcp_server: M,
23}
24
25#[derive(Debug, Clone)]
27pub struct SseHandlerConfig {
28 pub sse_path: String,
29 pub message_path: String,
30 pub keep_alive_seconds: u64,
31 pub require_auth: bool,
32}
33
34impl Default for SseHandlerConfig {
35 fn default() -> Self {
36 Self {
37 sse_path: "/mcp/sse".to_string(),
38 message_path: "/mcp/message".to_string(),
39 keep_alive_seconds: 15,
40 require_auth: true,
41 }
42 }
43}
44
45impl SseHandlerConfig {
46 pub fn default_config() -> SseHandlerConfig {
48 SseHandlerConfig::default()
49 }
50
51 pub fn config_with_paths(
53 sse_path: impl Into<String>,
54 message_path: impl Into<String>,
55 ) -> SseHandlerConfig {
56 SseHandlerConfig {
57 sse_path: sse_path.into(),
58 message_path: message_path.into(),
59 ..Default::default()
60 }
61 }
62
63 pub fn config_without_auth() -> SseHandlerConfig {
65 SseHandlerConfig {
66 require_auth: false,
67 ..Default::default()
68 }
69 }
70}
71
72impl<M: McpServerHandler> SseHandler<M> {
73 pub fn new(mcp_server: M) -> Self {
75 Self { mcp_server }
76 }
77
78 pub fn router(&self, config: SseHandlerConfig) -> AppResult<Router> {
83 let sse_config = SseServerConfig {
84 bind: "0.0.0.0:0".parse().unwrap(), sse_path: config.sse_path,
86 post_path: config.message_path,
87 ct: CancellationToken::new(),
88 sse_keep_alive: Some(std::time::Duration::from_secs(config.keep_alive_seconds)),
89 };
90
91 let (sse_server, sse_router) = SseServer::new(sse_config);
92
93 let mcp_server = self.mcp_server.clone();
95 let _service_token = sse_server.with_service(move || mcp_server.clone());
96
97 let router = if config.require_auth {
99 sse_router.layer(middleware::from_fn(simple_auth_middleware))
100 } else {
101 sse_router
102 };
103
104 Ok(router)
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::handlers::McpServer;
112 use oauth_provider_rs::{
113 GitHubOAuthProvider, OAuthProvider, provider_trait::OAuthProviderConfig,
114 };
115
116 #[test]
117 fn test_sse_handler_creation() {
118 let github_config = OAuthProviderConfig::with_oauth_config(
119 "test_client_id".to_string(),
120 "test_client_secret".to_string(),
121 "http://localhost:8080/oauth/callback".to_string(),
122 "read:user".to_string(),
123 "github".to_string(),
124 );
125
126 let oauth_provider = OAuthProvider::new(
127 GitHubOAuthProvider::new_github(github_config),
128 oauth_provider_rs::http_integration::config::OAuthProviderConfig::default(),
129 );
130 let mcp_server = McpServer::new();
131 let _sse_handler = SseHandler::new(mcp_server);
132
133 }
135
136 #[test]
137 fn test_sse_config_defaults() {
138 let config = SseHandlerConfig::default();
139
140 assert_eq!(config.sse_path, "/mcp/sse");
141 assert_eq!(config.message_path, "/mcp/message");
142 assert_eq!(config.keep_alive_seconds, 15);
143 assert!(config.require_auth);
144 }
145
146 #[test]
147 fn test_sse_config_custom_paths() {
148 let config = SseHandlerConfig::config_with_paths("/custom/sse", "/custom/message");
149
150 assert_eq!(config.sse_path, "/custom/sse");
151 assert_eq!(config.message_path, "/custom/message");
152 assert!(config.require_auth);
153 }
154
155 #[test]
156 fn test_sse_config_without_auth() {
157 let config = SseHandlerConfig::config_without_auth();
158
159 assert!(!config.require_auth);
160 assert_eq!(config.sse_path, "/mcp/sse");
161 }
162
163 #[tokio::test]
164 async fn test_sse_router_creation() {
165 let github_config = OAuthProviderConfig::with_oauth_config(
166 "test_client_id".to_string(),
167 "test_client_secret".to_string(),
168 "http://localhost:8080/oauth/callback".to_string(),
169 "read:user".to_string(),
170 "github".to_string(),
171 );
172
173 let oauth_provider = OAuthProvider::new(
174 GitHubOAuthProvider::new_github(github_config),
175 oauth_provider_rs::http_integration::config::OAuthProviderConfig::default(),
176 );
177 let mcp_server = McpServer::new();
178 let sse_handler = SseHandler::new(mcp_server);
179 let config = SseHandlerConfig::default();
180
181 let router_result = sse_handler.router(config);
182 assert!(router_result.is_ok());
183 }
184}