Skip to main content

shaperail_runtime/ws/
session.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use actix_web::{web, HttpRequest, HttpResponse};
5use actix_ws::Message;
6use futures_util::StreamExt;
7use shaperail_core::{AuthRule, ChannelDefinition, WsClientMessage, WsServerMessage};
8use tokio::sync::mpsc;
9
10use crate::auth::jwt::{Claims, JwtConfig};
11
12use super::pubsub::{PubSubMessage, RedisPubSub};
13use super::room::RoomManager;
14
15/// Heartbeat interval — server sends ping every 30 seconds.
16const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
17
18/// Client timeout — disconnect if no pong received within 60 seconds.
19const CLIENT_TIMEOUT: Duration = Duration::from_secs(60);
20
21/// Configuration passed to each WebSocket session task.
22struct SessionConfig {
23    session_id: String,
24    room_manager: RoomManager,
25    pubsub: RedisPubSub,
26    channel_name: String,
27    rooms_enabled: bool,
28}
29
30/// Shared state for a WebSocket channel, stored in Actix app data.
31pub struct WsChannelState {
32    pub definition: ChannelDefinition,
33    pub room_manager: RoomManager,
34    pub pubsub: RedisPubSub,
35    pub jwt_config: Arc<JwtConfig>,
36}
37
38/// HTTP handler for WebSocket upgrade at `/ws/<channel>`.
39///
40/// Validates JWT from query parameter `?token=<jwt>` before upgrading.
41/// Returns 401 if auth fails (before WebSocket handshake completes).
42pub async fn ws_handler(
43    req: HttpRequest,
44    body: web::Payload,
45    state: web::Data<WsChannelState>,
46) -> Result<HttpResponse, actix_web::Error> {
47    // Extract JWT from query string
48    let token = extract_token(&req);
49
50    // Validate auth before upgrade
51    let claims = match validate_ws_auth(&state.definition, &state.jwt_config, token.as_deref()) {
52        Ok(c) => c,
53        Err(response) => return Ok(response),
54    };
55
56    let (response, session, stream) = actix_ws::handle(&req, body)?;
57
58    let session_id = uuid::Uuid::new_v4().to_string();
59    let room_manager = state.room_manager.clone();
60    let pubsub = state.pubsub.clone();
61    let channel_name = state.definition.channel.clone();
62    let rooms_enabled = state.definition.rooms;
63
64    tracing::info!(
65        session_id = %session_id,
66        channel = %channel_name,
67        user_id = %claims.as_ref().map(|c| c.sub.as_str()).unwrap_or("anonymous"),
68        "WebSocket connected"
69    );
70
71    let config = SessionConfig {
72        session_id,
73        room_manager,
74        pubsub,
75        channel_name,
76        rooms_enabled,
77    };
78
79    // Spawn the session task on the Actix runtime (not Send-bound)
80    actix_web::rt::spawn(ws_session(config, session, stream));
81
82    Ok(response)
83}
84
85/// Extracts the JWT token from query parameters.
86fn extract_token(req: &HttpRequest) -> Option<String> {
87    let query = req.query_string();
88    // Simple query string parsing without the `url` crate
89    for pair in query.split('&') {
90        if let Some(value) = pair.strip_prefix("token=") {
91            return Some(value.to_string());
92        }
93    }
94    None
95}
96
97/// Validates WebSocket auth before upgrading the connection.
98///
99/// Returns Ok(Some(claims)) for authenticated users, Ok(None) for public channels,
100/// or Err(HttpResponse) with 401 status for auth failures.
101fn validate_ws_auth(
102    definition: &ChannelDefinition,
103    jwt_config: &JwtConfig,
104    token: Option<&str>,
105) -> Result<Option<Claims>, HttpResponse> {
106    let auth = match &definition.auth {
107        Some(auth) => auth,
108        None => return Ok(None), // No auth required
109    };
110
111    if auth.is_public() {
112        return Ok(None);
113    }
114
115    let token = token.ok_or_else(|| {
116        HttpResponse::Unauthorized().json(serde_json::json!({
117            "error": {
118                "code": "UNAUTHORIZED",
119                "status": 401,
120                "message": "WebSocket connection requires authentication"
121            }
122        }))
123    })?;
124
125    let claims = jwt_config.decode(token).map_err(|_| {
126        HttpResponse::Unauthorized().json(serde_json::json!({
127            "error": {
128                "code": "UNAUTHORIZED",
129                "status": 401,
130                "message": "Invalid or expired token"
131            }
132        }))
133    })?;
134
135    // Check role authorization
136    if let AuthRule::Roles(roles) = auth {
137        if !roles.iter().any(|r| r == &claims.role || r == "owner") {
138            return Err(HttpResponse::Forbidden().json(serde_json::json!({
139                "error": {
140                    "code": "FORBIDDEN",
141                    "status": 403,
142                    "message": "Insufficient permissions for this channel"
143                }
144            })));
145        }
146    }
147
148    Ok(Some(claims))
149}
150
151/// Runs a single WebSocket session: heartbeat, message routing, cleanup.
152async fn ws_session(
153    config: SessionConfig,
154    mut session: actix_ws::Session,
155    mut stream: actix_ws::MessageStream,
156) {
157    let SessionConfig {
158        session_id,
159        room_manager,
160        pubsub,
161        channel_name,
162        rooms_enabled,
163    } = config;
164
165    // Register session with room manager
166    let (tx, mut rx) = mpsc::unbounded_channel::<String>();
167    room_manager.register_session(&session_id, tx).await;
168
169    let mut last_heartbeat = Instant::now();
170
171    // Spawn heartbeat ping sender on the Actix runtime
172    let heartbeat_session = session.clone();
173    let heartbeat_handle = actix_web::rt::spawn(heartbeat_loop(heartbeat_session));
174
175    loop {
176        tokio::select! {
177            // Outbound: messages from room manager → client
178            Some(text) = rx.recv() => {
179                if session.text(text).await.is_err() {
180                    break;
181                }
182            }
183
184            // Inbound: messages from client → server
185            frame = stream.next() => {
186                match frame {
187                    Some(Ok(Message::Text(text))) => {
188                        last_heartbeat = Instant::now();
189                        handle_text_message(
190                            &session_id,
191                            &text,
192                            &mut session,
193                            &room_manager,
194                            &pubsub,
195                            &channel_name,
196                            rooms_enabled,
197                        ).await;
198                    }
199                    Some(Ok(Message::Ping(bytes))) => {
200                        last_heartbeat = Instant::now();
201                        if session.pong(&bytes).await.is_err() {
202                            break;
203                        }
204                    }
205                    Some(Ok(Message::Pong(_))) => {
206                        last_heartbeat = Instant::now();
207                    }
208                    Some(Ok(Message::Close(reason))) => {
209                        tracing::info!(
210                            session_id = %session_id,
211                            "Client initiated close"
212                        );
213                        let _ = session.close(reason).await;
214                        break;
215                    }
216                    Some(Ok(Message::Continuation(_))) => {
217                        // Continuation frames not supported
218                    }
219                    Some(Ok(Message::Binary(_))) => {
220                        let err_msg = WsServerMessage::Error {
221                            message: "Binary messages not supported".to_string(),
222                        };
223                        if let Ok(json) = serde_json::to_string(&err_msg) {
224                            let _ = session.text(json).await;
225                        }
226                    }
227                    Some(Ok(Message::Nop)) => {}
228                    Some(Err(e)) => {
229                        tracing::warn!(
230                            session_id = %session_id,
231                            error = %e,
232                            "WebSocket protocol error"
233                        );
234                        break;
235                    }
236                    None => break,
237                }
238            }
239
240            // Heartbeat timeout check
241            _ = tokio::time::sleep(Duration::from_secs(5)) => {
242                if last_heartbeat.elapsed() > CLIENT_TIMEOUT {
243                    tracing::info!(
244                        session_id = %session_id,
245                        "Client heartbeat timeout, disconnecting"
246                    );
247                    let _ = session.close(None).await;
248                    break;
249                }
250            }
251        }
252    }
253
254    // Cleanup
255    heartbeat_handle.abort();
256    room_manager.remove_session(&session_id).await;
257    tracing::info!(session_id = %session_id, "WebSocket disconnected");
258}
259
260/// Sends periodic ping frames to the client.
261async fn heartbeat_loop(mut session: actix_ws::Session) {
262    let mut interval = tokio::time::interval(HEARTBEAT_INTERVAL);
263    loop {
264        interval.tick().await;
265        // Send application-level ping as JSON
266        let ping = WsServerMessage::Ping;
267        if let Ok(json) = serde_json::to_string(&ping) {
268            if session.text(json).await.is_err() {
269                break;
270            }
271        }
272    }
273}
274
275/// Processes an incoming text message from a client.
276async fn handle_text_message(
277    session_id: &str,
278    text: &str,
279    session: &mut actix_ws::Session,
280    room_manager: &RoomManager,
281    pubsub: &RedisPubSub,
282    channel_name: &str,
283    rooms_enabled: bool,
284) {
285    let msg: WsClientMessage = match serde_json::from_str(text) {
286        Ok(m) => m,
287        Err(e) => {
288            let err = WsServerMessage::Error {
289                message: format!("Invalid message format: {e}"),
290            };
291            if let Ok(json) = serde_json::to_string(&err) {
292                let _ = session.text(json).await;
293            }
294            return;
295        }
296    };
297
298    match msg {
299        WsClientMessage::Subscribe { room } => {
300            if !rooms_enabled {
301                let err = WsServerMessage::Error {
302                    message: "Room subscriptions not enabled for this channel".to_string(),
303                };
304                if let Ok(json) = serde_json::to_string(&err) {
305                    let _ = session.text(json).await;
306                }
307                return;
308            }
309            room_manager.subscribe(session_id, &room).await;
310            let ack = WsServerMessage::Subscribed { room };
311            if let Ok(json) = serde_json::to_string(&ack) {
312                let _ = session.text(json).await;
313            }
314        }
315        WsClientMessage::Unsubscribe { room } => {
316            room_manager.unsubscribe(session_id, &room).await;
317            let ack = WsServerMessage::Unsubscribed { room };
318            if let Ok(json) = serde_json::to_string(&ack) {
319                let _ = session.text(json).await;
320            }
321        }
322        WsClientMessage::Message { room, data } => {
323            if !rooms_enabled {
324                let err = WsServerMessage::Error {
325                    message: "Room messaging not enabled for this channel".to_string(),
326                };
327                if let Ok(json) = serde_json::to_string(&err) {
328                    let _ = session.text(json).await;
329                }
330                return;
331            }
332            // Publish via Redis so all instances receive it
333            let pub_msg = PubSubMessage {
334                channel: channel_name.to_string(),
335                room: room.clone(),
336                event: "message".to_string(),
337                data,
338            };
339            if let Err(e) = pubsub.publish(&pub_msg).await {
340                tracing::warn!(error = %e, "Failed to publish message via Redis");
341                // Fall back to local-only broadcast
342                let server_msg = WsServerMessage::Broadcast {
343                    room: room.clone(),
344                    event: "message".to_string(),
345                    data: pub_msg.data,
346                };
347                if let Ok(json) = serde_json::to_string(&server_msg) {
348                    room_manager.broadcast_to_room(&room, &json).await;
349                }
350            }
351        }
352        WsClientMessage::Pong => {
353            // Client pong — heartbeat already updated by caller
354        }
355    }
356}
357
358/// Registers WebSocket routes for a channel on the Actix service config.
359pub fn configure_ws_routes(
360    cfg: &mut web::ServiceConfig,
361    definition: ChannelDefinition,
362    room_manager: RoomManager,
363    pubsub: RedisPubSub,
364    jwt_config: Arc<JwtConfig>,
365) {
366    let channel_name = definition.channel.clone();
367    let state = web::Data::new(WsChannelState {
368        definition,
369        room_manager,
370        pubsub,
371        jwt_config,
372    });
373
374    cfg.app_data(state)
375        .route(&format!("/ws/{channel_name}"), web::get().to(ws_handler));
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn validate_public_channel() {
384        let def = ChannelDefinition {
385            channel: "public".to_string(),
386            auth: Some(AuthRule::Public),
387            rooms: false,
388            hooks: None,
389        };
390        let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
391        let result = validate_ws_auth(&def, &jwt, None);
392        assert!(result.is_ok());
393        assert!(result.unwrap().is_none());
394    }
395
396    #[test]
397    fn validate_no_auth_channel() {
398        let def = ChannelDefinition {
399            channel: "open".to_string(),
400            auth: None,
401            rooms: false,
402            hooks: None,
403        };
404        let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
405        let result = validate_ws_auth(&def, &jwt, None);
406        assert!(result.is_ok());
407    }
408
409    #[test]
410    fn validate_auth_no_token_returns_401() {
411        let def = ChannelDefinition {
412            channel: "private".to_string(),
413            auth: Some(AuthRule::Roles(vec!["admin".to_string()])),
414            rooms: false,
415            hooks: None,
416        };
417        let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
418        let result = validate_ws_auth(&def, &jwt, None);
419        assert!(result.is_err());
420    }
421
422    #[test]
423    fn validate_auth_invalid_token_returns_401() {
424        let def = ChannelDefinition {
425            channel: "private".to_string(),
426            auth: Some(AuthRule::Roles(vec!["admin".to_string()])),
427            rooms: false,
428            hooks: None,
429        };
430        let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
431        let result = validate_ws_auth(&def, &jwt, Some("invalid.token.here"));
432        assert!(result.is_err());
433    }
434
435    #[test]
436    fn validate_auth_valid_token_correct_role() {
437        let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
438        let token = jwt.encode_access("user-1", "admin").unwrap();
439
440        let def = ChannelDefinition {
441            channel: "private".to_string(),
442            auth: Some(AuthRule::Roles(vec!["admin".to_string()])),
443            rooms: false,
444            hooks: None,
445        };
446        let result = validate_ws_auth(&def, &jwt, Some(&token));
447        assert!(result.is_ok());
448        let claims = result.unwrap().unwrap();
449        assert_eq!(claims.sub, "user-1");
450        assert_eq!(claims.role, "admin");
451    }
452
453    #[test]
454    fn validate_auth_valid_token_wrong_role() {
455        let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
456        let token = jwt.encode_access("user-1", "viewer").unwrap();
457
458        let def = ChannelDefinition {
459            channel: "private".to_string(),
460            auth: Some(AuthRule::Roles(vec!["admin".to_string()])),
461            rooms: false,
462            hooks: None,
463        };
464        let result = validate_ws_auth(&def, &jwt, Some(&token));
465        assert!(result.is_err());
466    }
467
468    #[test]
469    fn extract_token_from_query() {
470        // We can't easily construct a full HttpRequest in unit tests,
471        // so we test the parsing logic directly
472        fn parse_token(query: &str) -> Option<String> {
473            for pair in query.split('&') {
474                if let Some(value) = pair.strip_prefix("token=") {
475                    return Some(value.to_string());
476                }
477            }
478            None
479        }
480
481        assert_eq!(parse_token("token=abc123"), Some("abc123".to_string()));
482        assert_eq!(parse_token("foo=bar&token=xyz"), Some("xyz".to_string()));
483        assert_eq!(parse_token("foo=bar"), None);
484        assert_eq!(parse_token(""), None);
485    }
486}