1use actix_web::middleware::Logger;
2use actix_web::web::Payload;
3use actix_web::web::Query;
4use actix_web::{web, App, HttpResponse, HttpServer};
5use anyhow::Result;
6use futures::StreamExt;
7use uuid::Uuid;
8
9use crate::server::Server;
10use crate::sse::middleware::{AuthConfig, JwtAuth};
11use crate::transport::ServerHttpTransport;
12use crate::transport::{handle_ws_connection, Message, ServerSseTransport, ServerWsTransport};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::{Arc, Mutex};
16use tokio::sync::broadcast;
17use tracing::{debug, error, info};
18
19#[derive(Debug, Serialize, Deserialize)]
22pub struct Claims {
23 pub exp: usize,
24 pub iat: usize,
25}
26
27#[derive(Deserialize)]
28pub struct MessageQuery {
29 #[serde(rename = "sessionId")]
30 session_id: Option<String>,
31}
32
33#[derive(Clone)]
34pub struct SessionState {
35 sessions: Arc<Mutex<HashMap<String, ServerHttpTransport>>>,
36 port: u16,
37 build_server: Arc<
38 dyn Fn(
39 ServerHttpTransport,
40 )
41 -> futures::future::BoxFuture<'static, Result<Server<ServerHttpTransport>>>
42 + Send
43 + Sync,
44 >,
45}
46
47pub async fn run_http_server<F, Fut>(
49 port: u16,
50 jwt_secret: Option<String>,
51 build_server: F,
52) -> Result<()>
53where
54 F: Fn(ServerHttpTransport) -> Fut + Send + Sync + 'static,
55 Fut: futures::Future<Output = Result<Server<ServerHttpTransport>>> + Send + 'static,
56{
57 info!("Starting server on http://127.0.0.1:{}", port);
58 info!("WebSocket endpoint: ws://127.0.0.1:{}/ws", port);
59 info!("SSE endpoint: http://127.0.0.1:{}/sse", port);
60
61 let sessions = Arc::new(Mutex::new(HashMap::new()));
62
63 let build_server =
65 Arc::new(move |t| Box::pin(build_server(t)) as futures::future::BoxFuture<_>);
66
67 let auth_config = jwt_secret.map(|jwt_secret| AuthConfig { jwt_secret });
68 let http_server = http_server(port, sessions, auth_config, build_server);
69
70 http_server.await?;
71 Ok(())
72}
73
74pub async fn http_server(
75 port: u16,
76 sessions: Arc<Mutex<HashMap<String, ServerHttpTransport>>>,
77 auth_config: Option<AuthConfig>,
78 build_server: Arc<
79 dyn Fn(
80 ServerHttpTransport,
81 )
82 -> futures::future::BoxFuture<'static, Result<Server<ServerHttpTransport>>>
83 + Send
84 + Sync,
85 >,
86) -> std::result::Result<(), std::io::Error> {
87 let session_state = SessionState {
88 sessions,
89 build_server,
90 port,
91 };
92
93 let server = HttpServer::new(move || {
94 let session_state = session_state.clone();
95 App::new()
96 .wrap(Logger::default())
97 .wrap(JwtAuth::new(auth_config.clone()))
98 .app_data(web::Data::new(session_state))
99 .route("/sse", web::get().to(sse_handler))
100 .route("/message", web::post().to(message_handler))
101 .route("/ws", web::get().to(ws_handler))
102 })
103 .bind(("127.0.0.1", port))?
104 .run();
105
106 server.await
107}
108
109pub async fn sse_handler(
110 req: actix_web::HttpRequest,
111 session_state: web::Data<SessionState>,
112) -> HttpResponse {
113 let client_ip = req
114 .peer_addr()
115 .map(|addr| addr.ip().to_string())
116 .unwrap_or_else(|| "unknown".to_string());
117
118 info!("New SSE connection request from {}", client_ip);
119
120 let session_id = Uuid::new_v4().to_string();
122
123 let (sse_tx, sse_rx) = broadcast::channel(100);
125
126 let transport = ServerHttpTransport::Sse(ServerSseTransport::new(sse_tx.clone()));
128
129 session_state
131 .sessions
132 .lock()
133 .unwrap()
134 .insert(session_id.clone(), transport.clone());
135
136 info!(
137 "SSE connection established for {} with session_id {}",
138 client_ip, session_id
139 );
140 let port = session_state.port;
141 let endpoint_info = format!(
143 "event: endpoint\ndata: http://127.0.0.1:{port}/message?sessionId={session_id}\n\n",
144 );
145
146 let stream = futures::stream::once(async move {
147 Ok::<_, std::convert::Infallible>(web::Bytes::from(endpoint_info))
148 })
149 .chain(futures::stream::unfold(sse_rx, move |mut rx| {
150 let client_ip = client_ip.clone();
151 async move {
152 match rx.recv().await {
153 Ok(msg) => {
154 debug!("Sending SSE message to {}: {:?}", client_ip, msg);
155 let json = serde_json::to_string(&msg).unwrap();
156 let sse_data = format!("data: {}\n\n", json);
157 Some((
158 Ok::<_, std::convert::Infallible>(web::Bytes::from(sse_data)),
159 rx,
160 ))
161 }
162 _ => None,
163 }
164 }
165 }));
166
167 let transport_clone = transport.clone();
169 let build_server = session_state.build_server.clone();
170 tokio::spawn(async move {
171 match build_server(transport_clone).await {
172 Ok(server) => {
173 if let Err(e) = server.listen().await {
174 error!("Server error: {:?}", e);
175 }
176 }
177 Err(e) => {
178 error!("Failed to build server: {:?}", e);
179 }
180 }
181 });
182
183 HttpResponse::Ok()
184 .append_header(("X-Session-Id", session_id))
185 .content_type("text/event-stream")
186 .streaming(stream)
187}
188
189async fn message_handler(
190 query: Query<MessageQuery>,
191 message: web::Json<Message>,
192 session_state: web::Data<SessionState>,
193) -> HttpResponse {
194 if let Some(session_id) = &query.session_id {
195 let sessions = session_state.sessions.lock().unwrap();
196 if let Some(transport) = sessions.get(session_id) {
197 match transport {
198 ServerHttpTransport::Sse(sse) => match sse.send_message(message.into_inner()).await
199 {
200 Ok(_) => {
201 debug!("Successfully sent message to session {}", session_id);
202 HttpResponse::Accepted().finish()
203 }
204 Err(e) => {
205 error!("Failed to send message to session {}: {:?}", session_id, e);
206 HttpResponse::InternalServerError().finish()
207 }
208 },
209 ServerHttpTransport::Ws(_) => HttpResponse::BadRequest()
210 .body("Cannot send message to WebSocket connection through HTTP endpoint"),
211 }
212 } else {
213 HttpResponse::NotFound().body(format!("Session {} not found", session_id))
214 }
215 } else {
216 HttpResponse::BadRequest().body("Session ID not specified")
217 }
218}
219
220async fn ws_handler(
221 req: actix_web::HttpRequest,
222 body: Payload,
223 session_state: web::Data<SessionState>,
224) -> Result<HttpResponse, actix_web::Error> {
225 let (response, session, msg_stream) = actix_ws::handle(&req, body)?;
226
227 let client_ip = req
228 .peer_addr()
229 .map(|addr| addr.ip().to_string())
230 .unwrap_or_else(|| "unknown".to_string());
231
232 info!("New WebSocket connection from {}", client_ip);
233
234 let (tx, rx) = broadcast::channel(100);
236 let transport =
237 ServerHttpTransport::Ws(ServerWsTransport::new(session.clone(), rx.resubscribe()));
238
239 let session_id = Uuid::new_v4().to_string();
241 session_state
242 .sessions
243 .lock()
244 .unwrap()
245 .insert(session_id, transport.clone());
246
247 actix_web::rt::spawn(async move {
249 let _ = handle_ws_connection(session, msg_stream, tx.clone(), rx.resubscribe()).await;
250 });
251
252 let build_server = session_state.build_server.clone();
254 actix_web::rt::spawn(async move {
255 if let Ok(server) = build_server(transport).await {
256 let _ = server.listen().await;
257 }
258 });
259
260 Ok(response)
261}