Skip to main content

tuitbot_server/
ws.rs

1//! WebSocket hub for real-time event streaming.
2//!
3//! Provides a `/api/ws` endpoint that streams server events to dashboard clients
4//! via a `tokio::sync::broadcast` channel.
5//!
6//! Supports two authentication methods:
7//! - Query parameter: `?token=<api_token>` (Tauri/API clients)
8//! - Session cookie: `tuitbot_session=<token>` (web/LAN clients)
9
10use std::sync::Arc;
11
12use axum::extract::ws::{Message, WebSocket};
13use axum::extract::{Query, State, WebSocketUpgrade};
14use axum::http::{HeaderMap, StatusCode};
15use axum::response::{IntoResponse, Response};
16use serde::{Deserialize, Serialize};
17use serde_json::json;
18use tuitbot_core::auth::session;
19
20use crate::state::AppState;
21
22/// Wrapper that tags every [`WsEvent`] with the originating account.
23///
24/// Serializes flat thanks to `#[serde(flatten)]`, so the JSON looks like:
25/// `{ "account_id": "...", "type": "ApprovalQueued", ... }`
26#[derive(Clone, Debug, Serialize, Deserialize)]
27pub struct AccountWsEvent {
28    pub account_id: String,
29    #[serde(flatten)]
30    pub event: WsEvent,
31}
32
33/// Events pushed to WebSocket clients.
34#[derive(Clone, Debug, Serialize, Deserialize)]
35#[serde(tag = "type")]
36pub enum WsEvent {
37    /// An automation action was performed (reply, tweet, thread, etc.).
38    ActionPerformed {
39        action_type: String,
40        target: String,
41        content: String,
42        timestamp: String,
43    },
44    /// A new item was queued for approval.
45    ApprovalQueued {
46        id: i64,
47        action_type: String,
48        content: String,
49        #[serde(default)]
50        media_paths: Vec<String>,
51    },
52    /// An approval item's status was updated (approved, rejected, edited).
53    ApprovalUpdated {
54        id: i64,
55        status: String,
56        action_type: String,
57        #[serde(skip_serializing_if = "Option::is_none")]
58        actor: Option<String>,
59    },
60    /// Follower count changed.
61    FollowerUpdate { count: i64, change: i64 },
62    /// Automation runtime status changed.
63    RuntimeStatus {
64        running: bool,
65        active_loops: Vec<String>,
66    },
67    /// A tweet was discovered and scored by the discovery loop.
68    TweetDiscovered {
69        tweet_id: String,
70        author: String,
71        score: f64,
72        timestamp: String,
73    },
74    /// An action was skipped (rate limited, below threshold, safety filter).
75    ActionSkipped {
76        action_type: String,
77        reason: String,
78        timestamp: String,
79    },
80    /// A new content item was scheduled via the composer.
81    ContentScheduled {
82        id: i64,
83        content_type: String,
84        scheduled_for: Option<String>,
85    },
86    /// Circuit breaker state changed.
87    CircuitBreakerTripped {
88        state: String,
89        error_count: u32,
90        cooldown_remaining_seconds: u64,
91        timestamp: String,
92    },
93    /// A Ghostwriter selection was received from the Obsidian plugin.
94    SelectionReceived { session_id: String },
95    /// An error occurred.
96    Error { message: String },
97}
98
99/// Query parameters for WebSocket authentication.
100#[derive(Deserialize)]
101pub struct WsQuery {
102    /// API token passed as a query parameter (optional — cookie auth is fallback).
103    pub token: Option<String>,
104}
105
106/// Extract the session cookie value from headers (exported for tests).
107fn extract_session_cookie(headers: &HeaderMap) -> Option<String> {
108    headers
109        .get("cookie")
110        .and_then(|v| v.to_str().ok())
111        .and_then(|cookies| {
112            cookies.split(';').find_map(|c| {
113                let c = c.trim();
114                c.strip_prefix("tuitbot_session=").map(|v| v.to_string())
115            })
116        })
117}
118
119/// `GET /api/ws` — WebSocket upgrade with token or cookie auth.
120pub async fn ws_handler(
121    ws: WebSocketUpgrade,
122    State(state): State<Arc<AppState>>,
123    headers: HeaderMap,
124    Query(params): Query<WsQuery>,
125) -> Response {
126    // Strategy 1: Bearer token via query parameter
127    if let Some(ref token) = params.token {
128        if token == &state.api_token {
129            return ws.on_upgrade(move |socket| handle_ws(socket, state));
130        }
131    }
132
133    // Strategy 2: Session cookie
134    if let Some(session_token) = extract_session_cookie(&headers) {
135        if let Ok(Some(_)) = session::validate_session(&state.db, &session_token).await {
136            return ws.on_upgrade(move |socket| handle_ws(socket, state));
137        }
138    }
139
140    (
141        StatusCode::UNAUTHORIZED,
142        axum::Json(json!({"error": "unauthorized"})),
143    )
144        .into_response()
145}
146
147/// Handle a single WebSocket connection.
148///
149/// Subscribes to the broadcast channel and forwards events as JSON text frames.
150async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
151    let mut rx = state.event_tx.subscribe();
152
153    loop {
154        match rx.recv().await {
155            Ok(event) => {
156                let json = match serde_json::to_string(&event) {
157                    Ok(j) => j,
158                    Err(e) => {
159                        tracing::error!(error = %e, "Failed to serialize WsEvent");
160                        continue;
161                    }
162                };
163                if socket.send(Message::Text(json.into())).await.is_err() {
164                    // Client disconnected.
165                    break;
166                }
167            }
168            Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
169                tracing::warn!(count, "WebSocket client lagged, events dropped");
170                let error_event = AccountWsEvent {
171                    account_id: String::new(),
172                    event: WsEvent::Error {
173                        message: format!("{count} events dropped due to slow consumer"),
174                    },
175                };
176                if let Ok(json) = serde_json::to_string(&error_event) {
177                    if socket.send(Message::Text(json.into())).await.is_err() {
178                        break;
179                    }
180                }
181            }
182            Err(tokio::sync::broadcast::error::RecvError::Closed) => {
183                break;
184            }
185        }
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    // --- WsEvent serialization (tagged enum) ---
194
195    #[test]
196    fn action_performed_serializes_with_type_tag() {
197        let event = WsEvent::ActionPerformed {
198            action_type: "reply".into(),
199            target: "@user".into(),
200            content: "Hello!".into(),
201            timestamp: "2026-03-15T12:00:00Z".into(),
202        };
203        let json = serde_json::to_value(&event).unwrap();
204        assert_eq!(json["type"], "ActionPerformed");
205        assert_eq!(json["action_type"], "reply");
206        assert_eq!(json["target"], "@user");
207    }
208
209    #[test]
210    fn approval_queued_serializes() {
211        let event = WsEvent::ApprovalQueued {
212            id: 42,
213            action_type: "tweet".into(),
214            content: "Draft tweet".into(),
215            media_paths: vec!["img.png".into()],
216        };
217        let json = serde_json::to_value(&event).unwrap();
218        assert_eq!(json["type"], "ApprovalQueued");
219        assert_eq!(json["id"], 42);
220        assert_eq!(json["media_paths"].as_array().unwrap().len(), 1);
221    }
222
223    #[test]
224    fn approval_updated_serializes_with_optional_actor() {
225        let event = WsEvent::ApprovalUpdated {
226            id: 1,
227            status: "approved".into(),
228            action_type: "tweet".into(),
229            actor: Some("admin".into()),
230        };
231        let json = serde_json::to_value(&event).unwrap();
232        assert_eq!(json["actor"], "admin");
233
234        let event_no_actor = WsEvent::ApprovalUpdated {
235            id: 1,
236            status: "rejected".into(),
237            action_type: "tweet".into(),
238            actor: None,
239        };
240        let json2 = serde_json::to_value(&event_no_actor).unwrap();
241        assert!(
242            json2.get("actor").is_none(),
243            "actor should be skipped when None"
244        );
245    }
246
247    #[test]
248    fn follower_update_serializes() {
249        let event = WsEvent::FollowerUpdate {
250            count: 1500,
251            change: 25,
252        };
253        let json = serde_json::to_value(&event).unwrap();
254        assert_eq!(json["type"], "FollowerUpdate");
255        assert_eq!(json["count"], 1500);
256        assert_eq!(json["change"], 25);
257    }
258
259    #[test]
260    fn runtime_status_serializes() {
261        let event = WsEvent::RuntimeStatus {
262            running: true,
263            active_loops: vec!["mentions".into(), "discovery".into()],
264        };
265        let json = serde_json::to_value(&event).unwrap();
266        assert_eq!(json["type"], "RuntimeStatus");
267        assert_eq!(json["running"], true);
268        assert_eq!(json["active_loops"].as_array().unwrap().len(), 2);
269    }
270
271    #[test]
272    fn tweet_discovered_serializes() {
273        let event = WsEvent::TweetDiscovered {
274            tweet_id: "123456".into(),
275            author: "user1".into(),
276            score: 0.95,
277            timestamp: "2026-03-15T12:00:00Z".into(),
278        };
279        let json = serde_json::to_value(&event).unwrap();
280        assert_eq!(json["type"], "TweetDiscovered");
281        assert_eq!(json["score"], 0.95);
282    }
283
284    #[test]
285    fn action_skipped_serializes() {
286        let event = WsEvent::ActionSkipped {
287            action_type: "reply".into(),
288            reason: "rate limited".into(),
289            timestamp: "2026-03-15T12:00:00Z".into(),
290        };
291        let json = serde_json::to_value(&event).unwrap();
292        assert_eq!(json["type"], "ActionSkipped");
293        assert_eq!(json["reason"], "rate limited");
294    }
295
296    #[test]
297    fn content_scheduled_serializes() {
298        let event = WsEvent::ContentScheduled {
299            id: 7,
300            content_type: "tweet".into(),
301            scheduled_for: Some("2026-03-16T09:00:00Z".into()),
302        };
303        let json = serde_json::to_value(&event).unwrap();
304        assert_eq!(json["type"], "ContentScheduled");
305        assert_eq!(json["id"], 7);
306        assert!(json["scheduled_for"].is_string());
307    }
308
309    #[test]
310    fn circuit_breaker_tripped_serializes() {
311        let event = WsEvent::CircuitBreakerTripped {
312            state: "open".into(),
313            error_count: 5,
314            cooldown_remaining_seconds: 120,
315            timestamp: "2026-03-15T12:00:00Z".into(),
316        };
317        let json = serde_json::to_value(&event).unwrap();
318        assert_eq!(json["type"], "CircuitBreakerTripped");
319        assert_eq!(json["error_count"], 5);
320        assert_eq!(json["cooldown_remaining_seconds"], 120);
321    }
322
323    #[test]
324    fn error_event_serializes() {
325        let event = WsEvent::Error {
326            message: "something broke".into(),
327        };
328        let json = serde_json::to_value(&event).unwrap();
329        assert_eq!(json["type"], "Error");
330        assert_eq!(json["message"], "something broke");
331    }
332
333    // --- AccountWsEvent flattening ---
334
335    #[test]
336    fn account_ws_event_flattens_correctly() {
337        let event = AccountWsEvent {
338            account_id: "acct-123".into(),
339            event: WsEvent::FollowerUpdate {
340                count: 100,
341                change: 5,
342            },
343        };
344        let json = serde_json::to_value(&event).unwrap();
345        assert_eq!(json["account_id"], "acct-123");
346        assert_eq!(json["type"], "FollowerUpdate");
347        assert_eq!(json["count"], 100);
348    }
349
350    #[test]
351    fn account_ws_event_roundtrip() {
352        let original = AccountWsEvent {
353            account_id: "acct-456".into(),
354            event: WsEvent::Error {
355                message: "test error".into(),
356            },
357        };
358        let json_str = serde_json::to_string(&original).unwrap();
359        let deserialized: AccountWsEvent = serde_json::from_str(&json_str).unwrap();
360        assert_eq!(deserialized.account_id, "acct-456");
361        match deserialized.event {
362            WsEvent::Error { message } => assert_eq!(message, "test error"),
363            _ => panic!("expected Error variant"),
364        }
365    }
366
367    // --- extract_session_cookie ---
368
369    #[test]
370    fn extract_session_cookie_present() {
371        let mut headers = HeaderMap::new();
372        headers.insert(
373            "cookie",
374            "other=foo; tuitbot_session=abc123; another=bar"
375                .parse()
376                .unwrap(),
377        );
378        let result = extract_session_cookie(&headers);
379        assert_eq!(result.as_deref(), Some("abc123"));
380    }
381
382    #[test]
383    fn extract_session_cookie_not_present() {
384        let mut headers = HeaderMap::new();
385        headers.insert("cookie", "other=foo; another=bar".parse().unwrap());
386        let result = extract_session_cookie(&headers);
387        assert!(result.is_none());
388    }
389
390    #[test]
391    fn extract_session_cookie_no_cookie_header() {
392        let headers = HeaderMap::new();
393        let result = extract_session_cookie(&headers);
394        assert!(result.is_none());
395    }
396
397    // --- WsEvent deserialization ---
398
399    #[test]
400    fn ws_event_deserializes_from_json() {
401        let json = r#"{"type":"Error","message":"test"}"#;
402        let event: WsEvent = serde_json::from_str(json).unwrap();
403        match event {
404            WsEvent::Error { message } => assert_eq!(message, "test"),
405            _ => panic!("expected Error variant"),
406        }
407    }
408
409    #[test]
410    fn all_event_variants_serialize_without_panic() {
411        let events: Vec<WsEvent> = vec![
412            WsEvent::ActionPerformed {
413                action_type: "reply".into(),
414                target: "t".into(),
415                content: "c".into(),
416                timestamp: "ts".into(),
417            },
418            WsEvent::ApprovalQueued {
419                id: 1,
420                action_type: "tweet".into(),
421                content: "c".into(),
422                media_paths: vec![],
423            },
424            WsEvent::ApprovalUpdated {
425                id: 1,
426                status: "s".into(),
427                action_type: "a".into(),
428                actor: None,
429            },
430            WsEvent::FollowerUpdate {
431                count: 0,
432                change: 0,
433            },
434            WsEvent::RuntimeStatus {
435                running: false,
436                active_loops: vec![],
437            },
438            WsEvent::TweetDiscovered {
439                tweet_id: "t".into(),
440                author: "a".into(),
441                score: 0.0,
442                timestamp: "ts".into(),
443            },
444            WsEvent::ActionSkipped {
445                action_type: "a".into(),
446                reason: "r".into(),
447                timestamp: "ts".into(),
448            },
449            WsEvent::ContentScheduled {
450                id: 1,
451                content_type: "tweet".into(),
452                scheduled_for: None,
453            },
454            WsEvent::CircuitBreakerTripped {
455                state: "open".into(),
456                error_count: 0,
457                cooldown_remaining_seconds: 0,
458                timestamp: "ts".into(),
459            },
460            WsEvent::SelectionReceived {
461                session_id: "sess-1".into(),
462            },
463            WsEvent::Error {
464                message: "err".into(),
465            },
466        ];
467        for event in events {
468            let json = serde_json::to_string(&event).unwrap();
469            assert!(!json.is_empty());
470            // Round-trip
471            let _: WsEvent = serde_json::from_str(&json).unwrap();
472        }
473    }
474}