use crate::error::QrustyClientError;
use crate::priority::Priority;
use futures_util::{stream::SplitSink, SinkExt, StreamExt};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot, Mutex};
use tokio::task::JoinHandle;
use tokio_tungstenite::{
connect_async, tungstenite::Message as TMsg, MaybeTlsStream, WebSocketStream,
};
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
type WsSink = SplitSink<WsStream, TMsg>;
type PendingMap = Arc<Mutex<HashMap<String, oneshot::Sender<Result<Value, QrustyClientError>>>>>;
type SubMap =
Arc<Mutex<HashMap<String, mpsc::Sender<Result<DeliveredMessage, QrustyClientError>>>>>;
type DeliveryEnvelope = (String, Result<DeliveredMessage, QrustyClientError>);
type LogSender = Arc<Mutex<Option<mpsc::Sender<LogEntry>>>>;
#[derive(Debug, Clone)]
pub struct LogEntry {
pub timestamp: String,
pub level: String,
pub message: String,
}
#[derive(Debug, Clone)]
pub struct DeliveredMessage {
pub queue: String,
pub id: String,
pub payload: String,
pub priority: Priority,
pub created_at: String,
}
impl TryFrom<&Value> for DeliveredMessage {
type Error = QrustyClientError;
fn try_from(v: &Value) -> Result<Self, Self::Error> {
Ok(DeliveredMessage {
queue: str_field(v, "queue")?,
id: str_field(v, "id")?,
payload: str_field(v, "payload")?,
priority: serde_json::from_value(
v.get("priority")
.cloned()
.ok_or_else(|| invalid("missing priority"))?,
)
.map_err(|_| invalid("invalid priority value"))?,
created_at: str_field(v, "created_at")?,
})
}
}
pub struct WsSession {
sink: Arc<Mutex<WsSink>>,
req_counter: Arc<AtomicU64>,
pending: PendingMap,
subscribers: SubMap,
log_tx: LogSender,
_router: JoinHandle<()>,
_delivery_worker: JoinHandle<()>,
request_timeout: Duration,
}
impl WsSession {
pub async fn connect(addr: &str) -> Result<Self, QrustyClientError> {
Self::connect_with_timeout(addr, Duration::from_secs(30)).await
}
pub async fn connect_with_timeout(
addr: &str,
request_timeout: Duration,
) -> Result<Self, QrustyClientError> {
let url = format!("{}/ws", addr);
let (ws, _) = connect_async(&url)
.await
.map_err(|e| QrustyClientError::Other(format!("WebSocket connect failed: {}", e)))?;
let (sink, source) = ws.split();
let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
let subscribers: SubMap = Arc::new(Mutex::new(HashMap::new()));
let (delivery_tx, delivery_rx) = mpsc::unbounded_channel();
let log_tx: LogSender = Arc::new(Mutex::new(None));
let router = tokio::spawn(router_task(
source,
Arc::clone(&pending),
Arc::clone(&subscribers),
delivery_tx,
Arc::clone(&log_tx),
));
let delivery_worker = tokio::spawn(delivery_task(delivery_rx, Arc::clone(&subscribers)));
Ok(WsSession {
sink: Arc::new(Mutex::new(sink)),
req_counter: Arc::new(AtomicU64::new(0)),
pending,
subscribers,
log_tx,
_router: router,
_delivery_worker: delivery_worker,
request_timeout,
})
}
fn next_req_id(&self) -> String {
format!("req-{}", self.req_counter.fetch_add(1, Ordering::Relaxed))
}
pub async fn send_frame(&self, frame: Value) -> Result<(), QrustyClientError> {
self.sink
.lock()
.await
.send(TMsg::Text(frame.to_string().into()))
.await
.map_err(|e| QrustyClientError::Other(format!("send error: {}", e)))
}
async fn request(&self, mut frame: Value) -> Result<Value, QrustyClientError> {
let req_id = self.next_req_id();
frame
.as_object_mut()
.unwrap()
.insert("req_id".to_owned(), Value::String(req_id.clone()));
let (tx, rx) = oneshot::channel();
self.pending.lock().await.insert(req_id.clone(), tx);
if let Err(e) = self.send_frame(frame).await {
self.pending.lock().await.remove(&req_id);
return Err(e);
}
match tokio::time::timeout(self.request_timeout, rx).await {
Ok(Ok(result)) => result,
Ok(Err(_)) => Err(QrustyClientError::Other(
"connection closed before response".into(),
)),
Err(_) => {
self.pending.lock().await.remove(&req_id);
Err(QrustyClientError::RequestTimeout(self.request_timeout))
}
}
}
pub async fn publish(
&self,
queue: &str,
payload: &str,
priority: Option<Priority>,
) -> Result<String, QrustyClientError> {
let priority = priority.unwrap_or_default();
let resp = self
.request(json!({
"type": "publish",
"queue": queue,
"payload": payload,
"priority": priority,
}))
.await?;
resp["id"]
.as_str()
.map(str::to_owned)
.ok_or_else(|| invalid("missing id in publish response"))
}
pub async fn subscribe(
&self,
queue: &str,
) -> Result<mpsc::Receiver<Result<DeliveredMessage, QrustyClientError>>, QrustyClientError>
{
self.subscribe_with_credits(queue, None).await
}
pub async fn subscribe_with_credits(
&self,
queue: &str,
credits: Option<u64>,
) -> Result<mpsc::Receiver<Result<DeliveredMessage, QrustyClientError>>, QrustyClientError>
{
let (tx, rx) = mpsc::channel::<Result<DeliveredMessage, QrustyClientError>>(256);
self.subscribers.lock().await.insert(queue.to_owned(), tx);
let mut frame = json!({"type": "subscribe", "queue": queue});
if let Some(c) = credits {
frame
.as_object_mut()
.unwrap()
.insert("credits".to_owned(), json!(c));
}
if let Err(e) = self.request(frame).await {
self.subscribers.lock().await.remove(queue);
return Err(e);
}
Ok(rx)
}
pub async fn unsubscribe(&self, queue: &str) -> Result<(), QrustyClientError> {
self.subscribers.lock().await.remove(queue);
self.request(json!({"type": "unsubscribe", "queue": queue}))
.await
.map(|_| ())
}
pub async fn ack(&self, queue: &str, id: &str) -> Result<(), QrustyClientError> {
self.request(json!({"type": "ack", "queue": queue, "id": id}))
.await
.map(|_| ())
}
pub async fn nack(&self, queue: &str, id: &str) -> Result<(), QrustyClientError> {
self.request(json!({"type": "nack", "queue": queue, "id": id}))
.await
.map(|_| ())
}
pub async fn ack_noreply(&self, queue: &str, id: &str) -> Result<(), QrustyClientError> {
self.send_frame(json!({
"type": "ack",
"queue": queue,
"id": id,
"no_reply": true,
}))
.await
}
pub async fn nack_noreply(&self, queue: &str, id: &str) -> Result<(), QrustyClientError> {
self.send_frame(json!({
"type": "nack",
"queue": queue,
"id": id,
"no_reply": true,
}))
.await
}
pub async fn batch_ack(&self, queue: &str, ids: &[&str]) -> Result<usize, QrustyClientError> {
let resp = self
.request(json!({"type": "batch-ack", "queue": queue, "ids": ids}))
.await?;
resp["acked"]
.as_u64()
.map(|n| n as usize)
.ok_or_else(|| invalid("missing 'acked' field"))
}
pub async fn batch_nack(
&self,
queue: &str,
ids: &[&str],
) -> Result<(usize, usize), QrustyClientError> {
let resp = self
.request(json!({"type": "batch-nack", "queue": queue, "ids": ids}))
.await?;
let unlocked = resp["unlocked"].as_u64().unwrap_or(0) as usize;
let dropped = resp["dropped"].as_u64().unwrap_or(0) as usize;
Ok((unlocked, dropped))
}
pub async fn grant_credits(&self, queue: &str, credits: u64) -> Result<(), QrustyClientError> {
self.request(json!({
"type": "credit",
"queue": queue,
"credits": credits,
}))
.await
.map(|_| ())
}
pub async fn subscribe_logs(&self) -> Result<mpsc::Receiver<LogEntry>, QrustyClientError> {
let (tx, rx) = mpsc::channel::<LogEntry>(256);
*self.log_tx.lock().await = Some(tx);
if let Err(e) = self.request(json!({"type": "subscribe-logs"})).await {
*self.log_tx.lock().await = None;
return Err(e);
}
Ok(rx)
}
pub async fn unsubscribe_logs(&self) -> Result<(), QrustyClientError> {
self.request(json!({"type": "unsubscribe-logs"}))
.await
.map(|_| ())?;
*self.log_tx.lock().await = None;
Ok(())
}
pub async fn renew(&self, queue: &str, id: &str) -> Result<(), QrustyClientError> {
self.request(json!({"type": "renew", "queue": queue, "id": id}))
.await
.map(|_| ())
}
pub async fn close(self) -> Result<(), QrustyClientError> {
self.sink
.lock()
.await
.send(TMsg::Close(None))
.await
.map_err(|e| QrustyClientError::Other(format!("close error: {}", e)))
}
}
async fn router_task<S>(
mut source: S,
pending: PendingMap,
subscribers: SubMap,
delivery_tx: mpsc::UnboundedSender<DeliveryEnvelope>,
log_tx: LogSender,
) where
S: futures_util::Stream<Item = Result<TMsg, tokio_tungstenite::tungstenite::Error>> + Unpin,
{
while let Some(Ok(msg)) = source.next().await {
let text = match msg {
TMsg::Text(t) => t,
_ => continue,
};
let frame: Value = match serde_json::from_str(&text) {
Ok(v) => v,
Err(_) => continue,
};
if let Some(req_id) = frame.get("req_id").and_then(|v| v.as_str()) {
if let Some(tx) = pending.lock().await.remove(req_id) {
let result = if frame["type"] == "error" {
Err(QrustyClientError::Other(format!(
"server error {}: {}",
frame["code"].as_str().unwrap_or("?"),
frame["message"].as_str().unwrap_or("?"),
)))
} else {
Ok(frame)
};
let _ = tx.send(result);
}
continue;
}
if frame["type"] == "deliver" {
if let Some(queue) = frame["queue"].as_str().map(str::to_owned) {
let msg = DeliveredMessage::try_from(&frame);
let _ = delivery_tx.send((queue, msg));
}
continue;
}
if frame["type"] == "log" {
let entry = LogEntry {
timestamp: frame["timestamp"].as_str().unwrap_or("").to_owned(),
level: frame["level"].as_str().unwrap_or("").to_owned(),
message: frame["message"].as_str().unwrap_or("").to_owned(),
};
if let Some(tx) = log_tx.lock().await.as_ref() {
let _ = tx.try_send(entry);
}
}
}
let mut map = pending.lock().await;
for (_, tx) in map.drain() {
let _ = tx.send(Err(QrustyClientError::Other("connection closed".into())));
}
subscribers.lock().await.clear();
}
async fn delivery_task(
mut delivery_rx: mpsc::UnboundedReceiver<DeliveryEnvelope>,
subscribers: SubMap,
) {
while let Some((queue, msg)) = delivery_rx.recv().await {
let maybe_tx = subscribers.lock().await.get(&queue).cloned();
if let Some(tx) = maybe_tx {
match tx.try_send(msg) {
Ok(()) => {}
Err(mpsc::error::TrySendError::Full(_)) => {
log::warn!(
"subscriber channel full for queue '{}'; dropping delivery frame",
queue,
);
}
Err(mpsc::error::TrySendError::Closed(_)) => {
subscribers.lock().await.remove(&queue);
}
}
}
}
}
fn str_field(v: &Value, field: &str) -> Result<String, QrustyClientError> {
v.get(field)
.and_then(|f| f.as_str())
.map(str::to_owned)
.ok_or_else(|| invalid(&format!("missing field '{}'", field)))
}
fn invalid(msg: &str) -> QrustyClientError {
QrustyClientError::InvalidResponse(msg.to_owned())
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_tungstenite::tungstenite::protocol::Role;
async fn in_memory_ws_pair() -> (
futures_util::stream::SplitStream<
tokio_tungstenite::WebSocketStream<tokio::io::DuplexStream>,
>,
tokio_tungstenite::WebSocketStream<tokio::io::DuplexStream>,
) {
let (server_io, client_io) = tokio::io::duplex(65_536);
let server_ws =
tokio_tungstenite::WebSocketStream::from_raw_socket(server_io, Role::Server, None)
.await;
let client_ws =
tokio_tungstenite::WebSocketStream::from_raw_socket(client_io, Role::Client, None)
.await;
let (_, client_source) = client_ws.split();
(client_source, server_ws)
}
#[tokio::test]
async fn router_close_wakes_pending_and_clears_subscribers() {
let (client_source, server_ws) = in_memory_ws_pair().await;
let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
let subscribers: SubMap = Arc::new(Mutex::new(HashMap::new()));
let (delivery_tx, _delivery_rx) = mpsc::unbounded_channel();
let (tx, rx) = oneshot::channel();
pending.lock().await.insert("req-0".to_owned(), tx);
let (sub_tx, mut sub_rx) = mpsc::channel(8);
subscribers.lock().await.insert("orders".to_owned(), sub_tx);
let log_tx: LogSender = Arc::new(Mutex::new(None));
let task = tokio::spawn(router_task(
client_source,
Arc::clone(&pending),
Arc::clone(&subscribers),
delivery_tx,
Arc::clone(&log_tx),
));
drop(server_ws);
task.await.unwrap();
let result = rx.await.expect("oneshot was not sent");
assert!(result.is_err(), "expected Err, got Ok");
assert!(
sub_rx.recv().await.is_none(),
"subscriber channel should be closed after connection close"
);
assert!(pending.lock().await.is_empty());
assert!(subscribers.lock().await.is_empty());
}
#[tokio::test]
async fn router_does_not_block_on_full_subscriber_channel() {
let (client_source, mut server_ws) = in_memory_ws_pair().await;
let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
let subscribers: SubMap = Arc::new(Mutex::new(HashMap::new()));
let (delivery_tx, delivery_rx) = mpsc::unbounded_channel();
let (sub_tx, _sub_rx) = mpsc::channel(1);
subscribers.lock().await.insert("q".to_owned(), sub_tx);
let (req_tx, req_rx) = oneshot::channel();
pending.lock().await.insert("req-1".to_owned(), req_tx);
let log_tx: LogSender = Arc::new(Mutex::new(None));
let router_handle = tokio::spawn(router_task(
client_source,
Arc::clone(&pending),
Arc::clone(&subscribers),
delivery_tx,
Arc::clone(&log_tx),
));
let delivery_handle = tokio::spawn(delivery_task(delivery_rx, Arc::clone(&subscribers)));
for i in 0..5 {
let frame = json!({
"type": "deliver",
"queue": "q",
"id": format!("msg-{}", i),
"payload": "data",
"priority": 0,
"created_at": "2026-01-01T00:00:00Z",
});
server_ws
.send(TMsg::Text(frame.to_string().into()))
.await
.unwrap();
}
let resp_frame = json!({"req_id": "req-1", "type": "ok"});
server_ws
.send(TMsg::Text(resp_frame.to_string().into()))
.await
.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), req_rx)
.await
.expect("timed out waiting for req-1 response")
.expect("oneshot recv error");
assert!(result.is_ok());
drop(server_ws);
let _ = router_handle.await;
let _ = delivery_handle.await;
}
#[tokio::test]
async fn request_timeout_cleans_up_pending() {
let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
let (tx, rx) = oneshot::channel::<Result<Value, QrustyClientError>>();
pending.lock().await.insert("req-timeout".to_owned(), tx);
let timeout_dur = Duration::from_millis(50);
let result = tokio::time::timeout(timeout_dur, rx).await;
assert!(result.is_err(), "should have timed out");
pending.lock().await.remove("req-timeout");
assert!(pending.lock().await.is_empty());
}
#[tokio::test]
async fn router_routes_log_frames_to_log_channel() {
let (client_source, mut server_ws) = in_memory_ws_pair().await;
let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
let subscribers: SubMap = Arc::new(Mutex::new(HashMap::new()));
let (delivery_tx, _delivery_rx) = mpsc::unbounded_channel();
let (log_tx_inner, mut log_rx) = mpsc::channel::<LogEntry>(256);
let log_tx: LogSender = Arc::new(Mutex::new(Some(log_tx_inner)));
let router_handle = tokio::spawn(router_task(
client_source,
Arc::clone(&pending),
Arc::clone(&subscribers),
delivery_tx,
Arc::clone(&log_tx),
));
let log_frame = json!({
"type": "log",
"timestamp": "2026-03-10T00:00:00Z",
"level": "INFO",
"message": "test log message",
});
server_ws
.send(TMsg::Text(log_frame.to_string().into()))
.await
.unwrap();
let entry = tokio::time::timeout(Duration::from_secs(2), log_rx.recv())
.await
.expect("timed out waiting for log entry")
.expect("log channel closed");
assert_eq!(entry.timestamp, "2026-03-10T00:00:00Z");
assert_eq!(entry.level, "INFO");
assert_eq!(entry.message, "test log message");
drop(server_ws);
let _ = router_handle.await;
}
#[tokio::test]
async fn router_drops_log_frames_when_no_subscriber() {
let (client_source, mut server_ws) = in_memory_ws_pair().await;
let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
let subscribers: SubMap = Arc::new(Mutex::new(HashMap::new()));
let (delivery_tx, _delivery_rx) = mpsc::unbounded_channel();
let log_tx: LogSender = Arc::new(Mutex::new(None));
let (req_tx, req_rx) = oneshot::channel();
pending.lock().await.insert("req-1".to_owned(), req_tx);
let router_handle = tokio::spawn(router_task(
client_source,
Arc::clone(&pending),
Arc::clone(&subscribers),
delivery_tx,
Arc::clone(&log_tx),
));
let log_frame = json!({
"type": "log",
"timestamp": "2026-03-10T00:00:00Z",
"level": "WARN",
"message": "dropped",
});
server_ws
.send(TMsg::Text(log_frame.to_string().into()))
.await
.unwrap();
let resp_frame = json!({"req_id": "req-1", "type": "ok"});
server_ws
.send(TMsg::Text(resp_frame.to_string().into()))
.await
.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), req_rx)
.await
.expect("timed out waiting for req-1")
.expect("oneshot recv error");
assert!(result.is_ok());
drop(server_ws);
let _ = router_handle.await;
}
}