use crate::ConnectionManager;
#[cfg(feature = "tag-filtering")]
use crate::filter_index::FilterIndex;
use ahash::AHashMap as HashMap;
use async_trait::async_trait;
use bytes::Bytes;
use dashmap::DashMap;
use sockudo_core::app::AppManager;
use sockudo_core::channel::PresenceMemberInfo;
use sockudo_core::error::{Error, Result};
use sockudo_core::namespace::{Namespace, SocketInitOptions};
use sockudo_core::websocket::{SocketId, WebSocketRef};
use sockudo_protocol::messages::{
ANNOTATION_EVENT_NAME, MESSAGE_SUMMARY_EVENT_NAME, PusherMessage, generate_message_id,
};
use sockudo_protocol::versioned_messages::extract_runtime_action;
use sockudo_ws::axum_integration::WebSocketWriter;
use std::any::Any;
use std::sync::Arc;
#[cfg(feature = "delta")]
use std::sync::OnceLock;
#[cfg(feature = "tag-filtering")]
use std::sync::atomic::AtomicBool;
#[cfg(any(feature = "tag-filtering", feature = "delta"))]
use std::sync::atomic::Ordering;
use tokio::sync::Semaphore;
use tracing::{debug, error, info, warn};
fn filter_annotation_subscribers_in_place(
channel: &str,
message: &PusherMessage,
socket_refs: &mut Vec<WebSocketRef>,
) {
if message.event.as_deref() != Some(ANNOTATION_EVENT_NAME) {
return;
}
socket_refs.retain(|socket_ref| socket_ref.allows_annotation_events_sync(channel));
}
fn is_v2_annotation_protocol_event(message: &PusherMessage) -> bool {
matches!(
message.event.as_deref(),
Some(ANNOTATION_EVENT_NAME) | Some(MESSAGE_SUMMARY_EVENT_NAME)
)
}
pub struct LocalAdapter {
pub namespaces: Arc<DashMap<String, Arc<Namespace>>>,
pub buffer_multiplier_per_cpu: usize,
pub max_concurrent: usize,
broadcast_semaphore: Arc<Semaphore>,
#[cfg(feature = "delta")]
delta_compression: Arc<OnceLock<Arc<sockudo_delta::DeltaCompressionManager>>>,
#[cfg(feature = "delta")]
app_manager: Arc<OnceLock<Arc<dyn AppManager + Send + Sync>>>,
#[cfg(feature = "tag-filtering")]
tag_filtering_enabled: Arc<AtomicBool>,
#[cfg(feature = "tag-filtering")]
enable_tags_globally: Arc<AtomicBool>,
#[cfg(feature = "tag-filtering")]
filter_index: Arc<FilterIndex>,
}
impl Clone for LocalAdapter {
fn clone(&self) -> Self {
Self {
namespaces: Arc::clone(&self.namespaces),
buffer_multiplier_per_cpu: self.buffer_multiplier_per_cpu,
max_concurrent: self.max_concurrent,
broadcast_semaphore: Arc::clone(&self.broadcast_semaphore),
#[cfg(feature = "delta")]
delta_compression: Arc::clone(&self.delta_compression),
#[cfg(feature = "delta")]
app_manager: Arc::clone(&self.app_manager),
#[cfg(feature = "tag-filtering")]
tag_filtering_enabled: Arc::clone(&self.tag_filtering_enabled),
#[cfg(feature = "tag-filtering")]
enable_tags_globally: Arc::clone(&self.enable_tags_globally),
#[cfg(feature = "tag-filtering")]
filter_index: Arc::clone(&self.filter_index),
}
}
}
impl Default for LocalAdapter {
fn default() -> Self {
Self::new()
}
}
impl LocalAdapter {
pub fn new() -> Self {
Self::new_with_buffer_multiplier(64)
}
#[cfg(feature = "tag-filtering")]
pub fn set_tag_filtering_enabled(&self, enabled: bool) {
self.tag_filtering_enabled.store(enabled, Ordering::Release);
}
#[cfg(feature = "tag-filtering")]
pub fn is_tag_filtering_enabled(&self) -> bool {
self.tag_filtering_enabled.load(Ordering::Acquire)
}
#[cfg(feature = "tag-filtering")]
pub fn set_enable_tags_globally(&self, enabled: bool) {
self.enable_tags_globally.store(enabled, Ordering::Release);
}
#[cfg(feature = "tag-filtering")]
pub fn get_enable_tags_globally(&self) -> bool {
self.enable_tags_globally.load(Ordering::Acquire)
}
pub fn new_with_buffer_multiplier(multiplier: usize) -> Self {
let cpu_cores = num_cpus::get();
let max_concurrent = cpu_cores * multiplier;
info!(
"LocalAdapter initialized with {} CPU cores, buffer multiplier {}, max concurrent {}",
cpu_cores, multiplier, max_concurrent
);
Self {
namespaces: Arc::new(DashMap::new()),
buffer_multiplier_per_cpu: multiplier,
max_concurrent,
broadcast_semaphore: Arc::new(Semaphore::new(max_concurrent * 8)),
#[cfg(feature = "delta")]
delta_compression: Arc::new(OnceLock::new()),
#[cfg(feature = "delta")]
app_manager: Arc::new(OnceLock::new()),
#[cfg(feature = "tag-filtering")]
tag_filtering_enabled: Arc::new(AtomicBool::new(false)),
#[cfg(feature = "tag-filtering")]
enable_tags_globally: Arc::new(AtomicBool::new(true)), #[cfg(feature = "tag-filtering")]
filter_index: Arc::new(FilterIndex::new()),
}
}
#[cfg(feature = "tag-filtering")]
pub fn get_filter_index(&self) -> Arc<FilterIndex> {
Arc::clone(&self.filter_index)
}
#[cfg(feature = "delta")]
pub async fn set_delta_compression(
&self,
delta_compression: Arc<sockudo_delta::DeltaCompressionManager>,
app_manager: Arc<dyn AppManager + Send + Sync>,
) {
let _ = self.delta_compression.set(delta_compression);
let _ = self.app_manager.set(app_manager);
}
#[cfg(feature = "delta")]
#[inline]
pub fn get_delta_compression(&self) -> Option<&Arc<sockudo_delta::DeltaCompressionManager>> {
self.delta_compression.get()
}
#[cfg(feature = "delta")]
#[inline]
pub fn get_app_manager(&self) -> Option<&Arc<dyn AppManager + Send + Sync>> {
self.app_manager.get()
}
}
#[cfg(feature = "delta")]
#[allow(dead_code)]
struct SocketMessageParams<'a> {
socket_ref: WebSocketRef,
base_message: PusherMessage,
base_message_bytes: Vec<u8>,
channel: &'a str,
event_name: &'a str,
delta_compression: Arc<sockudo_delta::DeltaCompressionManager>,
channel_settings: Option<&'a sockudo_delta::ChannelDeltaSettings>,
tag_filtering_enabled: bool,
}
#[cfg(feature = "delta")]
#[allow(dead_code)]
struct PrecomputedDeltaParams<'a> {
socket_ref: WebSocketRef,
base_message: PusherMessage,
base_message_bytes: Vec<u8>,
channel: &'a str,
event_name: &'a str,
delta_compression: Arc<sockudo_delta::DeltaCompressionManager>,
channel_settings: Option<&'a sockudo_delta::ChannelDeltaSettings>,
tag_filtering_enabled: bool,
precomputed_delta: Option<(Arc<Vec<u8>>, u32)>,
conflation_key: String,
base_hash: u64,
}
impl LocalAdapter {
async fn send_protocol_messages_concurrent(
&self,
target_socket_refs: Vec<WebSocketRef>,
message: PusherMessage,
) -> Vec<Result<()>> {
use futures::stream::{self, StreamExt};
let socket_count = target_socket_refs.len();
let target_chunks = socket_count.div_ceil(self.max_concurrent).clamp(1, 8);
let socket_chunk_size = (socket_count / target_chunks)
.min(self.max_concurrent)
.max(1);
let mut results = Vec::with_capacity(socket_count);
for socket_chunk in target_socket_refs.chunks(socket_chunk_size) {
let chunk_size = socket_chunk.len();
match self
.broadcast_semaphore
.acquire_many(chunk_size as u32)
.await
{
Ok(_permits) => {
let chunk_vec: Vec<_> = socket_chunk.to_vec();
let chunk_results: Vec<Result<()>> = stream::iter(chunk_vec)
.map(|socket_ref| {
let msg = message.clone();
async move { socket_ref.send_message(&msg).await }
})
.buffer_unordered(chunk_size)
.collect()
.await;
results.extend(chunk_results);
}
Err(_) => {
for _ in 0..chunk_size {
results.push(Err(Error::Connection(
"Broadcast semaphore unavailable".to_string(),
)));
}
}
}
}
results
}
async fn send_messages_concurrent(
&self,
target_socket_refs: Vec<WebSocketRef>,
message_bytes: Bytes,
) -> Vec<Result<()>> {
use futures::stream::{self, StreamExt};
let socket_count = target_socket_refs.len();
let target_chunks = socket_count.div_ceil(self.max_concurrent).clamp(1, 8);
let socket_chunk_size = (socket_count / target_chunks)
.min(self.max_concurrent)
.max(1);
let mut results = Vec::with_capacity(socket_count);
for socket_chunk in target_socket_refs.chunks(socket_chunk_size) {
let chunk_size = socket_chunk.len();
match self
.broadcast_semaphore
.acquire_many(chunk_size as u32)
.await
{
Ok(_permits) => {
let chunk_vec: Vec<_> = socket_chunk.to_vec();
let chunk_results: Vec<Result<()>> = stream::iter(chunk_vec)
.map(|socket_ref| {
let bytes = message_bytes.clone();
async move { socket_ref.send_broadcast(bytes) }
})
.buffer_unordered(chunk_size)
.collect()
.await;
results.extend(chunk_results);
}
Err(_) => {
for _ in 0..chunk_size {
results.push(Err(Error::Connection(
"Broadcast semaphore unavailable".to_string(),
)));
}
}
}
}
results
}
#[cfg(feature = "delta")]
async fn send_messages_with_compression(
&self,
target_socket_refs: Vec<WebSocketRef>,
base_message: PusherMessage,
base_message_bytes: Vec<u8>,
channel: &str,
event_name: &str,
compression: crate::connection_manager::CompressionParams<'_>,
) -> Vec<Result<()>> {
let delta_compression = compression.delta_compression;
let channel_settings = compression.channel_settings;
use futures::stream::{self, StreamExt};
let socket_count = target_socket_refs.len();
let tag_filtering = self.tag_filtering_enabled.load(Ordering::Acquire);
let mut socket_groups: HashMap<(String, u64), Vec<WebSocketRef>> = HashMap::new();
let mut non_delta_sockets = Vec::new();
for socket_ref in target_socket_refs {
let socket_id = socket_ref.get_socket_id_sync();
if delta_compression.is_enabled_for_socket_channel(socket_id, channel) {
debug!(
"Socket {} has delta compression enabled for channel {}",
socket_id, channel
);
let conflation_key_path = channel_settings
.and_then(|s| s.conflation_key.as_ref())
.or(delta_compression.get_conflation_key_path());
debug!(
"Conflation key path for socket {}: {:?}",
socket_id, conflation_key_path
);
let conflation_key = if let Some(path) = conflation_key_path {
let extracted = delta_compression
.extract_conflation_key_from_path(&base_message_bytes, path);
debug!(
"Extracted conflation key from path '{}': '{}' (base_message_bytes len={})",
path,
extracted,
base_message_bytes.len()
);
extracted
} else {
debug!("No conflation key path configured");
String::new()
};
let cache_key = if conflation_key.is_empty() {
event_name.to_string()
} else {
format!("{}:{}", event_name, conflation_key)
};
debug!(
"Looking for base message: socket={}, channel={}, cache_key='{}', event_name='{}'",
socket_id, channel, cache_key, event_name
);
let base_msg =
delta_compression.get_last_message_for_socket(socket_id, channel, &cache_key);
let base_msg_opt = base_msg.await;
debug!(
"Base message lookup result: socket={}, found={}",
socket_id,
base_msg_opt.is_some()
);
let base_hash = if let Some(base) = base_msg_opt {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
base.as_ref().hash(&mut hasher);
hasher.finish()
} else {
0 };
debug!(
"Socket {} grouped with conflation_key='{}', base_hash={} (has_base={})",
socket_id,
conflation_key,
base_hash,
base_hash != 0
);
socket_groups
.entry((conflation_key, base_hash))
.or_default()
.push(socket_ref);
} else {
debug!(
"Socket {} does NOT have delta compression enabled",
socket_id
);
non_delta_sockets.push(socket_ref);
}
}
type PrecomputedDelta = Option<(Arc<Vec<u8>>, u32)>;
let mut precomputed_deltas: HashMap<(String, u64), PrecomputedDelta> = HashMap::new();
debug!(
"Pre-computing deltas for {} socket groups ({} non-delta sockets)",
socket_groups.len(),
non_delta_sockets.len()
);
for ((conflation_key, base_hash), group_sockets) in &socket_groups {
if *base_hash == 0 {
debug!(
"Group (conflation_key='{}', base_hash=0) has {} sockets - will send FULL messages (first message)",
conflation_key,
group_sockets.len()
);
precomputed_deltas.insert((conflation_key.clone(), *base_hash), None);
continue;
}
if let Some(first_socket) = group_sockets.first() {
let socket_id = first_socket.get_socket_id_sync();
let cache_key = if conflation_key.is_empty() {
event_name.to_string()
} else {
format!("{}:{}", event_name, conflation_key)
};
let should_send_full =
delta_compression.should_send_full_message(socket_id, channel, &cache_key);
if should_send_full {
debug!(
"Group (conflation_key='{}', base_hash={}) has {} sockets - sending FULL message due to interval",
conflation_key,
base_hash,
group_sockets.len()
);
precomputed_deltas.insert((conflation_key.clone(), *base_hash), None);
continue;
}
if let Some((base_msg, base_sequence)) = delta_compression
.get_last_message_with_sequence(socket_id, channel, &cache_key)
.await
{
match delta_compression
.compute_delta_for_broadcast(&base_msg, &base_message_bytes)
{
Ok(delta) => {
if delta.len() < base_message_bytes.len() {
debug!(
"Group (conflation_key='{}', base_hash={}) has {} sockets - computed delta: {} bytes (vs {} bytes original, {:.1}% savings), base_seq={}",
conflation_key,
base_hash,
group_sockets.len(),
delta.len(),
base_message_bytes.len(),
(1.0 - (delta.len() as f64 / base_message_bytes.len() as f64))
* 100.0,
base_sequence
);
precomputed_deltas.insert(
(conflation_key.clone(), *base_hash),
Some((Arc::new(delta), base_sequence)),
);
} else {
debug!(
"Group (conflation_key='{}', base_hash={}) has {} sockets - delta NOT beneficial ({} >= {} bytes), sending FULL",
conflation_key,
base_hash,
group_sockets.len(),
delta.len(),
base_message_bytes.len()
);
precomputed_deltas
.insert((conflation_key.clone(), *base_hash), None);
}
}
Err(e) => {
warn!(
"Group (conflation_key='{}', base_hash={}) delta computation FAILED: {}, sending FULL",
conflation_key, base_hash, e
);
precomputed_deltas.insert((conflation_key.clone(), *base_hash), None);
}
}
} else {
precomputed_deltas.insert((conflation_key.clone(), *base_hash), None);
}
}
}
let target_chunks = socket_count.div_ceil(self.max_concurrent).clamp(1, 8);
let socket_chunk_size = (socket_count / target_chunks)
.min(self.max_concurrent)
.max(1);
let mut results = Vec::with_capacity(socket_count);
for ((conflation_key, base_hash), group_sockets) in socket_groups {
let precomputed_delta = precomputed_deltas
.get(&(conflation_key.clone(), base_hash))
.and_then(|d| d.as_ref())
.cloned();
for socket_chunk in group_sockets.chunks(socket_chunk_size) {
let chunk_size = socket_chunk.len();
match self
.broadcast_semaphore
.acquire_many(chunk_size as u32)
.await
{
Ok(_permits) => {
let chunk_vec: Vec<_> = socket_chunk.to_vec();
let base_msg = base_message.clone();
let base_bytes = base_message_bytes.clone();
let channel_str = channel.to_string();
let event_str = event_name.to_string();
let delta_comp = Arc::clone(&delta_compression);
let ch_settings = channel_settings.cloned();
let precomp_delta = precomputed_delta.clone();
let conf_key = conflation_key.clone();
let chunk_results: Vec<Result<()>> = stream::iter(chunk_vec)
.map(|socket_ref| {
let msg = base_msg.clone();
let bytes = base_bytes.clone();
let ch = channel_str.clone();
let ev = event_str.clone();
let dc = Arc::clone(&delta_comp);
let settings = ch_settings.clone();
let delta = precomp_delta.clone();
let ck = conf_key.clone();
async move {
Self::send_to_socket_with_precomputed_delta(
PrecomputedDeltaParams {
socket_ref,
base_message: msg,
base_message_bytes: bytes,
channel: &ch,
event_name: &ev,
delta_compression: dc,
channel_settings: settings.as_ref(),
tag_filtering_enabled: tag_filtering,
precomputed_delta: delta,
conflation_key: ck,
base_hash,
},
)
.await
}
})
.buffer_unordered(chunk_size)
.collect()
.await;
results.extend(chunk_results);
}
Err(_) => {
for _ in 0..chunk_size {
results.push(Err(Error::Connection(
"Broadcast semaphore unavailable".to_string(),
)));
}
}
}
}
}
for socket_chunk in non_delta_sockets.chunks(socket_chunk_size) {
let chunk_size = socket_chunk.len();
match self
.broadcast_semaphore
.acquire_many(chunk_size as u32)
.await
{
Ok(_permits) => {
let chunk_vec: Vec<_> = socket_chunk.to_vec();
let base_msg = base_message.clone();
let base_bytes = base_message_bytes.clone();
let channel_str = channel.to_string();
let event_str = event_name.to_string();
let delta_comp = Arc::clone(&delta_compression);
let ch_settings = channel_settings.cloned();
let chunk_results: Vec<Result<()>> = stream::iter(chunk_vec)
.map(|socket_ref| {
let msg = base_msg.clone();
let bytes = base_bytes.clone();
let ch = channel_str.clone();
let ev = event_str.clone();
let dc = Arc::clone(&delta_comp);
let settings = ch_settings.clone();
async move {
Self::send_to_socket_with_compression(SocketMessageParams {
socket_ref,
base_message: msg,
base_message_bytes: bytes,
channel: &ch,
event_name: &ev,
delta_compression: dc,
channel_settings: settings.as_ref(),
tag_filtering_enabled: tag_filtering,
})
.await
}
})
.buffer_unordered(chunk_size)
.collect()
.await;
results.extend(chunk_results);
}
Err(_) => {
for _ in 0..chunk_size {
results.push(Err(Error::Connection(
"Broadcast semaphore unavailable".to_string(),
)));
}
}
}
}
results
}
#[cfg(feature = "delta")]
async fn send_to_socket_with_compression(params: SocketMessageParams<'_>) -> Result<()> {
use sockudo_delta::CompressionResult;
let SocketMessageParams {
socket_ref,
base_message,
base_message_bytes,
channel,
event_name,
delta_compression,
channel_settings,
tag_filtering_enabled: _,
} = params;
let socket_id = socket_ref.get_socket_id_sync();
let (compression_result, message_with_sequence) = if !delta_compression
.is_enabled_for_socket_channel(socket_id, channel)
{
(CompressionResult::Uncompressed, base_message_bytes.clone())
} else {
let conflation_key_path = channel_settings
.and_then(|s| s.conflation_key.as_ref())
.or(delta_compression.get_conflation_key_path());
let conflation_key = if let Some(path) = conflation_key_path {
delta_compression.extract_conflation_key_from_path(&base_message_bytes, path)
} else {
String::new()
};
let cache_key = if conflation_key.is_empty() {
event_name.to_string()
} else {
format!("{}:{}", event_name, conflation_key)
};
let result = delta_compression
.compress_message(
socket_id,
channel,
event_name,
&base_message_bytes,
channel_settings,
)
.await?;
let next_sequence = delta_compression.get_next_sequence(socket_id, channel, &cache_key);
let msg_with_seq = if let Ok(base_str) = std::str::from_utf8(&base_message_bytes) {
if let Some(last_brace) = base_str.rfind('}') {
let mut result = String::with_capacity(base_str.len() + 100);
result.push_str(&base_str[..last_brace]);
if last_brace > 1
&& base_str[..last_brace]
.trim_end()
.ends_with(|c| c != '{' && c != ',')
{
result.push(',');
}
result.push_str(&format!("\"__delta_seq\":{}", next_sequence));
if !conflation_key.is_empty() {
result.push_str(&format!(",\"__conflation_key\":\"{}\"", conflation_key));
}
result.push('}');
result.into_bytes()
} else {
base_message_bytes.clone()
}
} else {
base_message_bytes.clone()
};
(result, msg_with_seq)
};
match compression_result {
CompressionResult::Uncompressed => socket_ref.send_message(&base_message).await,
CompressionResult::FullMessage {
sequence,
conflation_key,
} => {
info!(
"🔵 STORING FULL base message: seq={}, conflation_key={:?}, len={}, last50='{}'",
sequence,
conflation_key,
base_message_bytes.len(),
String::from_utf8_lossy(
&base_message_bytes[base_message_bytes.len().saturating_sub(50)..]
)
);
info!(
"🔵 SENDING FULL message: len={}, last100='{}'",
message_with_sequence.len(),
String::from_utf8_lossy(
&message_with_sequence[message_with_sequence.len().saturating_sub(100)..]
)
);
if let Err(e) = delta_compression
.store_sent_message(
socket_id,
channel,
event_name,
base_message_bytes.clone(),
true,
channel_settings,
)
.await
{
warn!("Failed to store full message for delta state: {e}");
}
let mut full_message = base_message;
full_message.delta_sequence = Some(sequence.into());
full_message.delta_conflation_key = conflation_key;
socket_ref.send_message(&full_message).await
}
CompressionResult::Delta {
delta,
sequence,
algorithm,
conflation_key,
base_index,
} => {
info!(
"🔴 SENDING DELTA: seq={}, base_index={:?}, conflation_key={:?}, delta_len={}, new_msg_len={}, new_msg_last50='{}'",
sequence,
base_index,
conflation_key,
delta.len(),
base_message_bytes.len(),
String::from_utf8_lossy(
&base_message_bytes[base_message_bytes.len().saturating_sub(50)..]
)
);
let algorithm_str = match algorithm {
sockudo_delta::DeltaAlgorithm::Fossil => "fossil",
sockudo_delta::DeltaAlgorithm::Xdelta3 => "xdelta3",
};
let mut delta_data = sonic_rs::json!({
"event": event_name,
"delta": base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
&delta,
),
"seq": sequence,
"algorithm": algorithm_str,
});
if let Some(key) = conflation_key.clone() {
delta_data["conflation_key"] = sonic_rs::Value::from(&key);
}
if let Some(idx) = base_index {
delta_data["base_index"] = sonic_rs::Value::from(idx as u64);
}
let mut pusher_msg = PusherMessage {
event: Some("pusher:delta".to_string()),
channel: Some(channel.to_string()),
data: Some(sockudo_protocol::messages::MessageData::Json(delta_data)),
name: None,
user_id: None,
tags: None,
sequence: None,
conflation_key: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
};
if socket_ref.protocol_version == sockudo_protocol::ProtocolVersion::V2 {
pusher_msg.rewrite_prefix(sockudo_protocol::ProtocolVersion::V2);
}
if let Err(e) = delta_compression
.store_sent_message(
socket_id,
channel,
event_name,
base_message_bytes.clone(),
false,
channel_settings,
)
.await
{
warn!("Failed to store delta base message state: {e}");
}
socket_ref.send_message(&pusher_msg).await
}
}
}
#[cfg(feature = "delta")]
async fn send_to_socket_with_precomputed_delta(
params: PrecomputedDeltaParams<'_>,
) -> Result<()> {
let PrecomputedDeltaParams {
socket_ref,
base_message,
base_message_bytes,
channel,
event_name,
delta_compression,
channel_settings,
tag_filtering_enabled: _,
precomputed_delta,
conflation_key,
base_hash,
} = params;
let socket_id = socket_ref.get_socket_id_sync();
let cache_key = if conflation_key.is_empty() {
event_name.to_string()
} else {
format!("{}:{}", event_name, conflation_key)
};
let sequence = delta_compression.get_next_sequence(socket_id, channel, &cache_key);
let (can_use_precomputed, actual_base_sequence) = if let Some((socket_base, base_seq)) =
delta_compression
.get_last_message_with_sequence(socket_id, channel, &cache_key)
.await
{
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
socket_base.as_ref().hash(&mut hasher);
let socket_base_hash = hasher.finish();
(socket_base_hash == base_hash, Some(base_seq))
} else {
(false, None)
};
if let Some((delta_bytes, _precomputed_base_seq)) = precomputed_delta.filter(|_| {
sequence >= 2
&& can_use_precomputed
&& !delta_compression.should_send_full_message(socket_id, channel, &cache_key)
}) {
let base_sequence =
actual_base_sequence.unwrap_or(if sequence > 0 { sequence - 1 } else { 0 });
debug!(
"Sending DELTA message to socket {} on channel {} (seq={}, base_seq={}, delta_size={} bytes)",
socket_id,
channel,
sequence,
base_sequence,
delta_bytes.len()
);
let algorithm = delta_compression.get_algorithm();
let omit_algorithm = delta_compression.config.omit_delta_algorithm;
let algorithm_str = match algorithm {
sockudo_delta::DeltaAlgorithm::Fossil => "fossil",
sockudo_delta::DeltaAlgorithm::Xdelta3 => "xdelta3",
};
let base_index = base_sequence as usize;
let mut delta_data = sonic_rs::json!({
"event": event_name,
"delta": base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
delta_bytes.as_ref(),
),
"seq": sequence,
});
if !omit_algorithm {
delta_data["algorithm"] = sonic_rs::json!(algorithm_str);
}
if !conflation_key.is_empty() {
delta_data["conflation_key"] = sonic_rs::Value::from(&conflation_key);
}
delta_data["base_index"] = sonic_rs::Value::from(base_index as u64);
let mut pusher_msg = PusherMessage {
event: Some("pusher:delta".to_string()),
channel: Some(channel.to_string()),
data: Some(sockudo_protocol::messages::MessageData::Json(delta_data)),
name: None,
user_id: None,
tags: None, sequence: None,
conflation_key: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
};
if socket_ref.protocol_version == sockudo_protocol::ProtocolVersion::V2 {
pusher_msg.rewrite_prefix(sockudo_protocol::ProtocolVersion::V2);
}
let _ = delta_compression
.store_sent_message(
socket_id,
channel,
event_name,
base_message_bytes.clone(),
false,
channel_settings,
)
.await;
socket_ref.send_message(&pusher_msg).await
} else {
debug!(
"Sending FULL message to socket {} on channel {} (seq={}, size={} bytes)",
socket_id,
channel,
sequence,
base_message_bytes.len()
);
let _msg_with_seq = if let Ok(base_str) = std::str::from_utf8(&base_message_bytes) {
if let Some(last_brace) = base_str.rfind('}') {
let mut result = String::with_capacity(base_str.len() + 100);
result.push_str(&base_str[..last_brace]);
if last_brace > 1
&& base_str[..last_brace]
.trim_end()
.ends_with(|c| c != '{' && c != ',')
{
result.push(',');
}
result.push_str(&format!("\"__delta_seq\":{}", sequence));
if !conflation_key.is_empty() {
result.push_str(&format!(",\"__conflation_key\":\"{}\"", conflation_key));
}
result.push('}');
result.into_bytes()
} else {
base_message_bytes.clone()
}
} else {
base_message_bytes.clone()
};
match delta_compression
.store_sent_message(
socket_id,
channel,
event_name,
base_message_bytes.clone(),
true,
channel_settings,
)
.await
{
Ok(_) => {
debug!(
"Successfully stored FULL message for socket {} on channel {} (seq={}, will be base for future deltas)",
socket_id, channel, sequence
);
}
Err(e) => {
warn!(
"Failed to store FULL message for socket {} on channel {}: {}",
socket_id, channel, e
);
}
}
let mut full_message = base_message;
full_message.delta_sequence = Some(sequence.into());
if !conflation_key.is_empty() {
full_message.delta_conflation_key = Some(conflation_key);
}
socket_ref.send_message(&full_message).await
}
}
async fn get_or_create_namespace(&self, app_id: &str) -> Arc<Namespace> {
self.namespaces
.entry(app_id.to_string())
.or_insert_with(|| Arc::new(Namespace::new(app_id.to_string())))
.clone()
}
fn existing_namespace(&self, app_id: &str) -> Option<Arc<Namespace>> {
self.namespaces
.get(app_id)
.map(|namespace| namespace.value().clone())
}
fn should_assign_v2_message_id(message: &PusherMessage) -> bool {
message.message_id.is_none() && !message.is_protocol_ping_or_pong()
}
fn v1_compatible_message(message: &PusherMessage) -> Option<PusherMessage> {
if extract_runtime_action(message).is_some() {
return None;
}
let mut v1_message = message.clone();
v1_message.serial = None;
v1_message.message_id = None;
v1_message.stream_id = None;
v1_message.tags = None;
v1_message.idempotency_key = None;
v1_message.extras = None;
Some(v1_message)
}
async fn send_to_v1_sockets(
&self,
sockets: Vec<WebSocketRef>,
message: &PusherMessage,
) -> Result<()> {
if sockets.is_empty() {
return Ok(());
}
let Some(v1_message) = Self::v1_compatible_message(message) else {
return Ok(());
};
let v1_bytes = Bytes::from(
sonic_rs::to_vec(&v1_message)
.map_err(|e| Error::InvalidMessageFormat(format!("Serialization failed: {e}")))?,
);
Self::log_send_errors(self.send_messages_concurrent(sockets, v1_bytes).await);
Ok(())
}
fn log_send_errors(results: Vec<Result<()>>) {
for r in results {
if let Err(e) = r {
match &e {
Error::ConnectionClosed(_) => debug!("Send error: {}", e),
_ => warn!("Send error: {}", e),
}
}
}
}
fn filter_v2_sockets_in_place(
&self,
channel: &str,
message: &PusherMessage,
v2_sockets: &mut Vec<WebSocketRef>,
except: Option<&SocketId>,
namespace: &Namespace,
) {
#[cfg(feature = "tag-filtering")]
{
crate::v2_broadcast::apply_tag_filter_in_place(
&self.filter_index,
self.tag_filtering_enabled.load(Ordering::Acquire),
channel,
message,
v2_sockets,
except,
namespace,
)
}
#[cfg(not(feature = "tag-filtering"))]
{
let _ = (channel, message, v2_sockets, except, namespace);
}
}
#[cfg(feature = "delta")]
fn filter_v2_sockets_strict_in_place(
&self,
channel: &str,
message: &PusherMessage,
v2_sockets: &mut Vec<WebSocketRef>,
except: Option<&SocketId>,
namespace: &Namespace,
) {
#[cfg(feature = "tag-filtering")]
{
crate::v2_broadcast::apply_tag_filter_v2_only_in_place(
&self.filter_index,
self.tag_filtering_enabled.load(Ordering::Acquire),
channel,
message,
v2_sockets,
except,
namespace,
)
}
#[cfg(not(feature = "tag-filtering"))]
{
let _ = (channel, message, v2_sockets, except, namespace);
}
}
#[cfg(feature = "delta")]
fn maybe_strip_tags(
&self,
message: PusherMessage,
_channel_settings: Option<&sockudo_delta::ChannelDeltaSettings>,
) -> PusherMessage {
#[cfg(feature = "tag-filtering")]
{
let global_enable_tags = self.enable_tags_globally.load(Ordering::Acquire);
let enable_tags =
crate::v2_broadcast::should_enable_tags(_channel_settings, global_enable_tags);
crate::v2_broadcast::strip_tags_if_disabled(message, enable_tags)
}
#[cfg(not(feature = "tag-filtering"))]
{
message
}
}
async fn split_rewind_gated_sockets_in_place(
&self,
channel: &str,
message: &PusherMessage,
sockets: &mut Vec<WebSocketRef>,
) {
let mut index = 0;
while index < sockets.len() {
if sockets[index].buffer_rewind_message(channel, message).await {
sockets.swap_remove(index);
} else {
index += 1;
}
}
}
pub async fn get_all_connections(&self, app_id: &str) -> Vec<SocketId> {
self.existing_namespace(app_id)
.map(|namespace| namespace.sockets.iter().map(|entry| *entry.key()).collect())
.unwrap_or_default()
}
pub fn join_channel_fast(
&self,
app_id: &str,
channel: &str,
socket_id: &SocketId,
) -> Option<usize> {
let t_start = std::time::Instant::now();
let t_before_ns_get = t_start.elapsed().as_nanos();
let namespace = self.namespaces.get(app_id)?;
let t_after_ns_get = t_start.elapsed().as_nanos();
let t_before_socket_check = t_start.elapsed().as_nanos();
if !namespace.sockets.contains_key(socket_id) {
debug!(
"PERF[FAST_PATH_FAIL] channel={} socket={} reason=socket_not_found at={}ns",
channel, socket_id, t_before_socket_check
);
return None;
}
let t_after_socket_check = t_start.elapsed().as_nanos();
let t_before_chan_check = t_start.elapsed().as_nanos();
if namespace.is_in_channel(channel, socket_id) {
let t_before_count = t_start.elapsed().as_nanos();
let count = namespace.get_channel_socket_count(channel);
let t_after_count = t_start.elapsed().as_nanos();
debug!(
"PERF[FAST_PATH_ALREADY] channel={} socket={} total={}ns ns_get={}ns socket_check={}ns chan_check={}ns count={}ns",
channel,
socket_id,
t_after_count,
t_after_ns_get - t_before_ns_get,
t_after_socket_check - t_before_socket_check,
t_before_count - t_before_chan_check,
t_after_count - t_before_count
);
return Some(count);
}
let t_after_chan_check = t_start.elapsed().as_nanos();
let t_before_add = t_start.elapsed().as_nanos();
namespace.add_channel_to_socket(channel, socket_id);
let t_after_add = t_start.elapsed().as_nanos();
let t_before_count = t_start.elapsed().as_nanos();
let count = namespace.get_channel_socket_count(channel);
let t_after_count = t_start.elapsed().as_nanos();
debug!(
"PERF[FAST_PATH_NEW] channel={} socket={} total={}ns ns_get={}ns socket_check={}ns chan_check={}ns add={}ns count={}ns",
channel,
socket_id,
t_after_count,
t_after_ns_get - t_before_ns_get,
t_after_socket_check - t_before_socket_check,
t_after_chan_check - t_before_chan_check,
t_after_add - t_before_add,
t_after_count - t_before_count
);
Some(count)
}
}
#[async_trait]
impl ConnectionManager for LocalAdapter {
async fn init(&self) {
info!("Initializing local adapter");
}
async fn get_namespace(&self, app_id: &str) -> Option<Arc<Namespace>> {
self.existing_namespace(app_id)
}
async fn add_socket(
&self,
socket_id: SocketId,
socket: WebSocketWriter,
app_id: &str,
app_manager: Arc<dyn AppManager + Send + Sync>,
buffer_config: sockudo_core::websocket::WebSocketBufferConfig,
protocol_version: sockudo_protocol::ProtocolVersion,
wire_format: sockudo_protocol::WireFormat,
echo_messages: bool,
) -> Result<()> {
debug!(
"LocalAdapter::add_socket: adding socket {} for app {}",
&socket_id, app_id
);
let namespace = self.get_or_create_namespace(app_id).await;
let socket_id_clone = socket_id;
namespace
.add_socket(
socket_id,
socket,
app_manager,
SocketInitOptions {
buffer_config,
protocol_version,
wire_format,
echo_messages,
},
)
.await?;
debug!(
"LocalAdapter::add_socket: successfully added socket {} for app {}",
socket_id_clone, app_id
);
Ok(())
}
async fn get_connection(&self, socket_id: &SocketId, app_id: &str) -> Option<WebSocketRef> {
debug!(
"LocalAdapter::get_connection: looking for socket {} in app {}",
socket_id, app_id
);
let result = self
.existing_namespace(app_id)
.and_then(|namespace| namespace.get_connection(socket_id));
debug!(
"LocalAdapter::get_connection: socket {} in app {} found: {}",
socket_id,
app_id,
result.is_some()
);
result
}
async fn remove_connection(&self, socket_id: &SocketId, app_id: &str) -> Result<()> {
if let Some(namespace) = self.namespaces.get(app_id) {
namespace.remove_connection(socket_id);
Ok(())
} else {
Err(Error::Connection("Namespace not found".to_string()))
}
}
async fn send_message(
&self,
app_id: &str,
socket_id: &SocketId,
message: PusherMessage,
) -> Result<()> {
let connection = self
.get_connection(socket_id, app_id)
.await
.ok_or_else(|| Error::Connection("Connection not found".to_string()))?;
match connection.protocol_version {
sockudo_protocol::ProtocolVersion::V2 => {
let mut rewritten = message;
if Self::should_assign_v2_message_id(&rewritten) {
rewritten.message_id = Some(generate_message_id());
}
rewritten.rewrite_prefix(sockudo_protocol::ProtocolVersion::V2);
connection.send_message(&rewritten).await
}
sockudo_protocol::ProtocolVersion::V1 => match Self::v1_compatible_message(&message) {
Some(v1_msg) => connection.send_message(&v1_msg).await,
None => Ok(()),
},
}
}
async fn send(
&self,
channel: &str,
message: PusherMessage,
except: Option<&SocketId>,
app_id: &str,
_start_time_ms: Option<f64>,
) -> Result<()> {
debug!("Sending message to channel: {}", channel);
debug!("Message: {:?}", message);
#[cfg(feature = "delta")]
{
let delta_compression = self.delta_compression.get();
let app_manager = self.app_manager.get();
if let (Some(delta_compression), Some(app_manager)) = (delta_compression, app_manager) {
if let Ok(Some(app)) = app_manager.find_by_id(app_id).await {
let channel_settings = app
.channel_delta_compression_ref()
.and_then(|map| map.get(channel))
.and_then(|config| {
use sockudo_delta::ChannelDeltaConfig;
match config {
ChannelDeltaConfig::Full(settings) => Some(settings.clone()),
_ => None,
}
});
if channel_settings
.as_ref()
.and_then(|s| s.conflation_key.as_ref())
.is_some()
{
return self
.send_with_compression(
channel,
message,
except,
app_id,
_start_time_ms,
crate::connection_manager::CompressionParams {
delta_compression: Arc::clone(delta_compression),
channel_settings: channel_settings.as_ref(),
},
)
.await;
}
}
}
}
let Some(namespace) = self.get_namespace(app_id).await else {
return Ok(());
};
let (v1_all_sockets, v2_target_sockets) = if channel.starts_with("#server-to-user-") {
let user_id = channel.trim_start_matches("#server-to-user-");
let socket_refs = namespace.get_user_sockets(user_id).await?;
let mut target_refs = Vec::new();
for socket_ref in socket_refs.iter() {
let socket_id = socket_ref.get_socket_id_sync();
if except != Some(socket_id) {
target_refs.push(socket_ref.clone());
}
}
target_refs
.into_iter()
.partition(|s| s.protocol_version == sockudo_protocol::ProtocolVersion::V1)
} else {
namespace.get_matching_channel_socket_refs_partitioned_except(channel, except)
};
if !is_v2_annotation_protocol_event(&message) {
self.send_to_v1_sockets(v1_all_sockets, &message).await?;
}
let mut filtered_socket_refs = v2_target_sockets;
self.filter_v2_sockets_in_place(
channel,
&message,
&mut filtered_socket_refs,
except,
&namespace,
);
filter_annotation_subscribers_in_place(channel, &message, &mut filtered_socket_refs);
crate::v2_broadcast::apply_event_name_filter_in_place(
channel,
&message,
&mut filtered_socket_refs,
);
self.split_rewind_gated_sockets_in_place(channel, &message, &mut filtered_socket_refs)
.await;
if !filtered_socket_refs.is_empty() {
let (v2_message, _v2_bytes) = crate::v2_broadcast::prepare_v2_message(message)?;
Self::log_send_errors(
self.send_protocol_messages_concurrent(filtered_socket_refs, v2_message)
.await,
);
}
Ok(())
}
#[cfg(feature = "delta")]
async fn send_with_compression(
&self,
channel: &str,
message: PusherMessage,
except: Option<&SocketId>,
app_id: &str,
_start_time_ms: Option<f64>,
compression: crate::connection_manager::CompressionParams<'_>,
) -> Result<()> {
let delta_compression = compression.delta_compression;
let channel_settings = compression.channel_settings;
debug!(
"Sending message to channel with compression support: {}",
channel
);
debug!("Message: {:?}", message);
let Some(namespace) = self.get_namespace(app_id).await else {
return Ok(());
};
let (v1_all_sockets, v2_target_sockets) = if channel.starts_with("#server-to-user-") {
let user_id = channel.trim_start_matches("#server-to-user-");
let socket_refs = namespace.get_user_sockets(user_id).await?;
let mut target_refs = Vec::new();
for socket_ref in socket_refs.iter() {
let socket_id = socket_ref.get_socket_id_sync();
if except != Some(socket_id) {
target_refs.push(socket_ref.clone());
}
}
target_refs
.into_iter()
.partition(|s| s.protocol_version == sockudo_protocol::ProtocolVersion::V1)
} else {
namespace.get_matching_channel_socket_refs_partitioned_except(channel, except)
};
if !is_v2_annotation_protocol_event(&message) {
self.send_to_v1_sockets(v1_all_sockets, &message).await?;
}
let mut filtered_socket_refs = v2_target_sockets;
self.filter_v2_sockets_strict_in_place(
channel,
&message,
&mut filtered_socket_refs,
except,
&namespace,
);
filter_annotation_subscribers_in_place(channel, &message, &mut filtered_socket_refs);
crate::v2_broadcast::apply_event_name_filter_in_place(
channel,
&message,
&mut filtered_socket_refs,
);
self.split_rewind_gated_sockets_in_place(channel, &message, &mut filtered_socket_refs)
.await;
let message = self.maybe_strip_tags(message, channel_settings);
if !filtered_socket_refs.is_empty() {
let mut v2_message = message;
v2_message.rewrite_prefix(sockudo_protocol::ProtocolVersion::V2);
v2_message.idempotency_key = None;
let v2_event_name = v2_message.event.as_deref().unwrap_or("").to_string();
let v2_bytes = sonic_rs::to_vec(&v2_message)
.map_err(|e| Error::InvalidMessageFormat(format!("Serialization failed: {e}")))?;
Self::log_send_errors(
self.send_messages_with_compression(
filtered_socket_refs,
v2_message,
v2_bytes,
channel,
&v2_event_name,
crate::connection_manager::CompressionParams {
delta_compression,
channel_settings,
},
)
.await,
);
}
Ok(())
}
async fn get_channel_members(
&self,
app_id: &str,
channel: &str,
) -> Result<HashMap<String, PresenceMemberInfo>> {
match self.existing_namespace(app_id) {
Some(namespace) => namespace.get_channel_members(channel).await,
None => Ok(HashMap::new()),
}
}
async fn get_channel_sockets(&self, app_id: &str, channel: &str) -> Result<Vec<SocketId>> {
Ok(self
.existing_namespace(app_id)
.map(|namespace| namespace.get_channel_sockets(channel))
.unwrap_or_default())
}
async fn remove_channel(&self, app_id: &str, channel: &str) {
#[cfg(feature = "tag-filtering")]
self.filter_index.clear_channel(channel);
let namespace = self.get_or_create_namespace(app_id).await;
namespace.remove_channel(channel);
if namespace.sockets.is_empty()
&& namespace.channels.is_empty()
&& namespace.users.is_empty()
{
self.namespaces.remove(app_id);
debug!(
"Removed empty namespace for app_id: {} after channel removal",
app_id
);
}
}
async fn is_in_channel(
&self,
app_id: &str,
channel: &str,
socket_id: &SocketId,
) -> Result<bool> {
Ok(self
.existing_namespace(app_id)
.is_some_and(|namespace| namespace.is_in_channel(channel, socket_id)))
}
async fn get_user_sockets(&self, user_id: &str, app_id: &str) -> Result<Vec<WebSocketRef>> {
match self.existing_namespace(app_id) {
Some(namespace) => namespace.get_user_sockets(user_id).await,
None => Ok(Vec::new()),
}
}
async fn cleanup_connection(&self, app_id: &str, ws: WebSocketRef) {
let namespace = self.get_or_create_namespace(app_id).await;
namespace.cleanup_connection(ws).await;
if namespace.sockets.is_empty()
&& namespace.channels.is_empty()
&& namespace.users.is_empty()
{
self.namespaces.remove(app_id);
debug!("Removed empty namespace for app_id: {}", app_id);
}
}
async fn terminate_connection(&self, app_id: &str, user_id: &str) -> Result<()> {
let namespace = self.get_or_create_namespace(app_id).await;
if let Err(e) = namespace.terminate_user_connections(user_id).await {
error!("Failed to terminate adapter: {}", e);
}
Ok(())
}
async fn add_channel_to_sockets(&self, app_id: &str, channel: &str, socket_id: &SocketId) {
let namespace = self.get_or_create_namespace(app_id).await;
namespace.add_channel_to_socket(channel, socket_id);
}
async fn get_channel_socket_count_info(
&self,
app_id: &str,
channel: &str,
) -> crate::connection_manager::ChannelSocketCount {
crate::connection_manager::ChannelSocketCount {
count: self.get_channel_socket_count(app_id, channel).await,
complete: true,
}
}
async fn get_channel_socket_count(&self, app_id: &str, channel: &str) -> usize {
self.existing_namespace(app_id)
.map(|namespace| namespace.get_channel_socket_count(channel))
.unwrap_or(0)
}
async fn add_to_channel(
&self,
app_id: &str,
channel: &str,
socket_id: &SocketId,
) -> Result<bool> {
let t_start = std::time::Instant::now();
let t_before_ns = t_start.elapsed().as_micros();
let namespace = self.get_or_create_namespace(app_id).await;
let t_after_ns = t_start.elapsed().as_micros();
let t_before_add = t_start.elapsed().as_micros();
let result = namespace.add_channel_to_socket(channel, socket_id);
let t_after_add = t_start.elapsed().as_micros();
debug!(
"PERF[LOCAL_ADD_CHAN] channel={} socket={} total={}μs get_ns={}μs add={}μs",
channel,
socket_id,
t_after_add,
t_after_ns - t_before_ns,
t_after_add - t_before_add
);
Ok(result)
}
async fn remove_from_channel(
&self,
app_id: &str,
channel: &str,
socket_id: &SocketId,
) -> Result<bool> {
let namespace = self.get_or_create_namespace(app_id).await;
#[cfg(feature = "tag-filtering")]
if let Some(socket_ref) = namespace.sockets.get(socket_id) {
let filter_node = socket_ref.get_channel_filter_sync(channel);
self.filter_index
.remove_socket_filter(channel, *socket_id, filter_node.as_deref());
socket_ref.channel_filters.remove(channel);
}
Ok(namespace.remove_channel_from_socket(channel, socket_id))
}
async fn get_presence_member(
&self,
app_id: &str,
channel: &str,
socket_id: &SocketId,
) -> Option<PresenceMemberInfo> {
let namespace = self.get_or_create_namespace(app_id).await;
namespace.get_presence_member(channel, socket_id).await
}
async fn terminate_user_connections(&self, app_id: &str, user_id: &str) -> Result<()> {
let namespace = self.get_or_create_namespace(app_id).await;
if let Err(e) = namespace.terminate_user_connections(user_id).await {
error!("Failed to terminate user connections: {}", e);
}
Ok(())
}
async fn add_user(&self, ws_ref: WebSocketRef) -> Result<()> {
let app_id = {
let ws_guard = ws_ref.inner.lock().await;
ws_guard.state.get_app_id()
};
let namespace = self.get_or_create_namespace(&app_id).await;
namespace.add_user(ws_ref).await
}
async fn remove_user(&self, ws_ref: WebSocketRef) -> Result<()> {
let app_id = {
let ws_guard = ws_ref.inner.lock().await;
ws_guard.state.get_app_id()
};
match self.existing_namespace(&app_id) {
Some(namespace) => namespace.remove_user(ws_ref).await,
None => Ok(()),
}
}
async fn remove_user_socket(
&self,
user_id: &str,
socket_id: &SocketId,
app_id: &str,
) -> Result<()> {
match self.existing_namespace(app_id) {
Some(namespace) => namespace.remove_user_socket(user_id, socket_id).await,
None => Ok(()),
}
}
async fn count_user_connections_in_channel(
&self,
user_id: &str,
app_id: &str,
channel: &str,
excluding_socket: Option<&SocketId>,
) -> Result<usize> {
match self.existing_namespace(app_id) {
Some(namespace) => {
namespace
.count_user_connections_in_channel(user_id, channel, excluding_socket)
.await
}
None => Ok(0),
}
}
async fn get_channels_with_socket_count(&self, app_id: &str) -> Result<HashMap<String, usize>> {
match self.existing_namespace(app_id) {
Some(namespace) => namespace.get_channels_with_socket_count().await,
None => Ok(HashMap::new()),
}
}
async fn get_sockets_count(&self, app_id: &str) -> Result<usize> {
if let Some(namespace) = self.namespaces.get(app_id) {
let count = namespace.sockets.len();
Ok(count)
} else {
Ok(0) }
}
async fn get_namespaces(&self) -> Result<Vec<(String, Arc<Namespace>)>> {
Ok(self
.namespaces
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect())
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
async fn check_health(&self) -> Result<()> {
Ok(())
}
fn get_node_id(&self) -> String {
"local".to_string()
}
fn as_horizontal_adapter(
&self,
) -> Option<&dyn crate::connection_manager::HorizontalAdapterInterface> {
None
}
}
#[cfg(test)]
mod tests {
use super::LocalAdapter;
use crate::ConnectionManager;
use sockudo_protocol::messages::{ExtrasValue, MessageData, MessageExtras, PusherMessage};
use std::collections::HashMap;
#[test]
fn v1_compatible_message_strips_v2_only_fields_for_plain_messages() {
let message = PusherMessage {
event: Some("chat.message".to_string()),
channel: Some("room".to_string()),
data: Some(MessageData::String("hello".to_string())),
name: Some("chat.message".to_string()),
user_id: None,
tags: None,
sequence: None,
conflation_key: None,
message_id: Some("mid-1".to_string()),
stream_id: Some("stream-1".to_string()),
serial: Some(9),
idempotency_key: Some("idem-1".to_string()),
extras: Some(MessageExtras {
headers: Some(HashMap::from([(
"note".to_string(),
ExtrasValue::String("ok".to_string()),
)])),
ephemeral: Some(true),
idempotency_key: Some("extra-idem".to_string()),
push: None,
echo: Some(false),
}),
delta_sequence: None,
delta_conflation_key: None,
};
let v1 = LocalAdapter::v1_compatible_message(&message).unwrap();
assert_eq!(v1.event.as_deref(), Some("chat.message"));
assert!(v1.serial.is_none());
assert!(v1.message_id.is_none());
assert!(v1.stream_id.is_none());
assert!(v1.idempotency_key.is_none());
assert!(v1.extras.is_none());
}
#[test]
fn v1_compatible_message_drops_versioned_mutation_events() {
let mut message = PusherMessage::channel_event(
"sockudo:message.update",
"room",
sonic_rs::json!({"text": "patched"}),
);
message.extras = Some(MessageExtras {
headers: Some(HashMap::from([
(
"sockudo_action".to_string(),
ExtrasValue::String("message.update".to_string()),
),
(
"sockudo_message_serial".to_string(),
ExtrasValue::String("msg:1".to_string()),
),
])),
..Default::default()
});
message.serial = Some(11);
message.stream_id = Some("stream-1".to_string());
assert!(LocalAdapter::v1_compatible_message(&message).is_none());
}
#[test]
fn v2_runtime_message_id_skips_protocol_heartbeats() {
assert!(!LocalAdapter::should_assign_v2_message_id(
&PusherMessage::ping()
));
assert!(!LocalAdapter::should_assign_v2_message_id(
&PusherMessage::pong()
));
}
#[test]
fn v2_runtime_message_id_still_assigns_regular_messages() {
let message = PusherMessage::channel_event(
"chat.message",
"room",
sonic_rs::json!({"text": "hello"}),
);
assert!(LocalAdapter::should_assign_v2_message_id(&message));
}
#[tokio::test]
async fn read_only_queries_do_not_create_empty_namespaces() {
let adapter = LocalAdapter::new();
assert_eq!(
ConnectionManager::get_channel_socket_count(&adapter, "missing-app", "room").await,
0
);
assert!(
ConnectionManager::get_channel_sockets(&adapter, "missing-app", "room")
.await
.unwrap()
.is_empty()
);
assert_eq!(adapter.namespaces.len(), 0);
}
}