Skip to main content

pondsocket_client/
lib.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::Duration;
4
5use futures_util::{SinkExt, StreamExt};
6use pondsocket_common::{
7    ChannelEvent, ChannelState, ClientAction, ClientMessage, EventName, JoinParams, PondMessage,
8    PondPresence, PresenceEventType, PresenceMessage, ServerAction, ServerMessage, uuid,
9};
10use serde_json::{Map, Value};
11use thiserror::Error;
12use tokio::sync::{Mutex, broadcast, mpsc, oneshot, watch};
13use tokio::task::JoinHandle;
14use tokio_tungstenite::connect_async;
15use tokio_tungstenite::tungstenite::Message;
16use url::Url;
17
18pub mod typed;
19pub use typed::TypedChannel;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ConnectionState {
23    Connecting,
24    Connected,
25    Disconnected,
26}
27
28#[derive(Debug, Clone)]
29pub struct ClientOptions {
30    pub connection_timeout: Duration,
31    pub response_timeout: Duration,
32    pub max_queue_size: usize,
33}
34
35impl Default for ClientOptions {
36    fn default() -> Self {
37        Self {
38            connection_timeout: Duration::from_secs(10),
39            response_timeout: Duration::from_secs(5),
40            max_queue_size: 100,
41        }
42    }
43}
44
45#[derive(Debug, Error)]
46pub enum ClientError {
47    #[error("invalid websocket URL: {0}")]
48    Url(#[from] url::ParseError),
49    #[error("unsupported URL scheme: {0}")]
50    UnsupportedScheme(String),
51    #[error("websocket error: {0}")]
52    WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
53    #[error("serialization error: {0}")]
54    Serialization(#[from] serde_json::Error),
55    #[error("connection timed out")]
56    ConnectionTimeout,
57    #[error("client is not connected")]
58    NotConnected,
59    #[error("channel is closed")]
60    ChannelClosed,
61    #[error("response timed out")]
62    ResponseTimeout,
63}
64
65type Result<T> = std::result::Result<T, ClientError>;
66
67#[derive(Clone)]
68pub struct PondClient {
69    inner: Arc<ClientInner>,
70}
71
72struct ClientInner {
73    url: String,
74    options: ClientOptions,
75    state: watch::Sender<ConnectionState>,
76    channels: Mutex<HashMap<String, Channel>>,
77    outbound: Mutex<Option<mpsc::Sender<ClientMessage>>>,
78    read_task: Mutex<Option<JoinHandle<()>>>,
79    write_task: Mutex<Option<JoinHandle<()>>>,
80}
81
82#[derive(Clone)]
83pub struct Channel {
84    inner: Arc<ChannelInner>,
85}
86
87struct ChannelInner {
88    name: String,
89    params: JoinParams,
90    client: Arc<ClientInner>,
91    state: watch::Sender<ChannelState>,
92    events: broadcast::Sender<ChannelEvent>,
93    presence: Mutex<Vec<PondPresence>>,
94    queue: Mutex<VecDeque<ClientMessage>>,
95    pending: Mutex<HashMap<String, oneshot::Sender<PondMessage>>>,
96    closed: Mutex<bool>,
97}
98
99impl PondClient {
100    pub fn new(endpoint: impl AsRef<str>, params: Option<JoinParams>) -> Result<Self> {
101        Self::with_options(endpoint, params, ClientOptions::default())
102    }
103
104    pub fn with_options(
105        endpoint: impl AsRef<str>,
106        params: Option<JoinParams>,
107        options: ClientOptions,
108    ) -> Result<Self> {
109        let url = resolve_url(endpoint.as_ref(), params.as_ref())?;
110        let (state, _) = watch::channel(ConnectionState::Disconnected);
111
112        Ok(Self {
113            inner: Arc::new(ClientInner {
114                url,
115                options,
116                state,
117                channels: Mutex::new(HashMap::new()),
118                outbound: Mutex::new(None),
119                read_task: Mutex::new(None),
120                write_task: Mutex::new(None),
121            }),
122        })
123    }
124
125    pub fn state(&self) -> ConnectionState {
126        *self.inner.state.borrow()
127    }
128
129    pub fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
130        self.inner.state.subscribe()
131    }
132
133    pub async fn create_channel(
134        &self,
135        name: impl Into<String>,
136        params: Option<JoinParams>,
137    ) -> Channel {
138        let name = name.into();
139        let mut channels = self.inner.channels.lock().await;
140        if let Some(channel) = channels.get(&name) {
141            if channel.state() != ChannelState::Closed && channel.state() != ChannelState::Declined
142            {
143                return channel.clone();
144            }
145        }
146
147        let (state, _) = watch::channel(ChannelState::Idle);
148        let (events, _) = broadcast::channel(100);
149        let channel = Channel {
150            inner: Arc::new(ChannelInner {
151                name: name.clone(),
152                params: params.unwrap_or_default(),
153                client: Arc::clone(&self.inner),
154                state,
155                events,
156                presence: Mutex::new(Vec::new()),
157                queue: Mutex::new(VecDeque::new()),
158                pending: Mutex::new(HashMap::new()),
159                closed: Mutex::new(false),
160            }),
161        };
162        channels.insert(name, channel.clone());
163        channel
164    }
165
166    pub async fn connect(&self) -> Result<()> {
167        if self.state() != ConnectionState::Disconnected {
168            return Ok(());
169        }
170        self.inner.state.send_replace(ConnectionState::Connecting);
171        let connect = connect_async(&self.inner.url);
172        let (socket, _) = tokio::time::timeout(self.inner.options.connection_timeout, connect)
173            .await
174            .map_err(|_| ClientError::ConnectionTimeout)??;
175        let (mut writer, mut reader) = socket.split();
176        let (tx, mut rx) = mpsc::channel::<ClientMessage>(self.inner.options.max_queue_size);
177        *self.inner.outbound.lock().await = Some(tx);
178
179        let write_task = tokio::spawn(async move {
180            while let Some(message) = rx.recv().await {
181                let Ok(text) = serde_json::to_string(&message) else {
182                    continue;
183                };
184                if writer.send(Message::Text(text.into())).await.is_err() {
185                    break;
186                }
187            }
188            let _ = writer.close().await;
189        });
190
191        let inner = Arc::clone(&self.inner);
192        let read_task = tokio::spawn(async move {
193            while let Some(frame) = reader.next().await {
194                let text = match frame {
195                    Ok(Message::Text(text)) => text.to_string(),
196                    Ok(Message::Binary(bytes)) => match String::from_utf8(bytes.to_vec()) {
197                        Ok(text) => text,
198                        Err(_) => continue,
199                    },
200                    Ok(Message::Close(_)) => break,
201                    Ok(_) => continue,
202                    Err(_) => break,
203                };
204                let Ok(event) = pondsocket_common::parse_channel_event(&text) else {
205                    continue;
206                };
207                inner.route_event(event).await;
208            }
209            inner.state.send_replace(ConnectionState::Disconnected);
210            *inner.outbound.lock().await = None;
211        });
212
213        *self.inner.read_task.lock().await = Some(read_task);
214        *self.inner.write_task.lock().await = Some(write_task);
215        self.inner.state.send_replace(ConnectionState::Connected);
216        self.inner.rejoin_stalled_channels().await;
217        Ok(())
218    }
219
220    pub async fn disconnect(&self) {
221        if let Some(task) = self.inner.read_task.lock().await.take() {
222            task.abort();
223        }
224        if let Some(task) = self.inner.write_task.lock().await.take() {
225            task.abort();
226        }
227        *self.inner.outbound.lock().await = None;
228        self.inner.state.send_replace(ConnectionState::Disconnected);
229        let channels: Vec<Channel> = self.inner.channels.lock().await.values().cloned().collect();
230        for channel in channels {
231            channel.force_close().await;
232        }
233        self.inner.channels.lock().await.clear();
234    }
235}
236
237impl ClientInner {
238    async fn publish(&self, message: ClientMessage) -> Result<()> {
239        let tx = self
240            .outbound
241            .lock()
242            .await
243            .clone()
244            .ok_or(ClientError::NotConnected)?;
245        tx.send(message)
246            .await
247            .map_err(|_| ClientError::NotConnected)
248    }
249
250    async fn route_event(&self, event: ChannelEvent) {
251        let channel_name = match &event {
252            ChannelEvent::Message(message) => &message.channel_name,
253            ChannelEvent::Presence(message) => &message.channel_name,
254        };
255        let channel = self.channels.lock().await.get(channel_name).cloned();
256        if let Some(channel) = channel {
257            channel.handle_event(event).await;
258        }
259    }
260
261    async fn rejoin_stalled_channels(&self) {
262        let channels: Vec<Channel> = self.channels.lock().await.values().cloned().collect();
263        for channel in channels {
264            let state = channel.state();
265            if state == ChannelState::Joining
266                || state == ChannelState::Joined
267                || state == ChannelState::Stalled
268            {
269                channel.join().await;
270            }
271        }
272    }
273}
274
275impl Channel {
276    pub fn name(&self) -> &str {
277        &self.inner.name
278    }
279
280    pub fn state(&self) -> ChannelState {
281        *self.inner.state.borrow()
282    }
283
284    pub fn subscribe_state(&self) -> watch::Receiver<ChannelState> {
285        self.inner.state.subscribe()
286    }
287
288    pub fn subscribe_events(&self) -> broadcast::Receiver<ChannelEvent> {
289        self.inner.events.subscribe()
290    }
291
292    pub async fn presence(&self) -> Vec<PondPresence> {
293        self.inner.presence.lock().await.clone()
294    }
295
296    pub async fn join(&self) {
297        if *self.inner.closed.lock().await {
298            return;
299        }
300        if matches!(
301            self.state(),
302            ChannelState::Joining | ChannelState::Joined | ChannelState::Declined
303        ) {
304            return;
305        }
306        self.inner.state.send_replace(ChannelState::Joining);
307        self.enqueue_or_send(self.join_message()).await;
308    }
309
310    pub async fn leave(&self) {
311        if *self.inner.closed.lock().await {
312            return;
313        }
314        let message = ClientMessage {
315            action: ClientAction::LeaveChannel,
316            event: "LEAVE_CHANNEL".to_owned(),
317            payload: Map::new(),
318            channel_name: self.inner.name.clone(),
319            request_id: uuid(),
320        };
321        let _ = self.inner.client.publish(message).await;
322        self.force_close().await;
323    }
324
325    pub async fn send_message(&self, event: impl Into<String>, payload: Option<PondMessage>) {
326        if *self.inner.closed.lock().await {
327            return;
328        }
329        let message = ClientMessage {
330            action: ClientAction::Broadcast,
331            event: event.into(),
332            payload: payload.unwrap_or_default(),
333            channel_name: self.inner.name.clone(),
334            request_id: uuid(),
335        };
336        self.enqueue_or_send(message).await;
337    }
338
339    pub async fn send_for_response(
340        &self,
341        event: impl Into<String>,
342        payload: Option<PondMessage>,
343        timeout: Option<Duration>,
344    ) -> Result<PondMessage> {
345        if *self.inner.closed.lock().await {
346            return Err(ClientError::ChannelClosed);
347        }
348        let request_id = uuid();
349        let (tx, rx) = oneshot::channel();
350        self.inner
351            .pending
352            .lock()
353            .await
354            .insert(request_id.clone(), tx);
355        let message = ClientMessage {
356            action: ClientAction::Broadcast,
357            event: event.into(),
358            payload: payload.unwrap_or_default(),
359            channel_name: self.inner.name.clone(),
360            request_id: request_id.clone(),
361        };
362        self.enqueue_or_send(message).await;
363        let timeout = timeout.unwrap_or(self.inner.client.options.response_timeout);
364        let result = tokio::time::timeout(timeout, rx).await;
365        self.inner.pending.lock().await.remove(&request_id);
366        match result {
367            Ok(Ok(payload)) => Ok(payload),
368            _ => Err(ClientError::ResponseTimeout),
369        }
370    }
371
372    async fn enqueue_or_send(&self, message: ClientMessage) {
373        let connected = *self.inner.client.state.borrow() == ConnectionState::Connected;
374        let joined = self.state() == ChannelState::Joined;
375        let is_join = message.action == ClientAction::JoinChannel;
376        if connected && (joined || is_join) {
377            if self.inner.client.publish(message.clone()).await.is_ok() {
378                return;
379            }
380        }
381        let mut queue = self.inner.queue.lock().await;
382        if queue.len() == self.inner.client.options.max_queue_size {
383            queue.pop_front();
384        }
385        queue.push_back(message);
386    }
387
388    async fn handle_event(&self, event: ChannelEvent) {
389        if *self.inner.closed.lock().await {
390            return;
391        }
392        match event {
393            ChannelEvent::Presence(message) => self.handle_presence(message).await,
394            ChannelEvent::Message(message) => self.handle_message(message).await,
395        }
396    }
397
398    async fn handle_presence(&self, message: PresenceMessage) {
399        *self.inner.presence.lock().await = message.payload.presence.clone();
400        let event = ChannelEvent::Presence(message.clone());
401        let _ = self.inner.events.send(event);
402    }
403
404    async fn handle_message(&self, message: ServerMessage) {
405        if message.action == ServerAction::System
406            && message.event == event_name(EventName::Acknowledge)
407        {
408            self.acknowledge().await;
409            return;
410        }
411        if message.action == ServerAction::System
412            && message.event == event_name(EventName::Unauthorized)
413        {
414            self.decline().await;
415            return;
416        }
417        if let Some(tx) = self.inner.pending.lock().await.remove(&message.request_id) {
418            let _ = tx.send(message.payload);
419            return;
420        }
421        if self.state() == ChannelState::Joined {
422            let _ = self.inner.events.send(ChannelEvent::Message(message));
423        }
424    }
425
426    async fn acknowledge(&self) {
427        if self.state() != ChannelState::Joined {
428            self.inner.state.send_replace(ChannelState::Joined);
429        }
430        let mut queue = self.inner.queue.lock().await;
431        let pending: Vec<ClientMessage> = queue.drain(..).collect();
432        drop(queue);
433        for message in pending {
434            let _ = self.inner.client.publish(message).await;
435        }
436    }
437
438    async fn decline(&self) {
439        self.inner.state.send_replace(ChannelState::Declined);
440        self.inner.queue.lock().await.clear();
441        self.inner.pending.lock().await.clear();
442    }
443
444    async fn force_close(&self) {
445        *self.inner.closed.lock().await = true;
446        self.inner.state.send_replace(ChannelState::Closed);
447        self.inner.queue.lock().await.clear();
448        self.inner.pending.lock().await.clear();
449    }
450
451    fn join_message(&self) -> ClientMessage {
452        ClientMessage {
453            action: ClientAction::JoinChannel,
454            event: "JOIN_CHANNEL".to_owned(),
455            payload: self.inner.params.clone(),
456            channel_name: self.inner.name.clone(),
457            request_id: uuid(),
458        }
459    }
460}
461
462fn resolve_url(endpoint: &str, params: Option<&JoinParams>) -> Result<String> {
463    let mut url = Url::parse(endpoint)?;
464    match url.scheme() {
465        "http" => url
466            .set_scheme("ws")
467            .map_err(|_| ClientError::UnsupportedScheme("http".to_owned()))?,
468        "https" => url
469            .set_scheme("wss")
470            .map_err(|_| ClientError::UnsupportedScheme("https".to_owned()))?,
471        "ws" | "wss" => {}
472        scheme => return Err(ClientError::UnsupportedScheme(scheme.to_owned())),
473    }
474    if let Some(params) = params {
475        let mut pairs = url.query_pairs_mut();
476        for (key, value) in params {
477            let value = match value {
478                Value::String(value) => value.clone(),
479                other => other.to_string(),
480            };
481            pairs.append_pair(key, &value);
482        }
483    }
484    Ok(url.to_string())
485}
486
487fn event_name(event: EventName) -> String {
488    serde_json::to_string(&event)
489        .unwrap_or_default()
490        .trim_matches('"')
491        .to_owned()
492}
493
494#[allow(dead_code)]
495fn presence_event_name(event: PresenceEventType) -> String {
496    serde_json::to_string(&event)
497        .unwrap_or_default()
498        .trim_matches('"')
499        .to_owned()
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use pondsocket_common::{PondEvent, PondSchema, PresencePayload};
506    use serde::{Deserialize, Serialize};
507
508    #[test]
509    fn resolves_http_url_to_ws_with_params() {
510        let mut params = JoinParams::new();
511        params.insert("token".to_owned(), Value::String("abc".to_owned()));
512        let url = resolve_url("https://example.com/socket?room=one", Some(&params)).unwrap();
513        assert_eq!(url, "wss://example.com/socket?room=one&token=abc");
514    }
515
516    #[tokio::test]
517    async fn queues_join_message_before_connect() {
518        let client = PondClient::new("ws://example.com/socket", None).unwrap();
519        let channel = client.create_channel("room", None).await;
520        channel.join().await;
521        assert_eq!(channel.state(), ChannelState::Joining);
522        assert_eq!(channel.inner.queue.lock().await.len(), 1);
523    }
524
525    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
526    struct ChatPayload {
527        text: String,
528    }
529
530    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
531    struct AckPayload {
532        ok: bool,
533    }
534
535    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
536    struct Presence {
537        user_id: String,
538    }
539
540    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
541    struct Assigns {
542        role: String,
543    }
544
545    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
546    struct Join {
547        token: String,
548    }
549
550    struct Chat;
551    struct ChatSchema;
552
553    impl PondEvent for Chat {
554        type Payload = ChatPayload;
555        type Response = AckPayload;
556
557        const NAME: &'static str = "chat";
558    }
559
560    impl PondSchema for ChatSchema {
561        type Presence = Presence;
562        type Assigns = Assigns;
563        type JoinParams = Join;
564    }
565
566    #[tokio::test]
567    async fn typed_channel_sends_and_decodes_schema_values() {
568        let client = PondClient::new("ws://example.com/socket", None).unwrap();
569        let params = Join {
570            token: "secret".to_owned(),
571        };
572        let channel = client
573            .create_typed_channel::<ChatSchema>("room", Some(&params))
574            .await
575            .unwrap();
576
577        channel.join().await;
578        channel
579            .send::<Chat>(&ChatPayload {
580                text: "hello".to_owned(),
581            })
582            .await
583            .unwrap();
584
585        let queued = channel.raw().inner.queue.lock().await;
586        assert_eq!(queued[0].payload["token"], "secret");
587        assert_eq!(queued[1].event, "chat");
588        assert_eq!(queued[1].payload["text"], "hello");
589        drop(queued);
590
591        let message = ServerMessage {
592            action: ServerAction::Broadcast,
593            event: "chat".to_owned(),
594            channel_name: "room".to_owned(),
595            request_id: "r1".to_owned(),
596            payload: serde_json::from_value(serde_json::json!({ "text": "from server" })).unwrap(),
597        };
598        assert_eq!(
599            channel.decode_message::<Chat>(&message).unwrap(),
600            Some(ChatPayload {
601                text: "from server".to_owned()
602            })
603        );
604
605        let presence = PresenceMessage {
606            action: pondsocket_common::PresenceAction::Presence,
607            event: PresenceEventType::Join,
608            channel_name: "room".to_owned(),
609            request_id: "p1".to_owned(),
610            payload: PresencePayload {
611                changed: serde_json::from_value(serde_json::json!({ "user_id": "u1" })).unwrap(),
612                presence: vec![
613                    serde_json::from_value(serde_json::json!({ "user_id": "u1" })).unwrap(),
614                ],
615            },
616        };
617        let (changed, users) = channel.decode_presence(&presence).unwrap();
618        assert_eq!(
619            changed,
620            Presence {
621                user_id: "u1".to_owned()
622            }
623        );
624        assert_eq!(users, vec![changed]);
625    }
626}