use super::ConnectionHandler;
use super::annotations::clipped_contributor_count;
use super::types::*;
use crate::channel_manager::ChannelManager;
use crate::channel_manager::JoinResponse;
use sockudo_core::annotations::AnnotationProjectionsForChannelRequest;
use sockudo_core::app::App;
use sockudo_core::channel::{ChannelType, PresenceMemberInfo};
use sockudo_core::error::Result;
use sockudo_core::history::{HistoryDirection, HistoryItem, HistoryReadRequest, now_ms};
#[cfg(feature = "delta")]
use sockudo_delta::DeltaCompressionManager;
use ahash::AHashMap;
use sockudo_core::utils::is_cache_channel;
use sockudo_core::websocket::SocketId;
use sockudo_protocol::ProtocolVersion;
use sockudo_protocol::messages::{
AnnotationSummaryEnvelope, MESSAGE_SUMMARY_EVENT_NAME, MessageData, MessageExtras,
MessageSummaryData, PresenceData, PusherMessage,
};
use sonic_rs::Value;
use std::collections::{BTreeMap, HashSet};
use std::sync::Arc;
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct SubscriptionResult {
pub success: bool,
pub auth_error: Option<String>,
pub member: Option<PresenceMember>,
pub channel_connections: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct PresenceMember {
pub user_id: String,
pub user_info: Value,
}
impl ConnectionHandler {
pub async fn execute_subscription(
&self,
socket_id: &SocketId,
app_config: &App,
request: &SubscriptionRequest,
is_authenticated: bool,
) -> Result<SubscriptionResult> {
let t_start = std::time::Instant::now();
let t_before_msg_create = t_start.elapsed().as_micros();
let temp_message = PusherMessage {
channel: Some(request.channel.clone()),
event: Some("pusher:subscribe".to_string()),
data: Some(MessageData::Json(sonic_rs::json!({
"channel": request.channel,
"auth": request.auth,
"channel_data": request.channel_data
}))),
name: None,
user_id: None,
tags: None,
sequence: None,
conflation_key: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
};
let t_after_msg_create = t_start.elapsed().as_micros();
let subscription_result = if let Some(ref local_adapter) = self.local_adapter {
let t_before_channel_type = t_start.elapsed().as_micros();
let channel_type = ChannelType::from_name(&request.channel);
let t_after_channel_type = t_start.elapsed().as_micros();
let can_use_fast_path = channel_type != ChannelType::Presence
&& (!channel_type.requires_authentication() || is_authenticated);
if can_use_fast_path {
let t_before_fast = t_start.elapsed().as_micros();
let mut fast_result =
local_adapter.join_channel_fast(&app_config.id, &request.channel, socket_id);
let mut retry_count = 0;
while fast_result.is_none() && retry_count < 3 {
tokio::task::yield_now().await;
fast_result = local_adapter.join_channel_fast(
&app_config.id,
&request.channel,
socket_id,
);
retry_count += 1;
}
match fast_result {
Some(channel_connections) => {
let t_after_fast = t_start.elapsed().as_micros();
tracing::debug!(
"PERF[FAST_PATH] socket_id={} channel={} total={}μs channel_type={}μs",
socket_id,
request.channel,
t_after_fast - t_before_fast,
t_after_channel_type - t_before_channel_type
);
JoinResponse {
success: true,
channel_connections: Some(channel_connections),
member: None,
auth_error: None,
error_message: None,
error_code: None,
_type: None,
}
}
None => {
let t_before_fallback = t_start.elapsed().as_micros();
tracing::debug!(
"PERF[FAST_PATH_FALLBACK] socket_id={} channel={} fallback_at={}μs",
socket_id,
request.channel,
t_before_fallback
);
let result = ChannelManager::subscribe(
&self.connection_manager,
&socket_id.to_string(),
&temp_message,
&request.channel,
is_authenticated,
&app_config.id,
)
.await?;
let t_after_fallback = t_start.elapsed().as_micros();
tracing::debug!(
"PERF[FALLBACK_DONE] socket_id={} channel={} fallback_time={}μs",
socket_id,
request.channel,
t_after_fallback - t_before_fallback
);
result
}
}
} else {
let t_before_normal = t_start.elapsed().as_micros();
let result = ChannelManager::subscribe(
&self.connection_manager,
&socket_id.to_string(),
&temp_message,
&request.channel,
is_authenticated,
&app_config.id,
)
.await?;
let t_after_normal = t_start.elapsed().as_micros();
tracing::debug!(
"PERF[NORMAL_PATH] socket_id={} channel={} reason=presence total={}μs",
socket_id,
request.channel,
t_after_normal - t_before_normal
);
result
}
} else {
let t_before_redis = t_start.elapsed().as_micros();
let result = ChannelManager::subscribe(
&self.connection_manager,
&socket_id.to_string(),
&temp_message,
&request.channel,
is_authenticated,
&app_config.id,
)
.await?;
let t_after_redis = t_start.elapsed().as_micros();
tracing::debug!(
"PERF[REDIS_PATH] socket_id={} channel={} total={}μs",
socket_id,
request.channel,
t_after_redis - t_before_redis
);
result
};
let t_after_channel_mgr = t_start.elapsed().as_micros();
let t_before_metrics = t_start.elapsed().as_micros();
if subscription_result.success
&& let Some(ref metrics) = self.metrics
{
let channel_type = ChannelType::from_name(&request.channel);
let channel_type_str = channel_type.as_str();
{
metrics.mark_channel_subscription(&app_config.id, channel_type_str);
}
if subscription_result.channel_connections == Some(1) {
super::sync_active_channel_count(
&self.connection_manager,
metrics,
&app_config.id,
channel_type_str,
)
.await;
}
}
let t_after_metrics = t_start.elapsed().as_micros();
if subscription_result.success
&& ChannelType::from_name(&request.channel).requires_authentication()
&& let Err(e) = self
.clear_user_authentication_timeout(&app_config.id, socket_id)
.await
{
tracing::warn!(
"Failed to clear user auth timeout for socket {}: {}",
socket_id,
e
);
}
let total = t_start.elapsed().as_micros();
tracing::debug!(
"PERF[EXECUTE_SUB] socket_id={} channel={} total={}μs msg_create={}μs channel_mgr={}μs metrics={}μs",
socket_id,
request.channel,
total,
t_after_msg_create - t_before_msg_create,
t_after_channel_mgr - t_after_msg_create,
t_after_metrics - t_before_metrics
);
Ok(SubscriptionResult {
success: subscription_result.success,
auth_error: subscription_result.auth_error,
member: subscription_result.member.map(|m| PresenceMember {
user_id: m.user_id.to_string(),
user_info: m.user_info,
}),
channel_connections: subscription_result.channel_connections,
})
}
pub async fn handle_post_subscription(
&self,
socket_id: &SocketId,
app_config: &App,
request: &SubscriptionRequest,
subscription_result: &SubscriptionResult,
) -> Result<()> {
let channel_type = ChannelType::from_name(&request.channel);
if request.rewind.is_some() {
self.update_connection_subscription_state(
socket_id,
app_config,
request,
subscription_result,
)
.await?;
match channel_type {
ChannelType::Presence => {
self.handle_presence_subscription_success(
socket_id,
app_config,
request,
subscription_result,
)
.await?;
}
_ => {
self.send_subscription_succeeded(socket_id, app_config, &request.channel, None)
.await?;
}
}
self.rewind_subscription(socket_id, app_config, request)
.await?;
} else {
match channel_type {
ChannelType::Presence => {
self.handle_presence_subscription_success(
socket_id,
app_config,
request,
subscription_result,
)
.await?;
}
_ => {
self.send_subscription_succeeded(socket_id, app_config, &request.channel, None)
.await?;
}
}
self.update_connection_subscription_state(
socket_id,
app_config,
request,
subscription_result,
)
.await?;
}
#[cfg(feature = "delta")]
{
self.apply_subscription_delta_settings(socket_id, &request.channel, &request.delta)
.await;
tracing::debug!(
"About to call send_delta_cache_sync_if_needed for socket {} channel {}",
socket_id,
request.channel
);
self.send_delta_cache_sync_if_needed(socket_id, app_config, &request.channel)
.await?;
tracing::debug!(
"send_delta_cache_sync_if_needed completed successfully for socket {} channel {}",
socket_id,
request.channel
);
}
if subscription_result.channel_connections == Some(1)
&& let Some(webhook_integration) = self.webhook_integration.clone()
{
let app_config = app_config.clone();
let channel = request.channel.clone();
tokio::spawn(async move {
if let Err(e) = webhook_integration
.send_channel_occupied(&app_config, &channel)
.await
{
tracing::warn!(
"Failed to send channel_occupied webhook for {}: {}",
channel,
e
);
}
});
}
if !sockudo_core::utils::is_meta_channel(&request.channel) {
let current_count = self
.connection_manager
.get_channel_socket_count(&app_config.id, &request.channel)
.await;
if current_count == 1 {
self.broadcast_metachannel_event(
app_config,
&request.channel,
"channel_occupied",
sonic_rs::json!({
"channel": request.channel,
"subscription_count": current_count,
}),
)
.await
.ok();
}
self.broadcast_metachannel_event(
app_config,
&request.channel,
"subscription_count",
sonic_rs::json!({
"channel": request.channel,
"subscription_count": current_count,
}),
)
.await
.ok();
}
if !request.channel.starts_with("presence-")
&& !sockudo_core::utils::is_meta_channel(&request.channel)
&& let Some(webhook_integration) = &self.webhook_integration
{
let current_count = self
.connection_manager
.get_channel_socket_count(&app_config.id, &request.channel)
.await;
webhook_integration
.send_subscription_count_changed(app_config, &request.channel, current_count)
.await
.ok();
}
if is_cache_channel(&request.channel) {
self.send_missed_cache_if_exists(&app_config.id, socket_id, &request.channel)
.await?;
}
Ok(())
}
async fn rewind_subscription(
&self,
socket_id: &SocketId,
app_config: &App,
request: &SubscriptionRequest,
) -> Result<()> {
let Some(rewind) = request.rewind.as_ref() else {
return Ok(());
};
let connection = self
.connection_manager
.get_connection(socket_id, &app_config.id)
.await
.ok_or(sockudo_core::error::Error::ConnectionNotFound)?;
let history_policy =
app_config.resolved_history(&request.channel, &self.server_options().history);
if !history_policy.rewind_allowed() {
return Err(sockudo_core::error::Error::Channel(format!(
"Channel rewind is disabled by policy for channel '{}'",
request.channel
)));
}
let max_page_size = history_policy.max_page_size;
let page = self
.history_store()
.read_page(build_rewind_history_read_request(
&app_config.id,
&request.channel,
rewind,
max_page_size,
))
.await?;
let items = normalize_rewind_items_for_delivery(rewind, page.items);
let history_head_serial = items
.last()
.map(|item| item.serial)
.or(page.retained.newest_serial);
let delivered_message_ids = items
.iter()
.filter_map(|item| item.message_id.clone())
.collect::<HashSet<_>>();
self.send_rewind_history_items(socket_id, app_config, &items)
.await?;
let buffered = connection.finish_rewind_gate(&request.channel).await;
let live_messages =
filter_buffered_rewind_messages(buffered, history_head_serial, &delivered_message_ids);
let live_count = live_messages.len();
for message in live_messages {
self.send_message_to_socket(&app_config.id, socket_id, message)
.await?;
}
let truncated_by_limit = match rewind {
SubscriptionRewind::Count(count) => *count > max_page_size,
SubscriptionRewind::Seconds(_) => page.has_more,
};
let rewind_complete = !page.truncated_by_retention && !truncated_by_limit;
self.send_rewind_complete(
socket_id,
app_config,
&request.channel,
items.len(),
live_count,
rewind_complete,
page.truncated_by_retention,
truncated_by_limit,
)
.await?;
Ok(())
}
async fn send_rewind_history_items(
&self,
socket_id: &SocketId,
app_config: &App,
items: &[HistoryItem],
) -> Result<()> {
for item in items {
let message: PusherMessage = sonic_rs::from_slice(item.payload_bytes.as_ref())
.map_err(|e| {
sockudo_core::error::Error::InvalidMessageFormat(format!(
"Invalid stored history payload: {e}"
))
})?;
self.send_message_to_socket(&app_config.id, socket_id, message)
.await?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn send_rewind_complete(
&self,
socket_id: &SocketId,
app_config: &App,
channel: &str,
historical_count: usize,
live_count: usize,
complete: bool,
truncated_by_retention: bool,
truncated_by_limit: bool,
) -> Result<()> {
let message = PusherMessage {
event: Some("sockudo:rewind_complete".to_string()),
channel: Some(channel.to_string()),
data: Some(MessageData::Json(sonic_rs::json!({
"historical_count": historical_count,
"live_count": live_count,
"complete": complete,
"truncated_by_retention": truncated_by_retention,
"truncated_by_limit": truncated_by_limit,
}))),
name: None,
user_id: None,
tags: None,
sequence: None,
conflation_key: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
};
self.send_message_to_socket(&app_config.id, socket_id, message)
.await
}
async fn update_connection_subscription_state(
&self,
socket_id: &SocketId,
app_config: &App,
request: &SubscriptionRequest,
subscription_result: &SubscriptionResult,
) -> Result<()> {
let connection_manager = &self.connection_manager;
if let Some(conn_arc) = connection_manager
.get_connection(socket_id, &app_config.id)
.await
{
{
#[cfg(feature = "tag-filtering")]
let tag_filter = request.tags_filter.clone();
#[cfg(not(feature = "tag-filtering"))]
let tag_filter = None;
conn_arc
.subscribe_to_channel_with_filters(
request.channel.clone(),
tag_filter,
request.event_name_filter.clone(),
request.annotation_subscribe,
)
.await;
}
#[cfg(feature = "tag-filtering")]
if let Some(ref local_adapter) = self.local_adapter {
let filter_index = local_adapter.get_filter_index();
let filter_node = conn_arc.get_channel_filter_sync(&request.channel);
filter_index.add_socket_filter(
&request.channel,
*socket_id,
filter_node.as_deref(),
);
}
let mut conn_locked = conn_arc.inner.lock().await;
if let Some(ref member) = subscription_result.member {
conn_locked.state.user_id = Some(member.user_id.clone());
let presence_info = PresenceMemberInfo {
user_id: member.user_id.clone(),
user_info: Some(member.user_info.clone()),
};
conn_locked.add_presence_info(request.channel.clone(), presence_info);
drop(conn_locked);
self.connection_manager.add_user(conn_arc.clone()).await?;
} else {
drop(conn_locked);
}
}
self.send_annotation_summary_snapshots(socket_id, app_config, &request.channel)
.await?;
Ok(())
}
async fn handle_presence_subscription_success(
&self,
socket_id: &SocketId,
app_config: &App,
request: &SubscriptionRequest,
subscription_result: &SubscriptionResult,
) -> Result<()> {
if let Some(ref presence_member) = subscription_result.member {
let presence_history_policy = app_config.resolved_presence_history(
&request.channel,
&self.server_options().presence_history,
);
self.presence_manager
.handle_member_added(
Arc::clone(&self.connection_manager),
Arc::clone(self.presence_history_store()),
presence_history_policy.enabled,
self.webhook_integration.as_ref(),
self.metrics.as_ref(),
app_config,
&request.channel,
&presence_member.user_id,
Some(&presence_member.user_info),
Some(socket_id),
Some(presence_history_policy.retention()),
)
.await?;
let mut members_map = self
.connection_manager
.get_local_channel_members(&app_config.id, &request.channel)
.await?;
if !members_map.contains_key(&presence_member.user_id) {
members_map.insert(
presence_member.user_id.clone(),
PresenceMemberInfo {
user_id: presence_member.user_id.clone(),
user_info: Some(presence_member.user_info.clone()),
},
);
}
let presence_data = PresenceData {
ids: members_map.keys().cloned().collect::<Vec<String>>(),
hash: members_map
.iter()
.map(|(k, v)| (k.clone(), v.user_info.clone()))
.collect::<AHashMap<String, Option<Value>>>(),
count: members_map.len(),
};
self.send_subscription_succeeded(
socket_id,
app_config,
&request.channel,
Some(presence_data),
)
.await?;
}
Ok(())
}
async fn send_subscription_succeeded(
&self,
socket_id: &SocketId,
app_config: &App,
channel: &str,
data: Option<PresenceData>,
) -> Result<()> {
let response_msg = PusherMessage::subscription_succeeded(channel.to_string(), data);
if let Some(ref local_adapter) = self.local_adapter {
let adapter_ref: &dyn crate::connection_manager::ConnectionManager =
local_adapter.as_ref();
return adapter_ref
.send_message(&app_config.id, socket_id, response_msg)
.await;
}
self.connection_manager
.send_message(&app_config.id, socket_id, response_msg)
.await
}
async fn send_annotation_summary_snapshots(
&self,
socket_id: &SocketId,
app_config: &App,
channel: &str,
) -> Result<()> {
let Some(connection) = self
.connection_manager
.get_connection(socket_id, &app_config.id)
.await
else {
return Ok(());
};
if connection.protocol_version != ProtocolVersion::V2 {
return Ok(());
}
let rebuild_started = Instant::now();
let (projections, rebuild_count) = self
.annotation_store()
.list_projections_for_channel_with_rebuild_count(
AnnotationProjectionsForChannelRequest {
app_id: app_config.id.clone(),
channel_id: channel.to_string(),
},
)
.await?;
if rebuild_count > 0
&& let Some(metrics) = self.metrics()
{
for _ in 0..rebuild_count {
metrics.mark_annotation_projection_rebuild(channel);
}
metrics.track_annotation_projection_rebuild_duration(
channel,
rebuild_started.elapsed().as_secs_f64(),
);
}
for projection in projections {
if let Some(contributor_count) = clipped_contributor_count(&projection.summary) {
tracing::warn!(
channel = %channel,
message_serial = %projection.message_serial.as_str(),
annotation_type = %projection.annotation_type.as_str(),
contributor_count,
"annotation summary clipped"
);
if let Some(metrics) = self.metrics() {
metrics.mark_annotation_summary_clipped(
channel,
projection.annotation_type.as_str(),
);
}
}
let mut summary_by_type = BTreeMap::new();
summary_by_type.insert(
projection.annotation_type.as_str().to_string(),
sonic_rs::to_value(&projection.summary).map_err(|err| {
sockudo_core::error::Error::InvalidMessageFormat(format!(
"Failed to encode annotation summary: {err}"
))
})?,
);
let message = PusherMessage {
event: Some(MESSAGE_SUMMARY_EVENT_NAME.to_string()),
channel: Some(channel.to_string()),
data: Some(MessageData::Json(
sonic_rs::to_value(&MessageSummaryData {
action: "message.summary".to_string(),
serial: projection.message_serial.as_str().to_string(),
annotations: AnnotationSummaryEnvelope {
summary: summary_by_type,
},
})
.map_err(|err| {
sockudo_core::error::Error::InvalidMessageFormat(format!(
"Failed to encode annotation summary message: {err}"
))
})?,
)),
name: None,
user_id: None,
tags: None,
sequence: None,
conflation_key: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: Some(MessageExtras {
ephemeral: Some(true),
..Default::default()
}),
delta_sequence: None,
delta_conflation_key: None,
};
connection.send_message(&message).await?;
if let Some(metrics) = self.metrics() {
metrics.mark_annotation_summary_delivery(channel);
}
}
Ok(())
}
#[cfg(feature = "delta")]
async fn apply_subscription_delta_settings(
&self,
socket_id: &SocketId,
channel: &str,
delta_settings: &Option<SubscriptionDeltaSettings>,
) {
let Some(settings) = delta_settings else {
return;
};
if DeltaCompressionManager::is_encrypted_channel(channel) {
tracing::debug!(
"Ignoring per-subscription delta settings for encrypted channel '{}' - delta compression has no benefit",
channel
);
return;
}
self.delta_compression.set_channel_delta_settings(
socket_id,
channel,
settings.should_enable(),
settings.preferred_algorithm(),
);
tracing::info!(
"Applied per-subscription delta settings for socket {} channel {}: enabled={:?}, algorithm={:?}",
socket_id,
channel,
settings.should_enable(),
settings.preferred_algorithm()
);
if settings.should_enable() == Some(true)
&& !self.delta_compression.is_enabled_for_socket(socket_id)
{
let algorithm = settings
.preferred_algorithm()
.unwrap_or(self.delta_compression.get_algorithm());
let algorithm_str = match algorithm {
sockudo_delta::DeltaAlgorithm::Fossil => "fossil",
sockudo_delta::DeltaAlgorithm::Xdelta3 => "xdelta3",
};
let confirmation = PusherMessage {
event: Some("pusher:delta_compression_enabled".to_string()),
channel: Some(channel.to_string()),
data: Some(sockudo_protocol::messages::MessageData::Json(
sonic_rs::json!({
"enabled": true,
"algorithm": algorithm_str,
"channel": channel
}),
)),
name: None,
user_id: None,
tags: None,
sequence: None,
conflation_key: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
};
if let Ok(apps) = self.app_manager.get_apps().await {
for app in apps {
if self
.connection_manager
.get_connection(socket_id, &app.id)
.await
.is_some()
{
if let Err(e) = self
.send_message_to_socket(&app.id, socket_id, confirmation.clone())
.await
{
tracing::warn!(
"Failed to send per-channel delta confirmation for socket {} channel {}: {}",
socket_id,
channel,
e
);
}
break;
}
}
}
}
}
#[cfg(feature = "delta")]
async fn send_delta_cache_sync_if_needed(
&self,
socket_id: &SocketId,
app_config: &App,
channel: &str,
) -> Result<()> {
tracing::debug!(
"send_delta_cache_sync_if_needed: START for socket {} channel {}",
socket_id,
channel
);
let is_enabled = self.delta_compression.is_enabled_for_socket(socket_id);
tracing::debug!(
"send_delta_cache_sync_if_needed: delta compression enabled for socket {}: {}",
socket_id,
is_enabled
);
if !is_enabled {
tracing::debug!("send_delta_cache_sync_if_needed: returning early - not enabled");
return Ok(());
}
let channel_settings = app_config
.channel_delta_compression_ref()
.and_then(|map| map.get(channel))
.and_then(|config| {
use sockudo_delta::ChannelDeltaConfig;
match config {
ChannelDeltaConfig::Full(settings) => Some(settings.clone()),
_ => None,
}
});
let conflation_key = channel_settings
.as_ref()
.and_then(|s| s.conflation_key.as_deref());
tracing::debug!(
"send_delta_cache_sync_if_needed: conflation_key = {:?}",
conflation_key
);
if conflation_key.is_none() {
tracing::debug!("send_delta_cache_sync_if_needed: returning early - no conflation key");
return Ok(());
}
tracing::debug!(
"send_delta_cache_sync_if_needed: skipping cache sync (causes sequence mismatch on resubscribe)"
);
Ok(())
}
}
fn filter_buffered_rewind_messages(
mut buffered: Vec<sockudo_core::websocket::BufferedRewindMessage>,
history_head_serial: Option<u64>,
delivered_message_ids: &HashSet<String>,
) -> Vec<PusherMessage> {
buffered.sort_by_key(|message| message.serial.unwrap_or(u64::MAX));
buffered
.into_iter()
.filter(|message| {
let after_history_head = history_head_serial
.is_none_or(|head| message.serial.is_none_or(|serial| serial > head));
let not_duplicate = message
.message_id
.as_ref()
.is_none_or(|message_id| !delivered_message_ids.contains(message_id));
after_history_head && not_duplicate
})
.map(|message| message.message)
.collect()
}
fn build_rewind_history_read_request(
app_id: &str,
channel: &str,
rewind: &SubscriptionRewind,
max_page_size: usize,
) -> HistoryReadRequest {
HistoryReadRequest {
app_id: app_id.to_string(),
channel: channel.to_string(),
direction: match rewind {
SubscriptionRewind::Count(_) => HistoryDirection::NewestFirst,
SubscriptionRewind::Seconds(_) => HistoryDirection::OldestFirst,
},
limit: rewind.limit().min(max_page_size),
cursor: None,
bounds: rewind.to_history_bounds(now_ms()),
}
}
fn normalize_rewind_items_for_delivery(
rewind: &SubscriptionRewind,
mut items: Vec<HistoryItem>,
) -> Vec<HistoryItem> {
if matches!(rewind, SubscriptionRewind::Count(_)) {
items.reverse();
}
items
}
#[cfg(test)]
mod rewind_tests {
use super::*;
use bytes::Bytes;
fn test_message(event: &str) -> PusherMessage {
PusherMessage {
event: Some(event.to_string()),
channel: None,
data: None,
name: None,
user_id: None,
tags: None,
sequence: None,
conflation_key: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
#[test]
fn buffered_rewind_messages_drop_duplicates_and_preserve_gap_free_order() {
let delivered = HashSet::from(["msg-3".to_string()]);
let buffered = vec![
sockudo_core::websocket::BufferedRewindMessage {
serial: Some(3),
message_id: Some("msg-3".to_string()),
message: test_message("three"),
},
sockudo_core::websocket::BufferedRewindMessage {
serial: Some(4),
message_id: Some("msg-4".to_string()),
message: test_message("four"),
},
sockudo_core::websocket::BufferedRewindMessage {
serial: Some(5),
message_id: Some("msg-5".to_string()),
message: test_message("five"),
},
];
let filtered = filter_buffered_rewind_messages(buffered, Some(3), &delivered);
assert_eq!(filtered.len(), 2);
assert_eq!(filtered[0].event.as_deref(), Some("four"));
assert_eq!(filtered[1].event.as_deref(), Some("five"));
}
#[test]
fn rewind_count_builds_newest_first_request() {
let request =
build_rewind_history_read_request("app", "chat", &SubscriptionRewind::Count(10), 100);
assert_eq!(request.direction, HistoryDirection::NewestFirst);
assert_eq!(request.limit, 10);
assert_eq!(request.bounds.start_time_ms, None);
}
#[test]
fn rewind_duration_builds_time_bounded_request() {
let request =
build_rewind_history_read_request("app", "chat", &SubscriptionRewind::Seconds(30), 100);
assert_eq!(request.direction, HistoryDirection::OldestFirst);
assert_eq!(request.limit, 100);
assert!(request.bounds.start_time_ms.is_some());
}
#[test]
fn rewind_count_items_are_reordered_oldest_to_newest_for_delivery() {
let items = vec![
HistoryItem {
stream_id: "stream".to_string(),
serial: 5,
published_at_ms: 5,
message_id: None,
event_name: None,
operation_kind: "append".to_string(),
payload_size_bytes: 0,
payload_bytes: Bytes::new(),
},
HistoryItem {
stream_id: "stream".to_string(),
serial: 4,
published_at_ms: 4,
message_id: None,
event_name: None,
operation_kind: "append".to_string(),
payload_size_bytes: 0,
payload_bytes: Bytes::new(),
},
];
let reordered = normalize_rewind_items_for_delivery(&SubscriptionRewind::Count(2), items);
assert_eq!(reordered[0].serial, 4);
assert_eq!(reordered[1].serial, 5);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_presence_subscription_includes_self() {
let mut members_map: AHashMap<String, PresenceMemberInfo> = AHashMap::new();
let presence_member = PresenceMember {
user_id: "user-123".to_string(),
user_info: sonic_rs::json!({"name": "Alice"}),
};
if !members_map.contains_key(&presence_member.user_id) {
members_map.insert(
presence_member.user_id.clone(),
PresenceMemberInfo {
user_id: presence_member.user_id.clone(),
user_info: Some(presence_member.user_info.clone()),
},
);
}
let presence_data = PresenceData {
ids: members_map.keys().cloned().collect::<Vec<String>>(),
hash: members_map
.iter()
.map(|(k, v)| (k.clone(), v.user_info.clone()))
.collect::<AHashMap<String, Option<Value>>>(),
count: members_map.len(),
};
assert_eq!(presence_data.count, 1);
assert!(presence_data.ids.contains(&"user-123".to_string()));
assert!(presence_data.hash.contains_key("user-123"));
}
#[test]
fn test_presence_subscription_no_double_counting() {
let mut members_map: AHashMap<String, PresenceMemberInfo> = AHashMap::new();
members_map.insert(
"user-123".to_string(),
PresenceMemberInfo {
user_id: "user-123".to_string(),
user_info: Some(sonic_rs::json!({"name": "Alice"})),
},
);
let presence_member = PresenceMember {
user_id: "user-123".to_string(),
user_info: sonic_rs::json!({"name": "Alice Updated"}),
};
if !members_map.contains_key(&presence_member.user_id) {
members_map.insert(
presence_member.user_id.clone(),
PresenceMemberInfo {
user_id: presence_member.user_id.clone(),
user_info: Some(presence_member.user_info.clone()),
},
);
}
let presence_data = PresenceData {
ids: members_map.keys().cloned().collect::<Vec<String>>(),
hash: members_map
.iter()
.map(|(k, v)| (k.clone(), v.user_info.clone()))
.collect::<AHashMap<String, Option<Value>>>(),
count: members_map.len(),
};
assert_eq!(presence_data.count, 1);
assert_eq!(presence_data.ids.len(), 1);
}
}