use moltendb_auth as auth;
use moltendb_core::engine;
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
extract::ws::Utf8Bytes,
Extension,
};
use futures::{sink::SinkExt, stream::StreamExt};
use tokio::time::{interval, Duration};
use tracing::warn;
pub async fn ws_handler(
ws: WebSocketUpgrade,
State((db, _, _max_body_size, _, _)): State<(engine::Db, auth::UserStore, usize, usize, String)>,
Extension(revocation_store): Extension<auth::RevocationStore>,
) -> impl axum::response::IntoResponse {
ws.on_upgrade(|socket| handle_socket(socket, db, revocation_store))
}
async fn handle_socket(mut socket: WebSocket, db: engine::Db, revocation_store: auth::RevocationStore) {
enum AuthResult {
Ok(auth::Claims),
Err(&'static str),
}
let auth_result = match socket.next().await {
Some(Ok(Message::Text(text))) => {
match serde_json::from_str::<serde_json::Value>(&text) {
Err(_) => AuthResult::Err(
r#"{"error":"invalid_message","detail":"Could not parse JSON. Expected {\"action\":\"AUTH\",\"token\":\"<jwt>\"}"}"#,
),
Ok(payload) => {
if payload["action"].as_str() != Some("AUTH") {
AuthResult::Err(
r#"{"error":"invalid_action","detail":"First message must have \"action\":\"AUTH\". Use HTTP endpoints for CRUD operations."}"#,
)
} else if let Some(token) = payload["token"].as_str() {
match auth::verify_token(token) {
Err(_) => AuthResult::Err(
r#"{"error":"invalid_token","detail":"JWT verification failed. The token may be expired, malformed, or signed with the wrong secret."}"#,
),
Ok(c) => {
if revocation_store.is_revoked(&c.jti) {
warn!("🔒 Rejected WebSocket connection: token JTI '{}' is revoked.", c.jti);
AuthResult::Err(
r#"{"error":"token_revoked","detail":"This token has been revoked. Mint a new token via POST /auth/tokens."}"#,
)
} else {
AuthResult::Ok(c)
}
}
}
} else {
AuthResult::Err(
r#"{"error":"missing_token","detail":"AUTH message is missing the \"token\" field. Expected {\"action\":\"AUTH\",\"token\":\"<jwt>\"}"}"#,
)
}
}
}
}
_ => AuthResult::Err(
r#"{"error":"invalid_message","detail":"First message must be a text frame containing {\"action\":\"AUTH\",\"token\":\"<jwt>\"}"}"#,
),
};
let claims = match auth_result {
AuthResult::Ok(c) => c,
AuthResult::Err(msg) => {
let _ = socket.send(Message::Text(Utf8Bytes::from(msg))).await;
let _ = socket.close().await;
warn!("🔒 Rejected WebSocket connection: {}", msg);
return;
}
};
let _ = socket
.send(Message::Text(Utf8Bytes::from(
r#"{"status":"authenticated","message":"Connected to MoltenDB real-time feed. Use HTTP endpoints for CRUD. Send {\"action\":\"SUBSCRIBE\",\"collection\":\"<name>\"} to register interest."}"#,
)))
.await;
let (mut sender, mut receiver) = socket.split();
let mut rx = db.subscribe();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(Message::Text(_text))) = receiver.next().await {
}
});
let mut send_task = tokio::spawn(async move {
let mut revocation_check = interval(Duration::from_secs(30));
revocation_check.tick().await; loop {
tokio::select! {
_ = revocation_check.tick() => {
if revocation_store.is_revoked(&claims.jti) {
warn!("🔒 Closing WebSocket: token JTI '{}' was revoked after connection was established.", claims.jti);
let _ = sender.send(Message::Text(Utf8Bytes::from(
r#"{"error":"token_revoked","detail":"Your token has been revoked. The connection is being closed."}"#,
))).await;
break;
}
}
recv_result = rx.recv() => {
match recv_result {
Ok(msg) => {
let allowed = if let Ok(event) = serde_json::from_str::<serde_json::Value>(&msg) {
if let Some(collection) = event.get("collection").and_then(|v| v.as_str()) {
claims.has_collection_access("read", collection)
} else {
false
}
} else {
false
};
if allowed {
if sender.send(Message::Text(Utf8Bytes::from(msg))).await.is_err() {
break; }
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
warn!("⚠️ WebSocket send task lagged: {} events dropped for this client.", n);
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
break;
}
}
}
}
}
});
tokio::select! {
_ = (&mut recv_task) => send_task.abort(),
_ = (&mut send_task) => recv_task.abort(),
};
}