remote_mcp_kernel/handlers/
sse_handler.rs1use axum::{Router, middleware};
8use rmcp::transport::sse_server::{SseServer, SseServerConfig};
9use tokio_util::sync::CancellationToken;
10
11use crate::{error::AppResult, handlers::McpServer};
12use oauth_provider_rs::{DefaultClientManager, OAuthProviderTrait, OAuthStorage, http_integration::middleware::simple_auth_middleware};
13
14#[derive(Clone)]
20pub struct SseHandler<
21 P: OAuthProviderTrait<S, DefaultClientManager<S>> + 'static,
22 S: OAuthStorage + Clone + 'static,
23> {
24 mcp_server: McpServer<P, S>,
26}
27
28#[derive(Debug, Clone)]
30pub struct SseHandlerConfig {
31 pub sse_path: String,
32 pub message_path: String,
33 pub keep_alive_seconds: u64,
34 pub require_auth: bool,
35}
36
37impl Default for SseHandlerConfig {
38 fn default() -> Self {
39 Self {
40 sse_path: "/mcp/sse".to_string(),
41 message_path: "/mcp/message".to_string(),
42 keep_alive_seconds: 15,
43 require_auth: true,
44 }
45 }
46}
47
48impl SseHandlerConfig {
49 pub fn default_config() -> SseHandlerConfig {
51 SseHandlerConfig::default()
52 }
53
54 pub fn config_with_paths(
56 sse_path: impl Into<String>,
57 message_path: impl Into<String>,
58 ) -> SseHandlerConfig {
59 SseHandlerConfig {
60 sse_path: sse_path.into(),
61 message_path: message_path.into(),
62 ..Default::default()
63 }
64 }
65
66 pub fn config_without_auth() -> SseHandlerConfig {
68 SseHandlerConfig {
69 require_auth: false,
70 ..Default::default()
71 }
72 }
73}
74
75impl<P: OAuthProviderTrait<S, DefaultClientManager<S>> + 'static, S: OAuthStorage + Clone + 'static>
76 SseHandler<P, S>
77{
78 pub fn new(mcp_server: McpServer<P, S>) -> Self {
80 Self { mcp_server }
81 }
82
83 pub fn router(&self, config: SseHandlerConfig) -> AppResult<Router> {
88 let sse_config = SseServerConfig {
89 bind: "0.0.0.0:0".parse().unwrap(), sse_path: config.sse_path,
91 post_path: config.message_path,
92 ct: CancellationToken::new(),
93 sse_keep_alive: Some(std::time::Duration::from_secs(config.keep_alive_seconds)),
94 };
95
96 let (sse_server, sse_router) = SseServer::new(sse_config);
97
98 let mcp_server = self.mcp_server.clone();
100 let _service_token = sse_server.with_service(move || mcp_server.clone());
101
102 let router = if config.require_auth {
104 sse_router.layer(middleware::from_fn(simple_auth_middleware))
105 } else {
106 sse_router
107 };
108
109 Ok(router)
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116 use oauth_provider_rs::{GitHubOAuthConfig, GitHubOAuthProvider, OAuthProvider};
117
118 #[test]
119 fn test_sse_handler_creation() {
120 let github_config = GitHubOAuthConfig {
121 client_id: "test_client_id".to_string(),
122 client_secret: "test_client_secret".to_string(),
123 redirect_uri: "http://localhost:8080/oauth/callback".to_string(),
124 scope: "read:user".to_string(),
125 provider_name: "github".to_string(),
126 };
127
128 let oauth_provider = OAuthProvider::new(GitHubOAuthProvider::new_github(github_config));
129 let mcp_server = McpServer::new(oauth_provider);
130 let _sse_handler = SseHandler::new(mcp_server);
131
132 }
134
135 #[test]
136 fn test_sse_config_defaults() {
137 let config = SseHandlerConfig::default();
138
139 assert_eq!(config.sse_path, "/mcp/sse");
140 assert_eq!(config.message_path, "/mcp/message");
141 assert_eq!(config.keep_alive_seconds, 15);
142 assert!(config.require_auth);
143 }
144
145 #[test]
146 fn test_sse_config_custom_paths() {
147 let config = SseHandlerConfig::config_with_paths("/custom/sse", "/custom/message");
148
149 assert_eq!(config.sse_path, "/custom/sse");
150 assert_eq!(config.message_path, "/custom/message");
151 assert!(config.require_auth);
152 }
153
154 #[test]
155 fn test_sse_config_without_auth() {
156 let config = SseHandlerConfig::config_without_auth();
157
158 assert!(!config.require_auth);
159 assert_eq!(config.sse_path, "/mcp/sse");
160 }
161
162 #[tokio::test]
163 async fn test_sse_router_creation() {
164 let github_config = GitHubOAuthConfig {
165 client_id: "test_client_id".to_string(),
166 client_secret: "test_client_secret".to_string(),
167 redirect_uri: "http://localhost:8080/oauth/callback".to_string(),
168 scope: "read:user".to_string(),
169 provider_name: "github".to_string(),
170 };
171
172 let oauth_provider = OAuthProvider::new(GitHubOAuthProvider::new_github(github_config));
173 let mcp_server = McpServer::new(oauth_provider);
174 let sse_handler = SseHandler::new(mcp_server);
175 let config = SseHandlerConfig::default();
176
177 let router_result = sse_handler.router(config);
178 assert!(router_result.is_ok());
179 }
180}