Skip to main content

orca_control/
ws_handler.rs

1//! WebSocket handler for agent↔master streaming communication.
2//!
3//! Agents connect to `GET /api/v1/ws/agent?token=<cluster_token>&node_id=<id>`.
4//! After the upgrade, messages flow bidirectionally using [`AgentMessage`] and
5//! [`MasterMessage`] JSON frames.
6
7use std::sync::Arc;
8
9use axum::extract::ws::{Message, WebSocket};
10use axum::extract::{Query, State, WebSocketUpgrade};
11use axum::response::IntoResponse;
12use futures_util::{SinkExt, StreamExt};
13use serde::Deserialize;
14use tokio::sync::mpsc;
15use tracing::{error, info, warn};
16
17use orca_core::ws_types::{AgentMessage, MasterMessage};
18
19use crate::state::AppState;
20
21/// Query params for the WS upgrade request.
22#[derive(Deserialize)]
23pub struct WsQuery {
24    token: String,
25    node_id: u64,
26}
27
28/// Per-node sender so the master can push messages to a connected agent.
29pub type AgentSender = mpsc::Sender<MasterMessage>;
30
31/// Handle the WebSocket upgrade request.
32///
33/// Authenticates via the `token` query param, then upgrades to a
34/// bidirectional WebSocket connection.
35pub async fn ws_agent_handler(
36    ws: WebSocketUpgrade,
37    State(state): State<Arc<AppState>>,
38    Query(query): Query<WsQuery>,
39) -> impl IntoResponse {
40    // Authenticate: check token against cluster tokens
41    let valid = state.api_tokens.iter().any(|t| t == &query.token)
42        || state
43            .cluster_config
44            .token
45            .iter()
46            .any(|t| t.value == query.token);
47
48    if !valid {
49        return (axum::http::StatusCode::UNAUTHORIZED, "invalid token").into_response();
50    }
51
52    let node_id = query.node_id;
53    info!("WebSocket upgrade accepted for node {node_id}");
54
55    ws.on_upgrade(move |socket| handle_agent_ws(socket, state, node_id))
56        .into_response()
57}
58
59/// Main WebSocket loop for a connected agent.
60async fn handle_agent_ws(socket: WebSocket, state: Arc<AppState>, node_id: u64) {
61    let (mut ws_tx, mut ws_rx) = socket.split();
62
63    // Channel for master → agent messages (deploy commands, log requests, etc.)
64    let (tx, mut rx) = mpsc::channel::<MasterMessage>(64);
65
66    // Register this agent's sender so other parts of the system can push messages.
67    {
68        let mut senders = state.ws_agents.write().await;
69        senders.insert(node_id, tx.clone());
70    }
71
72    info!("Agent {node_id} connected via WebSocket");
73
74    // Send initial Ack
75    let ack = MasterMessage::Ack { node_id };
76    if let Ok(json) = serde_json::to_string(&ack) {
77        let _ = ws_tx.send(Message::Text(json.into())).await;
78    }
79
80    // Drain any pending commands that were queued before the WS connected.
81    drain_pending_commands(&state, node_id, &tx).await;
82
83    // Send Reconcile with all services expected on this node so the agent
84    // can self-heal after a restart (fixes #21: stale remote state).
85    send_reconcile(&state, node_id, &tx).await;
86
87    // Spawn task to forward master→agent messages from the channel to the WS.
88    let send_task = tokio::spawn(async move {
89        while let Some(msg) = rx.recv().await {
90            let json = match serde_json::to_string(&msg) {
91                Ok(j) => j,
92                Err(e) => {
93                    error!("Failed to serialize MasterMessage: {e}");
94                    continue;
95                }
96            };
97            if ws_tx.send(Message::Text(json.into())).await.is_err() {
98                break; // connection closed
99            }
100        }
101    });
102
103    // Process incoming agent messages.
104    while let Some(Ok(msg)) = ws_rx.next().await {
105        match msg {
106            Message::Text(text) => {
107                if let Err(e) = handle_agent_message(&state, node_id, &text, &tx).await {
108                    warn!("Error handling agent message from {node_id}: {e}");
109                }
110            }
111            Message::Close(_) => break,
112            _ => {} // ignore binary, ping, pong (axum handles pong auto)
113        }
114    }
115
116    // Cleanup on disconnect
117    send_task.abort();
118    {
119        let mut senders = state.ws_agents.write().await;
120        senders.remove(&node_id);
121    }
122    info!("Agent {node_id} WebSocket disconnected");
123}
124
125/// Process a single message from the agent.
126async fn handle_agent_message(
127    state: &AppState,
128    node_id: u64,
129    text: &str,
130    _tx: &mpsc::Sender<MasterMessage>,
131) -> anyhow::Result<()> {
132    let msg: AgentMessage = serde_json::from_str(text)?;
133
134    match msg {
135        AgentMessage::Heartbeat {
136            node_id: reported_id,
137            workloads,
138            stats,
139        } => {
140            handle_ws_heartbeat(state, reported_id, &workloads, &stats).await;
141        }
142        AgentMessage::DomainDiscovered {
143            service_name,
144            domain,
145            host_port,
146        } => {
147            info!(
148                "Node {node_id} discovered domain {domain} for {service_name} (port {host_port})"
149            );
150            // Update the service's domain tracking on the master so the
151            // TUI and status API reflect the domain correctly.
152            let mut services = state.services.write().await;
153            if let Some(svc) = services.get_mut(&service_name) {
154                svc.config.domain = Some(domain);
155            }
156        }
157        AgentMessage::DeployResult {
158            service_name,
159            success,
160            error,
161        } => {
162            if success {
163                info!("Node {node_id}: deploy of {service_name} succeeded");
164            } else {
165                error!(
166                    "Node {node_id}: deploy of {service_name} failed: {}",
167                    error.as_deref().unwrap_or("unknown")
168                );
169            }
170        }
171        AgentMessage::LogChunk {
172            request_id,
173            service_name: _,
174            data,
175            done,
176        } => {
177            // Forward to any pending log stream listener.
178            let listeners = state.log_listeners.read().await;
179            if let Some(listener_tx) = listeners.get(&request_id) {
180                let _ = listener_tx.send((data, done)).await;
181            }
182        }
183    }
184
185    Ok(())
186}
187
188/// Process a heartbeat received over WebSocket (same logic as HTTP handler).
189async fn handle_ws_heartbeat(
190    state: &AppState,
191    node_id: u64,
192    workloads: &[orca_core::ws_types::WorkloadReport],
193    stats: &orca_core::ws_types::HostStats,
194) {
195    let mut nodes = state.registered_nodes.write().await;
196    if let Some(node) = nodes.get_mut(&node_id) {
197        node.last_heartbeat = chrono::Utc::now();
198        node.cpu_percent = stats.cpu_percent;
199        node.memory_bytes = stats.memory_bytes;
200        node.memory_total = stats.memory_total;
201        node.disk_used = stats.disk_used;
202        node.disk_total = stats.disk_total;
203        node.net_rx = stats.net_rx;
204        node.net_tx = stats.net_tx;
205    }
206    drop(nodes);
207
208    // Update service statuses
209    if !workloads.is_empty() {
210        let mut services = state.services.write().await;
211        for report in workloads {
212            if let Some(svc) = services.get_mut(&report.service_name) {
213                let status = match report.status.as_str() {
214                    "running" => orca_core::types::WorkloadStatus::Running,
215                    "stopped" => orca_core::types::WorkloadStatus::Stopped,
216                    "failed" => orca_core::types::WorkloadStatus::Failed,
217                    _ => orca_core::types::WorkloadStatus::Stopped,
218                };
219                for instance in &mut svc.instances {
220                    instance.status = status;
221                }
222            }
223        }
224    }
225}
226
227/// Drain pending commands from the HTTP queue and send them over WS.
228async fn drain_pending_commands(state: &AppState, node_id: u64, tx: &mpsc::Sender<MasterMessage>) {
229    let commands = {
230        let mut pending = state.pending_commands.write().await;
231        pending.remove(&node_id).unwrap_or_default()
232    };
233    for cmd in commands {
234        if let Some(action) = cmd.get("action").and_then(|a| a.as_str()) {
235            match action {
236                "deploy" => {
237                    if let Some(spec) = cmd.get("spec")
238                        && let Ok(spec) = serde_json::from_value(spec.clone())
239                    {
240                        let _ = tx
241                            .send(MasterMessage::Deploy {
242                                spec: Box::new(spec),
243                            })
244                            .await;
245                    }
246                }
247                "stop" => {
248                    if let Some(name) = cmd.get("service_name").and_then(|n| n.as_str()) {
249                        let _ = tx
250                            .send(MasterMessage::Stop {
251                                service_name: name.to_string(),
252                            })
253                            .await;
254                    }
255                }
256                _ => {}
257            }
258        }
259    }
260}
261
262/// Send the list of services expected on this agent node so it can
263/// reconcile (redeploy missing containers, stop unexpected ones).
264async fn send_reconcile(state: &AppState, node_id: u64, tx: &mpsc::Sender<MasterMessage>) {
265    // Find the node's address/hostname for placement matching
266    let node_address = {
267        let nodes = state.registered_nodes.read().await;
268        nodes.get(&node_id).map(|n| n.address.clone())
269    };
270    let Some(node_addr) = node_address else {
271        return;
272    };
273
274    // Collect all services whose placement targets this node
275    let services = state.services.read().await;
276    let expected: Vec<Box<orca_core::types::WorkloadSpec>> = services
277        .values()
278        .filter(|svc| {
279            svc.config
280                .placement
281                .as_ref()
282                .and_then(|p| p.node.as_ref())
283                .is_some_and(|target| {
284                    node_addr.contains(target.as_str()) || target == &node_id.to_string() || {
285                        let nodes_guard =
286                            futures_util::FutureExt::now_or_never(state.registered_nodes.read());
287                        nodes_guard
288                            .and_then(|nodes| {
289                                nodes
290                                    .get(&node_id)
291                                    .and_then(|n| n.labels.get("hostname").map(|h| h == target))
292                            })
293                            .unwrap_or(false)
294                    }
295                })
296        })
297        .filter_map(|svc| {
298            crate::routes::service_config_to_spec(&svc.config)
299                .ok()
300                .map(Box::new)
301        })
302        .collect();
303
304    if expected.is_empty() {
305        return;
306    }
307
308    info!(
309        "Sending Reconcile to node {node_id} with {} expected services",
310        expected.len()
311    );
312    let _ = tx.send(MasterMessage::Reconcile { expected }).await;
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn ws_query_deserializes() {
321        let q: WsQuery = serde_json::from_str(r#"{"token":"abc123","node_id":42}"#).unwrap();
322        assert_eq!(q.token, "abc123");
323        assert_eq!(q.node_id, 42);
324    }
325}