orca_control/ws_handler/
mod.rs1mod heartbeat;
8mod placeholders;
9mod reconcile;
10
11use std::sync::Arc;
12
13use axum::extract::ws::{Message, WebSocket};
14use axum::extract::{Query, State, WebSocketUpgrade};
15use axum::response::IntoResponse;
16use futures_util::{SinkExt, StreamExt};
17use serde::Deserialize;
18use tokio::sync::mpsc;
19use tracing::{error, info, warn};
20
21use orca_core::ws_types::{AgentMessage, MasterMessage};
22
23use crate::state::AppState;
24
25use heartbeat::handle_ws_heartbeat;
26use placeholders::{remove_remote_placeholders, upsert_remote_placeholders};
27use reconcile::{drain_pending_commands, send_reconcile};
28
29#[derive(Deserialize)]
31pub struct WsQuery {
32 token: String,
33 node_id: u64,
34 #[serde(default)]
36 address: Option<String>,
37}
38
39pub type AgentSender = mpsc::Sender<MasterMessage>;
41
42pub async fn ws_agent_handler(
47 ws: WebSocketUpgrade,
48 State(state): State<Arc<AppState>>,
49 Query(query): Query<WsQuery>,
50) -> impl IntoResponse {
51 let valid = state.api_tokens.iter().any(|t| t == &query.token)
53 || state
54 .cluster_config
55 .token
56 .iter()
57 .any(|t| t.value == query.token);
58
59 if !valid {
60 return (axum::http::StatusCode::UNAUTHORIZED, "invalid token").into_response();
61 }
62
63 let node_id = query.node_id;
64 let address = query.address;
65 info!("WebSocket upgrade accepted for node {node_id}");
66
67 ws.on_upgrade(move |socket| handle_agent_ws(socket, state, node_id, address))
68 .into_response()
69}
70
71async fn handle_agent_ws(
73 socket: WebSocket,
74 state: Arc<AppState>,
75 node_id: u64,
76 agent_address: Option<String>,
77) {
78 let (mut ws_tx, mut ws_rx) = socket.split();
79
80 let (tx, mut rx) = mpsc::channel::<MasterMessage>(64);
82
83 {
85 let mut senders = state.ws_agents.write().await;
86 senders.insert(node_id, tx.clone());
87 }
88
89 info!("Agent {node_id} connected via WebSocket");
90
91 {
94 let addr = agent_address.unwrap_or_else(|| format!("ws-agent-{node_id}"));
95 let mut nodes = state.registered_nodes.write().await;
96 let node = nodes
97 .entry(node_id)
98 .or_insert_with(|| crate::state::RegisteredNode {
99 node_id,
100 address: addr.clone(),
101 labels: std::collections::HashMap::new(),
102 last_heartbeat: chrono::Utc::now(),
103 drain: false,
104 cpu_percent: 0.0,
105 memory_bytes: 0,
106 memory_total: 0,
107 disk_used: 0,
108 disk_total: 0,
109 net_rx: 0,
110 net_tx: 0,
111 });
112 node.last_heartbeat = chrono::Utc::now();
113 node.address = addr;
114 info!("Node {node_id} registered at {}", node.address);
115 }
116
117 let ack = MasterMessage::Ack { node_id };
119 if let Ok(json) = serde_json::to_string(&ack) {
120 let _ = ws_tx.send(Message::Text(json.into())).await;
121 }
122
123 drain_pending_commands(&state, node_id, &tx).await;
125
126 upsert_remote_placeholders(&state, node_id).await;
130
131 send_reconcile(&state, node_id, &tx).await;
134
135 let send_task = tokio::spawn(async move {
137 while let Some(msg) = rx.recv().await {
138 let json = match serde_json::to_string(&msg) {
139 Ok(j) => j,
140 Err(e) => {
141 error!("Failed to serialize MasterMessage: {e}");
142 continue;
143 }
144 };
145 if ws_tx.send(Message::Text(json.into())).await.is_err() {
146 break; }
148 }
149 });
150
151 let ping_tx = tx.clone();
153 let ping_task = tokio::spawn(async move {
154 let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
155 interval.tick().await; loop {
157 interval.tick().await;
158 if ping_tx.send(MasterMessage::StatusPing).await.is_err() {
159 break;
160 }
161 }
162 });
163
164 while let Some(Ok(msg)) = ws_rx.next().await {
166 match msg {
167 Message::Text(text) => {
168 if let Err(e) = handle_agent_message(&state, node_id, &text, &tx).await {
169 warn!("Error handling agent message from {node_id}: {e}");
170 }
171 }
172 Message::Close(_) => break,
173 _ => {} }
175 }
176
177 send_task.abort();
179 ping_task.abort();
180 {
181 let mut senders = state.ws_agents.write().await;
182 senders.remove(&node_id);
183 }
184 remove_remote_placeholders(&state, node_id).await;
185 info!("Agent {node_id} WebSocket disconnected");
186}
187
188async fn handle_agent_message(
190 state: &AppState,
191 node_id: u64,
192 text: &str,
193 _tx: &mpsc::Sender<MasterMessage>,
194) -> anyhow::Result<()> {
195 let msg: AgentMessage = serde_json::from_str(text)?;
196
197 match msg {
198 AgentMessage::Heartbeat {
199 node_id: reported_id,
200 workloads,
201 stats,
202 } => {
203 handle_ws_heartbeat(state, reported_id, &workloads, &stats).await;
204 }
205 AgentMessage::DomainDiscovered {
206 service_name,
207 domain,
208 host_port,
209 } => {
210 info!(
211 "Node {node_id} discovered domain {domain} for {service_name} (port {host_port})"
212 );
213 let mut services = state.services.write().await;
216 if let Some(svc) = services.get_mut(&service_name) {
217 svc.config.domain = Some(domain);
218 }
219 }
220 AgentMessage::DeployResult {
221 service_name,
222 success,
223 error,
224 } => {
225 if success {
226 info!("Node {node_id}: deploy of {service_name} succeeded");
227 let mut services = state.services.write().await;
228 if let Some(svc) = services.get_mut(&service_name) {
229 let placeholder_id = format!("remote-{node_id}");
230 if let Some(inst) = svc
231 .instances
232 .iter_mut()
233 .find(|i| i.handle.runtime_id == placeholder_id)
234 {
235 inst.status = orca_core::types::WorkloadStatus::Running;
236 }
237 }
238 } else {
239 error!(
240 "Node {node_id}: deploy of {service_name} failed: {}",
241 error.as_deref().unwrap_or("unknown")
242 );
243 }
244 let result = if success {
245 Ok(())
246 } else {
247 Err(error.unwrap_or_else(|| "deploy failed".to_string()))
248 };
249 if let Some(tx) = state.pending_deploys.write().await.remove(&service_name) {
250 let _ = tx.send(result);
251 }
252 }
253 AgentMessage::LogChunk {
254 request_id,
255 service_name: _,
256 data,
257 done,
258 } => {
259 let listeners = state.log_listeners.read().await;
261 if let Some(listener_tx) = listeners.get(&request_id) {
262 let _ = listener_tx.send((data, done)).await;
263 }
264 }
265 AgentMessage::BackupResult {
266 node_id,
267 success,
268 message,
269 } => {
270 if success {
271 info!("Node {node_id}: backup complete — {message}");
272 } else {
273 error!("Node {node_id}: backup failed — {message}");
274 }
275 }
276 AgentMessage::ExecOutput { session_id, data } => {
277 use base64::Engine as _;
278 let bytes = match base64::engine::general_purpose::STANDARD.decode(&data) {
279 Ok(b) => b,
280 Err(e) => {
281 tracing::warn!("exec: bad base64 output for session {session_id}: {e}");
282 return Ok(());
283 }
284 };
285 let sessions = state.exec_sessions.read().await;
286 if let Some(tx) = sessions.get(&session_id) {
287 let _ = tx.send(bytes).await;
288 }
289 }
290 AgentMessage::ExecDone {
291 session_id,
292 exit_code,
293 } => {
294 info!("Node {node_id}: exec session {session_id} done (exit {exit_code})");
295 state.exec_sessions.write().await.remove(&session_id);
296 }
297 }
298
299 Ok(())
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn ws_query_deserializes() {
308 let q: WsQuery = serde_json::from_str(r#"{"token":"abc123","node_id":42}"#).unwrap();
309 assert_eq!(q.token, "abc123");
310 assert_eq!(q.node_id, 42);
311 }
312}