Skip to main content

realtime/client/
api.rs

1use std::{
2    collections::HashMap,
3    sync::{
4        Arc,
5        atomic::{AtomicU64, Ordering},
6    },
7    time::Duration,
8};
9
10use futures_util::{
11    SinkExt, StreamExt,
12    stream::{SplitSink, SplitStream},
13};
14use serde_json::Value;
15use tokio::net::TcpStream;
16use tokio::{
17    sync::{Mutex, mpsc, oneshot},
18    time::{interval, timeout},
19};
20use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
21use uuid::Uuid;
22
23use crate::protocol::{ClientFrame, DEFAULT_EVENT, ErrorPayload, ServerFrame};
24
25use super::ClientConfig;
26
27pub type ClientResult<T> = std::result::Result<T, String>;
28pub type SubscriptionId = u64;
29type ChannelHandler = Arc<dyn Fn(Value) + Send + Sync>;
30type GlobalHandler = Arc<dyn Fn(String, Value) + Send + Sync>;
31type ChannelEventHandler = Arc<dyn Fn(String, Value) + Send + Sync>;
32type GlobalEventHandler = Arc<dyn Fn(String, String, Value) + Send + Sync>;
33type PendingAcks = Arc<Mutex<HashMap<String, oneshot::Sender<ClientResult<()>>>>>;
34type ChannelHandlers =
35    Arc<std::sync::Mutex<HashMap<String, HashMap<SubscriptionId, ChannelHandler>>>>;
36type GlobalHandlers = Arc<std::sync::Mutex<HashMap<SubscriptionId, GlobalHandler>>>;
37type ChannelEventHandlers =
38    Arc<std::sync::Mutex<HashMap<String, HashMap<SubscriptionId, ChannelEventHandler>>>>;
39type GlobalEventHandlers = Arc<std::sync::Mutex<HashMap<SubscriptionId, GlobalEventHandler>>>;
40type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
41type WsWriter = SplitSink<WsStream, Message>;
42type WsReader = SplitStream<WsStream>;
43
44#[derive(Clone)]
45pub struct RealtimeClient {
46    outbound_tx: mpsc::Sender<ClientFrame>,
47    pending_acks: PendingAcks,
48    channel_handlers: ChannelHandlers,
49    global_handlers: GlobalHandlers,
50    channel_event_handlers: ChannelEventHandlers,
51    global_event_handlers: GlobalEventHandlers,
52    next_subscription_id: Arc<AtomicU64>,
53    cfg: ClientConfig,
54}
55
56impl RealtimeClient {
57    pub async fn connect(base_url: &str, token: &str) -> ClientResult<Self> {
58        Self::connect_with_config(base_url, token, ClientConfig::default()).await
59    }
60
61    pub async fn connect_with_config(
62        base_url: &str,
63        token: &str,
64        cfg: ClientConfig,
65    ) -> ClientResult<Self> {
66        let ws = Self::open_socket(base_url, token).await?;
67        let (write, read) = ws.split();
68        let (outbound_tx, outbound_rx) = mpsc::channel::<ClientFrame>(cfg.outbound_buffer);
69
70        let pending_acks: PendingAcks = Arc::new(Mutex::new(HashMap::new()));
71        let channel_handlers: ChannelHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
72        let global_handlers: GlobalHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
73        let channel_event_handlers: ChannelEventHandlers =
74            Arc::new(std::sync::Mutex::new(HashMap::new()));
75        let global_event_handlers: GlobalEventHandlers =
76            Arc::new(std::sync::Mutex::new(HashMap::new()));
77
78        Self::spawn_writer_task(write, outbound_rx);
79        Self::spawn_reader_task(
80            read,
81            Arc::clone(&pending_acks),
82            Arc::clone(&channel_handlers),
83            Arc::clone(&global_handlers),
84            Arc::clone(&channel_event_handlers),
85            Arc::clone(&global_event_handlers),
86        );
87        Self::spawn_ping_task(outbound_tx.clone(), cfg.ping_interval);
88
89        Ok(Self {
90            outbound_tx,
91            pending_acks,
92            channel_handlers,
93            global_handlers,
94            channel_event_handlers,
95            global_event_handlers,
96            next_subscription_id: Arc::new(AtomicU64::new(1)),
97            cfg,
98        })
99    }
100
101    pub async fn join(&self, channel: &str) -> ClientResult<()> {
102        self.request_ack(
103            ClientFrame::ChannelJoin {
104                id: Uuid::new_v4().to_string(),
105                channel: channel.to_string(),
106                ts: None,
107            },
108            self.cfg.request_timeout,
109        )
110        .await
111    }
112
113    pub async fn leave(&self, channel: &str) -> ClientResult<()> {
114        self.request_ack(
115            ClientFrame::ChannelLeave {
116                id: Uuid::new_v4().to_string(),
117                channel: channel.to_string(),
118                ts: None,
119            },
120            self.cfg.request_timeout,
121        )
122        .await
123    }
124
125    pub async fn send(&self, channel: &str, message: Value) -> ClientResult<()> {
126        self.send_event(channel, DEFAULT_EVENT, message).await
127    }
128
129    pub async fn send_event(&self, channel: &str, event: &str, message: Value) -> ClientResult<()> {
130        self.request_ack(
131            ClientFrame::ChannelEmit {
132                id: Uuid::new_v4().to_string(),
133                channel: channel.to_string(),
134                event: event.to_string(),
135                data: message,
136                ts: None,
137            },
138            self.cfg.request_timeout,
139        )
140        .await
141    }
142
143    pub fn on_message<F>(&self, channel: &str, handler: F) -> SubscriptionId
144    where
145        F: Fn(Value) + Send + Sync + 'static,
146    {
147        let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
148        let mut guard = self
149            .channel_handlers
150            .lock()
151            .expect("channel handler mutex poisoned");
152        guard
153            .entry(channel.to_string())
154            .or_default()
155            .insert(id, Arc::new(handler));
156        id
157    }
158
159    pub fn on_messages<F>(&self, handler: F) -> SubscriptionId
160    where
161        F: Fn(String, Value) + Send + Sync + 'static,
162    {
163        let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
164        self.global_handlers
165            .lock()
166            .expect("global handler mutex poisoned")
167            .insert(id, Arc::new(handler));
168        id
169    }
170
171    pub fn on_channel_event<F>(&self, channel: &str, handler: F) -> SubscriptionId
172    where
173        F: Fn(String, Value) + Send + Sync + 'static,
174    {
175        let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
176        let mut guard = self
177            .channel_event_handlers
178            .lock()
179            .expect("channel event handler mutex poisoned");
180        guard
181            .entry(channel.to_string())
182            .or_default()
183            .insert(id, Arc::new(handler));
184        id
185    }
186
187    pub fn on_events<F>(&self, handler: F) -> SubscriptionId
188    where
189        F: Fn(String, String, Value) + Send + Sync + 'static,
190    {
191        let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
192        self.global_event_handlers
193            .lock()
194            .expect("global event handler mutex poisoned")
195            .insert(id, Arc::new(handler));
196        id
197    }
198
199    pub fn off(&self, id: SubscriptionId) -> bool {
200        let mut removed = false;
201
202        let mut global = self
203            .global_handlers
204            .lock()
205            .expect("global handler mutex poisoned");
206        if global.remove(&id).is_some() {
207            removed = true;
208        }
209        drop(global);
210
211        let mut channels = self
212            .channel_handlers
213            .lock()
214            .expect("channel handler mutex poisoned");
215        for handlers in channels.values_mut() {
216            if handlers.remove(&id).is_some() {
217                removed = true;
218            }
219        }
220
221        let mut global_events = self
222            .global_event_handlers
223            .lock()
224            .expect("global event handler mutex poisoned");
225        if global_events.remove(&id).is_some() {
226            removed = true;
227        }
228        drop(global_events);
229
230        let mut channel_events = self
231            .channel_event_handlers
232            .lock()
233            .expect("channel event handler mutex poisoned");
234        for handlers in channel_events.values_mut() {
235            if handlers.remove(&id).is_some() {
236                removed = true;
237            }
238        }
239
240        removed
241    }
242
243    async fn open_socket(base_url: &str, token: &str) -> ClientResult<WsStream> {
244        let url = with_query_token(base_url, token);
245        let (ws, _) = connect_async(&url)
246            .await
247            .map_err(|err| format!("failed to connect to {url}: {err}"))?;
248        Ok(ws)
249    }
250
251    fn spawn_writer_task(mut write: WsWriter, mut outbound_rx: mpsc::Receiver<ClientFrame>) {
252        tokio::spawn(async move {
253            while let Some(frame) = outbound_rx.recv().await {
254                let text = match serde_json::to_string(&frame) {
255                    Ok(text) => text,
256                    Err(err) => {
257                        eprintln!("failed to serialize outbound frame: {err}");
258                        continue;
259                    }
260                };
261
262                if write.send(Message::Text(text.into())).await.is_err() {
263                    break;
264                }
265            }
266        });
267    }
268
269    fn spawn_reader_task(
270        mut read: WsReader,
271        pending_acks: PendingAcks,
272        channel_handlers: ChannelHandlers,
273        global_handlers: GlobalHandlers,
274        channel_event_handlers: ChannelEventHandlers,
275        global_event_handlers: GlobalEventHandlers,
276    ) {
277        tokio::spawn(async move {
278            while let Some(next) = read.next().await {
279                let msg = match next {
280                    Ok(msg) => msg,
281                    Err(err) => {
282                        eprintln!("websocket read error: {err}");
283                        break;
284                    }
285                };
286
287                let keep_reading = Self::handle_incoming_message(
288                    msg,
289                    &pending_acks,
290                    &channel_handlers,
291                    &global_handlers,
292                    &channel_event_handlers,
293                    &global_event_handlers,
294                )
295                .await;
296                if !keep_reading {
297                    break;
298                }
299            }
300
301            Self::fail_pending_acks(&pending_acks).await;
302        });
303    }
304
305    async fn handle_incoming_message(
306        msg: Message,
307        pending_acks: &PendingAcks,
308        channel_handlers: &ChannelHandlers,
309        global_handlers: &GlobalHandlers,
310        channel_event_handlers: &ChannelEventHandlers,
311        global_event_handlers: &GlobalEventHandlers,
312    ) -> bool {
313        let text = match msg {
314            Message::Text(text) => text,
315            Message::Close(_) => return false,
316            _ => return true,
317        };
318
319        let frame = match serde_json::from_str::<ServerFrame>(&text) {
320            Ok(frame) => frame,
321            Err(err) => {
322                eprintln!("invalid server frame: {err}");
323                return true;
324            }
325        };
326
327        Self::handle_server_frame(
328            frame,
329            pending_acks,
330            channel_handlers,
331            global_handlers,
332            channel_event_handlers,
333            global_event_handlers,
334        )
335        .await;
336        true
337    }
338
339    async fn handle_server_frame(
340        frame: ServerFrame,
341        pending_acks: &PendingAcks,
342        channel_handlers: &ChannelHandlers,
343        global_handlers: &GlobalHandlers,
344        channel_event_handlers: &ChannelEventHandlers,
345        global_event_handlers: &GlobalEventHandlers,
346    ) {
347        match frame {
348            ServerFrame::Connected {
349                conn_id, user_id, ..
350            } => {
351                println!("connected: conn_id={conn_id} user_id={user_id}");
352            }
353            ServerFrame::Joined { channel, .. } => {
354                println!("joined channel={channel}");
355            }
356            ServerFrame::Left { channel, .. } => {
357                println!("left channel={channel}");
358            }
359            ServerFrame::Event {
360                channel,
361                event,
362                data,
363                ..
364            } => {
365                dispatch_channel_handlers(channel_handlers, &channel, &data);
366                dispatch_global_handlers(global_handlers, &channel, &data);
367                dispatch_channel_event_handlers(channel_event_handlers, &channel, &event, &data);
368                dispatch_global_event_handlers(global_event_handlers, &channel, &event, &data);
369            }
370            ServerFrame::Ack {
371                for_id, ok, error, ..
372            } => {
373                Self::resolve_ack(pending_acks, for_id, ok, error).await;
374            }
375            ServerFrame::Pong { id, .. } => {
376                println!("pong id={id}");
377            }
378            ServerFrame::Error { error, .. } => {
379                eprintln!("server error {}: {}", error.code, error.message);
380            }
381        }
382    }
383
384    async fn resolve_ack(
385        pending_acks: &PendingAcks,
386        for_id: String,
387        ok: bool,
388        error: Option<ErrorPayload>,
389    ) {
390        let Some(tx) = pending_acks.lock().await.remove(&for_id) else {
391            return;
392        };
393
394        let result = if ok {
395            Ok(())
396        } else {
397            let message = error
398                .map(|e| format!("{}: {}", e.code, e.message))
399                .unwrap_or_else(|| "request rejected".to_string());
400            Err(message)
401        };
402        let _ = tx.send(result);
403    }
404
405    async fn fail_pending_acks(pending_acks: &PendingAcks) {
406        let mut pending = pending_acks.lock().await;
407        for (_, tx) in pending.drain() {
408            let _ = tx.send(Err("websocket connection closed".to_string()));
409        }
410    }
411
412    fn spawn_ping_task(outbound_tx: mpsc::Sender<ClientFrame>, ping_interval: Duration) {
413        tokio::spawn(async move {
414            let mut ticker = interval(ping_interval);
415            loop {
416                ticker.tick().await;
417                if outbound_tx
418                    .send(ClientFrame::Ping {
419                        id: Uuid::new_v4().to_string(),
420                        ts: None,
421                    })
422                    .await
423                    .is_err()
424                {
425                    break;
426                }
427            }
428        });
429    }
430
431    async fn request_ack(&self, frame: ClientFrame, timeout_dur: Duration) -> ClientResult<()> {
432        let req_id = frame_id(&frame).to_string();
433        let (tx, rx) = oneshot::channel();
434        self.pending_acks.lock().await.insert(req_id.clone(), tx);
435
436        if let Err(err) = self.outbound_tx.send(frame).await {
437            self.pending_acks.lock().await.remove(&req_id);
438            return Err(format!("failed to send request: {err}"));
439        }
440
441        match timeout(timeout_dur, rx).await {
442            Ok(Ok(result)) => result,
443            Ok(Err(_)) => Err("ack wait channel dropped".to_string()),
444            Err(_) => {
445                self.pending_acks.lock().await.remove(&req_id);
446                Err(format!("ack timeout for request {req_id}"))
447            }
448        }
449    }
450}
451
452fn dispatch_channel_handlers(handlers: &ChannelHandlers, channel: &str, message: &Value) {
453    let callbacks: Vec<ChannelHandler> = {
454        let guard = handlers.lock().expect("channel handler mutex poisoned");
455        guard
456            .get(channel)
457            .map(|entries| entries.values().cloned().collect())
458            .unwrap_or_default()
459    };
460
461    for callback in callbacks {
462        callback(message.clone());
463    }
464}
465
466fn dispatch_global_handlers(handlers: &GlobalHandlers, channel: &str, message: &Value) {
467    let callbacks: Vec<GlobalHandler> = {
468        let guard = handlers.lock().expect("global handler mutex poisoned");
469        guard.values().cloned().collect()
470    };
471
472    for callback in callbacks {
473        callback(channel.to_string(), message.clone());
474    }
475}
476
477fn dispatch_channel_event_handlers(
478    handlers: &ChannelEventHandlers,
479    channel: &str,
480    event: &str,
481    message: &Value,
482) {
483    let callbacks: Vec<ChannelEventHandler> = {
484        let guard = handlers
485            .lock()
486            .expect("channel event handler mutex poisoned");
487        guard
488            .get(channel)
489            .map(|entries| entries.values().cloned().collect())
490            .unwrap_or_default()
491    };
492
493    for callback in callbacks {
494        callback(event.to_string(), message.clone());
495    }
496}
497
498fn dispatch_global_event_handlers(
499    handlers: &GlobalEventHandlers,
500    channel: &str,
501    event: &str,
502    message: &Value,
503) {
504    let callbacks: Vec<GlobalEventHandler> = {
505        let guard = handlers
506            .lock()
507            .expect("global event handler mutex poisoned");
508        guard.values().cloned().collect()
509    };
510
511    for callback in callbacks {
512        callback(channel.to_string(), event.to_string(), message.clone());
513    }
514}
515
516fn frame_id(frame: &ClientFrame) -> &str {
517    match frame {
518        ClientFrame::ChannelJoin { id, .. } => id,
519        ClientFrame::ChannelLeave { id, .. } => id,
520        ClientFrame::ChannelEmit { id, .. } => id,
521        ClientFrame::Ping { id, .. } => id,
522    }
523}
524
525fn with_query_token(base_url: &str, token: &str) -> String {
526    if base_url.contains('?') {
527        format!("{base_url}&token={token}")
528    } else {
529        format!("{base_url}?token={token}")
530    }
531}