mlua_swarm_server/operator_ws/
login.rs1use axum::{
38 extract::{
39 ws::{Message, WebSocket, WebSocketUpgrade},
40 Path, State,
41 },
42 http::{HeaderMap, StatusCode},
43 response::{IntoResponse, Response},
44 Json,
45};
46use futures_util::{sink::SinkExt, stream::StreamExt};
47use mlua_swarm::{Operator, SeniorBridge, SpawnHook};
48use serde::{Deserialize, Serialize};
49use serde_json::json;
50use std::sync::Arc;
51use tokio::sync::{mpsc, Mutex};
52
53use super::protocol::{ClientMsg, PendingReply, ServerMsg};
54use super::session::WSOperatorSession;
55use crate::AppState;
56
57pub struct OperatorSessionEntry {
63 pub sid: String,
65 pub token: String,
67 pub roles: Vec<String>,
69 pub ws_session: Mutex<Option<Arc<WSOperatorSession>>>,
71}
72
73#[derive(Debug, Deserialize, Default)]
77pub struct OperatorsCreateReq {
78 #[serde(default)]
80 pub roles: Vec<String>,
81}
82
83#[derive(Debug, Serialize)]
85pub struct OperatorsCreateResp {
86 pub sid: String,
88 pub token: String,
90 pub roles: Vec<String>,
92}
93
94pub async fn operators_create(
103 State(state): State<AppState>,
104 Json(req): Json<OperatorsCreateReq>,
105) -> Response {
106 let roles = req.roles;
107 let sid = format!("op-{}", uuid::Uuid::new_v4());
108 let token = mlua_swarm::types::secure_hex(5);
109
110 {
111 let mut map = state.roles_to_sid.lock().await;
112 let conflicts: Vec<String> = roles
113 .iter()
114 .filter(|r| map.contains_key(r.as_str()))
115 .cloned()
116 .collect();
117 if !conflicts.is_empty() {
118 return (
119 StatusCode::CONFLICT,
120 Json(json!({"error": "roles conflict", "conflicts": conflicts})),
121 )
122 .into_response();
123 }
124 for r in &roles {
125 map.insert(r.clone(), sid.clone());
126 }
127 }
128
129 let entry = Arc::new(OperatorSessionEntry {
130 sid: sid.clone(),
131 token: token.clone(),
132 roles: roles.clone(),
133 ws_session: Mutex::new(None),
134 });
135 state
136 .operator_sessions
137 .lock()
138 .await
139 .insert(sid.clone(), entry);
140
141 (
142 StatusCode::OK,
143 Json(OperatorsCreateResp { sid, token, roles }),
144 )
145 .into_response()
146}
147
148fn extract_bearer_token_required(headers: &HeaderMap) -> Result<String, Box<Response>> {
154 let token = headers
155 .get(axum::http::header::AUTHORIZATION)
156 .and_then(|v| v.to_str().ok())
157 .and_then(|s| s.strip_prefix("Bearer "))
158 .map(|s| s.trim().to_string())
159 .filter(|s| !s.is_empty());
160 token.ok_or_else(|| {
161 Box::new((StatusCode::UNAUTHORIZED, "missing or empty Bearer token").into_response())
162 })
163}
164
165pub async fn operators_ws_connect(
171 State(state): State<AppState>,
172 Path(sid): Path<String>,
173 headers: HeaderMap,
174 ws: WebSocketUpgrade,
175) -> Response {
176 let bearer = match extract_bearer_token_required(&headers) {
177 Ok(t) => t,
178 Err(resp) => return *resp,
179 };
180
181 let entry = {
182 let map = state.operator_sessions.lock().await;
183 map.get(&sid).cloned()
184 };
185 let entry = match entry {
186 Some(e) => e,
187 None => return (StatusCode::NOT_FOUND, "unknown sid").into_response(),
188 };
189 if entry.token != bearer {
190 return (StatusCode::UNAUTHORIZED, "token mismatch").into_response();
191 }
192
193 ws.on_upgrade(move |socket| handle_operator_socket(socket, state, entry))
194}
195
196async fn handle_operator_socket(
200 socket: WebSocket,
201 state: AppState,
202 entry: Arc<OperatorSessionEntry>,
203) {
204 let (tx, mut rx) = mpsc::unbounded_channel::<ServerMsg>();
205
206 let existing_ws = entry.ws_session.lock().await.clone();
207 let session = match existing_ws {
208 Some(ws_session) => {
209 ws_session.replace_tx(tx.clone()).await;
211 ws_session
212 }
213 None => {
214 let ws_session = Arc::new(WSOperatorSession::new(entry.sid.clone(), tx.clone()));
215 state
216 .engine
217 .register_senior_bridge(
218 entry.sid.clone(),
219 ws_session.clone() as Arc<dyn SeniorBridge>,
220 )
221 .await;
222 state
223 .engine
224 .register_spawn_hook(entry.sid.clone(), ws_session.clone() as Arc<dyn SpawnHook>)
225 .await;
226 state
227 .engine
228 .register_operator(entry.sid.clone(), ws_session.clone() as Arc<dyn Operator>)
229 .await;
230 if let Some(factory) = &state.ws_operator_factory {
231 factory
232 .register_operator(entry.sid.clone(), ws_session.clone() as Arc<dyn Operator>);
233 }
234 for role in &entry.roles {
239 if let Some(factory) = &state.ws_operator_factory {
240 factory
241 .register_operator(role.clone(), ws_session.clone() as Arc<dyn Operator>);
242 }
243 state
244 .engine
245 .register_operator(role.clone(), ws_session.clone() as Arc<dyn Operator>)
246 .await;
247 }
248 *entry.ws_session.lock().await = Some(ws_session.clone());
249 ws_session
250 }
251 };
252
253 let (mut ws_sink, mut ws_stream) = socket.split();
254
255 let write_task = tokio::spawn(async move {
257 while let Some(msg) = rx.recv().await {
258 let txt = match serde_json::to_string(&msg) {
259 Ok(s) => s,
260 Err(_) => continue,
261 };
262 if ws_sink.send(Message::Text(txt)).await.is_err() {
263 break;
264 }
265 }
266 let _ = ws_sink.close().await;
267 });
268
269 let session_for_read = session.clone();
271 let read_result: Result<(), String> = async {
272 while let Some(item) = ws_stream.next().await {
273 match item {
274 Ok(Message::Text(t)) => {
275 let parsed: ClientMsg = match serde_json::from_str(&t) {
276 Ok(p) => p,
277 Err(_) => continue,
278 };
279 match parsed {
280 ClientMsg::Answer { req_id, value } => {
281 session_for_read
282 .resolve_pending(&req_id, PendingReply::Answer(value))
283 .await;
284 }
285 ClientMsg::HookAck { req_id, ok, reason } => {
286 session_for_read
287 .resolve_pending(&req_id, PendingReply::HookAck { ok, reason })
288 .await;
289 }
290 ClientMsg::SpawnAck {
291 req_id,
292 value,
293 ok,
294 error,
295 } => {
296 session_for_read
297 .resolve_pending(
298 &req_id,
299 PendingReply::SpawnAck { value, ok, error },
300 )
301 .await;
302 }
303 }
304 }
305 Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {}
306 Ok(Message::Close(_)) | Err(_) => break,
307 _ => {}
308 }
309 }
310 Ok(())
311 }
312 .await;
313
314 session.clear_tx().await;
318 write_task.abort();
319 let _ = read_result;
320}
321
322pub async fn operators_delete(
330 State(state): State<AppState>,
331 Path(sid): Path<String>,
332 headers: HeaderMap,
333) -> Response {
334 let bearer = match extract_bearer_token_required(&headers) {
335 Ok(t) => t,
336 Err(resp) => return *resp,
337 };
338
339 let entry = {
340 let map = state.operator_sessions.lock().await;
341 map.get(&sid).cloned()
342 };
343 let entry = match entry {
344 Some(e) => e,
345 None => return (StatusCode::NOT_FOUND, "unknown sid").into_response(),
346 };
347 if entry.token != bearer {
348 return (StatusCode::UNAUTHORIZED, "token mismatch").into_response();
349 }
350
351 state.engine.unregister_senior_bridge(&sid).await;
352 state.engine.unregister_spawn_hook(&sid).await;
353 state.engine.unregister_operator(&sid).await;
354 if let Some(factory) = &state.ws_operator_factory {
355 factory.unregister_operator(&sid);
356 }
357 for role in &entry.roles {
358 state.engine.unregister_operator(role).await;
359 if let Some(factory) = &state.ws_operator_factory {
360 factory.unregister_operator(role);
361 }
362 }
363
364 if let Some(session) = entry.ws_session.lock().await.take() {
365 session.clear_tx().await;
366 }
367
368 state.operator_sessions.lock().await.remove(&sid);
369
370 {
371 let mut map = state.roles_to_sid.lock().await;
372 for role in &entry.roles {
373 if map.get(role).map(String::as_str) == Some(sid.as_str()) {
374 map.remove(role);
375 }
376 }
377 }
378
379 StatusCode::NO_CONTENT.into_response()
380}
381
382#[derive(Debug, Serialize)]
386pub struct OperatorsInfoResp {
387 pub sid: String,
389 pub roles: Vec<String>,
391 pub connected: bool,
393}
394
395pub async fn operators_info(
399 State(state): State<AppState>,
400 Path(sid): Path<String>,
401 headers: HeaderMap,
402) -> Response {
403 let bearer = match extract_bearer_token_required(&headers) {
404 Ok(t) => t,
405 Err(resp) => return *resp,
406 };
407
408 let entry = {
409 let map = state.operator_sessions.lock().await;
410 map.get(&sid).cloned()
411 };
412 let entry = match entry {
413 Some(e) => e,
414 None => return (StatusCode::NOT_FOUND, "unknown sid").into_response(),
415 };
416 if entry.token != bearer {
417 return (StatusCode::UNAUTHORIZED, "token mismatch").into_response();
418 }
419
420 let connected = entry.ws_session.lock().await.is_some();
421 (
422 StatusCode::OK,
423 Json(OperatorsInfoResp {
424 sid: entry.sid.clone(),
425 roles: entry.roles.clone(),
426 connected,
427 }),
428 )
429 .into_response()
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use axum::http::HeaderValue;
436
437 fn headers_with_bearer(token: &str) -> HeaderMap {
438 let mut h = HeaderMap::new();
439 h.insert(
440 axum::http::header::AUTHORIZATION,
441 HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
442 );
443 h
444 }
445
446 #[test]
447 fn extract_bearer_token_required_accepts_valid() {
448 let h = headers_with_bearer("abc123");
449 assert_eq!(extract_bearer_token_required(&h).unwrap(), "abc123");
450 }
451
452 #[test]
453 fn extract_bearer_token_required_rejects_missing_header() {
454 let h = HeaderMap::new();
455 assert!(extract_bearer_token_required(&h).is_err());
456 }
457
458 #[test]
459 fn extract_bearer_token_required_rejects_empty_token() {
460 let h = headers_with_bearer("");
461 assert!(extract_bearer_token_required(&h).is_err());
462 }
463
464 #[test]
465 fn extract_bearer_token_required_rejects_wrong_scheme() {
466 let mut h = HeaderMap::new();
467 h.insert(
468 axum::http::header::AUTHORIZATION,
469 HeaderValue::from_static("Basic dXNlcjpwYXNz"),
470 );
471 assert!(extract_bearer_token_required(&h).is_err());
472 }
473}