async_mcp/sse/
http_server.rs

1use actix_web::middleware::Logger;
2use actix_web::web::Payload;
3use actix_web::web::Query;
4use actix_web::{web, App, HttpResponse, HttpServer};
5use anyhow::Result;
6use futures::StreamExt;
7use uuid::Uuid;
8
9use crate::server::Server;
10use crate::sse::middleware::{AuthConfig, JwtAuth};
11use crate::transport::ServerHttpTransport;
12use crate::transport::{handle_ws_connection, Message, ServerSseTransport, ServerWsTransport};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::{Arc, Mutex};
16use tokio::sync::broadcast;
17use tracing::{debug, error, info};
18
19/// Server-side SSE transport that handles HTTP POST requests for incoming messages
20/// and sends responses via SSE
21#[derive(Debug, Serialize, Deserialize)]
22pub struct Claims {
23    pub exp: usize,
24    pub iat: usize,
25}
26
27#[derive(Deserialize)]
28pub struct MessageQuery {
29    #[serde(rename = "sessionId")]
30    session_id: Option<String>,
31}
32
33#[derive(Clone)]
34pub struct SessionState {
35    sessions: Arc<Mutex<HashMap<String, ServerHttpTransport>>>,
36    port: u16,
37    build_server: Arc<
38        dyn Fn(
39                ServerHttpTransport,
40            )
41                -> futures::future::BoxFuture<'static, Result<Server<ServerHttpTransport>>>
42            + Send
43            + Sync,
44    >,
45}
46
47/// Run a server instance with the specified transport
48pub async fn run_http_server<F, Fut>(
49    port: u16,
50    jwt_secret: Option<String>,
51    build_server: F,
52) -> Result<()>
53where
54    F: Fn(ServerHttpTransport) -> Fut + Send + Sync + 'static,
55    Fut: futures::Future<Output = Result<Server<ServerHttpTransport>>> + Send + 'static,
56{
57    info!("Starting server on http://127.0.0.1:{}", port);
58    info!("WebSocket endpoint: ws://127.0.0.1:{}/ws", port);
59    info!("SSE endpoint: http://127.0.0.1:{}/sse", port);
60
61    let sessions = Arc::new(Mutex::new(HashMap::new()));
62
63    // Box the future when creating the Arc
64    let build_server =
65        Arc::new(move |t| Box::pin(build_server(t)) as futures::future::BoxFuture<_>);
66
67    let auth_config = jwt_secret.map(|jwt_secret| AuthConfig { jwt_secret });
68    let http_server = http_server(port, sessions, auth_config, build_server);
69
70    http_server.await?;
71    Ok(())
72}
73
74pub async fn http_server(
75    port: u16,
76    sessions: Arc<Mutex<HashMap<String, ServerHttpTransport>>>,
77    auth_config: Option<AuthConfig>,
78    build_server: Arc<
79        dyn Fn(
80                ServerHttpTransport,
81            )
82                -> futures::future::BoxFuture<'static, Result<Server<ServerHttpTransport>>>
83            + Send
84            + Sync,
85    >,
86) -> std::result::Result<(), std::io::Error> {
87    let session_state = SessionState {
88        sessions,
89        build_server,
90        port,
91    };
92
93    let server = HttpServer::new(move || {
94        let session_state = session_state.clone();
95        App::new()
96            .wrap(Logger::default())
97            .wrap(JwtAuth::new(auth_config.clone()))
98            .app_data(web::Data::new(session_state))
99            .route("/sse", web::get().to(sse_handler))
100            .route("/message", web::post().to(message_handler))
101            .route("/ws", web::get().to(ws_handler))
102    })
103    .bind(("127.0.0.1", port))?
104    .run();
105
106    server.await
107}
108
109pub async fn sse_handler(
110    req: actix_web::HttpRequest,
111    session_state: web::Data<SessionState>,
112) -> HttpResponse {
113    let client_ip = req
114        .peer_addr()
115        .map(|addr| addr.ip().to_string())
116        .unwrap_or_else(|| "unknown".to_string());
117
118    info!("New SSE connection request from {}", client_ip);
119
120    // Create new session
121    let session_id = Uuid::new_v4().to_string();
122
123    // Create channel for SSE messages
124    let (sse_tx, sse_rx) = broadcast::channel(100);
125
126    // Create new transport for this session
127    let transport = ServerHttpTransport::Sse(ServerSseTransport::new(sse_tx.clone()));
128
129    // Store transport in sessions map
130    session_state
131        .sessions
132        .lock()
133        .unwrap()
134        .insert(session_id.clone(), transport.clone());
135
136    info!(
137        "SSE connection established for {} with session_id {}",
138        client_ip, session_id
139    );
140    let port = session_state.port;
141    // Create initial endpoint info event
142    let endpoint_info = format!(
143        "event: endpoint\ndata: http://127.0.0.1:{port}/message?sessionId={session_id}\n\n",
144    );
145
146    let stream = futures::stream::once(async move {
147        Ok::<_, std::convert::Infallible>(web::Bytes::from(endpoint_info))
148    })
149    .chain(futures::stream::unfold(sse_rx, move |mut rx| {
150        let client_ip = client_ip.clone();
151        async move {
152            match rx.recv().await {
153                Ok(msg) => {
154                    debug!("Sending SSE message to {}: {:?}", client_ip, msg);
155                    let json = serde_json::to_string(&msg).unwrap();
156                    let sse_data = format!("data: {}\n\n", json);
157                    Some((
158                        Ok::<_, std::convert::Infallible>(web::Bytes::from(sse_data)),
159                        rx,
160                    ))
161                }
162                _ => None,
163            }
164        }
165    }));
166
167    // Create and start server instance for this session
168    let transport_clone = transport.clone();
169    let build_server = session_state.build_server.clone();
170    tokio::spawn(async move {
171        match build_server(transport_clone).await {
172            Ok(server) => {
173                if let Err(e) = server.listen().await {
174                    error!("Server error: {:?}", e);
175                }
176            }
177            Err(e) => {
178                error!("Failed to build server: {:?}", e);
179            }
180        }
181    });
182
183    HttpResponse::Ok()
184        .append_header(("X-Session-Id", session_id))
185        .content_type("text/event-stream")
186        .streaming(stream)
187}
188
189async fn message_handler(
190    query: Query<MessageQuery>,
191    message: web::Json<Message>,
192    session_state: web::Data<SessionState>,
193) -> HttpResponse {
194    if let Some(session_id) = &query.session_id {
195        let sessions = session_state.sessions.lock().unwrap();
196        if let Some(transport) = sessions.get(session_id) {
197            match transport {
198                ServerHttpTransport::Sse(sse) => match sse.send_message(message.into_inner()).await
199                {
200                    Ok(_) => {
201                        debug!("Successfully sent message to session {}", session_id);
202                        HttpResponse::Accepted().finish()
203                    }
204                    Err(e) => {
205                        error!("Failed to send message to session {}: {:?}", session_id, e);
206                        HttpResponse::InternalServerError().finish()
207                    }
208                },
209                ServerHttpTransport::Ws(_) => HttpResponse::BadRequest()
210                    .body("Cannot send message to WebSocket connection through HTTP endpoint"),
211            }
212        } else {
213            HttpResponse::NotFound().body(format!("Session {} not found", session_id))
214        }
215    } else {
216        HttpResponse::BadRequest().body("Session ID not specified")
217    }
218}
219
220async fn ws_handler(
221    req: actix_web::HttpRequest,
222    body: Payload,
223    session_state: web::Data<SessionState>,
224) -> Result<HttpResponse, actix_web::Error> {
225    let (response, session, msg_stream) = actix_ws::handle(&req, body)?;
226
227    let client_ip = req
228        .peer_addr()
229        .map(|addr| addr.ip().to_string())
230        .unwrap_or_else(|| "unknown".to_string());
231
232    info!("New WebSocket connection from {}", client_ip);
233
234    // Create channels for message passing
235    let (tx, rx) = broadcast::channel(100);
236    let transport =
237        ServerHttpTransport::Ws(ServerWsTransport::new(session.clone(), rx.resubscribe()));
238
239    // Store transport in sessions map
240    let session_id = Uuid::new_v4().to_string();
241    session_state
242        .sessions
243        .lock()
244        .unwrap()
245        .insert(session_id, transport.clone());
246
247    // Start WebSocket handling in the background
248    actix_web::rt::spawn(async move {
249        let _ = handle_ws_connection(session, msg_stream, tx.clone(), rx.resubscribe()).await;
250    });
251
252    // Spawn server instance
253    let build_server = session_state.build_server.clone();
254    actix_web::rt::spawn(async move {
255        if let Ok(server) = build_server(transport).await {
256            let _ = server.listen().await;
257        }
258    });
259
260    Ok(response)
261}