1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5
6use futures_util::future::{select, Either};
7use serde_json::{json, Value};
8use supabase_client_core::platform;
9use tokio::sync::{broadcast, oneshot, Mutex, RwLock};
10use tracing::{debug, info, trace, warn};
11
12use crate::callback::Binding;
13use crate::channel::{ChannelBuilder, RealtimeChannel};
14use crate::error::RealtimeError;
15use crate::presence;
16use crate::protocol::{self, RefCounter};
17use crate::transport::{self, WsMessage, WsRead, WsSink};
18use crate::types::{
19 BroadcastConfig, ChannelState, JoinPayload, PhoenixMessage, PostgresChangePayload,
20 PostgresChangesEvent, PresenceDiff, RealtimeConfig, SubscriptionStatus,
21};
22
23#[derive(Clone)]
27pub struct ClientSender {
28 inner: Arc<RealtimeClientInner>,
29}
30
31impl ClientSender {
32 pub(crate) async fn subscribe_channel(
34 &self,
35 channel: RealtimeChannel,
36 join_payload: JoinPayload,
37 timeout_dur: Duration,
38 ) -> Result<(), RealtimeError> {
39 let topic = channel.topic().to_string();
40
41 {
43 let channels = self.inner.channels.read().await;
44 if channels.contains_key(&topic) {
45 return Err(RealtimeError::ChannelAlreadyExists(topic));
46 }
47 }
48
49 let join_msg = protocol::build_join(&topic, &join_payload, &self.inner.ref_counter);
51 let join_ref = join_msg.join_ref.clone().unwrap();
52
53 {
55 let mut ch_join_ref = channel.inner.join_ref.write().await;
56 *ch_join_ref = Some(join_ref.clone());
57 }
58
59 let (reply_tx, reply_rx) = oneshot::channel();
61 {
62 let mut pending = self.inner.pending_replies.lock().await;
63 pending.insert(join_ref.clone(), reply_tx);
64 }
65
66 {
68 let mut channels = self.inner.channels.write().await;
69 channels.insert(topic.clone(), channel.clone());
70 }
71
72 self.send_message(join_msg).await?;
74
75 let result = platform::timeout(timeout_dur, reply_rx).await;
77
78 match result {
79 Ok(Ok(reply)) => {
80 let status = reply
81 .payload
82 .get("status")
83 .and_then(|s| s.as_str())
84 .unwrap_or("");
85 if status == "ok" {
86 if let Some(pg_changes) = reply
88 .payload
89 .get("response")
90 .and_then(|r| r.get("postgres_changes"))
91 .and_then(|pc| pc.as_array())
92 {
93 let mut id_map = channel.inner.pg_change_id_map.write().await;
94 for (index, entry) in pg_changes.iter().enumerate() {
95 if let Some(server_id) = entry.get("id").and_then(|id| id.as_u64()) {
96 id_map.insert(server_id, index);
97 }
98 }
99 }
100
101 *channel.inner.state.write().await = ChannelState::Joined;
102 let status_cb = channel.inner.registry.status_callback.read().await;
104 if let Some(cb) = status_cb.as_ref() {
105 cb(SubscriptionStatus::Subscribed, None);
106 }
107 Ok(())
108 } else {
109 *channel.inner.state.write().await = ChannelState::Errored;
110 let reason = reply
111 .payload
112 .get("response")
113 .and_then(|r| r.get("reason"))
114 .and_then(|r| r.as_str())
115 .unwrap_or("unknown")
116 .to_string();
117 self.inner.channels.write().await.remove(&topic);
119 let status_cb = channel.inner.registry.status_callback.read().await;
121 if let Some(cb) = status_cb.as_ref() {
122 cb(
123 SubscriptionStatus::ChannelError,
124 Some(RealtimeError::ServerError(reason.clone())),
125 );
126 }
127 Err(RealtimeError::ServerError(reason))
128 }
129 }
130 Ok(Err(_)) => {
131 *channel.inner.state.write().await = ChannelState::Errored;
132 self.inner.channels.write().await.remove(&topic);
133 Err(RealtimeError::ConnectionClosed)
134 }
135 Err(_) => {
136 *channel.inner.state.write().await = ChannelState::Errored;
137 self.inner.channels.write().await.remove(&topic);
138 self.inner.pending_replies.lock().await.remove(&join_ref);
140 let status_cb = channel.inner.registry.status_callback.read().await;
141 if let Some(cb) = status_cb.as_ref() {
142 cb(SubscriptionStatus::TimedOut, None);
143 }
144 Err(RealtimeError::SubscribeTimeout(timeout_dur))
145 }
146 }
147 }
148
149 pub(crate) async fn send_broadcast(
150 &self,
151 topic: &str,
152 event: &str,
153 payload: Value,
154 join_ref: &str,
155 ) -> Result<(), RealtimeError> {
156 let msg =
157 protocol::build_broadcast(topic, event, payload, join_ref, &self.inner.ref_counter);
158 self.send_message(msg).await
159 }
160
161 pub(crate) async fn send_presence_track(
162 &self,
163 topic: &str,
164 payload: Value,
165 join_ref: &str,
166 ) -> Result<(), RealtimeError> {
167 let msg =
168 protocol::build_presence_track(topic, payload, join_ref, &self.inner.ref_counter);
169 self.send_message(msg).await
170 }
171
172 pub(crate) async fn send_presence_untrack(
173 &self,
174 topic: &str,
175 join_ref: &str,
176 ) -> Result<(), RealtimeError> {
177 let msg = protocol::build_presence_untrack(topic, join_ref, &self.inner.ref_counter);
178 self.send_message(msg).await
179 }
180
181 pub(crate) async fn send_leave(
182 &self,
183 topic: &str,
184 join_ref: &str,
185 ) -> Result<(), RealtimeError> {
186 let msg = protocol::build_leave(topic, join_ref, &self.inner.ref_counter);
187 self.send_message(msg).await
188 }
189
190 pub(crate) async fn send_access_token(
191 &self,
192 topic: &str,
193 token: &str,
194 join_ref: &str,
195 ) -> Result<(), RealtimeError> {
196 let msg =
197 protocol::build_access_token(topic, token, join_ref, &self.inner.ref_counter);
198 self.send_message(msg).await
199 }
200
201 async fn send_message(&self, msg: PhoenixMessage) -> Result<(), RealtimeError> {
202 let text = serde_json::to_string(&msg)?;
203 let mut ws = self.inner.ws_write.lock().await;
204 let sink = ws
205 .as_mut()
206 .ok_or(RealtimeError::ConnectionClosed)?;
207 trace!(topic = %msg.topic, event = %msg.event, "Sending WS message");
208 transport::send_text(sink, text).await
209 }
210}
211
212struct RealtimeClientInner {
215 config: RealtimeConfig,
216 ws_write: Mutex<Option<WsSink>>,
217 channels: RwLock<HashMap<String, RealtimeChannel>>,
218 ref_counter: RefCounter,
219 pending_replies: Mutex<HashMap<String, oneshot::Sender<PhoenixMessage>>>,
220 connected: AtomicBool,
221 intentional_disconnect: AtomicBool,
222 shutdown_tx: broadcast::Sender<()>,
223}
224
225#[derive(Clone)]
229pub struct RealtimeClient {
230 inner: Arc<RealtimeClientInner>,
231}
232
233impl RealtimeClient {
234 pub fn new(
236 url: impl Into<String>,
237 api_key: impl Into<String>,
238 ) -> Result<Self, RealtimeError> {
239 let config = RealtimeConfig::new(url, api_key);
240 Self::with_config(config)
241 }
242
243 pub fn with_config(config: RealtimeConfig) -> Result<Self, RealtimeError> {
245 if config.url.is_empty() {
246 return Err(RealtimeError::InvalidConfig(
247 "URL must not be empty".to_string(),
248 ));
249 }
250 if config.api_key.is_empty() {
251 return Err(RealtimeError::InvalidConfig(
252 "API key must not be empty".to_string(),
253 ));
254 }
255
256 let (shutdown_tx, _) = broadcast::channel(1);
257
258 Ok(Self {
259 inner: Arc::new(RealtimeClientInner {
260 config,
261 ws_write: Mutex::new(None),
262 channels: RwLock::new(HashMap::new()),
263 ref_counter: RefCounter::new(),
264 pending_replies: Mutex::new(HashMap::new()),
265 connected: AtomicBool::new(false),
266 intentional_disconnect: AtomicBool::new(false),
267 shutdown_tx,
268 }),
269 })
270 }
271
272 pub async fn connect(&self) -> Result<(), RealtimeError> {
277 self.inner.intentional_disconnect.store(false, Ordering::SeqCst);
278
279 let ws_url = build_ws_url(&self.inner.config.url, &self.inner.config.api_key)?;
280 debug!(url = %ws_url, "Connecting to Supabase Realtime");
281
282 let (write, read) = transport::connect_ws(&self.inner.config, &ws_url).await?;
283 *self.inner.ws_write.lock().await = Some(write);
284 self.inner.connected.store(true, Ordering::SeqCst);
285
286 let inner = Arc::clone(&self.inner);
288 let ws_url_owned = ws_url;
289 platform::spawn(async move {
290 run_reader_loop(inner, read, ws_url_owned).await;
291 });
292
293 spawn_heartbeat(Arc::clone(&self.inner));
295
296 debug!("Connected to Supabase Realtime");
297 Ok(())
298 }
299
300 pub async fn disconnect(&self) -> Result<(), RealtimeError> {
302 debug!("Disconnecting from Supabase Realtime");
303 self.inner.intentional_disconnect.store(true, Ordering::SeqCst);
304 let _ = self.inner.shutdown_tx.send(());
306 self.inner.connected.store(false, Ordering::SeqCst);
307
308 {
310 let mut ws = self.inner.ws_write.lock().await;
311 if let Some(sink) = ws.as_mut() {
312 let _ = transport::send_close(sink).await;
313 }
314 *ws = None;
315 }
316
317 {
319 let mut pending = self.inner.pending_replies.lock().await;
320 pending.clear();
321 }
322
323 Ok(())
324 }
325
326 pub fn channel(&self, name: &str) -> ChannelBuilder {
330 let topic = format!("realtime:{}", name);
331 ChannelBuilder {
332 name: name.to_string(),
333 topic,
334 broadcast_config: BroadcastConfig::default(),
335 presence_key: String::new(),
336 presence_enabled: false,
337 postgres_changes: Vec::new(),
338 bindings: Vec::new(),
339 is_private: false,
340 subscribe_timeout: self.inner.config.subscribe_timeout,
341 access_token: Some(self.inner.config.api_key.clone()),
342 client_sender: ClientSender {
343 inner: Arc::clone(&self.inner),
344 },
345 }
346 }
347
348 pub async fn remove_channel(
350 &self,
351 channel: &RealtimeChannel,
352 ) -> Result<(), RealtimeError> {
353 let topic = channel.topic().to_string();
354 let state = *channel.inner.state.read().await;
356 if state == ChannelState::Joined || state == ChannelState::Joining {
357 let _ = channel.unsubscribe().await;
358 }
359 *channel.inner.state.write().await = ChannelState::Closed;
360 self.inner.channels.write().await.remove(&topic);
361 Ok(())
362 }
363
364 pub async fn remove_all_channels(&self) -> Result<(), RealtimeError> {
366 let channels: Vec<RealtimeChannel> = {
367 self.inner.channels.read().await.values().cloned().collect()
368 };
369 for ch in channels {
370 self.remove_channel(&ch).await?;
371 }
372 Ok(())
373 }
374
375 pub fn channels(&self) -> Vec<RealtimeChannel> {
377 match self.inner.channels.try_read() {
379 Ok(channels) => channels.values().cloned().collect(),
380 Err(_) => Vec::new(),
381 }
382 }
383
384 pub async fn set_auth(&self, token: &str) -> Result<(), RealtimeError> {
390 if !self.is_connected() {
391 return Err(RealtimeError::ConnectionClosed);
392 }
393
394 let channels: Vec<RealtimeChannel> = {
395 self.inner.channels.read().await.values().cloned().collect()
396 };
397
398 let sender = ClientSender {
399 inner: Arc::clone(&self.inner),
400 };
401
402 for channel in &channels {
403 let state = *channel.inner.state.read().await;
404 if state == ChannelState::Joined {
405 let join_ref = channel.inner.join_ref.read().await;
406 if let Some(ref jr) = *join_ref {
407 sender
408 .send_access_token(channel.topic(), token, jr)
409 .await?;
410 }
411 }
412 }
413
414 Ok(())
415 }
416
417 pub fn is_connected(&self) -> bool {
419 self.inner.connected.load(Ordering::SeqCst)
420 }
421}
422
423pub(crate) fn build_ws_url(base_url: &str, api_key: &str) -> Result<String, RealtimeError> {
427 let mut parsed = url::Url::parse(base_url)?;
428
429 let ws_scheme = match parsed.scheme() {
431 "http" | "ws" => "ws",
432 "https" | "wss" => "wss",
433 other => {
434 return Err(RealtimeError::InvalidConfig(format!(
435 "Unsupported URL scheme: {}",
436 other
437 )));
438 }
439 };
440 parsed
441 .set_scheme(ws_scheme)
442 .map_err(|_| RealtimeError::InvalidConfig("Failed to set WS scheme".to_string()))?;
443
444 {
446 let mut path = parsed.path().to_string();
447 if !path.ends_with('/') {
448 path.push('/');
449 }
450 path.push_str("realtime/v1/websocket");
451 parsed.set_path(&path);
452 }
453
454 parsed
456 .query_pairs_mut()
457 .append_pair("apikey", api_key)
458 .append_pair("vsn", "1.0.0");
459
460 Ok(parsed.to_string())
461}
462
463async fn run_reader_loop(
470 inner: Arc<RealtimeClientInner>,
471 initial_read: WsRead,
472 ws_url: String,
473) {
474 let mut read = initial_read;
475 let mut shutdown_rx = inner.shutdown_tx.subscribe();
476
477 loop {
478 let disconnected_by_shutdown = read_until_disconnect(&inner, &mut read, &mut shutdown_rx).await;
480
481 if disconnected_by_shutdown || inner.intentional_disconnect.load(Ordering::SeqCst) {
482 break;
483 }
484
485 match attempt_reconnect(&inner, &ws_url).await {
487 Some(new_read) => {
488 read = new_read;
489 spawn_heartbeat(Arc::clone(&inner));
491 if let Err(e) = rejoin_channels(&inner).await {
493 warn!(error = %e, "Failed to rejoin channels after reconnect");
494 }
495 }
496 None => {
497 notify_all_channels_closed(&inner).await;
499 break;
500 }
501 }
502 }
503}
504
505async fn read_until_disconnect(
508 inner: &RealtimeClientInner,
509 read: &mut WsRead,
510 shutdown_rx: &mut broadcast::Receiver<()>,
511) -> bool {
512 loop {
513 let recv_fut = transport::recv_message(read);
515 let shutdown_fut = shutdown_rx.recv();
516
517 futures_util::pin_mut!(recv_fut);
518 futures_util::pin_mut!(shutdown_fut);
519
520 match select(recv_fut, shutdown_fut).await {
521 Either::Left((msg, _)) => {
522 match msg {
523 Some(Ok(WsMessage::Text(text))) => {
524 handle_message(inner, &text).await;
525 }
526 Some(Ok(WsMessage::Close)) => {
527 debug!("WebSocket closed by server");
528 inner.connected.store(false, Ordering::SeqCst);
529 return false;
530 }
531 Some(Ok(WsMessage::Ping(data))) => {
532 let mut ws = inner.ws_write.lock().await;
533 if let Some(sink) = ws.as_mut() {
534 let _ = transport::send_pong(sink, data).await;
535 }
536 }
537 Some(Err(e)) => {
538 warn!(error = %e, "WebSocket read error");
539 inner.connected.store(false, Ordering::SeqCst);
540 return false;
541 }
542 None => {
543 debug!("WebSocket stream ended");
544 inner.connected.store(false, Ordering::SeqCst);
545 return false;
546 }
547 }
548 }
549 Either::Right(_) => {
550 debug!("Reader task shutting down");
551 return true;
552 }
553 }
554 }
555}
556
557fn spawn_heartbeat(inner: Arc<RealtimeClientInner>) {
559 let mut shutdown_rx = inner.shutdown_tx.subscribe();
560 let heartbeat_interval = inner.config.heartbeat_interval;
561 platform::spawn(async move {
562 loop {
563 let sleep_fut = platform::sleep(heartbeat_interval);
565 let shutdown_fut = shutdown_rx.recv();
566 futures_util::pin_mut!(sleep_fut);
567 futures_util::pin_mut!(shutdown_fut);
568
569 match select(sleep_fut, shutdown_fut).await {
570 Either::Left(_) => {
571 if !inner.connected.load(Ordering::SeqCst) {
573 break;
574 }
575 let heartbeat = protocol::build_heartbeat(&inner.ref_counter);
576 let text = match serde_json::to_string(&heartbeat) {
577 Ok(t) => t,
578 Err(_) => continue,
579 };
580 let mut ws = inner.ws_write.lock().await;
581 if let Some(sink) = ws.as_mut() {
582 if let Err(e) = transport::send_text(sink, text).await {
583 warn!(error = %e, "Heartbeat send failed");
584 inner.connected.store(false, Ordering::SeqCst);
585 break;
586 }
587 trace!("Heartbeat sent");
588 }
589 }
590 Either::Right(_) => {
591 debug!("Heartbeat task shutting down");
592 break;
593 }
594 }
595 }
596 });
597}
598
599async fn attempt_reconnect(
602 inner: &Arc<RealtimeClientInner>,
603 ws_url: &str,
604) -> Option<WsRead> {
605 let config = &inner.config;
606
607 let intervals = config.reconnect.intervals.iter().copied()
609 .chain(std::iter::repeat(config.reconnect.fallback));
610
611 let max_attempts = config.reconnect.intervals.len() + 3;
612
613 for (attempt, delay) in intervals.enumerate().take(max_attempts) {
614 if inner.intentional_disconnect.load(Ordering::SeqCst) {
615 return None;
616 }
617
618 info!(attempt = attempt + 1, delay_ms = delay.as_millis(), "Attempting reconnect");
619 platform::sleep(delay).await;
620
621 if inner.intentional_disconnect.load(Ordering::SeqCst) {
622 return None;
623 }
624
625 match transport::connect_ws(&config, ws_url).await {
626 Ok((write, read)) => {
627 *inner.ws_write.lock().await = Some(write);
628 inner.connected.store(true, Ordering::SeqCst);
629 info!("Reconnected successfully");
630 return Some(read);
631 }
632 Err(e) => {
633 warn!(error = %e, attempt = attempt + 1, "Reconnect attempt failed");
634 }
635 }
636 }
637
638 warn!("All reconnect attempts exhausted");
639 None
640}
641
642async fn rejoin_channels(inner: &RealtimeClientInner) -> Result<(), RealtimeError> {
644 let channels = inner.channels.read().await;
645 for (topic, channel) in channels.iter() {
646 let state = *channel.inner.state.read().await;
647 if state == ChannelState::Joined || state == ChannelState::Joining {
648 debug!(topic = %topic, "Rejoining channel after reconnect");
649 let join_ref = inner.ref_counter.next();
650 let msg_ref = inner.ref_counter.next();
651 let join_payload = channel.inner.join_payload.read().await.clone();
653 let phoenix_msg = PhoenixMessage {
654 event: "phx_join".to_string(),
655 topic: topic.clone(),
656 payload: serde_json::to_value(&join_payload).unwrap_or(json!({})),
657 msg_ref: Some(msg_ref),
658 join_ref: Some(join_ref),
659 };
660 let text = serde_json::to_string(&phoenix_msg)
661 .map_err(|e| RealtimeError::ServerError(format!("JSON error: {}", e)))?;
662 let mut ws = inner.ws_write.lock().await;
663 if let Some(sink) = ws.as_mut() {
664 transport::send_text(sink, text).await?;
665 }
666 *channel.inner.state.write().await = ChannelState::Joining;
667 }
668 }
669 Ok(())
670}
671
672async fn handle_message(inner: &RealtimeClientInner, text: &str) {
675 let msg: PhoenixMessage = match serde_json::from_str(text) {
676 Ok(m) => m,
677 Err(e) => {
678 warn!(error = %e, "Failed to parse Phoenix message");
679 return;
680 }
681 };
682
683 trace!(
684 topic = %msg.topic,
685 event = %msg.event,
686 "Received WS message"
687 );
688
689 match msg.event.as_str() {
690 "phx_reply" => handle_phx_reply(inner, msg).await,
691 "postgres_changes" => handle_postgres_changes(inner, msg).await,
692 "broadcast" => handle_broadcast(inner, msg).await,
693 "presence_state" => handle_presence_state(inner, msg).await,
694 "presence_diff" => handle_presence_diff(inner, msg).await,
695 "phx_close" => handle_phx_close(inner, msg).await,
696 "phx_error" => handle_phx_error(inner, msg).await,
697 "system" => handle_system(inner, msg).await,
698 _ => {
699 trace!(event = %msg.event, "Unhandled event type");
700 }
701 }
702}
703
704async fn handle_phx_reply(inner: &RealtimeClientInner, msg: PhoenixMessage) {
705 if let Some(ref ref_id) = msg.msg_ref {
707 let mut pending = inner.pending_replies.lock().await;
708 if let Some(tx) = pending.remove(ref_id) {
709 let _ = tx.send(msg);
710 return;
711 }
712 }
713 if let Some(ref join_ref) = msg.join_ref {
715 let mut pending = inner.pending_replies.lock().await;
716 if let Some(tx) = pending.remove(join_ref) {
717 let _ = tx.send(msg);
718 return;
719 }
720 }
721}
722
723async fn handle_postgres_changes(inner: &RealtimeClientInner, msg: PhoenixMessage) {
724 let channels = inner.channels.read().await;
725 let channel = match channels.get(&msg.topic) {
726 Some(ch) => ch,
727 None => return,
728 };
729
730 let data = &msg.payload;
732
733 let ids_val = data.get("ids").and_then(|v| v.as_array());
735
736 let change_data = match data.get("data") {
738 Some(d) => d,
739 None => {
740 data
742 }
743 };
744
745 let payload: PostgresChangePayload = match serde_json::from_value(change_data.clone()) {
746 Ok(p) => p,
747 Err(e) => {
748 warn!(error = %e, "Failed to parse postgres change payload");
749 return;
750 }
751 };
752
753 let id_map = channel.inner.pg_change_id_map.read().await;
755 let matched_indices: Vec<usize> = match ids_val {
756 Some(ids) => ids
757 .iter()
758 .filter_map(|id| id.as_u64())
759 .filter_map(|server_id| id_map.get(&server_id).copied())
760 .collect(),
761 None => Vec::new(),
762 };
763 drop(id_map);
764
765 let bindings = channel.inner.registry.bindings.read().await;
767 for binding in bindings.iter() {
768 if let Binding::PostgresChanges {
769 filter_index,
770 event,
771 callback,
772 } = binding
773 {
774 let matches_id = matched_indices.is_empty() || matched_indices.contains(filter_index);
776
777 let event_matches = match event {
779 PostgresChangesEvent::All => true,
780 PostgresChangesEvent::Insert => payload.change_type == "INSERT",
781 PostgresChangesEvent::Update => payload.change_type == "UPDATE",
782 PostgresChangesEvent::Delete => payload.change_type == "DELETE",
783 };
784
785 if matches_id && event_matches {
786 callback(payload.clone());
787 }
788 }
789 }
790}
791
792async fn handle_broadcast(inner: &RealtimeClientInner, msg: PhoenixMessage) {
793 let channels = inner.channels.read().await;
794 let channel = match channels.get(&msg.topic) {
795 Some(ch) => ch,
796 None => return,
797 };
798
799 let event = msg
800 .payload
801 .get("event")
802 .and_then(|e| e.as_str())
803 .unwrap_or("");
804 let payload = msg
805 .payload
806 .get("payload")
807 .cloned()
808 .unwrap_or(json!({}));
809
810 let bindings = channel.inner.registry.bindings.read().await;
811 for binding in bindings.iter() {
812 if let Binding::Broadcast {
813 event: bind_event,
814 callback,
815 } = binding
816 {
817 if bind_event == event {
818 callback(payload.clone());
819 }
820 }
821 }
822}
823
824async fn handle_presence_state(inner: &RealtimeClientInner, msg: PhoenixMessage) {
825 let channels = inner.channels.read().await;
826 let channel = match channels.get(&msg.topic) {
827 Some(ch) => ch,
828 None => return,
829 };
830
831 let new_state = presence::apply_state(msg.payload);
832 *channel.inner.presence_state.write().await = new_state.clone();
833
834 let bindings = channel.inner.registry.bindings.read().await;
836 for binding in bindings.iter() {
837 if let Binding::PresenceSync(callback) = binding {
838 callback(&new_state);
839 }
840 }
841}
842
843async fn handle_presence_diff(inner: &RealtimeClientInner, msg: PhoenixMessage) {
844 let channels = inner.channels.read().await;
845 let channel = match channels.get(&msg.topic) {
846 Some(ch) => ch,
847 None => return,
848 };
849
850 let diff: PresenceDiff = match serde_json::from_value(msg.payload) {
851 Ok(d) => d,
852 Err(e) => {
853 warn!(error = %e, "Failed to parse presence diff");
854 return;
855 }
856 };
857
858 let (joins, leaves) = {
859 let mut state = channel.inner.presence_state.write().await;
860 presence::apply_diff(&mut state, diff)
861 };
862
863 let state = channel.inner.presence_state.read().await;
864
865 let bindings = channel.inner.registry.bindings.read().await;
867 for binding in bindings.iter() {
868 match binding {
869 Binding::PresenceJoin(callback) => {
870 for (key, metas) in &joins {
871 callback(key.clone(), metas.clone());
872 }
873 }
874 Binding::PresenceLeave(callback) => {
875 for (key, metas) in &leaves {
876 callback(key.clone(), metas.clone());
877 }
878 }
879 Binding::PresenceSync(callback) => {
880 callback(&state);
881 }
882 _ => {}
883 }
884 }
885}
886
887async fn handle_phx_close(inner: &RealtimeClientInner, msg: PhoenixMessage) {
888 let channels = inner.channels.read().await;
889 if let Some(channel) = channels.get(&msg.topic) {
890 *channel.inner.state.write().await = ChannelState::Closed;
891 let status_cb = channel.inner.registry.status_callback.read().await;
892 if let Some(cb) = status_cb.as_ref() {
893 cb(SubscriptionStatus::Closed, None);
894 }
895 }
896}
897
898async fn handle_phx_error(inner: &RealtimeClientInner, msg: PhoenixMessage) {
899 let channels = inner.channels.read().await;
900 if let Some(channel) = channels.get(&msg.topic) {
901 *channel.inner.state.write().await = ChannelState::Errored;
902 let reason = msg
903 .payload
904 .get("reason")
905 .and_then(|r| r.as_str())
906 .unwrap_or("unknown")
907 .to_string();
908 let status_cb = channel.inner.registry.status_callback.read().await;
909 if let Some(cb) = status_cb.as_ref() {
910 cb(
911 SubscriptionStatus::ChannelError,
912 Some(RealtimeError::ServerError(reason)),
913 );
914 }
915 }
916}
917
918async fn handle_system(_inner: &RealtimeClientInner, msg: PhoenixMessage) {
919 debug!(
921 topic = %msg.topic,
922 payload = %msg.payload,
923 "System message received"
924 );
925}
926
927async fn notify_all_channels_closed(inner: &RealtimeClientInner) {
928 let channels = inner.channels.read().await;
929 for channel in channels.values() {
930 let current = *channel.inner.state.read().await;
931 if current == ChannelState::Joined || current == ChannelState::Joining {
932 *channel.inner.state.write().await = ChannelState::Closed;
933 let status_cb = channel.inner.registry.status_callback.read().await;
934 if let Some(cb) = status_cb.as_ref() {
935 cb(SubscriptionStatus::Closed, None);
936 }
937 }
938 }
939}
940
941#[cfg(test)]
942mod tests {
943 use super::*;
944 use crate::types::ReconnectConfig;
945
946 #[test]
947 fn test_build_ws_url_http() {
948 let url = build_ws_url("http://localhost:54321", "test-key").unwrap();
949 assert_eq!(
950 url,
951 "ws://localhost:54321/realtime/v1/websocket?apikey=test-key&vsn=1.0.0"
952 );
953 }
954
955 #[test]
956 fn test_build_ws_url_https() {
957 let url = build_ws_url("https://example.supabase.co", "anon-key").unwrap();
958 assert_eq!(
959 url,
960 "wss://example.supabase.co/realtime/v1/websocket?apikey=anon-key&vsn=1.0.0"
961 );
962 }
963
964 #[test]
965 fn test_build_ws_url_with_path() {
966 let url = build_ws_url("http://localhost:54321/", "key").unwrap();
967 assert!(url.starts_with("ws://localhost:54321/realtime/v1/websocket"));
968 }
969
970 #[test]
971 fn test_build_ws_url_invalid_scheme() {
972 let result = build_ws_url("ftp://localhost", "key");
973 assert!(result.is_err());
974 }
975
976 #[test]
977 fn test_set_auth_requires_connection() {
978 let rt = tokio::runtime::Builder::new_current_thread()
979 .enable_all()
980 .build()
981 .unwrap();
982 let client = RealtimeClient::new("http://localhost:54321", "test-key").unwrap();
983 let result = rt.block_on(client.set_auth("new-token"));
985 assert!(result.is_err());
986 }
987
988 #[test]
989 fn test_custom_headers_stored() {
990 let mut headers = HashMap::new();
991 headers.insert("X-Custom-Header".to_string(), "custom-value".to_string());
992 let config = RealtimeConfig::new("http://localhost:54321", "test-key")
993 .with_headers(headers);
994 assert_eq!(config.headers.len(), 1);
995 assert_eq!(config.headers.get("X-Custom-Header").unwrap(), "custom-value");
996 }
997
998 #[test]
999 fn test_custom_headers_default_empty() {
1000 let config = RealtimeConfig::new("http://localhost:54321", "test-key");
1001 assert!(config.headers.is_empty());
1002 }
1003
1004 #[test]
1005 fn test_intentional_disconnect_flag() {
1006 let client = RealtimeClient::new("http://localhost:54321", "test-key").unwrap();
1007 assert!(!client.inner.intentional_disconnect.load(Ordering::SeqCst));
1008 }
1009
1010 #[test]
1011 fn test_reconnect_config_intervals() {
1012 let config = ReconnectConfig::default();
1013 assert_eq!(config.intervals.len(), 4);
1014 assert_eq!(config.intervals[0], Duration::from_secs(1));
1015 assert_eq!(config.intervals[3], Duration::from_secs(10));
1016 assert_eq!(config.fallback, Duration::from_secs(10));
1017 }
1018}