Skip to main content

orca_control/
ws_handler.rs

1//! WebSocket handler for agent↔master streaming communication.
2//!
3//! Agents connect to `GET /api/v1/ws/agent?token=<cluster_token>&node_id=<id>`.
4//! After the upgrade, messages flow bidirectionally using [`AgentMessage`] and
5//! [`MasterMessage`] JSON frames.
6
7use std::sync::Arc;
8
9use axum::extract::ws::{Message, WebSocket};
10use axum::extract::{Query, State, WebSocketUpgrade};
11use axum::response::IntoResponse;
12use futures_util::{SinkExt, StreamExt};
13use serde::Deserialize;
14use tokio::sync::mpsc;
15use tracing::{error, info, warn};
16
17use orca_core::ws_types::{AgentMessage, MasterMessage};
18
19use crate::state::AppState;
20
21/// Query params for the WS upgrade request.
22#[derive(Deserialize)]
23pub struct WsQuery {
24    token: String,
25    node_id: u64,
26    /// Agent's address (e.g. "10.0.0.5:6881") for node registration.
27    #[serde(default)]
28    address: Option<String>,
29}
30
31/// Per-node sender so the master can push messages to a connected agent.
32pub type AgentSender = mpsc::Sender<MasterMessage>;
33
34/// Handle the WebSocket upgrade request.
35///
36/// Authenticates via the `token` query param, then upgrades to a
37/// bidirectional WebSocket connection.
38pub async fn ws_agent_handler(
39    ws: WebSocketUpgrade,
40    State(state): State<Arc<AppState>>,
41    Query(query): Query<WsQuery>,
42) -> impl IntoResponse {
43    // Authenticate: check token against cluster tokens
44    let valid = state.api_tokens.iter().any(|t| t == &query.token)
45        || state
46            .cluster_config
47            .token
48            .iter()
49            .any(|t| t.value == query.token);
50
51    if !valid {
52        return (axum::http::StatusCode::UNAUTHORIZED, "invalid token").into_response();
53    }
54
55    let node_id = query.node_id;
56    let address = query.address;
57    info!("WebSocket upgrade accepted for node {node_id}");
58
59    ws.on_upgrade(move |socket| handle_agent_ws(socket, state, node_id, address))
60        .into_response()
61}
62
63/// Main WebSocket loop for a connected agent.
64async fn handle_agent_ws(
65    socket: WebSocket,
66    state: Arc<AppState>,
67    node_id: u64,
68    agent_address: Option<String>,
69) {
70    let (mut ws_tx, mut ws_rx) = socket.split();
71
72    // Channel for master → agent messages (deploy commands, log requests, etc.)
73    let (tx, mut rx) = mpsc::channel::<MasterMessage>(64);
74
75    // Register this agent's sender so other parts of the system can push messages.
76    {
77        let mut senders = state.ws_agents.write().await;
78        senders.insert(node_id, tx.clone());
79    }
80
81    info!("Agent {node_id} connected via WebSocket");
82
83    // Register/update the node in registered_nodes so the reconciler's
84    // find_target_node() can match placement constraints to this agent.
85    {
86        let addr = agent_address.unwrap_or_else(|| format!("ws-agent-{node_id}"));
87        let mut nodes = state.registered_nodes.write().await;
88        let node = nodes
89            .entry(node_id)
90            .or_insert_with(|| crate::state::RegisteredNode {
91                node_id,
92                address: addr.clone(),
93                labels: std::collections::HashMap::new(),
94                last_heartbeat: chrono::Utc::now(),
95                drain: false,
96                cpu_percent: 0.0,
97                memory_bytes: 0,
98                memory_total: 0,
99                disk_used: 0,
100                disk_total: 0,
101                net_rx: 0,
102                net_tx: 0,
103            });
104        node.last_heartbeat = chrono::Utc::now();
105        node.address = addr;
106        info!("Node {node_id} registered at {}", node.address);
107    }
108
109    // Send initial Ack
110    let ack = MasterMessage::Ack { node_id };
111    if let Ok(json) = serde_json::to_string(&ack) {
112        let _ = ws_tx.send(Message::Text(json.into())).await;
113    }
114
115    // Drain any pending commands that were queued before the WS connected.
116    drain_pending_commands(&state, node_id, &tx).await;
117
118    // Send Reconcile with all services expected on this node so the agent
119    // can self-heal after a restart (fixes #21: stale remote state).
120    send_reconcile(&state, node_id, &tx).await;
121
122    // Spawn task to forward master→agent messages from the channel to the WS.
123    let send_task = tokio::spawn(async move {
124        while let Some(msg) = rx.recv().await {
125            let json = match serde_json::to_string(&msg) {
126                Ok(j) => j,
127                Err(e) => {
128                    error!("Failed to serialize MasterMessage: {e}");
129                    continue;
130                }
131            };
132            if ws_tx.send(Message::Text(json.into())).await.is_err() {
133                break; // connection closed
134            }
135        }
136    });
137
138    // Process incoming agent messages.
139    while let Some(Ok(msg)) = ws_rx.next().await {
140        match msg {
141            Message::Text(text) => {
142                if let Err(e) = handle_agent_message(&state, node_id, &text, &tx).await {
143                    warn!("Error handling agent message from {node_id}: {e}");
144                }
145            }
146            Message::Close(_) => break,
147            _ => {} // ignore binary, ping, pong (axum handles pong auto)
148        }
149    }
150
151    // Cleanup on disconnect
152    send_task.abort();
153    {
154        let mut senders = state.ws_agents.write().await;
155        senders.remove(&node_id);
156    }
157    info!("Agent {node_id} WebSocket disconnected");
158}
159
160/// Process a single message from the agent.
161async fn handle_agent_message(
162    state: &AppState,
163    node_id: u64,
164    text: &str,
165    _tx: &mpsc::Sender<MasterMessage>,
166) -> anyhow::Result<()> {
167    let msg: AgentMessage = serde_json::from_str(text)?;
168
169    match msg {
170        AgentMessage::Heartbeat {
171            node_id: reported_id,
172            workloads,
173            stats,
174        } => {
175            handle_ws_heartbeat(state, reported_id, &workloads, &stats).await;
176        }
177        AgentMessage::DomainDiscovered {
178            service_name,
179            domain,
180            host_port,
181        } => {
182            info!(
183                "Node {node_id} discovered domain {domain} for {service_name} (port {host_port})"
184            );
185            // Update the service's domain tracking on the master so the
186            // TUI and status API reflect the domain correctly.
187            let mut services = state.services.write().await;
188            if let Some(svc) = services.get_mut(&service_name) {
189                svc.config.domain = Some(domain);
190            }
191        }
192        AgentMessage::DeployResult {
193            service_name,
194            success,
195            error,
196        } => {
197            if success {
198                info!("Node {node_id}: deploy of {service_name} succeeded");
199            } else {
200                error!(
201                    "Node {node_id}: deploy of {service_name} failed: {}",
202                    error.as_deref().unwrap_or("unknown")
203                );
204            }
205        }
206        AgentMessage::LogChunk {
207            request_id,
208            service_name: _,
209            data,
210            done,
211        } => {
212            // Forward to any pending log stream listener.
213            let listeners = state.log_listeners.read().await;
214            if let Some(listener_tx) = listeners.get(&request_id) {
215                let _ = listener_tx.send((data, done)).await;
216            }
217        }
218    }
219
220    Ok(())
221}
222
223/// Process a heartbeat received over WebSocket (same logic as HTTP handler).
224async fn handle_ws_heartbeat(
225    state: &AppState,
226    node_id: u64,
227    workloads: &[orca_core::ws_types::WorkloadReport],
228    stats: &orca_core::ws_types::HostStats,
229) {
230    let mut nodes = state.registered_nodes.write().await;
231    if let Some(node) = nodes.get_mut(&node_id) {
232        node.last_heartbeat = chrono::Utc::now();
233        node.cpu_percent = stats.cpu_percent;
234        node.memory_bytes = stats.memory_bytes;
235        node.memory_total = stats.memory_total;
236        node.disk_used = stats.disk_used;
237        node.disk_total = stats.disk_total;
238        node.net_rx = stats.net_rx;
239        node.net_tx = stats.net_tx;
240    }
241    drop(nodes);
242
243    // Update service statuses + per-container stats
244    if !workloads.is_empty() {
245        let mut services = state.services.write().await;
246        let mut stats_cache = state.container_stats.write().await;
247
248        for report in workloads {
249            if let Some(svc) = services.get_mut(&report.service_name) {
250                let status = match report.status.as_str() {
251                    "running" => orca_core::types::WorkloadStatus::Running,
252                    "stopped" => orca_core::types::WorkloadStatus::Stopped,
253                    "failed" => orca_core::types::WorkloadStatus::Failed,
254                    _ => orca_core::types::WorkloadStatus::Stopped,
255                };
256                for instance in &mut svc.instances {
257                    instance.status = status;
258                }
259            }
260
261            // Cache per-container stats from remote agents
262            if report.memory_bytes > 0 || report.cpu_percent > 0.0 {
263                stats_cache.insert(
264                    report.service_name.clone(),
265                    crate::stats::ContainerStats {
266                        memory_usage: crate::stats::format_bytes(report.memory_bytes),
267                        cpu_percent: report.cpu_percent,
268                    },
269                );
270            }
271        }
272    }
273}
274
275/// Drain pending commands from the HTTP queue and send them over WS.
276async fn drain_pending_commands(state: &AppState, node_id: u64, tx: &mpsc::Sender<MasterMessage>) {
277    let commands = {
278        let mut pending = state.pending_commands.write().await;
279        pending.remove(&node_id).unwrap_or_default()
280    };
281    for cmd in commands {
282        if let Some(action) = cmd.get("action").and_then(|a| a.as_str()) {
283            match action {
284                "deploy" => {
285                    if let Some(spec) = cmd.get("spec")
286                        && let Ok(spec) = serde_json::from_value(spec.clone())
287                    {
288                        let _ = tx
289                            .send(MasterMessage::Deploy {
290                                spec: Box::new(spec),
291                            })
292                            .await;
293                    }
294                }
295                "stop" => {
296                    if let Some(name) = cmd.get("service_name").and_then(|n| n.as_str()) {
297                        let _ = tx
298                            .send(MasterMessage::Stop {
299                                service_name: name.to_string(),
300                            })
301                            .await;
302                    }
303                }
304                _ => {}
305            }
306        }
307    }
308}
309
310/// Send the list of services expected on this agent node so it can
311/// reconcile (redeploy missing containers, stop unexpected ones).
312async fn send_reconcile(state: &AppState, node_id: u64, tx: &mpsc::Sender<MasterMessage>) {
313    // Find the node's address/hostname for placement matching
314    let node_address = {
315        let nodes = state.registered_nodes.read().await;
316        nodes.get(&node_id).map(|n| n.address.clone())
317    };
318    let Some(node_addr) = node_address else {
319        return;
320    };
321
322    // Collect all services whose placement targets this node
323    let services = state.services.read().await;
324    let expected: Vec<Box<orca_core::types::WorkloadSpec>> = services
325        .values()
326        .filter(|svc| {
327            svc.config
328                .placement
329                .as_ref()
330                .and_then(|p| p.node.as_ref())
331                .is_some_and(|target| {
332                    node_addr.contains(target.as_str()) || target == &node_id.to_string() || {
333                        let nodes_guard =
334                            futures_util::FutureExt::now_or_never(state.registered_nodes.read());
335                        nodes_guard
336                            .and_then(|nodes| {
337                                nodes
338                                    .get(&node_id)
339                                    .and_then(|n| n.labels.get("hostname").map(|h| h == target))
340                            })
341                            .unwrap_or(false)
342                    }
343                })
344        })
345        .filter_map(|svc| {
346            crate::routes::service_config_to_spec(&svc.config)
347                .ok()
348                .map(Box::new)
349        })
350        .collect();
351
352    if expected.is_empty() {
353        return;
354    }
355
356    info!(
357        "Sending Reconcile to node {node_id} with {} expected services",
358        expected.len()
359    );
360    let _ = tx.send(MasterMessage::Reconcile { expected }).await;
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn ws_query_deserializes() {
369        let q: WsQuery = serde_json::from_str(r#"{"token":"abc123","node_id":42}"#).unwrap();
370        assert_eq!(q.token, "abc123");
371        assert_eq!(q.node_id, 42);
372    }
373}