Skip to main content

pondsocket_client/
lib.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::Duration;
4
5use futures_util::{SinkExt, StreamExt};
6use pondsocket_common::{
7    ChannelEvent, ChannelState, ClientAction, ClientMessage, EventName, JoinParams, PondMessage,
8    PondPresence, PresenceEventType, PresenceMessage, ServerAction, ServerMessage, uuid,
9};
10use serde_json::{Map, Value};
11use thiserror::Error;
12use tokio::sync::{Mutex, broadcast, mpsc, oneshot, watch};
13use tokio::task::JoinHandle;
14use tokio_tungstenite::connect_async;
15use tokio_tungstenite::tungstenite::Message;
16use url::Url;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ConnectionState {
20    Connecting,
21    Connected,
22    Disconnected,
23}
24
25#[derive(Debug, Clone)]
26pub struct ClientOptions {
27    pub connection_timeout: Duration,
28    pub response_timeout: Duration,
29    pub max_queue_size: usize,
30}
31
32impl Default for ClientOptions {
33    fn default() -> Self {
34        Self {
35            connection_timeout: Duration::from_secs(10),
36            response_timeout: Duration::from_secs(5),
37            max_queue_size: 100,
38        }
39    }
40}
41
42#[derive(Debug, Error)]
43pub enum ClientError {
44    #[error("invalid websocket URL: {0}")]
45    Url(#[from] url::ParseError),
46    #[error("unsupported URL scheme: {0}")]
47    UnsupportedScheme(String),
48    #[error("websocket error: {0}")]
49    WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
50    #[error("serialization error: {0}")]
51    Serialization(#[from] serde_json::Error),
52    #[error("connection timed out")]
53    ConnectionTimeout,
54    #[error("client is not connected")]
55    NotConnected,
56    #[error("channel is closed")]
57    ChannelClosed,
58    #[error("response timed out")]
59    ResponseTimeout,
60}
61
62type Result<T> = std::result::Result<T, ClientError>;
63
64#[derive(Clone)]
65pub struct PondClient {
66    inner: Arc<ClientInner>,
67}
68
69struct ClientInner {
70    url: String,
71    options: ClientOptions,
72    state: watch::Sender<ConnectionState>,
73    channels: Mutex<HashMap<String, Channel>>,
74    outbound: Mutex<Option<mpsc::Sender<ClientMessage>>>,
75    read_task: Mutex<Option<JoinHandle<()>>>,
76    write_task: Mutex<Option<JoinHandle<()>>>,
77}
78
79#[derive(Clone)]
80pub struct Channel {
81    inner: Arc<ChannelInner>,
82}
83
84struct ChannelInner {
85    name: String,
86    params: JoinParams,
87    client: Arc<ClientInner>,
88    state: watch::Sender<ChannelState>,
89    events: broadcast::Sender<ChannelEvent>,
90    presence: Mutex<Vec<PondPresence>>,
91    queue: Mutex<VecDeque<ClientMessage>>,
92    pending: Mutex<HashMap<String, oneshot::Sender<PondMessage>>>,
93    closed: Mutex<bool>,
94}
95
96impl PondClient {
97    pub fn new(endpoint: impl AsRef<str>, params: Option<JoinParams>) -> Result<Self> {
98        Self::with_options(endpoint, params, ClientOptions::default())
99    }
100
101    pub fn with_options(
102        endpoint: impl AsRef<str>,
103        params: Option<JoinParams>,
104        options: ClientOptions,
105    ) -> Result<Self> {
106        let url = resolve_url(endpoint.as_ref(), params.as_ref())?;
107        let (state, _) = watch::channel(ConnectionState::Disconnected);
108
109        Ok(Self {
110            inner: Arc::new(ClientInner {
111                url,
112                options,
113                state,
114                channels: Mutex::new(HashMap::new()),
115                outbound: Mutex::new(None),
116                read_task: Mutex::new(None),
117                write_task: Mutex::new(None),
118            }),
119        })
120    }
121
122    pub fn state(&self) -> ConnectionState {
123        *self.inner.state.borrow()
124    }
125
126    pub fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
127        self.inner.state.subscribe()
128    }
129
130    pub async fn create_channel(
131        &self,
132        name: impl Into<String>,
133        params: Option<JoinParams>,
134    ) -> Channel {
135        let name = name.into();
136        let mut channels = self.inner.channels.lock().await;
137        if let Some(channel) = channels.get(&name) {
138            if channel.state() != ChannelState::Closed && channel.state() != ChannelState::Declined
139            {
140                return channel.clone();
141            }
142        }
143
144        let (state, _) = watch::channel(ChannelState::Idle);
145        let (events, _) = broadcast::channel(100);
146        let channel = Channel {
147            inner: Arc::new(ChannelInner {
148                name: name.clone(),
149                params: params.unwrap_or_default(),
150                client: Arc::clone(&self.inner),
151                state,
152                events,
153                presence: Mutex::new(Vec::new()),
154                queue: Mutex::new(VecDeque::new()),
155                pending: Mutex::new(HashMap::new()),
156                closed: Mutex::new(false),
157            }),
158        };
159        channels.insert(name, channel.clone());
160        channel
161    }
162
163    pub async fn connect(&self) -> Result<()> {
164        if self.state() != ConnectionState::Disconnected {
165            return Ok(());
166        }
167        self.inner.state.send_replace(ConnectionState::Connecting);
168        let connect = connect_async(&self.inner.url);
169        let (socket, _) = tokio::time::timeout(self.inner.options.connection_timeout, connect)
170            .await
171            .map_err(|_| ClientError::ConnectionTimeout)??;
172        let (mut writer, mut reader) = socket.split();
173        let (tx, mut rx) = mpsc::channel::<ClientMessage>(self.inner.options.max_queue_size);
174        *self.inner.outbound.lock().await = Some(tx);
175
176        let write_task = tokio::spawn(async move {
177            while let Some(message) = rx.recv().await {
178                let Ok(text) = serde_json::to_string(&message) else {
179                    continue;
180                };
181                if writer.send(Message::Text(text.into())).await.is_err() {
182                    break;
183                }
184            }
185            let _ = writer.close().await;
186        });
187
188        let inner = Arc::clone(&self.inner);
189        let read_task = tokio::spawn(async move {
190            while let Some(frame) = reader.next().await {
191                let text = match frame {
192                    Ok(Message::Text(text)) => text.to_string(),
193                    Ok(Message::Binary(bytes)) => match String::from_utf8(bytes.to_vec()) {
194                        Ok(text) => text,
195                        Err(_) => continue,
196                    },
197                    Ok(Message::Close(_)) => break,
198                    Ok(_) => continue,
199                    Err(_) => break,
200                };
201                let Ok(event) = pondsocket_common::parse_channel_event(&text) else {
202                    continue;
203                };
204                inner.route_event(event).await;
205            }
206            inner.state.send_replace(ConnectionState::Disconnected);
207            *inner.outbound.lock().await = None;
208        });
209
210        *self.inner.read_task.lock().await = Some(read_task);
211        *self.inner.write_task.lock().await = Some(write_task);
212        self.inner.state.send_replace(ConnectionState::Connected);
213        self.inner.rejoin_stalled_channels().await;
214        Ok(())
215    }
216
217    pub async fn disconnect(&self) {
218        if let Some(task) = self.inner.read_task.lock().await.take() {
219            task.abort();
220        }
221        if let Some(task) = self.inner.write_task.lock().await.take() {
222            task.abort();
223        }
224        *self.inner.outbound.lock().await = None;
225        self.inner.state.send_replace(ConnectionState::Disconnected);
226        let channels: Vec<Channel> = self.inner.channels.lock().await.values().cloned().collect();
227        for channel in channels {
228            channel.force_close().await;
229        }
230        self.inner.channels.lock().await.clear();
231    }
232}
233
234impl ClientInner {
235    async fn publish(&self, message: ClientMessage) -> Result<()> {
236        let tx = self
237            .outbound
238            .lock()
239            .await
240            .clone()
241            .ok_or(ClientError::NotConnected)?;
242        tx.send(message)
243            .await
244            .map_err(|_| ClientError::NotConnected)
245    }
246
247    async fn route_event(&self, event: ChannelEvent) {
248        let channel_name = match &event {
249            ChannelEvent::Message(message) => &message.channel_name,
250            ChannelEvent::Presence(message) => &message.channel_name,
251        };
252        let channel = self.channels.lock().await.get(channel_name).cloned();
253        if let Some(channel) = channel {
254            channel.handle_event(event).await;
255        }
256    }
257
258    async fn rejoin_stalled_channels(&self) {
259        let channels: Vec<Channel> = self.channels.lock().await.values().cloned().collect();
260        for channel in channels {
261            let state = channel.state();
262            if state == ChannelState::Joining
263                || state == ChannelState::Joined
264                || state == ChannelState::Stalled
265            {
266                channel.join().await;
267            }
268        }
269    }
270}
271
272impl Channel {
273    pub fn name(&self) -> &str {
274        &self.inner.name
275    }
276
277    pub fn state(&self) -> ChannelState {
278        *self.inner.state.borrow()
279    }
280
281    pub fn subscribe_state(&self) -> watch::Receiver<ChannelState> {
282        self.inner.state.subscribe()
283    }
284
285    pub fn subscribe_events(&self) -> broadcast::Receiver<ChannelEvent> {
286        self.inner.events.subscribe()
287    }
288
289    pub async fn presence(&self) -> Vec<PondPresence> {
290        self.inner.presence.lock().await.clone()
291    }
292
293    pub async fn join(&self) {
294        if *self.inner.closed.lock().await {
295            return;
296        }
297        if matches!(
298            self.state(),
299            ChannelState::Joining | ChannelState::Joined | ChannelState::Declined
300        ) {
301            return;
302        }
303        self.inner.state.send_replace(ChannelState::Joining);
304        self.enqueue_or_send(self.join_message()).await;
305    }
306
307    pub async fn leave(&self) {
308        if *self.inner.closed.lock().await {
309            return;
310        }
311        let message = ClientMessage {
312            action: ClientAction::LeaveChannel,
313            event: "LEAVE_CHANNEL".to_owned(),
314            payload: Map::new(),
315            channel_name: self.inner.name.clone(),
316            request_id: uuid(),
317        };
318        let _ = self.inner.client.publish(message).await;
319        self.force_close().await;
320    }
321
322    pub async fn send_message(&self, event: impl Into<String>, payload: Option<PondMessage>) {
323        if *self.inner.closed.lock().await {
324            return;
325        }
326        let message = ClientMessage {
327            action: ClientAction::Broadcast,
328            event: event.into(),
329            payload: payload.unwrap_or_default(),
330            channel_name: self.inner.name.clone(),
331            request_id: uuid(),
332        };
333        self.enqueue_or_send(message).await;
334    }
335
336    pub async fn send_for_response(
337        &self,
338        event: impl Into<String>,
339        payload: Option<PondMessage>,
340        timeout: Option<Duration>,
341    ) -> Result<PondMessage> {
342        if *self.inner.closed.lock().await {
343            return Err(ClientError::ChannelClosed);
344        }
345        let request_id = uuid();
346        let (tx, rx) = oneshot::channel();
347        self.inner
348            .pending
349            .lock()
350            .await
351            .insert(request_id.clone(), tx);
352        let message = ClientMessage {
353            action: ClientAction::Broadcast,
354            event: event.into(),
355            payload: payload.unwrap_or_default(),
356            channel_name: self.inner.name.clone(),
357            request_id: request_id.clone(),
358        };
359        self.enqueue_or_send(message).await;
360        let timeout = timeout.unwrap_or(self.inner.client.options.response_timeout);
361        let result = tokio::time::timeout(timeout, rx).await;
362        self.inner.pending.lock().await.remove(&request_id);
363        match result {
364            Ok(Ok(payload)) => Ok(payload),
365            _ => Err(ClientError::ResponseTimeout),
366        }
367    }
368
369    async fn enqueue_or_send(&self, message: ClientMessage) {
370        let connected = *self.inner.client.state.borrow() == ConnectionState::Connected;
371        let joined = self.state() == ChannelState::Joined;
372        let is_join = message.action == ClientAction::JoinChannel;
373        if connected && (joined || is_join) {
374            if self.inner.client.publish(message.clone()).await.is_ok() {
375                return;
376            }
377        }
378        let mut queue = self.inner.queue.lock().await;
379        if queue.len() == self.inner.client.options.max_queue_size {
380            queue.pop_front();
381        }
382        queue.push_back(message);
383    }
384
385    async fn handle_event(&self, event: ChannelEvent) {
386        if *self.inner.closed.lock().await {
387            return;
388        }
389        match event {
390            ChannelEvent::Presence(message) => self.handle_presence(message).await,
391            ChannelEvent::Message(message) => self.handle_message(message).await,
392        }
393    }
394
395    async fn handle_presence(&self, message: PresenceMessage) {
396        *self.inner.presence.lock().await = message.payload.presence.clone();
397        let event = ChannelEvent::Presence(message.clone());
398        let _ = self.inner.events.send(event);
399    }
400
401    async fn handle_message(&self, message: ServerMessage) {
402        if message.action == ServerAction::System
403            && message.event == event_name(EventName::Acknowledge)
404        {
405            self.acknowledge().await;
406            return;
407        }
408        if message.action == ServerAction::System
409            && message.event == event_name(EventName::Unauthorized)
410        {
411            self.decline().await;
412            return;
413        }
414        if let Some(tx) = self.inner.pending.lock().await.remove(&message.request_id) {
415            let _ = tx.send(message.payload);
416            return;
417        }
418        if self.state() == ChannelState::Joined {
419            let _ = self.inner.events.send(ChannelEvent::Message(message));
420        }
421    }
422
423    async fn acknowledge(&self) {
424        if self.state() != ChannelState::Joined {
425            self.inner.state.send_replace(ChannelState::Joined);
426        }
427        let mut queue = self.inner.queue.lock().await;
428        let pending: Vec<ClientMessage> = queue.drain(..).collect();
429        drop(queue);
430        for message in pending {
431            let _ = self.inner.client.publish(message).await;
432        }
433    }
434
435    async fn decline(&self) {
436        self.inner.state.send_replace(ChannelState::Declined);
437        self.inner.queue.lock().await.clear();
438        self.inner.pending.lock().await.clear();
439    }
440
441    async fn force_close(&self) {
442        *self.inner.closed.lock().await = true;
443        self.inner.state.send_replace(ChannelState::Closed);
444        self.inner.queue.lock().await.clear();
445        self.inner.pending.lock().await.clear();
446    }
447
448    fn join_message(&self) -> ClientMessage {
449        ClientMessage {
450            action: ClientAction::JoinChannel,
451            event: "JOIN_CHANNEL".to_owned(),
452            payload: self.inner.params.clone(),
453            channel_name: self.inner.name.clone(),
454            request_id: uuid(),
455        }
456    }
457}
458
459fn resolve_url(endpoint: &str, params: Option<&JoinParams>) -> Result<String> {
460    let mut url = Url::parse(endpoint)?;
461    match url.scheme() {
462        "http" => url
463            .set_scheme("ws")
464            .map_err(|_| ClientError::UnsupportedScheme("http".to_owned()))?,
465        "https" => url
466            .set_scheme("wss")
467            .map_err(|_| ClientError::UnsupportedScheme("https".to_owned()))?,
468        "ws" | "wss" => {}
469        scheme => return Err(ClientError::UnsupportedScheme(scheme.to_owned())),
470    }
471    if let Some(params) = params {
472        let mut pairs = url.query_pairs_mut();
473        for (key, value) in params {
474            let value = match value {
475                Value::String(value) => value.clone(),
476                other => other.to_string(),
477            };
478            pairs.append_pair(key, &value);
479        }
480    }
481    Ok(url.to_string())
482}
483
484fn event_name(event: EventName) -> String {
485    serde_json::to_string(&event)
486        .unwrap_or_default()
487        .trim_matches('"')
488        .to_owned()
489}
490
491#[allow(dead_code)]
492fn presence_event_name(event: PresenceEventType) -> String {
493    serde_json::to_string(&event)
494        .unwrap_or_default()
495        .trim_matches('"')
496        .to_owned()
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    #[test]
504    fn resolves_http_url_to_ws_with_params() {
505        let mut params = JoinParams::new();
506        params.insert("token".to_owned(), Value::String("abc".to_owned()));
507        let url = resolve_url("https://example.com/socket?room=one", Some(&params)).unwrap();
508        assert_eq!(url, "wss://example.com/socket?room=one&token=abc");
509    }
510
511    #[tokio::test]
512    async fn queues_join_message_before_connect() {
513        let client = PondClient::new("ws://example.com/socket", None).unwrap();
514        let channel = client.create_channel("room", None).await;
515        channel.join().await;
516        assert_eq!(channel.state(), ChannelState::Joining);
517        assert_eq!(channel.inner.queue.lock().await.len(), 1);
518    }
519}