use std::time::Duration;
use axum::extract::ws::{Message, WebSocket};
use axum::extract::{Query, State, WebSocketUpgrade};
use axum::http::HeaderMap;
use axum::response::IntoResponse;
use futures_util::{SinkExt, StreamExt};
use serde::Deserialize;
use tokio::time::interval;
use tracing::{debug, info, warn};
use crate::auth::require_bearer;
use crate::error::AppError;
use crate::jetstream::{looks_like_retention_exhausted, open_ws_message_stream, stream_first_seq};
use crate::state::AppState;
const DEFAULT_SUBJECT: &str = "cellos.events.>";
const HEARTBEAT: Duration = Duration::from_secs(25);
const WS_MAX_FRAME_BYTES: usize = 64 * 1024;
const WS_SEND_TIMEOUT: Duration = Duration::from_secs(50);
#[derive(Debug, Deserialize)]
pub struct WsParams {
pub subject: Option<String>,
pub since: Option<u64>,
}
pub async fn ws_events(
State(state): State<AppState>,
headers: HeaderMap,
Query(params): Query<WsParams>,
ws: WebSocketUpgrade,
) -> Result<impl IntoResponse, AppError> {
require_bearer(&headers, &state.api_token)?;
let subject = params.subject.unwrap_or_else(|| DEFAULT_SUBJECT.to_owned());
let since = params.since;
let ws = ws
.max_message_size(WS_MAX_FRAME_BYTES)
.max_frame_size(WS_MAX_FRAME_BYTES);
Ok(ws.on_upgrade(move |socket| handle_socket(socket, state, subject, since)))
}
async fn handle_socket(socket: WebSocket, state: AppState, subject: String, since: Option<u64>) {
let Some(ctx) = state.jetstream.clone() else {
warn!("ws connect with no JetStream context configured; closing");
let _ = socket
.send_close_with_reason("no upstream broker configured")
.await;
return;
};
let subject_filter = if subject == DEFAULT_SUBJECT {
None
} else {
Some(subject.as_str())
};
let mut messages = match open_ws_message_stream(&ctx, subject_filter, since).await {
Ok(s) => s,
Err(e) => {
warn!(error = %format!("{e:#}"), subject = %subject, since = ?since, "jetstream consumer create failed");
if since.is_some() && looks_like_retention_exhausted(&e) {
let oldest = stream_first_seq(&ctx).await;
close_retention_exhausted(socket, oldest).await;
} else {
let _ = socket.send_close_with_reason("subscribe failed").await;
}
return;
}
};
info!(
subject = %subject,
since = ?since,
"ws client connected, bridging JetStream messages",
);
let (mut tx, mut rx) = socket.split();
let mut heartbeat = interval(HEARTBEAT);
heartbeat.tick().await;
loop {
tokio::select! {
biased;
incoming = rx.next() => {
match incoming {
Some(Ok(Message::Close(_))) | None => {
debug!("ws client closed");
break;
}
Some(Err(e)) => {
warn!(error = %e, "ws recv error");
break;
}
Some(Ok(_)) => {}
}
}
_ = heartbeat.tick() => {
match tokio::time::timeout(WS_SEND_TIMEOUT, tx.send(Message::Ping(Vec::new()))).await {
Ok(Ok(())) => {}
Ok(Err(_)) => {
debug!("ws heartbeat send failed; client gone");
break;
}
Err(_) => {
warn!("ws heartbeat send timed out after {:?}; closing", WS_SEND_TIMEOUT);
break;
}
}
}
msg = messages.next() => {
match msg {
Some(Ok(m)) => {
let seq = match m.info() {
Ok(info) => info.stream_sequence,
Err(e) => {
warn!(error = %e, "ws msg missing stream info; skipping");
continue;
}
};
let payload = match build_envelope(seq, &m.payload) {
Ok(s) => s,
Err(EnvelopeError::NotUtf8) => {
warn!(subject = %subject, "dropping non-utf8 jetstream payload");
continue;
}
Err(EnvelopeError::NotJson(e)) => {
warn!(
subject = %subject,
error = %e,
"dropping non-json jetstream payload",
);
continue;
}
};
state.bump_cursor(seq);
match tokio::time::timeout(WS_SEND_TIMEOUT, tx.send(Message::Text(payload))).await {
Ok(Ok(())) => {}
Ok(Err(_)) => {
debug!("ws send failed; client gone");
break;
}
Err(_) => {
warn!(seq, "ws send timed out after {:?}; closing", WS_SEND_TIMEOUT);
break;
}
}
if let Err(e) = m.ack().await {
debug!(seq, error = %e, "jetstream ack failed (AckPolicy::None)");
}
}
Some(Err(e)) => {
warn!(error = %e, "jetstream message error; closing ws");
break;
}
None => {
debug!("jetstream message stream ended");
break;
}
}
}
}
}
info!(subject = %subject, "ws client disconnected");
}
async fn close_retention_exhausted(mut socket: WebSocket, oldest_seq: Option<u64>) {
let problem = serde_json::json!({
"type": "/problems/ws/retention-exhausted",
"title": "Cursor older than stream retention",
"oldest_seq": oldest_seq,
});
let _ = socket.send(Message::Text(problem.to_string())).await;
let _ = socket
.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: 4410,
reason: "retention-exhausted".into(),
})))
.await;
}
pub(crate) fn build_envelope(seq: u64, payload: &[u8]) -> Result<String, EnvelopeError> {
let s = std::str::from_utf8(payload).map_err(|_| EnvelopeError::NotUtf8)?;
let event_value: serde_json::Value = serde_json::from_str(s).map_err(EnvelopeError::NotJson)?;
let envelope = serde_json::json!({ "seq": seq, "event": event_value });
Ok(envelope.to_string())
}
#[derive(Debug)]
pub(crate) enum EnvelopeError {
NotUtf8,
NotJson(serde_json::Error),
}
trait CloseExt {
async fn send_close_with_reason(self, reason: &'static str) -> Result<(), axum::Error>;
}
impl CloseExt for WebSocket {
async fn send_close_with_reason(mut self, reason: &'static str) -> Result<(), axum::Error> {
self.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: axum::extract::ws::close_code::POLICY,
reason: reason.into(),
})))
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ws_envelope_carries_seq() {
let cloud_event = serde_json::json!({
"specversion": "1.0",
"type": "io.cellos.formation.v1.created",
"source": "/formations/abc",
"id": "evt-1",
"data": { "name": "demo" }
});
let payload = serde_json::to_vec(&cloud_event).unwrap();
let frame = build_envelope(42, &payload).expect("envelope build");
let parsed: serde_json::Value = serde_json::from_str(&frame).unwrap();
assert_eq!(
parsed["seq"].as_u64(),
Some(42),
"envelope must carry the seq as the cursor field; got {}",
parsed["seq"]
);
assert!(
parsed["event"].is_object(),
"event must be a structured JSON object, not a string-of-JSON; got {}",
parsed["event"]
);
assert_eq!(parsed["event"]["type"], "io.cellos.formation.v1.created");
assert_eq!(parsed["event"]["data"]["name"], "demo");
}
#[test]
fn ws_envelope_rejects_non_utf8_payload() {
let bad = [0xffu8, 0xfe, 0xfd];
match build_envelope(1, &bad) {
Err(EnvelopeError::NotUtf8) => {}
other => panic!("expected NotUtf8, got {other:?}"),
}
}
#[test]
fn ws_envelope_rejects_non_json_payload() {
let bad = b"hello, world";
match build_envelope(1, bad) {
Err(EnvelopeError::NotJson(_)) => {}
other => panic!("expected NotJson, got {other:?}"),
}
}
}