1use crate::error::{Result, VoxRtcError};
2use crate::types::{ChannelState, ConnectionState, EventData};
3use pondsocket_client::{
4 Channel as PondChannel, ClientError, ClientOptions, ConnectionState as PondConnectionState,
5 PondClient,
6};
7use pondsocket_common::{ChannelEvent, ChannelState as PondChannelState};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::time::Duration;
11use tokio::sync::{broadcast, watch};
12
13const INITIAL_RECONNECT_DELAY: Duration = Duration::from_millis(200);
14
15#[derive(Clone)]
16pub(crate) struct RawSocketClient {
17 client: PondClient,
18 params: EventData,
19 state_tx: watch::Sender<ConnectionState>,
20 active: Arc<AtomicBool>,
21 supervisor_started: Arc<AtomicBool>,
22 max_reconnect_delay: Duration,
23}
24
25#[derive(Clone)]
26pub(crate) struct RawSocketChannel {
27 channel: PondChannel,
28 state_tx: watch::Sender<ChannelState>,
29 message_tx: broadcast::Sender<(String, EventData)>,
30}
31
32impl RawSocketClient {
33 pub(crate) fn new(
34 endpoint: &str,
35 params: EventData,
36 connection_timeout: Duration,
37 max_reconnect_delay: Duration,
38 ) -> Result<Self> {
39 let options = ClientOptions {
40 connection_timeout,
41 ..ClientOptions::default()
42 };
43 let client = PondClient::with_options(endpoint, Some(params.clone()), options)?;
44 let (state_tx, _) = watch::channel(map_connection_state(client.state()));
45
46 Ok(Self {
47 client,
48 params,
49 state_tx,
50 active: Arc::new(AtomicBool::new(false)),
51 supervisor_started: Arc::new(AtomicBool::new(false)),
52 max_reconnect_delay,
53 })
54 }
55
56 fn ensure_supervisor(&self) {
57 if self.supervisor_started.swap(true, Ordering::SeqCst) {
58 return;
59 }
60 spawn_reconnect_supervisor(
61 self.client.clone(),
62 self.state_tx.clone(),
63 self.active.clone(),
64 self.max_reconnect_delay,
65 );
66 }
67
68 pub(crate) fn state(&self) -> ConnectionState {
69 map_connection_state(self.client.state())
70 }
71
72 pub(crate) fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
73 self.state_tx.subscribe()
74 }
75
76 pub(crate) async fn connect(&self) -> Result<()> {
77 self.active.store(true, Ordering::SeqCst);
78 self.ensure_supervisor();
79 self.state_tx
80 .send_replace(map_connection_state(self.client.state()));
81 self.client.connect().await?;
82 self.state_tx
83 .send_replace(map_connection_state(self.client.state()));
84 Ok(())
85 }
86
87 pub(crate) async fn disconnect(&self) {
88 self.active.store(false, Ordering::SeqCst);
89 self.client.disconnect().await;
90 self.state_tx
91 .send_replace(map_connection_state(self.client.state()));
92 }
93
94 pub(crate) async fn create_channel(
95 &self,
96 name: impl Into<String>,
97 params: EventData,
98 ) -> RawSocketChannel {
99 let channel = self.client.create_channel(name, Some(params)).await;
100 RawSocketChannel::new(channel)
101 }
102
103 #[allow(dead_code)]
104 pub(crate) fn params(&self) -> &EventData {
105 &self.params
106 }
107}
108
109fn spawn_reconnect_supervisor(
110 client: PondClient,
111 state_tx: watch::Sender<ConnectionState>,
112 active: Arc<AtomicBool>,
113 max_reconnect_delay: Duration,
114) {
115 let mut states = client.subscribe_state();
116 tokio::spawn(async move {
117 loop {
118 if states.changed().await.is_err() {
119 break;
120 }
121 let current = *states.borrow_and_update();
122 state_tx.send_replace(map_connection_state(current));
123 if current != PondConnectionState::Disconnected || !active.load(Ordering::SeqCst) {
124 continue;
125 }
126 let mut delay = INITIAL_RECONNECT_DELAY;
127 while active.load(Ordering::SeqCst)
128 && client.state() == PondConnectionState::Disconnected
129 {
130 tokio::time::sleep(delay).await;
131 if !active.load(Ordering::SeqCst) {
132 break;
133 }
134 if client.connect().await.is_ok() {
135 state_tx.send_replace(map_connection_state(client.state()));
136 break;
137 }
138 delay = next_reconnect_delay(delay, max_reconnect_delay);
139 }
140 }
141 });
142}
143
144fn next_reconnect_delay(current: Duration, max: Duration) -> Duration {
145 let doubled = current.saturating_mul(2);
146 if doubled > max { max } else { doubled }
147}
148
149impl RawSocketChannel {
150 fn new(channel: PondChannel) -> Self {
151 let (state_tx, _) = watch::channel(map_channel_state(channel.state()));
152 let (message_tx, _) = broadcast::channel(1024);
153
154 let mut pond_states = channel.subscribe_state();
155 let mirror_state_tx = state_tx.clone();
156 tokio::spawn(async move {
157 loop {
158 mirror_state_tx.send_replace(map_channel_state(*pond_states.borrow_and_update()));
159 if pond_states.changed().await.is_err() {
160 break;
161 }
162 }
163 });
164
165 let mut pond_events = channel.subscribe_events();
166 let mirror_message_tx = message_tx.clone();
167 tokio::spawn(async move {
168 while let Ok(event) = pond_events.recv().await {
169 if let Some((event, payload)) = map_channel_event(event) {
170 let _ = mirror_message_tx.send((event, payload));
171 }
172 }
173 });
174
175 Self {
176 channel,
177 state_tx,
178 message_tx,
179 }
180 }
181
182 pub(crate) fn name(&self) -> &str {
183 self.channel.name()
184 }
185
186 pub(crate) fn subscribe_state(&self) -> watch::Receiver<ChannelState> {
187 self.state_tx.subscribe()
188 }
189
190 pub(crate) fn subscribe_messages(&self) -> broadcast::Receiver<(String, EventData)> {
191 self.message_tx.subscribe()
192 }
193
194 fn closed_error(&self) -> Option<VoxRtcError> {
195 match self.channel.state() {
196 PondChannelState::Closed | PondChannelState::Declined => {
197 Some(VoxRtcError::ChannelClosed)
198 }
199 _ => None,
200 }
201 }
202
203 pub(crate) async fn join(&self) -> Result<()> {
204 if let Some(error) = self.closed_error() {
205 return Err(error);
206 }
207 self.channel.join().await;
208 Ok(())
209 }
210
211 pub(crate) async fn leave(&self) -> Result<()> {
212 if let Some(error) = self.closed_error() {
213 return Err(error);
214 }
215 self.channel.leave().await;
216 Ok(())
217 }
218
219 pub(crate) async fn send_message(&self, event: &str, payload: EventData) -> Result<()> {
220 if let Some(error) = self.closed_error() {
221 return Err(error);
222 }
223 self.channel.send_message(event, Some(payload)).await;
224 Ok(())
225 }
226}
227
228fn map_connection_state(state: PondConnectionState) -> ConnectionState {
229 match state {
230 PondConnectionState::Connecting => ConnectionState::Connecting,
231 PondConnectionState::Connected => ConnectionState::Connected,
232 PondConnectionState::Disconnected => ConnectionState::Disconnected,
233 }
234}
235
236fn map_channel_state(state: PondChannelState) -> ChannelState {
237 match state {
238 PondChannelState::Idle => ChannelState::Idle,
239 PondChannelState::Joining => ChannelState::Joining,
240 PondChannelState::Joined => ChannelState::Joined,
241 PondChannelState::Closed => ChannelState::Closed,
242 PondChannelState::Declined => ChannelState::Declined,
243 PondChannelState::Stalled => ChannelState::Joining,
244 }
245}
246
247fn map_channel_event(event: ChannelEvent) -> Option<(String, EventData)> {
248 match event {
249 ChannelEvent::Message(message) => Some((message.event, message.payload)),
250 ChannelEvent::Presence(_) => None,
251 }
252}
253
254impl From<ClientError> for VoxRtcError {
255 fn from(value: ClientError) -> Self {
256 match value {
257 ClientError::Url(err) => Self::InvalidUrl(err),
258 ClientError::Serialization(err) => Self::Json(err),
259 ClientError::WebSocket(err) => Self::PondSocketClient(err.to_string()),
260 ClientError::NotConnected => Self::NotConnected,
261 ClientError::ChannelClosed => Self::ChannelClosed,
262 other => Self::PondSocketClient(other.to_string()),
263 }
264 }
265}
266
267#[cfg(test)]
268pub(crate) async fn test_channel() -> (RawSocketChannel, broadcast::Sender<(String, EventData)>) {
269 let client = PondClient::new("ws://localhost/socket", None).expect("valid test url");
270 let channel = client.create_channel("/rtc/test", None).await;
271 let raw = RawSocketChannel::new(channel);
272 let sender = raw.message_tx.clone();
273 (raw, sender)
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn distinguishes_not_connected_from_channel_closed() {
282 assert!(matches!(
283 VoxRtcError::from(ClientError::NotConnected),
284 VoxRtcError::NotConnected
285 ));
286 assert!(matches!(
287 VoxRtcError::from(ClientError::ChannelClosed),
288 VoxRtcError::ChannelClosed
289 ));
290 }
291
292 #[test]
293 fn reconnect_delay_doubles_then_caps() {
294 let max = Duration::from_secs(5);
295 assert_eq!(
296 next_reconnect_delay(Duration::from_millis(200), max),
297 Duration::from_millis(400)
298 );
299 assert_eq!(
300 next_reconnect_delay(Duration::from_secs(4), max),
301 Duration::from_secs(5)
302 );
303 assert_eq!(next_reconnect_delay(max, max), max);
304 }
305
306 #[tokio::test]
307 async fn send_message_errors_when_channel_closed() {
308 let (channel, _sender) = test_channel().await;
309 channel.leave().await.expect("first leave closes channel");
310 let error = channel
311 .send_message("response.start", EventData::new())
312 .await
313 .expect_err("closed channel must reject sends");
314 assert!(matches!(error, VoxRtcError::ChannelClosed));
315 }
316
317 #[tokio::test]
318 async fn join_and_leave_error_when_channel_closed() {
319 let (channel, _sender) = test_channel().await;
320 channel.leave().await.expect("first leave closes channel");
321 assert!(matches!(
322 channel.join().await.expect_err("cannot join a closed channel"),
323 VoxRtcError::ChannelClosed
324 ));
325 assert!(matches!(
326 channel
327 .leave()
328 .await
329 .expect_err("cannot leave an already-closed channel"),
330 VoxRtcError::ChannelClosed
331 ));
332 }
333
334 #[tokio::test]
335 async fn lagged_broadcast_does_not_stop_consumption() {
336 let (tx, mut rx) = broadcast::channel::<(String, EventData)>(2);
337 for index in 0..5u32 {
338 let _ = tx.send((format!("event-{index}"), EventData::new()));
339 }
340
341 let mut lagged = false;
342 let mut delivered = Vec::new();
343 loop {
344 match rx.try_recv() {
345 Ok(message) => delivered.push(message.0),
346 Err(broadcast::error::TryRecvError::Lagged(_)) => lagged = true,
347 Err(_) => break,
348 }
349 }
350
351 assert!(lagged, "small buffer overflow must surface a lag");
352 assert!(
353 delivered.contains(&"event-4".to_owned()),
354 "consumer must keep reading past the lag: {delivered:?}"
355 );
356 }
357}