Skip to main content

orca_control/ws_handler/
mod.rs

1//! WebSocket handler for agent↔master streaming communication.
2//!
3//! Agents connect to `GET /api/v1/ws/agent?token=<cluster_token>&node_id=<id>`.
4//! After the upgrade, messages flow bidirectionally using [`AgentMessage`] and
5//! [`MasterMessage`] JSON frames.
6
7mod heartbeat;
8mod placeholders;
9mod reconcile;
10
11use std::sync::Arc;
12
13use axum::extract::ws::{Message, WebSocket};
14use axum::extract::{Query, State, WebSocketUpgrade};
15use axum::response::IntoResponse;
16use futures_util::{SinkExt, StreamExt};
17use serde::Deserialize;
18use tokio::sync::mpsc;
19use tracing::{error, info, warn};
20
21use orca_core::ws_types::{AgentMessage, MasterMessage};
22
23use crate::state::AppState;
24
25use heartbeat::handle_ws_heartbeat;
26use placeholders::{remove_remote_placeholders, upsert_remote_placeholders};
27use reconcile::{drain_pending_commands, send_reconcile};
28
29/// Query params for the WS upgrade request.
30#[derive(Deserialize)]
31pub struct WsQuery {
32    token: String,
33    node_id: u64,
34    /// Agent's address (e.g. "10.0.0.5:6881") for node registration.
35    #[serde(default)]
36    address: Option<String>,
37}
38
39/// Per-node sender so the master can push messages to a connected agent.
40pub type AgentSender = mpsc::Sender<MasterMessage>;
41
42/// Handle the WebSocket upgrade request.
43///
44/// Authenticates via the `token` query param, then upgrades to a
45/// bidirectional WebSocket connection.
46pub async fn ws_agent_handler(
47    ws: WebSocketUpgrade,
48    State(state): State<Arc<AppState>>,
49    Query(query): Query<WsQuery>,
50) -> impl IntoResponse {
51    // Authenticate: check token against cluster tokens
52    let valid = state.api_tokens.iter().any(|t| t == &query.token)
53        || state
54            .cluster_config
55            .token
56            .iter()
57            .any(|t| t.value == query.token);
58
59    if !valid {
60        return (axum::http::StatusCode::UNAUTHORIZED, "invalid token").into_response();
61    }
62
63    let node_id = query.node_id;
64    let address = query.address;
65    info!("WebSocket upgrade accepted for node {node_id}");
66
67    ws.on_upgrade(move |socket| handle_agent_ws(socket, state, node_id, address))
68        .into_response()
69}
70
71/// Main WebSocket loop for a connected agent.
72async fn handle_agent_ws(
73    socket: WebSocket,
74    state: Arc<AppState>,
75    node_id: u64,
76    agent_address: Option<String>,
77) {
78    let (mut ws_tx, mut ws_rx) = socket.split();
79
80    // Channel for master → agent messages (deploy commands, log requests, etc.)
81    let (tx, mut rx) = mpsc::channel::<MasterMessage>(64);
82
83    // Register this agent's sender so other parts of the system can push messages.
84    {
85        let mut senders = state.ws_agents.write().await;
86        senders.insert(node_id, tx.clone());
87    }
88
89    info!("Agent {node_id} connected via WebSocket");
90
91    // Register/update the node in registered_nodes so the reconciler's
92    // find_target_node() can match placement constraints to this agent.
93    {
94        let addr = agent_address.unwrap_or_else(|| format!("ws-agent-{node_id}"));
95        let mut nodes = state.registered_nodes.write().await;
96        let node = nodes
97            .entry(node_id)
98            .or_insert_with(|| crate::state::RegisteredNode {
99                node_id,
100                address: addr.clone(),
101                labels: std::collections::HashMap::new(),
102                last_heartbeat: chrono::Utc::now(),
103                drain: false,
104                cpu_percent: 0.0,
105                memory_bytes: 0,
106                memory_total: 0,
107                disk_used: 0,
108                disk_total: 0,
109                net_rx: 0,
110                net_tx: 0,
111            });
112        node.last_heartbeat = chrono::Utc::now();
113        node.address = addr;
114        info!("Node {node_id} registered at {}", node.address);
115    }
116
117    // Send initial Ack
118    let ack = MasterMessage::Ack { node_id };
119    if let Ok(json) = serde_json::to_string(&ack) {
120        let _ = ws_tx.send(Message::Text(json.into())).await;
121    }
122
123    // Drain any pending commands that were queued before the WS connected.
124    drain_pending_commands(&state, node_id, &tx).await;
125
126    // Ensure a placeholder InstanceState exists for every service placed on this
127    // node so the heartbeat and DeployResult handlers have something to update,
128    // and the watchdog current < desired check never fires for remote services.
129    upsert_remote_placeholders(&state, node_id).await;
130
131    // Send Reconcile with all services expected on this node so the agent
132    // can self-heal after a restart (fixes #21: stale remote state).
133    send_reconcile(&state, node_id, &tx).await;
134
135    // Spawn task to forward master→agent messages from the channel to the WS.
136    let send_task = tokio::spawn(async move {
137        while let Some(msg) = rx.recv().await {
138            let json = match serde_json::to_string(&msg) {
139                Ok(j) => j,
140                Err(e) => {
141                    error!("Failed to serialize MasterMessage: {e}");
142                    continue;
143                }
144            };
145            if ws_tx.send(Message::Text(json.into())).await.is_err() {
146                break; // connection closed
147            }
148        }
149    });
150
151    // Periodic status sync: master pings agent every 30 s for a fresh heartbeat.
152    let ping_tx = tx.clone();
153    let ping_task = tokio::spawn(async move {
154        let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
155        interval.tick().await; // skip first tick (agent just connected and sent initial state)
156        loop {
157            interval.tick().await;
158            if ping_tx.send(MasterMessage::StatusPing).await.is_err() {
159                break;
160            }
161        }
162    });
163
164    // Process incoming agent messages.
165    while let Some(Ok(msg)) = ws_rx.next().await {
166        match msg {
167            Message::Text(text) => {
168                if let Err(e) = handle_agent_message(&state, node_id, &text, &tx).await {
169                    warn!("Error handling agent message from {node_id}: {e}");
170                }
171            }
172            Message::Close(_) => break,
173            _ => {} // ignore binary, ping, pong (axum handles pong auto)
174        }
175    }
176
177    // Cleanup on disconnect
178    send_task.abort();
179    ping_task.abort();
180    {
181        let mut senders = state.ws_agents.write().await;
182        senders.remove(&node_id);
183    }
184    remove_remote_placeholders(&state, node_id).await;
185    info!("Agent {node_id} WebSocket disconnected");
186}
187
188/// Process a single message from the agent.
189async fn handle_agent_message(
190    state: &AppState,
191    node_id: u64,
192    text: &str,
193    _tx: &mpsc::Sender<MasterMessage>,
194) -> anyhow::Result<()> {
195    let msg: AgentMessage = serde_json::from_str(text)?;
196
197    match msg {
198        AgentMessage::Heartbeat {
199            node_id: reported_id,
200            workloads,
201            stats,
202        } => {
203            handle_ws_heartbeat(state, reported_id, &workloads, &stats).await;
204        }
205        AgentMessage::DomainDiscovered {
206            service_name,
207            domain,
208            host_port,
209        } => {
210            info!(
211                "Node {node_id} discovered domain {domain} for {service_name} (port {host_port})"
212            );
213            // Update the service's domain tracking on the master so the
214            // TUI and status API reflect the domain correctly.
215            let mut services = state.services.write().await;
216            if let Some(svc) = services.get_mut(&service_name) {
217                svc.config.domain = Some(domain);
218            }
219        }
220        AgentMessage::DeployResult {
221            service_name,
222            success,
223            error,
224        } => {
225            if success {
226                info!("Node {node_id}: deploy of {service_name} succeeded");
227                let mut services = state.services.write().await;
228                if let Some(svc) = services.get_mut(&service_name) {
229                    let placeholder_id = format!("remote-{node_id}");
230                    if let Some(inst) = svc
231                        .instances
232                        .iter_mut()
233                        .find(|i| i.handle.runtime_id == placeholder_id)
234                    {
235                        inst.status = orca_core::types::WorkloadStatus::Running;
236                    }
237                }
238            } else {
239                error!(
240                    "Node {node_id}: deploy of {service_name} failed: {}",
241                    error.as_deref().unwrap_or("unknown")
242                );
243            }
244            let result = if success {
245                Ok(())
246            } else {
247                Err(error.unwrap_or_else(|| "deploy failed".to_string()))
248            };
249            if let Some(tx) = state.pending_deploys.write().await.remove(&service_name) {
250                let _ = tx.send(result);
251            }
252        }
253        AgentMessage::LogChunk {
254            request_id,
255            service_name: _,
256            data,
257            done,
258        } => {
259            // Forward to any pending log stream listener.
260            let listeners = state.log_listeners.read().await;
261            if let Some(listener_tx) = listeners.get(&request_id) {
262                let _ = listener_tx.send((data, done)).await;
263            }
264        }
265        AgentMessage::BackupResult {
266            node_id,
267            success,
268            message,
269        } => {
270            if success {
271                info!("Node {node_id}: backup complete — {message}");
272            } else {
273                error!("Node {node_id}: backup failed — {message}");
274            }
275            // Cache for the cluster-backups dashboard so it can surface the
276            // last-known failure without rescanning logs.
277            state.last_backup_results.write().await.insert(
278                node_id,
279                crate::state::LastBackupResult {
280                    success,
281                    message,
282                    recorded_at: chrono::Utc::now(),
283                },
284            );
285        }
286        AgentMessage::BackupStatusReport { request_id, data } => {
287            if let Some(tx) = state.backup_listeners.read().await.get(&request_id) {
288                let _ = tx.send(data).await;
289            }
290        }
291        AgentMessage::NetworkStatusReport { request_id, data } => {
292            if let Some(tx) = state.network_listeners.read().await.get(&request_id) {
293                let _ = tx.send(data).await;
294            }
295        }
296        AgentMessage::ExecOutput { session_id, data } => {
297            use base64::Engine as _;
298            let bytes = match base64::engine::general_purpose::STANDARD.decode(&data) {
299                Ok(b) => b,
300                Err(e) => {
301                    tracing::warn!("exec: bad base64 output for session {session_id}: {e}");
302                    return Ok(());
303                }
304            };
305            let sessions = state.exec_sessions.read().await;
306            if let Some(tx) = sessions.get(&session_id) {
307                let _ = tx.send(bytes).await;
308            }
309        }
310        AgentMessage::ExecDone {
311            session_id,
312            exit_code,
313        } => {
314            info!("Node {node_id}: exec session {session_id} done (exit {exit_code})");
315            state.exec_sessions.write().await.remove(&session_id);
316        }
317    }
318
319    Ok(())
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn ws_query_deserializes() {
328        let q: WsQuery = serde_json::from_str(r#"{"token":"abc123","node_id":42}"#).unwrap();
329        assert_eq!(q.token, "abc123");
330        assert_eq!(q.node_id, 42);
331    }
332}