1use ahash::AHashMap;
2use prost::Message;
3use serde::{Deserialize, Serialize};
4use sonic_rs::Value;
5use std::collections::{BTreeMap, HashMap};
6
7use crate::messages::{ExtrasValue, MessageData, MessageExtras, PusherMessage};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
10#[serde(rename_all = "lowercase")]
11pub enum WireFormat {
12 #[default]
13 Json,
14 MessagePack,
15 Protobuf,
16}
17
18impl WireFormat {
19 pub fn from_query_param(value: Option<&str>) -> Self {
20 Self::parse_query_param(value).unwrap_or(Self::Json)
21 }
22
23 pub fn parse_query_param(value: Option<&str>) -> Result<Self, String> {
24 match value.map(|v| v.trim().to_ascii_lowercase()) {
25 None => Ok(Self::Json),
26 Some(v) if v.is_empty() || v == "json" => Ok(Self::Json),
27 Some(v) if v == "msgpack" || v == "messagepack" => Ok(Self::MessagePack),
28 Some(v) if v == "protobuf" || v == "proto" => Ok(Self::Protobuf),
29 Some(v) => Err(format!("unsupported wire format '{v}'")),
30 }
31 }
32
33 pub const fn is_binary(self) -> bool {
34 !matches!(self, Self::Json)
35 }
36}
37
38pub fn serialize_message(message: &PusherMessage, format: WireFormat) -> Result<Vec<u8>, String> {
39 match format {
40 WireFormat::Json => {
41 sonic_rs::to_vec(message).map_err(|e| format!("JSON serialization failed: {e}"))
42 }
43 WireFormat::MessagePack => rmp_serde::to_vec(&MsgpackPusherMessage::from(message.clone()))
44 .map_err(|e| format!("MessagePack serialization failed: {e}")),
45 WireFormat::Protobuf => {
46 let proto = ProtoPusherMessage::from(message.clone());
47 let mut buf = Vec::with_capacity(proto.encoded_len());
48 proto
49 .encode(&mut buf)
50 .map_err(|e| format!("Protobuf serialization failed: {e}"))?;
51 Ok(buf)
52 }
53 }
54}
55
56pub fn deserialize_message(bytes: &[u8], format: WireFormat) -> Result<PusherMessage, String> {
57 match format {
58 WireFormat::Json => {
59 sonic_rs::from_slice(bytes).map_err(|e| format!("JSON deserialization failed: {e}"))
60 }
61 WireFormat::MessagePack => {
62 let msg: MsgpackPusherMessage = rmp_serde::from_slice(bytes)
63 .map_err(|e| format!("MessagePack deserialization failed: {e}"))?;
64 Ok(msg.into())
65 }
66 WireFormat::Protobuf => {
67 let proto = ProtoPusherMessage::decode(bytes)
68 .map_err(|e| format!("Protobuf deserialization failed: {e}"))?;
69 Ok(proto.into())
70 }
71 }
72}
73
74#[derive(Clone, PartialEq, Message)]
75struct ProtoPusherMessage {
76 #[prost(string, optional, tag = "1")]
77 event: Option<String>,
78 #[prost(string, optional, tag = "2")]
79 channel: Option<String>,
80 #[prost(message, optional, tag = "3")]
81 data: Option<ProtoMessageData>,
82 #[prost(string, optional, tag = "4")]
83 name: Option<String>,
84 #[prost(string, optional, tag = "5")]
85 user_id: Option<String>,
86 #[prost(map = "string, string", tag = "6")]
87 tags: HashMap<String, String>,
88 #[prost(uint64, optional, tag = "7")]
89 sequence: Option<u64>,
90 #[prost(string, optional, tag = "8")]
91 conflation_key: Option<String>,
92 #[prost(string, optional, tag = "9")]
93 message_id: Option<String>,
94 #[prost(string, optional, tag = "10")]
95 stream_id: Option<String>,
96 #[prost(uint64, optional, tag = "11")]
97 serial: Option<u64>,
98 #[prost(string, optional, tag = "12")]
99 idempotency_key: Option<String>,
100 #[prost(message, optional, tag = "13")]
101 extras: Option<ProtoMessageExtras>,
102 #[prost(uint64, optional, tag = "14")]
103 delta_sequence: Option<u64>,
104 #[prost(string, optional, tag = "15")]
105 delta_conflation_key: Option<String>,
106}
107
108#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
109struct MsgpackPusherMessage {
110 event: Option<String>,
111 channel: Option<String>,
112 data: Option<MsgpackMessageData>,
113 name: Option<String>,
114 user_id: Option<String>,
115 tags: Option<BTreeMap<String, String>>,
116 sequence: Option<u64>,
117 conflation_key: Option<String>,
118 message_id: Option<String>,
119 stream_id: Option<String>,
120 serial: Option<u64>,
121 idempotency_key: Option<String>,
122 extras: Option<MsgpackMessageExtras>,
123 delta_sequence: Option<u64>,
124 delta_conflation_key: Option<String>,
125}
126
127#[derive(Clone, PartialEq, Message)]
128struct ProtoMessageData {
129 #[prost(oneof = "proto_message_data::Kind", tags = "1, 2, 3")]
130 kind: Option<proto_message_data::Kind>,
131}
132
133#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
134#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
135enum MsgpackMessageData {
136 String(String),
137 Structured(MsgpackStructuredData),
138 Json(String),
139}
140
141mod proto_message_data {
142 use super::ProtoStructuredData;
143 use prost::Oneof;
144
145 #[derive(Clone, PartialEq, Oneof)]
146 pub enum Kind {
147 #[prost(string, tag = "1")]
148 String(String),
149 #[prost(message, tag = "2")]
150 Structured(ProtoStructuredData),
151 #[prost(string, tag = "3")]
152 Json(String),
153 }
154}
155
156#[derive(Clone, PartialEq, Message)]
157struct ProtoStructuredData {
158 #[prost(string, optional, tag = "1")]
159 channel_data: Option<String>,
160 #[prost(string, optional, tag = "2")]
161 channel: Option<String>,
162 #[prost(string, optional, tag = "3")]
163 user_data: Option<String>,
164 #[prost(map = "string, string", tag = "4")]
165 extra: HashMap<String, String>,
166}
167
168#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
169struct MsgpackStructuredData {
170 channel_data: Option<String>,
171 channel: Option<String>,
172 user_data: Option<String>,
173 extra: HashMap<String, String>,
174}
175
176#[derive(Clone, PartialEq, Message)]
177struct ProtoMessageExtras {
178 #[prost(map = "string, message", tag = "1")]
179 headers: HashMap<String, ProtoExtrasValue>,
180 #[prost(bool, optional, tag = "2")]
181 ephemeral: Option<bool>,
182 #[prost(string, optional, tag = "3")]
183 idempotency_key: Option<String>,
184 #[prost(bool, optional, tag = "4")]
185 echo: Option<bool>,
186}
187
188#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
189struct MsgpackMessageExtras {
190 headers: Option<HashMap<String, MsgpackExtrasValue>>,
191 ephemeral: Option<bool>,
192 idempotency_key: Option<String>,
193 echo: Option<bool>,
194}
195
196#[derive(Clone, PartialEq, Message)]
197struct ProtoExtrasValue {
198 #[prost(oneof = "proto_extras_value::Kind", tags = "1, 2, 3")]
199 kind: Option<proto_extras_value::Kind>,
200}
201
202#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
203#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
204enum MsgpackExtrasValue {
205 String(String),
206 Number(f64),
207 Bool(bool),
208}
209
210mod proto_extras_value {
211 use prost::Oneof;
212
213 #[derive(Clone, PartialEq, Oneof)]
214 pub enum Kind {
215 #[prost(string, tag = "1")]
216 String(String),
217 #[prost(double, tag = "2")]
218 Number(f64),
219 #[prost(bool, tag = "3")]
220 Bool(bool),
221 }
222}
223
224impl From<PusherMessage> for ProtoPusherMessage {
225 fn from(value: PusherMessage) -> Self {
226 Self {
227 event: value.event,
228 channel: value.channel,
229 data: value.data.map(Into::into),
230 name: value.name,
231 user_id: value.user_id,
232 tags: value
233 .tags
234 .map(|m| m.into_iter().collect())
235 .unwrap_or_default(),
236 sequence: value.sequence,
237 conflation_key: value.conflation_key,
238 message_id: value.message_id,
239 stream_id: value.stream_id,
240 serial: value.serial,
241 idempotency_key: value.idempotency_key,
242 extras: value.extras.map(Into::into),
243 delta_sequence: value.delta_sequence,
244 delta_conflation_key: value.delta_conflation_key,
245 }
246 }
247}
248
249impl From<PusherMessage> for MsgpackPusherMessage {
250 fn from(value: PusherMessage) -> Self {
251 Self {
252 event: value.event,
253 channel: value.channel,
254 data: value.data.map(Into::into),
255 name: value.name,
256 user_id: value.user_id,
257 tags: value.tags,
258 sequence: value.sequence,
259 conflation_key: value.conflation_key,
260 message_id: value.message_id,
261 stream_id: value.stream_id,
262 serial: value.serial,
263 idempotency_key: value.idempotency_key,
264 extras: value.extras.map(Into::into),
265 delta_sequence: value.delta_sequence,
266 delta_conflation_key: value.delta_conflation_key,
267 }
268 }
269}
270
271impl From<ProtoPusherMessage> for PusherMessage {
272 fn from(value: ProtoPusherMessage) -> Self {
273 Self {
274 event: value.event,
275 channel: value.channel,
276 data: value.data.map(Into::into),
277 name: value.name,
278 user_id: value.user_id,
279 tags: (!value.tags.is_empty())
280 .then_some(value.tags.into_iter().collect::<BTreeMap<_, _>>()),
281 sequence: value.sequence,
282 conflation_key: value.conflation_key,
283 message_id: value.message_id,
284 stream_id: value.stream_id,
285 serial: value.serial,
286 idempotency_key: value.idempotency_key,
287 extras: value.extras.map(Into::into),
288 delta_sequence: value.delta_sequence,
289 delta_conflation_key: value.delta_conflation_key,
290 }
291 }
292}
293
294impl From<MsgpackPusherMessage> for PusherMessage {
295 fn from(value: MsgpackPusherMessage) -> Self {
296 Self {
297 event: value.event,
298 channel: value.channel,
299 data: value.data.map(Into::into),
300 name: value.name,
301 user_id: value.user_id,
302 tags: value.tags,
303 sequence: value.sequence,
304 conflation_key: value.conflation_key,
305 message_id: value.message_id,
306 stream_id: value.stream_id,
307 serial: value.serial,
308 idempotency_key: value.idempotency_key,
309 extras: value.extras.map(Into::into),
310 delta_sequence: value.delta_sequence,
311 delta_conflation_key: value.delta_conflation_key,
312 }
313 }
314}
315
316impl From<MessageData> for ProtoMessageData {
317 fn from(value: MessageData) -> Self {
318 let kind = match value {
319 MessageData::String(s) => Some(proto_message_data::Kind::String(s)),
320 MessageData::Structured {
321 channel_data,
322 channel,
323 user_data,
324 extra,
325 } => Some(proto_message_data::Kind::Structured(ProtoStructuredData {
326 channel_data,
327 channel,
328 user_data,
329 extra: extra
330 .into_iter()
331 .map(|(k, v)| {
332 (
333 k,
334 sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
335 )
336 })
337 .collect(),
338 })),
339 MessageData::Json(v) => Some(proto_message_data::Kind::Json(
340 sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
341 )),
342 };
343
344 Self { kind }
345 }
346}
347
348impl From<MessageData> for MsgpackMessageData {
349 fn from(value: MessageData) -> Self {
350 match value {
351 MessageData::String(s) => Self::String(s),
352 MessageData::Structured {
353 channel_data,
354 channel,
355 user_data,
356 extra,
357 } => Self::Structured(MsgpackStructuredData {
358 channel_data,
359 channel,
360 user_data,
361 extra: extra
362 .into_iter()
363 .map(|(k, v)| {
364 (
365 k,
366 sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
367 )
368 })
369 .collect(),
370 }),
371 MessageData::Json(v) => {
372 Self::Json(sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()))
373 }
374 }
375 }
376}
377
378impl From<ProtoMessageData> for MessageData {
379 fn from(value: ProtoMessageData) -> Self {
380 match value.kind {
381 Some(proto_message_data::Kind::String(s)) => MessageData::String(s),
382 Some(proto_message_data::Kind::Structured(s)) => MessageData::Structured {
383 channel_data: s.channel_data,
384 channel: s.channel,
385 user_data: s.user_data,
386 extra: s
387 .extra
388 .into_iter()
389 .map(|(k, v)| {
390 let parsed =
391 sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str()));
392 (k, parsed)
393 })
394 .collect::<AHashMap<_, _>>(),
395 },
396 Some(proto_message_data::Kind::Json(v)) => MessageData::Json(
397 sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str())),
398 ),
399 None => MessageData::Json(Value::new_null()),
400 }
401 }
402}
403
404impl From<MsgpackMessageData> for MessageData {
405 fn from(value: MsgpackMessageData) -> Self {
406 match value {
407 MsgpackMessageData::String(s) => MessageData::String(s),
408 MsgpackMessageData::Structured(s) => MessageData::Structured {
409 channel_data: s.channel_data,
410 channel: s.channel,
411 user_data: s.user_data,
412 extra: s
413 .extra
414 .into_iter()
415 .map(|(k, v)| {
416 let parsed =
417 sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str()));
418 (k, parsed)
419 })
420 .collect::<AHashMap<_, _>>(),
421 },
422 MsgpackMessageData::Json(v) => MessageData::Json(
423 sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str())),
424 ),
425 }
426 }
427}
428
429impl From<MessageExtras> for ProtoMessageExtras {
430 fn from(value: MessageExtras) -> Self {
431 Self {
432 headers: value
433 .headers
434 .unwrap_or_default()
435 .into_iter()
436 .map(|(k, v)| (k, v.into()))
437 .collect(),
438 ephemeral: value.ephemeral,
439 idempotency_key: value.idempotency_key,
440 echo: value.echo,
441 }
442 }
443}
444
445impl From<MessageExtras> for MsgpackMessageExtras {
446 fn from(value: MessageExtras) -> Self {
447 Self {
448 headers: value
449 .headers
450 .map(|headers| headers.into_iter().map(|(k, v)| (k, v.into())).collect()),
451 ephemeral: value.ephemeral,
452 idempotency_key: value.idempotency_key,
453 echo: value.echo,
454 }
455 }
456}
457
458impl From<ProtoMessageExtras> for MessageExtras {
459 fn from(value: ProtoMessageExtras) -> Self {
460 Self {
461 headers: (!value.headers.is_empty()).then_some(
462 value
463 .headers
464 .into_iter()
465 .map(|(k, v)| (k, v.into()))
466 .collect(),
467 ),
468 ephemeral: value.ephemeral,
469 idempotency_key: value.idempotency_key,
470 echo: value.echo,
471 }
472 }
473}
474
475impl From<MsgpackMessageExtras> for MessageExtras {
476 fn from(value: MsgpackMessageExtras) -> Self {
477 Self {
478 headers: value
479 .headers
480 .map(|headers| headers.into_iter().map(|(k, v)| (k, v.into())).collect()),
481 ephemeral: value.ephemeral,
482 idempotency_key: value.idempotency_key,
483 echo: value.echo,
484 }
485 }
486}
487
488impl From<ExtrasValue> for ProtoExtrasValue {
489 fn from(value: ExtrasValue) -> Self {
490 let kind = match value {
491 ExtrasValue::String(s) => Some(proto_extras_value::Kind::String(s)),
492 ExtrasValue::Number(n) => Some(proto_extras_value::Kind::Number(n)),
493 ExtrasValue::Bool(b) => Some(proto_extras_value::Kind::Bool(b)),
494 };
495 Self { kind }
496 }
497}
498
499impl From<ExtrasValue> for MsgpackExtrasValue {
500 fn from(value: ExtrasValue) -> Self {
501 match value {
502 ExtrasValue::String(s) => Self::String(s),
503 ExtrasValue::Number(n) => Self::Number(n),
504 ExtrasValue::Bool(b) => Self::Bool(b),
505 }
506 }
507}
508
509impl From<ProtoExtrasValue> for ExtrasValue {
510 fn from(value: ProtoExtrasValue) -> Self {
511 match value.kind {
512 Some(proto_extras_value::Kind::String(s)) => ExtrasValue::String(s),
513 Some(proto_extras_value::Kind::Number(n)) => ExtrasValue::Number(n),
514 Some(proto_extras_value::Kind::Bool(b)) => ExtrasValue::Bool(b),
515 None => ExtrasValue::String(String::new()),
516 }
517 }
518}
519
520impl From<MsgpackExtrasValue> for ExtrasValue {
521 fn from(value: MsgpackExtrasValue) -> Self {
522 match value {
523 MsgpackExtrasValue::String(s) => ExtrasValue::String(s),
524 MsgpackExtrasValue::Number(n) => ExtrasValue::Number(n),
525 MsgpackExtrasValue::Bool(b) => ExtrasValue::Bool(b),
526 }
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533
534 fn sample_message() -> PusherMessage {
535 PusherMessage {
536 event: Some("sockudo:test".to_string()),
537 channel: Some("chat:room-1".to_string()),
538 data: Some(MessageData::Json(sonic_rs::json!({
539 "hello": "world",
540 "count": 3,
541 "nested": { "ok": true },
542 "items": [1, 2, 3]
543 }))),
544 name: None,
545 user_id: Some("user-1".to_string()),
546 tags: Some(BTreeMap::from([
547 ("region".to_string(), "eu".to_string()),
548 ("tier".to_string(), "gold".to_string()),
549 ])),
550 sequence: Some(7),
551 conflation_key: Some("room".to_string()),
552 message_id: Some("mid-1".to_string()),
553 stream_id: Some("stream-1".to_string()),
554 serial: Some(9),
555 idempotency_key: Some("idem-1".to_string()),
556 extras: Some(MessageExtras {
557 headers: Some(HashMap::from([
558 (
559 "priority".to_string(),
560 ExtrasValue::String("high".to_string()),
561 ),
562 ("ttl".to_string(), ExtrasValue::Number(5.0)),
563 ])),
564 ephemeral: Some(true),
565 idempotency_key: Some("extra-idem".to_string()),
566 echo: Some(false),
567 }),
568 delta_sequence: Some(11),
569 delta_conflation_key: Some("btc".to_string()),
570 }
571 }
572
573 #[test]
574 fn round_trip_messagepack() {
575 let msg = sample_message();
576 let bytes = serialize_message(&msg, WireFormat::MessagePack).unwrap();
577 let decoded = deserialize_message(&bytes, WireFormat::MessagePack).unwrap();
578 assert_eq!(decoded.event, msg.event);
579 assert_eq!(decoded.delta_sequence, msg.delta_sequence);
580 }
581
582 #[test]
583 fn round_trip_protobuf() {
584 let msg = sample_message();
585 let bytes = serialize_message(&msg, WireFormat::Protobuf).unwrap();
586 let decoded = deserialize_message(&bytes, WireFormat::Protobuf).unwrap();
587 assert_eq!(decoded.event, msg.event);
588 assert_eq!(decoded.channel, msg.channel);
589 assert_eq!(decoded.message_id, msg.message_id);
590 assert_eq!(decoded.delta_conflation_key, msg.delta_conflation_key);
591 }
592
593 #[test]
594 fn parse_query_param_accepts_known_values() {
595 assert_eq!(
596 WireFormat::parse_query_param(None).unwrap(),
597 WireFormat::Json
598 );
599 assert_eq!(
600 WireFormat::parse_query_param(Some("json")).unwrap(),
601 WireFormat::Json
602 );
603 assert_eq!(
604 WireFormat::parse_query_param(Some("messagepack")).unwrap(),
605 WireFormat::MessagePack
606 );
607 assert_eq!(
608 WireFormat::parse_query_param(Some("msgpack")).unwrap(),
609 WireFormat::MessagePack
610 );
611 assert_eq!(
612 WireFormat::parse_query_param(Some("protobuf")).unwrap(),
613 WireFormat::Protobuf
614 );
615 assert_eq!(
616 WireFormat::parse_query_param(Some("proto")).unwrap(),
617 WireFormat::Protobuf
618 );
619 }
620
621 #[test]
622 fn parse_query_param_rejects_unknown_value() {
623 assert!(WireFormat::parse_query_param(Some("avro")).is_err());
624 }
625}