1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use actix_web::{web, HttpRequest, HttpResponse};
5use actix_ws::Message;
6use futures_util::StreamExt;
7use shaperail_core::{AuthRule, ChannelDefinition, WsClientMessage, WsServerMessage};
8use tokio::sync::mpsc;
9
10use crate::auth::jwt::{Claims, JwtConfig};
11
12use super::pubsub::{PubSubMessage, RedisPubSub};
13use super::room::RoomManager;
14
15const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
17
18const CLIENT_TIMEOUT: Duration = Duration::from_secs(60);
20
21struct SessionConfig {
23 session_id: String,
24 room_manager: RoomManager,
25 pubsub: RedisPubSub,
26 channel_name: String,
27 rooms_enabled: bool,
28}
29
30pub struct WsChannelState {
32 pub definition: ChannelDefinition,
33 pub room_manager: RoomManager,
34 pub pubsub: RedisPubSub,
35 pub jwt_config: Arc<JwtConfig>,
36}
37
38pub async fn ws_handler(
43 req: HttpRequest,
44 body: web::Payload,
45 state: web::Data<WsChannelState>,
46) -> Result<HttpResponse, actix_web::Error> {
47 let token = extract_token(&req);
49
50 let claims = match validate_ws_auth(&state.definition, &state.jwt_config, token.as_deref()) {
52 Ok(c) => c,
53 Err(response) => return Ok(response),
54 };
55
56 let (response, session, stream) = actix_ws::handle(&req, body)?;
57
58 let session_id = uuid::Uuid::new_v4().to_string();
59 let room_manager = state.room_manager.clone();
60 let pubsub = state.pubsub.clone();
61 let channel_name = state.definition.channel.clone();
62 let rooms_enabled = state.definition.rooms;
63
64 tracing::info!(
65 session_id = %session_id,
66 channel = %channel_name,
67 user_id = %claims.as_ref().map(|c| c.sub.as_str()).unwrap_or("anonymous"),
68 "WebSocket connected"
69 );
70
71 let config = SessionConfig {
72 session_id,
73 room_manager,
74 pubsub,
75 channel_name,
76 rooms_enabled,
77 };
78
79 actix_web::rt::spawn(ws_session(config, session, stream));
81
82 Ok(response)
83}
84
85fn extract_token(req: &HttpRequest) -> Option<String> {
87 let query = req.query_string();
88 for pair in query.split('&') {
90 if let Some(value) = pair.strip_prefix("token=") {
91 return Some(value.to_string());
92 }
93 }
94 None
95}
96
97fn validate_ws_auth(
102 definition: &ChannelDefinition,
103 jwt_config: &JwtConfig,
104 token: Option<&str>,
105) -> Result<Option<Claims>, HttpResponse> {
106 let auth = match &definition.auth {
107 Some(auth) => auth,
108 None => return Ok(None), };
110
111 if auth.is_public() {
112 return Ok(None);
113 }
114
115 let token = token.ok_or_else(|| {
116 HttpResponse::Unauthorized().json(serde_json::json!({
117 "error": {
118 "code": "UNAUTHORIZED",
119 "status": 401,
120 "message": "WebSocket connection requires authentication"
121 }
122 }))
123 })?;
124
125 let claims = jwt_config.decode(token).map_err(|_| {
126 HttpResponse::Unauthorized().json(serde_json::json!({
127 "error": {
128 "code": "UNAUTHORIZED",
129 "status": 401,
130 "message": "Invalid or expired token"
131 }
132 }))
133 })?;
134
135 if let AuthRule::Roles(roles) = auth {
137 if !roles.iter().any(|r| r == &claims.role || r == "owner") {
138 return Err(HttpResponse::Forbidden().json(serde_json::json!({
139 "error": {
140 "code": "FORBIDDEN",
141 "status": 403,
142 "message": "Insufficient permissions for this channel"
143 }
144 })));
145 }
146 }
147
148 Ok(Some(claims))
149}
150
151async fn ws_session(
153 config: SessionConfig,
154 mut session: actix_ws::Session,
155 mut stream: actix_ws::MessageStream,
156) {
157 let SessionConfig {
158 session_id,
159 room_manager,
160 pubsub,
161 channel_name,
162 rooms_enabled,
163 } = config;
164
165 let (tx, mut rx) = mpsc::unbounded_channel::<String>();
167 room_manager.register_session(&session_id, tx).await;
168
169 let mut last_heartbeat = Instant::now();
170
171 let heartbeat_session = session.clone();
173 let heartbeat_handle = actix_web::rt::spawn(heartbeat_loop(heartbeat_session));
174
175 loop {
176 tokio::select! {
177 Some(text) = rx.recv() => {
179 if session.text(text).await.is_err() {
180 break;
181 }
182 }
183
184 frame = stream.next() => {
186 match frame {
187 Some(Ok(Message::Text(text))) => {
188 last_heartbeat = Instant::now();
189 handle_text_message(
190 &session_id,
191 &text,
192 &mut session,
193 &room_manager,
194 &pubsub,
195 &channel_name,
196 rooms_enabled,
197 ).await;
198 }
199 Some(Ok(Message::Ping(bytes))) => {
200 last_heartbeat = Instant::now();
201 if session.pong(&bytes).await.is_err() {
202 break;
203 }
204 }
205 Some(Ok(Message::Pong(_))) => {
206 last_heartbeat = Instant::now();
207 }
208 Some(Ok(Message::Close(reason))) => {
209 tracing::info!(
210 session_id = %session_id,
211 "Client initiated close"
212 );
213 let _ = session.close(reason).await;
214 break;
215 }
216 Some(Ok(Message::Continuation(_))) => {
217 }
219 Some(Ok(Message::Binary(_))) => {
220 let err_msg = WsServerMessage::Error {
221 message: "Binary messages not supported".to_string(),
222 };
223 if let Ok(json) = serde_json::to_string(&err_msg) {
224 let _ = session.text(json).await;
225 }
226 }
227 Some(Ok(Message::Nop)) => {}
228 Some(Err(e)) => {
229 tracing::warn!(
230 session_id = %session_id,
231 error = %e,
232 "WebSocket protocol error"
233 );
234 break;
235 }
236 None => break,
237 }
238 }
239
240 _ = tokio::time::sleep(Duration::from_secs(5)) => {
242 if last_heartbeat.elapsed() > CLIENT_TIMEOUT {
243 tracing::info!(
244 session_id = %session_id,
245 "Client heartbeat timeout, disconnecting"
246 );
247 let _ = session.close(None).await;
248 break;
249 }
250 }
251 }
252 }
253
254 heartbeat_handle.abort();
256 room_manager.remove_session(&session_id).await;
257 tracing::info!(session_id = %session_id, "WebSocket disconnected");
258}
259
260async fn heartbeat_loop(mut session: actix_ws::Session) {
262 let mut interval = tokio::time::interval(HEARTBEAT_INTERVAL);
263 loop {
264 interval.tick().await;
265 let ping = WsServerMessage::Ping;
267 if let Ok(json) = serde_json::to_string(&ping) {
268 if session.text(json).await.is_err() {
269 break;
270 }
271 }
272 }
273}
274
275async fn handle_text_message(
277 session_id: &str,
278 text: &str,
279 session: &mut actix_ws::Session,
280 room_manager: &RoomManager,
281 pubsub: &RedisPubSub,
282 channel_name: &str,
283 rooms_enabled: bool,
284) {
285 let msg: WsClientMessage = match serde_json::from_str(text) {
286 Ok(m) => m,
287 Err(e) => {
288 let err = WsServerMessage::Error {
289 message: format!("Invalid message format: {e}"),
290 };
291 if let Ok(json) = serde_json::to_string(&err) {
292 let _ = session.text(json).await;
293 }
294 return;
295 }
296 };
297
298 match msg {
299 WsClientMessage::Subscribe { room } => {
300 if !rooms_enabled {
301 let err = WsServerMessage::Error {
302 message: "Room subscriptions not enabled for this channel".to_string(),
303 };
304 if let Ok(json) = serde_json::to_string(&err) {
305 let _ = session.text(json).await;
306 }
307 return;
308 }
309 room_manager.subscribe(session_id, &room).await;
310 let ack = WsServerMessage::Subscribed { room };
311 if let Ok(json) = serde_json::to_string(&ack) {
312 let _ = session.text(json).await;
313 }
314 }
315 WsClientMessage::Unsubscribe { room } => {
316 room_manager.unsubscribe(session_id, &room).await;
317 let ack = WsServerMessage::Unsubscribed { room };
318 if let Ok(json) = serde_json::to_string(&ack) {
319 let _ = session.text(json).await;
320 }
321 }
322 WsClientMessage::Message { room, data } => {
323 if !rooms_enabled {
324 let err = WsServerMessage::Error {
325 message: "Room messaging not enabled for this channel".to_string(),
326 };
327 if let Ok(json) = serde_json::to_string(&err) {
328 let _ = session.text(json).await;
329 }
330 return;
331 }
332 let pub_msg = PubSubMessage {
334 channel: channel_name.to_string(),
335 room: room.clone(),
336 event: "message".to_string(),
337 data,
338 };
339 if let Err(e) = pubsub.publish(&pub_msg).await {
340 tracing::warn!(error = %e, "Failed to publish message via Redis");
341 let server_msg = WsServerMessage::Broadcast {
343 room: room.clone(),
344 event: "message".to_string(),
345 data: pub_msg.data,
346 };
347 if let Ok(json) = serde_json::to_string(&server_msg) {
348 room_manager.broadcast_to_room(&room, &json).await;
349 }
350 }
351 }
352 WsClientMessage::Pong => {
353 }
355 }
356}
357
358pub fn configure_ws_routes(
360 cfg: &mut web::ServiceConfig,
361 definition: ChannelDefinition,
362 room_manager: RoomManager,
363 pubsub: RedisPubSub,
364 jwt_config: Arc<JwtConfig>,
365) {
366 let channel_name = definition.channel.clone();
367 let state = web::Data::new(WsChannelState {
368 definition,
369 room_manager,
370 pubsub,
371 jwt_config,
372 });
373
374 cfg.app_data(state)
375 .route(&format!("/ws/{channel_name}"), web::get().to(ws_handler));
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn validate_public_channel() {
384 let def = ChannelDefinition {
385 channel: "public".to_string(),
386 auth: Some(AuthRule::Public),
387 rooms: false,
388 hooks: None,
389 };
390 let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
391 let result = validate_ws_auth(&def, &jwt, None);
392 assert!(result.is_ok());
393 assert!(result.unwrap().is_none());
394 }
395
396 #[test]
397 fn validate_no_auth_channel() {
398 let def = ChannelDefinition {
399 channel: "open".to_string(),
400 auth: None,
401 rooms: false,
402 hooks: None,
403 };
404 let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
405 let result = validate_ws_auth(&def, &jwt, None);
406 assert!(result.is_ok());
407 }
408
409 #[test]
410 fn validate_auth_no_token_returns_401() {
411 let def = ChannelDefinition {
412 channel: "private".to_string(),
413 auth: Some(AuthRule::Roles(vec!["admin".to_string()])),
414 rooms: false,
415 hooks: None,
416 };
417 let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
418 let result = validate_ws_auth(&def, &jwt, None);
419 assert!(result.is_err());
420 }
421
422 #[test]
423 fn validate_auth_invalid_token_returns_401() {
424 let def = ChannelDefinition {
425 channel: "private".to_string(),
426 auth: Some(AuthRule::Roles(vec!["admin".to_string()])),
427 rooms: false,
428 hooks: None,
429 };
430 let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
431 let result = validate_ws_auth(&def, &jwt, Some("invalid.token.here"));
432 assert!(result.is_err());
433 }
434
435 #[test]
436 fn validate_auth_valid_token_correct_role() {
437 let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
438 let token = jwt.encode_access("user-1", "admin").unwrap();
439
440 let def = ChannelDefinition {
441 channel: "private".to_string(),
442 auth: Some(AuthRule::Roles(vec!["admin".to_string()])),
443 rooms: false,
444 hooks: None,
445 };
446 let result = validate_ws_auth(&def, &jwt, Some(&token));
447 assert!(result.is_ok());
448 let claims = result.unwrap().unwrap();
449 assert_eq!(claims.sub, "user-1");
450 assert_eq!(claims.role, "admin");
451 }
452
453 #[test]
454 fn validate_auth_valid_token_wrong_role() {
455 let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
456 let token = jwt.encode_access("user-1", "viewer").unwrap();
457
458 let def = ChannelDefinition {
459 channel: "private".to_string(),
460 auth: Some(AuthRule::Roles(vec!["admin".to_string()])),
461 rooms: false,
462 hooks: None,
463 };
464 let result = validate_ws_auth(&def, &jwt, Some(&token));
465 assert!(result.is_err());
466 }
467
468 #[test]
469 fn extract_token_from_query() {
470 fn parse_token(query: &str) -> Option<String> {
473 for pair in query.split('&') {
474 if let Some(value) = pair.strip_prefix("token=") {
475 return Some(value.to_string());
476 }
477 }
478 None
479 }
480
481 assert_eq!(parse_token("token=abc123"), Some("abc123".to_string()));
482 assert_eq!(parse_token("foo=bar&token=xyz"), Some("xyz".to_string()));
483 assert_eq!(parse_token("foo=bar"), None);
484 assert_eq!(parse_token(""), None);
485 }
486}