orca_control/
ws_handler.rs1use std::sync::Arc;
8
9use axum::extract::ws::{Message, WebSocket};
10use axum::extract::{Query, State, WebSocketUpgrade};
11use axum::response::IntoResponse;
12use futures_util::{SinkExt, StreamExt};
13use serde::Deserialize;
14use tokio::sync::mpsc;
15use tracing::{error, info, warn};
16
17use orca_core::ws_types::{AgentMessage, MasterMessage};
18
19use crate::state::AppState;
20
21#[derive(Deserialize)]
23pub struct WsQuery {
24 token: String,
25 node_id: u64,
26}
27
28pub type AgentSender = mpsc::Sender<MasterMessage>;
30
31pub async fn ws_agent_handler(
36 ws: WebSocketUpgrade,
37 State(state): State<Arc<AppState>>,
38 Query(query): Query<WsQuery>,
39) -> impl IntoResponse {
40 let valid = state.api_tokens.iter().any(|t| t == &query.token)
42 || state
43 .cluster_config
44 .token
45 .iter()
46 .any(|t| t.value == query.token);
47
48 if !valid {
49 return (axum::http::StatusCode::UNAUTHORIZED, "invalid token").into_response();
50 }
51
52 let node_id = query.node_id;
53 info!("WebSocket upgrade accepted for node {node_id}");
54
55 ws.on_upgrade(move |socket| handle_agent_ws(socket, state, node_id))
56 .into_response()
57}
58
59async fn handle_agent_ws(socket: WebSocket, state: Arc<AppState>, node_id: u64) {
61 let (mut ws_tx, mut ws_rx) = socket.split();
62
63 let (tx, mut rx) = mpsc::channel::<MasterMessage>(64);
65
66 {
68 let mut senders = state.ws_agents.write().await;
69 senders.insert(node_id, tx.clone());
70 }
71
72 info!("Agent {node_id} connected via WebSocket");
73
74 let ack = MasterMessage::Ack { node_id };
76 if let Ok(json) = serde_json::to_string(&ack) {
77 let _ = ws_tx.send(Message::Text(json.into())).await;
78 }
79
80 drain_pending_commands(&state, node_id, &tx).await;
82
83 send_reconcile(&state, node_id, &tx).await;
86
87 let send_task = tokio::spawn(async move {
89 while let Some(msg) = rx.recv().await {
90 let json = match serde_json::to_string(&msg) {
91 Ok(j) => j,
92 Err(e) => {
93 error!("Failed to serialize MasterMessage: {e}");
94 continue;
95 }
96 };
97 if ws_tx.send(Message::Text(json.into())).await.is_err() {
98 break; }
100 }
101 });
102
103 while let Some(Ok(msg)) = ws_rx.next().await {
105 match msg {
106 Message::Text(text) => {
107 if let Err(e) = handle_agent_message(&state, node_id, &text, &tx).await {
108 warn!("Error handling agent message from {node_id}: {e}");
109 }
110 }
111 Message::Close(_) => break,
112 _ => {} }
114 }
115
116 send_task.abort();
118 {
119 let mut senders = state.ws_agents.write().await;
120 senders.remove(&node_id);
121 }
122 info!("Agent {node_id} WebSocket disconnected");
123}
124
125async fn handle_agent_message(
127 state: &AppState,
128 node_id: u64,
129 text: &str,
130 _tx: &mpsc::Sender<MasterMessage>,
131) -> anyhow::Result<()> {
132 let msg: AgentMessage = serde_json::from_str(text)?;
133
134 match msg {
135 AgentMessage::Heartbeat {
136 node_id: reported_id,
137 workloads,
138 stats,
139 } => {
140 handle_ws_heartbeat(state, reported_id, &workloads, &stats).await;
141 }
142 AgentMessage::DomainDiscovered {
143 service_name,
144 domain,
145 host_port,
146 } => {
147 info!(
148 "Node {node_id} discovered domain {domain} for {service_name} (port {host_port})"
149 );
150 let mut services = state.services.write().await;
153 if let Some(svc) = services.get_mut(&service_name) {
154 svc.config.domain = Some(domain);
155 }
156 }
157 AgentMessage::DeployResult {
158 service_name,
159 success,
160 error,
161 } => {
162 if success {
163 info!("Node {node_id}: deploy of {service_name} succeeded");
164 } else {
165 error!(
166 "Node {node_id}: deploy of {service_name} failed: {}",
167 error.as_deref().unwrap_or("unknown")
168 );
169 }
170 }
171 AgentMessage::LogChunk {
172 request_id,
173 service_name: _,
174 data,
175 done,
176 } => {
177 let listeners = state.log_listeners.read().await;
179 if let Some(listener_tx) = listeners.get(&request_id) {
180 let _ = listener_tx.send((data, done)).await;
181 }
182 }
183 }
184
185 Ok(())
186}
187
188async fn handle_ws_heartbeat(
190 state: &AppState,
191 node_id: u64,
192 workloads: &[orca_core::ws_types::WorkloadReport],
193 stats: &orca_core::ws_types::HostStats,
194) {
195 let mut nodes = state.registered_nodes.write().await;
196 if let Some(node) = nodes.get_mut(&node_id) {
197 node.last_heartbeat = chrono::Utc::now();
198 node.cpu_percent = stats.cpu_percent;
199 node.memory_bytes = stats.memory_bytes;
200 node.memory_total = stats.memory_total;
201 node.disk_used = stats.disk_used;
202 node.disk_total = stats.disk_total;
203 node.net_rx = stats.net_rx;
204 node.net_tx = stats.net_tx;
205 }
206 drop(nodes);
207
208 if !workloads.is_empty() {
210 let mut services = state.services.write().await;
211 for report in workloads {
212 if let Some(svc) = services.get_mut(&report.service_name) {
213 let status = match report.status.as_str() {
214 "running" => orca_core::types::WorkloadStatus::Running,
215 "stopped" => orca_core::types::WorkloadStatus::Stopped,
216 "failed" => orca_core::types::WorkloadStatus::Failed,
217 _ => orca_core::types::WorkloadStatus::Stopped,
218 };
219 for instance in &mut svc.instances {
220 instance.status = status;
221 }
222 }
223 }
224 }
225}
226
227async fn drain_pending_commands(state: &AppState, node_id: u64, tx: &mpsc::Sender<MasterMessage>) {
229 let commands = {
230 let mut pending = state.pending_commands.write().await;
231 pending.remove(&node_id).unwrap_or_default()
232 };
233 for cmd in commands {
234 if let Some(action) = cmd.get("action").and_then(|a| a.as_str()) {
235 match action {
236 "deploy" => {
237 if let Some(spec) = cmd.get("spec")
238 && let Ok(spec) = serde_json::from_value(spec.clone())
239 {
240 let _ = tx
241 .send(MasterMessage::Deploy {
242 spec: Box::new(spec),
243 })
244 .await;
245 }
246 }
247 "stop" => {
248 if let Some(name) = cmd.get("service_name").and_then(|n| n.as_str()) {
249 let _ = tx
250 .send(MasterMessage::Stop {
251 service_name: name.to_string(),
252 })
253 .await;
254 }
255 }
256 _ => {}
257 }
258 }
259 }
260}
261
262async fn send_reconcile(state: &AppState, node_id: u64, tx: &mpsc::Sender<MasterMessage>) {
265 let node_address = {
267 let nodes = state.registered_nodes.read().await;
268 nodes.get(&node_id).map(|n| n.address.clone())
269 };
270 let Some(node_addr) = node_address else {
271 return;
272 };
273
274 let services = state.services.read().await;
276 let expected: Vec<Box<orca_core::types::WorkloadSpec>> = services
277 .values()
278 .filter(|svc| {
279 svc.config
280 .placement
281 .as_ref()
282 .and_then(|p| p.node.as_ref())
283 .is_some_and(|target| {
284 node_addr.contains(target.as_str()) || target == &node_id.to_string() || {
285 let nodes_guard =
286 futures_util::FutureExt::now_or_never(state.registered_nodes.read());
287 nodes_guard
288 .and_then(|nodes| {
289 nodes
290 .get(&node_id)
291 .and_then(|n| n.labels.get("hostname").map(|h| h == target))
292 })
293 .unwrap_or(false)
294 }
295 })
296 })
297 .filter_map(|svc| {
298 crate::routes::service_config_to_spec(&svc.config)
299 .ok()
300 .map(Box::new)
301 })
302 .collect();
303
304 if expected.is_empty() {
305 return;
306 }
307
308 info!(
309 "Sending Reconcile to node {node_id} with {} expected services",
310 expected.len()
311 );
312 let _ = tx.send(MasterMessage::Reconcile { expected }).await;
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn ws_query_deserializes() {
321 let q: WsQuery = serde_json::from_str(r#"{"token":"abc123","node_id":42}"#).unwrap();
322 assert_eq!(q.token, "abc123");
323 assert_eq!(q.node_id, 42);
324 }
325}