use super::buffer::{
BufferedRewindMessage, ByteCounter, MessageSenderHandle, RewindGate, SizedMessage,
SizedMessageSenderHandle, WebSocketBufferConfig,
};
use super::capabilities::ConnectionCapabilities;
use super::connection::WebSocket;
use super::socket_id::SocketId;
use crate::capability_token::TokenAuthContext;
use crate::error::{Error, Result};
use bytes::Bytes;
use crossfire::TrySendError;
use dashmap::DashMap;
use sockudo_filter::FilterNode;
use sockudo_protocol::messages::PusherMessage;
use sockudo_protocol::{ProtocolVersion, WireFormat};
use sonic_rs::Value;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
use tracing::warn;
#[derive(Clone)]
pub struct WebSocketRef {
pub broadcast_tx: SizedMessageSenderHandle,
pub message_sender: MessageSenderHandle,
pub channel_filters: Arc<DashMap<String, Option<Arc<FilterNode>>>>,
pub event_name_filters: Arc<DashMap<String, Option<Vec<String>>>>,
pub annotation_subscriptions: Arc<DashMap<String, bool>>,
pub attach_serials: Arc<DashMap<String, u64>>,
pub rewind_gates: Arc<DashMap<String, Arc<Mutex<RewindGate>>>>,
pub socket_id: SocketId,
pub buffer_config: WebSocketBufferConfig,
pub byte_counter: Option<Arc<ByteCounter>>,
pub shutdown_token: CancellationToken,
closing: Arc<std::sync::atomic::AtomicBool>,
pub inner: Arc<Mutex<WebSocket>>,
pub protocol_version: ProtocolVersion,
pub wire_format: WireFormat,
pub echo_messages: bool,
}
impl WebSocketRef {
pub fn new(websocket: WebSocket) -> Self {
let broadcast_tx = websocket.broadcast_tx.clone();
let message_sender = websocket.message_sender.sender_handle();
let socket_id = *websocket.get_socket_id();
let buffer_config = websocket.buffer_config;
let byte_counter = websocket.byte_counter.clone();
let shutdown_token = websocket.shutdown_token.clone();
let protocol_version = websocket.state.protocol_version;
let wire_format = websocket.state.wire_format;
let echo_messages = websocket.state.echo_messages;
let channel_filters = Arc::new(DashMap::new());
for (channel, filter) in &websocket.state.subscribed_channels {
channel_filters.insert(channel.clone(), filter.clone().map(Arc::new));
}
let event_name_filters = Arc::new(DashMap::new());
let annotation_subscriptions = Arc::new(DashMap::new());
let attach_serials = Arc::new(DashMap::new());
let rewind_gates = Arc::new(DashMap::new());
Self {
broadcast_tx,
message_sender,
channel_filters,
event_name_filters,
annotation_subscriptions,
attach_serials,
rewind_gates,
socket_id,
buffer_config,
byte_counter,
shutdown_token,
closing: Arc::new(std::sync::atomic::AtomicBool::new(false)),
protocol_version,
wire_format,
echo_messages,
inner: Arc::new(Mutex::new(websocket)),
}
}
#[inline]
pub fn send_broadcast(&self, bytes: Bytes) -> Result<()> {
let msg_size = bytes.len();
if let Some(ref counter) = self.byte_counter
&& let Some(byte_limit) = self.buffer_config.limit.byte_limit()
&& counter.would_exceed(msg_size, byte_limit)
{
return self.handle_buffer_full("byte limit", byte_limit, Some(msg_size));
}
let sized_msg = SizedMessage::new(bytes);
match self.broadcast_tx.try_send(sized_msg) {
Ok(()) => {
if let Some(ref counter) = self.byte_counter {
counter.add(msg_size);
}
Ok(())
}
Err(TrySendError::Full(_)) => {
let limit = self.buffer_config.limit.message_limit().unwrap_or(0);
self.handle_buffer_full("message limit", limit, None)
}
Err(TrySendError::Disconnected(_)) => Err(Error::ConnectionClosed(
"Broadcast channel closed".to_string(),
)),
}
}
#[inline]
fn handle_buffer_full(
&self,
limit_type: &str,
limit_value: usize,
msg_size: Option<usize>,
) -> Result<()> {
if self.buffer_config.disconnect_on_full {
let size_info = msg_size
.map(|s| format!(", message size: {} bytes", s))
.unwrap_or_default();
Err(Error::BufferFull(format!(
"Client buffer full ({}: {}{}), disconnecting slow consumer",
limit_type, limit_value, size_info
)))
} else {
warn!(
socket_id = %self.socket_id,
limit_type = limit_type,
limit_value = limit_value,
"Dropping message for slow consumer (buffer full)"
);
Ok(())
}
}
pub fn buffer_stats(&self) -> BufferStats {
BufferStats {
pending_bytes: self.byte_counter.as_ref().map(|c| c.get()),
byte_limit: self.buffer_config.limit.byte_limit(),
message_limit: self.buffer_config.limit.message_limit(),
}
}
pub fn send_message(&self, message: &PusherMessage) -> Result<()> {
if self.closing.load(Ordering::Acquire) || self.shutdown_token.is_cancelled() {
return Err(Error::ConnectionClosed("Connection shutting down".into()));
}
let payload = sockudo_protocol::wire::serialize_message(message, self.wire_format)
.map_err(|e| Error::InvalidMessageFormat(format!("Serialization failed: {e}")))?;
if self.wire_format.is_binary() {
self.message_sender
.try_send(sockudo_ws::Message::Binary(Bytes::from(payload)))
.map_err(|e| match e {
TrySendError::Full(_) => Error::BufferFull("Message buffer full".into()),
TrySendError::Disconnected(_) => {
Error::ConnectionClosed("Channel closed".into())
}
})
} else {
let text = String::from_utf8(payload).map_err(|e| {
Error::InvalidMessageFormat(format!("JSON payload is not UTF-8: {e}"))
})?;
self.message_sender
.try_send(sockudo_ws::Message::text(text))
.map_err(|e| match e {
TrySendError::Full(_) => Error::BufferFull("Message buffer full".into()),
TrySendError::Disconnected(_) => {
Error::ConnectionClosed("Channel closed".into())
}
})
}
}
pub async fn close(&self, code: u16, reason: String) -> Result<()> {
self.closing.store(true, Ordering::Release);
let result = {
let mut ws = self.inner.lock().await;
ws.close(code, reason).await
};
self.shutdown_token.cancel();
result
}
pub fn shutdown(&self) {
self.shutdown_token.cancel();
}
pub fn cancellation_token(&self) -> CancellationToken {
self.shutdown_token.clone()
}
pub fn get_socket_id_sync(&self) -> &SocketId {
&self.socket_id
}
pub async fn get_socket_id(&self) -> SocketId {
self.socket_id
}
pub async fn is_subscribed_to(&self, channel: &str) -> bool {
let ws = self.inner.lock().await;
ws.is_subscribed_to(channel)
}
pub async fn get_user_id(&self) -> Option<String> {
let ws = self.inner.lock().await;
ws.state.user_id.clone()
}
pub async fn get_connection_capabilities(&self) -> Option<ConnectionCapabilities> {
let ws = self.inner.lock().await;
ws.state.connection_capabilities.clone()
}
pub async fn get_token_auth_context(&self) -> Option<TokenAuthContext> {
let ws = self.inner.lock().await;
ws.state.token_auth_context.clone()
}
pub async fn set_token_auth_context(&self, context: TokenAuthContext) {
let mut ws = self.inner.lock().await;
ws.set_token_auth_context(context);
}
pub async fn get_connection_meta(&self) -> Option<Value> {
let ws = self.inner.lock().await;
ws.state.connection_meta.clone()
}
pub async fn update_activity(&self) {
let mut ws = self.inner.lock().await;
ws.update_activity();
}
pub async fn subscribe_to_channel(&self, channel: String) {
let mut ws = self.inner.lock().await;
ws.subscribe_to_channel(channel.clone());
self.channel_filters.insert(channel.clone(), None);
self.event_name_filters.insert(channel.clone(), None);
self.annotation_subscriptions.insert(channel.clone(), false);
self.attach_serials.remove(&channel);
}
pub async fn subscribe_to_channel_with_filter(
&self,
channel: String,
mut filter: Option<FilterNode>,
) {
if let Some(ref mut f) = filter {
f.optimize();
}
let mut ws = self.inner.lock().await;
ws.subscribe_to_channel_with_filter(channel.clone(), filter.clone());
self.channel_filters
.insert(channel.clone(), filter.map(Arc::new));
self.event_name_filters.insert(channel.clone(), None);
self.annotation_subscriptions.insert(channel.clone(), false);
self.attach_serials.remove(&channel);
}
pub async fn subscribe_to_channel_with_filters(
&self,
channel: String,
mut tag_filter: Option<FilterNode>,
event_name_filter: Option<Vec<String>>,
annotation_subscribe: bool,
) {
if let Some(ref mut f) = tag_filter {
f.optimize();
}
let mut ws = self.inner.lock().await;
ws.subscribe_to_channel_with_filter(channel.clone(), tag_filter.clone());
self.channel_filters
.insert(channel.clone(), tag_filter.map(Arc::new));
self.event_name_filters
.insert(channel.clone(), event_name_filter);
self.annotation_subscriptions
.insert(channel.clone(), annotation_subscribe);
self.attach_serials.remove(&channel);
}
pub async fn unsubscribe_from_channel(&self, channel: &str) -> bool {
let mut ws = self.inner.lock().await;
let result = ws.unsubscribe_from_channel(channel);
self.channel_filters.remove(channel);
self.event_name_filters.remove(channel);
self.annotation_subscriptions.remove(channel);
self.attach_serials.remove(channel);
result
}
pub async fn get_channel_filter(&self, channel: &str) -> Option<Arc<FilterNode>> {
self.channel_filters
.get(channel)
.and_then(|entry| entry.value().clone())
}
pub fn get_channel_filter_sync(&self, channel: &str) -> Option<Arc<FilterNode>> {
self.channel_filters
.get(channel)
.and_then(|entry| entry.value().clone())
}
pub fn get_event_name_filter_sync(&self, channel: &str) -> Option<Vec<String>> {
self.event_name_filters
.get(channel)
.and_then(|entry| entry.value().clone())
}
pub fn allows_annotation_events_sync(&self, channel: &str) -> bool {
self.annotation_subscriptions
.get(channel)
.is_some_and(|entry| *entry.value())
}
pub fn set_attach_serial(&self, channel: String, serial: u64) {
self.attach_serials.insert(channel, serial);
}
pub fn attach_serial(&self, channel: &str) -> Option<u64> {
self.attach_serials.get(channel).map(|entry| *entry.value())
}
pub fn start_rewind_gate(&self, channel: String) {
self.rewind_gates
.insert(channel, Arc::new(Mutex::new(RewindGate::default())));
}
pub async fn buffer_rewind_message(&self, channel: &str, message: &PusherMessage) -> bool {
let Some(gate) = self
.rewind_gates
.get(channel)
.map(|entry| entry.value().clone())
else {
return false;
};
let mut gate = gate.lock().await;
gate.buffered.push(BufferedRewindMessage {
serial: message.serial,
message_id: message.message_id.clone(),
message: message.clone(),
});
true
}
pub async fn finish_rewind_gate(&self, channel: &str) -> Vec<BufferedRewindMessage> {
let Some((_, gate)) = self.rewind_gates.remove(channel) else {
return Vec::new();
};
let mut gate = gate.lock().await;
std::mem::take(&mut gate.buffered)
}
}
impl Hash for WebSocketRef {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
let ptr = Arc::as_ptr(&self.inner) as *const () as usize;
ptr.hash(state);
}
}
impl PartialEq for WebSocketRef {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
}
impl Eq for WebSocketRef {}
impl Debug for WebSocketRef {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebSocketRef")
.field("ptr", &Arc::as_ptr(&self.inner))
.finish()
}
}
pub trait WebSocketExt {
async fn send_pusher_message(&self, message: PusherMessage) -> Result<()>;
async fn send_error(&self, code: u16, message: String, channel: Option<String>) -> Result<()>;
async fn send_pong(&self) -> Result<()>;
}
impl WebSocketExt for WebSocketRef {
async fn send_pusher_message(&self, message: PusherMessage) -> Result<()> {
self.send_message(&message)
}
async fn send_error(&self, code: u16, message: String, channel: Option<String>) -> Result<()> {
let error_msg = PusherMessage::error(u32::from(code), message, channel);
self.send_message(&error_msg)
}
async fn send_pong(&self) -> Result<()> {
let pong_msg = PusherMessage::pong();
self.send_message(&pong_msg)
}
}
#[derive(Debug, Clone)]
pub struct BufferStats {
pub pending_bytes: Option<usize>,
pub byte_limit: Option<usize>,
pub message_limit: Option<usize>,
}