Skip to main content

sockudo_protocol/
wire.rs

1use ahash::AHashMap;
2use prost::Message;
3use serde::{Deserialize, Serialize};
4use sonic_rs::Value;
5use std::collections::{BTreeMap, HashMap};
6
7use crate::messages::{ExtrasValue, MessageData, MessageExtras, PusherMessage};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
10#[serde(rename_all = "lowercase")]
11pub enum WireFormat {
12    #[default]
13    Json,
14    MessagePack,
15    Protobuf,
16}
17
18impl WireFormat {
19    pub fn from_query_param(value: Option<&str>) -> Self {
20        Self::parse_query_param(value).unwrap_or(Self::Json)
21    }
22
23    pub fn parse_query_param(value: Option<&str>) -> Result<Self, String> {
24        match value.map(|v| v.trim().to_ascii_lowercase()) {
25            None => Ok(Self::Json),
26            Some(v) if v.is_empty() || v == "json" => Ok(Self::Json),
27            Some(v) if v == "msgpack" || v == "messagepack" => Ok(Self::MessagePack),
28            Some(v) if v == "protobuf" || v == "proto" => Ok(Self::Protobuf),
29            Some(v) => Err(format!("unsupported wire format '{v}'")),
30        }
31    }
32
33    pub const fn is_binary(self) -> bool {
34        !matches!(self, Self::Json)
35    }
36}
37
38pub fn serialize_message(message: &PusherMessage, format: WireFormat) -> Result<Vec<u8>, String> {
39    match format {
40        WireFormat::Json => {
41            sonic_rs::to_vec(message).map_err(|e| format!("JSON serialization failed: {e}"))
42        }
43        WireFormat::MessagePack => rmp_serde::to_vec(&MsgpackPusherMessage::from(message.clone()))
44            .map_err(|e| format!("MessagePack serialization failed: {e}")),
45        WireFormat::Protobuf => {
46            let proto = ProtoPusherMessage::from(message.clone());
47            let mut buf = Vec::with_capacity(proto.encoded_len());
48            proto
49                .encode(&mut buf)
50                .map_err(|e| format!("Protobuf serialization failed: {e}"))?;
51            Ok(buf)
52        }
53    }
54}
55
56pub fn deserialize_message(bytes: &[u8], format: WireFormat) -> Result<PusherMessage, String> {
57    match format {
58        WireFormat::Json => {
59            sonic_rs::from_slice(bytes).map_err(|e| format!("JSON deserialization failed: {e}"))
60        }
61        WireFormat::MessagePack => {
62            let msg: MsgpackPusherMessage = rmp_serde::from_slice(bytes)
63                .map_err(|e| format!("MessagePack deserialization failed: {e}"))?;
64            Ok(msg.into())
65        }
66        WireFormat::Protobuf => {
67            let proto = ProtoPusherMessage::decode(bytes)
68                .map_err(|e| format!("Protobuf deserialization failed: {e}"))?;
69            Ok(proto.into())
70        }
71    }
72}
73
74#[derive(Clone, PartialEq, Message)]
75struct ProtoPusherMessage {
76    #[prost(string, optional, tag = "1")]
77    event: Option<String>,
78    #[prost(string, optional, tag = "2")]
79    channel: Option<String>,
80    #[prost(message, optional, tag = "3")]
81    data: Option<ProtoMessageData>,
82    #[prost(string, optional, tag = "4")]
83    name: Option<String>,
84    #[prost(string, optional, tag = "5")]
85    user_id: Option<String>,
86    #[prost(map = "string, string", tag = "6")]
87    tags: HashMap<String, String>,
88    #[prost(uint64, optional, tag = "7")]
89    sequence: Option<u64>,
90    #[prost(string, optional, tag = "8")]
91    conflation_key: Option<String>,
92    #[prost(string, optional, tag = "9")]
93    message_id: Option<String>,
94    #[prost(uint64, optional, tag = "10")]
95    serial: Option<u64>,
96    #[prost(string, optional, tag = "11")]
97    idempotency_key: Option<String>,
98    #[prost(message, optional, tag = "12")]
99    extras: Option<ProtoMessageExtras>,
100    #[prost(uint64, optional, tag = "13")]
101    delta_sequence: Option<u64>,
102    #[prost(string, optional, tag = "14")]
103    delta_conflation_key: Option<String>,
104}
105
106#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
107struct MsgpackPusherMessage {
108    event: Option<String>,
109    channel: Option<String>,
110    data: Option<MsgpackMessageData>,
111    name: Option<String>,
112    user_id: Option<String>,
113    tags: Option<BTreeMap<String, String>>,
114    sequence: Option<u64>,
115    conflation_key: Option<String>,
116    message_id: Option<String>,
117    serial: Option<u64>,
118    idempotency_key: Option<String>,
119    extras: Option<MsgpackMessageExtras>,
120    delta_sequence: Option<u64>,
121    delta_conflation_key: Option<String>,
122}
123
124#[derive(Clone, PartialEq, Message)]
125struct ProtoMessageData {
126    #[prost(oneof = "proto_message_data::Kind", tags = "1, 2, 3")]
127    kind: Option<proto_message_data::Kind>,
128}
129
130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
131#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
132enum MsgpackMessageData {
133    String(String),
134    Structured(MsgpackStructuredData),
135    Json(String),
136}
137
138mod proto_message_data {
139    use super::ProtoStructuredData;
140    use prost::Oneof;
141
142    #[derive(Clone, PartialEq, Oneof)]
143    pub enum Kind {
144        #[prost(string, tag = "1")]
145        String(String),
146        #[prost(message, tag = "2")]
147        Structured(ProtoStructuredData),
148        #[prost(string, tag = "3")]
149        Json(String),
150    }
151}
152
153#[derive(Clone, PartialEq, Message)]
154struct ProtoStructuredData {
155    #[prost(string, optional, tag = "1")]
156    channel_data: Option<String>,
157    #[prost(string, optional, tag = "2")]
158    channel: Option<String>,
159    #[prost(string, optional, tag = "3")]
160    user_data: Option<String>,
161    #[prost(map = "string, string", tag = "4")]
162    extra: HashMap<String, String>,
163}
164
165#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
166struct MsgpackStructuredData {
167    channel_data: Option<String>,
168    channel: Option<String>,
169    user_data: Option<String>,
170    extra: HashMap<String, String>,
171}
172
173#[derive(Clone, PartialEq, Message)]
174struct ProtoMessageExtras {
175    #[prost(map = "string, message", tag = "1")]
176    headers: HashMap<String, ProtoExtrasValue>,
177    #[prost(bool, optional, tag = "2")]
178    ephemeral: Option<bool>,
179    #[prost(string, optional, tag = "3")]
180    idempotency_key: Option<String>,
181    #[prost(bool, optional, tag = "4")]
182    echo: Option<bool>,
183}
184
185#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
186struct MsgpackMessageExtras {
187    headers: Option<HashMap<String, MsgpackExtrasValue>>,
188    ephemeral: Option<bool>,
189    idempotency_key: Option<String>,
190    echo: Option<bool>,
191}
192
193#[derive(Clone, PartialEq, Message)]
194struct ProtoExtrasValue {
195    #[prost(oneof = "proto_extras_value::Kind", tags = "1, 2, 3")]
196    kind: Option<proto_extras_value::Kind>,
197}
198
199#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
200#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
201enum MsgpackExtrasValue {
202    String(String),
203    Number(f64),
204    Bool(bool),
205}
206
207mod proto_extras_value {
208    use prost::Oneof;
209
210    #[derive(Clone, PartialEq, Oneof)]
211    pub enum Kind {
212        #[prost(string, tag = "1")]
213        String(String),
214        #[prost(double, tag = "2")]
215        Number(f64),
216        #[prost(bool, tag = "3")]
217        Bool(bool),
218    }
219}
220
221impl From<PusherMessage> for ProtoPusherMessage {
222    fn from(value: PusherMessage) -> Self {
223        Self {
224            event: value.event,
225            channel: value.channel,
226            data: value.data.map(Into::into),
227            name: value.name,
228            user_id: value.user_id,
229            tags: value
230                .tags
231                .map(|m| m.into_iter().collect())
232                .unwrap_or_default(),
233            sequence: value.sequence,
234            conflation_key: value.conflation_key,
235            message_id: value.message_id,
236            serial: value.serial,
237            idempotency_key: value.idempotency_key,
238            extras: value.extras.map(Into::into),
239            delta_sequence: value.delta_sequence,
240            delta_conflation_key: value.delta_conflation_key,
241        }
242    }
243}
244
245impl From<PusherMessage> for MsgpackPusherMessage {
246    fn from(value: PusherMessage) -> Self {
247        Self {
248            event: value.event,
249            channel: value.channel,
250            data: value.data.map(Into::into),
251            name: value.name,
252            user_id: value.user_id,
253            tags: value.tags,
254            sequence: value.sequence,
255            conflation_key: value.conflation_key,
256            message_id: value.message_id,
257            serial: value.serial,
258            idempotency_key: value.idempotency_key,
259            extras: value.extras.map(Into::into),
260            delta_sequence: value.delta_sequence,
261            delta_conflation_key: value.delta_conflation_key,
262        }
263    }
264}
265
266impl From<ProtoPusherMessage> for PusherMessage {
267    fn from(value: ProtoPusherMessage) -> Self {
268        Self {
269            event: value.event,
270            channel: value.channel,
271            data: value.data.map(Into::into),
272            name: value.name,
273            user_id: value.user_id,
274            tags: (!value.tags.is_empty())
275                .then_some(value.tags.into_iter().collect::<BTreeMap<_, _>>()),
276            sequence: value.sequence,
277            conflation_key: value.conflation_key,
278            message_id: value.message_id,
279            serial: value.serial,
280            idempotency_key: value.idempotency_key,
281            extras: value.extras.map(Into::into),
282            delta_sequence: value.delta_sequence,
283            delta_conflation_key: value.delta_conflation_key,
284        }
285    }
286}
287
288impl From<MsgpackPusherMessage> for PusherMessage {
289    fn from(value: MsgpackPusherMessage) -> Self {
290        Self {
291            event: value.event,
292            channel: value.channel,
293            data: value.data.map(Into::into),
294            name: value.name,
295            user_id: value.user_id,
296            tags: value.tags,
297            sequence: value.sequence,
298            conflation_key: value.conflation_key,
299            message_id: value.message_id,
300            serial: value.serial,
301            idempotency_key: value.idempotency_key,
302            extras: value.extras.map(Into::into),
303            delta_sequence: value.delta_sequence,
304            delta_conflation_key: value.delta_conflation_key,
305        }
306    }
307}
308
309impl From<MessageData> for ProtoMessageData {
310    fn from(value: MessageData) -> Self {
311        let kind = match value {
312            MessageData::String(s) => Some(proto_message_data::Kind::String(s)),
313            MessageData::Structured {
314                channel_data,
315                channel,
316                user_data,
317                extra,
318            } => Some(proto_message_data::Kind::Structured(ProtoStructuredData {
319                channel_data,
320                channel,
321                user_data,
322                extra: extra
323                    .into_iter()
324                    .map(|(k, v)| {
325                        (
326                            k,
327                            sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
328                        )
329                    })
330                    .collect(),
331            })),
332            MessageData::Json(v) => Some(proto_message_data::Kind::Json(
333                sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
334            )),
335        };
336
337        Self { kind }
338    }
339}
340
341impl From<MessageData> for MsgpackMessageData {
342    fn from(value: MessageData) -> Self {
343        match value {
344            MessageData::String(s) => Self::String(s),
345            MessageData::Structured {
346                channel_data,
347                channel,
348                user_data,
349                extra,
350            } => Self::Structured(MsgpackStructuredData {
351                channel_data,
352                channel,
353                user_data,
354                extra: extra
355                    .into_iter()
356                    .map(|(k, v)| {
357                        (
358                            k,
359                            sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
360                        )
361                    })
362                    .collect(),
363            }),
364            MessageData::Json(v) => {
365                Self::Json(sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()))
366            }
367        }
368    }
369}
370
371impl From<ProtoMessageData> for MessageData {
372    fn from(value: ProtoMessageData) -> Self {
373        match value.kind {
374            Some(proto_message_data::Kind::String(s)) => MessageData::String(s),
375            Some(proto_message_data::Kind::Structured(s)) => MessageData::Structured {
376                channel_data: s.channel_data,
377                channel: s.channel,
378                user_data: s.user_data,
379                extra: s
380                    .extra
381                    .into_iter()
382                    .map(|(k, v)| {
383                        let parsed =
384                            sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str()));
385                        (k, parsed)
386                    })
387                    .collect::<AHashMap<_, _>>(),
388            },
389            Some(proto_message_data::Kind::Json(v)) => MessageData::Json(
390                sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str())),
391            ),
392            None => MessageData::Json(Value::new_null()),
393        }
394    }
395}
396
397impl From<MsgpackMessageData> for MessageData {
398    fn from(value: MsgpackMessageData) -> Self {
399        match value {
400            MsgpackMessageData::String(s) => MessageData::String(s),
401            MsgpackMessageData::Structured(s) => MessageData::Structured {
402                channel_data: s.channel_data,
403                channel: s.channel,
404                user_data: s.user_data,
405                extra: s
406                    .extra
407                    .into_iter()
408                    .map(|(k, v)| {
409                        let parsed =
410                            sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str()));
411                        (k, parsed)
412                    })
413                    .collect::<AHashMap<_, _>>(),
414            },
415            MsgpackMessageData::Json(v) => MessageData::Json(
416                sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str())),
417            ),
418        }
419    }
420}
421
422impl From<MessageExtras> for ProtoMessageExtras {
423    fn from(value: MessageExtras) -> Self {
424        Self {
425            headers: value
426                .headers
427                .unwrap_or_default()
428                .into_iter()
429                .map(|(k, v)| (k, v.into()))
430                .collect(),
431            ephemeral: value.ephemeral,
432            idempotency_key: value.idempotency_key,
433            echo: value.echo,
434        }
435    }
436}
437
438impl From<MessageExtras> for MsgpackMessageExtras {
439    fn from(value: MessageExtras) -> Self {
440        Self {
441            headers: value
442                .headers
443                .map(|headers| headers.into_iter().map(|(k, v)| (k, v.into())).collect()),
444            ephemeral: value.ephemeral,
445            idempotency_key: value.idempotency_key,
446            echo: value.echo,
447        }
448    }
449}
450
451impl From<ProtoMessageExtras> for MessageExtras {
452    fn from(value: ProtoMessageExtras) -> Self {
453        Self {
454            headers: (!value.headers.is_empty()).then_some(
455                value
456                    .headers
457                    .into_iter()
458                    .map(|(k, v)| (k, v.into()))
459                    .collect(),
460            ),
461            ephemeral: value.ephemeral,
462            idempotency_key: value.idempotency_key,
463            echo: value.echo,
464        }
465    }
466}
467
468impl From<MsgpackMessageExtras> for MessageExtras {
469    fn from(value: MsgpackMessageExtras) -> Self {
470        Self {
471            headers: value
472                .headers
473                .map(|headers| headers.into_iter().map(|(k, v)| (k, v.into())).collect()),
474            ephemeral: value.ephemeral,
475            idempotency_key: value.idempotency_key,
476            echo: value.echo,
477        }
478    }
479}
480
481impl From<ExtrasValue> for ProtoExtrasValue {
482    fn from(value: ExtrasValue) -> Self {
483        let kind = match value {
484            ExtrasValue::String(s) => Some(proto_extras_value::Kind::String(s)),
485            ExtrasValue::Number(n) => Some(proto_extras_value::Kind::Number(n)),
486            ExtrasValue::Bool(b) => Some(proto_extras_value::Kind::Bool(b)),
487        };
488        Self { kind }
489    }
490}
491
492impl From<ExtrasValue> for MsgpackExtrasValue {
493    fn from(value: ExtrasValue) -> Self {
494        match value {
495            ExtrasValue::String(s) => Self::String(s),
496            ExtrasValue::Number(n) => Self::Number(n),
497            ExtrasValue::Bool(b) => Self::Bool(b),
498        }
499    }
500}
501
502impl From<ProtoExtrasValue> for ExtrasValue {
503    fn from(value: ProtoExtrasValue) -> Self {
504        match value.kind {
505            Some(proto_extras_value::Kind::String(s)) => ExtrasValue::String(s),
506            Some(proto_extras_value::Kind::Number(n)) => ExtrasValue::Number(n),
507            Some(proto_extras_value::Kind::Bool(b)) => ExtrasValue::Bool(b),
508            None => ExtrasValue::String(String::new()),
509        }
510    }
511}
512
513impl From<MsgpackExtrasValue> for ExtrasValue {
514    fn from(value: MsgpackExtrasValue) -> Self {
515        match value {
516            MsgpackExtrasValue::String(s) => ExtrasValue::String(s),
517            MsgpackExtrasValue::Number(n) => ExtrasValue::Number(n),
518            MsgpackExtrasValue::Bool(b) => ExtrasValue::Bool(b),
519        }
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526
527    fn sample_message() -> PusherMessage {
528        PusherMessage {
529            event: Some("sockudo:test".to_string()),
530            channel: Some("chat:room-1".to_string()),
531            data: Some(MessageData::Json(sonic_rs::json!({
532                "hello": "world",
533                "count": 3,
534                "nested": { "ok": true },
535                "items": [1, 2, 3]
536            }))),
537            name: None,
538            user_id: Some("user-1".to_string()),
539            tags: Some(BTreeMap::from([
540                ("region".to_string(), "eu".to_string()),
541                ("tier".to_string(), "gold".to_string()),
542            ])),
543            sequence: Some(7),
544            conflation_key: Some("room".to_string()),
545            message_id: Some("mid-1".to_string()),
546            serial: Some(9),
547            idempotency_key: Some("idem-1".to_string()),
548            extras: Some(MessageExtras {
549                headers: Some(HashMap::from([
550                    (
551                        "priority".to_string(),
552                        ExtrasValue::String("high".to_string()),
553                    ),
554                    ("ttl".to_string(), ExtrasValue::Number(5.0)),
555                ])),
556                ephemeral: Some(true),
557                idempotency_key: Some("extra-idem".to_string()),
558                echo: Some(false),
559            }),
560            delta_sequence: Some(11),
561            delta_conflation_key: Some("btc".to_string()),
562        }
563    }
564
565    #[test]
566    fn round_trip_messagepack() {
567        let msg = sample_message();
568        let bytes = serialize_message(&msg, WireFormat::MessagePack).unwrap();
569        let decoded = deserialize_message(&bytes, WireFormat::MessagePack).unwrap();
570        assert_eq!(decoded.event, msg.event);
571        assert_eq!(decoded.delta_sequence, msg.delta_sequence);
572    }
573
574    #[test]
575    fn round_trip_protobuf() {
576        let msg = sample_message();
577        let bytes = serialize_message(&msg, WireFormat::Protobuf).unwrap();
578        let decoded = deserialize_message(&bytes, WireFormat::Protobuf).unwrap();
579        assert_eq!(decoded.event, msg.event);
580        assert_eq!(decoded.channel, msg.channel);
581        assert_eq!(decoded.message_id, msg.message_id);
582        assert_eq!(decoded.delta_conflation_key, msg.delta_conflation_key);
583    }
584
585    #[test]
586    fn parse_query_param_accepts_known_values() {
587        assert_eq!(
588            WireFormat::parse_query_param(None).unwrap(),
589            WireFormat::Json
590        );
591        assert_eq!(
592            WireFormat::parse_query_param(Some("json")).unwrap(),
593            WireFormat::Json
594        );
595        assert_eq!(
596            WireFormat::parse_query_param(Some("messagepack")).unwrap(),
597            WireFormat::MessagePack
598        );
599        assert_eq!(
600            WireFormat::parse_query_param(Some("msgpack")).unwrap(),
601            WireFormat::MessagePack
602        );
603        assert_eq!(
604            WireFormat::parse_query_param(Some("protobuf")).unwrap(),
605            WireFormat::Protobuf
606        );
607        assert_eq!(
608            WireFormat::parse_query_param(Some("proto")).unwrap(),
609            WireFormat::Protobuf
610        );
611    }
612
613    #[test]
614    fn parse_query_param_rejects_unknown_value() {
615        assert!(WireFormat::parse_query_param(Some("avro")).is_err());
616    }
617}