Skip to main content

envoy/http/
ws.rs

1use axum::extract::ws::{Message, WebSocket};
2use axum::extract::{Path, State, WebSocketUpgrade};
3use axum::response::IntoResponse;
4use tokio::sync::broadcast;
5
6use crate::circuit;
7use crate::error::Result;
8use crate::http::state::SharedState;
9
10pub(crate) async fn ws_handler(
11    ws: WebSocketUpgrade,
12    State(state): State<SharedState>,
13    Path(agent_id): Path<String>,
14) -> Result<impl IntoResponse> {
15    {
16        let state_fb = state.clone();
17        let agent_id = agent_id.clone();
18        let _ = tokio::task::spawn_blocking(move || {
19            let engine = state_fb.engine.lock();
20            state_fb.agent_registry.heartbeat(
21                engine.graph(),
22                &agent_id,
23                crate::status::AgentStatusSnapshot {
24                    state: crate::status::AgentState::Working,
25                    task_id: None,
26                    blocked_reason: None,
27                    waiting_on_agent: None,
28                    checkpoint: Some("ws_connected".into()),
29                    working_on: "connected via WS".into(),
30                },
31            )
32        })
33        .await;
34    }
35
36    Ok(ws.on_upgrade(move |socket| handle_ws(socket, state, agent_id)))
37}
38
39async fn handle_ws(mut socket: WebSocket, state: SharedState, agent_id: String) {
40    let mut rx = state.ws_registry.register(&agent_id);
41
42    // Catch-up: undelivered messages for this agent
43    {
44        let state_fb = state.clone();
45        let agent_id_fb = agent_id.clone();
46        let pending = tokio::task::spawn_blocking(move || {
47            let engine = state_fb.engine.lock();
48            state_fb
49                .message_store
50                .poll(engine.graph(), &agent_id_fb, 0, 100, true)
51        })
52        .await
53        .unwrap_or(Ok(Vec::new()))
54        .unwrap_or_default();
55        for msg in &pending {
56            let event = serde_json::json!({"event": "message", "data": msg});
57            if socket
58                .send(Message::Text(event.to_string().into()))
59                .await
60                .is_err()
61            {
62                state.ws_registry.unregister(&agent_id);
63                return;
64            }
65        }
66    }
67
68    // Catch-up: undelivered events for subscribed projects (dead-letter replay)
69    let catchup_events: Vec<serde_json::Value> = {
70        let state_fb = state.clone();
71        let agent_id_fb = agent_id.clone();
72        tokio::task::spawn_blocking(move || {
73            let engine = state_fb.engine.lock();
74            let projects = state_fb
75                .subscription_store
76                .list(engine.graph(), &agent_id_fb)
77                .unwrap_or_default();
78            let mut payloads = Vec::new();
79            for project in &projects {
80                if let Ok(events) = state_fb.delivery_tracker.get_undelivered(
81                    engine.graph(),
82                    &agent_id_fb,
83                    project,
84                    Some(50),
85                ) {
86                    for evt in &events {
87                        if let Ok(payload) = serde_json::to_value(evt) {
88                            payloads.push(
89                                serde_json::json!({"event": "event_catchup", "data": payload}),
90                            );
91                        }
92                    }
93                }
94            }
95            payloads
96        })
97        .await
98        .unwrap_or_default()
99    };
100    for msg in &catchup_events {
101        if socket
102            .send(Message::Text(msg.to_string().into()))
103            .await
104            .is_err()
105        {
106            state.ws_registry.unregister(&agent_id);
107            return;
108        }
109    }
110    // Mark catch-up events as delivered
111    if !catchup_events.is_empty() {
112        let state_fb = state.clone();
113        let agent_id_fb = agent_id.clone();
114        let _ = tokio::task::spawn_blocking(move || {
115            let engine = state_fb.engine.lock();
116            for msg in &catchup_events {
117                if let Some(eid) = msg
118                    .get("data")
119                    .and_then(|d| d.get("id"))
120                    .and_then(|v| v.as_str())
121                {
122                    let _ = state_fb.delivery_tracker.record_delivery(
123                        engine.graph(),
124                        &agent_id_fb,
125                        eid,
126                    );
127                }
128            }
129        })
130        .await;
131    }
132
133    // Connected event
134    let connected = serde_json::json!({
135        "event": "agent_connected",
136        "data": { "agent_id": &agent_id }
137    });
138    let _ = socket
139        .send(Message::Text(connected.to_string().into()))
140        .await;
141
142    loop {
143        tokio::select! {
144            result = rx.recv() => {
145                match result {
146                    Ok(event_str) => {
147                        if socket.send(Message::Text(event_str.into())).await.is_err() {
148                            break;
149                        }
150                    }
151                    // Channel overflowed — replay missed messages from store
152                    Err(broadcast::error::RecvError::Lagged(n)) => {
153                        let _ = socket.send(Message::Text(
154                            serde_json::json!({
155                                "event": "channel_lagged",
156                                "data": { "skipped": n }
157                            }).to_string().into()
158                        )).await;
159
160                        // Replay unACKed messages from persistent store
161                        let state_fb = state.clone();
162                        let agent_id_fb = agent_id.clone();
163                        let replay = tokio::task::spawn_blocking(move || {
164                            let engine = state_fb.engine.lock();
165                            state_fb.message_store.poll(engine.graph(), &agent_id_fb, 0, 100, false)
166                        })
167                        .await
168                        .unwrap_or(Ok(Vec::new()))
169                        .unwrap_or_default();
170
171                        for msg in &replay {
172                            let event = serde_json::json!({"event": "message", "data": msg});
173                            if socket.send(Message::Text(event.to_string().into())).await.is_err() {
174                                state.ws_registry.unregister(&agent_id);
175                                return;
176                            }
177                        }
178
179                        rx = state.ws_registry.register(&agent_id);
180                    }
181                    Err(_) => break, // channel closed
182                }
183            }
184            msg = socket.recv() => {
185                match msg {
186                    Some(Ok(Message::Text(text))) => {
187                        if let Ok(hb) = serde_json::from_str::<serde_json::Value>(&text) {
188                            match hb.get("type").and_then(|v| v.as_str()) {
189                                Some("heartbeat") => {
190                                    let mut status: Option<crate::status::AgentStatusSnapshot> = None;
191                                    if let Some(data) = hb.get("data") {
192                                        status = serde_json::from_value::<crate::status::AgentStatusSnapshot>(data.clone()).ok();
193                                    }
194                                    let state_fb = state.clone();
195                                    let agent_id_fb = agent_id.clone();
196                                    let accepted = tokio::task::spawn_blocking(move || {
197                                        let engine = state_fb.engine.lock();
198                                        if let Some(ref st) = status {
199                                            state_fb.agent_registry.heartbeat(engine.graph(), &agent_id_fb, st.clone()).is_ok()
200                                        } else {
201                                            state_fb.agent_registry.heartbeat(engine.graph(), &agent_id_fb,
202                                                crate::status::AgentStatusSnapshot::default()).is_ok()
203                                        }
204                                    })
205                                    .await
206                                    .unwrap_or(false);
207                                    let _ = socket.send(Message::Text(
208                                        serde_json::json!({
209                                            "type": "heartbeat_ack",
210                                            "data": {
211                                                "accepted": accepted,
212                                                "timestamp": chrono::Utc::now().to_rfc3339(),
213                                            }
214                                        }).to_string().into()
215                                    )).await;
216                                    continue;
217                                }
218                                Some("ping") => {
219                                    let _ = socket.send(Message::Text(
220                                        serde_json::json!({"type": "pong"}).to_string().into()
221                                    )).await;
222                                    continue;
223                                }
224                                _ => {}
225                            }
226                        }
227                    }
228                    Some(Ok(Message::Close(_))) | None => break,
229                    _ => {}
230                }
231            }
232        }
233    }
234
235    state.ws_registry.unregister(&agent_id);
236}
237
238/// Broadcast an event to all agents subscribed to a project.
239pub(crate) async fn broadcast_to_project(
240    state: &SharedState,
241    project: &str,
242    event_type: &str,
243    data: &serde_json::Value,
244) {
245    // Fetch subscribers via blocking pool
246    let state_c = state.clone();
247    let project_owned = project.to_string();
248    let subs = match tokio::task::spawn_blocking(move || {
249        let engine = state_c.engine.lock();
250        state_c
251            .subscription_store
252            .subscribers(engine.graph(), &project_owned)
253            .unwrap_or_default()
254    })
255    .await
256    {
257        Ok(s) => s,
258        Err(_) => return,
259    };
260
261    let event_id = data.get("id").and_then(|v| v.as_str());
262    let mut delivery_pairs: Vec<(String, String)> = Vec::new();
263    let mut offline_agents: Vec<String> = Vec::new();
264
265    // WS sends are in-memory — safe on async runtime
266    for agent_id in &subs {
267        match state.circuit_breaker.check(agent_id) {
268            circuit::CanDeliver::No => continue,
269            circuit::CanDeliver::Yes | circuit::CanDeliver::Probe => {}
270        }
271        let delivered = state.ws_registry.send_json(agent_id, event_type, data);
272        if delivered {
273            state.circuit_breaker.record_success(agent_id);
274            let state_fb = state.clone();
275            let agent_id_fb = agent_id.clone();
276            let _ = tokio::task::spawn_blocking(move || {
277                let engine = state_fb.engine.lock();
278                state_fb
279                    .audit_store
280                    .log_circuit_closed(engine.graph(), &agent_id_fb)
281            })
282            .await;
283            if let Some(eid) = event_id {
284                delivery_pairs.push((agent_id.clone(), eid.to_string()));
285            }
286        } else {
287            state.circuit_breaker.record_failure(agent_id);
288            let status = state.circuit_breaker.get_state(agent_id);
289            if status.state == "open" {
290                let state_fb = state.clone();
291                let agent_id_fb = agent_id.clone();
292                let failures = status.failures;
293                let _ = tokio::task::spawn_blocking(move || {
294                    let engine = state_fb.engine.lock();
295                    state_fb
296                        .audit_store
297                        .log_circuit_opened(engine.graph(), &agent_id_fb, failures)
298                })
299                .await;
300            }
301            offline_agents.push(agent_id.clone());
302        }
303    }
304
305    // Record deliveries via blocking pool
306    if !delivery_pairs.is_empty() {
307        let state_c = state.clone();
308        let _ = tokio::task::spawn_blocking(move || {
309            let engine = state_c.engine.lock();
310            for (agent_id, eid) in &delivery_pairs {
311                let _ = state_c
312                    .delivery_tracker
313                    .record_delivery(engine.graph(), agent_id, eid);
314            }
315        })
316        .await;
317    }
318
319    // Store notifications for offline agents so they pick them up on poll/reconnect
320    if !offline_agents.is_empty() {
321        let state_c = state.clone();
322        let event_type_owned = event_type.to_string();
323        let data_clone = data.clone();
324        let _ = tokio::task::spawn_blocking(move || {
325            let engine = state_c.engine.lock();
326            for agent_id in &offline_agents {
327                let _ = state_c.message_store.store_notification(
328                    engine.graph(),
329                    agent_id,
330                    &event_type_owned,
331                    &data_clone,
332                );
333            }
334        })
335        .await;
336    }
337}