async_mcp/sse/
http_server.rs

1use actix_web::middleware::Logger;
2use actix_web::web::Payload;
3use actix_web::web::Query;
4use actix_web::HttpMessage;
5use actix_web::{web, App, HttpResponse, HttpServer};
6use anyhow::Result;
7use futures::StreamExt;
8use uuid::Uuid;
9
10use crate::server::Server;
11use crate::sse::middleware::{AuthConfig, JwtAuth};
12use crate::transport::ServerHttpTransport;
13use crate::transport::{handle_ws_connection, Message, ServerSseTransport, ServerWsTransport};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::fmt::Debug;
17use std::sync::{Arc, Mutex};
18use tokio::sync::broadcast;
19use tracing::{debug, error, info};
20
21/// Server-side SSE transport that handles HTTP POST requests for incoming messages
22/// and sends responses via SSE
23#[derive(Debug, Serialize, Deserialize)]
24pub struct Claims {
25    pub exp: usize,
26    pub iat: usize,
27}
28
29#[derive(Clone)]
30pub struct Endpoint(pub String);
31
32#[derive(Deserialize)]
33pub struct MessageQuery {
34    #[serde(rename = "sessionId")]
35    session_id: Option<String>,
36}
37
38#[derive(Clone)]
39pub struct SessionState {
40    sessions: Arc<Mutex<HashMap<String, ServerHttpTransport>>>,
41    build_server: Arc<
42        dyn Fn(
43                ServerHttpTransport,
44                Option<serde_json::Value>,
45                String,
46            )
47                -> futures::future::BoxFuture<'static, Result<Server<ServerHttpTransport>>>
48            + Send
49            + Sync,
50    >,
51    endpoint: String,
52}
53
54impl SessionState {
55    /// Create a new SessionState instance with configurable parameters
56    pub fn new(
57        endpoint: String,
58        build_server: Arc<
59            dyn Fn(
60                    ServerHttpTransport,
61                    Option<serde_json::Value>,
62                    String,
63                )
64                    -> futures::future::BoxFuture<'static, Result<Server<ServerHttpTransport>>>
65                + Send
66                + Sync,
67        >,
68        sessions: Arc<Mutex<HashMap<String, ServerHttpTransport>>>,
69    ) -> Self {
70        Self {
71            sessions,
72            build_server,
73            endpoint,
74        }
75    }
76}
77
78/// Run a server instance with the specified transport
79pub async fn run_http_server<F, Fut>(
80    port: u16,
81    jwt_secret: Option<String>,
82    build_server: F,
83) -> Result<()>
84where
85    F: Fn(ServerHttpTransport, Option<serde_json::Value>, String) -> Fut + Send + Sync + 'static,
86    Fut: futures::Future<Output = Result<Server<ServerHttpTransport>>> + Send + 'static,
87{
88    info!("Starting server on http://0.0.0.0:{}", port);
89    info!("WebSocket endpoint: ws://0.0.0.0:{}/ws", port);
90    info!("SSE endpoint: http://0.0.0.0:{}/sse", port);
91
92    let sessions = Arc::new(Mutex::new(HashMap::new()));
93
94    // Box the future when creating the Arc
95    let build_server = Arc::new(move |t, o, session_id| {
96        Box::pin(build_server(t, o, session_id)) as futures::future::BoxFuture<_>
97    });
98
99    let auth_config = jwt_secret.map(|jwt_secret| AuthConfig { jwt_secret });
100    let http_server = http_server(port, sessions, auth_config, build_server);
101
102    http_server.await?;
103    Ok(())
104}
105
106pub async fn http_server(
107    port: u16,
108    sessions: Arc<Mutex<HashMap<String, ServerHttpTransport>>>,
109    auth_config: Option<AuthConfig>,
110    build_server: Arc<
111        dyn Fn(
112                ServerHttpTransport,
113                Option<serde_json::Value>,
114                String,
115            )
116                -> futures::future::BoxFuture<'static, Result<Server<ServerHttpTransport>>>
117            + Send
118            + Sync,
119    >,
120) -> std::result::Result<(), std::io::Error> {
121    let session_state = SessionState {
122        sessions,
123        build_server,
124        endpoint: format!("http://0.0.0.0:{}", port),
125    };
126
127    let server = HttpServer::new(move || {
128        let session_state = session_state.clone();
129        App::new()
130            .wrap(Logger::default())
131            .wrap(JwtAuth::new(auth_config.clone()))
132            .app_data(web::Data::new(session_state))
133            .route("/sse", web::get().to(sse_handler))
134            .route("/message", web::post().to(message_handler))
135            .route("/ws", web::get().to(ws_handler))
136    })
137    .bind(("0.0.0.0", port))?
138    .run();
139
140    server.await
141}
142
143pub async fn sse_handler(
144    req: actix_web::HttpRequest,
145    session_state: web::Data<SessionState>,
146) -> HttpResponse {
147    let endpoint = req.extensions().get::<Endpoint>().cloned();
148    let session_metadata = req.extensions().get::<serde_json::Value>().cloned();
149    let client_ip = req
150        .peer_addr()
151        .map(|addr| addr.ip().to_string())
152        .unwrap_or_else(|| "unknown".to_string());
153
154    debug!("New SSE connection request from {}", client_ip);
155
156    // Create new session
157    let session_id = Uuid::new_v4().to_string();
158
159    // Create channel for SSE messages
160    let (sse_tx, sse_rx) = broadcast::channel(100);
161
162    // Create new transport for this session
163    let transport = ServerHttpTransport::Sse(ServerSseTransport::new(sse_tx.clone()));
164
165    // Store transport in sessions map
166    session_state
167        .sessions
168        .lock()
169        .unwrap()
170        .insert(session_id.clone(), transport.clone());
171
172    debug!(
173        "SSE connection established for {} with session_id {}",
174        client_ip, session_id
175    );
176    let endpoint = endpoint.map_or(session_state.endpoint.clone(), |e| e.0);
177    // Create initial endpoint info event
178    let endpoint_info =
179        format!("event: endpoint\ndata: {endpoint}/message?sessionId={session_id}\n\n",);
180
181    let stream = futures::stream::once(async move {
182        Ok::<_, std::convert::Infallible>(web::Bytes::from(endpoint_info))
183    })
184    .chain(futures::stream::unfold(sse_rx, move |mut rx| {
185        let client_ip = client_ip.clone();
186        async move {
187            match rx.recv().await {
188                Ok(msg) => {
189                    // Show first and last 500 characters for debugging
190                    let json = serde_json::to_string(&msg).unwrap();
191                    if json.len() > 1000 {
192                        let first = &json[..500];
193                        let last = &json[json.len() - 500..];
194                        debug!("Sending SSE message to {}: {}...{}", client_ip, first, last);
195                    } else {
196                        debug!("Sending SSE message to {}: {}", client_ip, json);
197                    }
198                    let sse_data = format!("data: {}\n\n", json);
199                    Some((
200                        Ok::<_, std::convert::Infallible>(web::Bytes::from(sse_data)),
201                        rx,
202                    ))
203                }
204                _ => None,
205            }
206        }
207    }));
208
209    // Create and start server instance for this session
210    let transport_clone = transport.clone();
211    let build_server = session_state.build_server.clone();
212    let session_metadata = session_metadata.clone();
213    let ses_id = session_id.clone();
214    tokio::spawn(async move {
215        match build_server(transport_clone, session_metadata, ses_id.clone()).await {
216            Ok(server) => {
217                if let Err(e) = server.listen().await {
218                    error!("Server error: {:?}", e);
219                }
220            }
221            Err(e) => {
222                error!("Failed to build server: {:?}", e);
223            }
224        }
225    });
226
227    HttpResponse::Ok()
228        .append_header(("X-Session-Id", session_id))
229        .content_type("text/event-stream")
230        .streaming(stream)
231}
232
233pub async fn message_handler(
234    query: Query<MessageQuery>,
235    message: web::Json<Message>,
236    session_state: web::Data<SessionState>,
237) -> HttpResponse {
238    if let Some(session_id) = &query.session_id {
239        let sessions = session_state.sessions.lock().unwrap();
240        if let Some(transport) = sessions.get(session_id) {
241            match transport {
242                ServerHttpTransport::Sse(sse) => match sse.send_message(message.into_inner()).await
243                {
244                    Ok(_) => {
245                        debug!("Successfully sent message to session {}", session_id);
246                        HttpResponse::Accepted().finish()
247                    }
248                    Err(e) => {
249                        error!("Failed to send message to session {}: {:?}", session_id, e);
250                        HttpResponse::InternalServerError().finish()
251                    }
252                },
253                ServerHttpTransport::Ws(_) => HttpResponse::BadRequest()
254                    .body("Cannot send message to WebSocket connection through HTTP endpoint"),
255            }
256        } else {
257            HttpResponse::NotFound().body(format!("Session {} not found", session_id))
258        }
259    } else {
260        HttpResponse::BadRequest().body("Session ID not specified")
261    }
262}
263
264pub async fn ws_handler(
265    req: actix_web::HttpRequest,
266    body: Payload,
267    session_state: web::Data<SessionState>,
268) -> Result<HttpResponse, actix_web::Error> {
269    let session_metadata = req.extensions().get::<serde_json::Value>().cloned();
270
271    let (response, session, msg_stream) = actix_ws::handle(&req, body)?;
272
273    let client_ip = req
274        .peer_addr()
275        .map(|addr| addr.ip().to_string())
276        .unwrap_or_else(|| "unknown".to_string());
277
278    info!("New WebSocket connection from {}", client_ip);
279
280    // Create channels for message passing
281    let (tx, rx) = broadcast::channel(100);
282    let transport =
283        ServerHttpTransport::Ws(ServerWsTransport::new(session.clone(), rx.resubscribe()));
284
285    // Store transport in sessions map
286    let session_id = Uuid::new_v4().to_string();
287    session_state
288        .sessions
289        .lock()
290        .unwrap()
291        .insert(session_id.clone(), transport.clone());
292
293    // Start WebSocket handling in the background
294    actix_web::rt::spawn(async move {
295        let _ = handle_ws_connection(session, msg_stream, tx.clone(), rx.resubscribe()).await;
296    });
297
298    // Spawn server instance
299    let build_server = session_state.build_server.clone();
300    let session_metadata = session_metadata.clone();
301    actix_web::rt::spawn(async move {
302        if let Ok(server) = build_server(transport, session_metadata, session_id.clone()).await {
303            let _ = server.listen().await;
304        }
305    });
306
307    Ok(response)
308}