1use bimap::BiMap;
2use serde::{de::DeserializeOwned, Deserialize, Serialize};
3use serde_json::Value;
4use serde_tuple::{Deserialize_tuple, Serialize_tuple};
5use snafu::{ResultExt, Snafu};
6use std::{collections::HashMap, hash::Hash};
7use uuid::Uuid;
8
9const TOPIC_PHOENIX: &str = "phoenix";
10const TOPIC_ABSINTHE_CONTROL: &str = "__absinthe__:control";
11
12#[derive(Debug, Clone, Serialize_tuple)]
13struct SendPhoenixMessage<T: Serialize> {
14 join_ref: Uuid,
15 msg_ref: Uuid,
16 topic: String,
17 event: SendEvent,
18 payload: T,
19}
20
21#[derive(Debug, Clone, Deserialize_tuple)]
22struct ReceivePhoenixMessage<T: DeserializeOwned> {
23 #[allow(unused)]
24 join_ref: Option<Uuid>,
25 msg_ref: Option<Uuid>,
26 #[allow(unused)]
27 topic: String,
28 event: ReceiveEvent,
29 payload: T,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33struct Empty;
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36struct PhxReply<T> {
37 response: T,
38 status: String,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42#[serde(rename_all = "camelCase")]
43struct DocumentSubscribeResponse {
44 subscription_id: String,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(rename_all = "camelCase")]
49struct SubscriptionEvent {
50 result: serde_json::Value,
51 subscription_id: String,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56struct UnsubscribePayload {
57 subscription_id: String,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61#[serde(rename_all = "snake_case")]
62enum SendEvent {
63 PhxJoin,
64 Heartbeat,
65 Doc,
66 Unsubscribe,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(rename_all = "snake_case")]
71enum ReceiveEvent {
72 PhxReply,
73 PhxError,
74
75 #[serde(rename = "subscription:data")]
76 SubscriptionData,
77}
78
79enum Request<T> {
80 Join,
81 Ping,
82 Subscribe(T),
83 Unsubscribe(T),
84}
85
86pub enum Event<T> {
87 Joined,
88 Pong,
89 Document(T, Value),
90}
91
92pub enum SessionResult<T> {
93 Event(Event<T>),
94 Message(String),
95}
96
97impl<T> From<Event<T>> for SessionResult<T> {
98 fn from(evt: Event<T>) -> Self {
99 Self::Event(evt)
100 }
101}
102
103struct PendingRequest<T> {
104 msg_ref: Uuid,
105 message: String,
106 request: Request<T>,
107}
108
109pub struct SocketSession<T>
110where
111 T: Clone + Hash + Eq,
112{
113 join_ref: Uuid,
114 pending_requests: Vec<PendingRequest<T>>,
115 in_flight_requests: HashMap<Uuid, Request<T>>,
116 subscriptions: BiMap<String, T>,
117 joined: bool,
118}
119
120impl<T> SocketSession<T>
121where
122 T: Clone + Hash + Eq,
123{
124 pub fn new() -> Self {
125 Self {
126 join_ref: Uuid::new_v4(),
127 pending_requests: Vec::new(),
128 in_flight_requests: Default::default(),
129 subscriptions: Default::default(),
130 joined: false,
131 }
132 }
133
134 fn send_request<V>(
135 &mut self,
136 request: Request<T>,
137 message: SendPhoenixMessage<V>,
138 ) -> Vec<SessionResult<T>>
139 where
140 V: Serialize,
141 {
142 let message_str = serde_json::to_string(&message).unwrap();
143
144 if self.joined || matches!(request, Request::Join) {
145 self.in_flight_requests.insert(message.msg_ref, request);
146 vec![SessionResult::Message(message_str)]
147 } else {
148 self.pending_requests.push(PendingRequest {
149 msg_ref: message.msg_ref,
150 message: message_str,
151 request,
152 });
153 vec![]
154 }
155 }
156
157 pub fn join(&mut self) -> Vec<SessionResult<T>> {
158 self.send_request(
159 Request::Join,
160 SendPhoenixMessage {
161 join_ref: self.join_ref,
162 msg_ref: Uuid::new_v4(),
163 topic: TOPIC_ABSINTHE_CONTROL.into(),
164 event: SendEvent::PhxJoin,
165 payload: Empty,
166 },
167 )
168 }
169
170 pub fn ping(&mut self) -> Vec<SessionResult<T>> {
171 self.send_request(
172 Request::Ping,
173 SendPhoenixMessage {
174 join_ref: self.join_ref,
175 msg_ref: Uuid::new_v4(),
176 topic: TOPIC_PHOENIX.into(),
177 event: SendEvent::Heartbeat,
178 payload: Empty,
179 },
180 )
181 }
182
183 pub fn subscribe<B>(&mut self, reference: T, body: B) -> Vec<SessionResult<T>>
184 where
185 B: Serialize,
186 {
187 self.send_request(
188 Request::Subscribe(reference),
189 SendPhoenixMessage {
190 join_ref: self.join_ref,
191 msg_ref: Uuid::new_v4(),
192 topic: TOPIC_ABSINTHE_CONTROL.into(),
193 event: SendEvent::Doc,
194 payload: body,
195 },
196 )
197 }
198
199 pub fn unsubscribe(&mut self, reference: T) -> Vec<SessionResult<T>> {
200 match self.subscriptions.remove_by_right(&reference) {
201 Some((subscription_id, _)) => self.send_request(
202 Request::Unsubscribe(reference),
203 SendPhoenixMessage {
204 join_ref: self.join_ref,
205 msg_ref: Uuid::new_v4(),
206 topic: TOPIC_ABSINTHE_CONTROL.into(),
207 event: SendEvent::Unsubscribe,
208 payload: UnsubscribePayload { subscription_id },
209 },
210 ),
211 None => {
212 vec![]
213 }
214 }
215 }
216
217 pub fn handle_message(
218 &mut self,
219 msg: &str,
220 ) -> Result<Vec<SessionResult<T>>, HandleMessageError> {
221 let message: ReceivePhoenixMessage<Value> =
222 serde_json::from_str(msg).context(DeserializeSnafu {})?;
223
224 let results = match message.event {
225 ReceiveEvent::PhxReply => match message.msg_ref {
226 Some(msg_ref) => match self.in_flight_requests.remove(&msg_ref) {
227 Some(Request::Join) => {
228 self.joined = true;
229 let mut results = Vec::with_capacity(self.pending_requests.len() + 1);
230 results.push(Event::Joined.into());
231 for pending in self.pending_requests.drain(..) {
232 self.in_flight_requests
233 .insert(pending.msg_ref, pending.request);
234 results.push(SessionResult::Message(pending.message));
235 }
236 results
237 }
238 Some(Request::Ping) => {
239 vec![Event::Pong.into()]
240 }
241 Some(Request::Subscribe(reference)) => {
242 let reply: PhxReply<DocumentSubscribeResponse> =
243 serde_json::from_value(message.payload).context(DeserializeSnafu {})?;
244 self.subscriptions
245 .insert(reply.response.subscription_id, reference);
246 vec![]
247 }
248 Some(Request::Unsubscribe(_)) => {
249 vec![]
250 }
251 None => {
252 tracing::warn!(?message, "received a reply to a request we didn't make");
253 vec![]
254 }
255 },
256 None => {
257 tracing::warn!(
258 ?message,
259 "received a reply that cannot be matched to a request"
260 );
261 vec![]
262 }
263 },
264 ReceiveEvent::PhxError => {
265 tracing::error!(?message, "got phx_error");
266 vec![]
268 }
269 ReceiveEvent::SubscriptionData => {
270 let data: SubscriptionEvent =
271 serde_json::from_value(message.payload).context(DeserializeSnafu {})?;
272 match self.subscriptions.get_by_left(&data.subscription_id) {
273 Some(reference) => {
274 vec![Event::Document(reference.clone(), data.result).into()]
275 }
276 None => {
277 tracing::warn!(
278 ?data,
279 "received subscription data for a subscription we are not tracking"
280 );
281 vec![]
282 }
283 }
284 }
285 };
286
287 Ok(results)
288 }
289}
290
291impl<T> Default for SocketSession<T>
292where
293 T: Clone + Hash + Eq,
294{
295 fn default() -> Self {
296 Self::new()
297 }
298}
299
300#[derive(Debug, Snafu)]
301pub enum HandleMessageError {
302 DeserializeError { source: serde_json::Error },
303}