1use std::sync::Arc;
8
9use axum::extract::ws::{Message, WebSocket};
10use axum::extract::{Query, State, WebSocketUpgrade};
11use axum::response::IntoResponse;
12use futures_util::{SinkExt, StreamExt};
13use serde::Deserialize;
14use tokio::sync::mpsc;
15use tracing::{error, info, warn};
16
17use orca_core::ws_types::{AgentMessage, MasterMessage};
18
19use crate::state::AppState;
20
21#[derive(Deserialize)]
23pub struct WsQuery {
24 token: String,
25 node_id: u64,
26 #[serde(default)]
28 address: Option<String>,
29}
30
31pub type AgentSender = mpsc::Sender<MasterMessage>;
33
34pub async fn ws_agent_handler(
39 ws: WebSocketUpgrade,
40 State(state): State<Arc<AppState>>,
41 Query(query): Query<WsQuery>,
42) -> impl IntoResponse {
43 let valid = state.api_tokens.iter().any(|t| t == &query.token)
45 || state
46 .cluster_config
47 .token
48 .iter()
49 .any(|t| t.value == query.token);
50
51 if !valid {
52 return (axum::http::StatusCode::UNAUTHORIZED, "invalid token").into_response();
53 }
54
55 let node_id = query.node_id;
56 let address = query.address;
57 info!("WebSocket upgrade accepted for node {node_id}");
58
59 ws.on_upgrade(move |socket| handle_agent_ws(socket, state, node_id, address))
60 .into_response()
61}
62
63async fn handle_agent_ws(
65 socket: WebSocket,
66 state: Arc<AppState>,
67 node_id: u64,
68 agent_address: Option<String>,
69) {
70 let (mut ws_tx, mut ws_rx) = socket.split();
71
72 let (tx, mut rx) = mpsc::channel::<MasterMessage>(64);
74
75 {
77 let mut senders = state.ws_agents.write().await;
78 senders.insert(node_id, tx.clone());
79 }
80
81 info!("Agent {node_id} connected via WebSocket");
82
83 {
86 let addr = agent_address.unwrap_or_else(|| format!("ws-agent-{node_id}"));
87 let mut nodes = state.registered_nodes.write().await;
88 let node = nodes
89 .entry(node_id)
90 .or_insert_with(|| crate::state::RegisteredNode {
91 node_id,
92 address: addr.clone(),
93 labels: std::collections::HashMap::new(),
94 last_heartbeat: chrono::Utc::now(),
95 drain: false,
96 cpu_percent: 0.0,
97 memory_bytes: 0,
98 memory_total: 0,
99 disk_used: 0,
100 disk_total: 0,
101 net_rx: 0,
102 net_tx: 0,
103 });
104 node.last_heartbeat = chrono::Utc::now();
105 node.address = addr;
106 info!("Node {node_id} registered at {}", node.address);
107 }
108
109 let ack = MasterMessage::Ack { node_id };
111 if let Ok(json) = serde_json::to_string(&ack) {
112 let _ = ws_tx.send(Message::Text(json.into())).await;
113 }
114
115 drain_pending_commands(&state, node_id, &tx).await;
117
118 send_reconcile(&state, node_id, &tx).await;
121
122 let send_task = tokio::spawn(async move {
124 while let Some(msg) = rx.recv().await {
125 let json = match serde_json::to_string(&msg) {
126 Ok(j) => j,
127 Err(e) => {
128 error!("Failed to serialize MasterMessage: {e}");
129 continue;
130 }
131 };
132 if ws_tx.send(Message::Text(json.into())).await.is_err() {
133 break; }
135 }
136 });
137
138 while let Some(Ok(msg)) = ws_rx.next().await {
140 match msg {
141 Message::Text(text) => {
142 if let Err(e) = handle_agent_message(&state, node_id, &text, &tx).await {
143 warn!("Error handling agent message from {node_id}: {e}");
144 }
145 }
146 Message::Close(_) => break,
147 _ => {} }
149 }
150
151 send_task.abort();
153 {
154 let mut senders = state.ws_agents.write().await;
155 senders.remove(&node_id);
156 }
157 info!("Agent {node_id} WebSocket disconnected");
158}
159
160async fn handle_agent_message(
162 state: &AppState,
163 node_id: u64,
164 text: &str,
165 _tx: &mpsc::Sender<MasterMessage>,
166) -> anyhow::Result<()> {
167 let msg: AgentMessage = serde_json::from_str(text)?;
168
169 match msg {
170 AgentMessage::Heartbeat {
171 node_id: reported_id,
172 workloads,
173 stats,
174 } => {
175 handle_ws_heartbeat(state, reported_id, &workloads, &stats).await;
176 }
177 AgentMessage::DomainDiscovered {
178 service_name,
179 domain,
180 host_port,
181 } => {
182 info!(
183 "Node {node_id} discovered domain {domain} for {service_name} (port {host_port})"
184 );
185 let mut services = state.services.write().await;
188 if let Some(svc) = services.get_mut(&service_name) {
189 svc.config.domain = Some(domain);
190 }
191 }
192 AgentMessage::DeployResult {
193 service_name,
194 success,
195 error,
196 } => {
197 if success {
198 info!("Node {node_id}: deploy of {service_name} succeeded");
199 } else {
200 error!(
201 "Node {node_id}: deploy of {service_name} failed: {}",
202 error.as_deref().unwrap_or("unknown")
203 );
204 }
205 }
206 AgentMessage::LogChunk {
207 request_id,
208 service_name: _,
209 data,
210 done,
211 } => {
212 let listeners = state.log_listeners.read().await;
214 if let Some(listener_tx) = listeners.get(&request_id) {
215 let _ = listener_tx.send((data, done)).await;
216 }
217 }
218 }
219
220 Ok(())
221}
222
223async fn handle_ws_heartbeat(
225 state: &AppState,
226 node_id: u64,
227 workloads: &[orca_core::ws_types::WorkloadReport],
228 stats: &orca_core::ws_types::HostStats,
229) {
230 let mut nodes = state.registered_nodes.write().await;
231 if let Some(node) = nodes.get_mut(&node_id) {
232 node.last_heartbeat = chrono::Utc::now();
233 node.cpu_percent = stats.cpu_percent;
234 node.memory_bytes = stats.memory_bytes;
235 node.memory_total = stats.memory_total;
236 node.disk_used = stats.disk_used;
237 node.disk_total = stats.disk_total;
238 node.net_rx = stats.net_rx;
239 node.net_tx = stats.net_tx;
240 }
241 drop(nodes);
242
243 if !workloads.is_empty() {
245 let mut services = state.services.write().await;
246 let mut stats_cache = state.container_stats.write().await;
247
248 for report in workloads {
249 if let Some(svc) = services.get_mut(&report.service_name) {
250 let status = match report.status.as_str() {
251 "running" => orca_core::types::WorkloadStatus::Running,
252 "stopped" => orca_core::types::WorkloadStatus::Stopped,
253 "failed" => orca_core::types::WorkloadStatus::Failed,
254 _ => orca_core::types::WorkloadStatus::Stopped,
255 };
256 for instance in &mut svc.instances {
257 instance.status = status;
258 }
259 }
260
261 if report.memory_bytes > 0 || report.cpu_percent > 0.0 {
263 stats_cache.insert(
264 report.service_name.clone(),
265 crate::stats::ContainerStats {
266 memory_usage: crate::stats::format_bytes(report.memory_bytes),
267 cpu_percent: report.cpu_percent,
268 },
269 );
270 }
271 }
272 }
273}
274
275async fn drain_pending_commands(state: &AppState, node_id: u64, tx: &mpsc::Sender<MasterMessage>) {
277 let commands = {
278 let mut pending = state.pending_commands.write().await;
279 pending.remove(&node_id).unwrap_or_default()
280 };
281 for cmd in commands {
282 if let Some(action) = cmd.get("action").and_then(|a| a.as_str()) {
283 match action {
284 "deploy" => {
285 if let Some(spec) = cmd.get("spec")
286 && let Ok(spec) = serde_json::from_value(spec.clone())
287 {
288 let _ = tx
289 .send(MasterMessage::Deploy {
290 spec: Box::new(spec),
291 })
292 .await;
293 }
294 }
295 "stop" => {
296 if let Some(name) = cmd.get("service_name").and_then(|n| n.as_str()) {
297 let _ = tx
298 .send(MasterMessage::Stop {
299 service_name: name.to_string(),
300 })
301 .await;
302 }
303 }
304 _ => {}
305 }
306 }
307 }
308}
309
310async fn send_reconcile(state: &AppState, node_id: u64, tx: &mpsc::Sender<MasterMessage>) {
313 let node_address = {
315 let nodes = state.registered_nodes.read().await;
316 nodes.get(&node_id).map(|n| n.address.clone())
317 };
318 let Some(node_addr) = node_address else {
319 return;
320 };
321
322 let services = state.services.read().await;
324 let expected: Vec<Box<orca_core::types::WorkloadSpec>> = services
325 .values()
326 .filter(|svc| {
327 svc.config
328 .placement
329 .as_ref()
330 .and_then(|p| p.node.as_ref())
331 .is_some_and(|target| {
332 node_addr.contains(target.as_str()) || target == &node_id.to_string() || {
333 let nodes_guard =
334 futures_util::FutureExt::now_or_never(state.registered_nodes.read());
335 nodes_guard
336 .and_then(|nodes| {
337 nodes
338 .get(&node_id)
339 .and_then(|n| n.labels.get("hostname").map(|h| h == target))
340 })
341 .unwrap_or(false)
342 }
343 })
344 })
345 .filter_map(|svc| {
346 crate::routes::service_config_to_spec(&svc.config)
347 .ok()
348 .map(Box::new)
349 })
350 .collect();
351
352 if expected.is_empty() {
353 return;
354 }
355
356 info!(
357 "Sending Reconcile to node {node_id} with {} expected services",
358 expected.len()
359 );
360 let _ = tx.send(MasterMessage::Reconcile { expected }).await;
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn ws_query_deserializes() {
369 let q: WsQuery = serde_json::from_str(r#"{"token":"abc123","node_id":42}"#).unwrap();
370 assert_eq!(q.token, "abc123");
371 assert_eq!(q.node_id, 42);
372 }
373}