use crate::api::{QueueRateTracker, StorageApi};
use crate::message::Priority;
use axum::extract::ws::{CloseFrame, Message as WsMessage, WebSocket};
use futures_util::{SinkExt, StreamExt};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::atomic::{AtomicI64, Ordering as AtomicOrdering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{mpsc, Notify};
use uuid::Uuid;
fn ping_interval() -> Duration {
let secs: u64 = std::env::var("WS_PING_INTERVAL_SECS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30);
Duration::from_secs(secs)
}
fn ping_timeout() -> Duration {
let secs: u64 = std::env::var("WS_PING_TIMEOUT_SECS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30);
Duration::from_secs(secs)
}
const OUTBOUND_CHANNEL_CAPACITY: usize = 512;
struct Delivery {
queue: String,
id: String,
payload: String,
priority: Priority,
created_at: String,
}
impl Delivery {
fn to_frame(&self) -> WsMessage {
let v = json!({
"type": "deliver",
"queue": self.queue,
"id": self.id,
"payload": self.payload,
"priority": self.priority,
"created_at": self.created_at,
});
WsMessage::Text(v.to_string().into())
}
}
struct SubscriptionState {
abort_handle: tokio::task::AbortHandle,
credits: Arc<AtomicI64>,
credits_notify: Arc<Notify>,
message_notify: Arc<Notify>,
}
struct SubscriptionRegistry {
handles: HashMap<String, SubscriptionState>,
}
impl SubscriptionRegistry {
fn new() -> Self {
Self {
handles: HashMap::new(),
}
}
fn subscribe(
&mut self,
queue: String,
storage: Arc<dyn StorageApi>,
conn_id: String,
tx: mpsc::Sender<Delivery>,
credits: Option<u64>,
lock_timeout_secs: u64,
) {
if self.handles.contains_key(&queue) {
return; }
let credit_count = Arc::new(AtomicI64::new(
credits.map(|c| c as i64).unwrap_or(-1), ));
let credits_notify = Arc::new(Notify::new());
let message_notify = Arc::new(Notify::new());
let cc = Arc::clone(&credit_count);
let cn = Arc::clone(&credits_notify);
let mn = Arc::clone(&message_notify);
let q = queue.clone();
let handle = tokio::spawn(async move {
loop {
let remaining = cc.load(AtomicOrdering::Acquire);
if remaining == 0 {
cn.notified().await;
continue;
}
match storage.pop(&q, &conn_id, lock_timeout_secs).await {
Ok(Some(msg)) => {
let delivery = Delivery {
queue: msg.queue.clone(),
id: msg.id.clone(),
payload: msg.payload.clone(),
priority: msg.priority,
created_at: msg.created_at.to_rfc3339(),
};
if tx.send(delivery).await.is_err() {
break; }
if remaining > 0 {
cc.fetch_sub(1, AtomicOrdering::Release);
}
}
Ok(None) => {
tokio::time::timeout(Duration::from_millis(500), mn.notified())
.await
.ok();
}
Err(_) => {
tokio::time::sleep(Duration::from_millis(500)).await;
}
}
}
})
.abort_handle();
self.handles.insert(
queue,
SubscriptionState {
abort_handle: handle,
credits: credit_count,
credits_notify,
message_notify,
},
);
}
fn unsubscribe(&mut self, queue: &str) {
if let Some(state) = self.handles.remove(queue) {
state.abort_handle.abort();
}
}
fn add_credits(&self, queue: &str, n: u64) -> bool {
if let Some(state) = self.handles.get(queue) {
let current = state.credits.load(AtomicOrdering::Acquire);
if current >= 0 {
state.credits.fetch_add(n as i64, AtomicOrdering::Release);
state.credits_notify.notify_one();
}
true
} else {
false
}
}
fn notify_message_available(&self, queue: &str) {
if let Some(state) = self.handles.get(queue) {
state.message_notify.notify_one();
}
}
}
impl Drop for SubscriptionRegistry {
fn drop(&mut self) {
for state in self.handles.values() {
state.abort_handle.abort();
}
}
}
pub(crate) async fn handle_connection(
ws: WebSocket,
storage: Arc<dyn StorageApi>,
rate_tracker: Arc<QueueRateTracker>,
log_buffer: Option<crate::log_buffer::LogBuffer>,
memory_pressure: Arc<std::sync::atomic::AtomicBool>,
) {
let conn_id = Uuid::new_v4().to_string();
tracing::info!(conn_id = %conn_id, "ws: connection opened");
let conn_started = Instant::now();
let (ws_sender, mut receiver) = ws.split();
let (outbound_tx, mut outbound_rx) = mpsc::channel::<WsMessage>(OUTBOUND_CHANNEL_CAPACITY);
let sender_conn_id = conn_id.clone();
let sender_handle = tokio::spawn(async move {
let mut sender = ws_sender;
let mut frames_sent: u64 = 0;
while let Some(frame) = outbound_rx.recv().await {
if let Err(e) = sender.send(frame).await {
tracing::warn!(
conn_id = %sender_conn_id,
frames_sent,
"ws: outbound sender: socket write error, exiting: {e}"
);
break;
}
frames_sent += 1;
}
if let Err(e) = sender.flush().await {
tracing::debug!(
conn_id = %sender_conn_id,
"ws: outbound sender: flush on exit failed: {e}"
);
}
});
macro_rules! send_frame {
($frame:expr) => {
if let Err(__send_err) = outbound_tx.try_send($frame) {
use tokio::sync::mpsc::error::TrySendError;
match __send_err {
TrySendError::Full(__dropped) => {
let __kind = match &__dropped {
WsMessage::Text(_) => "text",
WsMessage::Binary(_) => "binary",
WsMessage::Ping(_) => "ping",
WsMessage::Pong(_) => "pong",
WsMessage::Close(_) => "close",
};
tracing::warn!(
conn_id = %conn_id,
frame_type = __kind,
capacity = OUTBOUND_CHANNEL_CAPACITY,
"ws: outbound channel full, dropped frame"
);
}
TrySendError::Closed(_) => {
tracing::debug!(
conn_id = %conn_id,
"ws: outbound channel closed, frame not sent"
);
}
}
}
};
}
let (deliver_tx, mut deliver_rx) = mpsc::channel::<Delivery>(256);
let (log_tx, mut log_rx) = mpsc::channel::<crate::log_buffer::LogEntry>(256);
let ping_iv = ping_interval();
let ping_to = ping_timeout();
let mut ping_ticker = tokio::time::interval(ping_iv);
ping_ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
ping_ticker.tick().await;
let mut last_pong = Instant::now();
let mut subscriptions = SubscriptionRegistry::new();
let mut log_task: Option<tokio::task::JoinHandle<()>> = None;
loop {
tokio::select! {
msg = receiver.next() => {
let ws_msg = match msg {
Some(Ok(m)) => m,
Some(Err(e)) => {
tracing::warn!(
conn_id = %conn_id,
uptime_secs = conn_started.elapsed().as_secs_f64(),
"ws: receiver error, closing: {e}"
);
break;
}
None => {
tracing::info!(
conn_id = %conn_id,
uptime_secs = conn_started.elapsed().as_secs_f64(),
"ws: receiver stream ended (EOF)"
);
break;
}
};
match ws_msg {
WsMessage::Binary(_) => {
tracing::warn!(
conn_id = %conn_id,
"ws: received binary frame, closing with 1003"
);
send_frame!(WsMessage::Close(Some(CloseFrame {
code: axum::extract::ws::close_code::UNSUPPORTED,
reason: "binary frames not supported".into(),
})));
break;
}
WsMessage::Text(text) => {
let frame: Value = match serde_json::from_str(&text) {
Ok(v) => v,
Err(e) => {
tracing::warn!(
conn_id = %conn_id,
"ws: invalid JSON, closing with 1007: {e}"
);
send_frame!(WsMessage::Close(Some(CloseFrame {
code: axum::extract::ws::close_code::INVALID,
reason: "invalid JSON".into(),
})));
break;
}
};
let req_id = frame.get("req_id").and_then(|v| v.as_str()).map(str::to_owned);
let no_reply = frame.get("no_reply").and_then(|v| v.as_bool()).unwrap_or(false);
let frame_type = frame.get("type").and_then(|v| v.as_str()).unwrap_or("");
let out = match frame_type {
"subscribe-logs" => {
handle_subscribe_logs(&log_buffer, &mut log_task, &log_tx).await
}
"unsubscribe-logs" => {
handle_unsubscribe_logs(&mut log_task)
}
_ => {
dispatch(
&frame,
&storage,
&rate_tracker,
&mut subscriptions,
&conn_id,
&deliver_tx,
&memory_pressure,
)
.await
}
};
if !no_reply {
let out = attach_req_id(out, req_id);
send_frame!(WsMessage::Text(out.to_string().into()));
}
}
WsMessage::Ping(data) => {
send_frame!(WsMessage::Pong(data));
}
WsMessage::Pong(_) => {
last_pong = Instant::now();
}
WsMessage::Close(cf) => {
match cf {
Some(cf) => tracing::info!(
conn_id = %conn_id,
uptime_secs = conn_started.elapsed().as_secs_f64(),
"ws: client Close frame: code={} reason='{}'",
cf.code,
cf.reason
),
None => tracing::info!(
conn_id = %conn_id,
uptime_secs = conn_started.elapsed().as_secs_f64(),
"ws: client Close frame (no payload)"
),
}
break;
}
}
}
Some(delivery) = deliver_rx.recv() => {
send_frame!(delivery.to_frame());
}
Some(entry) = log_rx.recv() => {
let frame = json!({
"type": "log",
"timestamp": entry.timestamp,
"level": entry.level,
"message": entry.message,
});
send_frame!(WsMessage::Text(frame.to_string().into()));
}
_ = ping_ticker.tick() => {
if last_pong.elapsed() > ping_iv + ping_to {
tracing::warn!(
conn_id = %conn_id,
uptime_secs = conn_started.elapsed().as_secs_f64(),
last_pong_secs_ago = last_pong.elapsed().as_secs_f64(),
"ws: ping timeout, closing with 1001"
);
send_frame!(WsMessage::Close(Some(CloseFrame {
code: axum::extract::ws::close_code::AWAY,
reason: "ping timeout".into(),
})));
break;
}
tracing::debug!(
conn_id = %conn_id,
last_pong_secs_ago = last_pong.elapsed().as_secs_f64(),
"ws: sending ping"
);
send_frame!(WsMessage::Ping(vec![].into()));
}
}
}
let subscription_count = subscriptions.handles.len();
tracing::info!(
conn_id = %conn_id,
uptime_secs = conn_started.elapsed().as_secs_f64(),
subscriptions = subscription_count,
"ws: connection main loop exited, starting shutdown"
);
drop(outbound_tx);
if let Some(task) = log_task.take() {
task.abort();
}
let _ = sender_handle.await;
tracing::info!(
conn_id = %conn_id,
total_uptime_secs = conn_started.elapsed().as_secs_f64(),
"ws: connection closed"
);
}
async fn dispatch(
frame: &Value,
storage: &Arc<dyn StorageApi>,
rate_tracker: &Arc<QueueRateTracker>,
subs: &mut SubscriptionRegistry,
conn_id: &str,
deliver_tx: &mpsc::Sender<Delivery>,
memory_pressure: &std::sync::atomic::AtomicBool,
) -> Value {
let frame_type = match frame.get("type").and_then(|v| v.as_str()) {
Some(t) => t,
None => {
return error_frame("invalid_request", "missing 'type' field");
}
};
match frame_type {
"publish" => handle_publish(frame, storage, rate_tracker, memory_pressure, subs).await,
"subscribe" => handle_subscribe(frame, storage, subs, conn_id, deliver_tx).await,
"unsubscribe" => handle_unsubscribe(frame, subs),
"ack" => handle_ack(frame, storage, rate_tracker, conn_id).await,
"nack" => handle_nack(frame, storage, rate_tracker, conn_id).await,
"nack-with-delay" => handle_nack_with_delay(frame, storage, rate_tracker, conn_id).await,
"batch-ack" => handle_batch_ack(frame, storage, rate_tracker, conn_id).await,
"batch-nack" => handle_batch_nack(frame, storage, rate_tracker, conn_id).await,
"renew" => handle_renew(frame, storage, conn_id).await,
"credit" => handle_credit(frame, subs),
other => error_frame("unknown_type", &format!("unknown frame type '{}'", other)),
}
}
async fn handle_publish(
frame: &Value,
storage: &Arc<dyn StorageApi>,
rate_tracker: &Arc<QueueRateTracker>,
memory_pressure: &std::sync::atomic::AtomicBool,
subs: &SubscriptionRegistry,
) -> Value {
if memory_pressure.load(std::sync::atomic::Ordering::Relaxed) {
return error_frame(
"memory_pressure",
"Server is under memory pressure. Try again later.",
);
}
let queue = match frame.get("queue").and_then(|v| v.as_str()) {
Some(q) => q.to_owned(),
None => return error_frame("invalid_request", "publish requires 'queue' field"),
};
let payload = match frame.get("payload").and_then(|v| v.as_str()) {
Some(p) => p.to_owned(),
None => return error_frame("invalid_request", "publish requires 'payload' field"),
};
let priority = match frame.get("priority") {
Some(serde_json::Value::String(s)) => Priority::Text(s.clone()),
Some(v) => Priority::Numeric(v.as_u64().unwrap_or(0)),
None => Priority::Numeric(0),
};
let max_retries = frame
.get("max_retries")
.and_then(|v| v.as_u64())
.map(|v| v as u32);
match storage.queue_exists(&queue).await {
Ok(true) => {}
Ok(false) => {
return error_frame("queue_not_found", &format!("queue '{}' not found", queue))
}
Err(e) => return error_frame("storage_error", &e.to_string()),
}
let msg = crate::message::Message {
id: Uuid::new_v4().to_string(),
queue: queue.clone(),
priority,
payload,
created_at: chrono::Utc::now(),
locked_until: None,
locked_by: None,
retry_count: 0,
max_retries: max_retries.unwrap_or(3),
payload_ref: None,
payload_hash: None,
};
match storage.push(msg).await {
Ok(id) => {
rate_tracker.record_publish(&queue);
subs.notify_message_available(&queue);
json!({"type": "ok", "id": id})
}
Err(e) => error_frame("publish_failed", &e.to_string()),
}
}
async fn handle_subscribe(
frame: &Value,
storage: &Arc<dyn StorageApi>,
subs: &mut SubscriptionRegistry,
conn_id: &str,
deliver_tx: &mpsc::Sender<Delivery>,
) -> Value {
let queue = match frame.get("queue").and_then(|v| v.as_str()) {
Some(q) => q.to_owned(),
None => return error_frame("invalid_request", "subscribe requires 'queue' field"),
};
let credits = frame.get("credits").and_then(|v| v.as_u64());
let lock_timeout_secs = frame
.get("lock_timeout_secs")
.and_then(|v| v.as_u64())
.unwrap_or(30);
match storage.queue_exists(&queue).await {
Ok(true) => {}
Ok(false) => {
return error_frame("queue_not_found", &format!("queue '{}' not found", queue))
}
Err(e) => return error_frame("storage_error", &e.to_string()),
}
subs.subscribe(
queue,
storage.clone(),
conn_id.to_owned(),
deliver_tx.clone(),
credits,
lock_timeout_secs,
);
json!({"type": "ok"})
}
fn handle_unsubscribe(frame: &Value, subs: &mut SubscriptionRegistry) -> Value {
let queue = match frame.get("queue").and_then(|v| v.as_str()) {
Some(q) => q,
None => return error_frame("invalid_request", "unsubscribe requires 'queue' field"),
};
subs.unsubscribe(queue);
json!({"type": "ok"})
}
async fn handle_ack(
frame: &Value,
storage: &Arc<dyn StorageApi>,
rate_tracker: &Arc<QueueRateTracker>,
conn_id: &str,
) -> Value {
let (queue, id) = match required_queue_id(frame) {
Ok(v) => v,
Err(e) => return e,
};
match storage.ack(&queue, &id, conn_id).await {
Ok(true) => {
rate_tracker.record_ack(&queue);
json!({"type": "ok"})
}
Ok(false) => json!({"type": "ok"}),
Err(e) => error_frame("ack_failed", &e.to_string()),
}
}
async fn handle_nack(
frame: &Value,
storage: &Arc<dyn StorageApi>,
rate_tracker: &Arc<QueueRateTracker>,
conn_id: &str,
) -> Value {
let (queue, id) = match required_queue_id(frame) {
Ok(v) => v,
Err(e) => return e,
};
match storage.nack(&queue, &id, conn_id).await {
Ok(true) => {
rate_tracker.record_nack(&queue);
json!({"type": "ok"})
}
Ok(false) => json!({"type": "ok"}),
Err(e) => error_frame("nack_failed", &e.to_string()),
}
}
async fn handle_nack_with_delay(
frame: &Value,
storage: &Arc<dyn StorageApi>,
rate_tracker: &Arc<QueueRateTracker>,
conn_id: &str,
) -> Value {
let (queue, id) = match required_queue_id(frame) {
Ok(v) => v,
Err(e) => return e,
};
let delay_secs = match frame.get("delay_secs").and_then(|v| v.as_u64()) {
Some(n) => n,
None => {
return error_frame(
"invalid_request",
"nack-with-delay requires non-negative integer 'delay_secs' field",
);
}
};
match storage
.nack_with_delay(&queue, &id, conn_id, delay_secs)
.await
{
Ok(true) => {
rate_tracker.record_nack(&queue);
json!({"type": "ok"})
}
Ok(false) => json!({"type": "ok"}),
Err(e) => error_frame("nack_with_delay_failed", &e.to_string()),
}
}
async fn handle_batch_ack(
frame: &Value,
storage: &Arc<dyn StorageApi>,
rate_tracker: &Arc<QueueRateTracker>,
conn_id: &str,
) -> Value {
let queue = match frame.get("queue").and_then(|v| v.as_str()) {
Some(q) => q.to_owned(),
None => return error_frame("invalid_request", "batch-ack requires 'queue' field"),
};
let ids: Vec<String> = match frame.get("ids").and_then(|v| v.as_array()) {
Some(arr) => arr
.iter()
.filter_map(|v| v.as_str().map(str::to_owned))
.collect(),
None => return error_frame("invalid_request", "batch-ack requires 'ids' array"),
};
match storage.batch_ack(&queue, conn_id, &ids).await {
Ok(result) => {
for _ in &result.acked {
rate_tracker.record_ack(&queue);
}
json!({"type": "ok", "acked": result.acked.len()})
}
Err(e) => error_frame("batch_ack_failed", &e.to_string()),
}
}
async fn handle_batch_nack(
frame: &Value,
storage: &Arc<dyn StorageApi>,
rate_tracker: &Arc<QueueRateTracker>,
conn_id: &str,
) -> Value {
let queue = match frame.get("queue").and_then(|v| v.as_str()) {
Some(q) => q.to_owned(),
None => return error_frame("invalid_request", "batch-nack requires 'queue' field"),
};
let ids: Vec<String> = match frame.get("ids").and_then(|v| v.as_array()) {
Some(arr) => arr
.iter()
.filter_map(|v| v.as_str().map(str::to_owned))
.collect(),
None => return error_frame("invalid_request", "batch-nack requires 'ids' array"),
};
match storage.batch_nack(&queue, conn_id, &ids).await {
Ok(result) => {
let nacked = result.unlocked.len() + result.dead_lettered.len() + result.dropped.len();
for _ in 0..nacked {
rate_tracker.record_nack(&queue);
}
json!({
"type": "ok",
"unlocked": result.unlocked.len(),
"dropped": result.dead_lettered.len() + result.dropped.len(),
})
}
Err(e) => error_frame("batch_nack_failed", &e.to_string()),
}
}
async fn handle_renew(frame: &Value, storage: &Arc<dyn StorageApi>, conn_id: &str) -> Value {
let (queue, id) = match required_queue_id(frame) {
Ok(v) => v,
Err(e) => return e,
};
let lock_timeout_secs = frame
.get("lock_timeout_secs")
.and_then(|v| v.as_u64())
.unwrap_or(30);
match storage.renew(&queue, &id, conn_id, lock_timeout_secs).await {
Ok(true) => json!({"type": "ok"}),
Ok(false) => error_frame("not_locked", "message not locked by this consumer"),
Err(e) => error_frame("renew_failed", &e.to_string()),
}
}
fn handle_credit(frame: &Value, subs: &SubscriptionRegistry) -> Value {
let queue = match frame.get("queue").and_then(|v| v.as_str()) {
Some(q) => q,
None => return error_frame("invalid_request", "credit requires 'queue' field"),
};
let credits = match frame.get("credits").and_then(|v| v.as_u64()) {
Some(c) if c > 0 => c,
_ => {
return error_frame(
"invalid_request",
"credit requires positive 'credits' field",
)
}
};
if subs.add_credits(queue, credits) {
json!({"type": "ok"})
} else {
error_frame(
"not_subscribed",
&format!("not subscribed to queue '{}'", queue),
)
}
}
fn error_frame(code: &str, message: &str) -> Value {
json!({"type": "error", "code": code, "message": message})
}
fn attach_req_id(mut frame: Value, req_id: Option<String>) -> Value {
if let Some(rid) = req_id {
if let Some(obj) = frame.as_object_mut() {
obj.insert("req_id".to_owned(), Value::String(rid));
}
}
frame
}
async fn handle_subscribe_logs(
log_buffer: &Option<crate::log_buffer::LogBuffer>,
log_task: &mut Option<tokio::task::JoinHandle<()>>,
log_tx: &mpsc::Sender<crate::log_buffer::LogEntry>,
) -> Value {
let buf = match log_buffer {
Some(b) => b.clone(),
None => return error_frame("not_available", "log streaming not configured"),
};
if let Some(task) = log_task.take() {
task.abort();
}
let tx = log_tx.clone();
let mut rx = buf.subscribe();
*log_task = Some(tokio::spawn(async move {
let history = buf.history_sync();
for entry in history {
if tx.send(entry).await.is_err() {
return; }
}
loop {
match rx.recv().await {
Ok(entry) => {
if tx.send(entry).await.is_err() {
break; }
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
continue;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
}));
json!({"type": "ok"})
}
fn handle_unsubscribe_logs(log_task: &mut Option<tokio::task::JoinHandle<()>>) -> Value {
if let Some(task) = log_task.take() {
task.abort();
}
json!({"type": "ok"})
}
fn required_queue_id(frame: &Value) -> Result<(String, String), Value> {
let queue = frame
.get("queue")
.and_then(|v| v.as_str())
.ok_or_else(|| error_frame("invalid_request", "requires 'queue' field"))?
.to_owned();
let id = frame
.get("id")
.and_then(|v| v.as_str())
.ok_or_else(|| error_frame("invalid_request", "requires 'id' field"))?
.to_owned();
Ok((queue, id))
}