1use std::borrow::Cow;
4use std::fmt::{Debug, Display};
5use std::marker::PhantomData;
6
7use serde::de::Visitor;
8use serde::ser::SerializeSeq;
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
13#[non_exhaustive]
14pub struct Message<P> {
15 pub join_reference: Option<String>,
22 pub message_reference: String,
26 pub topic_name: String,
29 pub event_name: String,
32 pub payload: P,
35}
36
37impl<'a, P> From<ChannelMsg<'a, P>> for Message<P> {
38 fn from(value: ChannelMsg<'a, P>) -> Self {
39 Self {
40 join_reference: value.join_reference.map(Cow::into),
41 message_reference: value.message_reference.into(),
42 topic_name: value.topic_name.into(),
43 event_name: value.event_name.into(),
44 payload: value.payload,
45 }
46 }
47}
48
49impl<P> Display for Message<P>
50where
51 P: Serialize + Debug,
52{
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(f, "[")?;
55 ser_or_debug(&self.join_reference, f)?;
56 write!(f, ", ")?;
57 ser_or_debug(&self.message_reference, f)?;
58 write!(f, ", ")?;
59 ser_or_debug(&self.topic_name, f)?;
60 write!(f, ", ")?;
61 ser_or_debug(&self.event_name, f)?;
62 write!(f, ", ")?;
63 ser_or_debug(&self.payload, f)?;
64 write!(f, "]")
65 }
66}
67
68fn ser_or_debug<T>(v: &T, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
69where
70 T: Serialize + Debug,
71{
72 if let Ok(s) = serde_json::to_string(v) {
73 write!(f, "{s}")
74 } else {
75 write!(f, "{v:?}")
76 }
77}
78
79#[derive(Debug)]
80pub(crate) struct ChannelMsg<'a, P> {
81 pub(crate) join_reference: Option<Cow<'a, str>>,
82 pub(crate) message_reference: Cow<'a, str>,
83 pub(crate) topic_name: Cow<'a, str>,
84 pub(crate) event_name: Cow<'a, str>,
85 pub(crate) payload: P,
86}
87
88impl<P> ChannelMsg<'_, P> {
89 pub(crate) fn into_err(self) -> Message<()> {
90 Message {
91 join_reference: self.join_reference.map(Cow::into),
92 message_reference: self.message_reference.into(),
93 topic_name: self.topic_name.into(),
94 event_name: self.event_name.into(),
95 payload: (),
96 }
97 }
98}
99
100impl<P> Serialize for ChannelMsg<'_, P>
101where
102 P: Serialize,
103{
104 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
105 where
106 S: serde::Serializer,
107 {
108 let mut s = serializer.serialize_seq(Some(5))?;
109 s.serialize_element(&self.join_reference)?;
110 s.serialize_element(&self.message_reference)?;
111 s.serialize_element(&self.topic_name)?;
112 s.serialize_element(&self.event_name)?;
113 s.serialize_element(&self.payload)?;
114 s.end()
115 }
116}
117
118impl<'de, 'a, P> Deserialize<'de> for ChannelMsg<'a, P>
119where
120 P: Deserialize<'de>,
121{
122 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
123 where
124 D: serde::Deserializer<'de>,
125 {
126 use serde::de::Error;
127
128 #[derive(Debug)]
129 struct ChannelMsgVisitor<'a, P> {
130 _marker: PhantomData<(Cow<'a, str>, P)>,
131 }
132
133 impl<'de, 'a, P> Visitor<'de> for ChannelMsgVisitor<'a, P>
134 where
135 P: Deserialize<'de>,
136 {
137 type Value = ChannelMsg<'a, P>;
138
139 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
140 write!(
141 formatter,
142 "a sequence of 5 elements for a valid Phoenix channel"
143 )
144 }
145
146 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
147 where
148 A: serde::de::SeqAccess<'de>,
149 {
150 if let Some(len) = seq.size_hint() {
151 if len != 5 {
152 return Err(A::Error::invalid_length(len, &"5"));
153 }
154 }
155
156 let Some(join_reference) = seq.next_element()? else {
157 return Err(A::Error::invalid_length(0, &"5"));
158 };
159 let Some(message_reference) = seq.next_element()? else {
160 return Err(A::Error::invalid_length(1, &"5"));
161 };
162 let Some(topic_name) = seq.next_element()? else {
163 return Err(A::Error::invalid_length(2, &"5"));
164 };
165 let Some(event_name) = seq.next_element()? else {
166 return Err(A::Error::invalid_length(3, &"5"));
167 };
168 let Some(payload) = seq.next_element()? else {
169 return Err(A::Error::invalid_length(4, &"5"));
170 };
171
172 Ok(ChannelMsg::<P> {
173 join_reference,
174 message_reference,
175 topic_name,
176 event_name,
177 payload,
178 })
179 }
180 }
181
182 deserializer.deserialize_seq(ChannelMsgVisitor::<'a, P> {
183 _marker: PhantomData,
184 })
185 }
186}
187
188impl<P> Display for ChannelMsg<'_, P>
189where
190 P: Serialize + Debug,
191{
192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 let Ok(s) = serde_json::to_string(self) else {
194 return write!(
195 f,
196 "[{:?}, {:?}, {:?}, {:?}, {:?}]",
197 self.join_reference,
198 self.message_reference,
199 self.topic_name,
200 self.event_name,
201 self.payload,
202 );
203 };
204
205 write!(f, "{s}")
206 }
207}