pixeluvw_supabase/
realtime.rs

1use crate::core::SupabaseClient;
2use crate::error::{Result, SupaError};
3use futures_util::{SinkExt, StreamExt};
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5use serde_json::{json, Value};
6
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::time::sleep;
12use tokio_stream::Stream;
13use tokio_tungstenite::tungstenite::Message;
14
15// ============================================================================
16//  RealtimeClient
17// ============================================================================
18
19#[derive(Clone)]
20pub struct RealtimeClient {
21    pub(crate) client: SupabaseClient,
22}
23
24impl RealtimeClient {
25    pub(crate) fn new(client: SupabaseClient) -> Self {
26        Self { client }
27    }
28
29    /// Create a channel builder to configure a new subscription.
30    pub fn channel(&self, topic: &str) -> RealtimeChannelBuilder {
31        RealtimeChannelBuilder::new(self.client.clone(), topic)
32    }
33}
34
35// ============================================================================
36//  Realtime Types
37// ============================================================================
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum PostgresEvent {
41    Insert,
42    Update,
43    Delete,
44    All,
45}
46
47impl ToString for PostgresEvent {
48    fn to_string(&self) -> String {
49        match self {
50            PostgresEvent::Insert => "INSERT".to_string(),
51            PostgresEvent::Update => "UPDATE".to_string(),
52            PostgresEvent::Delete => "DELETE".to_string(),
53            PostgresEvent::All => "*".to_string(),
54        }
55    }
56}
57
58/// Connection state for realtime channels.
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum ConnectionState {
61    /// Channel is connecting to the server.
62    Connecting,
63    /// Channel is connected and receiving messages.
64    Connected,
65    /// Channel is reconnecting after a disconnection.
66    Reconnecting,
67    /// Channel has been closed.
68    Closed,
69}
70
71/// Commands sent from the user handle to the connection loop.
72enum ChannelCommand {
73    Broadcast {
74        event: String,
75        payload: Value,
76    },
77    Track {
78        payload: Value,
79    },
80    Untrack,
81    /// Close the channel gracefully.
82    Close,
83}
84
85// ============================================================================
86//  RealtimeChannel (Handle)
87// ============================================================================
88
89/// A handle to a subscribed Realtime channel.
90///
91/// Implements `Stream` to receive messages, and provides methods to send broadcasts or track presence.
92pub struct RealtimeChannel {
93    topic: String,
94    rx: mpsc::UnboundedReceiver<Result<RealtimeMessage>>,
95    cmd_tx: mpsc::UnboundedSender<ChannelCommand>,
96}
97
98impl Stream for RealtimeChannel {
99    type Item = Result<RealtimeMessage>;
100
101    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
102        self.rx.poll_recv(cx)
103    }
104}
105
106impl RealtimeChannel {
107    /// Get the topic name of this channel.
108    pub fn topic(&self) -> &str {
109        &self.topic
110    }
111
112    /// Send a broadcast message to other clients in the channel.
113    pub fn broadcast(&self, event: &str, payload: Value) -> Result<()> {
114        self.cmd_tx
115            .send(ChannelCommand::Broadcast {
116                event: event.to_string(),
117                payload,
118            })
119            .map_err(|_| SupaError::RealtimeError {
120                message: "Channel closed".to_string(),
121            })
122    }
123
124    /// Track user presence.
125    pub fn track(&self, payload: Value) -> Result<()> {
126        self.cmd_tx
127            .send(ChannelCommand::Track { payload })
128            .map_err(|_| SupaError::RealtimeError {
129                message: "Channel closed".to_string(),
130            })
131    }
132
133    /// Untrack user presence.
134    pub fn untrack(&self) -> Result<()> {
135        self.cmd_tx
136            .send(ChannelCommand::Untrack)
137            .map_err(|_| SupaError::RealtimeError {
138                message: "Channel closed".to_string(),
139            })
140    }
141
142    /// Close the channel gracefully.
143    ///
144    /// This will disconnect the WebSocket connection and stop receiving messages.
145    /// The channel cannot be reused after closing.
146    pub fn close(&self) -> Result<()> {
147        self.cmd_tx
148            .send(ChannelCommand::Close)
149            .map_err(|_| SupaError::RealtimeError {
150                message: "Channel already closed".to_string(),
151            })
152    }
153}
154
155// ============================================================================
156//  RealtimeChannelBuilder
157// ============================================================================
158
159pub struct RealtimeChannelBuilder {
160    client: SupabaseClient,
161    topic: String,
162    postgres_changes: Vec<Value>,
163}
164
165impl RealtimeChannelBuilder {
166    pub fn new(client: SupabaseClient, topic: &str) -> Self {
167        Self {
168            client,
169            topic: topic.to_string(),
170            postgres_changes: Vec::new(),
171        }
172    }
173
174    /// Listen to Postgres Changes
175    pub fn on_postgres_changes<S1, S2, S3>(
176        mut self,
177        event: PostgresEvent,
178        schema: S1,
179        table: Option<S2>,
180        filter: Option<S3>,
181    ) -> Self
182    where
183        S1: Into<String>,
184        S2: Into<String>,
185        S3: Into<String>,
186    {
187        let mut config = json!({
188            "event": event.to_string(),
189            "schema": schema.into(),
190        });
191
192        if let Some(t) = table {
193            config
194                .as_object_mut()
195                .unwrap()
196                .insert("table".to_string(), json!(t.into()));
197        }
198        if let Some(f) = filter {
199            config
200                .as_object_mut()
201                .unwrap()
202                .insert("filter".to_string(), json!(f.into()));
203        }
204
205        self.postgres_changes.push(config);
206        self
207    }
208
209    /// Subscribe to the channel.
210    /// Returns a `RealtimeChannel` handle.
211    pub async fn subscribe(self) -> Result<RealtimeChannel> {
212        let (tx, rx) = mpsc::unbounded_channel();
213        let (cmd_tx, mut cmd_rx) = mpsc::unbounded_channel();
214
215        let client = self.client.clone();
216        let topic = self.topic.clone();
217
218        // Build the config payload for phx_join
219        let mut postgres_changes_config = Vec::new();
220        for cfg in &self.postgres_changes {
221            postgres_changes_config.push(json!({
222                "event": cfg["event"],
223                "schema": cfg["schema"],
224                "table": cfg.get("table"),
225                "filter": cfg.get("filter")
226            }));
227        }
228
229        let mut config = json!({});
230        if !postgres_changes_config.is_empty() {
231            config.as_object_mut().unwrap().insert(
232                "postgres_changes".to_string(),
233                json!(postgres_changes_config),
234            );
235        }
236
237        config.as_object_mut().unwrap().insert(
238            "broadcast".to_string(),
239            json!({ "ack": false, "self": false }),
240        );
241        config
242            .as_object_mut()
243            .unwrap()
244            .insert("presence".to_string(), json!({ "key": "" }));
245
246        let config_clone = config.clone();
247
248        tokio::spawn(async move {
249            let mut retry_count = 0;
250            let base_delay = client.inner.config.retry_base_delay_ms;
251
252            loop {
253                // Connection Loop
254                // We pass &mut cmd_rx to the connection function
255                match connect_and_listen(&client, &topic, &config_clone, &tx, &mut cmd_rx).await {
256                    Ok(_) => {
257                        retry_count = 0;
258                    }
259                    Err(e) => {
260                        let _ = tx.send(Err(SupaError::RealtimeError {
261                            message: format!("Realtime disconnected: {}. Reconnecting...", e),
262                        }));
263                    }
264                }
265
266                retry_count += 1;
267                let delay = base_delay * 2u64.pow(retry_count.min(9) as u32);
268                sleep(Duration::from_millis(delay)).await;
269            }
270        });
271
272        Ok(RealtimeChannel {
273            topic: self.topic,
274            rx,
275            cmd_tx,
276        })
277    }
278}
279
280// ============================================================================
281//  Internal Connection Logic
282// ============================================================================
283
284async fn connect_and_listen(
285    client: &SupabaseClient,
286    topic: &str,
287    config: &Value,
288    tx: &mpsc::UnboundedSender<Result<RealtimeMessage>>,
289    user_cmd_rx: &mut mpsc::UnboundedReceiver<ChannelCommand>,
290) -> Result<()> {
291    // 1. WebSocket Handshake
292    let url = client.inner.url.clone();
293    let scheme = match url.scheme() {
294        "https" => "wss",
295        "http" => "ws",
296        _ => "wss",
297    };
298    let host = url.host_str().unwrap_or_default();
299    let port = url.port_or_known_default().unwrap_or(443);
300
301    let ws_url = format!(
302        "{}://{}:{}/realtime/v1/websocket?apikey={}&vsn=1.0.0",
303        scheme, host, port, client.inner.key
304    );
305
306    let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url)
307        .await
308        .map_err(|e| SupaError::RealtimeError {
309            message: format!("Connection failed: {}", e),
310        })?;
311
312    let (mut write, mut read) = ws_stream.split();
313    let (internal_cmd_tx, mut internal_cmd_rx) = mpsc::channel::<Message>(10);
314
315    // 2. Writer Task (Proxies messages from internal loop to WebSocket)
316    let writer_handle = tokio::spawn(async move {
317        while let Some(msg) = internal_cmd_rx.recv().await {
318            if let Err(_) = write.send(msg).await {
319                break;
320            }
321        }
322    });
323
324    // 3. Join Channel (phx_join)
325    let join_ref = format!("{}", rand::random::<u64>());
326    let access_token = {
327        let lock = client.inner.session.read().unwrap();
328        // Fallback to anon key if no session (Realtime usually uses public key or user token, logic varies)
329        // If we want RLS, we need user token. If public, anon key.
330        // Let's use get_access_token (internal clone) logic? No it's async and we are in async context but wrapped.
331        // Just read from session or fallback to key.
332        lock.as_ref()
333            .map(|s| s.access_token.clone())
334            .unwrap_or_else(|| client.inner.key.clone())
335    };
336
337    let join_msg = json!({
338        "topic": topic,
339        "event": "phx_join",
340        "payload": {
341            "config": config,
342            "access_token": access_token
343        },
344        "ref": join_ref
345    });
346
347    internal_cmd_tx
348        .send(Message::Text(join_msg.to_string()))
349        .await
350        .map_err(|e| SupaError::RealtimeError {
351            message: format!("Failed to send join: {}", e),
352        })?;
353
354    // 4. Heartbeat Task
355    let hb_cmd_tx = internal_cmd_tx.clone();
356    let hb_handle = tokio::spawn(async move {
357        loop {
358            sleep(Duration::from_secs(30)).await;
359            let msg = json!({
360                "topic": "phoenix",
361                "event": "heartbeat",
362                "payload": {},
363                "ref": format!("{}", rand::random::<u64>())
364            });
365            if hb_cmd_tx
366                .send(Message::Text(msg.to_string()))
367                .await
368                .is_err()
369            {
370                break;
371            }
372        }
373    });
374
375    // 5. Main Select Loop
376    loop {
377        tokio::select! {
378            // Incoming WebSocket Message
379            msg_res = read.next() => {
380                match msg_res {
381                    Some(Ok(msg)) => {
382                        match msg {
383                            Message::Text(text) => {
384                                if let Ok(parsed) = serde_json::from_str::<RealtimeMessage>(&text) {
385                                    if parsed.event == "phx_reply" {
386                                        // Heartbeat reply or join reply, ignore for now
387                                        continue;
388                                    }
389                                    if parsed.event == "phx_close" {
390                                        // Server closed channel
391                                        break;
392                                    }
393                                    if parsed.event == "phx_error" {
394                                         // Error
395                                         break;
396                                    }
397                                    if tx.send(Ok(parsed)).is_err() {
398                                        break;
399                                    }
400                                }
401                            }
402                            Message::Close(_) => break,
403                            _ => {}
404                        }
405                    }
406                    Some(Err(_)) => break, // WS Error
407                    None => break, // Stream ended
408                }
409            }
410
411            // Outgoing User Command (Broadcast / Presence)
412            cmd = user_cmd_rx.recv() => {
413                match cmd {
414                    Some(ChannelCommand::Broadcast { event, payload }) => {
415                        let msg = json!({
416                            "topic": topic,
417                            "event": "broadcast",
418                            "payload": {
419                                "event": event,
420                                "payload": payload
421                            },
422                            "ref": format!("{}", rand::random::<u64>())
423                        });
424                        if internal_cmd_tx.send(Message::Text(msg.to_string())).await.is_err() {
425                             break;
426                        }
427                    }
428                    Some(ChannelCommand::Track { payload }) => {
429                         let msg = json!({
430                            "topic": topic,
431                            "event": "presence",
432                            "payload": {
433                                "type": "track",
434                                "event": "track",
435                                "payload": payload
436                            },
437                             "ref": format!("{}", rand::random::<u64>())
438                        });
439                        if internal_cmd_tx.send(Message::Text(msg.to_string())).await.is_err() {
440                             break;
441                        }
442                    }
443                     Some(ChannelCommand::Untrack) => {
444                         let msg = json!({
445                            "topic": topic,
446                            "event": "presence",
447                            "payload": {
448                                "type": "untrack",
449                                "event": "untrack"
450                            },
451                             "ref": format!("{}", rand::random::<u64>())
452                        });
453                        if internal_cmd_tx.send(Message::Text(msg.to_string())).await.is_err() {
454                             break;
455                        }
456                    }
457                    Some(ChannelCommand::Close) => {
458                        // Send phx_leave to cleanly disconnect
459                        let leave_msg = json!({
460                            "topic": topic,
461                            "event": "phx_leave",
462                            "payload": {},
463                            "ref": format!("{}", rand::random::<u64>())
464                        });
465                        let _ = internal_cmd_tx.send(Message::Text(leave_msg.to_string())).await;
466                        // Return Ok to signal intentional close (not error)
467                        return Ok(());
468                    }
469                    None => break // User dropped channel handle
470                }
471            }
472        }
473    }
474
475    // Cleanup
476    hb_handle.abort();
477    writer_handle.abort();
478
479    Err(SupaError::RealtimeError {
480        message: "Connection ended".into(),
481    })
482}
483
484// ============================================================================
485//  Message Types
486// ============================================================================
487
488#[derive(Debug, Serialize, Deserialize)]
489pub struct RealtimeMessage {
490    pub topic: String,
491    pub event: String,
492    pub payload: Value,
493    #[serde(rename = "ref")]
494    pub ref_: Option<String>,
495}
496
497impl RealtimeMessage {
498    /// Check if this is a Postgres Change event
499    pub fn is_postgres_change(&self) -> bool {
500        self.event == "postgres_changes"
501            || self.event == "INSERT"
502            || self.event == "UPDATE"
503            || self.event == "DELETE"
504    }
505
506    /// Check if this is a Presence event
507    pub fn is_presence(&self) -> bool {
508        self.event == "presence_state" || self.event == "presence_diff"
509    }
510
511    /// Check if this is a Broadcast event
512    pub fn is_broadcast(&self) -> bool {
513        self.event == "broadcast"
514    }
515
516    /// Parse as Postgres change record
517    pub fn as_insert<T: DeserializeOwned>(&self) -> Result<T> {
518        self.extract_record("INSERT")
519    }
520
521    pub fn as_update<T: DeserializeOwned>(&self) -> Result<T> {
522        self.extract_record("UPDATE")
523    }
524
525    pub fn as_delete<T: DeserializeOwned>(&self) -> Result<T> {
526        self.extract_record("DELETE")
527    }
528
529    // Helper to extract record from payload, handling Supabase's wrapper structure
530    fn extract_record<T: DeserializeOwned>(&self, expected_type: &str) -> Result<T> {
531        // Payload for postgres_changes usually looks like:
532        // { "type": "INSERT", "table": "users", "schema": "public", "record": { ... }, "old_record": null }
533
534        let type_ = self
535            .payload
536            .get("type")
537            .and_then(|v| v.as_str())
538            .unwrap_or_default();
539
540        // Note: Sometimes strict checking fails if event name differs from internal type.
541        // We match loosely if type is empty or matches.
542        if !type_.is_empty() && type_ != expected_type {
543            return Err(SupaError::RealtimeError {
544                message: format!("Expected type {}, got {}", expected_type, type_),
545            });
546        }
547
548        let record_key = if expected_type == "DELETE" {
549            "old_record"
550        } else {
551            "record"
552        };
553        let record = self.payload.get(record_key);
554
555        match record {
556            Some(val) if !val.is_null() => {
557                serde_json::from_value(val.clone()).map_err(|e| SupaError::RealtimeError {
558                    message: format!("Deserialization failed: {}", e),
559                })
560            }
561            _ => {
562                // Fallback: maybe it's the other key
563                let fallback = self
564                    .payload
565                    .get("record")
566                    .or_else(|| self.payload.get("old_record"));
567                if let Some(val) = fallback {
568                    if !val.is_null() {
569                        return serde_json::from_value(val.clone()).map_err(|e| {
570                            SupaError::RealtimeError {
571                                message: format!("Deserialization failed (fallback): {}", e),
572                            }
573                        });
574                    }
575                }
576                Err(SupaError::RealtimeError {
577                    message: format!("No {} found in payload", record_key),
578                })
579            }
580        }
581    }
582}