1use std::sync::Arc;
7
8use axum::extract::ws::{Message, WebSocket};
9use axum::extract::{Query, State, WebSocketUpgrade};
10use axum::http::StatusCode;
11use axum::response::{IntoResponse, Response};
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14
15use crate::state::AppState;
16
17#[derive(Clone, Debug, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum WsEvent {
21 ActionPerformed {
23 action_type: String,
24 target: String,
25 content: String,
26 timestamp: String,
27 },
28 ApprovalQueued {
30 id: i64,
31 action_type: String,
32 content: String,
33 #[serde(default)]
34 media_paths: Vec<String>,
35 },
36 ApprovalUpdated {
38 id: i64,
39 status: String,
40 action_type: String,
41 },
42 FollowerUpdate { count: i64, change: i64 },
44 RuntimeStatus {
46 running: bool,
47 active_loops: Vec<String>,
48 },
49 TweetDiscovered {
51 tweet_id: String,
52 author: String,
53 score: f64,
54 timestamp: String,
55 },
56 ActionSkipped {
58 action_type: String,
59 reason: String,
60 timestamp: String,
61 },
62 ContentScheduled {
64 id: i64,
65 content_type: String,
66 scheduled_for: Option<String>,
67 },
68 Error { message: String },
70}
71
72#[derive(Deserialize)]
74pub struct WsQuery {
75 pub token: String,
77}
78
79pub async fn ws_handler(
81 ws: WebSocketUpgrade,
82 State(state): State<Arc<AppState>>,
83 Query(params): Query<WsQuery>,
84) -> Response {
85 if params.token != state.api_token {
87 return (
88 StatusCode::UNAUTHORIZED,
89 axum::Json(json!({"error": "unauthorized"})),
90 )
91 .into_response();
92 }
93
94 ws.on_upgrade(move |socket| handle_ws(socket, state))
95}
96
97async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
101 let mut rx = state.event_tx.subscribe();
102
103 loop {
104 match rx.recv().await {
105 Ok(event) => {
106 let json = match serde_json::to_string(&event) {
107 Ok(j) => j,
108 Err(e) => {
109 tracing::error!(error = %e, "Failed to serialize WsEvent");
110 continue;
111 }
112 };
113 if socket.send(Message::Text(json.into())).await.is_err() {
114 break;
116 }
117 }
118 Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
119 tracing::warn!(count, "WebSocket client lagged, events dropped");
120 let error_event = WsEvent::Error {
121 message: format!("{count} events dropped due to slow consumer"),
122 };
123 if let Ok(json) = serde_json::to_string(&error_event) {
124 if socket.send(Message::Text(json.into())).await.is_err() {
125 break;
126 }
127 }
128 }
129 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
130 break;
131 }
132 }
133 }
134}