Skip to main content

sockudo_protocol/
messages.rs

1use ahash::AHashMap;
2use serde::{Deserialize, Serialize};
3use sonic_rs::prelude::*;
4use sonic_rs::{Value, json};
5use std::collections::{BTreeMap, HashMap};
6use std::time::Duration;
7
8use crate::protocol_version::ProtocolVersion;
9
10/// Allowed value types for extras.headers.
11/// Flat only — no Object or Array variant so nesting is structurally impossible.
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
13#[serde(untagged)]
14pub enum ExtrasValue {
15    String(String),
16    Number(f64),
17    Bool(bool),
18}
19
20/// Structured metadata envelope for V2-specific message features.
21///
22/// Present on the wire for V2 connections only. V1 connections receive messages
23/// with extras stripped entirely. Pusher SDKs ignore unknown fields so the
24/// field is safe to carry through internal pipelines even when the publisher
25/// is V1.
26#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
27#[serde(rename_all = "camelCase")]
28pub struct MessageExtras {
29    /// Flat metadata for server-side event name filtering.
30    /// Must be a flat object — no nested objects, no arrays.
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub headers: Option<HashMap<String, ExtrasValue>>,
33
34    /// If true: skip connection recovery buffer and webhook forwarding.
35    /// Deliver to currently connected V2 subscribers only.
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub ephemeral: Option<bool>,
38
39    /// Server-side deduplication key. If the same key arrives again within
40    /// the app's idempotency TTL window, the message is silently dropped.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub idempotency_key: Option<String>,
43
44    /// Per-message echo control. Overrides the connection-level echo setting
45    /// when explicitly set.
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub echo: Option<bool>,
48}
49
50impl MessageExtras {
51    /// Validate that headers (if present) contain only flat scalar values.
52    /// This is structurally guaranteed by `ExtrasValue` having no Object/Array
53    /// variants, but this method provides an explicit check with a clear error
54    /// when validating raw JSON before deserialization.
55    pub fn validate_headers_from_json(raw: &Value) -> Result<(), String> {
56        if let Some(extras) = raw.get("extras")
57            && let Some(headers) = extras.get("headers")
58            && let Some(obj) = headers.as_object()
59        {
60            for (key, val) in obj.iter() {
61                if val.is_object() || val.is_array() {
62                    return Err(format!(
63                        "extras.headers must be a flat object — nested objects and arrays are not allowed (key: '{key}')"
64                    ));
65                }
66            }
67        }
68        Ok(())
69    }
70}
71
72/// Generate a unique message ID (UUIDv4) for client-side deduplication.
73pub fn generate_message_id() -> String {
74    uuid::Uuid::new_v4().to_string()
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct PresenceData {
79    pub ids: Vec<String>,
80    pub hash: AHashMap<String, Option<Value>>,
81    pub count: usize,
82}
83
84#[derive(Debug, Clone, Serialize)]
85#[serde(untagged)]
86pub enum MessageData {
87    String(String),
88    Structured {
89        #[serde(skip_serializing_if = "Option::is_none")]
90        channel_data: Option<String>,
91        #[serde(skip_serializing_if = "Option::is_none")]
92        channel: Option<String>,
93        #[serde(skip_serializing_if = "Option::is_none")]
94        user_data: Option<String>,
95        #[serde(flatten)]
96        extra: AHashMap<String, Value>,
97    },
98    Json(Value),
99}
100
101impl<'de> Deserialize<'de> for MessageData {
102    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
103    where
104        D: serde::Deserializer<'de>,
105    {
106        let v = Value::deserialize(deserializer)?;
107        if let Some(s) = v.as_str() {
108            return Ok(MessageData::String(s.to_string()));
109        }
110        if let Some(obj) = v.as_object() {
111            // Flatten workaround for sonic-rs issue #114:
112            // manually split known structured keys and keep remaining keys in `extra`.
113            let channel_data = obj
114                .get(&"channel_data")
115                .and_then(|x| x.as_str())
116                .map(ToString::to_string);
117            let channel = obj
118                .get(&"channel")
119                .and_then(|x| x.as_str())
120                .map(ToString::to_string);
121            let user_data = obj
122                .get(&"user_data")
123                .and_then(|x| x.as_str())
124                .map(ToString::to_string);
125
126            if channel_data.is_some() || channel.is_some() || user_data.is_some() {
127                let mut extra = AHashMap::new();
128                for (k, val) in obj.iter() {
129                    if k != "channel_data" && k != "channel" && k != "user_data" {
130                        extra.insert(k.to_string(), val.clone());
131                    }
132                }
133                return Ok(MessageData::Structured {
134                    channel_data,
135                    channel,
136                    user_data,
137                    extra,
138                });
139            }
140        }
141        Ok(MessageData::Json(v))
142    }
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct ErrorData {
147    pub code: Option<u16>,
148    pub message: String,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct PusherMessage {
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub event: Option<String>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub channel: Option<String>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub data: Option<MessageData>,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub name: Option<String>,
161    #[serde(skip_serializing_if = "Option::is_none")]
162    pub user_id: Option<String>,
163    /// Tags for filtering - uses BTreeMap for deterministic serialization order
164    /// which is required for delta compression to work correctly
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub tags: Option<BTreeMap<String, String>>,
167    /// Delta compression sequence number for full messages
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub sequence: Option<u64>,
170    /// Delta compression conflation key for message grouping
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub conflation_key: Option<String>,
173    /// Unique message ID for client-side deduplication
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub message_id: Option<String>,
176    /// Opaque per-channel continuity token for durable history and recovery.
177    /// Changes only when the server can no longer prove continuity for the channel stream.
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub stream_id: Option<String>,
180    /// Monotonically increasing serial for connection recovery.
181    /// Assigned per-channel at broadcast time when connection recovery is enabled.
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub serial: Option<u64>,
184    /// Idempotency key for cross-region deduplication.
185    /// Threaded from the HTTP publish request through the broadcast pipeline
186    /// so that receiving nodes can register it in their local cache.
187    /// Never sent to WebSocket clients.
188    #[serde(skip_serializing_if = "Option::is_none")]
189    pub idempotency_key: Option<String>,
190    /// V2 message extras envelope. Carries ephemeral flag, per-message echo
191    /// control, header-based filtering metadata, and extras-level idempotency.
192    /// Stripped from V1 deliveries; included in V2 wire format.
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub extras: Option<MessageExtras>,
195    /// Delta sequence marker for full messages in V2 delta streams.
196    #[serde(rename = "__delta_seq", skip_serializing_if = "Option::is_none")]
197    pub delta_sequence: Option<u64>,
198    /// Delta conflation key marker for full messages in V2 delta streams.
199    #[serde(rename = "__conflation_key", skip_serializing_if = "Option::is_none")]
200    pub delta_conflation_key: Option<String>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct PusherApiMessage {
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub name: Option<String>,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub data: Option<ApiMessageData>,
209    #[serde(skip_serializing_if = "Option::is_none")]
210    pub channel: Option<String>,
211    #[serde(skip_serializing_if = "Option::is_none")]
212    pub channels: Option<Vec<String>>,
213    #[serde(skip_serializing_if = "Option::is_none")]
214    pub socket_id: Option<String>,
215    #[serde(skip_serializing_if = "Option::is_none")]
216    pub info: Option<String>,
217    #[serde(skip_serializing_if = "Option::is_none")]
218    pub tags: Option<AHashMap<String, String>>,
219    /// Per-publish delta compression control.
220    /// - `Some(true)`: Force delta compression for this message (if client supports it)
221    /// - `Some(false)`: Force full message (skip delta compression)
222    /// - `None`: Use default behavior based on channel/global configuration
223    #[serde(skip_serializing_if = "Option::is_none")]
224    pub delta: Option<bool>,
225    /// Idempotency key for deduplicating publish requests.
226    /// If the same key is seen within the TTL window, the server returns the
227    /// cached response without re-broadcasting.
228    #[serde(skip_serializing_if = "Option::is_none")]
229    pub idempotency_key: Option<String>,
230    /// V2 extras envelope. Passed through to PusherMessage for V2 delivery.
231    #[serde(skip_serializing_if = "Option::is_none")]
232    pub extras: Option<MessageExtras>,
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct BatchPusherApiMessage {
237    pub batch: Vec<PusherApiMessage>,
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
241#[serde(untagged)]
242pub enum ApiMessageData {
243    String(String),
244    Json(Value),
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct SentPusherMessage {
249    #[serde(skip_serializing_if = "Option::is_none")]
250    pub channel: Option<String>,
251    #[serde(skip_serializing_if = "Option::is_none")]
252    pub event: Option<String>,
253    #[serde(skip_serializing_if = "Option::is_none")]
254    pub data: Option<MessageData>,
255}
256
257// Helper implementations
258impl MessageData {
259    pub fn as_string(&self) -> Option<&str> {
260        match self {
261            MessageData::String(s) => Some(s),
262            _ => None,
263        }
264    }
265
266    pub fn into_string(self) -> Option<String> {
267        match self {
268            MessageData::String(s) => Some(s),
269            _ => None,
270        }
271    }
272
273    pub fn as_value(&self) -> Option<&Value> {
274        match self {
275            MessageData::Structured { extra, .. } => extra.values().next(),
276            _ => None,
277        }
278    }
279}
280
281impl From<String> for MessageData {
282    fn from(s: String) -> Self {
283        MessageData::String(s)
284    }
285}
286
287impl From<Value> for MessageData {
288    fn from(v: Value) -> Self {
289        MessageData::Json(v)
290    }
291}
292
293impl PusherMessage {
294    pub fn connection_established(socket_id: String, activity_timeout: u64) -> Self {
295        Self {
296            event: Some("pusher:connection_established".to_string()),
297            data: Some(MessageData::from(
298                json!({
299                    "socket_id": socket_id,
300                    "activity_timeout": activity_timeout  // Now configurable
301                })
302                .to_string(),
303            )),
304            channel: None,
305            name: None,
306            user_id: None,
307            sequence: None,
308            conflation_key: None,
309            tags: None,
310            message_id: None,
311            stream_id: None,
312            serial: None,
313            idempotency_key: None,
314            extras: None,
315            delta_sequence: None,
316            delta_conflation_key: None,
317        }
318    }
319    pub fn subscription_succeeded(channel: String, presence_data: Option<PresenceData>) -> Self {
320        let data_obj = if let Some(data) = presence_data {
321            json!({
322                "presence": {
323                    "ids": data.ids,
324                    "hash": data.hash,
325                    "count": data.count
326                }
327            })
328        } else {
329            json!({})
330        };
331
332        Self {
333            event: Some("pusher_internal:subscription_succeeded".to_string()),
334            channel: Some(channel),
335            data: Some(MessageData::String(data_obj.to_string())),
336            name: None,
337            user_id: None,
338            sequence: None,
339            conflation_key: None,
340            tags: None,
341            message_id: None,
342            stream_id: None,
343            serial: None,
344            idempotency_key: None,
345            extras: None,
346            delta_sequence: None,
347            delta_conflation_key: None,
348        }
349    }
350
351    pub fn error(code: u16, message: String, channel: Option<String>) -> Self {
352        Self {
353            event: Some("pusher:error".to_string()),
354            data: Some(MessageData::Json(json!({
355                "code": code,
356                "message": message
357            }))),
358            channel,
359            name: None,
360            user_id: None,
361            sequence: None,
362            conflation_key: None,
363            tags: None,
364            message_id: None,
365            stream_id: None,
366            serial: None,
367            idempotency_key: None,
368            extras: None,
369            delta_sequence: None,
370            delta_conflation_key: None,
371        }
372    }
373
374    pub fn ping() -> Self {
375        Self {
376            event: Some("pusher:ping".to_string()),
377            data: None,
378            channel: None,
379            name: None,
380            user_id: None,
381            sequence: None,
382            conflation_key: None,
383            tags: None,
384            message_id: None,
385            stream_id: None,
386            serial: None,
387            idempotency_key: None,
388            extras: None,
389            delta_sequence: None,
390            delta_conflation_key: None,
391        }
392    }
393    pub fn channel_event<S: Into<String>>(event: S, channel: S, data: Value) -> Self {
394        Self {
395            event: Some(event.into()),
396            channel: Some(channel.into()),
397            data: Some(MessageData::String(data.to_string())),
398            name: None,
399            user_id: None,
400            sequence: None,
401            conflation_key: None,
402            tags: None,
403            message_id: None,
404            stream_id: None,
405            serial: None,
406            idempotency_key: None,
407            extras: None,
408            delta_sequence: None,
409            delta_conflation_key: None,
410        }
411    }
412
413    pub fn member_added(channel: String, user_id: String, user_info: Option<Value>) -> Self {
414        Self {
415            event: Some("pusher_internal:member_added".to_string()),
416            channel: Some(channel),
417            // FIX: Use MessageData::String with JSON-encoded string instead of MessageData::Json
418            data: Some(MessageData::String(
419                json!({
420                    "user_id": user_id,
421                    "user_info": user_info.unwrap_or_else(|| json!({}))
422                })
423                .to_string(),
424            )),
425            name: None,
426            user_id: None,
427            sequence: None,
428            conflation_key: None,
429            tags: None,
430            message_id: None,
431            stream_id: None,
432            serial: None,
433            idempotency_key: None,
434            extras: None,
435            delta_sequence: None,
436            delta_conflation_key: None,
437        }
438    }
439
440    pub fn member_removed(channel: String, user_id: String) -> Self {
441        Self {
442            event: Some("pusher_internal:member_removed".to_string()),
443            channel: Some(channel),
444            // FIX: Also apply same fix to member_removed for consistency
445            data: Some(MessageData::String(
446                json!({
447                    "user_id": user_id
448                })
449                .to_string(),
450            )),
451            name: None,
452            user_id: None,
453            sequence: None,
454            conflation_key: None,
455            tags: None,
456            message_id: None,
457            stream_id: None,
458            serial: None,
459            idempotency_key: None,
460            extras: None,
461            delta_sequence: None,
462            delta_conflation_key: None,
463        }
464    }
465
466    // New helper method for pong response
467    pub fn pong() -> Self {
468        Self {
469            event: Some("pusher:pong".to_string()),
470            data: None,
471            channel: None,
472            name: None,
473            user_id: None,
474            sequence: None,
475            conflation_key: None,
476            tags: None,
477            message_id: None,
478            stream_id: None,
479            serial: None,
480            idempotency_key: None,
481            extras: None,
482            delta_sequence: None,
483            delta_conflation_key: None,
484        }
485    }
486
487    // Helper for creating channel info response
488    pub fn channel_info(
489        occupied: bool,
490        subscription_count: Option<u64>,
491        user_count: Option<u64>,
492        cache_data: Option<(String, Duration)>,
493    ) -> Value {
494        let mut response = json!({
495            "occupied": occupied
496        });
497
498        if let Some(count) = subscription_count {
499            response["subscription_count"] = json!(count);
500        }
501
502        if let Some(count) = user_count {
503            response["user_count"] = json!(count);
504        }
505
506        if let Some((data, ttl)) = cache_data {
507            response["cache"] = json!({
508                "data": data,
509                "ttl": ttl.as_secs()
510            });
511        }
512
513        response
514    }
515
516    // Helper for creating channels list response
517    pub fn channels_list(channels_info: AHashMap<String, Value>) -> Value {
518        json!({
519            "channels": channels_info
520        })
521    }
522
523    // Helper for creating user list response
524    pub fn user_list(user_ids: Vec<String>) -> Value {
525        let users = user_ids
526            .into_iter()
527            .map(|id| json!({ "id": id }))
528            .collect::<Vec<_>>();
529
530        json!({ "users": users })
531    }
532
533    // Helper for batch events response
534    pub fn batch_response(batch_info: Vec<Value>) -> Value {
535        json!({ "batch": batch_info })
536    }
537
538    // Helper for simple success response
539    pub fn success_response() -> Value {
540        json!({ "ok": true })
541    }
542
543    pub fn watchlist_online_event(user_ids: Vec<String>) -> Self {
544        Self {
545            event: Some("online".to_string()),
546            channel: None, // Watchlist events don't use channels
547            name: None,
548            data: Some(MessageData::Json(json!({
549                "user_ids": user_ids
550            }))),
551            user_id: None,
552            sequence: None,
553            conflation_key: None,
554            tags: None,
555            message_id: None,
556            stream_id: None,
557            serial: None,
558            idempotency_key: None,
559            extras: None,
560            delta_sequence: None,
561            delta_conflation_key: None,
562        }
563    }
564
565    pub fn watchlist_offline_event(user_ids: Vec<String>) -> Self {
566        Self {
567            event: Some("offline".to_string()),
568            channel: None,
569            name: None,
570            data: Some(MessageData::Json(json!({
571                "user_ids": user_ids
572            }))),
573            user_id: None,
574            sequence: None,
575            conflation_key: None,
576            tags: None,
577            message_id: None,
578            stream_id: None,
579            serial: None,
580            idempotency_key: None,
581            extras: None,
582            delta_sequence: None,
583            delta_conflation_key: None,
584        }
585    }
586
587    pub fn cache_miss_event(channel: String) -> Self {
588        Self {
589            event: Some("pusher:cache_miss".to_string()),
590            channel: Some(channel),
591            data: Some(MessageData::String("{}".to_string())),
592            name: None,
593            user_id: None,
594            sequence: None,
595            conflation_key: None,
596            tags: None,
597            message_id: None,
598            stream_id: None,
599            serial: None,
600            idempotency_key: None,
601            extras: None,
602            delta_sequence: None,
603            delta_conflation_key: None,
604        }
605    }
606
607    pub fn signin_success(user_data: String) -> Self {
608        Self {
609            event: Some("pusher:signin_success".to_string()),
610            data: Some(MessageData::Json(json!({
611                "user_data": user_data
612            }))),
613            channel: None,
614            name: None,
615            user_id: None,
616            sequence: None,
617            conflation_key: None,
618            tags: None,
619            message_id: None,
620            stream_id: None,
621            serial: None,
622            idempotency_key: None,
623            extras: None,
624            delta_sequence: None,
625            delta_conflation_key: None,
626        }
627    }
628
629    /// Create a delta-compressed message
630    pub fn delta_message(
631        channel: String,
632        event: String,
633        delta_base64: String,
634        base_sequence: u32,
635        target_sequence: u32,
636        algorithm: &str,
637    ) -> Self {
638        Self {
639            event: Some("pusher:delta".to_string()),
640            channel: Some(channel.clone()),
641            data: Some(MessageData::String(
642                json!({
643                    "channel": channel,
644                    "event": event,
645                    "delta": delta_base64,
646                    "base_seq": base_sequence,
647                    "target_seq": target_sequence,
648                    "algorithm": algorithm,
649                })
650                .to_string(),
651            )),
652            name: None,
653            user_id: None,
654            sequence: None,
655            conflation_key: None,
656            tags: None,
657            message_id: None,
658            stream_id: None,
659            serial: None,
660            idempotency_key: None,
661            extras: None,
662            delta_sequence: None,
663            delta_conflation_key: None,
664        }
665    }
666
667    /// Rewrite the event name prefix to match the given protocol version.
668    /// This is the single translation point between V1 (`pusher:`) and V2 (`sockudo:`) wire formats.
669    pub fn rewrite_prefix(&mut self, version: ProtocolVersion) {
670        if let Some(ref event) = self.event {
671            self.event = Some(version.rewrite_event_prefix(event));
672        }
673    }
674
675    /// Returns true if this message is ephemeral (skip recovery buffer and webhooks).
676    pub fn is_ephemeral(&self) -> bool {
677        self.extras
678            .as_ref()
679            .and_then(|e| e.ephemeral)
680            .unwrap_or(false)
681    }
682
683    /// Returns the extras-level idempotency key, if set.
684    pub fn extras_idempotency_key(&self) -> Option<&str> {
685        self.extras
686            .as_ref()
687            .and_then(|e| e.idempotency_key.as_deref())
688    }
689
690    /// Resolve whether this message should be echoed back to the publishing socket.
691    /// Message-level `extras.echo` takes precedence over the connection default.
692    pub fn should_echo(&self, connection_default: bool) -> bool {
693        self.extras
694            .as_ref()
695            .and_then(|e| e.echo)
696            .unwrap_or(connection_default)
697    }
698
699    /// Returns the extras headers for server-side filtering, if present.
700    pub fn filter_headers(&self) -> Option<&HashMap<String, ExtrasValue>> {
701        self.extras.as_ref().and_then(|e| e.headers.as_ref())
702    }
703
704    /// Returns true if the given protocol version should receive extras in delivered messages.
705    pub fn should_include_extras(protocol: &ProtocolVersion) -> bool {
706        matches!(protocol, ProtocolVersion::V2)
707    }
708
709    /// Add base sequence marker to a full message for delta tracking
710    pub fn add_base_sequence(mut self, base_sequence: u32) -> Self {
711        if let Some(MessageData::String(ref data_str)) = self.data
712            && let Ok(mut data_obj) = sonic_rs::from_str::<Value>(data_str)
713            && let Some(obj) = data_obj.as_object_mut()
714        {
715            obj.insert("__delta_base_seq", json!(base_sequence));
716            self.data = Some(MessageData::String(data_obj.to_string()));
717        }
718        self
719    }
720
721    /// Create delta compression enabled confirmation
722    pub fn delta_compression_enabled(default_algorithm: &str) -> Self {
723        Self {
724            event: Some("pusher:delta_compression_enabled".to_string()),
725            data: Some(MessageData::Json(json!({
726                "enabled": true,
727                "default_algorithm": default_algorithm,
728            }))),
729            channel: None,
730            name: None,
731            user_id: None,
732            sequence: None,
733            conflation_key: None,
734            tags: None,
735            message_id: None,
736            stream_id: None,
737            serial: None,
738            idempotency_key: None,
739            extras: None,
740            delta_sequence: None,
741            delta_conflation_key: None,
742        }
743    }
744}
745
746// Add a helper extension trait for working with info parameters
747pub trait InfoQueryParser {
748    fn parse_info(&self) -> Vec<&str>;
749    fn wants_user_count(&self) -> bool;
750    fn wants_subscription_count(&self) -> bool;
751    fn wants_cache(&self) -> bool;
752}
753
754impl InfoQueryParser for Option<&String> {
755    fn parse_info(&self) -> Vec<&str> {
756        self.map(|s| s.split(',').collect::<Vec<_>>())
757            .unwrap_or_default()
758    }
759
760    fn wants_user_count(&self) -> bool {
761        self.parse_info().contains(&"user_count")
762    }
763
764    fn wants_subscription_count(&self) -> bool {
765        self.parse_info().contains(&"subscription_count")
766    }
767
768    fn wants_cache(&self) -> bool {
769        self.parse_info().contains(&"cache")
770    }
771}