Skip to main content

pondsocket_client/
lib.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::Duration;
4
5use futures_util::{SinkExt, StreamExt};
6use pondsocket_common::{
7    ChannelEvent, ChannelState, ClientAction, ClientMessage, EventName, JoinParams, PondMessage,
8    PondPresence, PresenceEventType, PresenceMessage, ServerAction, ServerMessage, uuid,
9};
10use serde_json::{Map, Value};
11use thiserror::Error;
12use tokio::sync::{Mutex, broadcast, mpsc, oneshot, watch};
13use tokio::task::JoinHandle;
14use tokio_tungstenite::connect_async;
15use tokio_tungstenite::tungstenite::Message;
16use url::Url;
17
18pub mod typed;
19pub use typed::TypedChannel;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ConnectionState {
23    Connecting,
24    Connected,
25    Disconnected,
26}
27
28#[derive(Debug, Clone)]
29pub struct ClientOptions {
30    pub connection_timeout: Duration,
31    pub response_timeout: Duration,
32    pub max_queue_size: usize,
33}
34
35impl Default for ClientOptions {
36    fn default() -> Self {
37        Self {
38            connection_timeout: Duration::from_secs(10),
39            response_timeout: Duration::from_secs(5),
40            max_queue_size: 100,
41        }
42    }
43}
44
45#[derive(Debug, Error)]
46pub enum ClientError {
47    #[error("invalid websocket URL: {0}")]
48    Url(#[from] url::ParseError),
49    #[error("unsupported URL scheme: {0}")]
50    UnsupportedScheme(String),
51    #[error("websocket error: {0}")]
52    WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
53    #[error("serialization error: {0}")]
54    Serialization(#[from] serde_json::Error),
55    #[error("connection timed out")]
56    ConnectionTimeout,
57    #[error("client is not connected")]
58    NotConnected,
59    #[error("channel is closed")]
60    ChannelClosed,
61    #[error("response timed out")]
62    ResponseTimeout,
63}
64
65type Result<T> = std::result::Result<T, ClientError>;
66
67#[derive(Clone)]
68pub struct PondClient {
69    inner: Arc<ClientInner>,
70}
71
72struct ClientInner {
73    url: String,
74    options: ClientOptions,
75    state: watch::Sender<ConnectionState>,
76    conn_events: broadcast::Sender<ConnectionState>,
77    channels: Mutex<HashMap<String, Channel>>,
78    outbound: Mutex<Option<mpsc::Sender<ClientMessage>>>,
79    read_task: Mutex<Option<JoinHandle<()>>>,
80    write_task: Mutex<Option<JoinHandle<()>>>,
81}
82
83#[derive(Clone)]
84pub struct Channel {
85    inner: Arc<ChannelInner>,
86}
87
88struct ChannelInner {
89    name: String,
90    params: JoinParams,
91    client: Arc<ClientInner>,
92    state: watch::Sender<ChannelState>,
93    events: broadcast::Sender<ChannelEvent>,
94    presence: Mutex<Vec<PondPresence>>,
95    queue: Mutex<VecDeque<ClientMessage>>,
96    pending: Mutex<HashMap<String, oneshot::Sender<PondMessage>>>,
97    conn_task: Mutex<Option<JoinHandle<()>>>,
98    closed: Mutex<bool>,
99}
100
101impl PondClient {
102    pub fn new(endpoint: impl AsRef<str>, params: Option<JoinParams>) -> Result<Self> {
103        Self::with_options(endpoint, params, ClientOptions::default())
104    }
105
106    pub fn with_options(
107        endpoint: impl AsRef<str>,
108        params: Option<JoinParams>,
109        options: ClientOptions,
110    ) -> Result<Self> {
111        let url = resolve_url(endpoint.as_ref(), params.as_ref())?;
112        let (state, _) = watch::channel(ConnectionState::Disconnected);
113        let (conn_events, _) = broadcast::channel(16);
114
115        Ok(Self {
116            inner: Arc::new(ClientInner {
117                url,
118                options,
119                state,
120                conn_events,
121                channels: Mutex::new(HashMap::new()),
122                outbound: Mutex::new(None),
123                read_task: Mutex::new(None),
124                write_task: Mutex::new(None),
125            }),
126        })
127    }
128
129    pub fn state(&self) -> ConnectionState {
130        *self.inner.state.borrow()
131    }
132
133    pub fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
134        self.inner.state.subscribe()
135    }
136
137    pub async fn create_channel(
138        &self,
139        name: impl Into<String>,
140        params: Option<JoinParams>,
141    ) -> Channel {
142        let name = name.into();
143        let mut channels = self.inner.channels.lock().await;
144        if let Some(channel) = channels.get(&name) {
145            if channel.state() != ChannelState::Closed && channel.state() != ChannelState::Declined
146            {
147                return channel.clone();
148            }
149        }
150
151        let (state, _) = watch::channel(ChannelState::Idle);
152        let (events, _) = broadcast::channel(100);
153        let channel = Channel {
154            inner: Arc::new(ChannelInner {
155                name: name.clone(),
156                params: params.unwrap_or_default(),
157                client: Arc::clone(&self.inner),
158                state,
159                events,
160                presence: Mutex::new(Vec::new()),
161                queue: Mutex::new(VecDeque::new()),
162                pending: Mutex::new(HashMap::new()),
163                conn_task: Mutex::new(None),
164                closed: Mutex::new(false),
165            }),
166        };
167
168        let watcher = channel.clone();
169        let mut conn_rx = self.inner.conn_events.subscribe();
170        let handle = tokio::spawn(async move {
171            loop {
172                match conn_rx.recv().await {
173                    Ok(state) => watcher.on_connection_change(state).await,
174                    Err(broadcast::error::RecvError::Lagged(_)) => {}
175                    Err(broadcast::error::RecvError::Closed) => break,
176                }
177                if *watcher.inner.closed.lock().await {
178                    break;
179                }
180            }
181        });
182        *channel.inner.conn_task.lock().await = Some(handle);
183
184        channels.insert(name, channel.clone());
185        channel
186    }
187
188    pub async fn connect(&self) -> Result<()> {
189        if self.state() != ConnectionState::Disconnected {
190            return Ok(());
191        }
192        self.inner.set_connection_state(ConnectionState::Connecting);
193        let connect = connect_async(&self.inner.url);
194        let (socket, _) = tokio::time::timeout(self.inner.options.connection_timeout, connect)
195            .await
196            .map_err(|_| ClientError::ConnectionTimeout)??;
197        let (mut writer, mut reader) = socket.split();
198        let (tx, mut rx) = mpsc::channel::<ClientMessage>(self.inner.options.max_queue_size);
199        *self.inner.outbound.lock().await = Some(tx);
200
201        let write_task = tokio::spawn(async move {
202            while let Some(message) = rx.recv().await {
203                let Ok(text) = serde_json::to_string(&message) else {
204                    continue;
205                };
206                if writer.send(Message::Text(text.into())).await.is_err() {
207                    break;
208                }
209            }
210            let _ = writer.close().await;
211        });
212
213        let inner = Arc::clone(&self.inner);
214        let read_task = tokio::spawn(async move {
215            while let Some(frame) = reader.next().await {
216                let text = match frame {
217                    Ok(Message::Text(text)) => text.to_string(),
218                    Ok(Message::Binary(bytes)) => match String::from_utf8(bytes.to_vec()) {
219                        Ok(text) => text,
220                        Err(_) => continue,
221                    },
222                    Ok(Message::Close(_)) => break,
223                    Ok(_) => continue,
224                    Err(_) => break,
225                };
226                let Ok(event) = pondsocket_common::parse_channel_event(&text) else {
227                    continue;
228                };
229                inner.route_event(event).await;
230            }
231            inner.set_connection_state(ConnectionState::Disconnected);
232            *inner.outbound.lock().await = None;
233        });
234
235        *self.inner.read_task.lock().await = Some(read_task);
236        *self.inner.write_task.lock().await = Some(write_task);
237        self.inner.set_connection_state(ConnectionState::Connected);
238        Ok(())
239    }
240
241    pub async fn disconnect(&self) {
242        if let Some(task) = self.inner.read_task.lock().await.take() {
243            task.abort();
244        }
245        if let Some(task) = self.inner.write_task.lock().await.take() {
246            task.abort();
247        }
248        *self.inner.outbound.lock().await = None;
249        self.inner.set_connection_state(ConnectionState::Disconnected);
250        let channels: Vec<Channel> = self.inner.channels.lock().await.values().cloned().collect();
251        for channel in channels {
252            channel.force_close().await;
253        }
254        self.inner.channels.lock().await.clear();
255    }
256}
257
258impl ClientInner {
259    async fn publish(&self, message: ClientMessage) -> Result<()> {
260        let tx = self
261            .outbound
262            .lock()
263            .await
264            .clone()
265            .ok_or(ClientError::NotConnected)?;
266        tx.send(message)
267            .await
268            .map_err(|_| ClientError::NotConnected)
269    }
270
271    async fn route_event(&self, event: ChannelEvent) {
272        let channel_name = match &event {
273            ChannelEvent::Message(message) => &message.channel_name,
274            ChannelEvent::Presence(message) => &message.channel_name,
275        };
276        let channel = self.channels.lock().await.get(channel_name).cloned();
277        if let Some(channel) = channel {
278            channel.handle_event(event).await;
279        }
280    }
281
282    fn set_connection_state(&self, state: ConnectionState) {
283        let changed = *self.state.borrow() != state;
284        self.state.send_replace(state);
285        if changed {
286            let _ = self.conn_events.send(state);
287        }
288    }
289}
290
291impl Channel {
292    pub fn name(&self) -> &str {
293        &self.inner.name
294    }
295
296    pub fn state(&self) -> ChannelState {
297        *self.inner.state.borrow()
298    }
299
300    pub fn subscribe_state(&self) -> watch::Receiver<ChannelState> {
301        self.inner.state.subscribe()
302    }
303
304    pub fn subscribe_events(&self) -> broadcast::Receiver<ChannelEvent> {
305        self.inner.events.subscribe()
306    }
307
308    pub async fn presence(&self) -> Vec<PondPresence> {
309        self.inner.presence.lock().await.clone()
310    }
311
312    pub async fn join(&self) {
313        if *self.inner.closed.lock().await {
314            return;
315        }
316        if matches!(
317            self.state(),
318            ChannelState::Joining | ChannelState::Joined | ChannelState::Declined
319        ) {
320            return;
321        }
322        self.inner.state.send_replace(ChannelState::Joining);
323        self.enqueue_or_send(self.join_message()).await;
324    }
325
326    pub async fn leave(&self) {
327        if *self.inner.closed.lock().await {
328            return;
329        }
330        let message = ClientMessage {
331            action: ClientAction::LeaveChannel,
332            event: "LEAVE_CHANNEL".to_owned(),
333            payload: Map::new(),
334            channel_name: self.inner.name.clone(),
335            request_id: uuid(),
336        };
337        let _ = self.inner.client.publish(message).await;
338        self.force_close().await;
339    }
340
341    pub async fn send_message(&self, event: impl Into<String>, payload: Option<PondMessage>) {
342        if *self.inner.closed.lock().await {
343            return;
344        }
345        let message = ClientMessage {
346            action: ClientAction::Broadcast,
347            event: event.into(),
348            payload: payload.unwrap_or_default(),
349            channel_name: self.inner.name.clone(),
350            request_id: uuid(),
351        };
352        self.enqueue_or_send(message).await;
353    }
354
355    pub async fn send_for_response(
356        &self,
357        event: impl Into<String>,
358        payload: Option<PondMessage>,
359        timeout: Option<Duration>,
360    ) -> Result<PondMessage> {
361        if *self.inner.closed.lock().await {
362            return Err(ClientError::ChannelClosed);
363        }
364        let request_id = uuid();
365        let (tx, rx) = oneshot::channel();
366        self.inner
367            .pending
368            .lock()
369            .await
370            .insert(request_id.clone(), tx);
371        let message = ClientMessage {
372            action: ClientAction::Broadcast,
373            event: event.into(),
374            payload: payload.unwrap_or_default(),
375            channel_name: self.inner.name.clone(),
376            request_id: request_id.clone(),
377        };
378        self.enqueue_or_send(message).await;
379        let timeout = timeout.unwrap_or(self.inner.client.options.response_timeout);
380        let result = tokio::time::timeout(timeout, rx).await;
381        self.inner.pending.lock().await.remove(&request_id);
382        match result {
383            Ok(Ok(payload)) => Ok(payload),
384            _ => Err(ClientError::ResponseTimeout),
385        }
386    }
387
388    async fn enqueue_or_send(&self, message: ClientMessage) {
389        let connected = *self.inner.client.state.borrow() == ConnectionState::Connected;
390        let joined = self.state() == ChannelState::Joined;
391        let is_join = message.action == ClientAction::JoinChannel;
392        if connected && (joined || is_join) {
393            if self.inner.client.publish(message.clone()).await.is_ok() {
394                return;
395            }
396        }
397        let mut queue = self.inner.queue.lock().await;
398        if queue.len() == self.inner.client.options.max_queue_size {
399            queue.pop_front();
400        }
401        queue.push_back(message);
402    }
403
404    async fn on_connection_change(&self, state: ConnectionState) {
405        if *self.inner.closed.lock().await {
406            return;
407        }
408        match state {
409            ConnectionState::Disconnected => {
410                if self.state() == ChannelState::Joined {
411                    self.inner.state.send_replace(ChannelState::Stalled);
412                }
413            }
414            ConnectionState::Connected => {
415                if self.state() == ChannelState::Stalled {
416                    self.join().await;
417                }
418            }
419            ConnectionState::Connecting => {}
420        }
421    }
422
423    async fn handle_event(&self, event: ChannelEvent) {
424        if *self.inner.closed.lock().await {
425            return;
426        }
427        match event {
428            ChannelEvent::Presence(message) => self.handle_presence(message).await,
429            ChannelEvent::Message(message) => self.handle_message(message).await,
430        }
431    }
432
433    async fn handle_presence(&self, message: PresenceMessage) {
434        *self.inner.presence.lock().await = message.payload.presence.clone();
435        let event = ChannelEvent::Presence(message.clone());
436        let _ = self.inner.events.send(event);
437    }
438
439    async fn handle_message(&self, message: ServerMessage) {
440        if message.action == ServerAction::System
441            && message.event == event_name(EventName::Acknowledge)
442        {
443            self.acknowledge().await;
444            return;
445        }
446        if message.action == ServerAction::System
447            && message.event == event_name(EventName::Unauthorized)
448        {
449            self.decline().await;
450            return;
451        }
452        if let Some(tx) = self.inner.pending.lock().await.remove(&message.request_id) {
453            let _ = tx.send(message.payload);
454            return;
455        }
456        if self.state() == ChannelState::Joined {
457            let _ = self.inner.events.send(ChannelEvent::Message(message));
458        }
459    }
460
461    async fn acknowledge(&self) {
462        if self.state() != ChannelState::Joined {
463            self.inner.state.send_replace(ChannelState::Joined);
464        }
465        let mut queue = self.inner.queue.lock().await;
466        let pending: Vec<ClientMessage> = queue.drain(..).collect();
467        drop(queue);
468        for message in pending {
469            let _ = self.inner.client.publish(message).await;
470        }
471    }
472
473    async fn decline(&self) {
474        self.inner.state.send_replace(ChannelState::Declined);
475        self.inner.queue.lock().await.clear();
476        self.inner.pending.lock().await.clear();
477    }
478
479    async fn force_close(&self) {
480        *self.inner.closed.lock().await = true;
481        self.inner.state.send_replace(ChannelState::Closed);
482        self.inner.queue.lock().await.clear();
483        self.inner.pending.lock().await.clear();
484        if let Some(handle) = self.inner.conn_task.lock().await.take() {
485            handle.abort();
486        }
487    }
488
489    fn join_message(&self) -> ClientMessage {
490        ClientMessage {
491            action: ClientAction::JoinChannel,
492            event: "JOIN_CHANNEL".to_owned(),
493            payload: self.inner.params.clone(),
494            channel_name: self.inner.name.clone(),
495            request_id: uuid(),
496        }
497    }
498}
499
500fn resolve_url(endpoint: &str, params: Option<&JoinParams>) -> Result<String> {
501    let mut url = Url::parse(endpoint)?;
502    match url.scheme() {
503        "http" => url
504            .set_scheme("ws")
505            .map_err(|_| ClientError::UnsupportedScheme("http".to_owned()))?,
506        "https" => url
507            .set_scheme("wss")
508            .map_err(|_| ClientError::UnsupportedScheme("https".to_owned()))?,
509        "ws" | "wss" => {}
510        scheme => return Err(ClientError::UnsupportedScheme(scheme.to_owned())),
511    }
512    if let Some(params) = params {
513        let mut pairs = url.query_pairs_mut();
514        for (key, value) in params {
515            let value = match value {
516                Value::String(value) => value.clone(),
517                other => other.to_string(),
518            };
519            pairs.append_pair(key, &value);
520        }
521    }
522    Ok(url.to_string())
523}
524
525fn event_name(event: EventName) -> String {
526    serde_json::to_string(&event)
527        .unwrap_or_default()
528        .trim_matches('"')
529        .to_owned()
530}
531
532#[allow(dead_code)]
533fn presence_event_name(event: PresenceEventType) -> String {
534    serde_json::to_string(&event)
535        .unwrap_or_default()
536        .trim_matches('"')
537        .to_owned()
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use pondsocket_common::{PondEvent, PondSchema, PresencePayload};
544    use serde::{Deserialize, Serialize};
545
546    #[test]
547    fn resolves_http_url_to_ws_with_params() {
548        let mut params = JoinParams::new();
549        params.insert("token".to_owned(), Value::String("abc".to_owned()));
550        let url = resolve_url("https://example.com/socket?room=one", Some(&params)).unwrap();
551        assert_eq!(url, "wss://example.com/socket?room=one&token=abc");
552    }
553
554    #[tokio::test]
555    async fn queues_join_message_before_connect() {
556        let client = PondClient::new("ws://example.com/socket", None).unwrap();
557        let channel = client.create_channel("room", None).await;
558        channel.join().await;
559        assert_eq!(channel.state(), ChannelState::Joining);
560        assert_eq!(channel.inner.queue.lock().await.len(), 1);
561    }
562
563    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
564    struct ChatPayload {
565        text: String,
566    }
567
568    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
569    struct AckPayload {
570        ok: bool,
571    }
572
573    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
574    struct Presence {
575        user_id: String,
576    }
577
578    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
579    struct Assigns {
580        role: String,
581    }
582
583    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
584    struct Join {
585        token: String,
586    }
587
588    struct Chat;
589    struct ChatSchema;
590
591    impl PondEvent for Chat {
592        type Payload = ChatPayload;
593        type Response = AckPayload;
594
595        const NAME: &'static str = "chat";
596    }
597
598    impl PondSchema for ChatSchema {
599        type Presence = Presence;
600        type Assigns = Assigns;
601        type JoinParams = Join;
602    }
603
604    #[tokio::test]
605    async fn typed_channel_sends_and_decodes_schema_values() {
606        let client = PondClient::new("ws://example.com/socket", None).unwrap();
607        let params = Join {
608            token: "secret".to_owned(),
609        };
610        let channel = client
611            .create_typed_channel::<ChatSchema>("room", Some(&params))
612            .await
613            .unwrap();
614
615        channel.join().await;
616        channel
617            .send::<Chat>(&ChatPayload {
618                text: "hello".to_owned(),
619            })
620            .await
621            .unwrap();
622
623        let queued = channel.raw().inner.queue.lock().await;
624        assert_eq!(queued[0].payload["token"], "secret");
625        assert_eq!(queued[1].event, "chat");
626        assert_eq!(queued[1].payload["text"], "hello");
627        drop(queued);
628
629        let message = ServerMessage {
630            action: ServerAction::Broadcast,
631            event: "chat".to_owned(),
632            channel_name: "room".to_owned(),
633            request_id: "r1".to_owned(),
634            payload: serde_json::from_value(serde_json::json!({ "text": "from server" })).unwrap(),
635        };
636        assert_eq!(
637            channel.decode_message::<Chat>(&message).unwrap(),
638            Some(ChatPayload {
639                text: "from server".to_owned()
640            })
641        );
642
643        let presence = PresenceMessage {
644            action: pondsocket_common::PresenceAction::Presence,
645            event: PresenceEventType::Join,
646            channel_name: "room".to_owned(),
647            request_id: "p1".to_owned(),
648            payload: PresencePayload {
649                changed: serde_json::from_value(serde_json::json!({ "user_id": "u1" })).unwrap(),
650                presence: vec![
651                    serde_json::from_value(serde_json::json!({ "user_id": "u1" })).unwrap(),
652                ],
653            },
654        };
655        let (changed, users) = channel.decode_presence(&presence).unwrap();
656        assert_eq!(
657            changed,
658            Presence {
659                user_id: "u1".to_owned()
660            }
661        );
662        assert_eq!(users, vec![changed]);
663    }
664
665    use std::sync::atomic::{AtomicUsize, Ordering};
666    use tokio::net::{TcpListener, TcpStream};
667    use tokio_tungstenite::accept_async;
668
669    async fn serve_session(
670        stream: TcpStream,
671        joins: Arc<AtomicUsize>,
672        conn_id: u64,
673        mut kill: broadcast::Receiver<()>,
674    ) {
675        let Ok(ws) = accept_async(stream).await else {
676            return;
677        };
678        let (mut writer, mut reader) = ws.split();
679        loop {
680            tokio::select! {
681                _ = kill.recv() => {
682                    let _ = writer.close().await;
683                    return;
684                }
685                frame = reader.next() => {
686                    let message = match frame {
687                        Some(Ok(Message::Text(text))) => text.to_string(),
688                        Some(Ok(_)) => continue,
689                        _ => return,
690                    };
691                    let Ok(request) = serde_json::from_str::<ClientMessage>(&message) else {
692                        continue;
693                    };
694                    if request.action != ClientAction::JoinChannel {
695                        continue;
696                    }
697                    joins.fetch_add(1, Ordering::SeqCst);
698                    let ack = ServerMessage {
699                        action: ServerAction::System,
700                        event: event_name(EventName::Acknowledge),
701                        channel_name: request.channel_name.clone(),
702                        request_id: request.request_id.clone(),
703                        payload: Map::new(),
704                    };
705                    let _ = writer
706                        .send(Message::Text(serde_json::to_string(&ack).unwrap().into()))
707                        .await;
708                    let mut payload = Map::new();
709                    payload.insert("conn".to_owned(), Value::from(conn_id));
710                    let greeting = ServerMessage {
711                        action: ServerAction::Broadcast,
712                        event: "greeting".to_owned(),
713                        channel_name: request.channel_name.clone(),
714                        request_id: uuid(),
715                        payload,
716                    };
717                    let _ = writer
718                        .send(Message::Text(serde_json::to_string(&greeting).unwrap().into()))
719                        .await;
720                }
721            }
722        }
723    }
724
725    async fn wait_for_state(channel: &Channel, target: ChannelState) {
726        let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
727        while channel.state() != target {
728            if tokio::time::Instant::now() > deadline {
729                panic!("timed out waiting for {target:?}, still {:?}", channel.state());
730            }
731            tokio::time::sleep(Duration::from_millis(5)).await;
732        }
733    }
734
735    async fn recv_greeting(events: &mut broadcast::Receiver<ChannelEvent>) -> u64 {
736        loop {
737            match tokio::time::timeout(Duration::from_secs(2), events.recv()).await {
738                Ok(Ok(ChannelEvent::Message(message))) if message.event == "greeting" => {
739                    return message.payload.get("conn").and_then(Value::as_u64).unwrap();
740                }
741                Ok(Ok(_)) => {}
742                other => panic!("expected greeting event, got {other:?}"),
743            }
744        }
745    }
746
747    #[tokio::test]
748    async fn rejoins_channel_after_socket_drop_and_keeps_receiving_events() {
749        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
750        let addr = listener.local_addr().unwrap();
751        let joins = Arc::new(AtomicUsize::new(0));
752        let (kill, _) = broadcast::channel::<()>(4);
753
754        let server_joins = Arc::clone(&joins);
755        let server_kill = kill.clone();
756        tokio::spawn(async move {
757            let mut conn_id = 0;
758            while let Ok((stream, _)) = listener.accept().await {
759                conn_id += 1;
760                tokio::spawn(serve_session(
761                    stream,
762                    Arc::clone(&server_joins),
763                    conn_id,
764                    server_kill.subscribe(),
765                ));
766            }
767        });
768
769        let url = format!("ws://{addr}/socket");
770        let client = PondClient::new(&url, None).unwrap();
771        client.connect().await.unwrap();
772        let channel = client.create_channel("room", None).await;
773        let mut events = channel.subscribe_events();
774        channel.join().await;
775
776        wait_for_state(&channel, ChannelState::Joined).await;
777        assert_eq!(recv_greeting(&mut events).await, 1);
778        assert_eq!(joins.load(Ordering::SeqCst), 1);
779
780        kill.send(()).unwrap();
781        wait_for_state(&channel, ChannelState::Stalled).await;
782
783        client.connect().await.unwrap();
784        wait_for_state(&channel, ChannelState::Joined).await;
785        assert_eq!(joins.load(Ordering::SeqCst), 2);
786
787        assert_eq!(recv_greeting(&mut events).await, 2);
788    }
789}