phoenix_chan/
message.rs

1//! Messages sent from and to the phoenix channel.
2
3use std::borrow::Cow;
4use std::fmt::{Debug, Display};
5use std::marker::PhantomData;
6
7use serde::de::Visitor;
8use serde::ser::SerializeSeq;
9use serde::{Deserialize, Serialize};
10
11/// Message received from the channel.
12#[derive(Debug, Clone, PartialEq, Eq)]
13#[non_exhaustive]
14pub struct Message<P> {
15    /// The `join_reference` is also chosen by the client and should also be a unique value.
16    ///
17    /// It only needs to be sent for a `phx_join` event; for other messages it can be null. It is
18    /// used as a message reference for push messages from the server, meaning those that are not
19    /// replies to a specific client message. For example, imagine something like "a new user just
20    /// joined the chat room".
21    pub join_reference: Option<String>,
22    /// The `message_reference` is chosen by the client and should be a unique value.
23    ///
24    /// The server includes it in its reply so that the client knows which message the reply is for.
25    pub message_reference: String,
26    /// The `topic_name` must be a known topic for the socket endpoint, and a client must join that
27    /// topic before sending any messages on it.
28    pub topic_name: String,
29    /// The `event_name` must match the first argument of a `handle_in` function on the server channel
30    /// module.
31    pub event_name: String,
32    /// The `payload` should be a map and is passed as the second argument to that `handle_in`
33    /// function.
34    pub payload: P,
35}
36
37impl<'a, P> From<ChannelMsg<'a, P>> for Message<P> {
38    fn from(value: ChannelMsg<'a, P>) -> Self {
39        Self {
40            join_reference: value.join_reference.map(Cow::into),
41            message_reference: value.message_reference.into(),
42            topic_name: value.topic_name.into(),
43            event_name: value.event_name.into(),
44            payload: value.payload,
45        }
46    }
47}
48
49impl<P> Display for Message<P>
50where
51    P: Serialize + Debug,
52{
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        write!(f, "[")?;
55        ser_or_debug(&self.join_reference, f)?;
56        write!(f, ", ")?;
57        ser_or_debug(&self.message_reference, f)?;
58        write!(f, ", ")?;
59        ser_or_debug(&self.topic_name, f)?;
60        write!(f, ", ")?;
61        ser_or_debug(&self.event_name, f)?;
62        write!(f, ", ")?;
63        ser_or_debug(&self.payload, f)?;
64        write!(f, "]")
65    }
66}
67
68fn ser_or_debug<T>(v: &T, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
69where
70    T: Serialize + Debug,
71{
72    if let Ok(s) = serde_json::to_string(v) {
73        write!(f, "{s}")
74    } else {
75        write!(f, "{v:?}")
76    }
77}
78
79#[derive(Debug)]
80pub(crate) struct ChannelMsg<'a, P> {
81    pub(crate) join_reference: Option<Cow<'a, str>>,
82    pub(crate) message_reference: Cow<'a, str>,
83    pub(crate) topic_name: Cow<'a, str>,
84    pub(crate) event_name: Cow<'a, str>,
85    pub(crate) payload: P,
86}
87
88impl<P> ChannelMsg<'_, P> {
89    pub(crate) fn into_err(self) -> Message<()> {
90        Message {
91            join_reference: self.join_reference.map(Cow::into),
92            message_reference: self.message_reference.into(),
93            topic_name: self.topic_name.into(),
94            event_name: self.event_name.into(),
95            payload: (),
96        }
97    }
98}
99
100impl<P> Serialize for ChannelMsg<'_, P>
101where
102    P: Serialize,
103{
104    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
105    where
106        S: serde::Serializer,
107    {
108        let mut s = serializer.serialize_seq(Some(5))?;
109        s.serialize_element(&self.join_reference)?;
110        s.serialize_element(&self.message_reference)?;
111        s.serialize_element(&self.topic_name)?;
112        s.serialize_element(&self.event_name)?;
113        s.serialize_element(&self.payload)?;
114        s.end()
115    }
116}
117
118impl<'de, 'a, P> Deserialize<'de> for ChannelMsg<'a, P>
119where
120    P: Deserialize<'de>,
121{
122    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
123    where
124        D: serde::Deserializer<'de>,
125    {
126        use serde::de::Error;
127
128        #[derive(Debug)]
129        struct ChannelMsgVisitor<'a, P> {
130            _marker: PhantomData<(Cow<'a, str>, P)>,
131        }
132
133        impl<'de, 'a, P> Visitor<'de> for ChannelMsgVisitor<'a, P>
134        where
135            P: Deserialize<'de>,
136        {
137            type Value = ChannelMsg<'a, P>;
138
139            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
140                write!(
141                    formatter,
142                    "a sequence of 5 elements for a valid Phoenix channel"
143                )
144            }
145
146            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
147            where
148                A: serde::de::SeqAccess<'de>,
149            {
150                if let Some(len) = seq.size_hint() {
151                    if len != 5 {
152                        return Err(A::Error::invalid_length(len, &"5"));
153                    }
154                }
155
156                let Some(join_reference) = seq.next_element()? else {
157                    return Err(A::Error::invalid_length(0, &"5"));
158                };
159                let Some(message_reference) = seq.next_element()? else {
160                    return Err(A::Error::invalid_length(1, &"5"));
161                };
162                let Some(topic_name) = seq.next_element()? else {
163                    return Err(A::Error::invalid_length(2, &"5"));
164                };
165                let Some(event_name) = seq.next_element()? else {
166                    return Err(A::Error::invalid_length(3, &"5"));
167                };
168                let Some(payload) = seq.next_element()? else {
169                    return Err(A::Error::invalid_length(4, &"5"));
170                };
171
172                Ok(ChannelMsg::<P> {
173                    join_reference,
174                    message_reference,
175                    topic_name,
176                    event_name,
177                    payload,
178                })
179            }
180        }
181
182        deserializer.deserialize_seq(ChannelMsgVisitor::<'a, P> {
183            _marker: PhantomData,
184        })
185    }
186}
187
188impl<P> Display for ChannelMsg<'_, P>
189where
190    P: Serialize + Debug,
191{
192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193        let Ok(s) = serde_json::to_string(self) else {
194            return write!(
195                f,
196                "[{:?}, {:?}, {:?}, {:?}, {:?}]",
197                self.join_reference,
198                self.message_reference,
199                self.topic_name,
200                self.event_name,
201                self.payload,
202            );
203        };
204
205        write!(f, "{s}")
206    }
207}