use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Path, State,
},
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
Json,
};
use futures_util::{sink::SinkExt, stream::StreamExt};
use mlua_swarm::{Operator, SeniorBridge, SpawnHook};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use super::protocol::{ClientMsg, PendingReply, ServerMsg};
use super::session::WSOperatorSession;
use crate::AppState;
pub struct OperatorSessionEntry {
pub sid: String,
pub token: String,
pub roles: Vec<String>,
pub ws_session: Mutex<Option<Arc<WSOperatorSession>>>,
}
#[derive(Debug, Deserialize, Default)]
pub struct OperatorsCreateReq {
#[serde(default)]
pub roles: Vec<String>,
}
#[derive(Debug, Serialize)]
pub struct OperatorsCreateResp {
pub sid: String,
pub token: String,
pub roles: Vec<String>,
}
pub async fn operators_create(
State(state): State<AppState>,
Json(req): Json<OperatorsCreateReq>,
) -> Response {
let roles = req.roles;
let sid = format!("op-{}", uuid::Uuid::new_v4());
let token = mlua_swarm::types::secure_hex(5);
{
let mut map = state.roles_to_sid.lock().await;
let conflicts: Vec<String> = roles
.iter()
.filter(|r| map.contains_key(r.as_str()))
.cloned()
.collect();
if !conflicts.is_empty() {
return (
StatusCode::CONFLICT,
Json(json!({"error": "roles conflict", "conflicts": conflicts})),
)
.into_response();
}
for r in &roles {
map.insert(r.clone(), sid.clone());
}
}
let entry = Arc::new(OperatorSessionEntry {
sid: sid.clone(),
token: token.clone(),
roles: roles.clone(),
ws_session: Mutex::new(None),
});
state
.operator_sessions
.lock()
.await
.insert(sid.clone(), entry);
(
StatusCode::OK,
Json(OperatorsCreateResp { sid, token, roles }),
)
.into_response()
}
fn extract_bearer_token_required(headers: &HeaderMap) -> Result<String, Box<Response>> {
let token = headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
token.ok_or_else(|| {
Box::new((StatusCode::UNAUTHORIZED, "missing or empty Bearer token").into_response())
})
}
pub async fn operators_ws_connect(
State(state): State<AppState>,
Path(sid): Path<String>,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> Response {
let bearer = match extract_bearer_token_required(&headers) {
Ok(t) => t,
Err(resp) => return *resp,
};
let entry = {
let map = state.operator_sessions.lock().await;
map.get(&sid).cloned()
};
let entry = match entry {
Some(e) => e,
None => return (StatusCode::NOT_FOUND, "unknown sid").into_response(),
};
if entry.token != bearer {
return (StatusCode::UNAUTHORIZED, "token mismatch").into_response();
}
ws.on_upgrade(move |socket| handle_operator_socket(socket, state, entry))
}
async fn handle_operator_socket(
socket: WebSocket,
state: AppState,
entry: Arc<OperatorSessionEntry>,
) {
let (tx, mut rx) = mpsc::unbounded_channel::<ServerMsg>();
let existing_ws = entry.ws_session.lock().await.clone();
let session = match existing_ws {
Some(ws_session) => {
ws_session.replace_tx(tx.clone()).await;
ws_session
}
None => {
let ws_session = Arc::new(WSOperatorSession::new(entry.sid.clone(), tx.clone()));
state
.engine
.register_senior_bridge(
entry.sid.clone(),
ws_session.clone() as Arc<dyn SeniorBridge>,
)
.await;
state
.engine
.register_spawn_hook(entry.sid.clone(), ws_session.clone() as Arc<dyn SpawnHook>)
.await;
state
.engine
.register_operator(entry.sid.clone(), ws_session.clone() as Arc<dyn Operator>)
.await;
if let Some(factory) = &state.ws_operator_factory {
factory
.register_operator(entry.sid.clone(), ws_session.clone() as Arc<dyn Operator>);
}
for role in &entry.roles {
if let Some(factory) = &state.ws_operator_factory {
factory
.register_operator(role.clone(), ws_session.clone() as Arc<dyn Operator>);
}
state
.engine
.register_operator(role.clone(), ws_session.clone() as Arc<dyn Operator>)
.await;
}
*entry.ws_session.lock().await = Some(ws_session.clone());
ws_session
}
};
let (mut ws_sink, mut ws_stream) = socket.split();
let write_task = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
let txt = match serde_json::to_string(&msg) {
Ok(s) => s,
Err(_) => continue,
};
if ws_sink.send(Message::Text(txt)).await.is_err() {
break;
}
}
let _ = ws_sink.close().await;
});
let session_for_read = session.clone();
let read_result: Result<(), String> = async {
while let Some(item) = ws_stream.next().await {
match item {
Ok(Message::Text(t)) => {
let parsed: ClientMsg = match serde_json::from_str(&t) {
Ok(p) => p,
Err(_) => continue,
};
match parsed {
ClientMsg::Answer { req_id, value } => {
session_for_read
.resolve_pending(&req_id, PendingReply::Answer(value))
.await;
}
ClientMsg::HookAck { req_id, ok, reason } => {
session_for_read
.resolve_pending(&req_id, PendingReply::HookAck { ok, reason })
.await;
}
ClientMsg::SpawnAck {
req_id,
value,
ok,
error,
} => {
session_for_read
.resolve_pending(
&req_id,
PendingReply::SpawnAck { value, ok, error },
)
.await;
}
}
}
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {}
Ok(Message::Close(_)) | Err(_) => break,
_ => {}
}
}
Ok(())
}
.await;
session.clear_tx().await;
write_task.abort();
let _ = read_result;
}
pub async fn operators_delete(
State(state): State<AppState>,
Path(sid): Path<String>,
headers: HeaderMap,
) -> Response {
let bearer = match extract_bearer_token_required(&headers) {
Ok(t) => t,
Err(resp) => return *resp,
};
let entry = {
let map = state.operator_sessions.lock().await;
map.get(&sid).cloned()
};
let entry = match entry {
Some(e) => e,
None => return (StatusCode::NOT_FOUND, "unknown sid").into_response(),
};
if entry.token != bearer {
return (StatusCode::UNAUTHORIZED, "token mismatch").into_response();
}
state.engine.unregister_senior_bridge(&sid).await;
state.engine.unregister_spawn_hook(&sid).await;
state.engine.unregister_operator(&sid).await;
if let Some(factory) = &state.ws_operator_factory {
factory.unregister_operator(&sid);
}
for role in &entry.roles {
state.engine.unregister_operator(role).await;
if let Some(factory) = &state.ws_operator_factory {
factory.unregister_operator(role);
}
}
if let Some(session) = entry.ws_session.lock().await.take() {
session.clear_tx().await;
}
state.operator_sessions.lock().await.remove(&sid);
{
let mut map = state.roles_to_sid.lock().await;
for role in &entry.roles {
if map.get(role).map(String::as_str) == Some(sid.as_str()) {
map.remove(role);
}
}
}
StatusCode::NO_CONTENT.into_response()
}
#[derive(Debug, Serialize)]
pub struct OperatorsInfoResp {
pub sid: String,
pub roles: Vec<String>,
pub connected: bool,
}
pub async fn operators_info(
State(state): State<AppState>,
Path(sid): Path<String>,
headers: HeaderMap,
) -> Response {
let bearer = match extract_bearer_token_required(&headers) {
Ok(t) => t,
Err(resp) => return *resp,
};
let entry = {
let map = state.operator_sessions.lock().await;
map.get(&sid).cloned()
};
let entry = match entry {
Some(e) => e,
None => return (StatusCode::NOT_FOUND, "unknown sid").into_response(),
};
if entry.token != bearer {
return (StatusCode::UNAUTHORIZED, "token mismatch").into_response();
}
let connected = entry.ws_session.lock().await.is_some();
(
StatusCode::OK,
Json(OperatorsInfoResp {
sid: entry.sid.clone(),
roles: entry.roles.clone(),
connected,
}),
)
.into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;
fn headers_with_bearer(token: &str) -> HeaderMap {
let mut h = HeaderMap::new();
h.insert(
axum::http::header::AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
);
h
}
#[test]
fn extract_bearer_token_required_accepts_valid() {
let h = headers_with_bearer("abc123");
assert_eq!(extract_bearer_token_required(&h).unwrap(), "abc123");
}
#[test]
fn extract_bearer_token_required_rejects_missing_header() {
let h = HeaderMap::new();
assert!(extract_bearer_token_required(&h).is_err());
}
#[test]
fn extract_bearer_token_required_rejects_empty_token() {
let h = headers_with_bearer("");
assert!(extract_bearer_token_required(&h).is_err());
}
#[test]
fn extract_bearer_token_required_rejects_wrong_scheme() {
let mut h = HeaderMap::new();
h.insert(
axum::http::header::AUTHORIZATION,
HeaderValue::from_static("Basic dXNlcjpwYXNz"),
);
assert!(extract_bearer_token_required(&h).is_err());
}
}