1use axum::extract::ws::{Message, WebSocket};
2use axum::extract::{Path, State, WebSocketUpgrade};
3use axum::response::IntoResponse;
4use tokio::sync::broadcast;
5
6use crate::circuit;
7use crate::error::Result;
8use crate::http::state::SharedState;
9
10pub(crate) async fn ws_handler(
11 ws: WebSocketUpgrade,
12 State(state): State<SharedState>,
13 Path(agent_id): Path<String>,
14) -> Result<impl IntoResponse> {
15 {
16 let state_fb = state.clone();
17 let agent_id = agent_id.clone();
18 let _ = tokio::task::spawn_blocking(move || {
19 let engine = state_fb.engine.lock();
20 state_fb.agent_registry.heartbeat(
21 engine.graph(),
22 &agent_id,
23 crate::status::AgentStatusSnapshot {
24 state: crate::status::AgentState::Working,
25 task_id: None,
26 blocked_reason: None,
27 waiting_on_agent: None,
28 checkpoint: Some("ws_connected".into()),
29 working_on: "connected via WS".into(),
30 },
31 )
32 })
33 .await;
34 }
35
36 Ok(ws.on_upgrade(move |socket| handle_ws(socket, state, agent_id)))
37}
38
39async fn handle_ws(mut socket: WebSocket, state: SharedState, agent_id: String) {
40 let mut rx = state.ws_registry.register(&agent_id);
41
42 {
44 let state_fb = state.clone();
45 let agent_id_fb = agent_id.clone();
46 let pending = tokio::task::spawn_blocking(move || {
47 let engine = state_fb.engine.lock();
48 state_fb
49 .message_store
50 .poll(engine.graph(), &agent_id_fb, 0, 100, true)
51 })
52 .await
53 .unwrap_or(Ok(Vec::new()))
54 .unwrap_or_default();
55 for msg in &pending {
56 let event = serde_json::json!({"event": "message", "data": msg});
57 if socket
58 .send(Message::Text(event.to_string().into()))
59 .await
60 .is_err()
61 {
62 state.ws_registry.unregister(&agent_id);
63 return;
64 }
65 }
66 }
67
68 let catchup_events: Vec<serde_json::Value> = {
70 let state_fb = state.clone();
71 let agent_id_fb = agent_id.clone();
72 tokio::task::spawn_blocking(move || {
73 let engine = state_fb.engine.lock();
74 let projects = state_fb
75 .subscription_store
76 .list(engine.graph(), &agent_id_fb)
77 .unwrap_or_default();
78 let mut payloads = Vec::new();
79 for project in &projects {
80 if let Ok(events) = state_fb.delivery_tracker.get_undelivered(
81 engine.graph(),
82 &agent_id_fb,
83 project,
84 Some(50),
85 ) {
86 for evt in &events {
87 if let Ok(payload) = serde_json::to_value(evt) {
88 payloads.push(
89 serde_json::json!({"event": "event_catchup", "data": payload}),
90 );
91 }
92 }
93 }
94 }
95 payloads
96 })
97 .await
98 .unwrap_or_default()
99 };
100 for msg in &catchup_events {
101 if socket
102 .send(Message::Text(msg.to_string().into()))
103 .await
104 .is_err()
105 {
106 state.ws_registry.unregister(&agent_id);
107 return;
108 }
109 }
110 if !catchup_events.is_empty() {
112 let state_fb = state.clone();
113 let agent_id_fb = agent_id.clone();
114 let _ = tokio::task::spawn_blocking(move || {
115 let engine = state_fb.engine.lock();
116 for msg in &catchup_events {
117 if let Some(eid) = msg
118 .get("data")
119 .and_then(|d| d.get("id"))
120 .and_then(|v| v.as_str())
121 {
122 let _ = state_fb.delivery_tracker.record_delivery(
123 engine.graph(),
124 &agent_id_fb,
125 eid,
126 );
127 }
128 }
129 })
130 .await;
131 }
132
133 let connected = serde_json::json!({
135 "event": "agent_connected",
136 "data": { "agent_id": &agent_id }
137 });
138 let _ = socket
139 .send(Message::Text(connected.to_string().into()))
140 .await;
141
142 loop {
143 tokio::select! {
144 result = rx.recv() => {
145 match result {
146 Ok(event_str) => {
147 if socket.send(Message::Text(event_str.into())).await.is_err() {
148 break;
149 }
150 }
151 Err(broadcast::error::RecvError::Lagged(n)) => {
153 let _ = socket.send(Message::Text(
154 serde_json::json!({
155 "event": "channel_lagged",
156 "data": { "skipped": n }
157 }).to_string().into()
158 )).await;
159
160 let state_fb = state.clone();
162 let agent_id_fb = agent_id.clone();
163 let replay = tokio::task::spawn_blocking(move || {
164 let engine = state_fb.engine.lock();
165 state_fb.message_store.poll(engine.graph(), &agent_id_fb, 0, 100, false)
166 })
167 .await
168 .unwrap_or(Ok(Vec::new()))
169 .unwrap_or_default();
170
171 for msg in &replay {
172 let event = serde_json::json!({"event": "message", "data": msg});
173 if socket.send(Message::Text(event.to_string().into())).await.is_err() {
174 state.ws_registry.unregister(&agent_id);
175 return;
176 }
177 }
178
179 rx = state.ws_registry.register(&agent_id);
180 }
181 Err(_) => break, }
183 }
184 msg = socket.recv() => {
185 match msg {
186 Some(Ok(Message::Text(text))) => {
187 if let Ok(hb) = serde_json::from_str::<serde_json::Value>(&text) {
188 match hb.get("type").and_then(|v| v.as_str()) {
189 Some("heartbeat") => {
190 let mut status: Option<crate::status::AgentStatusSnapshot> = None;
191 if let Some(data) = hb.get("data") {
192 status = serde_json::from_value::<crate::status::AgentStatusSnapshot>(data.clone()).ok();
193 }
194 let state_fb = state.clone();
195 let agent_id_fb = agent_id.clone();
196 let accepted = tokio::task::spawn_blocking(move || {
197 let engine = state_fb.engine.lock();
198 if let Some(ref st) = status {
199 state_fb.agent_registry.heartbeat(engine.graph(), &agent_id_fb, st.clone()).is_ok()
200 } else {
201 state_fb.agent_registry.heartbeat(engine.graph(), &agent_id_fb,
202 crate::status::AgentStatusSnapshot::default()).is_ok()
203 }
204 })
205 .await
206 .unwrap_or(false);
207 let _ = socket.send(Message::Text(
208 serde_json::json!({
209 "type": "heartbeat_ack",
210 "data": {
211 "accepted": accepted,
212 "timestamp": chrono::Utc::now().to_rfc3339(),
213 }
214 }).to_string().into()
215 )).await;
216 continue;
217 }
218 Some("ping") => {
219 let _ = socket.send(Message::Text(
220 serde_json::json!({"type": "pong"}).to_string().into()
221 )).await;
222 continue;
223 }
224 _ => {}
225 }
226 }
227 }
228 Some(Ok(Message::Close(_))) | None => break,
229 _ => {}
230 }
231 }
232 }
233 }
234
235 state.ws_registry.unregister(&agent_id);
236}
237
238pub(crate) async fn broadcast_to_project(
240 state: &SharedState,
241 project: &str,
242 event_type: &str,
243 data: &serde_json::Value,
244) {
245 let state_c = state.clone();
247 let project_owned = project.to_string();
248 let subs = match tokio::task::spawn_blocking(move || {
249 let engine = state_c.engine.lock();
250 state_c
251 .subscription_store
252 .subscribers(engine.graph(), &project_owned)
253 .unwrap_or_default()
254 })
255 .await
256 {
257 Ok(s) => s,
258 Err(_) => return,
259 };
260
261 let event_id = data.get("id").and_then(|v| v.as_str());
262 let mut delivery_pairs: Vec<(String, String)> = Vec::new();
263 let mut offline_agents: Vec<String> = Vec::new();
264
265 for agent_id in &subs {
267 match state.circuit_breaker.check(agent_id) {
268 circuit::CanDeliver::No => continue,
269 circuit::CanDeliver::Yes | circuit::CanDeliver::Probe => {}
270 }
271 let delivered = state.ws_registry.send_json(agent_id, event_type, data);
272 if delivered {
273 state.circuit_breaker.record_success(agent_id);
274 let state_fb = state.clone();
275 let agent_id_fb = agent_id.clone();
276 let _ = tokio::task::spawn_blocking(move || {
277 let engine = state_fb.engine.lock();
278 state_fb
279 .audit_store
280 .log_circuit_closed(engine.graph(), &agent_id_fb)
281 })
282 .await;
283 if let Some(eid) = event_id {
284 delivery_pairs.push((agent_id.clone(), eid.to_string()));
285 }
286 } else {
287 state.circuit_breaker.record_failure(agent_id);
288 let status = state.circuit_breaker.get_state(agent_id);
289 if status.state == "open" {
290 let state_fb = state.clone();
291 let agent_id_fb = agent_id.clone();
292 let failures = status.failures;
293 let _ = tokio::task::spawn_blocking(move || {
294 let engine = state_fb.engine.lock();
295 state_fb
296 .audit_store
297 .log_circuit_opened(engine.graph(), &agent_id_fb, failures)
298 })
299 .await;
300 }
301 offline_agents.push(agent_id.clone());
302 }
303 }
304
305 if !delivery_pairs.is_empty() {
307 let state_c = state.clone();
308 let _ = tokio::task::spawn_blocking(move || {
309 let engine = state_c.engine.lock();
310 for (agent_id, eid) in &delivery_pairs {
311 let _ = state_c
312 .delivery_tracker
313 .record_delivery(engine.graph(), agent_id, eid);
314 }
315 })
316 .await;
317 }
318
319 if !offline_agents.is_empty() {
321 let state_c = state.clone();
322 let event_type_owned = event_type.to_string();
323 let data_clone = data.clone();
324 let _ = tokio::task::spawn_blocking(move || {
325 let engine = state_c.engine.lock();
326 for agent_id in &offline_agents {
327 let _ = state_c.message_store.store_notification(
328 engine.graph(),
329 agent_id,
330 &event_type_owned,
331 &data_clone,
332 );
333 }
334 })
335 .await;
336 }
337}