1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::Duration;
4
5use futures_util::{SinkExt, StreamExt};
6use pondsocket_common::{
7 ChannelEvent, ChannelState, ClientAction, ClientMessage, EventName, JoinParams, PondMessage,
8 PondPresence, PresenceEventType, PresenceMessage, ServerAction, ServerMessage, uuid,
9};
10use serde_json::{Map, Value};
11use thiserror::Error;
12use tokio::sync::{Mutex, broadcast, mpsc, oneshot, watch};
13use tokio::task::JoinHandle;
14use tokio_tungstenite::connect_async;
15use tokio_tungstenite::tungstenite::Message;
16use url::Url;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ConnectionState {
20 Connecting,
21 Connected,
22 Disconnected,
23}
24
25#[derive(Debug, Clone)]
26pub struct ClientOptions {
27 pub connection_timeout: Duration,
28 pub response_timeout: Duration,
29 pub max_queue_size: usize,
30}
31
32impl Default for ClientOptions {
33 fn default() -> Self {
34 Self {
35 connection_timeout: Duration::from_secs(10),
36 response_timeout: Duration::from_secs(5),
37 max_queue_size: 100,
38 }
39 }
40}
41
42#[derive(Debug, Error)]
43pub enum ClientError {
44 #[error("invalid websocket URL: {0}")]
45 Url(#[from] url::ParseError),
46 #[error("unsupported URL scheme: {0}")]
47 UnsupportedScheme(String),
48 #[error("websocket error: {0}")]
49 WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
50 #[error("serialization error: {0}")]
51 Serialization(#[from] serde_json::Error),
52 #[error("connection timed out")]
53 ConnectionTimeout,
54 #[error("client is not connected")]
55 NotConnected,
56 #[error("channel is closed")]
57 ChannelClosed,
58 #[error("response timed out")]
59 ResponseTimeout,
60}
61
62type Result<T> = std::result::Result<T, ClientError>;
63
64#[derive(Clone)]
65pub struct PondClient {
66 inner: Arc<ClientInner>,
67}
68
69struct ClientInner {
70 url: String,
71 options: ClientOptions,
72 state: watch::Sender<ConnectionState>,
73 channels: Mutex<HashMap<String, Channel>>,
74 outbound: Mutex<Option<mpsc::Sender<ClientMessage>>>,
75 read_task: Mutex<Option<JoinHandle<()>>>,
76 write_task: Mutex<Option<JoinHandle<()>>>,
77}
78
79#[derive(Clone)]
80pub struct Channel {
81 inner: Arc<ChannelInner>,
82}
83
84struct ChannelInner {
85 name: String,
86 params: JoinParams,
87 client: Arc<ClientInner>,
88 state: watch::Sender<ChannelState>,
89 events: broadcast::Sender<ChannelEvent>,
90 presence: Mutex<Vec<PondPresence>>,
91 queue: Mutex<VecDeque<ClientMessage>>,
92 pending: Mutex<HashMap<String, oneshot::Sender<PondMessage>>>,
93 closed: Mutex<bool>,
94}
95
96impl PondClient {
97 pub fn new(endpoint: impl AsRef<str>, params: Option<JoinParams>) -> Result<Self> {
98 Self::with_options(endpoint, params, ClientOptions::default())
99 }
100
101 pub fn with_options(
102 endpoint: impl AsRef<str>,
103 params: Option<JoinParams>,
104 options: ClientOptions,
105 ) -> Result<Self> {
106 let url = resolve_url(endpoint.as_ref(), params.as_ref())?;
107 let (state, _) = watch::channel(ConnectionState::Disconnected);
108
109 Ok(Self {
110 inner: Arc::new(ClientInner {
111 url,
112 options,
113 state,
114 channels: Mutex::new(HashMap::new()),
115 outbound: Mutex::new(None),
116 read_task: Mutex::new(None),
117 write_task: Mutex::new(None),
118 }),
119 })
120 }
121
122 pub fn state(&self) -> ConnectionState {
123 *self.inner.state.borrow()
124 }
125
126 pub fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
127 self.inner.state.subscribe()
128 }
129
130 pub async fn create_channel(
131 &self,
132 name: impl Into<String>,
133 params: Option<JoinParams>,
134 ) -> Channel {
135 let name = name.into();
136 let mut channels = self.inner.channels.lock().await;
137 if let Some(channel) = channels.get(&name) {
138 if channel.state() != ChannelState::Closed && channel.state() != ChannelState::Declined
139 {
140 return channel.clone();
141 }
142 }
143
144 let (state, _) = watch::channel(ChannelState::Idle);
145 let (events, _) = broadcast::channel(100);
146 let channel = Channel {
147 inner: Arc::new(ChannelInner {
148 name: name.clone(),
149 params: params.unwrap_or_default(),
150 client: Arc::clone(&self.inner),
151 state,
152 events,
153 presence: Mutex::new(Vec::new()),
154 queue: Mutex::new(VecDeque::new()),
155 pending: Mutex::new(HashMap::new()),
156 closed: Mutex::new(false),
157 }),
158 };
159 channels.insert(name, channel.clone());
160 channel
161 }
162
163 pub async fn connect(&self) -> Result<()> {
164 if self.state() != ConnectionState::Disconnected {
165 return Ok(());
166 }
167 self.inner.state.send_replace(ConnectionState::Connecting);
168 let connect = connect_async(&self.inner.url);
169 let (socket, _) = tokio::time::timeout(self.inner.options.connection_timeout, connect)
170 .await
171 .map_err(|_| ClientError::ConnectionTimeout)??;
172 let (mut writer, mut reader) = socket.split();
173 let (tx, mut rx) = mpsc::channel::<ClientMessage>(self.inner.options.max_queue_size);
174 *self.inner.outbound.lock().await = Some(tx);
175
176 let write_task = tokio::spawn(async move {
177 while let Some(message) = rx.recv().await {
178 let Ok(text) = serde_json::to_string(&message) else {
179 continue;
180 };
181 if writer.send(Message::Text(text.into())).await.is_err() {
182 break;
183 }
184 }
185 let _ = writer.close().await;
186 });
187
188 let inner = Arc::clone(&self.inner);
189 let read_task = tokio::spawn(async move {
190 while let Some(frame) = reader.next().await {
191 let text = match frame {
192 Ok(Message::Text(text)) => text.to_string(),
193 Ok(Message::Binary(bytes)) => match String::from_utf8(bytes.to_vec()) {
194 Ok(text) => text,
195 Err(_) => continue,
196 },
197 Ok(Message::Close(_)) => break,
198 Ok(_) => continue,
199 Err(_) => break,
200 };
201 let Ok(event) = pondsocket_common::parse_channel_event(&text) else {
202 continue;
203 };
204 inner.route_event(event).await;
205 }
206 inner.state.send_replace(ConnectionState::Disconnected);
207 *inner.outbound.lock().await = None;
208 });
209
210 *self.inner.read_task.lock().await = Some(read_task);
211 *self.inner.write_task.lock().await = Some(write_task);
212 self.inner.state.send_replace(ConnectionState::Connected);
213 self.inner.rejoin_stalled_channels().await;
214 Ok(())
215 }
216
217 pub async fn disconnect(&self) {
218 if let Some(task) = self.inner.read_task.lock().await.take() {
219 task.abort();
220 }
221 if let Some(task) = self.inner.write_task.lock().await.take() {
222 task.abort();
223 }
224 *self.inner.outbound.lock().await = None;
225 self.inner.state.send_replace(ConnectionState::Disconnected);
226 let channels: Vec<Channel> = self.inner.channels.lock().await.values().cloned().collect();
227 for channel in channels {
228 channel.force_close().await;
229 }
230 self.inner.channels.lock().await.clear();
231 }
232}
233
234impl ClientInner {
235 async fn publish(&self, message: ClientMessage) -> Result<()> {
236 let tx = self
237 .outbound
238 .lock()
239 .await
240 .clone()
241 .ok_or(ClientError::NotConnected)?;
242 tx.send(message)
243 .await
244 .map_err(|_| ClientError::NotConnected)
245 }
246
247 async fn route_event(&self, event: ChannelEvent) {
248 let channel_name = match &event {
249 ChannelEvent::Message(message) => &message.channel_name,
250 ChannelEvent::Presence(message) => &message.channel_name,
251 };
252 let channel = self.channels.lock().await.get(channel_name).cloned();
253 if let Some(channel) = channel {
254 channel.handle_event(event).await;
255 }
256 }
257
258 async fn rejoin_stalled_channels(&self) {
259 let channels: Vec<Channel> = self.channels.lock().await.values().cloned().collect();
260 for channel in channels {
261 let state = channel.state();
262 if state == ChannelState::Joining
263 || state == ChannelState::Joined
264 || state == ChannelState::Stalled
265 {
266 channel.join().await;
267 }
268 }
269 }
270}
271
272impl Channel {
273 pub fn name(&self) -> &str {
274 &self.inner.name
275 }
276
277 pub fn state(&self) -> ChannelState {
278 *self.inner.state.borrow()
279 }
280
281 pub fn subscribe_state(&self) -> watch::Receiver<ChannelState> {
282 self.inner.state.subscribe()
283 }
284
285 pub fn subscribe_events(&self) -> broadcast::Receiver<ChannelEvent> {
286 self.inner.events.subscribe()
287 }
288
289 pub async fn presence(&self) -> Vec<PondPresence> {
290 self.inner.presence.lock().await.clone()
291 }
292
293 pub async fn join(&self) {
294 if *self.inner.closed.lock().await {
295 return;
296 }
297 if matches!(
298 self.state(),
299 ChannelState::Joining | ChannelState::Joined | ChannelState::Declined
300 ) {
301 return;
302 }
303 self.inner.state.send_replace(ChannelState::Joining);
304 self.enqueue_or_send(self.join_message()).await;
305 }
306
307 pub async fn leave(&self) {
308 if *self.inner.closed.lock().await {
309 return;
310 }
311 let message = ClientMessage {
312 action: ClientAction::LeaveChannel,
313 event: "LEAVE_CHANNEL".to_owned(),
314 payload: Map::new(),
315 channel_name: self.inner.name.clone(),
316 request_id: uuid(),
317 };
318 let _ = self.inner.client.publish(message).await;
319 self.force_close().await;
320 }
321
322 pub async fn send_message(&self, event: impl Into<String>, payload: Option<PondMessage>) {
323 if *self.inner.closed.lock().await {
324 return;
325 }
326 let message = ClientMessage {
327 action: ClientAction::Broadcast,
328 event: event.into(),
329 payload: payload.unwrap_or_default(),
330 channel_name: self.inner.name.clone(),
331 request_id: uuid(),
332 };
333 self.enqueue_or_send(message).await;
334 }
335
336 pub async fn send_for_response(
337 &self,
338 event: impl Into<String>,
339 payload: Option<PondMessage>,
340 timeout: Option<Duration>,
341 ) -> Result<PondMessage> {
342 if *self.inner.closed.lock().await {
343 return Err(ClientError::ChannelClosed);
344 }
345 let request_id = uuid();
346 let (tx, rx) = oneshot::channel();
347 self.inner
348 .pending
349 .lock()
350 .await
351 .insert(request_id.clone(), tx);
352 let message = ClientMessage {
353 action: ClientAction::Broadcast,
354 event: event.into(),
355 payload: payload.unwrap_or_default(),
356 channel_name: self.inner.name.clone(),
357 request_id: request_id.clone(),
358 };
359 self.enqueue_or_send(message).await;
360 let timeout = timeout.unwrap_or(self.inner.client.options.response_timeout);
361 let result = tokio::time::timeout(timeout, rx).await;
362 self.inner.pending.lock().await.remove(&request_id);
363 match result {
364 Ok(Ok(payload)) => Ok(payload),
365 _ => Err(ClientError::ResponseTimeout),
366 }
367 }
368
369 async fn enqueue_or_send(&self, message: ClientMessage) {
370 let connected = *self.inner.client.state.borrow() == ConnectionState::Connected;
371 let joined = self.state() == ChannelState::Joined;
372 let is_join = message.action == ClientAction::JoinChannel;
373 if connected && (joined || is_join) {
374 if self.inner.client.publish(message.clone()).await.is_ok() {
375 return;
376 }
377 }
378 let mut queue = self.inner.queue.lock().await;
379 if queue.len() == self.inner.client.options.max_queue_size {
380 queue.pop_front();
381 }
382 queue.push_back(message);
383 }
384
385 async fn handle_event(&self, event: ChannelEvent) {
386 if *self.inner.closed.lock().await {
387 return;
388 }
389 match event {
390 ChannelEvent::Presence(message) => self.handle_presence(message).await,
391 ChannelEvent::Message(message) => self.handle_message(message).await,
392 }
393 }
394
395 async fn handle_presence(&self, message: PresenceMessage) {
396 *self.inner.presence.lock().await = message.payload.presence.clone();
397 let event = ChannelEvent::Presence(message.clone());
398 let _ = self.inner.events.send(event);
399 }
400
401 async fn handle_message(&self, message: ServerMessage) {
402 if message.action == ServerAction::System
403 && message.event == event_name(EventName::Acknowledge)
404 {
405 self.acknowledge().await;
406 return;
407 }
408 if message.action == ServerAction::System
409 && message.event == event_name(EventName::Unauthorized)
410 {
411 self.decline().await;
412 return;
413 }
414 if let Some(tx) = self.inner.pending.lock().await.remove(&message.request_id) {
415 let _ = tx.send(message.payload);
416 return;
417 }
418 if self.state() == ChannelState::Joined {
419 let _ = self.inner.events.send(ChannelEvent::Message(message));
420 }
421 }
422
423 async fn acknowledge(&self) {
424 if self.state() != ChannelState::Joined {
425 self.inner.state.send_replace(ChannelState::Joined);
426 }
427 let mut queue = self.inner.queue.lock().await;
428 let pending: Vec<ClientMessage> = queue.drain(..).collect();
429 drop(queue);
430 for message in pending {
431 let _ = self.inner.client.publish(message).await;
432 }
433 }
434
435 async fn decline(&self) {
436 self.inner.state.send_replace(ChannelState::Declined);
437 self.inner.queue.lock().await.clear();
438 self.inner.pending.lock().await.clear();
439 }
440
441 async fn force_close(&self) {
442 *self.inner.closed.lock().await = true;
443 self.inner.state.send_replace(ChannelState::Closed);
444 self.inner.queue.lock().await.clear();
445 self.inner.pending.lock().await.clear();
446 }
447
448 fn join_message(&self) -> ClientMessage {
449 ClientMessage {
450 action: ClientAction::JoinChannel,
451 event: "JOIN_CHANNEL".to_owned(),
452 payload: self.inner.params.clone(),
453 channel_name: self.inner.name.clone(),
454 request_id: uuid(),
455 }
456 }
457}
458
459fn resolve_url(endpoint: &str, params: Option<&JoinParams>) -> Result<String> {
460 let mut url = Url::parse(endpoint)?;
461 match url.scheme() {
462 "http" => url
463 .set_scheme("ws")
464 .map_err(|_| ClientError::UnsupportedScheme("http".to_owned()))?,
465 "https" => url
466 .set_scheme("wss")
467 .map_err(|_| ClientError::UnsupportedScheme("https".to_owned()))?,
468 "ws" | "wss" => {}
469 scheme => return Err(ClientError::UnsupportedScheme(scheme.to_owned())),
470 }
471 if let Some(params) = params {
472 let mut pairs = url.query_pairs_mut();
473 for (key, value) in params {
474 let value = match value {
475 Value::String(value) => value.clone(),
476 other => other.to_string(),
477 };
478 pairs.append_pair(key, &value);
479 }
480 }
481 Ok(url.to_string())
482}
483
484fn event_name(event: EventName) -> String {
485 serde_json::to_string(&event)
486 .unwrap_or_default()
487 .trim_matches('"')
488 .to_owned()
489}
490
491#[allow(dead_code)]
492fn presence_event_name(event: PresenceEventType) -> String {
493 serde_json::to_string(&event)
494 .unwrap_or_default()
495 .trim_matches('"')
496 .to_owned()
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502
503 #[test]
504 fn resolves_http_url_to_ws_with_params() {
505 let mut params = JoinParams::new();
506 params.insert("token".to_owned(), Value::String("abc".to_owned()));
507 let url = resolve_url("https://example.com/socket?room=one", Some(¶ms)).unwrap();
508 assert_eq!(url, "wss://example.com/socket?room=one&token=abc");
509 }
510
511 #[tokio::test]
512 async fn queues_join_message_before_connect() {
513 let client = PondClient::new("ws://example.com/socket", None).unwrap();
514 let channel = client.create_channel("room", None).await;
515 channel.join().await;
516 assert_eq!(channel.state(), ChannelState::Joining);
517 assert_eq!(channel.inner.queue.lock().await.len(), 1);
518 }
519}