Skip to main content

mlua_swarm_server/operator_ws/
login.rs

1//! REST-like Operator session resource.
2//!
3//! Provides the `POST/GET/DELETE /v1/operators` + `WS /v1/operators/:sid/ws`
4//! route family — the sole WS Operator session route. `session.rs` /
5//! `protocol.rs` are unchanged by this module.
6//!
7//! ## Login flow
8//!
9//! ```text
10//! POST /v1/operators { roles?: ["main-ai"] }
11//!   → 409 if any role already owns a live entry (roles alias exclusivity,
12//!     v1.md §Auth session flow)
13//!   → { sid: "op-<uuid>", token: "<10-hex>", roles: [...] }
14//!
15//! WS /v1/operators/:sid/ws
16//!   Authorization: Bearer <token>   (mandatory — no empty-string default)
17//!   → 401 missing/empty Bearer, 404 unknown sid, 401 token mismatch
18//!   → registers a `WSOperatorSession` into the engine's 3 registries
19//!     (senior_bridge / spawn_hook / operator) + role aliases, same pattern
20//!     as `handler::handle_socket`. Reconnect (same sid, matching token)
21//!     reuses the existing `WSOperatorSession` via `replace_tx`.
22//!
23//! DELETE /v1/operators/:sid   (Bearer required)
24//!   → unregisters the 3 registries + role aliases + `operator_sessions`
25//!     entry + releases `roles_to_sid` ownership.
26//!
27//! GET /v1/operators/:sid   (Bearer required)
28//!   → { sid, roles, connected }
29//! ```
30//!
31//! `OperatorSessionEntry` is the login-flow record (`AppState.operator_sessions`),
32//! distinct from `mlua_swarm::OperatorSession` (the engine-side
33//! `attach`/session-token record) and from `WSOperatorSession` (the 3-trait WS
34//! session, `session.rs`) — this module owns the mapping `sid → (token, roles,
35//! Option<WSOperatorSession>)` that the login flow is built on.
36
37use axum::{
38    extract::{
39        ws::{Message, WebSocket, WebSocketUpgrade},
40        Path, State,
41    },
42    http::{HeaderMap, StatusCode},
43    response::{IntoResponse, Response},
44    Json,
45};
46use futures_util::{sink::SinkExt, stream::StreamExt};
47use mlua_swarm::{Operator, SeniorBridge, SpawnHook};
48use serde::{Deserialize, Serialize};
49use serde_json::json;
50use std::sync::Arc;
51use tokio::sync::{mpsc, Mutex};
52
53use super::protocol::{ClientMsg, PendingReply, ServerMsg};
54use super::session::WSOperatorSession;
55use crate::AppState;
56
57/// Login-flow record for a minted Operator session. Held in
58/// `AppState.operator_sessions`, keyed by `sid`. `ws_session` starts `None`
59/// (login only mints sid+token) and is set on first successful WS connect;
60/// on reconnect the same `WSOperatorSession` is reused (`replace_tx`) rather
61/// than re-registered.
62pub struct OperatorSessionEntry {
63    /// Server-minted session id (`op-<uuid>`).
64    pub sid: String,
65    /// Bearer auth token (10-hex-char) required on the WS upgrade and admin routes.
66    pub token: String,
67    /// Role aliases claimed by this session (roles-exclusivity set).
68    pub roles: Vec<String>,
69    /// The live 3-trait session object once a WS has connected; `None` before first connect.
70    pub ws_session: Mutex<Option<Arc<WSOperatorSession>>>,
71}
72
73// ─── POST /v1/operators (mint) ──────────────────────────────────────────────
74
75/// Body for `POST /v1/operators`.
76#[derive(Debug, Deserialize, Default)]
77pub struct OperatorsCreateReq {
78    /// Role aliases to claim exclusively (empty = no exclusivity claimed).
79    #[serde(default)]
80    pub roles: Vec<String>,
81}
82
83/// Response for `POST /v1/operators`.
84#[derive(Debug, Serialize)]
85pub struct OperatorsCreateResp {
86    /// Newly minted session id (`op-<uuid>`).
87    pub sid: String,
88    /// Bearer auth token required on the WS upgrade and admin routes.
89    pub token: String,
90    /// Echoes the granted role aliases.
91    pub roles: Vec<String>,
92}
93
94/// `POST /v1/operators`. Mints `sid` (`op-<uuid>`) + a 10-hex-char token
95/// (`mlua_swarm::types::secure_hex(5)` — OS-RNG hex, unguessable across
96/// calls and restarts, which is the point: this token is the sole bearer
97/// secret on the short-handle path). When `roles` is non-empty, checks
98/// `AppState.roles_to_sid` for conflicts under a single lock (check + insert
99/// atomic w.r.t. concurrent mints) and returns `409 CONFLICT` with the
100/// conflicting role names on collision. Empty `roles` never conflicts (= no
101/// exclusivity is claimed).
102pub async fn operators_create(
103    State(state): State<AppState>,
104    Json(req): Json<OperatorsCreateReq>,
105) -> Response {
106    let roles = req.roles;
107    let sid = format!("op-{}", uuid::Uuid::new_v4());
108    let token = mlua_swarm::types::secure_hex(5);
109
110    {
111        let mut map = state.roles_to_sid.lock().await;
112        let conflicts: Vec<String> = roles
113            .iter()
114            .filter(|r| map.contains_key(r.as_str()))
115            .cloned()
116            .collect();
117        if !conflicts.is_empty() {
118            return (
119                StatusCode::CONFLICT,
120                Json(json!({"error": "roles conflict", "conflicts": conflicts})),
121            )
122                .into_response();
123        }
124        for r in &roles {
125            map.insert(r.clone(), sid.clone());
126        }
127    }
128
129    let entry = Arc::new(OperatorSessionEntry {
130        sid: sid.clone(),
131        token: token.clone(),
132        roles: roles.clone(),
133        ws_session: Mutex::new(None),
134    });
135    state
136        .operator_sessions
137        .lock()
138        .await
139        .insert(sid.clone(), entry);
140
141    (
142        StatusCode::OK,
143        Json(OperatorsCreateResp { sid, token, roles }),
144    )
145        .into_response()
146}
147
148// ─── WS /v1/operators/:sid/ws (Bearer required) ─────────────────────────────
149
150/// Extracts `Authorization: Bearer <token>`; missing header, wrong scheme, or
151/// an empty token all resolve to a `401` response. `Authorization` is
152/// mandatory on the WS path — there is no empty-string default.
153fn extract_bearer_token_required(headers: &HeaderMap) -> Result<String, Box<Response>> {
154    let token = headers
155        .get(axum::http::header::AUTHORIZATION)
156        .and_then(|v| v.to_str().ok())
157        .and_then(|s| s.strip_prefix("Bearer "))
158        .map(|s| s.trim().to_string())
159        .filter(|s| !s.is_empty());
160    token.ok_or_else(|| {
161        Box::new((StatusCode::UNAUTHORIZED, "missing or empty Bearer token").into_response())
162    })
163}
164
165/// `GET /v1/operators/:sid/ws` (WS upgrade). Bearer mandatory. `404` on
166/// unknown sid, `401` on token mismatch. On successful upgrade, registers (or
167/// reuses, on reconnect) a `WSOperatorSession` under `sid` — same 3-registry
168/// pattern as `handler::handle_socket`, plus role-alias registration for
169/// every role minted alongside this sid.
170pub async fn operators_ws_connect(
171    State(state): State<AppState>,
172    Path(sid): Path<String>,
173    headers: HeaderMap,
174    ws: WebSocketUpgrade,
175) -> Response {
176    let bearer = match extract_bearer_token_required(&headers) {
177        Ok(t) => t,
178        Err(resp) => return *resp,
179    };
180
181    let entry = {
182        let map = state.operator_sessions.lock().await;
183        map.get(&sid).cloned()
184    };
185    let entry = match entry {
186        Some(e) => e,
187        None => return (StatusCode::NOT_FOUND, "unknown sid").into_response(),
188    };
189    if entry.token != bearer {
190        return (StatusCode::UNAUTHORIZED, "token mismatch").into_response();
191    }
192
193    ws.on_upgrade(move |socket| handle_operator_socket(socket, state, entry))
194}
195
196/// Bidirectional pump for a single WS connection, bound to an
197/// `OperatorSessionEntry`. Owns the full wire protocol pump (write task /
198/// read task / `ClientMsg` dispatch / disconnect) for this session.
199async fn handle_operator_socket(
200    socket: WebSocket,
201    state: AppState,
202    entry: Arc<OperatorSessionEntry>,
203) {
204    let (tx, mut rx) = mpsc::unbounded_channel::<ServerMsg>();
205
206    let existing_ws = entry.ws_session.lock().await.clone();
207    let session = match existing_ws {
208        Some(ws_session) => {
209            // Reconnect: reuse the existing WSOperatorSession on this entry; only swap out `tx`.
210            ws_session.replace_tx(tx.clone()).await;
211            ws_session
212        }
213        None => {
214            let ws_session = Arc::new(WSOperatorSession::new(entry.sid.clone(), tx.clone()));
215            state
216                .engine
217                .register_senior_bridge(
218                    entry.sid.clone(),
219                    ws_session.clone() as Arc<dyn SeniorBridge>,
220                )
221                .await;
222            state
223                .engine
224                .register_spawn_hook(entry.sid.clone(), ws_session.clone() as Arc<dyn SpawnHook>)
225                .await;
226            state
227                .engine
228                .register_operator(entry.sid.clone(), ws_session.clone() as Arc<dyn Operator>)
229                .await;
230            if let Some(factory) = &state.ws_operator_factory {
231                factory
232                    .register_operator(entry.sid.clone(), ws_session.clone() as Arc<dyn Operator>);
233            }
234            // Role exclusivity was already resolved at login (POST) time. Here
235            // we just bind the same session into the three registries + factory
236            // under its role aliases (same shape as handler::handle_socket's
237            // ?roles= path).
238            for role in &entry.roles {
239                if let Some(factory) = &state.ws_operator_factory {
240                    factory
241                        .register_operator(role.clone(), ws_session.clone() as Arc<dyn Operator>);
242                }
243                state
244                    .engine
245                    .register_operator(role.clone(), ws_session.clone() as Arc<dyn Operator>)
246                    .await;
247            }
248            *entry.ws_session.lock().await = Some(ws_session.clone());
249            ws_session
250        }
251    };
252
253    let (mut ws_sink, mut ws_stream) = socket.split();
254
255    // write task: mpsc → WebSocket
256    let write_task = tokio::spawn(async move {
257        while let Some(msg) = rx.recv().await {
258            let txt = match serde_json::to_string(&msg) {
259                Ok(s) => s,
260                Err(_) => continue,
261            };
262            if ws_sink.send(Message::Text(txt)).await.is_err() {
263                break;
264            }
265        }
266        let _ = ws_sink.close().await;
267    });
268
269    // read task: WS message → ClientMsg parse → session.resolve_pending
270    let session_for_read = session.clone();
271    let read_result: Result<(), String> = async {
272        while let Some(item) = ws_stream.next().await {
273            match item {
274                Ok(Message::Text(t)) => {
275                    let parsed: ClientMsg = match serde_json::from_str(&t) {
276                        Ok(p) => p,
277                        Err(_) => continue,
278                    };
279                    match parsed {
280                        ClientMsg::Answer { req_id, value } => {
281                            session_for_read
282                                .resolve_pending(&req_id, PendingReply::Answer(value))
283                                .await;
284                        }
285                        ClientMsg::HookAck { req_id, ok, reason } => {
286                            session_for_read
287                                .resolve_pending(&req_id, PendingReply::HookAck { ok, reason })
288                                .await;
289                        }
290                        ClientMsg::SpawnAck {
291                            req_id,
292                            value,
293                            ok,
294                            error,
295                        } => {
296                            session_for_read
297                                .resolve_pending(
298                                    &req_id,
299                                    PendingReply::SpawnAck { value, ok, error },
300                                )
301                                .await;
302                        }
303                    }
304                }
305                Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {}
306                Ok(Message::Close(_)) | Err(_) => break,
307                _ => {}
308            }
309        }
310        Ok(())
311    }
312    .await;
313
314    // Disconnect: tx → None (the session itself stays in operator_sessions
315    // and the three registries, waiting for a reconnect; teardown happens
316    // only through DELETE).
317    session.clear_tx().await;
318    write_task.abort();
319    let _ = read_result;
320}
321
322// ─── DELETE /v1/operators/:sid (Bearer required) ────────────────────────────
323
324/// `DELETE /v1/operators/:sid`. Bearer mandatory. `404` on unknown sid, `401`
325/// on token mismatch. Drops the 3 engine registries + role aliases +
326/// `ws_operator_factory` bindings + `operator_sessions` entry, and releases
327/// this sid's ownership in `roles_to_sid` (re-opening the role names for a
328/// future mint).
329pub async fn operators_delete(
330    State(state): State<AppState>,
331    Path(sid): Path<String>,
332    headers: HeaderMap,
333) -> Response {
334    let bearer = match extract_bearer_token_required(&headers) {
335        Ok(t) => t,
336        Err(resp) => return *resp,
337    };
338
339    let entry = {
340        let map = state.operator_sessions.lock().await;
341        map.get(&sid).cloned()
342    };
343    let entry = match entry {
344        Some(e) => e,
345        None => return (StatusCode::NOT_FOUND, "unknown sid").into_response(),
346    };
347    if entry.token != bearer {
348        return (StatusCode::UNAUTHORIZED, "token mismatch").into_response();
349    }
350
351    state.engine.unregister_senior_bridge(&sid).await;
352    state.engine.unregister_spawn_hook(&sid).await;
353    state.engine.unregister_operator(&sid).await;
354    if let Some(factory) = &state.ws_operator_factory {
355        factory.unregister_operator(&sid);
356    }
357    for role in &entry.roles {
358        state.engine.unregister_operator(role).await;
359        if let Some(factory) = &state.ws_operator_factory {
360            factory.unregister_operator(role);
361        }
362    }
363
364    if let Some(session) = entry.ws_session.lock().await.take() {
365        session.clear_tx().await;
366    }
367
368    state.operator_sessions.lock().await.remove(&sid);
369
370    {
371        let mut map = state.roles_to_sid.lock().await;
372        for role in &entry.roles {
373            if map.get(role).map(String::as_str) == Some(sid.as_str()) {
374                map.remove(role);
375            }
376        }
377    }
378
379    StatusCode::NO_CONTENT.into_response()
380}
381
382// ─── GET /v1/operators/:sid (Bearer required) ───────────────────────────────
383
384/// Response for `GET /v1/operators/:sid`.
385#[derive(Debug, Serialize)]
386pub struct OperatorsInfoResp {
387    /// Echoes the requested session id.
388    pub sid: String,
389    /// Role aliases held by this session.
390    pub roles: Vec<String>,
391    /// Whether a WS is currently attached (not merely that the session ever connected).
392    pub connected: bool,
393}
394
395/// `GET /v1/operators/:sid`. Bearer mandatory. `404` on unknown sid, `401` on
396/// token mismatch. `connected` reflects whether `ws_session` is currently
397/// `Some` (= a WS is live, not merely that the session was ever connected).
398pub async fn operators_info(
399    State(state): State<AppState>,
400    Path(sid): Path<String>,
401    headers: HeaderMap,
402) -> Response {
403    let bearer = match extract_bearer_token_required(&headers) {
404        Ok(t) => t,
405        Err(resp) => return *resp,
406    };
407
408    let entry = {
409        let map = state.operator_sessions.lock().await;
410        map.get(&sid).cloned()
411    };
412    let entry = match entry {
413        Some(e) => e,
414        None => return (StatusCode::NOT_FOUND, "unknown sid").into_response(),
415    };
416    if entry.token != bearer {
417        return (StatusCode::UNAUTHORIZED, "token mismatch").into_response();
418    }
419
420    let connected = entry.ws_session.lock().await.is_some();
421    (
422        StatusCode::OK,
423        Json(OperatorsInfoResp {
424            sid: entry.sid.clone(),
425            roles: entry.roles.clone(),
426            connected,
427        }),
428    )
429        .into_response()
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use axum::http::HeaderValue;
436
437    fn headers_with_bearer(token: &str) -> HeaderMap {
438        let mut h = HeaderMap::new();
439        h.insert(
440            axum::http::header::AUTHORIZATION,
441            HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
442        );
443        h
444    }
445
446    #[test]
447    fn extract_bearer_token_required_accepts_valid() {
448        let h = headers_with_bearer("abc123");
449        assert_eq!(extract_bearer_token_required(&h).unwrap(), "abc123");
450    }
451
452    #[test]
453    fn extract_bearer_token_required_rejects_missing_header() {
454        let h = HeaderMap::new();
455        assert!(extract_bearer_token_required(&h).is_err());
456    }
457
458    #[test]
459    fn extract_bearer_token_required_rejects_empty_token() {
460        let h = headers_with_bearer("");
461        assert!(extract_bearer_token_required(&h).is_err());
462    }
463
464    #[test]
465    fn extract_bearer_token_required_rejects_wrong_scheme() {
466        let mut h = HeaderMap::new();
467        h.insert(
468            axum::http::header::AUTHORIZATION,
469            HeaderValue::from_static("Basic dXNlcjpwYXNz"),
470        );
471        assert!(extract_bearer_token_required(&h).is_err());
472    }
473}