Skip to main content

supabase_client_realtime/
client.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5
6use futures_util::future::{select, Either};
7use serde_json::{json, Value};
8use supabase_client_core::platform;
9use tokio::sync::{broadcast, oneshot, Mutex, RwLock};
10use tracing::{debug, info, trace, warn};
11
12use crate::callback::Binding;
13use crate::channel::{ChannelBuilder, RealtimeChannel};
14use crate::error::RealtimeError;
15use crate::presence;
16use crate::protocol::{self, RefCounter};
17use crate::transport::{self, WsMessage, WsRead, WsSink};
18use crate::types::{
19    BroadcastConfig, ChannelState, JoinPayload, PhoenixMessage, PostgresChangePayload,
20    PostgresChangesEvent, PresenceDiff, RealtimeConfig, SubscriptionStatus,
21};
22
23// ── ClientSender ──────────────────────────────────────────────────────────────
24
25/// Handle passed to channels for sending messages through the client's WebSocket.
26#[derive(Clone)]
27pub struct ClientSender {
28    inner: Arc<RealtimeClientInner>,
29}
30
31impl ClientSender {
32    /// Register a channel and send phx_join, waiting for acknowledgement.
33    pub(crate) async fn subscribe_channel(
34        &self,
35        channel: RealtimeChannel,
36        join_payload: JoinPayload,
37        timeout_dur: Duration,
38    ) -> Result<(), RealtimeError> {
39        let topic = channel.topic().to_string();
40
41        // Check if channel already exists
42        {
43            let channels = self.inner.channels.read().await;
44            if channels.contains_key(&topic) {
45                return Err(RealtimeError::ChannelAlreadyExists(topic));
46            }
47        }
48
49        // Build join message
50        let join_msg = protocol::build_join(&topic, &join_payload, &self.inner.ref_counter);
51        let join_ref = join_msg.join_ref.clone().unwrap();
52
53        // Set join_ref on the channel
54        {
55            let mut ch_join_ref = channel.inner.join_ref.write().await;
56            *ch_join_ref = Some(join_ref.clone());
57        }
58
59        // Register pending reply
60        let (reply_tx, reply_rx) = oneshot::channel();
61        {
62            let mut pending = self.inner.pending_replies.lock().await;
63            pending.insert(join_ref.clone(), reply_tx);
64        }
65
66        // Register channel
67        {
68            let mut channels = self.inner.channels.write().await;
69            channels.insert(topic.clone(), channel.clone());
70        }
71
72        // Send join message
73        self.send_message(join_msg).await?;
74
75        // Wait for reply with timeout
76        let result = platform::timeout(timeout_dur, reply_rx).await;
77
78        match result {
79            Ok(Ok(reply)) => {
80                let status = reply
81                    .payload
82                    .get("status")
83                    .and_then(|s| s.as_str())
84                    .unwrap_or("");
85                if status == "ok" {
86                    // Extract server-assigned postgres_changes IDs
87                    if let Some(pg_changes) = reply
88                        .payload
89                        .get("response")
90                        .and_then(|r| r.get("postgres_changes"))
91                        .and_then(|pc| pc.as_array())
92                    {
93                        let mut id_map = channel.inner.pg_change_id_map.write().await;
94                        for (index, entry) in pg_changes.iter().enumerate() {
95                            if let Some(server_id) = entry.get("id").and_then(|id| id.as_u64()) {
96                                id_map.insert(server_id, index);
97                            }
98                        }
99                    }
100
101                    *channel.inner.state.write().await = ChannelState::Joined;
102                    // Notify status callback
103                    let status_cb = channel.inner.registry.status_callback.read().await;
104                    if let Some(cb) = status_cb.as_ref() {
105                        cb(SubscriptionStatus::Subscribed, None);
106                    }
107                    Ok(())
108                } else {
109                    *channel.inner.state.write().await = ChannelState::Errored;
110                    let reason = reply
111                        .payload
112                        .get("response")
113                        .and_then(|r| r.get("reason"))
114                        .and_then(|r| r.as_str())
115                        .unwrap_or("unknown")
116                        .to_string();
117                    // Remove channel on failure
118                    self.inner.channels.write().await.remove(&topic);
119                    // Notify status callback
120                    let status_cb = channel.inner.registry.status_callback.read().await;
121                    if let Some(cb) = status_cb.as_ref() {
122                        cb(
123                            SubscriptionStatus::ChannelError,
124                            Some(RealtimeError::ServerError(reason.clone())),
125                        );
126                    }
127                    Err(RealtimeError::ServerError(reason))
128                }
129            }
130            Ok(Err(_)) => {
131                *channel.inner.state.write().await = ChannelState::Errored;
132                self.inner.channels.write().await.remove(&topic);
133                Err(RealtimeError::ConnectionClosed)
134            }
135            Err(_) => {
136                *channel.inner.state.write().await = ChannelState::Errored;
137                self.inner.channels.write().await.remove(&topic);
138                // Clean up pending reply
139                self.inner.pending_replies.lock().await.remove(&join_ref);
140                let status_cb = channel.inner.registry.status_callback.read().await;
141                if let Some(cb) = status_cb.as_ref() {
142                    cb(SubscriptionStatus::TimedOut, None);
143                }
144                Err(RealtimeError::SubscribeTimeout(timeout_dur))
145            }
146        }
147    }
148
149    pub(crate) async fn send_broadcast(
150        &self,
151        topic: &str,
152        event: &str,
153        payload: Value,
154        join_ref: &str,
155    ) -> Result<(), RealtimeError> {
156        let msg =
157            protocol::build_broadcast(topic, event, payload, join_ref, &self.inner.ref_counter);
158        self.send_message(msg).await
159    }
160
161    pub(crate) async fn send_presence_track(
162        &self,
163        topic: &str,
164        payload: Value,
165        join_ref: &str,
166    ) -> Result<(), RealtimeError> {
167        let msg =
168            protocol::build_presence_track(topic, payload, join_ref, &self.inner.ref_counter);
169        self.send_message(msg).await
170    }
171
172    pub(crate) async fn send_presence_untrack(
173        &self,
174        topic: &str,
175        join_ref: &str,
176    ) -> Result<(), RealtimeError> {
177        let msg = protocol::build_presence_untrack(topic, join_ref, &self.inner.ref_counter);
178        self.send_message(msg).await
179    }
180
181    pub(crate) async fn send_leave(
182        &self,
183        topic: &str,
184        join_ref: &str,
185    ) -> Result<(), RealtimeError> {
186        let msg = protocol::build_leave(topic, join_ref, &self.inner.ref_counter);
187        self.send_message(msg).await
188    }
189
190    pub(crate) async fn send_access_token(
191        &self,
192        topic: &str,
193        token: &str,
194        join_ref: &str,
195    ) -> Result<(), RealtimeError> {
196        let msg =
197            protocol::build_access_token(topic, token, join_ref, &self.inner.ref_counter);
198        self.send_message(msg).await
199    }
200
201    async fn send_message(&self, msg: PhoenixMessage) -> Result<(), RealtimeError> {
202        let text = serde_json::to_string(&msg)?;
203        let mut ws = self.inner.ws_write.lock().await;
204        let sink = ws
205            .as_mut()
206            .ok_or(RealtimeError::ConnectionClosed)?;
207        trace!(topic = %msg.topic, event = %msg.event, "Sending WS message");
208        transport::send_text(sink, text).await
209    }
210}
211
212// ── RealtimeClient ────────────────────────────────────────────────────────────
213
214struct RealtimeClientInner {
215    config: RealtimeConfig,
216    ws_write: Mutex<Option<WsSink>>,
217    channels: RwLock<HashMap<String, RealtimeChannel>>,
218    ref_counter: RefCounter,
219    pending_replies: Mutex<HashMap<String, oneshot::Sender<PhoenixMessage>>>,
220    connected: AtomicBool,
221    intentional_disconnect: AtomicBool,
222    shutdown_tx: broadcast::Sender<()>,
223}
224
225/// Client for Supabase Realtime WebSocket connections.
226///
227/// Wraps `Arc<Inner>` — cheaply cloneable, `Send + Sync`.
228#[derive(Clone)]
229pub struct RealtimeClient {
230    inner: Arc<RealtimeClientInner>,
231}
232
233impl RealtimeClient {
234    /// Create a new RealtimeClient from a Supabase URL and API key.
235    pub fn new(
236        url: impl Into<String>,
237        api_key: impl Into<String>,
238    ) -> Result<Self, RealtimeError> {
239        let config = RealtimeConfig::new(url, api_key);
240        Self::with_config(config)
241    }
242
243    /// Create a new RealtimeClient with full configuration.
244    pub fn with_config(config: RealtimeConfig) -> Result<Self, RealtimeError> {
245        if config.url.is_empty() {
246            return Err(RealtimeError::InvalidConfig(
247                "URL must not be empty".to_string(),
248            ));
249        }
250        if config.api_key.is_empty() {
251            return Err(RealtimeError::InvalidConfig(
252                "API key must not be empty".to_string(),
253            ));
254        }
255
256        let (shutdown_tx, _) = broadcast::channel(1);
257
258        Ok(Self {
259            inner: Arc::new(RealtimeClientInner {
260                config,
261                ws_write: Mutex::new(None),
262                channels: RwLock::new(HashMap::new()),
263                ref_counter: RefCounter::new(),
264                pending_replies: Mutex::new(HashMap::new()),
265                connected: AtomicBool::new(false),
266                intentional_disconnect: AtomicBool::new(false),
267                shutdown_tx,
268            }),
269        })
270    }
271
272    /// Connect to the Supabase Realtime server via WebSocket.
273    ///
274    /// Establishes the WebSocket connection and starts background reader,
275    /// heartbeat, and auto-reconnect tasks.
276    pub async fn connect(&self) -> Result<(), RealtimeError> {
277        self.inner.intentional_disconnect.store(false, Ordering::SeqCst);
278
279        let ws_url = build_ws_url(&self.inner.config.url, &self.inner.config.api_key)?;
280        debug!(url = %ws_url, "Connecting to Supabase Realtime");
281
282        let (write, read) = transport::connect_ws(&self.inner.config, &ws_url).await?;
283        *self.inner.ws_write.lock().await = Some(write);
284        self.inner.connected.store(true, Ordering::SeqCst);
285
286        // Spawn the reconnection-aware reader loop
287        let inner = Arc::clone(&self.inner);
288        let ws_url_owned = ws_url;
289        platform::spawn(async move {
290            run_reader_loop(inner, read, ws_url_owned).await;
291        });
292
293        // Start heartbeat task
294        spawn_heartbeat(Arc::clone(&self.inner));
295
296        debug!("Connected to Supabase Realtime");
297        Ok(())
298    }
299
300    /// Disconnect from the Realtime server.
301    pub async fn disconnect(&self) -> Result<(), RealtimeError> {
302        debug!("Disconnecting from Supabase Realtime");
303        self.inner.intentional_disconnect.store(true, Ordering::SeqCst);
304        // Signal background tasks to stop
305        let _ = self.inner.shutdown_tx.send(());
306        self.inner.connected.store(false, Ordering::SeqCst);
307
308        // Close WebSocket
309        {
310            let mut ws = self.inner.ws_write.lock().await;
311            if let Some(sink) = ws.as_mut() {
312                let _ = transport::send_close(sink).await;
313            }
314            *ws = None;
315        }
316
317        // Clear pending replies
318        {
319            let mut pending = self.inner.pending_replies.lock().await;
320            pending.clear();
321        }
322
323        Ok(())
324    }
325
326    /// Create a ChannelBuilder for the given name.
327    ///
328    /// The topic will be `"realtime:<name>"`.
329    pub fn channel(&self, name: &str) -> ChannelBuilder {
330        let topic = format!("realtime:{}", name);
331        ChannelBuilder {
332            name: name.to_string(),
333            topic,
334            broadcast_config: BroadcastConfig::default(),
335            presence_key: String::new(),
336            presence_enabled: false,
337            postgres_changes: Vec::new(),
338            bindings: Vec::new(),
339            is_private: false,
340            subscribe_timeout: self.inner.config.subscribe_timeout,
341            access_token: Some(self.inner.config.api_key.clone()),
342            client_sender: ClientSender {
343                inner: Arc::clone(&self.inner),
344            },
345        }
346    }
347
348    /// Remove a channel (unsubscribe and forget).
349    pub async fn remove_channel(
350        &self,
351        channel: &RealtimeChannel,
352    ) -> Result<(), RealtimeError> {
353        let topic = channel.topic().to_string();
354        // Send leave if joined
355        let state = *channel.inner.state.read().await;
356        if state == ChannelState::Joined || state == ChannelState::Joining {
357            let _ = channel.unsubscribe().await;
358        }
359        *channel.inner.state.write().await = ChannelState::Closed;
360        self.inner.channels.write().await.remove(&topic);
361        Ok(())
362    }
363
364    /// Remove all channels.
365    pub async fn remove_all_channels(&self) -> Result<(), RealtimeError> {
366        let channels: Vec<RealtimeChannel> = {
367            self.inner.channels.read().await.values().cloned().collect()
368        };
369        for ch in channels {
370            self.remove_channel(&ch).await?;
371        }
372        Ok(())
373    }
374
375    /// Get a list of all active channels.
376    pub fn channels(&self) -> Vec<RealtimeChannel> {
377        // Use try_read to avoid blocking; if locked, return empty
378        match self.inner.channels.try_read() {
379            Ok(channels) => channels.values().cloned().collect(),
380            Err(_) => Vec::new(),
381        }
382    }
383
384    /// Update the auth token for the realtime connection.
385    ///
386    /// If connected, pushes the new access token to the server for each subscribed channel.
387    ///
388    /// Mirrors `supabase.realtime.setAuth(token)`.
389    pub async fn set_auth(&self, token: &str) -> Result<(), RealtimeError> {
390        if !self.is_connected() {
391            return Err(RealtimeError::ConnectionClosed);
392        }
393
394        let channels: Vec<RealtimeChannel> = {
395            self.inner.channels.read().await.values().cloned().collect()
396        };
397
398        let sender = ClientSender {
399            inner: Arc::clone(&self.inner),
400        };
401
402        for channel in &channels {
403            let state = *channel.inner.state.read().await;
404            if state == ChannelState::Joined {
405                let join_ref = channel.inner.join_ref.read().await;
406                if let Some(ref jr) = *join_ref {
407                    sender
408                        .send_access_token(channel.topic(), token, jr)
409                        .await?;
410                }
411            }
412        }
413
414        Ok(())
415    }
416
417    /// Check if the client is currently connected.
418    pub fn is_connected(&self) -> bool {
419        self.inner.connected.load(Ordering::SeqCst)
420    }
421}
422
423// ── WebSocket URL Construction ────────────────────────────────────────────────
424
425/// Convert a Supabase HTTP URL to a WebSocket URL for the Realtime endpoint.
426pub(crate) fn build_ws_url(base_url: &str, api_key: &str) -> Result<String, RealtimeError> {
427    let mut parsed = url::Url::parse(base_url)?;
428
429    // Convert scheme: http→ws, https→wss
430    let ws_scheme = match parsed.scheme() {
431        "http" | "ws" => "ws",
432        "https" | "wss" => "wss",
433        other => {
434            return Err(RealtimeError::InvalidConfig(format!(
435                "Unsupported URL scheme: {}",
436                other
437            )));
438        }
439    };
440    parsed
441        .set_scheme(ws_scheme)
442        .map_err(|_| RealtimeError::InvalidConfig("Failed to set WS scheme".to_string()))?;
443
444    // Append realtime path
445    {
446        let mut path = parsed.path().to_string();
447        if !path.ends_with('/') {
448            path.push('/');
449        }
450        path.push_str("realtime/v1/websocket");
451        parsed.set_path(&path);
452    }
453
454    // Add query params
455    parsed
456        .query_pairs_mut()
457        .append_pair("apikey", api_key)
458        .append_pair("vsn", "1.0.0");
459
460    Ok(parsed.to_string())
461}
462
463// ── Connection Helpers ────────────────────────────────────────────────────────
464
465/// Run the reader loop with auto-reconnect support.
466///
467/// Reads from the WebSocket, handling messages. On disconnect (if not intentional),
468/// attempts reconnection with backoff and rejoins channels on success.
469async fn run_reader_loop(
470    inner: Arc<RealtimeClientInner>,
471    initial_read: WsRead,
472    ws_url: String,
473) {
474    let mut read = initial_read;
475    let mut shutdown_rx = inner.shutdown_tx.subscribe();
476
477    loop {
478        // Read messages until disconnect
479        let disconnected_by_shutdown = read_until_disconnect(&inner, &mut read, &mut shutdown_rx).await;
480
481        if disconnected_by_shutdown || inner.intentional_disconnect.load(Ordering::SeqCst) {
482            break;
483        }
484
485        // Attempt auto-reconnect with backoff
486        match attempt_reconnect(&inner, &ws_url).await {
487            Some(new_read) => {
488                read = new_read;
489                // Spawn a new heartbeat for the new connection
490                spawn_heartbeat(Arc::clone(&inner));
491                // Rejoin channels
492                if let Err(e) = rejoin_channels(&inner).await {
493                    warn!(error = %e, "Failed to rejoin channels after reconnect");
494                }
495            }
496            None => {
497                // All reconnect attempts failed
498                notify_all_channels_closed(&inner).await;
499                break;
500            }
501        }
502    }
503}
504
505/// Read messages from the WebSocket until it disconnects or shutdown is signaled.
506/// Returns `true` if shutdown was requested, `false` if the connection was lost.
507async fn read_until_disconnect(
508    inner: &RealtimeClientInner,
509    read: &mut WsRead,
510    shutdown_rx: &mut broadcast::Receiver<()>,
511) -> bool {
512    loop {
513        // Use futures_util::select instead of tokio::select! for WASM compatibility
514        let recv_fut = transport::recv_message(read);
515        let shutdown_fut = shutdown_rx.recv();
516
517        futures_util::pin_mut!(recv_fut);
518        futures_util::pin_mut!(shutdown_fut);
519
520        match select(recv_fut, shutdown_fut).await {
521            Either::Left((msg, _)) => {
522                match msg {
523                    Some(Ok(WsMessage::Text(text))) => {
524                        handle_message(inner, &text).await;
525                    }
526                    Some(Ok(WsMessage::Close)) => {
527                        debug!("WebSocket closed by server");
528                        inner.connected.store(false, Ordering::SeqCst);
529                        return false;
530                    }
531                    Some(Ok(WsMessage::Ping(data))) => {
532                        let mut ws = inner.ws_write.lock().await;
533                        if let Some(sink) = ws.as_mut() {
534                            let _ = transport::send_pong(sink, data).await;
535                        }
536                    }
537                    Some(Err(e)) => {
538                        warn!(error = %e, "WebSocket read error");
539                        inner.connected.store(false, Ordering::SeqCst);
540                        return false;
541                    }
542                    None => {
543                        debug!("WebSocket stream ended");
544                        inner.connected.store(false, Ordering::SeqCst);
545                        return false;
546                    }
547                }
548            }
549            Either::Right(_) => {
550                debug!("Reader task shutting down");
551                return true;
552            }
553        }
554    }
555}
556
557/// Spawn a heartbeat task that sends periodic heartbeats over the WebSocket.
558fn spawn_heartbeat(inner: Arc<RealtimeClientInner>) {
559    let mut shutdown_rx = inner.shutdown_tx.subscribe();
560    let heartbeat_interval = inner.config.heartbeat_interval;
561    platform::spawn(async move {
562        loop {
563            // Sleep for heartbeat interval, racing against shutdown
564            let sleep_fut = platform::sleep(heartbeat_interval);
565            let shutdown_fut = shutdown_rx.recv();
566            futures_util::pin_mut!(sleep_fut);
567            futures_util::pin_mut!(shutdown_fut);
568
569            match select(sleep_fut, shutdown_fut).await {
570                Either::Left(_) => {
571                    // Time to send heartbeat
572                    if !inner.connected.load(Ordering::SeqCst) {
573                        break;
574                    }
575                    let heartbeat = protocol::build_heartbeat(&inner.ref_counter);
576                    let text = match serde_json::to_string(&heartbeat) {
577                        Ok(t) => t,
578                        Err(_) => continue,
579                    };
580                    let mut ws = inner.ws_write.lock().await;
581                    if let Some(sink) = ws.as_mut() {
582                        if let Err(e) = transport::send_text(sink, text).await {
583                            warn!(error = %e, "Heartbeat send failed");
584                            inner.connected.store(false, Ordering::SeqCst);
585                            break;
586                        }
587                        trace!("Heartbeat sent");
588                    }
589                }
590                Either::Right(_) => {
591                    debug!("Heartbeat task shutting down");
592                    break;
593                }
594            }
595        }
596    });
597}
598
599/// Attempt to reconnect with backoff intervals from config.
600/// Returns the new read half on success, or None if all attempts failed.
601async fn attempt_reconnect(
602    inner: &Arc<RealtimeClientInner>,
603    ws_url: &str,
604) -> Option<WsRead> {
605    let config = &inner.config;
606
607    // Build iterator: configured intervals, then repeat fallback
608    let intervals = config.reconnect.intervals.iter().copied()
609        .chain(std::iter::repeat(config.reconnect.fallback));
610
611    let max_attempts = config.reconnect.intervals.len() + 3;
612
613    for (attempt, delay) in intervals.enumerate().take(max_attempts) {
614        if inner.intentional_disconnect.load(Ordering::SeqCst) {
615            return None;
616        }
617
618        info!(attempt = attempt + 1, delay_ms = delay.as_millis(), "Attempting reconnect");
619        platform::sleep(delay).await;
620
621        if inner.intentional_disconnect.load(Ordering::SeqCst) {
622            return None;
623        }
624
625        match transport::connect_ws(&config, ws_url).await {
626            Ok((write, read)) => {
627                *inner.ws_write.lock().await = Some(write);
628                inner.connected.store(true, Ordering::SeqCst);
629                info!("Reconnected successfully");
630                return Some(read);
631            }
632            Err(e) => {
633                warn!(error = %e, attempt = attempt + 1, "Reconnect attempt failed");
634            }
635        }
636    }
637
638    warn!("All reconnect attempts exhausted");
639    None
640}
641
642/// Rejoin all active channels after a successful reconnect.
643async fn rejoin_channels(inner: &RealtimeClientInner) -> Result<(), RealtimeError> {
644    let channels = inner.channels.read().await;
645    for (topic, channel) in channels.iter() {
646        let state = *channel.inner.state.read().await;
647        if state == ChannelState::Joined || state == ChannelState::Joining {
648            debug!(topic = %topic, "Rejoining channel after reconnect");
649            let join_ref = inner.ref_counter.next();
650            let msg_ref = inner.ref_counter.next();
651            // Use the stored join payload from the original subscribe
652            let join_payload = channel.inner.join_payload.read().await.clone();
653            let phoenix_msg = PhoenixMessage {
654                event: "phx_join".to_string(),
655                topic: topic.clone(),
656                payload: serde_json::to_value(&join_payload).unwrap_or(json!({})),
657                msg_ref: Some(msg_ref),
658                join_ref: Some(join_ref),
659            };
660            let text = serde_json::to_string(&phoenix_msg)
661                .map_err(|e| RealtimeError::ServerError(format!("JSON error: {}", e)))?;
662            let mut ws = inner.ws_write.lock().await;
663            if let Some(sink) = ws.as_mut() {
664                transport::send_text(sink, text).await?;
665            }
666            *channel.inner.state.write().await = ChannelState::Joining;
667        }
668    }
669    Ok(())
670}
671
672// ── Message Routing ───────────────────────────────────────────────────────────
673
674async fn handle_message(inner: &RealtimeClientInner, text: &str) {
675    let msg: PhoenixMessage = match serde_json::from_str(text) {
676        Ok(m) => m,
677        Err(e) => {
678            warn!(error = %e, "Failed to parse Phoenix message");
679            return;
680        }
681    };
682
683    trace!(
684        topic = %msg.topic,
685        event = %msg.event,
686        "Received WS message"
687    );
688
689    match msg.event.as_str() {
690        "phx_reply" => handle_phx_reply(inner, msg).await,
691        "postgres_changes" => handle_postgres_changes(inner, msg).await,
692        "broadcast" => handle_broadcast(inner, msg).await,
693        "presence_state" => handle_presence_state(inner, msg).await,
694        "presence_diff" => handle_presence_diff(inner, msg).await,
695        "phx_close" => handle_phx_close(inner, msg).await,
696        "phx_error" => handle_phx_error(inner, msg).await,
697        "system" => handle_system(inner, msg).await,
698        _ => {
699            trace!(event = %msg.event, "Unhandled event type");
700        }
701    }
702}
703
704async fn handle_phx_reply(inner: &RealtimeClientInner, msg: PhoenixMessage) {
705    // Check if this is a reply to a join (ref matches join_ref)
706    if let Some(ref ref_id) = msg.msg_ref {
707        let mut pending = inner.pending_replies.lock().await;
708        if let Some(tx) = pending.remove(ref_id) {
709            let _ = tx.send(msg);
710            return;
711        }
712    }
713    // Check if it's a reply by join_ref
714    if let Some(ref join_ref) = msg.join_ref {
715        let mut pending = inner.pending_replies.lock().await;
716        if let Some(tx) = pending.remove(join_ref) {
717            let _ = tx.send(msg);
718            return;
719        }
720    }
721}
722
723async fn handle_postgres_changes(inner: &RealtimeClientInner, msg: PhoenixMessage) {
724    let channels = inner.channels.read().await;
725    let channel = match channels.get(&msg.topic) {
726        Some(ch) => ch,
727        None => return,
728    };
729
730    // Parse the payload — the actual data is nested under the message
731    let data = &msg.payload;
732
733    // Extract ids from the payload to match with filter_index
734    let ids_val = data.get("ids").and_then(|v| v.as_array());
735
736    // Parse the postgres change payload from the "data" field
737    let change_data = match data.get("data") {
738        Some(d) => d,
739        None => {
740            // Sometimes the payload IS the data directly
741            data
742        }
743    };
744
745    let payload: PostgresChangePayload = match serde_json::from_value(change_data.clone()) {
746        Ok(p) => p,
747        Err(e) => {
748            warn!(error = %e, "Failed to parse postgres change payload");
749            return;
750        }
751    };
752
753    // Resolve server IDs to filter indices
754    let id_map = channel.inner.pg_change_id_map.read().await;
755    let matched_indices: Vec<usize> = match ids_val {
756        Some(ids) => ids
757            .iter()
758            .filter_map(|id| id.as_u64())
759            .filter_map(|server_id| id_map.get(&server_id).copied())
760            .collect(),
761        None => Vec::new(),
762    };
763    drop(id_map);
764
765    // Dispatch to matching bindings
766    let bindings = channel.inner.registry.bindings.read().await;
767    for binding in bindings.iter() {
768        if let Binding::PostgresChanges {
769            filter_index,
770            event,
771            callback,
772        } = binding
773        {
774            // Check if this binding's filter_index matches
775            let matches_id = matched_indices.is_empty() || matched_indices.contains(filter_index);
776
777            // Check event type matches
778            let event_matches = match event {
779                PostgresChangesEvent::All => true,
780                PostgresChangesEvent::Insert => payload.change_type == "INSERT",
781                PostgresChangesEvent::Update => payload.change_type == "UPDATE",
782                PostgresChangesEvent::Delete => payload.change_type == "DELETE",
783            };
784
785            if matches_id && event_matches {
786                callback(payload.clone());
787            }
788        }
789    }
790}
791
792async fn handle_broadcast(inner: &RealtimeClientInner, msg: PhoenixMessage) {
793    let channels = inner.channels.read().await;
794    let channel = match channels.get(&msg.topic) {
795        Some(ch) => ch,
796        None => return,
797    };
798
799    let event = msg
800        .payload
801        .get("event")
802        .and_then(|e| e.as_str())
803        .unwrap_or("");
804    let payload = msg
805        .payload
806        .get("payload")
807        .cloned()
808        .unwrap_or(json!({}));
809
810    let bindings = channel.inner.registry.bindings.read().await;
811    for binding in bindings.iter() {
812        if let Binding::Broadcast {
813            event: bind_event,
814            callback,
815        } = binding
816        {
817            if bind_event == event {
818                callback(payload.clone());
819            }
820        }
821    }
822}
823
824async fn handle_presence_state(inner: &RealtimeClientInner, msg: PhoenixMessage) {
825    let channels = inner.channels.read().await;
826    let channel = match channels.get(&msg.topic) {
827        Some(ch) => ch,
828        None => return,
829    };
830
831    let new_state = presence::apply_state(msg.payload);
832    *channel.inner.presence_state.write().await = new_state.clone();
833
834    // Dispatch sync callbacks
835    let bindings = channel.inner.registry.bindings.read().await;
836    for binding in bindings.iter() {
837        if let Binding::PresenceSync(callback) = binding {
838            callback(&new_state);
839        }
840    }
841}
842
843async fn handle_presence_diff(inner: &RealtimeClientInner, msg: PhoenixMessage) {
844    let channels = inner.channels.read().await;
845    let channel = match channels.get(&msg.topic) {
846        Some(ch) => ch,
847        None => return,
848    };
849
850    let diff: PresenceDiff = match serde_json::from_value(msg.payload) {
851        Ok(d) => d,
852        Err(e) => {
853            warn!(error = %e, "Failed to parse presence diff");
854            return;
855        }
856    };
857
858    let (joins, leaves) = {
859        let mut state = channel.inner.presence_state.write().await;
860        presence::apply_diff(&mut state, diff)
861    };
862
863    let state = channel.inner.presence_state.read().await;
864
865    // Dispatch callbacks
866    let bindings = channel.inner.registry.bindings.read().await;
867    for binding in bindings.iter() {
868        match binding {
869            Binding::PresenceJoin(callback) => {
870                for (key, metas) in &joins {
871                    callback(key.clone(), metas.clone());
872                }
873            }
874            Binding::PresenceLeave(callback) => {
875                for (key, metas) in &leaves {
876                    callback(key.clone(), metas.clone());
877                }
878            }
879            Binding::PresenceSync(callback) => {
880                callback(&state);
881            }
882            _ => {}
883        }
884    }
885}
886
887async fn handle_phx_close(inner: &RealtimeClientInner, msg: PhoenixMessage) {
888    let channels = inner.channels.read().await;
889    if let Some(channel) = channels.get(&msg.topic) {
890        *channel.inner.state.write().await = ChannelState::Closed;
891        let status_cb = channel.inner.registry.status_callback.read().await;
892        if let Some(cb) = status_cb.as_ref() {
893            cb(SubscriptionStatus::Closed, None);
894        }
895    }
896}
897
898async fn handle_phx_error(inner: &RealtimeClientInner, msg: PhoenixMessage) {
899    let channels = inner.channels.read().await;
900    if let Some(channel) = channels.get(&msg.topic) {
901        *channel.inner.state.write().await = ChannelState::Errored;
902        let reason = msg
903            .payload
904            .get("reason")
905            .and_then(|r| r.as_str())
906            .unwrap_or("unknown")
907            .to_string();
908        let status_cb = channel.inner.registry.status_callback.read().await;
909        if let Some(cb) = status_cb.as_ref() {
910            cb(
911                SubscriptionStatus::ChannelError,
912                Some(RealtimeError::ServerError(reason)),
913            );
914        }
915    }
916}
917
918async fn handle_system(_inner: &RealtimeClientInner, msg: PhoenixMessage) {
919    // System messages can include subscription confirmations, extensions info, etc.
920    debug!(
921        topic = %msg.topic,
922        payload = %msg.payload,
923        "System message received"
924    );
925}
926
927async fn notify_all_channels_closed(inner: &RealtimeClientInner) {
928    let channels = inner.channels.read().await;
929    for channel in channels.values() {
930        let current = *channel.inner.state.read().await;
931        if current == ChannelState::Joined || current == ChannelState::Joining {
932            *channel.inner.state.write().await = ChannelState::Closed;
933            let status_cb = channel.inner.registry.status_callback.read().await;
934            if let Some(cb) = status_cb.as_ref() {
935                cb(SubscriptionStatus::Closed, None);
936            }
937        }
938    }
939}
940
941#[cfg(test)]
942mod tests {
943    use super::*;
944    use crate::types::ReconnectConfig;
945
946    #[test]
947    fn test_build_ws_url_http() {
948        let url = build_ws_url("http://localhost:54321", "test-key").unwrap();
949        assert_eq!(
950            url,
951            "ws://localhost:54321/realtime/v1/websocket?apikey=test-key&vsn=1.0.0"
952        );
953    }
954
955    #[test]
956    fn test_build_ws_url_https() {
957        let url = build_ws_url("https://example.supabase.co", "anon-key").unwrap();
958        assert_eq!(
959            url,
960            "wss://example.supabase.co/realtime/v1/websocket?apikey=anon-key&vsn=1.0.0"
961        );
962    }
963
964    #[test]
965    fn test_build_ws_url_with_path() {
966        let url = build_ws_url("http://localhost:54321/", "key").unwrap();
967        assert!(url.starts_with("ws://localhost:54321/realtime/v1/websocket"));
968    }
969
970    #[test]
971    fn test_build_ws_url_invalid_scheme() {
972        let result = build_ws_url("ftp://localhost", "key");
973        assert!(result.is_err());
974    }
975
976    #[test]
977    fn test_set_auth_requires_connection() {
978        let rt = tokio::runtime::Builder::new_current_thread()
979            .enable_all()
980            .build()
981            .unwrap();
982        let client = RealtimeClient::new("http://localhost:54321", "test-key").unwrap();
983        // Not connected → should error
984        let result = rt.block_on(client.set_auth("new-token"));
985        assert!(result.is_err());
986    }
987
988    #[test]
989    fn test_custom_headers_stored() {
990        let mut headers = HashMap::new();
991        headers.insert("X-Custom-Header".to_string(), "custom-value".to_string());
992        let config = RealtimeConfig::new("http://localhost:54321", "test-key")
993            .with_headers(headers);
994        assert_eq!(config.headers.len(), 1);
995        assert_eq!(config.headers.get("X-Custom-Header").unwrap(), "custom-value");
996    }
997
998    #[test]
999    fn test_custom_headers_default_empty() {
1000        let config = RealtimeConfig::new("http://localhost:54321", "test-key");
1001        assert!(config.headers.is_empty());
1002    }
1003
1004    #[test]
1005    fn test_intentional_disconnect_flag() {
1006        let client = RealtimeClient::new("http://localhost:54321", "test-key").unwrap();
1007        assert!(!client.inner.intentional_disconnect.load(Ordering::SeqCst));
1008    }
1009
1010    #[test]
1011    fn test_reconnect_config_intervals() {
1012        let config = ReconnectConfig::default();
1013        assert_eq!(config.intervals.len(), 4);
1014        assert_eq!(config.intervals[0], Duration::from_secs(1));
1015        assert_eq!(config.intervals[3], Duration::from_secs(10));
1016        assert_eq!(config.fallback, Duration::from_secs(10));
1017    }
1018}