glimesh_protocol/
socket.rs

1use bimap::BiMap;
2use serde::{de::DeserializeOwned, Deserialize, Serialize};
3use serde_json::Value;
4use serde_tuple::{Deserialize_tuple, Serialize_tuple};
5use snafu::{ResultExt, Snafu};
6use std::{collections::HashMap, hash::Hash};
7use uuid::Uuid;
8
9const TOPIC_PHOENIX: &str = "phoenix";
10const TOPIC_ABSINTHE_CONTROL: &str = "__absinthe__:control";
11
12#[derive(Debug, Clone, Serialize_tuple)]
13struct SendPhoenixMessage<T: Serialize> {
14    join_ref: Uuid,
15    msg_ref: Uuid,
16    topic: String,
17    event: SendEvent,
18    payload: T,
19}
20
21#[derive(Debug, Clone, Deserialize_tuple)]
22struct ReceivePhoenixMessage<T: DeserializeOwned> {
23    #[allow(unused)]
24    join_ref: Option<Uuid>,
25    msg_ref: Option<Uuid>,
26    #[allow(unused)]
27    topic: String,
28    event: ReceiveEvent,
29    payload: T,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33struct Empty;
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36struct PhxReply<T> {
37    response: T,
38    status: String,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42#[serde(rename_all = "camelCase")]
43struct DocumentSubscribeResponse {
44    subscription_id: String,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(rename_all = "camelCase")]
49struct SubscriptionEvent {
50    result: serde_json::Value,
51    subscription_id: String,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56struct UnsubscribePayload {
57    subscription_id: String,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61#[serde(rename_all = "snake_case")]
62enum SendEvent {
63    PhxJoin,
64    Heartbeat,
65    Doc,
66    Unsubscribe,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(rename_all = "snake_case")]
71enum ReceiveEvent {
72    PhxReply,
73    PhxError,
74
75    #[serde(rename = "subscription:data")]
76    SubscriptionData,
77}
78
79enum Request<T> {
80    Join,
81    Ping,
82    Subscribe(T),
83    Unsubscribe(T),
84}
85
86pub enum Event<T> {
87    Joined,
88    Pong,
89    Document(T, Value),
90}
91
92pub enum SessionResult<T> {
93    Event(Event<T>),
94    Message(String),
95}
96
97impl<T> From<Event<T>> for SessionResult<T> {
98    fn from(evt: Event<T>) -> Self {
99        Self::Event(evt)
100    }
101}
102
103struct PendingRequest<T> {
104    msg_ref: Uuid,
105    message: String,
106    request: Request<T>,
107}
108
109pub struct SocketSession<T>
110where
111    T: Clone + Hash + Eq,
112{
113    join_ref: Uuid,
114    pending_requests: Vec<PendingRequest<T>>,
115    in_flight_requests: HashMap<Uuid, Request<T>>,
116    subscriptions: BiMap<String, T>,
117    joined: bool,
118}
119
120impl<T> SocketSession<T>
121where
122    T: Clone + Hash + Eq,
123{
124    pub fn new() -> Self {
125        Self {
126            join_ref: Uuid::new_v4(),
127            pending_requests: Vec::new(),
128            in_flight_requests: Default::default(),
129            subscriptions: Default::default(),
130            joined: false,
131        }
132    }
133
134    fn send_request<V>(
135        &mut self,
136        request: Request<T>,
137        message: SendPhoenixMessage<V>,
138    ) -> Vec<SessionResult<T>>
139    where
140        V: Serialize,
141    {
142        let message_str = serde_json::to_string(&message).unwrap();
143
144        if self.joined || matches!(request, Request::Join) {
145            self.in_flight_requests.insert(message.msg_ref, request);
146            vec![SessionResult::Message(message_str)]
147        } else {
148            self.pending_requests.push(PendingRequest {
149                msg_ref: message.msg_ref,
150                message: message_str,
151                request,
152            });
153            vec![]
154        }
155    }
156
157    pub fn join(&mut self) -> Vec<SessionResult<T>> {
158        self.send_request(
159            Request::Join,
160            SendPhoenixMessage {
161                join_ref: self.join_ref,
162                msg_ref: Uuid::new_v4(),
163                topic: TOPIC_ABSINTHE_CONTROL.into(),
164                event: SendEvent::PhxJoin,
165                payload: Empty,
166            },
167        )
168    }
169
170    pub fn ping(&mut self) -> Vec<SessionResult<T>> {
171        self.send_request(
172            Request::Ping,
173            SendPhoenixMessage {
174                join_ref: self.join_ref,
175                msg_ref: Uuid::new_v4(),
176                topic: TOPIC_PHOENIX.into(),
177                event: SendEvent::Heartbeat,
178                payload: Empty,
179            },
180        )
181    }
182
183    pub fn subscribe<B>(&mut self, reference: T, body: B) -> Vec<SessionResult<T>>
184    where
185        B: Serialize,
186    {
187        self.send_request(
188            Request::Subscribe(reference),
189            SendPhoenixMessage {
190                join_ref: self.join_ref,
191                msg_ref: Uuid::new_v4(),
192                topic: TOPIC_ABSINTHE_CONTROL.into(),
193                event: SendEvent::Doc,
194                payload: body,
195            },
196        )
197    }
198
199    pub fn unsubscribe(&mut self, reference: T) -> Vec<SessionResult<T>> {
200        match self.subscriptions.remove_by_right(&reference) {
201            Some((subscription_id, _)) => self.send_request(
202                Request::Unsubscribe(reference),
203                SendPhoenixMessage {
204                    join_ref: self.join_ref,
205                    msg_ref: Uuid::new_v4(),
206                    topic: TOPIC_ABSINTHE_CONTROL.into(),
207                    event: SendEvent::Unsubscribe,
208                    payload: UnsubscribePayload { subscription_id },
209                },
210            ),
211            None => {
212                vec![]
213            }
214        }
215    }
216
217    pub fn handle_message(
218        &mut self,
219        msg: &str,
220    ) -> Result<Vec<SessionResult<T>>, HandleMessageError> {
221        let message: ReceivePhoenixMessage<Value> =
222            serde_json::from_str(msg).context(DeserializeSnafu {})?;
223
224        let results = match message.event {
225            ReceiveEvent::PhxReply => match message.msg_ref {
226                Some(msg_ref) => match self.in_flight_requests.remove(&msg_ref) {
227                    Some(Request::Join) => {
228                        self.joined = true;
229                        let mut results = Vec::with_capacity(self.pending_requests.len() + 1);
230                        results.push(Event::Joined.into());
231                        for pending in self.pending_requests.drain(..) {
232                            self.in_flight_requests
233                                .insert(pending.msg_ref, pending.request);
234                            results.push(SessionResult::Message(pending.message));
235                        }
236                        results
237                    }
238                    Some(Request::Ping) => {
239                        vec![Event::Pong.into()]
240                    }
241                    Some(Request::Subscribe(reference)) => {
242                        let reply: PhxReply<DocumentSubscribeResponse> =
243                            serde_json::from_value(message.payload).context(DeserializeSnafu {})?;
244                        self.subscriptions
245                            .insert(reply.response.subscription_id, reference);
246                        vec![]
247                    }
248                    Some(Request::Unsubscribe(_)) => {
249                        vec![]
250                    }
251                    None => {
252                        tracing::warn!(?message, "received a reply to a request we didn't make");
253                        vec![]
254                    }
255                },
256                None => {
257                    tracing::warn!(
258                        ?message,
259                        "received a reply that cannot be matched to a request"
260                    );
261                    vec![]
262                }
263            },
264            ReceiveEvent::PhxError => {
265                tracing::error!(?message, "got phx_error");
266                // TODO
267                vec![]
268            }
269            ReceiveEvent::SubscriptionData => {
270                let data: SubscriptionEvent =
271                    serde_json::from_value(message.payload).context(DeserializeSnafu {})?;
272                match self.subscriptions.get_by_left(&data.subscription_id) {
273                    Some(reference) => {
274                        vec![Event::Document(reference.clone(), data.result).into()]
275                    }
276                    None => {
277                        tracing::warn!(
278                            ?data,
279                            "received subscription data for a subscription we are not tracking"
280                        );
281                        vec![]
282                    }
283                }
284            }
285        };
286
287        Ok(results)
288    }
289}
290
291impl<T> Default for SocketSession<T>
292where
293    T: Clone + Hash + Eq,
294{
295    fn default() -> Self {
296        Self::new()
297    }
298}
299
300#[derive(Debug, Snafu)]
301pub enum HandleMessageError {
302    DeserializeError { source: serde_json::Error },
303}