bevy_realtime/
channel.rs

1use bevy::{
2    ecs::{event::Event, system::SystemId},
3    log::debug,
4    prelude::In,
5};
6use bevy_crossbeam_event::CrossbeamEventSender;
7use crossbeam::channel::{unbounded, Receiver, SendError, Sender};
8use serde_json::Value;
9use uuid::Uuid;
10
11use super::client::ClientManager;
12use crate::{
13    message::{
14        payload::{
15            AccessTokenPayload, BroadcastConfig, BroadcastPayload, JoinConfig, JoinPayload,
16            Payload, PayloadStatus, PostgresChange, PostgresChangesEvent, PostgresChangesPayload,
17            PresenceConfig,
18        },
19        postgres_change_filter::PostgresChangeFilter,
20        realtime_message::{MessageEvent, RealtimeMessage},
21    },
22    presence::PresenceCallbackEvent,
23};
24
25use super::client::Client;
26use crate::presence::{Presence, PresenceCallback, PresenceEvent, PresenceState};
27use std::fmt::Debug;
28use std::{collections::HashMap, error::Error};
29
30#[derive(Clone)]
31struct BroadcastCallback(SystemId<In<HashMap<String, Value>>>);
32
33#[derive(Event, Clone)]
34pub struct BroadcastCallbackEvent(
35    pub (SystemId<In<HashMap<String, Value>>>, HashMap<String, Value>),
36);
37
38#[derive(Clone)]
39struct PostgresChangesCallback(PostgresChangeFilter, SystemId<In<PostgresChangesPayload>>);
40
41#[derive(Event, Clone)]
42pub struct PostgresChangesCallbackEvent(
43    pub (SystemId<In<PostgresChangesPayload>>, PostgresChangesPayload),
44);
45
46/// Channel states
47#[derive(PartialEq, Clone, Copy, Debug)]
48pub enum ChannelState {
49    Closed,
50    Errored,
51    Joined,
52    Joining,
53    Leaving,
54}
55
56#[derive(Clone)]
57pub struct ChannelManager {
58    pub tx: Sender<ChannelManagerMessage>,
59}
60
61pub enum ChannelManagerMessage {
62    Broadcast {
63        payload: BroadcastPayload,
64    },
65    Subscribe,
66    Track {
67        payload: HashMap<String, Value>,
68    },
69    Untrack,
70    PresenceState {
71        callback: SystemId<In<PresenceState>>,
72    },
73    ChannelState {
74        callback: SystemId<In<ChannelState>>,
75    },
76}
77
78impl ChannelManager {
79    pub fn broadcast(
80        &self,
81        payload: BroadcastPayload,
82    ) -> Result<(), SendError<ChannelManagerMessage>> {
83        self.tx.send(ChannelManagerMessage::Broadcast { payload })
84    }
85
86    pub fn subscribe(&self) -> Result<(), SendError<ChannelManagerMessage>> {
87        self.tx.send(ChannelManagerMessage::Subscribe)
88    }
89
90    pub fn track(
91        &self,
92        payload: HashMap<String, Value>,
93    ) -> Result<(), SendError<ChannelManagerMessage>> {
94        self.tx.send(ChannelManagerMessage::Track { payload })
95    }
96
97    pub fn untrack(&self) -> Result<(), SendError<ChannelManagerMessage>> {
98        self.tx.send(ChannelManagerMessage::Untrack)
99    }
100
101    pub fn presence_state(
102        &self,
103        callback: SystemId<In<PresenceState>>,
104    ) -> Result<(), SendError<ChannelManagerMessage>> {
105        self.tx
106            .send(ChannelManagerMessage::PresenceState { callback })
107    }
108
109    pub fn channel_state(
110        &self,
111        callback: SystemId<In<ChannelState>>,
112    ) -> Result<(), SendError<ChannelManagerMessage>> {
113        self.tx
114            .send(ChannelManagerMessage::ChannelState { callback })
115    }
116}
117
118#[derive(Event, Clone)]
119pub struct PresenceStateCallbackEvent(pub (SystemId<In<PresenceState>>, PresenceState));
120
121#[derive(Event, Clone)]
122pub struct ChannelStateCallbackEvent(pub (SystemId<In<ChannelState>>, ChannelState));
123
124/// Channel structure
125pub struct RealtimeChannel {
126    pub(crate) topic: String,
127    pub(crate) connection_state: ChannelState,
128    pub(crate) id: Uuid,
129    postgres_changes_callbacks: HashMap<PostgresChangesEvent, Vec<PostgresChangesCallback>>,
130    broadcast_callbacks: HashMap<String, Vec<BroadcastCallback>>,
131    join_payload: JoinPayload,
132    presence: Presence,
133    // sync bridge
134    tx: Sender<RealtimeMessage>,
135    manager_rx: Receiver<ChannelManagerMessage>,
136    presence_state_callback_event_sender: CrossbeamEventSender<PresenceStateCallbackEvent>,
137    channel_state_callback_event_sender: CrossbeamEventSender<ChannelStateCallbackEvent>,
138    broadcast_callback_event_sender: CrossbeamEventSender<BroadcastCallbackEvent>,
139    postgres_changes_callback_event_sender: CrossbeamEventSender<PostgresChangesCallbackEvent>,
140}
141
142// TODO channel options with broadcast + presence settings
143
144impl RealtimeChannel {
145    pub(crate) fn manager_recv(&mut self) -> Result<(), Box<dyn Error>> {
146        while let Ok(message) = self.manager_rx.try_recv() {
147            match message {
148                ChannelManagerMessage::Broadcast { payload } => self.broadcast(payload)?,
149                ChannelManagerMessage::Subscribe => self.subscribe()?,
150                ChannelManagerMessage::Track { payload } => self.track(payload)?,
151                ChannelManagerMessage::Untrack => self.untrack()?,
152                ChannelManagerMessage::PresenceState { callback } => self
153                    .presence_state_callback_event_sender
154                    .send(PresenceStateCallbackEvent((
155                        callback,
156                        self.presence_state(),
157                    ))),
158                ChannelManagerMessage::ChannelState { callback } => self
159                    .channel_state_callback_event_sender
160                    .send(ChannelStateCallbackEvent((callback, self.channel_state()))),
161            }
162        }
163
164        Ok(())
165    }
166    /// Returns the channel's connection state
167    fn channel_state(&self) -> ChannelState {
168        self.connection_state
169    }
170
171    /// Send a join request to the channel
172    /// Does not block, for blocking behaviour use [RealtimeClient::block_until_subscribed()]
173    pub(crate) fn subscribe(&mut self) -> Result<(), SendError<RealtimeMessage>> {
174        let join_message = RealtimeMessage {
175            event: MessageEvent::PhxJoin,
176            topic: self.topic.clone(),
177            payload: Payload::Join(self.join_payload.clone()),
178            message_ref: Some(self.id.into()),
179        };
180
181        self.connection_state = ChannelState::Joining;
182
183        self.tx.send(join_message)
184    }
185
186    /// Leave the channel
187    pub(crate) fn unsubscribe(&mut self) -> Result<ChannelState, SendError<RealtimeMessage>> {
188        if self.connection_state == ChannelState::Closed
189            || self.connection_state == ChannelState::Leaving
190        {
191            return Ok(self.connection_state);
192        }
193
194        let message = RealtimeMessage {
195            event: MessageEvent::PhxLeave,
196            topic: self.topic.clone(),
197            payload: Payload::Empty {},
198            message_ref: Some(format!("{}+leave", self.id)),
199        };
200
201        match self.send(message) {
202            Ok(()) => {
203                self.connection_state = ChannelState::Leaving;
204                Ok(self.connection_state)
205            }
206            Err(e) => Err(e),
207        }
208    }
209
210    /// Returns the current [PresenceState] of the channel
211    fn presence_state(&self) -> PresenceState {
212        self.presence.state.clone()
213    }
214
215    /// Track provided state in Realtime Presence
216    fn track(&mut self, payload: HashMap<String, Value>) -> Result<(), SendError<RealtimeMessage>> {
217        self.send(RealtimeMessage {
218            event: MessageEvent::Presence,
219            topic: self.topic.clone(),
220            payload: Payload::PresenceTrack(payload.into()),
221            message_ref: None,
222        })
223    }
224
225    /// Sends a message to stop tracking this channel's presence
226    fn untrack(&mut self) -> Result<(), SendError<RealtimeMessage>> {
227        self.send(RealtimeMessage {
228            event: MessageEvent::Untrack,
229            topic: self.topic.clone(),
230            payload: Payload::Empty {},
231            message_ref: None,
232        })
233    }
234
235    /// Send a [RealtimeMessage] on this channel
236    fn send(&mut self, message: RealtimeMessage) -> Result<(), SendError<RealtimeMessage>> {
237        // inject channel topic to message here
238        let mut message = message.clone();
239        message.topic.clone_from(&self.topic);
240
241        if self.connection_state == ChannelState::Leaving {
242            return Err(SendError(message));
243        }
244
245        self.tx.send(message)
246    }
247
248    /// Helper function for sending broadcast messages
249    fn broadcast(&mut self, payload: BroadcastPayload) -> Result<(), SendError<RealtimeMessage>> {
250        self.send(RealtimeMessage {
251            event: MessageEvent::Broadcast,
252            topic: "".into(),
253            payload: Payload::Broadcast(payload),
254            message_ref: None,
255        })
256    }
257
258    pub(crate) fn set_auth(
259        &mut self,
260        access_token: String,
261    ) -> Result<(), SendError<RealtimeMessage>> {
262        self.join_payload.access_token.clone_from(&access_token);
263
264        if self.connection_state != ChannelState::Joined {
265            return Ok(());
266        }
267
268        let access_token_message = RealtimeMessage {
269            event: MessageEvent::AccessToken,
270            topic: self.topic.clone(),
271            payload: Payload::AccessToken(AccessTokenPayload { access_token }),
272            ..Default::default()
273        };
274
275        self.send(access_token_message)
276    }
277
278    pub(crate) fn recieve(&mut self, message: RealtimeMessage) {
279        match &message.payload {
280            Payload::Response(join_response) => {
281                let target_id = message.message_ref.clone().unwrap_or("".to_string());
282                if target_id != self.id.to_string() {
283                    return;
284                }
285                if join_response.status == PayloadStatus::Ok {
286                    self.connection_state = ChannelState::Joined;
287                }
288            }
289            Payload::PresenceState(state) => self.presence.sync(state.clone().into()),
290            Payload::PresenceDiff(raw_diff) => {
291                self.presence.sync_diff(raw_diff.clone().into());
292            }
293            Payload::PostgresChanges(payload) => {
294                let event = &payload.data.change_type;
295
296                for callback in self
297                    .postgres_changes_callbacks
298                    .get_mut(event)
299                    .unwrap_or(&mut vec![])
300                {
301                    let filter = &callback.0;
302
303                    // TODO REFAC pointless message clones when not using result; filter.check
304                    // should borrow and return bool/result
305                    if let Some(_message) = filter.check(message.clone()) {
306                        self.postgres_changes_callback_event_sender
307                            .send(PostgresChangesCallbackEvent((callback.1, payload.clone())));
308                    }
309                }
310
311                for callback in self
312                    .postgres_changes_callbacks
313                    .get_mut(&PostgresChangesEvent::All)
314                    .unwrap_or(&mut vec![])
315                {
316                    let filter = &callback.0;
317
318                    if let Some(_message) = filter.check(message.clone()) {
319                        self.postgres_changes_callback_event_sender
320                            .send(PostgresChangesCallbackEvent((callback.1, payload.clone())));
321                    }
322                }
323            }
324            Payload::Broadcast(payload) => {
325                if let Some(callbacks) = self.broadcast_callbacks.get_mut(&payload.event) {
326                    for cb in callbacks {
327                        self.broadcast_callback_event_sender
328                            .send(BroadcastCallbackEvent((cb.0, payload.payload.clone())));
329                    }
330                }
331            }
332            _ => {}
333        }
334
335        match &message.event {
336            MessageEvent::PhxClose => {
337                if let Some(message_ref) = message.message_ref {
338                    if message_ref == self.id.to_string() {
339                        self.connection_state = ChannelState::Closed;
340                        debug!("Channel Closed! {:?}", self.id);
341                    }
342                }
343            }
344            MessageEvent::PhxReply => {
345                if message.message_ref.clone().unwrap_or("#NOREF".to_string())
346                    == format!("{}+leave", self.id)
347                {
348                    self.connection_state = ChannelState::Closed;
349                    debug!("Channel Closed! {:?}", self.id);
350                }
351            }
352            _ => {}
353        }
354    }
355}
356
357impl Debug for RealtimeChannel {
358    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359        f.write_str(&format!(
360            "RealtimeChannel {{ name: {:?}, callbacks: [TODO DEBUG]}}",
361            self.topic
362        ))
363    }
364}
365
366/// Builder struct for [RealtimeChannel]
367///
368/// Get access to this through [RealtimeClient::channel()]
369#[derive(Event, Clone)]
370pub struct ChannelBuilder {
371    topic: String,
372    access_token: String,
373    broadcast: BroadcastConfig,
374    presence: PresenceConfig,
375    id: Uuid,
376    postgres_changes: Vec<PostgresChange>,
377    cdc_callbacks: HashMap<PostgresChangesEvent, Vec<PostgresChangesCallback>>,
378    broadcast_callbacks: HashMap<String, Vec<BroadcastCallback>>,
379    presence_callbacks: HashMap<PresenceEvent, Vec<PresenceCallback>>,
380    tx: Sender<RealtimeMessage>,
381}
382
383impl ChannelBuilder {
384    pub(crate) fn new(client: &mut Client) -> Self {
385        Self {
386            topic: "no_topic".into(),
387            access_token: client.access_token.clone(),
388            broadcast: Default::default(),
389            presence: Default::default(),
390            id: Uuid::new_v4(),
391            postgres_changes: Default::default(),
392            cdc_callbacks: Default::default(),
393            broadcast_callbacks: Default::default(),
394            presence_callbacks: Default::default(),
395            tx: client.get_channel_tx(),
396        }
397    }
398
399    /// Set the topic of the channel
400    pub fn topic(&mut self, topic: impl Into<String>) -> &mut Self {
401        self.topic = format!("realtime:{}", topic.into());
402        self
403    }
404
405    /// Set the broadcast config for this channel
406    pub fn set_broadcast_config(&mut self, broadcast_config: BroadcastConfig) -> &mut Self {
407        self.broadcast = broadcast_config;
408        self
409    }
410
411    /// Set the presence config for this channel
412    pub fn set_presence_config(&mut self, presence_config: PresenceConfig) -> &mut Self {
413        self.presence = presence_config;
414        self
415    }
416
417    /// Add a postgres changes callback to this channel
418    pub fn on_postgres_change(
419        &mut self,
420        event: PostgresChangesEvent,
421        filter: PostgresChangeFilter,
422        callback: SystemId<In<PostgresChangesPayload>>,
423    ) -> &mut Self {
424        self.postgres_changes.push(PostgresChange {
425            event: event.clone(),
426            schema: filter.schema.clone(),
427            table: filter.table.clone().unwrap_or("".into()),
428            filter: filter.filter.clone(),
429        });
430
431        if self.cdc_callbacks.get_mut(&event).is_none() {
432            self.cdc_callbacks.insert(event.clone(), vec![]);
433        }
434
435        self.cdc_callbacks
436            .get_mut(&event)
437            .unwrap_or(&mut vec![])
438            .push(PostgresChangesCallback(filter, callback));
439
440        self
441    }
442
443    /// Add a presence callback to this channel
444    ///```
445    pub fn on_presence(
446        &mut self,
447        event: PresenceEvent,
448        callback: SystemId<In<(String, PresenceState, PresenceState)>>,
449    ) -> &mut Self {
450        if self.presence_callbacks.get_mut(&event).is_none() {
451            self.presence_callbacks.insert(event.clone(), vec![]);
452        }
453
454        self.presence_callbacks
455            .get_mut(&event)
456            .unwrap_or(&mut vec![])
457            .push(PresenceCallback(callback));
458
459        self
460    }
461
462    /// Add a broadcast callback to this channel
463    pub fn on_broadcast(
464        &mut self,
465        event: impl Into<String>,
466        callback: SystemId<In<HashMap<String, Value>>>,
467    ) -> &mut Self {
468        let event: String = event.into();
469
470        if self.broadcast_callbacks.get_mut(&event).is_none() {
471            self.broadcast_callbacks.insert(event.clone(), vec![]);
472        }
473
474        self.broadcast_callbacks
475            .get_mut(&event)
476            .unwrap_or(&mut vec![])
477            .push(BroadcastCallback(callback));
478
479        self
480    }
481
482    // TODO on_message handler for sys messages
483
484    /// Create the channel and pass ownership to provided [RealtimeClient], returning the channel
485    /// id for later access through the client
486    pub fn build(
487        &self,
488        client: &ClientManager,
489        presence_state_callback_event_sender: CrossbeamEventSender<PresenceStateCallbackEvent>,
490        channel_state_callback_event_sender: CrossbeamEventSender<ChannelStateCallbackEvent>,
491        broadcast_callback_event_sender: CrossbeamEventSender<BroadcastCallbackEvent>,
492        presence_callback_event_sender: CrossbeamEventSender<PresenceCallbackEvent>,
493        postgres_changes_callback_event_sender: CrossbeamEventSender<PostgresChangesCallbackEvent>,
494    ) -> ChannelManager {
495        let manager_channel = unbounded();
496
497        client
498            .add_channel(RealtimeChannel {
499                topic: self.topic.clone(),
500                postgres_changes_callbacks: self.cdc_callbacks.clone(),
501                broadcast_callbacks: self.broadcast_callbacks.clone(),
502                tx: self.tx.clone(),
503                manager_rx: manager_channel.1,
504                connection_state: ChannelState::Closed,
505                id: self.id,
506                join_payload: JoinPayload {
507                    config: JoinConfig {
508                        broadcast: self.broadcast.clone(),
509                        presence: self.presence.clone(),
510                        postgres_changes: self.postgres_changes.clone(),
511                    },
512                    access_token: self.access_token.clone(),
513                },
514                presence: Presence::from_channel_builder(
515                    self.presence_callbacks.clone(),
516                    presence_callback_event_sender,
517                ),
518                presence_state_callback_event_sender,
519                channel_state_callback_event_sender,
520                broadcast_callback_event_sender,
521                postgres_changes_callback_event_sender,
522            })
523            .unwrap();
524
525        ChannelManager {
526            tx: manager_channel.0,
527        }
528    }
529}