Skip to main content

vox_rtc_server/
socket.rs

1use crate::error::{Result, VoxRtcError};
2use crate::types::{ChannelState, ConnectionState, EventData};
3use pondsocket_client::{
4    Channel as PondChannel, ClientError, ClientOptions, ConnectionState as PondConnectionState,
5    PondClient,
6};
7use pondsocket_common::{ChannelEvent, ChannelState as PondChannelState};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::time::Duration;
11use tokio::sync::{broadcast, watch};
12
13const INITIAL_RECONNECT_DELAY: Duration = Duration::from_millis(200);
14
15#[derive(Clone)]
16pub(crate) struct RawSocketClient {
17    client: PondClient,
18    params: EventData,
19    state_tx: watch::Sender<ConnectionState>,
20    active: Arc<AtomicBool>,
21    supervisor_started: Arc<AtomicBool>,
22    max_reconnect_delay: Duration,
23}
24
25#[derive(Clone)]
26pub(crate) struct RawSocketChannel {
27    channel: PondChannel,
28    state_tx: watch::Sender<ChannelState>,
29    message_tx: broadcast::Sender<(String, EventData)>,
30}
31
32impl RawSocketClient {
33    pub(crate) fn new(
34        endpoint: &str,
35        params: EventData,
36        connection_timeout: Duration,
37        max_reconnect_delay: Duration,
38    ) -> Result<Self> {
39        let options = ClientOptions {
40            connection_timeout,
41            ..ClientOptions::default()
42        };
43        let client = PondClient::with_options(endpoint, Some(params.clone()), options)?;
44        let (state_tx, _) = watch::channel(map_connection_state(client.state()));
45
46        Ok(Self {
47            client,
48            params,
49            state_tx,
50            active: Arc::new(AtomicBool::new(false)),
51            supervisor_started: Arc::new(AtomicBool::new(false)),
52            max_reconnect_delay,
53        })
54    }
55
56    fn ensure_supervisor(&self) {
57        if self.supervisor_started.swap(true, Ordering::SeqCst) {
58            return;
59        }
60        spawn_reconnect_supervisor(
61            self.client.clone(),
62            self.state_tx.clone(),
63            self.active.clone(),
64            self.max_reconnect_delay,
65        );
66    }
67
68    pub(crate) fn state(&self) -> ConnectionState {
69        map_connection_state(self.client.state())
70    }
71
72    pub(crate) fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
73        self.state_tx.subscribe()
74    }
75
76    pub(crate) async fn connect(&self) -> Result<()> {
77        self.active.store(true, Ordering::SeqCst);
78        self.ensure_supervisor();
79        self.state_tx
80            .send_replace(map_connection_state(self.client.state()));
81        self.client.connect().await?;
82        self.state_tx
83            .send_replace(map_connection_state(self.client.state()));
84        Ok(())
85    }
86
87    pub(crate) async fn disconnect(&self) {
88        self.active.store(false, Ordering::SeqCst);
89        self.client.disconnect().await;
90        self.state_tx
91            .send_replace(map_connection_state(self.client.state()));
92    }
93
94    pub(crate) async fn create_channel(
95        &self,
96        name: impl Into<String>,
97        params: EventData,
98    ) -> RawSocketChannel {
99        let channel = self.client.create_channel(name, Some(params)).await;
100        RawSocketChannel::new(channel)
101    }
102
103    #[allow(dead_code)]
104    pub(crate) fn params(&self) -> &EventData {
105        &self.params
106    }
107}
108
109fn spawn_reconnect_supervisor(
110    client: PondClient,
111    state_tx: watch::Sender<ConnectionState>,
112    active: Arc<AtomicBool>,
113    max_reconnect_delay: Duration,
114) {
115    let mut states = client.subscribe_state();
116    tokio::spawn(async move {
117        loop {
118            if states.changed().await.is_err() {
119                break;
120            }
121            let current = *states.borrow_and_update();
122            state_tx.send_replace(map_connection_state(current));
123            if current != PondConnectionState::Disconnected || !active.load(Ordering::SeqCst) {
124                continue;
125            }
126            let mut delay = INITIAL_RECONNECT_DELAY;
127            while active.load(Ordering::SeqCst)
128                && client.state() == PondConnectionState::Disconnected
129            {
130                tokio::time::sleep(delay).await;
131                if !active.load(Ordering::SeqCst) {
132                    break;
133                }
134                if client.connect().await.is_ok() {
135                    state_tx.send_replace(map_connection_state(client.state()));
136                    break;
137                }
138                delay = next_reconnect_delay(delay, max_reconnect_delay);
139            }
140        }
141    });
142}
143
144fn next_reconnect_delay(current: Duration, max: Duration) -> Duration {
145    let doubled = current.saturating_mul(2);
146    if doubled > max { max } else { doubled }
147}
148
149impl RawSocketChannel {
150    fn new(channel: PondChannel) -> Self {
151        let (state_tx, _) = watch::channel(map_channel_state(channel.state()));
152        let (message_tx, _) = broadcast::channel(1024);
153
154        let mut pond_states = channel.subscribe_state();
155        let mirror_state_tx = state_tx.clone();
156        tokio::spawn(async move {
157            loop {
158                mirror_state_tx.send_replace(map_channel_state(*pond_states.borrow_and_update()));
159                if pond_states.changed().await.is_err() {
160                    break;
161                }
162            }
163        });
164
165        let mut pond_events = channel.subscribe_events();
166        let mirror_message_tx = message_tx.clone();
167        tokio::spawn(async move {
168            while let Ok(event) = pond_events.recv().await {
169                if let Some((event, payload)) = map_channel_event(event) {
170                    let _ = mirror_message_tx.send((event, payload));
171                }
172            }
173        });
174
175        Self {
176            channel,
177            state_tx,
178            message_tx,
179        }
180    }
181
182    pub(crate) fn name(&self) -> &str {
183        self.channel.name()
184    }
185
186    pub(crate) fn subscribe_state(&self) -> watch::Receiver<ChannelState> {
187        self.state_tx.subscribe()
188    }
189
190    pub(crate) fn subscribe_messages(&self) -> broadcast::Receiver<(String, EventData)> {
191        self.message_tx.subscribe()
192    }
193
194    fn closed_error(&self) -> Option<VoxRtcError> {
195        match self.channel.state() {
196            PondChannelState::Closed | PondChannelState::Declined => {
197                Some(VoxRtcError::ChannelClosed)
198            }
199            _ => None,
200        }
201    }
202
203    pub(crate) async fn join(&self) -> Result<()> {
204        if let Some(error) = self.closed_error() {
205            return Err(error);
206        }
207        self.channel.join().await;
208        Ok(())
209    }
210
211    pub(crate) async fn leave(&self) -> Result<()> {
212        if let Some(error) = self.closed_error() {
213            return Err(error);
214        }
215        self.channel.leave().await;
216        Ok(())
217    }
218
219    pub(crate) async fn send_message(&self, event: &str, payload: EventData) -> Result<()> {
220        if let Some(error) = self.closed_error() {
221            return Err(error);
222        }
223        self.channel.send_message(event, Some(payload)).await;
224        Ok(())
225    }
226}
227
228fn map_connection_state(state: PondConnectionState) -> ConnectionState {
229    match state {
230        PondConnectionState::Connecting => ConnectionState::Connecting,
231        PondConnectionState::Connected => ConnectionState::Connected,
232        PondConnectionState::Disconnected => ConnectionState::Disconnected,
233    }
234}
235
236fn map_channel_state(state: PondChannelState) -> ChannelState {
237    match state {
238        PondChannelState::Idle => ChannelState::Idle,
239        PondChannelState::Joining => ChannelState::Joining,
240        PondChannelState::Joined => ChannelState::Joined,
241        PondChannelState::Closed => ChannelState::Closed,
242        PondChannelState::Declined => ChannelState::Declined,
243        PondChannelState::Stalled => ChannelState::Joining,
244    }
245}
246
247fn map_channel_event(event: ChannelEvent) -> Option<(String, EventData)> {
248    match event {
249        ChannelEvent::Message(message) => Some((message.event, message.payload)),
250        ChannelEvent::Presence(_) => None,
251    }
252}
253
254impl From<ClientError> for VoxRtcError {
255    fn from(value: ClientError) -> Self {
256        match value {
257            ClientError::Url(err) => Self::InvalidUrl(err),
258            ClientError::Serialization(err) => Self::Json(err),
259            ClientError::WebSocket(err) => Self::PondSocketClient(err.to_string()),
260            ClientError::NotConnected => Self::NotConnected,
261            ClientError::ChannelClosed => Self::ChannelClosed,
262            other => Self::PondSocketClient(other.to_string()),
263        }
264    }
265}
266
267#[cfg(test)]
268pub(crate) async fn test_channel() -> (RawSocketChannel, broadcast::Sender<(String, EventData)>) {
269    let client = PondClient::new("ws://localhost/socket", None).expect("valid test url");
270    let channel = client.create_channel("/rtc/test", None).await;
271    let raw = RawSocketChannel::new(channel);
272    let sender = raw.message_tx.clone();
273    (raw, sender)
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn distinguishes_not_connected_from_channel_closed() {
282        assert!(matches!(
283            VoxRtcError::from(ClientError::NotConnected),
284            VoxRtcError::NotConnected
285        ));
286        assert!(matches!(
287            VoxRtcError::from(ClientError::ChannelClosed),
288            VoxRtcError::ChannelClosed
289        ));
290    }
291
292    #[test]
293    fn reconnect_delay_doubles_then_caps() {
294        let max = Duration::from_secs(5);
295        assert_eq!(
296            next_reconnect_delay(Duration::from_millis(200), max),
297            Duration::from_millis(400)
298        );
299        assert_eq!(
300            next_reconnect_delay(Duration::from_secs(4), max),
301            Duration::from_secs(5)
302        );
303        assert_eq!(next_reconnect_delay(max, max), max);
304    }
305
306    #[tokio::test]
307    async fn send_message_errors_when_channel_closed() {
308        let (channel, _sender) = test_channel().await;
309        channel.leave().await.expect("first leave closes channel");
310        let error = channel
311            .send_message("response.start", EventData::new())
312            .await
313            .expect_err("closed channel must reject sends");
314        assert!(matches!(error, VoxRtcError::ChannelClosed));
315    }
316
317    #[tokio::test]
318    async fn join_and_leave_error_when_channel_closed() {
319        let (channel, _sender) = test_channel().await;
320        channel.leave().await.expect("first leave closes channel");
321        assert!(matches!(
322            channel.join().await.expect_err("cannot join a closed channel"),
323            VoxRtcError::ChannelClosed
324        ));
325        assert!(matches!(
326            channel
327                .leave()
328                .await
329                .expect_err("cannot leave an already-closed channel"),
330            VoxRtcError::ChannelClosed
331        ));
332    }
333
334    #[tokio::test]
335    async fn lagged_broadcast_does_not_stop_consumption() {
336        let (tx, mut rx) = broadcast::channel::<(String, EventData)>(2);
337        for index in 0..5u32 {
338            let _ = tx.send((format!("event-{index}"), EventData::new()));
339        }
340
341        let mut lagged = false;
342        let mut delivered = Vec::new();
343        loop {
344            match rx.try_recv() {
345                Ok(message) => delivered.push(message.0),
346                Err(broadcast::error::TryRecvError::Lagged(_)) => lagged = true,
347                Err(_) => break,
348            }
349        }
350
351        assert!(lagged, "small buffer overflow must surface a lag");
352        assert!(
353            delivered.contains(&"event-4".to_owned()),
354            "consumer must keep reading past the lag: {delivered:?}"
355        );
356    }
357}