1use crate::realtime_client::ClientManager;
2use crate::realtime_client::ClientManagerSync;
3use crate::realtime_presence::PresenceCallbackMap;
4use crate::realtime_presence::RealtimePresence;
5use crate::Responder;
6
7use serde_json::Value;
8use tokio::{
9 runtime::Runtime,
10 sync::{
11 mpsc::{self, error::SendError, UnboundedReceiver, UnboundedSender},
12 oneshot::{self, error::RecvError},
13 Mutex,
14 },
15 task::JoinHandle,
16};
17use uuid::Uuid;
18
19use crate::message::{
20 payload::{
21 AccessTokenPayload, BroadcastConfig, BroadcastPayload, JoinConfig, JoinPayload, Payload,
22 PayloadStatus, PostgresChange, PostgresChangesEvent, PostgresChangesPayload,
23 PresenceConfig,
24 },
25 presence::{PresenceEvent, PresenceState},
26 MessageEvent, PostgresChangeFilter, RealtimeMessage,
27};
28
29use std::fmt::Debug;
30use std::{collections::HashMap, sync::Arc};
31
32type CdcCallback = (
33 PostgresChangeFilter,
34 Box<dyn FnMut(&PostgresChangesPayload) + Send>,
35);
36type BroadcastCallback = Box<dyn FnMut(&HashMap<String, Value>) + Send>;
37pub(crate) type PresenceCallback = Box<dyn Fn(String, PresenceState, PresenceState) + Send>;
38
39#[derive(PartialEq, Clone, Copy, Debug)]
41pub enum ChannelState {
42 Closed,
43 Errored,
44 Joined,
45 Joining,
46 Leaving,
47}
48
49#[derive(Debug)]
51pub enum ChannelSendError {
52 NoChannel,
53 SendError(SendError<RealtimeMessage>),
54 ChannelError(ChannelState),
55}
56
57pub(crate) enum ChannelManagerMessage {
58 Subscribe,
59 Unsubscribe {
60 res: Responder<Result<ChannelState, ChannelSendError>>,
61 },
62 SubscribeBlocking {
63 res: Responder<()>,
64 },
65 Broadcast {
66 payload: BroadcastPayload,
67 },
68 ClientTx {
69 new_tx: UnboundedSender<RealtimeMessage>,
70 res: Responder<()>,
71 },
72 GetState {
73 res: Responder<ChannelState>,
74 },
75 GetTx {
76 res: Responder<UnboundedSender<RealtimeMessage>>,
77 },
78 GetTopic {
79 res: Responder<String>,
80 },
81 GetPresenceState {
82 res: Responder<PresenceState>,
83 },
84 PresenceTrack {
85 payload: HashMap<String, Value>,
86 res: Responder<()>,
87 },
88 PresenceUntrack {
89 res: Responder<()>,
90 },
91 ReAuth {
92 res: Responder<()>,
93 },
94}
95
96#[derive(Clone, Debug)]
101pub struct ChannelManager {
102 pub(crate) tx: UnboundedSender<ChannelManagerMessage>,
103 rt: Arc<Runtime>,
104}
105
106impl ChannelManager {
107 pub fn subscribe(&self) {
109 let _ = self.send(ChannelManagerMessage::Subscribe);
110 }
111 pub async fn unsubscribe(&self) -> Result<Result<ChannelState, ChannelSendError>, RecvError> {
116 let (tx, rx) = oneshot::channel();
117 let _ = self.send(ChannelManagerMessage::Unsubscribe { res: tx });
118 rx.await
119 }
120 pub async fn subscribe_blocking(&self) -> Result<(), oneshot::error::RecvError> {
122 let (tx, rx) = oneshot::channel();
123 let _ = self.send(ChannelManagerMessage::SubscribeBlocking { res: tx });
124 rx.await
125 }
126 pub fn broadcast(&self, payload: BroadcastPayload) {
128 let _ = self.send(ChannelManagerMessage::Broadcast { payload });
129 }
130 pub async fn track(&self, payload: HashMap<String, Value>) -> Result<(), RecvError> {
132 let (tx, rx) = oneshot::channel();
133 let _ = self.send(ChannelManagerMessage::PresenceTrack { payload, res: tx });
134 rx.await
135 }
136 pub async fn untrack(&self) -> Result<(), RecvError> {
138 let (tx, rx) = oneshot::channel();
139 let _ = self.send(ChannelManagerMessage::PresenceUntrack { res: tx });
140 rx.await
141 }
142 pub async fn get_state(&self) -> Result<ChannelState, RecvError> {
144 let (tx, rx) = oneshot::channel();
145 let _ = self.send(ChannelManagerMessage::GetState { res: tx });
146 rx.await
147 }
148 pub async fn get_topic(&self) -> String {
150 let (tx, rx) = oneshot::channel();
151 let _ = self.send(ChannelManagerMessage::GetTopic { res: tx });
152 rx.await.unwrap()
153 }
154 pub async fn get_presence_state(&self) -> PresenceState {
156 let (tx, rx) = oneshot::channel();
157 let _ = self.send(ChannelManagerMessage::GetPresenceState { res: tx });
158 rx.await.unwrap()
159 }
160 pub fn to_sync(self) -> ChannelManagerSync {
162 ChannelManagerSync { inner: self }
163 }
164 pub(crate) fn send(
165 &self,
166 message: ChannelManagerMessage,
167 ) -> Result<(), SendError<ChannelManagerMessage>> {
168 self.tx.send(message)
169 }
170 pub(crate) async fn reauth(&self) -> Result<(), RecvError> {
171 let (tx, rx) = oneshot::channel();
172 let _ = self.send(ChannelManagerMessage::ReAuth { res: tx });
173 rx.await
174 }
175 pub(crate) async fn get_tx(&self) -> UnboundedSender<RealtimeMessage> {
176 let (tx, rx) = oneshot::channel();
177 let _ = self.send(ChannelManagerMessage::GetTx { res: tx });
178 rx.await.unwrap()
179 }
180}
181
182#[derive(Clone)]
183pub struct ChannelManagerSync {
184 inner: ChannelManager,
185}
186
187impl ChannelManagerSync {
188 pub fn subscribe(&self) {
189 self.inner.subscribe()
190 }
191 pub fn unsubscribe(&self) -> Result<Result<ChannelState, ChannelSendError>, RecvError> {
192 self.inner.rt.block_on(self.inner.unsubscribe())
193 }
194 pub fn subscribe_blocking(&self) -> Result<(), RecvError> {
195 self.inner.rt.block_on(self.inner.subscribe_blocking())
196 }
197 pub fn broadcast(&self, payload: BroadcastPayload) {
198 self.inner.broadcast(payload)
199 }
200 pub fn get_topic(&self) -> String {
202 self.inner.rt.block_on(self.inner.get_topic())
203 }
204 pub fn get_state(&self) -> Result<ChannelState, RecvError> {
205 self.inner.rt.block_on(self.inner.get_state())
206 }
207 pub fn get_presence_state(&self) -> PresenceState {
209 self.inner.rt.block_on(self.inner.get_presence_state())
210 }
211 pub fn track(&self, payload: HashMap<String, Value>) -> Result<(), RecvError> {
212 self.inner.rt.block_on(self.inner.track(payload))
213 }
214 pub fn untrack(&self) -> Result<(), RecvError> {
215 self.inner.rt.block_on(self.inner.untrack())
216 }
217 pub fn to_async(self) -> ChannelManager {
219 self.inner
220 }
221}
222
223impl<'a> FromIterator<&'a mut ChannelManager> for Vec<ChannelManager> {
224 fn from_iter<T: IntoIterator<Item = &'a mut ChannelManager>>(iter: T) -> Self {
225 let mut vec = Vec::new();
226 for c in iter {
227 vec.push(c.clone());
228 }
229 vec
230 }
231}
232
233struct RealtimeChannel {
234 pub(crate) topic: String,
235 pub(crate) state: Arc<Mutex<ChannelState>>,
236 pub(crate) id: Uuid,
237 pub(crate) cdc_callbacks: Arc<Mutex<HashMap<PostgresChangesEvent, Vec<CdcCallback>>>>,
238 pub(crate) broadcast_callbacks: Arc<Mutex<HashMap<String, Vec<BroadcastCallback>>>>,
239 pub(crate) client_tx: mpsc::UnboundedSender<RealtimeMessage>,
240 join_payload: JoinPayload,
241 presence: Arc<Mutex<RealtimePresence>>,
242 pub(crate) tx: Option<UnboundedSender<RealtimeMessage>>,
243 pub(crate) manager_channel: (
244 UnboundedSender<ChannelManagerMessage>,
245 UnboundedReceiver<ChannelManagerMessage>,
246 ),
247 pub(crate) message_handle: Option<JoinHandle<()>>,
248 rt: Arc<Runtime>,
249 access_token: Arc<Mutex<String>>,
250}
251
252impl RealtimeChannel {
253 async fn manager_recv(&mut self) {
254 while let Some(control_message) = self.manager_channel.1.recv().await {
255 match control_message {
256 ChannelManagerMessage::Subscribe => {
257 self.subscribe().await;
258 }
259 ChannelManagerMessage::Unsubscribe { res } => {
260 res.send(self.unsubscribe().await).unwrap();
261 }
262 ChannelManagerMessage::SubscribeBlocking { res } => {
263 self.subscribe_blocking(res).await;
264 }
265 ChannelManagerMessage::Broadcast { payload } => {
266 self.broadcast(payload).await.unwrap();
267 }
268 ChannelManagerMessage::ClientTx { new_tx, res } => {
269 self.client_tx = new_tx;
270 res.send(()).unwrap();
271 }
272 ChannelManagerMessage::GetState { res } => {
273 res.send(*self.state.lock().await).unwrap();
274 }
275 ChannelManagerMessage::GetTx { res } => {
276 res.send(self.tx.clone().unwrap()).unwrap();
277 }
278 ChannelManagerMessage::GetTopic { res } => {
279 res.send(self.topic.clone()).unwrap();
280 }
281 ChannelManagerMessage::PresenceTrack { payload, res } => {
282 self.track(payload).await.unwrap();
283 res.send(()).unwrap();
284 }
285 ChannelManagerMessage::PresenceUntrack { res } => {
286 self.untrack().await.unwrap();
287 res.send(()).unwrap()
288 }
289 ChannelManagerMessage::GetPresenceState { res } => {
290 let presence = self.presence.lock().await;
291 res.send(presence.state.clone()).unwrap();
292 }
293 ChannelManagerMessage::ReAuth { res } => {
294 self.reauth().await.unwrap();
295 res.send(()).unwrap();
296 }
297 }
298 }
299 }
300
301 async fn subscribe(&mut self) {
303 let join_message = RealtimeMessage {
304 event: MessageEvent::PhxJoin,
305 topic: self.topic.clone(),
306 payload: Payload::Join(self.join_payload.clone()),
307 message_ref: Some(self.id.into()),
308 };
309
310 let mut state = self.state.lock().await;
311 *state = ChannelState::Joining;
312 drop(state);
313
314 let _ = self.send(join_message).await;
315 }
316
317 async fn subscribe_blocking(&mut self, tx: Responder<()>) {
318 self.subscribe().await;
319
320 let state = self.state.clone();
321
322 self.rt.spawn(async move {
323 loop {
324 let state = state.lock().await;
325 if *state == ChannelState::Joined {
326 break;
327 }
328 }
329 tx.send(()).unwrap();
330 });
331 }
332
333 fn client_recv(&mut self) {
334 let (channel_tx, mut channel_rx) = mpsc::unbounded_channel::<RealtimeMessage>();
335 self.tx = Some(channel_tx);
336 let task_state = self.state.clone();
337 let task_cdc_cbs = self.cdc_callbacks.clone();
338 let task_bc_cbs = self.broadcast_callbacks.clone();
339 let id = self.id;
340 let presence = self.presence.clone();
341
342 self.message_handle = Some(self.rt.spawn(async move {
343 while let Some(message) = channel_rx.recv().await {
344 let mut broadcast_callbacks = task_bc_cbs.lock().await;
346 let mut cdc_callbacks = task_cdc_cbs.lock().await;
347
348 match message.payload {
349 Payload::Broadcast(payload) => {
350 if let Some(cb_vec) = broadcast_callbacks.get_mut(&payload.event) {
351 for cb in cb_vec {
352 cb(&payload.payload);
353 }
354 }
355 }
356 Payload::PostgresChanges(ref payload) => {
357 if let Some(cb_vec) = cdc_callbacks.get_mut(&payload.data.change_type) {
358 for cb in cb_vec {
359 if !cb.0.check(&message) {
360 continue;
361 }
362 cb.1(payload);
363 }
364 }
365 if let Some(cb_vec) = cdc_callbacks.get_mut(&PostgresChangesEvent::All) {
366 for cb in cb_vec {
367 if !cb.0.check(&message) {
368 continue;
369 }
370 cb.1(payload);
371 }
372 }
373 }
374 Payload::Response(join_response) => {
375 let target_id = message.message_ref.clone().unwrap_or("".to_string());
376 if target_id != id.to_string() {
377 return;
378 }
379 if join_response.status == PayloadStatus::Ok {
380 let mut channel_state = task_state.lock().await;
381 *channel_state = ChannelState::Joined;
382 drop(channel_state);
383 }
384 }
385 Payload::PresenceDiff(diff) => {
386 let mut presence = presence.lock().await;
387 presence.sync_diff(diff.into());
388 }
389 Payload::PresenceState(state) => {
390 let mut presence = presence.lock().await;
391 presence.sync(state.into());
392 }
393 _ => {
394 println!("Unmatched payload ;_;")
395 }
396 }
397
398 drop(broadcast_callbacks);
399 drop(cdc_callbacks);
400 }
401 }));
402 }
403
404 async fn unsubscribe(&mut self) -> Result<ChannelState, ChannelSendError> {
406 let state = self.state.clone();
407 {
408 let state = state.lock().await;
409 if *state == ChannelState::Closed || *state == ChannelState::Leaving {
410 return Ok(*state);
411 }
412 }
413
414 match self
415 .send(RealtimeMessage {
416 event: MessageEvent::PhxLeave,
417 topic: self.topic.clone(),
418 payload: Payload::Empty {},
419 message_ref: Some(format!("{}+leave", self.id)),
420 })
421 .await
422 {
423 Ok(()) => {
424 let mut state = state.lock().await;
425 *state = ChannelState::Leaving;
426 Ok(*state)
427 }
428 Err(ChannelSendError::ChannelError(status)) => Ok(status),
429 Err(e) => Err(e),
430 }
431 }
432
433 async fn track(&mut self, payload: HashMap<String, Value>) -> Result<(), ChannelSendError> {
435 self.send(RealtimeMessage {
436 event: MessageEvent::Presence,
437 topic: self.topic.clone(),
438 payload: Payload::PresenceTrack(payload.into()),
439 message_ref: None,
440 })
441 .await
442 }
443
444 async fn untrack(&mut self) -> Result<(), ChannelSendError> {
446 self.send(RealtimeMessage {
447 event: MessageEvent::Untrack,
448 topic: self.topic.clone(),
449 payload: Payload::Empty {},
450 message_ref: None,
451 })
452 .await
453 }
454
455 async fn send(&mut self, message: RealtimeMessage) -> Result<(), ChannelSendError> {
456 let mut message = message.clone();
458 message.topic = self.topic.clone();
459
460 let state = self.state.lock().await;
461
462 if *state == ChannelState::Leaving {
463 return Err(ChannelSendError::ChannelError(*state));
464 }
465
466 match self.client_tx.send(message) {
467 Ok(()) => Ok(()),
468 Err(e) => Err(ChannelSendError::SendError(e)),
469 }
470 }
471
472 async fn broadcast(&mut self, payload: BroadcastPayload) -> Result<(), ChannelSendError> {
473 self.send(RealtimeMessage {
474 event: MessageEvent::Broadcast,
475 topic: "".into(),
476 payload: Payload::Broadcast(payload),
477 message_ref: None,
478 })
479 .await
480 }
481
482 async fn reauth(&mut self) -> Result<(), ChannelSendError> {
483 let access_token = self.access_token.lock().await;
485
486 self.join_payload.access_token = access_token.clone();
487
488 let state = self.state.lock().await;
489
490 if *state != ChannelState::Joined {
491 return Ok(());
492 }
493
494 drop(state);
495
496 let access_token_message = RealtimeMessage {
497 event: MessageEvent::AccessToken,
498 topic: self.topic.clone(),
499 payload: Payload::AccessToken(AccessTokenPayload {
500 access_token: access_token.clone(),
501 }),
502 ..Default::default()
503 };
504
505 drop(access_token);
506
507 self.send(access_token_message).await
508 }
509}
510
511pub struct RealtimeChannelBuilder {
513 topic: String,
514 broadcast: BroadcastConfig,
515 presence: PresenceConfig,
516 id: Uuid,
517 postgres_changes: Vec<PostgresChange>,
518 cdc_callbacks: HashMap<PostgresChangesEvent, Vec<CdcCallback>>,
519 broadcast_callbacks: HashMap<String, Vec<BroadcastCallback>>,
520 presence_callbacks: PresenceCallbackMap,
521}
522
523impl RealtimeChannelBuilder {
524 pub fn new(topic: impl Into<String>) -> Self {
527 Self {
528 topic: format!("realtime:{}", topic.into()),
529 broadcast: Default::default(),
530 presence: Default::default(),
531 id: Uuid::new_v4(),
532 postgres_changes: Default::default(),
533 cdc_callbacks: Default::default(),
534 broadcast_callbacks: Default::default(),
535 presence_callbacks: Default::default(),
536 }
537 }
538
539 pub fn topic(mut self, topic: impl Into<String>) -> Self {
541 self.topic = format!("realtime:{}", topic.into());
542 self
543 }
544
545 pub fn broadcast(mut self, broadcast_config: BroadcastConfig) -> Self {
547 self.broadcast = broadcast_config;
548 self
549 }
550
551 pub fn presence(mut self, presence_config: PresenceConfig) -> Self {
553 self.presence = presence_config;
554 self
555 }
556
557 pub fn on_postgres_change(
559 mut self,
560 event: PostgresChangesEvent,
561 filter: PostgresChangeFilter,
562 callback: impl FnMut(&PostgresChangesPayload) + 'static + Send,
563 ) -> Self {
564 self.postgres_changes.push(PostgresChange {
565 event: event.clone(),
566 schema: filter.schema.clone(),
567 table: filter.table.clone().unwrap_or("".into()),
568 filter: filter.filter.clone(),
569 });
570
571 if self.cdc_callbacks.get_mut(&event).is_none() {
572 self.cdc_callbacks.insert(event.clone(), vec![]);
573 }
574
575 self.cdc_callbacks
576 .get_mut(&event)
577 .unwrap_or(&mut vec![])
578 .push((filter, Box::new(callback)));
579
580 self
581 }
582
583 pub fn on_presence(
585 mut self,
586 event: PresenceEvent,
587 callback: impl Fn(String, PresenceState, PresenceState) + Send + 'static,
588 ) -> Self {
589 if self.presence_callbacks.get_mut(&event).is_none() {
590 self.presence_callbacks.insert(event.clone(), vec![]);
591 }
592
593 self.presence_callbacks
594 .get_mut(&event)
595 .unwrap_or(&mut vec![])
596 .push(Box::new(callback));
597
598 self
599 }
600
601 pub fn on_broadcast(
603 mut self,
604 event: impl Into<String>,
605 callback: impl FnMut(&HashMap<String, Value>) + 'static + Send,
606 ) -> Self {
607 let event: String = event.into();
608
609 if self.broadcast_callbacks.get_mut(&event).is_none() {
610 self.broadcast_callbacks.insert(event.clone(), vec![]);
611 }
612
613 self.broadcast_callbacks
614 .get_mut(&event)
615 .unwrap_or(&mut vec![])
616 .push(Box::new(callback));
617
618 self
619 }
620
621 fn build_common(
622 self,
623 client_tx: UnboundedSender<RealtimeMessage>,
624 access_token: String,
625 access_token_arc: Arc<Mutex<String>>,
626 rt: Arc<Runtime>,
627 ) -> ChannelManager {
628 let state = Arc::new(Mutex::new(ChannelState::Closed));
629 let cdc_callbacks = Arc::new(Mutex::new(self.cdc_callbacks));
630 let broadcast_callbacks = Arc::new(Mutex::new(self.broadcast_callbacks));
631 let (controller_tx, controller_rx) = mpsc::unbounded_channel::<ChannelManagerMessage>();
632
633 let mut channel = RealtimeChannel {
634 access_token: access_token_arc,
635 rt: rt.clone(),
636 tx: None,
637 topic: self.topic,
638 cdc_callbacks,
639 broadcast_callbacks,
640 client_tx,
641 state,
642 id: self.id,
643 join_payload: JoinPayload {
644 config: JoinConfig {
645 broadcast: self.broadcast,
646 presence: self.presence,
647 postgres_changes: self.postgres_changes,
648 },
649 access_token,
650 },
651 presence: Arc::new(Mutex::new(RealtimePresence::from_channel_builder(
652 self.presence_callbacks,
653 ))),
654 manager_channel: (controller_tx, controller_rx),
655 message_handle: None,
656 };
657
658 channel.client_recv(); let tx = channel.manager_channel.0.clone();
660
661 let _handle = rt.spawn(async move { channel.manager_recv().await });
662
663 ChannelManager { tx, rt }
664 }
665
666 pub fn build_sync(self, client: &ClientManagerSync) -> Result<ChannelManagerSync, RecvError> {
673 let client_tx = client.clone().get_ws_tx().unwrap();
674 let access_token = client.clone().get_access_token().unwrap();
675 let access_token_arc = client.clone().get_access_token_arc().unwrap();
676
677 let channel_manager =
678 self.build_common(client_tx, access_token, access_token_arc, client.get_rt());
679
680 client.add_channel(channel_manager.clone()).unwrap();
681
682 Ok(channel_manager.to_sync())
683 }
684
685 pub async fn build(self, client: &ClientManager) -> Result<ChannelManager, RecvError> {
690 let client_tx = client.clone().get_ws_tx().await?;
691 let access_token = client.clone().get_access_token().await?;
692 let access_token_arc = client.clone().get_access_token_arc().await?;
693
694 let channel_manager =
695 self.build_common(client_tx, access_token, access_token_arc, client.get_rt());
696
697 client.add_channel(channel_manager.clone()).await.unwrap();
698
699 Ok(channel_manager)
700 }
701}