Skip to main content

tuitbot_server/
ws.rs

1//! WebSocket hub for real-time event streaming.
2//!
3//! Provides a `/api/ws` endpoint that streams server events to dashboard clients
4//! via a `tokio::sync::broadcast` channel.
5
6use std::sync::Arc;
7
8use axum::extract::ws::{Message, WebSocket};
9use axum::extract::{Query, State, WebSocketUpgrade};
10use axum::http::StatusCode;
11use axum::response::{IntoResponse, Response};
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14
15use crate::state::AppState;
16
17/// Events pushed to WebSocket clients.
18#[derive(Clone, Debug, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum WsEvent {
21    /// An automation action was performed (reply, tweet, thread, etc.).
22    ActionPerformed {
23        action_type: String,
24        target: String,
25        content: String,
26        timestamp: String,
27    },
28    /// A new item was queued for approval.
29    ApprovalQueued {
30        id: i64,
31        action_type: String,
32        content: String,
33    },
34    /// Follower count changed.
35    FollowerUpdate { count: i64, change: i64 },
36    /// Automation runtime status changed.
37    RuntimeStatus {
38        running: bool,
39        active_loops: Vec<String>,
40    },
41    /// A tweet was discovered and scored by the discovery loop.
42    TweetDiscovered {
43        tweet_id: String,
44        author: String,
45        score: f64,
46        timestamp: String,
47    },
48    /// An action was skipped (rate limited, below threshold, safety filter).
49    ActionSkipped {
50        action_type: String,
51        reason: String,
52        timestamp: String,
53    },
54    /// An error occurred.
55    Error { message: String },
56}
57
58/// Query parameters for WebSocket authentication.
59#[derive(Deserialize)]
60pub struct WsQuery {
61    /// API token passed as a query parameter.
62    pub token: String,
63}
64
65/// `GET /api/ws?token=...` — WebSocket upgrade with token auth.
66pub async fn ws_handler(
67    ws: WebSocketUpgrade,
68    State(state): State<Arc<AppState>>,
69    Query(params): Query<WsQuery>,
70) -> Response {
71    // Authenticate via query parameter.
72    if params.token != state.api_token {
73        return (
74            StatusCode::UNAUTHORIZED,
75            axum::Json(json!({"error": "unauthorized"})),
76        )
77            .into_response();
78    }
79
80    ws.on_upgrade(move |socket| handle_ws(socket, state))
81}
82
83/// Handle a single WebSocket connection.
84///
85/// Subscribes to the broadcast channel and forwards events as JSON text frames.
86async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
87    let mut rx = state.event_tx.subscribe();
88
89    loop {
90        match rx.recv().await {
91            Ok(event) => {
92                let json = match serde_json::to_string(&event) {
93                    Ok(j) => j,
94                    Err(e) => {
95                        tracing::error!(error = %e, "Failed to serialize WsEvent");
96                        continue;
97                    }
98                };
99                if socket.send(Message::Text(json.into())).await.is_err() {
100                    // Client disconnected.
101                    break;
102                }
103            }
104            Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
105                tracing::warn!(count, "WebSocket client lagged, events dropped");
106                let error_event = WsEvent::Error {
107                    message: format!("{count} events dropped due to slow consumer"),
108                };
109                if let Ok(json) = serde_json::to_string(&error_event) {
110                    if socket.send(Message::Text(json.into())).await.is_err() {
111                        break;
112                    }
113                }
114            }
115            Err(tokio::sync::broadcast::error::RecvError::Closed) => {
116                break;
117            }
118        }
119    }
120}