realtime_rs/
realtime_channel.rs

1use crate::realtime_client::ClientManager;
2use crate::realtime_client::ClientManagerSync;
3use crate::realtime_presence::PresenceCallbackMap;
4use crate::realtime_presence::RealtimePresence;
5use crate::Responder;
6
7use serde_json::Value;
8use tokio::{
9    runtime::Runtime,
10    sync::{
11        mpsc::{self, error::SendError, UnboundedReceiver, UnboundedSender},
12        oneshot::{self, error::RecvError},
13        Mutex,
14    },
15    task::JoinHandle,
16};
17use uuid::Uuid;
18
19use crate::message::{
20    payload::{
21        AccessTokenPayload, BroadcastConfig, BroadcastPayload, JoinConfig, JoinPayload, Payload,
22        PayloadStatus, PostgresChange, PostgresChangesEvent, PostgresChangesPayload,
23        PresenceConfig,
24    },
25    presence::{PresenceEvent, PresenceState},
26    MessageEvent, PostgresChangeFilter, RealtimeMessage,
27};
28
29use std::fmt::Debug;
30use std::{collections::HashMap, sync::Arc};
31
32type CdcCallback = (
33    PostgresChangeFilter,
34    Box<dyn FnMut(&PostgresChangesPayload) + Send>,
35);
36type BroadcastCallback = Box<dyn FnMut(&HashMap<String, Value>) + Send>;
37pub(crate) type PresenceCallback = Box<dyn Fn(String, PresenceState, PresenceState) + Send>;
38
39/// Channel states
40#[derive(PartialEq, Clone, Copy, Debug)]
41pub enum ChannelState {
42    Closed,
43    Errored,
44    Joined,
45    Joining,
46    Leaving,
47}
48
49/// Error for channel send failures
50#[derive(Debug)]
51pub enum ChannelSendError {
52    NoChannel,
53    SendError(SendError<RealtimeMessage>),
54    ChannelError(ChannelState),
55}
56
57pub(crate) enum ChannelManagerMessage {
58    Subscribe,
59    Unsubscribe {
60        res: Responder<Result<ChannelState, ChannelSendError>>,
61    },
62    SubscribeBlocking {
63        res: Responder<()>,
64    },
65    Broadcast {
66        payload: BroadcastPayload,
67    },
68    ClientTx {
69        new_tx: UnboundedSender<RealtimeMessage>,
70        res: Responder<()>,
71    },
72    GetState {
73        res: Responder<ChannelState>,
74    },
75    GetTx {
76        res: Responder<UnboundedSender<RealtimeMessage>>,
77    },
78    GetTopic {
79        res: Responder<String>,
80    },
81    GetPresenceState {
82        res: Responder<PresenceState>,
83    },
84    PresenceTrack {
85        payload: HashMap<String, Value>,
86        res: Responder<()>,
87    },
88    PresenceUntrack {
89        res: Responder<()>,
90    },
91    ReAuth {
92        res: Responder<()>,
93    },
94}
95
96/// Manager struct for a [RealtimeChannel]
97///
98/// Returned by [RealtimeChannelBuilder::build()]
99// TODO code example showing creation
100#[derive(Clone, Debug)]
101pub struct ChannelManager {
102    pub(crate) tx: UnboundedSender<ChannelManagerMessage>,
103    rt: Arc<Runtime>,
104}
105
106impl ChannelManager {
107    /// Send a JoinMessage for the channel
108    pub fn subscribe(&self) {
109        let _ = self.send(ChannelManagerMessage::Subscribe);
110    }
111    /// Leave the channel and stop recieving messages
112    ///
113    /// Once unsubscribed this manager is useless and should be dropped
114    // TODO return a preconfigured ChannelBuilder here, like client does?
115    pub async fn unsubscribe(&self) -> Result<Result<ChannelState, ChannelSendError>, RecvError> {
116        let (tx, rx) = oneshot::channel();
117        let _ = self.send(ChannelManagerMessage::Unsubscribe { res: tx });
118        rx.await
119    }
120    /// Send a JoinMessage for the channel and wait until the server has responded
121    pub async fn subscribe_blocking(&self) -> Result<(), oneshot::error::RecvError> {
122        let (tx, rx) = oneshot::channel();
123        let _ = self.send(ChannelManagerMessage::SubscribeBlocking { res: tx });
124        rx.await
125    }
126    /// Send a broadcast message on the channel
127    pub fn broadcast(&self, payload: BroadcastPayload) {
128        let _ = self.send(ChannelManagerMessage::Broadcast { payload });
129    }
130    /// Track data in Presence
131    pub async fn track(&self, payload: HashMap<String, Value>) -> Result<(), RecvError> {
132        let (tx, rx) = oneshot::channel();
133        let _ = self.send(ChannelManagerMessage::PresenceTrack { payload, res: tx });
134        rx.await
135    }
136    /// Stop tracking with Presence
137    pub async fn untrack(&self) -> Result<(), RecvError> {
138        let (tx, rx) = oneshot::channel();
139        let _ = self.send(ChannelManagerMessage::PresenceUntrack { res: tx });
140        rx.await
141    }
142    /// Returns the [ChannelState] for the associated channel
143    pub async fn get_state(&self) -> Result<ChannelState, RecvError> {
144        let (tx, rx) = oneshot::channel();
145        let _ = self.send(ChannelManagerMessage::GetState { res: tx });
146        rx.await
147    }
148    /// Returns the associated channel's topic
149    pub async fn get_topic(&self) -> String {
150        let (tx, rx) = oneshot::channel();
151        let _ = self.send(ChannelManagerMessage::GetTopic { res: tx });
152        rx.await.unwrap()
153    }
154    /// Returns the current [PresenceState] of the associated channel
155    pub async fn get_presence_state(&self) -> PresenceState {
156        let (tx, rx) = oneshot::channel();
157        let _ = self.send(ChannelManagerMessage::GetPresenceState { res: tx });
158        rx.await.unwrap()
159    }
160    /// Return a sync wrapper [ChannelManagerSync] for this manager
161    pub fn to_sync(self) -> ChannelManagerSync {
162        ChannelManagerSync { inner: self }
163    }
164    pub(crate) fn send(
165        &self,
166        message: ChannelManagerMessage,
167    ) -> Result<(), SendError<ChannelManagerMessage>> {
168        self.tx.send(message)
169    }
170    pub(crate) async fn reauth(&self) -> Result<(), RecvError> {
171        let (tx, rx) = oneshot::channel();
172        let _ = self.send(ChannelManagerMessage::ReAuth { res: tx });
173        rx.await
174    }
175    pub(crate) async fn get_tx(&self) -> UnboundedSender<RealtimeMessage> {
176        let (tx, rx) = oneshot::channel();
177        let _ = self.send(ChannelManagerMessage::GetTx { res: tx });
178        rx.await.unwrap()
179    }
180}
181
182#[derive(Clone)]
183pub struct ChannelManagerSync {
184    inner: ChannelManager,
185}
186
187impl ChannelManagerSync {
188    pub fn subscribe(&self) {
189        self.inner.subscribe()
190    }
191    pub fn unsubscribe(&self) -> Result<Result<ChannelState, ChannelSendError>, RecvError> {
192        self.inner.rt.block_on(self.inner.unsubscribe())
193    }
194    pub fn subscribe_blocking(&self) -> Result<(), RecvError> {
195        self.inner.rt.block_on(self.inner.subscribe_blocking())
196    }
197    pub fn broadcast(&self, payload: BroadcastPayload) {
198        self.inner.broadcast(payload)
199    }
200    /// Returns the associated channel's topic
201    pub fn get_topic(&self) -> String {
202        self.inner.rt.block_on(self.inner.get_topic())
203    }
204    pub fn get_state(&self) -> Result<ChannelState, RecvError> {
205        self.inner.rt.block_on(self.inner.get_state())
206    }
207    /// Returns the current [PresenceState] of the associated channel
208    pub fn get_presence_state(&self) -> PresenceState {
209        self.inner.rt.block_on(self.inner.get_presence_state())
210    }
211    pub fn track(&self, payload: HashMap<String, Value>) -> Result<(), RecvError> {
212        self.inner.rt.block_on(self.inner.track(payload))
213    }
214    pub fn untrack(&self) -> Result<(), RecvError> {
215        self.inner.rt.block_on(self.inner.untrack())
216    }
217    /// Unwrap the inner [ChannelManager]. Consumes self.
218    pub fn to_async(self) -> ChannelManager {
219        self.inner
220    }
221}
222
223impl<'a> FromIterator<&'a mut ChannelManager> for Vec<ChannelManager> {
224    fn from_iter<T: IntoIterator<Item = &'a mut ChannelManager>>(iter: T) -> Self {
225        let mut vec = Vec::new();
226        for c in iter {
227            vec.push(c.clone());
228        }
229        vec
230    }
231}
232
233struct RealtimeChannel {
234    pub(crate) topic: String,
235    pub(crate) state: Arc<Mutex<ChannelState>>,
236    pub(crate) id: Uuid,
237    pub(crate) cdc_callbacks: Arc<Mutex<HashMap<PostgresChangesEvent, Vec<CdcCallback>>>>,
238    pub(crate) broadcast_callbacks: Arc<Mutex<HashMap<String, Vec<BroadcastCallback>>>>,
239    pub(crate) client_tx: mpsc::UnboundedSender<RealtimeMessage>,
240    join_payload: JoinPayload,
241    presence: Arc<Mutex<RealtimePresence>>,
242    pub(crate) tx: Option<UnboundedSender<RealtimeMessage>>,
243    pub(crate) manager_channel: (
244        UnboundedSender<ChannelManagerMessage>,
245        UnboundedReceiver<ChannelManagerMessage>,
246    ),
247    pub(crate) message_handle: Option<JoinHandle<()>>,
248    rt: Arc<Runtime>,
249    access_token: Arc<Mutex<String>>,
250}
251
252impl RealtimeChannel {
253    async fn manager_recv(&mut self) {
254        while let Some(control_message) = self.manager_channel.1.recv().await {
255            match control_message {
256                ChannelManagerMessage::Subscribe => {
257                    self.subscribe().await;
258                }
259                ChannelManagerMessage::Unsubscribe { res } => {
260                    res.send(self.unsubscribe().await).unwrap();
261                }
262                ChannelManagerMessage::SubscribeBlocking { res } => {
263                    self.subscribe_blocking(res).await;
264                }
265                ChannelManagerMessage::Broadcast { payload } => {
266                    self.broadcast(payload).await.unwrap();
267                }
268                ChannelManagerMessage::ClientTx { new_tx, res } => {
269                    self.client_tx = new_tx;
270                    res.send(()).unwrap();
271                }
272                ChannelManagerMessage::GetState { res } => {
273                    res.send(*self.state.lock().await).unwrap();
274                }
275                ChannelManagerMessage::GetTx { res } => {
276                    res.send(self.tx.clone().unwrap()).unwrap();
277                }
278                ChannelManagerMessage::GetTopic { res } => {
279                    res.send(self.topic.clone()).unwrap();
280                }
281                ChannelManagerMessage::PresenceTrack { payload, res } => {
282                    self.track(payload).await.unwrap();
283                    res.send(()).unwrap();
284                }
285                ChannelManagerMessage::PresenceUntrack { res } => {
286                    self.untrack().await.unwrap();
287                    res.send(()).unwrap()
288                }
289                ChannelManagerMessage::GetPresenceState { res } => {
290                    let presence = self.presence.lock().await;
291                    res.send(presence.state.clone()).unwrap();
292                }
293                ChannelManagerMessage::ReAuth { res } => {
294                    self.reauth().await.unwrap();
295                    res.send(()).unwrap();
296                }
297            }
298        }
299    }
300
301    /// Send a join request to the channel
302    async fn subscribe(&mut self) {
303        let join_message = RealtimeMessage {
304            event: MessageEvent::PhxJoin,
305            topic: self.topic.clone(),
306            payload: Payload::Join(self.join_payload.clone()),
307            message_ref: Some(self.id.into()),
308        };
309
310        let mut state = self.state.lock().await;
311        *state = ChannelState::Joining;
312        drop(state);
313
314        let _ = self.send(join_message).await;
315    }
316
317    async fn subscribe_blocking(&mut self, tx: Responder<()>) {
318        self.subscribe().await;
319
320        let state = self.state.clone();
321
322        self.rt.spawn(async move {
323            loop {
324                let state = state.lock().await;
325                if *state == ChannelState::Joined {
326                    break;
327                }
328            }
329            tx.send(()).unwrap();
330        });
331    }
332
333    fn client_recv(&mut self) {
334        let (channel_tx, mut channel_rx) = mpsc::unbounded_channel::<RealtimeMessage>();
335        self.tx = Some(channel_tx);
336        let task_state = self.state.clone();
337        let task_cdc_cbs = self.cdc_callbacks.clone();
338        let task_bc_cbs = self.broadcast_callbacks.clone();
339        let id = self.id;
340        let presence = self.presence.clone();
341
342        self.message_handle = Some(self.rt.spawn(async move {
343            while let Some(message) = channel_rx.recv().await {
344                // get locks
345                let mut broadcast_callbacks = task_bc_cbs.lock().await;
346                let mut cdc_callbacks = task_cdc_cbs.lock().await;
347
348                match message.payload {
349                    Payload::Broadcast(payload) => {
350                        if let Some(cb_vec) = broadcast_callbacks.get_mut(&payload.event) {
351                            for cb in cb_vec {
352                                cb(&payload.payload);
353                            }
354                        }
355                    }
356                    Payload::PostgresChanges(ref payload) => {
357                        if let Some(cb_vec) = cdc_callbacks.get_mut(&payload.data.change_type) {
358                            for cb in cb_vec {
359                                if !cb.0.check(&message) {
360                                    continue;
361                                }
362                                cb.1(payload);
363                            }
364                        }
365                        if let Some(cb_vec) = cdc_callbacks.get_mut(&PostgresChangesEvent::All) {
366                            for cb in cb_vec {
367                                if !cb.0.check(&message) {
368                                    continue;
369                                }
370                                cb.1(payload);
371                            }
372                        }
373                    }
374                    Payload::Response(join_response) => {
375                        let target_id = message.message_ref.clone().unwrap_or("".to_string());
376                        if target_id != id.to_string() {
377                            return;
378                        }
379                        if join_response.status == PayloadStatus::Ok {
380                            let mut channel_state = task_state.lock().await;
381                            *channel_state = ChannelState::Joined;
382                            drop(channel_state);
383                        }
384                    }
385                    Payload::PresenceDiff(diff) => {
386                        let mut presence = presence.lock().await;
387                        presence.sync_diff(diff.into());
388                    }
389                    Payload::PresenceState(state) => {
390                        let mut presence = presence.lock().await;
391                        presence.sync(state.into());
392                    }
393                    _ => {
394                        println!("Unmatched payload ;_;")
395                    }
396                }
397
398                drop(broadcast_callbacks);
399                drop(cdc_callbacks);
400            }
401        }));
402    }
403
404    /// Leave the channel
405    async fn unsubscribe(&mut self) -> Result<ChannelState, ChannelSendError> {
406        let state = self.state.clone();
407        {
408            let state = state.lock().await;
409            if *state == ChannelState::Closed || *state == ChannelState::Leaving {
410                return Ok(*state);
411            }
412        }
413
414        match self
415            .send(RealtimeMessage {
416                event: MessageEvent::PhxLeave,
417                topic: self.topic.clone(),
418                payload: Payload::Empty {},
419                message_ref: Some(format!("{}+leave", self.id)),
420            })
421            .await
422        {
423            Ok(()) => {
424                let mut state = state.lock().await;
425                *state = ChannelState::Leaving;
426                Ok(*state)
427            }
428            Err(ChannelSendError::ChannelError(status)) => Ok(status),
429            Err(e) => Err(e),
430        }
431    }
432
433    /// Track provided state in Realtime Presence
434    async fn track(&mut self, payload: HashMap<String, Value>) -> Result<(), ChannelSendError> {
435        self.send(RealtimeMessage {
436            event: MessageEvent::Presence,
437            topic: self.topic.clone(),
438            payload: Payload::PresenceTrack(payload.into()),
439            message_ref: None,
440        })
441        .await
442    }
443
444    /// Sends a message to stop tracking this channel's presence
445    async fn untrack(&mut self) -> Result<(), ChannelSendError> {
446        self.send(RealtimeMessage {
447            event: MessageEvent::Untrack,
448            topic: self.topic.clone(),
449            payload: Payload::Empty {},
450            message_ref: None,
451        })
452        .await
453    }
454
455    async fn send(&mut self, message: RealtimeMessage) -> Result<(), ChannelSendError> {
456        // inject channel topic to message here
457        let mut message = message.clone();
458        message.topic = self.topic.clone();
459
460        let state = self.state.lock().await;
461
462        if *state == ChannelState::Leaving {
463            return Err(ChannelSendError::ChannelError(*state));
464        }
465
466        match self.client_tx.send(message) {
467            Ok(()) => Ok(()),
468            Err(e) => Err(ChannelSendError::SendError(e)),
469        }
470    }
471
472    async fn broadcast(&mut self, payload: BroadcastPayload) -> Result<(), ChannelSendError> {
473        self.send(RealtimeMessage {
474            event: MessageEvent::Broadcast,
475            topic: "".into(),
476            payload: Payload::Broadcast(payload),
477            message_ref: None,
478        })
479        .await
480    }
481
482    async fn reauth(&mut self) -> Result<(), ChannelSendError> {
483        // TODO test this
484        let access_token = self.access_token.lock().await;
485
486        self.join_payload.access_token = access_token.clone();
487
488        let state = self.state.lock().await;
489
490        if *state != ChannelState::Joined {
491            return Ok(());
492        }
493
494        drop(state);
495
496        let access_token_message = RealtimeMessage {
497            event: MessageEvent::AccessToken,
498            topic: self.topic.clone(),
499            payload: Payload::AccessToken(AccessTokenPayload {
500                access_token: access_token.clone(),
501            }),
502            ..Default::default()
503        };
504
505        drop(access_token);
506
507        self.send(access_token_message).await
508    }
509}
510
511/// Builder struct for [RealtimeChannel]
512pub struct RealtimeChannelBuilder {
513    topic: String,
514    broadcast: BroadcastConfig,
515    presence: PresenceConfig,
516    id: Uuid,
517    postgres_changes: Vec<PostgresChange>,
518    cdc_callbacks: HashMap<PostgresChangesEvent, Vec<CdcCallback>>,
519    broadcast_callbacks: HashMap<String, Vec<BroadcastCallback>>,
520    presence_callbacks: PresenceCallbackMap,
521}
522
523impl RealtimeChannelBuilder {
524    /// Create a new channel builder
525    // TODO example code
526    pub fn new(topic: impl Into<String>) -> Self {
527        Self {
528            topic: format!("realtime:{}", topic.into()),
529            broadcast: Default::default(),
530            presence: Default::default(),
531            id: Uuid::new_v4(),
532            postgres_changes: Default::default(),
533            cdc_callbacks: Default::default(),
534            broadcast_callbacks: Default::default(),
535            presence_callbacks: Default::default(),
536        }
537    }
538
539    /// Set the topic of the channel
540    pub fn topic(mut self, topic: impl Into<String>) -> Self {
541        self.topic = format!("realtime:{}", topic.into());
542        self
543    }
544
545    /// Set the broadcast config for this channel
546    pub fn broadcast(mut self, broadcast_config: BroadcastConfig) -> Self {
547        self.broadcast = broadcast_config;
548        self
549    }
550
551    /// Set the presence config for this channel
552    pub fn presence(mut self, presence_config: PresenceConfig) -> Self {
553        self.presence = presence_config;
554        self
555    }
556
557    /// Add a postgres changes callback to this channel
558    pub fn on_postgres_change(
559        mut self,
560        event: PostgresChangesEvent,
561        filter: PostgresChangeFilter,
562        callback: impl FnMut(&PostgresChangesPayload) + 'static + Send,
563    ) -> Self {
564        self.postgres_changes.push(PostgresChange {
565            event: event.clone(),
566            schema: filter.schema.clone(),
567            table: filter.table.clone().unwrap_or("".into()),
568            filter: filter.filter.clone(),
569        });
570
571        if self.cdc_callbacks.get_mut(&event).is_none() {
572            self.cdc_callbacks.insert(event.clone(), vec![]);
573        }
574
575        self.cdc_callbacks
576            .get_mut(&event)
577            .unwrap_or(&mut vec![])
578            .push((filter, Box::new(callback)));
579
580        self
581    }
582
583    /// Add a presence callback to this channel
584    pub fn on_presence(
585        mut self,
586        event: PresenceEvent,
587        callback: impl Fn(String, PresenceState, PresenceState) + Send + 'static,
588    ) -> Self {
589        if self.presence_callbacks.get_mut(&event).is_none() {
590            self.presence_callbacks.insert(event.clone(), vec![]);
591        }
592
593        self.presence_callbacks
594            .get_mut(&event)
595            .unwrap_or(&mut vec![])
596            .push(Box::new(callback));
597
598        self
599    }
600
601    /// Add a broadcast callback to this channel
602    pub fn on_broadcast(
603        mut self,
604        event: impl Into<String>,
605        callback: impl FnMut(&HashMap<String, Value>) + 'static + Send,
606    ) -> Self {
607        let event: String = event.into();
608
609        if self.broadcast_callbacks.get_mut(&event).is_none() {
610            self.broadcast_callbacks.insert(event.clone(), vec![]);
611        }
612
613        self.broadcast_callbacks
614            .get_mut(&event)
615            .unwrap_or(&mut vec![])
616            .push(Box::new(callback));
617
618        self
619    }
620
621    fn build_common(
622        self,
623        client_tx: UnboundedSender<RealtimeMessage>,
624        access_token: String,
625        access_token_arc: Arc<Mutex<String>>,
626        rt: Arc<Runtime>,
627    ) -> ChannelManager {
628        let state = Arc::new(Mutex::new(ChannelState::Closed));
629        let cdc_callbacks = Arc::new(Mutex::new(self.cdc_callbacks));
630        let broadcast_callbacks = Arc::new(Mutex::new(self.broadcast_callbacks));
631        let (controller_tx, controller_rx) = mpsc::unbounded_channel::<ChannelManagerMessage>();
632
633        let mut channel = RealtimeChannel {
634            access_token: access_token_arc,
635            rt: rt.clone(),
636            tx: None,
637            topic: self.topic,
638            cdc_callbacks,
639            broadcast_callbacks,
640            client_tx,
641            state,
642            id: self.id,
643            join_payload: JoinPayload {
644                config: JoinConfig {
645                    broadcast: self.broadcast,
646                    presence: self.presence,
647                    postgres_changes: self.postgres_changes,
648                },
649                access_token,
650            },
651            presence: Arc::new(Mutex::new(RealtimePresence::from_channel_builder(
652                self.presence_callbacks,
653            ))),
654            manager_channel: (controller_tx, controller_rx),
655            message_handle: None,
656        };
657
658        channel.client_recv(); // Spawns task, sets channel.tx
659        let tx = channel.manager_channel.0.clone();
660
661        let _handle = rt.spawn(async move { channel.manager_recv().await });
662
663        ChannelManager { tx, rt }
664    }
665
666    // TODO unify the builds using a clientmanager trait. Need async-trait
667
668    /// Consume self and return a new [ChannelManagerSync] that controls the newly created channel
669    /// Automatically assigns the new channel in the client.
670    ///
671    /// For async applications you may want `self::build()`
672    pub fn build_sync(self, client: &ClientManagerSync) -> Result<ChannelManagerSync, RecvError> {
673        let client_tx = client.clone().get_ws_tx().unwrap();
674        let access_token = client.clone().get_access_token().unwrap();
675        let access_token_arc = client.clone().get_access_token_arc().unwrap();
676
677        let channel_manager =
678            self.build_common(client_tx, access_token, access_token_arc, client.get_rt());
679
680        client.add_channel(channel_manager.clone()).unwrap();
681
682        Ok(channel_manager.to_sync())
683    }
684
685    /// Consume self and return a new [ChannelManager] that controls the newly created channel
686    /// Automatically assigns the new channel in the client.
687    ///
688    /// For sync applications you may want `self::build_sync()`
689    pub async fn build(self, client: &ClientManager) -> Result<ChannelManager, RecvError> {
690        let client_tx = client.clone().get_ws_tx().await?;
691        let access_token = client.clone().get_access_token().await?;
692        let access_token_arc = client.clone().get_access_token_arc().await?;
693
694        let channel_manager =
695            self.build_common(client_tx, access_token, access_token_arc, client.get_rt());
696
697        client.add_channel(channel_manager.clone()).await.unwrap();
698
699        Ok(channel_manager)
700    }
701}