Skip to main content

tuitbot_server/
ws.rs

1//! WebSocket hub for real-time event streaming.
2//!
3//! Provides a `/api/ws` endpoint that streams server events to dashboard clients
4//! via a `tokio::sync::broadcast` channel.
5//!
6//! Supports two authentication methods:
7//! - Query parameter: `?token=<api_token>` (Tauri/API clients)
8//! - Session cookie: `tuitbot_session=<token>` (web/LAN clients)
9
10use std::sync::Arc;
11
12use axum::extract::ws::{Message, WebSocket};
13use axum::extract::{Query, State, WebSocketUpgrade};
14use axum::http::{HeaderMap, StatusCode};
15use axum::response::{IntoResponse, Response};
16use serde::{Deserialize, Serialize};
17use serde_json::json;
18use tuitbot_core::auth::session;
19
20use crate::state::AppState;
21
22/// Wrapper that tags every [`WsEvent`] with the originating account.
23///
24/// Serializes flat thanks to `#[serde(flatten)]`, so the JSON looks like:
25/// `{ "account_id": "...", "type": "ApprovalQueued", ... }`
26#[derive(Clone, Debug, Serialize, Deserialize)]
27pub struct AccountWsEvent {
28    pub account_id: String,
29    #[serde(flatten)]
30    pub event: WsEvent,
31}
32
33/// Events pushed to WebSocket clients.
34#[derive(Clone, Debug, Serialize, Deserialize)]
35#[serde(tag = "type")]
36pub enum WsEvent {
37    /// An automation action was performed (reply, tweet, thread, etc.).
38    ActionPerformed {
39        action_type: String,
40        target: String,
41        content: String,
42        timestamp: String,
43    },
44    /// A new item was queued for approval.
45    ApprovalQueued {
46        id: i64,
47        action_type: String,
48        content: String,
49        #[serde(default)]
50        media_paths: Vec<String>,
51    },
52    /// An approval item's status was updated (approved, rejected, edited).
53    ApprovalUpdated {
54        id: i64,
55        status: String,
56        action_type: String,
57        #[serde(skip_serializing_if = "Option::is_none")]
58        actor: Option<String>,
59    },
60    /// Follower count changed.
61    FollowerUpdate { count: i64, change: i64 },
62    /// Automation runtime status changed.
63    RuntimeStatus {
64        running: bool,
65        active_loops: Vec<String>,
66    },
67    /// A tweet was discovered and scored by the discovery loop.
68    TweetDiscovered {
69        tweet_id: String,
70        author: String,
71        score: f64,
72        timestamp: String,
73    },
74    /// An action was skipped (rate limited, below threshold, safety filter).
75    ActionSkipped {
76        action_type: String,
77        reason: String,
78        timestamp: String,
79    },
80    /// A new content item was scheduled via the composer.
81    ContentScheduled {
82        id: i64,
83        content_type: String,
84        scheduled_for: Option<String>,
85    },
86    /// Circuit breaker state changed.
87    CircuitBreakerTripped {
88        state: String,
89        error_count: u32,
90        cooldown_remaining_seconds: u64,
91        timestamp: String,
92    },
93    /// An error occurred.
94    Error { message: String },
95}
96
97/// Query parameters for WebSocket authentication.
98#[derive(Deserialize)]
99pub struct WsQuery {
100    /// API token passed as a query parameter (optional — cookie auth is fallback).
101    pub token: Option<String>,
102}
103
104/// Extract the session cookie value from headers (exported for tests).
105fn extract_session_cookie(headers: &HeaderMap) -> Option<String> {
106    headers
107        .get("cookie")
108        .and_then(|v| v.to_str().ok())
109        .and_then(|cookies| {
110            cookies.split(';').find_map(|c| {
111                let c = c.trim();
112                c.strip_prefix("tuitbot_session=").map(|v| v.to_string())
113            })
114        })
115}
116
117/// `GET /api/ws` — WebSocket upgrade with token or cookie auth.
118pub async fn ws_handler(
119    ws: WebSocketUpgrade,
120    State(state): State<Arc<AppState>>,
121    headers: HeaderMap,
122    Query(params): Query<WsQuery>,
123) -> Response {
124    // Strategy 1: Bearer token via query parameter
125    if let Some(ref token) = params.token {
126        if token == &state.api_token {
127            return ws.on_upgrade(move |socket| handle_ws(socket, state));
128        }
129    }
130
131    // Strategy 2: Session cookie
132    if let Some(session_token) = extract_session_cookie(&headers) {
133        if let Ok(Some(_)) = session::validate_session(&state.db, &session_token).await {
134            return ws.on_upgrade(move |socket| handle_ws(socket, state));
135        }
136    }
137
138    (
139        StatusCode::UNAUTHORIZED,
140        axum::Json(json!({"error": "unauthorized"})),
141    )
142        .into_response()
143}
144
145/// Handle a single WebSocket connection.
146///
147/// Subscribes to the broadcast channel and forwards events as JSON text frames.
148async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
149    let mut rx = state.event_tx.subscribe();
150
151    loop {
152        match rx.recv().await {
153            Ok(event) => {
154                let json = match serde_json::to_string(&event) {
155                    Ok(j) => j,
156                    Err(e) => {
157                        tracing::error!(error = %e, "Failed to serialize WsEvent");
158                        continue;
159                    }
160                };
161                if socket.send(Message::Text(json.into())).await.is_err() {
162                    // Client disconnected.
163                    break;
164                }
165            }
166            Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
167                tracing::warn!(count, "WebSocket client lagged, events dropped");
168                let error_event = AccountWsEvent {
169                    account_id: String::new(),
170                    event: WsEvent::Error {
171                        message: format!("{count} events dropped due to slow consumer"),
172                    },
173                };
174                if let Ok(json) = serde_json::to_string(&error_event) {
175                    if socket.send(Message::Text(json.into())).await.is_err() {
176                        break;
177                    }
178                }
179            }
180            Err(tokio::sync::broadcast::error::RecvError::Closed) => {
181                break;
182            }
183        }
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    // --- WsEvent serialization (tagged enum) ---
192
193    #[test]
194    fn action_performed_serializes_with_type_tag() {
195        let event = WsEvent::ActionPerformed {
196            action_type: "reply".into(),
197            target: "@user".into(),
198            content: "Hello!".into(),
199            timestamp: "2026-03-15T12:00:00Z".into(),
200        };
201        let json = serde_json::to_value(&event).unwrap();
202        assert_eq!(json["type"], "ActionPerformed");
203        assert_eq!(json["action_type"], "reply");
204        assert_eq!(json["target"], "@user");
205    }
206
207    #[test]
208    fn approval_queued_serializes() {
209        let event = WsEvent::ApprovalQueued {
210            id: 42,
211            action_type: "tweet".into(),
212            content: "Draft tweet".into(),
213            media_paths: vec!["img.png".into()],
214        };
215        let json = serde_json::to_value(&event).unwrap();
216        assert_eq!(json["type"], "ApprovalQueued");
217        assert_eq!(json["id"], 42);
218        assert_eq!(json["media_paths"].as_array().unwrap().len(), 1);
219    }
220
221    #[test]
222    fn approval_updated_serializes_with_optional_actor() {
223        let event = WsEvent::ApprovalUpdated {
224            id: 1,
225            status: "approved".into(),
226            action_type: "tweet".into(),
227            actor: Some("admin".into()),
228        };
229        let json = serde_json::to_value(&event).unwrap();
230        assert_eq!(json["actor"], "admin");
231
232        let event_no_actor = WsEvent::ApprovalUpdated {
233            id: 1,
234            status: "rejected".into(),
235            action_type: "tweet".into(),
236            actor: None,
237        };
238        let json2 = serde_json::to_value(&event_no_actor).unwrap();
239        assert!(
240            json2.get("actor").is_none(),
241            "actor should be skipped when None"
242        );
243    }
244
245    #[test]
246    fn follower_update_serializes() {
247        let event = WsEvent::FollowerUpdate {
248            count: 1500,
249            change: 25,
250        };
251        let json = serde_json::to_value(&event).unwrap();
252        assert_eq!(json["type"], "FollowerUpdate");
253        assert_eq!(json["count"], 1500);
254        assert_eq!(json["change"], 25);
255    }
256
257    #[test]
258    fn runtime_status_serializes() {
259        let event = WsEvent::RuntimeStatus {
260            running: true,
261            active_loops: vec!["mentions".into(), "discovery".into()],
262        };
263        let json = serde_json::to_value(&event).unwrap();
264        assert_eq!(json["type"], "RuntimeStatus");
265        assert_eq!(json["running"], true);
266        assert_eq!(json["active_loops"].as_array().unwrap().len(), 2);
267    }
268
269    #[test]
270    fn tweet_discovered_serializes() {
271        let event = WsEvent::TweetDiscovered {
272            tweet_id: "123456".into(),
273            author: "user1".into(),
274            score: 0.95,
275            timestamp: "2026-03-15T12:00:00Z".into(),
276        };
277        let json = serde_json::to_value(&event).unwrap();
278        assert_eq!(json["type"], "TweetDiscovered");
279        assert_eq!(json["score"], 0.95);
280    }
281
282    #[test]
283    fn action_skipped_serializes() {
284        let event = WsEvent::ActionSkipped {
285            action_type: "reply".into(),
286            reason: "rate limited".into(),
287            timestamp: "2026-03-15T12:00:00Z".into(),
288        };
289        let json = serde_json::to_value(&event).unwrap();
290        assert_eq!(json["type"], "ActionSkipped");
291        assert_eq!(json["reason"], "rate limited");
292    }
293
294    #[test]
295    fn content_scheduled_serializes() {
296        let event = WsEvent::ContentScheduled {
297            id: 7,
298            content_type: "tweet".into(),
299            scheduled_for: Some("2026-03-16T09:00:00Z".into()),
300        };
301        let json = serde_json::to_value(&event).unwrap();
302        assert_eq!(json["type"], "ContentScheduled");
303        assert_eq!(json["id"], 7);
304        assert!(json["scheduled_for"].is_string());
305    }
306
307    #[test]
308    fn circuit_breaker_tripped_serializes() {
309        let event = WsEvent::CircuitBreakerTripped {
310            state: "open".into(),
311            error_count: 5,
312            cooldown_remaining_seconds: 120,
313            timestamp: "2026-03-15T12:00:00Z".into(),
314        };
315        let json = serde_json::to_value(&event).unwrap();
316        assert_eq!(json["type"], "CircuitBreakerTripped");
317        assert_eq!(json["error_count"], 5);
318        assert_eq!(json["cooldown_remaining_seconds"], 120);
319    }
320
321    #[test]
322    fn error_event_serializes() {
323        let event = WsEvent::Error {
324            message: "something broke".into(),
325        };
326        let json = serde_json::to_value(&event).unwrap();
327        assert_eq!(json["type"], "Error");
328        assert_eq!(json["message"], "something broke");
329    }
330
331    // --- AccountWsEvent flattening ---
332
333    #[test]
334    fn account_ws_event_flattens_correctly() {
335        let event = AccountWsEvent {
336            account_id: "acct-123".into(),
337            event: WsEvent::FollowerUpdate {
338                count: 100,
339                change: 5,
340            },
341        };
342        let json = serde_json::to_value(&event).unwrap();
343        assert_eq!(json["account_id"], "acct-123");
344        assert_eq!(json["type"], "FollowerUpdate");
345        assert_eq!(json["count"], 100);
346    }
347
348    #[test]
349    fn account_ws_event_roundtrip() {
350        let original = AccountWsEvent {
351            account_id: "acct-456".into(),
352            event: WsEvent::Error {
353                message: "test error".into(),
354            },
355        };
356        let json_str = serde_json::to_string(&original).unwrap();
357        let deserialized: AccountWsEvent = serde_json::from_str(&json_str).unwrap();
358        assert_eq!(deserialized.account_id, "acct-456");
359        match deserialized.event {
360            WsEvent::Error { message } => assert_eq!(message, "test error"),
361            _ => panic!("expected Error variant"),
362        }
363    }
364
365    // --- extract_session_cookie ---
366
367    #[test]
368    fn extract_session_cookie_present() {
369        let mut headers = HeaderMap::new();
370        headers.insert(
371            "cookie",
372            "other=foo; tuitbot_session=abc123; another=bar"
373                .parse()
374                .unwrap(),
375        );
376        let result = extract_session_cookie(&headers);
377        assert_eq!(result.as_deref(), Some("abc123"));
378    }
379
380    #[test]
381    fn extract_session_cookie_not_present() {
382        let mut headers = HeaderMap::new();
383        headers.insert("cookie", "other=foo; another=bar".parse().unwrap());
384        let result = extract_session_cookie(&headers);
385        assert!(result.is_none());
386    }
387
388    #[test]
389    fn extract_session_cookie_no_cookie_header() {
390        let headers = HeaderMap::new();
391        let result = extract_session_cookie(&headers);
392        assert!(result.is_none());
393    }
394
395    // --- WsEvent deserialization ---
396
397    #[test]
398    fn ws_event_deserializes_from_json() {
399        let json = r#"{"type":"Error","message":"test"}"#;
400        let event: WsEvent = serde_json::from_str(json).unwrap();
401        match event {
402            WsEvent::Error { message } => assert_eq!(message, "test"),
403            _ => panic!("expected Error variant"),
404        }
405    }
406
407    #[test]
408    fn all_event_variants_serialize_without_panic() {
409        let events: Vec<WsEvent> = vec![
410            WsEvent::ActionPerformed {
411                action_type: "reply".into(),
412                target: "t".into(),
413                content: "c".into(),
414                timestamp: "ts".into(),
415            },
416            WsEvent::ApprovalQueued {
417                id: 1,
418                action_type: "tweet".into(),
419                content: "c".into(),
420                media_paths: vec![],
421            },
422            WsEvent::ApprovalUpdated {
423                id: 1,
424                status: "s".into(),
425                action_type: "a".into(),
426                actor: None,
427            },
428            WsEvent::FollowerUpdate {
429                count: 0,
430                change: 0,
431            },
432            WsEvent::RuntimeStatus {
433                running: false,
434                active_loops: vec![],
435            },
436            WsEvent::TweetDiscovered {
437                tweet_id: "t".into(),
438                author: "a".into(),
439                score: 0.0,
440                timestamp: "ts".into(),
441            },
442            WsEvent::ActionSkipped {
443                action_type: "a".into(),
444                reason: "r".into(),
445                timestamp: "ts".into(),
446            },
447            WsEvent::ContentScheduled {
448                id: 1,
449                content_type: "tweet".into(),
450                scheduled_for: None,
451            },
452            WsEvent::CircuitBreakerTripped {
453                state: "open".into(),
454                error_count: 0,
455                cooldown_remaining_seconds: 0,
456                timestamp: "ts".into(),
457            },
458            WsEvent::Error {
459                message: "err".into(),
460            },
461        ];
462        for event in events {
463            let json = serde_json::to_string(&event).unwrap();
464            assert!(!json.is_empty());
465            // Round-trip
466            let _: WsEvent = serde_json::from_str(&json).unwrap();
467        }
468    }
469}