1use actix_web::middleware::Logger;
2use actix_web::web::Payload;
3use actix_web::web::Query;
4use actix_web::HttpMessage;
5use actix_web::{web, App, HttpResponse, HttpServer};
6use anyhow::Result;
7use futures::StreamExt;
8use uuid::Uuid;
9
10use crate::server::Server;
11use crate::sse::middleware::{AuthConfig, JwtAuth};
12use crate::transport::ServerHttpTransport;
13use crate::transport::{handle_ws_connection, Message, ServerSseTransport, ServerWsTransport};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::fmt::Debug;
17use std::sync::{Arc, Mutex};
18use tokio::sync::broadcast;
19use tracing::{debug, error, info};
20
21#[derive(Debug, Serialize, Deserialize)]
24pub struct Claims {
25 pub exp: usize,
26 pub iat: usize,
27}
28
29#[derive(Clone)]
30pub struct Endpoint(pub String);
31
32#[derive(Deserialize)]
33pub struct MessageQuery {
34 #[serde(rename = "sessionId")]
35 session_id: Option<String>,
36}
37
38#[derive(Clone)]
39pub struct SessionState {
40 sessions: Arc<Mutex<HashMap<String, ServerHttpTransport>>>,
41 build_server: Arc<
42 dyn Fn(
43 ServerHttpTransport,
44 Option<serde_json::Value>,
45 String,
46 )
47 -> futures::future::BoxFuture<'static, Result<Server<ServerHttpTransport>>>
48 + Send
49 + Sync,
50 >,
51 endpoint: String,
52}
53
54impl SessionState {
55 pub fn new(
57 endpoint: String,
58 build_server: Arc<
59 dyn Fn(
60 ServerHttpTransport,
61 Option<serde_json::Value>,
62 String,
63 )
64 -> futures::future::BoxFuture<'static, Result<Server<ServerHttpTransport>>>
65 + Send
66 + Sync,
67 >,
68 sessions: Arc<Mutex<HashMap<String, ServerHttpTransport>>>,
69 ) -> Self {
70 Self {
71 sessions,
72 build_server,
73 endpoint,
74 }
75 }
76}
77
78pub async fn run_http_server<F, Fut>(
80 port: u16,
81 jwt_secret: Option<String>,
82 build_server: F,
83) -> Result<()>
84where
85 F: Fn(ServerHttpTransport, Option<serde_json::Value>, String) -> Fut + Send + Sync + 'static,
86 Fut: futures::Future<Output = Result<Server<ServerHttpTransport>>> + Send + 'static,
87{
88 info!("Starting server on http://0.0.0.0:{}", port);
89 info!("WebSocket endpoint: ws://0.0.0.0:{}/ws", port);
90 info!("SSE endpoint: http://0.0.0.0:{}/sse", port);
91
92 let sessions = Arc::new(Mutex::new(HashMap::new()));
93
94 let build_server = Arc::new(move |t, o, session_id| {
96 Box::pin(build_server(t, o, session_id)) as futures::future::BoxFuture<_>
97 });
98
99 let auth_config = jwt_secret.map(|jwt_secret| AuthConfig { jwt_secret });
100 let http_server = http_server(port, sessions, auth_config, build_server);
101
102 http_server.await?;
103 Ok(())
104}
105
106pub async fn http_server(
107 port: u16,
108 sessions: Arc<Mutex<HashMap<String, ServerHttpTransport>>>,
109 auth_config: Option<AuthConfig>,
110 build_server: Arc<
111 dyn Fn(
112 ServerHttpTransport,
113 Option<serde_json::Value>,
114 String,
115 )
116 -> futures::future::BoxFuture<'static, Result<Server<ServerHttpTransport>>>
117 + Send
118 + Sync,
119 >,
120) -> std::result::Result<(), std::io::Error> {
121 let session_state = SessionState {
122 sessions,
123 build_server,
124 endpoint: format!("http://0.0.0.0:{}", port),
125 };
126
127 let server = HttpServer::new(move || {
128 let session_state = session_state.clone();
129 App::new()
130 .wrap(Logger::default())
131 .wrap(JwtAuth::new(auth_config.clone()))
132 .app_data(web::Data::new(session_state))
133 .route("/sse", web::get().to(sse_handler))
134 .route("/message", web::post().to(message_handler))
135 .route("/ws", web::get().to(ws_handler))
136 })
137 .bind(("0.0.0.0", port))?
138 .run();
139
140 server.await
141}
142
143pub async fn sse_handler(
144 req: actix_web::HttpRequest,
145 session_state: web::Data<SessionState>,
146) -> HttpResponse {
147 let endpoint = req.extensions().get::<Endpoint>().cloned();
148 let session_metadata = req.extensions().get::<serde_json::Value>().cloned();
149 let client_ip = req
150 .peer_addr()
151 .map(|addr| addr.ip().to_string())
152 .unwrap_or_else(|| "unknown".to_string());
153
154 debug!("New SSE connection request from {}", client_ip);
155
156 let session_id = Uuid::new_v4().to_string();
158
159 let (sse_tx, sse_rx) = broadcast::channel(100);
161
162 let transport = ServerHttpTransport::Sse(ServerSseTransport::new(sse_tx.clone()));
164
165 session_state
167 .sessions
168 .lock()
169 .unwrap()
170 .insert(session_id.clone(), transport.clone());
171
172 debug!(
173 "SSE connection established for {} with session_id {}",
174 client_ip, session_id
175 );
176 let endpoint = endpoint.map_or(session_state.endpoint.clone(), |e| e.0);
177 let endpoint_info =
179 format!("event: endpoint\ndata: {endpoint}/message?sessionId={session_id}\n\n",);
180
181 let stream = futures::stream::once(async move {
182 Ok::<_, std::convert::Infallible>(web::Bytes::from(endpoint_info))
183 })
184 .chain(futures::stream::unfold(sse_rx, move |mut rx| {
185 let client_ip = client_ip.clone();
186 async move {
187 match rx.recv().await {
188 Ok(msg) => {
189 let json = serde_json::to_string(&msg).unwrap();
191 if json.len() > 1000 {
192 let first = &json[..500];
193 let last = &json[json.len() - 500..];
194 debug!("Sending SSE message to {}: {}...{}", client_ip, first, last);
195 } else {
196 debug!("Sending SSE message to {}: {}", client_ip, json);
197 }
198 let sse_data = format!("data: {}\n\n", json);
199 Some((
200 Ok::<_, std::convert::Infallible>(web::Bytes::from(sse_data)),
201 rx,
202 ))
203 }
204 _ => None,
205 }
206 }
207 }));
208
209 let transport_clone = transport.clone();
211 let build_server = session_state.build_server.clone();
212 let session_metadata = session_metadata.clone();
213 let ses_id = session_id.clone();
214 tokio::spawn(async move {
215 match build_server(transport_clone, session_metadata, ses_id.clone()).await {
216 Ok(server) => {
217 if let Err(e) = server.listen().await {
218 error!("Server error: {:?}", e);
219 }
220 }
221 Err(e) => {
222 error!("Failed to build server: {:?}", e);
223 }
224 }
225 });
226
227 HttpResponse::Ok()
228 .append_header(("X-Session-Id", session_id))
229 .content_type("text/event-stream")
230 .streaming(stream)
231}
232
233pub async fn message_handler(
234 query: Query<MessageQuery>,
235 message: web::Json<Message>,
236 session_state: web::Data<SessionState>,
237) -> HttpResponse {
238 if let Some(session_id) = &query.session_id {
239 let sessions = session_state.sessions.lock().unwrap();
240 if let Some(transport) = sessions.get(session_id) {
241 match transport {
242 ServerHttpTransport::Sse(sse) => match sse.send_message(message.into_inner()).await
243 {
244 Ok(_) => {
245 debug!("Successfully sent message to session {}", session_id);
246 HttpResponse::Accepted().finish()
247 }
248 Err(e) => {
249 error!("Failed to send message to session {}: {:?}", session_id, e);
250 HttpResponse::InternalServerError().finish()
251 }
252 },
253 ServerHttpTransport::Ws(_) => HttpResponse::BadRequest()
254 .body("Cannot send message to WebSocket connection through HTTP endpoint"),
255 }
256 } else {
257 HttpResponse::NotFound().body(format!("Session {} not found", session_id))
258 }
259 } else {
260 HttpResponse::BadRequest().body("Session ID not specified")
261 }
262}
263
264pub async fn ws_handler(
265 req: actix_web::HttpRequest,
266 body: Payload,
267 session_state: web::Data<SessionState>,
268) -> Result<HttpResponse, actix_web::Error> {
269 let session_metadata = req.extensions().get::<serde_json::Value>().cloned();
270
271 let (response, session, msg_stream) = actix_ws::handle(&req, body)?;
272
273 let client_ip = req
274 .peer_addr()
275 .map(|addr| addr.ip().to_string())
276 .unwrap_or_else(|| "unknown".to_string());
277
278 info!("New WebSocket connection from {}", client_ip);
279
280 let (tx, rx) = broadcast::channel(100);
282 let transport =
283 ServerHttpTransport::Ws(ServerWsTransport::new(session.clone(), rx.resubscribe()));
284
285 let session_id = Uuid::new_v4().to_string();
287 session_state
288 .sessions
289 .lock()
290 .unwrap()
291 .insert(session_id.clone(), transport.clone());
292
293 actix_web::rt::spawn(async move {
295 let _ = handle_ws_connection(session, msg_stream, tx.clone(), rx.resubscribe()).await;
296 });
297
298 let build_server = session_state.build_server.clone();
300 let session_metadata = session_metadata.clone();
301 actix_web::rt::spawn(async move {
302 if let Ok(server) = build_server(transport, session_metadata, session_id.clone()).await {
303 let _ = server.listen().await;
304 }
305 });
306
307 Ok(response)
308}