use ahash::AHashMap;
use serde::de::Error as _;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use sonic_rs::prelude::*;
use sonic_rs::{Value, json};
use std::collections::{BTreeMap, HashMap};
use std::time::Duration;
use crate::protocol_version::ProtocolVersion;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum ExtrasValue {
String(String),
Number(f64),
Bool(bool),
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct MessageExtras {
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, ExtrasValue>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ephemeral: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub idempotency_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub push: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub echo: Option<bool>,
}
impl MessageExtras {
pub fn validate_headers_from_json(raw: &Value) -> Result<(), String> {
if let Some(extras) = raw.get("extras")
&& let Some(headers) = extras.get("headers")
&& let Some(obj) = headers.as_object()
{
for (key, val) in obj.iter() {
if val.is_object() || val.is_array() {
return Err(format!(
"extras.headers must be a flat object — nested objects and arrays are not allowed (key: '{key}')"
));
}
}
}
Ok(())
}
}
pub fn generate_message_id() -> String {
uuid::Uuid::new_v4().to_string()
}
pub const ANNOTATION_EVENT_NAME: &str = "sockudo_internal:annotation";
pub const MESSAGE_SUMMARY_EVENT_NAME: &str = "sockudo_internal:message";
pub const ANNOTATION_SUBSCRIBE_MODE: &str = "ANNOTATION_SUBSCRIBE";
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum AnnotationEventAction {
#[serde(rename = "annotation.create")]
Create,
#[serde(rename = "annotation.delete")]
Delete,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct AnnotationEventData {
pub action: AnnotationEventAction,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
pub serial: String,
pub message_serial: String,
#[serde(rename = "type")]
pub annotation_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding: Option<String>,
pub timestamp: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AnnotationSummaryEnvelope {
pub summary: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct MessageSummaryData {
pub action: String,
pub serial: String,
pub annotations: AnnotationSummaryEnvelope,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PresenceData {
pub ids: Vec<String>,
pub hash: AHashMap<String, Option<Value>>,
pub count: usize,
}
#[derive(Debug, Clone, Serialize, PartialEq)]
#[serde(untagged)]
pub enum MessageData {
String(String),
Structured {
#[serde(skip_serializing_if = "Option::is_none")]
channel_data: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
channel: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
user_data: Option<String>,
#[serde(flatten)]
extra: AHashMap<String, Value>,
},
Json(Value),
}
impl<'de> Deserialize<'de> for MessageData {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let v = JsonValue::deserialize(deserializer)?;
if let Some(s) = v.as_str() {
return Ok(MessageData::String(s.to_string()));
}
if let Some(obj) = v.as_object() {
let channel_data = obj
.get("channel_data")
.and_then(|x| x.as_str())
.map(ToString::to_string);
let channel = obj
.get("channel")
.and_then(|x| x.as_str())
.map(ToString::to_string);
let user_data = obj
.get("user_data")
.and_then(|x| x.as_str())
.map(ToString::to_string);
if channel_data.is_some() || channel.is_some() || user_data.is_some() {
let mut extra = AHashMap::new();
for (k, val) in obj.iter() {
if k != "channel_data" && k != "channel" && k != "user_data" {
extra.insert(
k.to_string(),
serde_json_value_to_sonic(val.clone()).map_err(D::Error::custom)?,
);
}
}
return Ok(MessageData::Structured {
channel_data,
channel,
user_data,
extra,
});
}
}
Ok(MessageData::Json(
serde_json_value_to_sonic(v).map_err(D::Error::custom)?,
))
}
}
fn serde_json_value_to_sonic(value: JsonValue) -> Result<Value, String> {
let encoded = serde_json::to_string(&value)
.map_err(|err| format!("failed to encode json value for MessageData: {err}"))?;
sonic_rs::from_str(&encoded)
.map_err(|err| format!("failed to decode json value for MessageData: {err}"))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorData {
pub code: Option<u16>,
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PusherMessage {
#[serde(skip_serializing_if = "Option::is_none")]
pub event: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub channel: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<MessageData>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tags: Option<BTreeMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sequence: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub conflation_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub message_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub serial: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub idempotency_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub extras: Option<MessageExtras>,
#[serde(rename = "__delta_seq", skip_serializing_if = "Option::is_none")]
pub delta_sequence: Option<u64>,
#[serde(rename = "__conflation_key", skip_serializing_if = "Option::is_none")]
pub delta_conflation_key: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PusherApiMessage {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<ApiMessageData>,
#[serde(skip_serializing_if = "Option::is_none")]
pub channel: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub channels: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub socket_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub info: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tags: Option<AHashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub delta: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub idempotency_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub extras: Option<MessageExtras>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchPusherApiMessage {
pub batch: Vec<PusherApiMessage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ApiMessageData {
String(String),
Json(Value),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SentPusherMessage {
#[serde(skip_serializing_if = "Option::is_none")]
pub channel: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub event: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<MessageData>,
}
impl MessageData {
pub fn as_string(&self) -> Option<&str> {
match self {
MessageData::String(s) => Some(s),
_ => None,
}
}
pub fn into_string(self) -> Option<String> {
match self {
MessageData::String(s) => Some(s),
_ => None,
}
}
pub fn as_value(&self) -> Option<&Value> {
match self {
MessageData::Structured { extra, .. } => extra.values().next(),
_ => None,
}
}
}
impl From<String> for MessageData {
fn from(s: String) -> Self {
MessageData::String(s)
}
}
impl From<Value> for MessageData {
fn from(v: Value) -> Self {
MessageData::Json(v)
}
}
impl PusherMessage {
pub fn is_protocol_ping_or_pong(&self) -> bool {
let Some(event) = self.event.as_deref() else {
return false;
};
matches!(
ProtocolVersion::parse_any_protocol_event(event),
Some(("ping", _)) | Some(("pong", _))
)
}
pub fn connection_established(socket_id: String, activity_timeout: u64) -> Self {
Self {
event: Some("pusher:connection_established".to_string()),
data: Some(MessageData::from(
json!({
"socket_id": socket_id,
"activity_timeout": activity_timeout })
.to_string(),
)),
channel: None,
name: None,
user_id: None,
sequence: None,
conflation_key: None,
tags: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
pub fn subscription_succeeded(channel: String, presence_data: Option<PresenceData>) -> Self {
let data_obj = if let Some(data) = presence_data {
json!({
"presence": {
"ids": data.ids,
"hash": data.hash,
"count": data.count
}
})
} else {
json!({})
};
Self {
event: Some("pusher_internal:subscription_succeeded".to_string()),
channel: Some(channel),
data: Some(MessageData::String(data_obj.to_string())),
name: None,
user_id: None,
sequence: None,
conflation_key: None,
tags: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
pub fn error(code: u16, message: String, channel: Option<String>) -> Self {
Self {
event: Some("pusher:error".to_string()),
data: Some(MessageData::Json(json!({
"code": code,
"message": message
}))),
channel,
name: None,
user_id: None,
sequence: None,
conflation_key: None,
tags: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
pub fn ping() -> Self {
Self {
event: Some("pusher:ping".to_string()),
data: None,
channel: None,
name: None,
user_id: None,
sequence: None,
conflation_key: None,
tags: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
pub fn channel_event<S: Into<String>>(event: S, channel: S, data: Value) -> Self {
Self {
event: Some(event.into()),
channel: Some(channel.into()),
data: Some(MessageData::String(data.to_string())),
name: None,
user_id: None,
sequence: None,
conflation_key: None,
tags: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
pub fn member_added(channel: String, user_id: String, user_info: Option<Value>) -> Self {
Self {
event: Some("pusher_internal:member_added".to_string()),
channel: Some(channel),
data: Some(MessageData::String(
json!({
"user_id": user_id,
"user_info": user_info.unwrap_or_else(|| json!({}))
})
.to_string(),
)),
name: None,
user_id: None,
sequence: None,
conflation_key: None,
tags: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
pub fn member_removed(channel: String, user_id: String) -> Self {
Self {
event: Some("pusher_internal:member_removed".to_string()),
channel: Some(channel),
data: Some(MessageData::String(
json!({
"user_id": user_id
})
.to_string(),
)),
name: None,
user_id: None,
sequence: None,
conflation_key: None,
tags: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
pub fn pong() -> Self {
Self {
event: Some("pusher:pong".to_string()),
data: None,
channel: None,
name: None,
user_id: None,
sequence: None,
conflation_key: None,
tags: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
pub fn channel_info(
occupied: bool,
subscription_count: Option<u64>,
user_count: Option<u64>,
cache_data: Option<(String, Duration)>,
) -> Value {
let mut response = json!({
"occupied": occupied
});
if let Some(count) = subscription_count {
response["subscription_count"] = json!(count);
}
if let Some(count) = user_count {
response["user_count"] = json!(count);
}
if let Some((data, ttl)) = cache_data {
response["cache"] = json!({
"data": data,
"ttl": ttl.as_secs()
});
}
response
}
pub fn channels_list(channels_info: AHashMap<String, Value>) -> Value {
json!({
"channels": channels_info
})
}
pub fn user_list(user_ids: Vec<String>) -> Value {
let users = user_ids
.into_iter()
.map(|id| json!({ "id": id }))
.collect::<Vec<_>>();
json!({ "users": users })
}
pub fn batch_response(batch_info: Vec<Value>) -> Value {
json!({ "batch": batch_info })
}
pub fn success_response() -> Value {
json!({ "ok": true })
}
pub fn watchlist_online_event(user_ids: Vec<String>) -> Self {
Self {
event: Some("online".to_string()),
channel: None, name: None,
data: Some(MessageData::Json(json!({
"user_ids": user_ids
}))),
user_id: None,
sequence: None,
conflation_key: None,
tags: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
pub fn watchlist_offline_event(user_ids: Vec<String>) -> Self {
Self {
event: Some("offline".to_string()),
channel: None,
name: None,
data: Some(MessageData::Json(json!({
"user_ids": user_ids
}))),
user_id: None,
sequence: None,
conflation_key: None,
tags: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
pub fn cache_miss_event(channel: String) -> Self {
Self {
event: Some("pusher:cache_miss".to_string()),
channel: Some(channel),
data: Some(MessageData::String("{}".to_string())),
name: None,
user_id: None,
sequence: None,
conflation_key: None,
tags: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
pub fn signin_success(user_data: String) -> Self {
Self {
event: Some("pusher:signin_success".to_string()),
data: Some(MessageData::Json(json!({
"user_data": user_data
}))),
channel: None,
name: None,
user_id: None,
sequence: None,
conflation_key: None,
tags: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
pub fn delta_message(
channel: String,
event: String,
delta_base64: String,
base_sequence: u32,
target_sequence: u32,
algorithm: &str,
) -> Self {
Self {
event: Some("pusher:delta".to_string()),
channel: Some(channel.clone()),
data: Some(MessageData::String(
json!({
"channel": channel,
"event": event,
"delta": delta_base64,
"base_seq": base_sequence,
"target_seq": target_sequence,
"algorithm": algorithm,
})
.to_string(),
)),
name: None,
user_id: None,
sequence: None,
conflation_key: None,
tags: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
pub fn rewrite_prefix(&mut self, version: ProtocolVersion) {
if let Some(ref event) = self.event {
self.event = Some(version.rewrite_event_prefix(event));
}
}
pub fn is_ephemeral(&self) -> bool {
self.extras
.as_ref()
.and_then(|e| e.ephemeral)
.unwrap_or(false)
}
pub fn extras_idempotency_key(&self) -> Option<&str> {
self.extras
.as_ref()
.and_then(|e| e.idempotency_key.as_deref())
}
pub fn should_echo(&self, connection_default: bool) -> bool {
self.extras
.as_ref()
.and_then(|e| e.echo)
.unwrap_or(connection_default)
}
pub fn filter_headers(&self) -> Option<&HashMap<String, ExtrasValue>> {
self.extras.as_ref().and_then(|e| e.headers.as_ref())
}
pub fn should_include_extras(protocol: &ProtocolVersion) -> bool {
matches!(protocol, ProtocolVersion::V2)
}
pub fn add_base_sequence(mut self, base_sequence: u32) -> Self {
if let Some(MessageData::String(ref data_str)) = self.data
&& let Ok(mut data_obj) = sonic_rs::from_str::<Value>(data_str)
&& let Some(obj) = data_obj.as_object_mut()
{
obj.insert("__delta_base_seq", json!(base_sequence));
self.data = Some(MessageData::String(data_obj.to_string()));
}
self
}
pub fn delta_compression_enabled(default_algorithm: &str) -> Self {
Self {
event: Some("pusher:delta_compression_enabled".to_string()),
data: Some(MessageData::Json(json!({
"enabled": true,
"default_algorithm": default_algorithm,
}))),
channel: None,
name: None,
user_id: None,
sequence: None,
conflation_key: None,
tags: None,
message_id: None,
stream_id: None,
serial: None,
idempotency_key: None,
extras: None,
delta_sequence: None,
delta_conflation_key: None,
}
}
}
pub trait InfoQueryParser {
fn parse_info(&self) -> Vec<&str>;
fn wants_user_count(&self) -> bool;
fn wants_subscription_count(&self) -> bool;
fn wants_cache(&self) -> bool;
}
impl InfoQueryParser for Option<&String> {
fn parse_info(&self) -> Vec<&str> {
self.map(|s| s.split(',').collect::<Vec<_>>())
.unwrap_or_default()
}
fn wants_user_count(&self) -> bool {
self.parse_info().contains(&"user_count")
}
fn wants_subscription_count(&self) -> bool {
self.parse_info().contains(&"subscription_count")
}
fn wants_cache(&self) -> bool {
self.parse_info().contains(&"cache")
}
}
#[cfg(test)]
mod tests {
use super::{
AnnotationEventAction, AnnotationEventData, AnnotationSummaryEnvelope, MessageSummaryData,
PusherMessage,
};
use sonic_rs::JsonValueTrait;
use std::collections::BTreeMap;
#[test]
fn protocol_heartbeat_detection_matches_both_prefix_families() {
let mut ping = PusherMessage::ping();
assert!(ping.is_protocol_ping_or_pong());
ping.rewrite_prefix(crate::protocol_version::ProtocolVersion::V2);
assert!(ping.is_protocol_ping_or_pong());
let mut pong = PusherMessage::pong();
assert!(pong.is_protocol_ping_or_pong());
pong.rewrite_prefix(crate::protocol_version::ProtocolVersion::V2);
assert!(pong.is_protocol_ping_or_pong());
}
#[test]
fn protocol_heartbeat_detection_ignores_regular_messages() {
let message = PusherMessage::channel_event(
"chat.message",
"room",
sonic_rs::json!({"text": "hello"}),
);
assert!(!message.is_protocol_ping_or_pong());
}
#[test]
fn annotation_create_serializes_camel_case_contract() {
let data = AnnotationEventData {
action: AnnotationEventAction::Create,
id: Some("ann-1".to_string()),
serial: "ann:1".to_string(),
message_serial: "msg:1".to_string(),
annotation_type: "reactions:distinct.v1".to_string(),
name: Some("thumbsup".to_string()),
client_id: Some("user-123".to_string()),
count: Some(1),
data: Some(sonic_rs::json!({"raw": true})),
encoding: None,
timestamp: 1_700_000_000_000,
};
let value = sonic_rs::to_value(&data).unwrap();
assert_eq!(value["action"].as_str(), Some("annotation.create"));
assert_eq!(value["messageSerial"].as_str(), Some("msg:1"));
assert_eq!(value["type"].as_str(), Some("reactions:distinct.v1"));
assert_eq!(value["clientId"].as_str(), Some("user-123"));
}
#[test]
fn message_summary_serializes_summary_envelope() {
let mut summary = BTreeMap::new();
summary.insert(
"reactions:distinct.v1".to_string(),
sonic_rs::json!({"thumbsup": {"total": 5, "clientIds": ["a"], "clipped": false}}),
);
let data = MessageSummaryData {
action: "message.summary".to_string(),
serial: "msg:1".to_string(),
annotations: AnnotationSummaryEnvelope { summary },
};
let value = sonic_rs::to_value(&data).unwrap();
assert_eq!(value["action"].as_str(), Some("message.summary"));
assert_eq!(value["serial"].as_str(), Some("msg:1"));
assert_eq!(
value["annotations"]["summary"]["reactions:distinct.v1"]["thumbsup"]["total"].as_u64(),
Some(5)
);
}
}