1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::Duration;
4
5use futures_util::{SinkExt, StreamExt};
6use pondsocket_common::{
7 ChannelEvent, ChannelState, ClientAction, ClientMessage, EventName, JoinParams, PondMessage,
8 PondPresence, PresenceEventType, PresenceMessage, ServerAction, ServerMessage, uuid,
9};
10use serde_json::{Map, Value};
11use thiserror::Error;
12use tokio::sync::{Mutex, broadcast, mpsc, oneshot, watch};
13use tokio::task::JoinHandle;
14use tokio_tungstenite::connect_async;
15use tokio_tungstenite::tungstenite::Message;
16use url::Url;
17
18pub mod typed;
19pub use typed::TypedChannel;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ConnectionState {
23 Connecting,
24 Connected,
25 Disconnected,
26}
27
28#[derive(Debug, Clone)]
29pub struct ClientOptions {
30 pub connection_timeout: Duration,
31 pub response_timeout: Duration,
32 pub max_queue_size: usize,
33}
34
35impl Default for ClientOptions {
36 fn default() -> Self {
37 Self {
38 connection_timeout: Duration::from_secs(10),
39 response_timeout: Duration::from_secs(5),
40 max_queue_size: 100,
41 }
42 }
43}
44
45#[derive(Debug, Error)]
46pub enum ClientError {
47 #[error("invalid websocket URL: {0}")]
48 Url(#[from] url::ParseError),
49 #[error("unsupported URL scheme: {0}")]
50 UnsupportedScheme(String),
51 #[error("websocket error: {0}")]
52 WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
53 #[error("serialization error: {0}")]
54 Serialization(#[from] serde_json::Error),
55 #[error("connection timed out")]
56 ConnectionTimeout,
57 #[error("client is not connected")]
58 NotConnected,
59 #[error("channel is closed")]
60 ChannelClosed,
61 #[error("response timed out")]
62 ResponseTimeout,
63}
64
65type Result<T> = std::result::Result<T, ClientError>;
66
67#[derive(Clone)]
68pub struct PondClient {
69 inner: Arc<ClientInner>,
70}
71
72struct ClientInner {
73 url: String,
74 options: ClientOptions,
75 state: watch::Sender<ConnectionState>,
76 channels: Mutex<HashMap<String, Channel>>,
77 outbound: Mutex<Option<mpsc::Sender<ClientMessage>>>,
78 read_task: Mutex<Option<JoinHandle<()>>>,
79 write_task: Mutex<Option<JoinHandle<()>>>,
80}
81
82#[derive(Clone)]
83pub struct Channel {
84 inner: Arc<ChannelInner>,
85}
86
87struct ChannelInner {
88 name: String,
89 params: JoinParams,
90 client: Arc<ClientInner>,
91 state: watch::Sender<ChannelState>,
92 events: broadcast::Sender<ChannelEvent>,
93 presence: Mutex<Vec<PondPresence>>,
94 queue: Mutex<VecDeque<ClientMessage>>,
95 pending: Mutex<HashMap<String, oneshot::Sender<PondMessage>>>,
96 closed: Mutex<bool>,
97}
98
99impl PondClient {
100 pub fn new(endpoint: impl AsRef<str>, params: Option<JoinParams>) -> Result<Self> {
101 Self::with_options(endpoint, params, ClientOptions::default())
102 }
103
104 pub fn with_options(
105 endpoint: impl AsRef<str>,
106 params: Option<JoinParams>,
107 options: ClientOptions,
108 ) -> Result<Self> {
109 let url = resolve_url(endpoint.as_ref(), params.as_ref())?;
110 let (state, _) = watch::channel(ConnectionState::Disconnected);
111
112 Ok(Self {
113 inner: Arc::new(ClientInner {
114 url,
115 options,
116 state,
117 channels: Mutex::new(HashMap::new()),
118 outbound: Mutex::new(None),
119 read_task: Mutex::new(None),
120 write_task: Mutex::new(None),
121 }),
122 })
123 }
124
125 pub fn state(&self) -> ConnectionState {
126 *self.inner.state.borrow()
127 }
128
129 pub fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
130 self.inner.state.subscribe()
131 }
132
133 pub async fn create_channel(
134 &self,
135 name: impl Into<String>,
136 params: Option<JoinParams>,
137 ) -> Channel {
138 let name = name.into();
139 let mut channels = self.inner.channels.lock().await;
140 if let Some(channel) = channels.get(&name) {
141 if channel.state() != ChannelState::Closed && channel.state() != ChannelState::Declined
142 {
143 return channel.clone();
144 }
145 }
146
147 let (state, _) = watch::channel(ChannelState::Idle);
148 let (events, _) = broadcast::channel(100);
149 let channel = Channel {
150 inner: Arc::new(ChannelInner {
151 name: name.clone(),
152 params: params.unwrap_or_default(),
153 client: Arc::clone(&self.inner),
154 state,
155 events,
156 presence: Mutex::new(Vec::new()),
157 queue: Mutex::new(VecDeque::new()),
158 pending: Mutex::new(HashMap::new()),
159 closed: Mutex::new(false),
160 }),
161 };
162 channels.insert(name, channel.clone());
163 channel
164 }
165
166 pub async fn connect(&self) -> Result<()> {
167 if self.state() != ConnectionState::Disconnected {
168 return Ok(());
169 }
170 self.inner.state.send_replace(ConnectionState::Connecting);
171 let connect = connect_async(&self.inner.url);
172 let (socket, _) = tokio::time::timeout(self.inner.options.connection_timeout, connect)
173 .await
174 .map_err(|_| ClientError::ConnectionTimeout)??;
175 let (mut writer, mut reader) = socket.split();
176 let (tx, mut rx) = mpsc::channel::<ClientMessage>(self.inner.options.max_queue_size);
177 *self.inner.outbound.lock().await = Some(tx);
178
179 let write_task = tokio::spawn(async move {
180 while let Some(message) = rx.recv().await {
181 let Ok(text) = serde_json::to_string(&message) else {
182 continue;
183 };
184 if writer.send(Message::Text(text.into())).await.is_err() {
185 break;
186 }
187 }
188 let _ = writer.close().await;
189 });
190
191 let inner = Arc::clone(&self.inner);
192 let read_task = tokio::spawn(async move {
193 while let Some(frame) = reader.next().await {
194 let text = match frame {
195 Ok(Message::Text(text)) => text.to_string(),
196 Ok(Message::Binary(bytes)) => match String::from_utf8(bytes.to_vec()) {
197 Ok(text) => text,
198 Err(_) => continue,
199 },
200 Ok(Message::Close(_)) => break,
201 Ok(_) => continue,
202 Err(_) => break,
203 };
204 let Ok(event) = pondsocket_common::parse_channel_event(&text) else {
205 continue;
206 };
207 inner.route_event(event).await;
208 }
209 inner.state.send_replace(ConnectionState::Disconnected);
210 *inner.outbound.lock().await = None;
211 });
212
213 *self.inner.read_task.lock().await = Some(read_task);
214 *self.inner.write_task.lock().await = Some(write_task);
215 self.inner.state.send_replace(ConnectionState::Connected);
216 self.inner.rejoin_stalled_channels().await;
217 Ok(())
218 }
219
220 pub async fn disconnect(&self) {
221 if let Some(task) = self.inner.read_task.lock().await.take() {
222 task.abort();
223 }
224 if let Some(task) = self.inner.write_task.lock().await.take() {
225 task.abort();
226 }
227 *self.inner.outbound.lock().await = None;
228 self.inner.state.send_replace(ConnectionState::Disconnected);
229 let channels: Vec<Channel> = self.inner.channels.lock().await.values().cloned().collect();
230 for channel in channels {
231 channel.force_close().await;
232 }
233 self.inner.channels.lock().await.clear();
234 }
235}
236
237impl ClientInner {
238 async fn publish(&self, message: ClientMessage) -> Result<()> {
239 let tx = self
240 .outbound
241 .lock()
242 .await
243 .clone()
244 .ok_or(ClientError::NotConnected)?;
245 tx.send(message)
246 .await
247 .map_err(|_| ClientError::NotConnected)
248 }
249
250 async fn route_event(&self, event: ChannelEvent) {
251 let channel_name = match &event {
252 ChannelEvent::Message(message) => &message.channel_name,
253 ChannelEvent::Presence(message) => &message.channel_name,
254 };
255 let channel = self.channels.lock().await.get(channel_name).cloned();
256 if let Some(channel) = channel {
257 channel.handle_event(event).await;
258 }
259 }
260
261 async fn rejoin_stalled_channels(&self) {
262 let channels: Vec<Channel> = self.channels.lock().await.values().cloned().collect();
263 for channel in channels {
264 let state = channel.state();
265 if state == ChannelState::Joining
266 || state == ChannelState::Joined
267 || state == ChannelState::Stalled
268 {
269 channel.join().await;
270 }
271 }
272 }
273}
274
275impl Channel {
276 pub fn name(&self) -> &str {
277 &self.inner.name
278 }
279
280 pub fn state(&self) -> ChannelState {
281 *self.inner.state.borrow()
282 }
283
284 pub fn subscribe_state(&self) -> watch::Receiver<ChannelState> {
285 self.inner.state.subscribe()
286 }
287
288 pub fn subscribe_events(&self) -> broadcast::Receiver<ChannelEvent> {
289 self.inner.events.subscribe()
290 }
291
292 pub async fn presence(&self) -> Vec<PondPresence> {
293 self.inner.presence.lock().await.clone()
294 }
295
296 pub async fn join(&self) {
297 if *self.inner.closed.lock().await {
298 return;
299 }
300 if matches!(
301 self.state(),
302 ChannelState::Joining | ChannelState::Joined | ChannelState::Declined
303 ) {
304 return;
305 }
306 self.inner.state.send_replace(ChannelState::Joining);
307 self.enqueue_or_send(self.join_message()).await;
308 }
309
310 pub async fn leave(&self) {
311 if *self.inner.closed.lock().await {
312 return;
313 }
314 let message = ClientMessage {
315 action: ClientAction::LeaveChannel,
316 event: "LEAVE_CHANNEL".to_owned(),
317 payload: Map::new(),
318 channel_name: self.inner.name.clone(),
319 request_id: uuid(),
320 };
321 let _ = self.inner.client.publish(message).await;
322 self.force_close().await;
323 }
324
325 pub async fn send_message(&self, event: impl Into<String>, payload: Option<PondMessage>) {
326 if *self.inner.closed.lock().await {
327 return;
328 }
329 let message = ClientMessage {
330 action: ClientAction::Broadcast,
331 event: event.into(),
332 payload: payload.unwrap_or_default(),
333 channel_name: self.inner.name.clone(),
334 request_id: uuid(),
335 };
336 self.enqueue_or_send(message).await;
337 }
338
339 pub async fn send_for_response(
340 &self,
341 event: impl Into<String>,
342 payload: Option<PondMessage>,
343 timeout: Option<Duration>,
344 ) -> Result<PondMessage> {
345 if *self.inner.closed.lock().await {
346 return Err(ClientError::ChannelClosed);
347 }
348 let request_id = uuid();
349 let (tx, rx) = oneshot::channel();
350 self.inner
351 .pending
352 .lock()
353 .await
354 .insert(request_id.clone(), tx);
355 let message = ClientMessage {
356 action: ClientAction::Broadcast,
357 event: event.into(),
358 payload: payload.unwrap_or_default(),
359 channel_name: self.inner.name.clone(),
360 request_id: request_id.clone(),
361 };
362 self.enqueue_or_send(message).await;
363 let timeout = timeout.unwrap_or(self.inner.client.options.response_timeout);
364 let result = tokio::time::timeout(timeout, rx).await;
365 self.inner.pending.lock().await.remove(&request_id);
366 match result {
367 Ok(Ok(payload)) => Ok(payload),
368 _ => Err(ClientError::ResponseTimeout),
369 }
370 }
371
372 async fn enqueue_or_send(&self, message: ClientMessage) {
373 let connected = *self.inner.client.state.borrow() == ConnectionState::Connected;
374 let joined = self.state() == ChannelState::Joined;
375 let is_join = message.action == ClientAction::JoinChannel;
376 if connected && (joined || is_join) {
377 if self.inner.client.publish(message.clone()).await.is_ok() {
378 return;
379 }
380 }
381 let mut queue = self.inner.queue.lock().await;
382 if queue.len() == self.inner.client.options.max_queue_size {
383 queue.pop_front();
384 }
385 queue.push_back(message);
386 }
387
388 async fn handle_event(&self, event: ChannelEvent) {
389 if *self.inner.closed.lock().await {
390 return;
391 }
392 match event {
393 ChannelEvent::Presence(message) => self.handle_presence(message).await,
394 ChannelEvent::Message(message) => self.handle_message(message).await,
395 }
396 }
397
398 async fn handle_presence(&self, message: PresenceMessage) {
399 *self.inner.presence.lock().await = message.payload.presence.clone();
400 let event = ChannelEvent::Presence(message.clone());
401 let _ = self.inner.events.send(event);
402 }
403
404 async fn handle_message(&self, message: ServerMessage) {
405 if message.action == ServerAction::System
406 && message.event == event_name(EventName::Acknowledge)
407 {
408 self.acknowledge().await;
409 return;
410 }
411 if message.action == ServerAction::System
412 && message.event == event_name(EventName::Unauthorized)
413 {
414 self.decline().await;
415 return;
416 }
417 if let Some(tx) = self.inner.pending.lock().await.remove(&message.request_id) {
418 let _ = tx.send(message.payload);
419 return;
420 }
421 if self.state() == ChannelState::Joined {
422 let _ = self.inner.events.send(ChannelEvent::Message(message));
423 }
424 }
425
426 async fn acknowledge(&self) {
427 if self.state() != ChannelState::Joined {
428 self.inner.state.send_replace(ChannelState::Joined);
429 }
430 let mut queue = self.inner.queue.lock().await;
431 let pending: Vec<ClientMessage> = queue.drain(..).collect();
432 drop(queue);
433 for message in pending {
434 let _ = self.inner.client.publish(message).await;
435 }
436 }
437
438 async fn decline(&self) {
439 self.inner.state.send_replace(ChannelState::Declined);
440 self.inner.queue.lock().await.clear();
441 self.inner.pending.lock().await.clear();
442 }
443
444 async fn force_close(&self) {
445 *self.inner.closed.lock().await = true;
446 self.inner.state.send_replace(ChannelState::Closed);
447 self.inner.queue.lock().await.clear();
448 self.inner.pending.lock().await.clear();
449 }
450
451 fn join_message(&self) -> ClientMessage {
452 ClientMessage {
453 action: ClientAction::JoinChannel,
454 event: "JOIN_CHANNEL".to_owned(),
455 payload: self.inner.params.clone(),
456 channel_name: self.inner.name.clone(),
457 request_id: uuid(),
458 }
459 }
460}
461
462fn resolve_url(endpoint: &str, params: Option<&JoinParams>) -> Result<String> {
463 let mut url = Url::parse(endpoint)?;
464 match url.scheme() {
465 "http" => url
466 .set_scheme("ws")
467 .map_err(|_| ClientError::UnsupportedScheme("http".to_owned()))?,
468 "https" => url
469 .set_scheme("wss")
470 .map_err(|_| ClientError::UnsupportedScheme("https".to_owned()))?,
471 "ws" | "wss" => {}
472 scheme => return Err(ClientError::UnsupportedScheme(scheme.to_owned())),
473 }
474 if let Some(params) = params {
475 let mut pairs = url.query_pairs_mut();
476 for (key, value) in params {
477 let value = match value {
478 Value::String(value) => value.clone(),
479 other => other.to_string(),
480 };
481 pairs.append_pair(key, &value);
482 }
483 }
484 Ok(url.to_string())
485}
486
487fn event_name(event: EventName) -> String {
488 serde_json::to_string(&event)
489 .unwrap_or_default()
490 .trim_matches('"')
491 .to_owned()
492}
493
494#[allow(dead_code)]
495fn presence_event_name(event: PresenceEventType) -> String {
496 serde_json::to_string(&event)
497 .unwrap_or_default()
498 .trim_matches('"')
499 .to_owned()
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use pondsocket_common::{PondEvent, PondSchema, PresencePayload};
506 use serde::{Deserialize, Serialize};
507
508 #[test]
509 fn resolves_http_url_to_ws_with_params() {
510 let mut params = JoinParams::new();
511 params.insert("token".to_owned(), Value::String("abc".to_owned()));
512 let url = resolve_url("https://example.com/socket?room=one", Some(¶ms)).unwrap();
513 assert_eq!(url, "wss://example.com/socket?room=one&token=abc");
514 }
515
516 #[tokio::test]
517 async fn queues_join_message_before_connect() {
518 let client = PondClient::new("ws://example.com/socket", None).unwrap();
519 let channel = client.create_channel("room", None).await;
520 channel.join().await;
521 assert_eq!(channel.state(), ChannelState::Joining);
522 assert_eq!(channel.inner.queue.lock().await.len(), 1);
523 }
524
525 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
526 struct ChatPayload {
527 text: String,
528 }
529
530 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
531 struct AckPayload {
532 ok: bool,
533 }
534
535 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
536 struct Presence {
537 user_id: String,
538 }
539
540 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
541 struct Assigns {
542 role: String,
543 }
544
545 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
546 struct Join {
547 token: String,
548 }
549
550 struct Chat;
551 struct ChatSchema;
552
553 impl PondEvent for Chat {
554 type Payload = ChatPayload;
555 type Response = AckPayload;
556
557 const NAME: &'static str = "chat";
558 }
559
560 impl PondSchema for ChatSchema {
561 type Presence = Presence;
562 type Assigns = Assigns;
563 type JoinParams = Join;
564 }
565
566 #[tokio::test]
567 async fn typed_channel_sends_and_decodes_schema_values() {
568 let client = PondClient::new("ws://example.com/socket", None).unwrap();
569 let params = Join {
570 token: "secret".to_owned(),
571 };
572 let channel = client
573 .create_typed_channel::<ChatSchema>("room", Some(¶ms))
574 .await
575 .unwrap();
576
577 channel.join().await;
578 channel
579 .send::<Chat>(&ChatPayload {
580 text: "hello".to_owned(),
581 })
582 .await
583 .unwrap();
584
585 let queued = channel.raw().inner.queue.lock().await;
586 assert_eq!(queued[0].payload["token"], "secret");
587 assert_eq!(queued[1].event, "chat");
588 assert_eq!(queued[1].payload["text"], "hello");
589 drop(queued);
590
591 let message = ServerMessage {
592 action: ServerAction::Broadcast,
593 event: "chat".to_owned(),
594 channel_name: "room".to_owned(),
595 request_id: "r1".to_owned(),
596 payload: serde_json::from_value(serde_json::json!({ "text": "from server" })).unwrap(),
597 };
598 assert_eq!(
599 channel.decode_message::<Chat>(&message).unwrap(),
600 Some(ChatPayload {
601 text: "from server".to_owned()
602 })
603 );
604
605 let presence = PresenceMessage {
606 action: pondsocket_common::PresenceAction::Presence,
607 event: PresenceEventType::Join,
608 channel_name: "room".to_owned(),
609 request_id: "p1".to_owned(),
610 payload: PresencePayload {
611 changed: serde_json::from_value(serde_json::json!({ "user_id": "u1" })).unwrap(),
612 presence: vec![
613 serde_json::from_value(serde_json::json!({ "user_id": "u1" })).unwrap(),
614 ],
615 },
616 };
617 let (changed, users) = channel.decode_presence(&presence).unwrap();
618 assert_eq!(
619 changed,
620 Presence {
621 user_id: "u1".to_owned()
622 }
623 );
624 assert_eq!(users, vec![changed]);
625 }
626}