1use crate::core::SupabaseClient;
2use crate::error::{Result, SupaError};
3use futures_util::{SinkExt, StreamExt};
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5use serde_json::{json, Value};
6
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::time::sleep;
12use tokio_stream::Stream;
13use tokio_tungstenite::tungstenite::Message;
14
15#[derive(Clone)]
20pub struct RealtimeClient {
21 pub(crate) client: SupabaseClient,
22}
23
24impl RealtimeClient {
25 pub(crate) fn new(client: SupabaseClient) -> Self {
26 Self { client }
27 }
28
29 pub fn channel(&self, topic: &str) -> RealtimeChannelBuilder {
31 RealtimeChannelBuilder::new(self.client.clone(), topic)
32 }
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum PostgresEvent {
41 Insert,
42 Update,
43 Delete,
44 All,
45}
46
47impl ToString for PostgresEvent {
48 fn to_string(&self) -> String {
49 match self {
50 PostgresEvent::Insert => "INSERT".to_string(),
51 PostgresEvent::Update => "UPDATE".to_string(),
52 PostgresEvent::Delete => "DELETE".to_string(),
53 PostgresEvent::All => "*".to_string(),
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum ConnectionState {
61 Connecting,
63 Connected,
65 Reconnecting,
67 Closed,
69}
70
71enum ChannelCommand {
73 Broadcast {
74 event: String,
75 payload: Value,
76 },
77 Track {
78 payload: Value,
79 },
80 Untrack,
81 Close,
83}
84
85pub struct RealtimeChannel {
93 topic: String,
94 rx: mpsc::UnboundedReceiver<Result<RealtimeMessage>>,
95 cmd_tx: mpsc::UnboundedSender<ChannelCommand>,
96}
97
98impl Stream for RealtimeChannel {
99 type Item = Result<RealtimeMessage>;
100
101 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
102 self.rx.poll_recv(cx)
103 }
104}
105
106impl RealtimeChannel {
107 pub fn topic(&self) -> &str {
109 &self.topic
110 }
111
112 pub fn broadcast(&self, event: &str, payload: Value) -> Result<()> {
114 self.cmd_tx
115 .send(ChannelCommand::Broadcast {
116 event: event.to_string(),
117 payload,
118 })
119 .map_err(|_| SupaError::RealtimeError {
120 message: "Channel closed".to_string(),
121 })
122 }
123
124 pub fn track(&self, payload: Value) -> Result<()> {
126 self.cmd_tx
127 .send(ChannelCommand::Track { payload })
128 .map_err(|_| SupaError::RealtimeError {
129 message: "Channel closed".to_string(),
130 })
131 }
132
133 pub fn untrack(&self) -> Result<()> {
135 self.cmd_tx
136 .send(ChannelCommand::Untrack)
137 .map_err(|_| SupaError::RealtimeError {
138 message: "Channel closed".to_string(),
139 })
140 }
141
142 pub fn close(&self) -> Result<()> {
147 self.cmd_tx
148 .send(ChannelCommand::Close)
149 .map_err(|_| SupaError::RealtimeError {
150 message: "Channel already closed".to_string(),
151 })
152 }
153}
154
155pub struct RealtimeChannelBuilder {
160 client: SupabaseClient,
161 topic: String,
162 postgres_changes: Vec<Value>,
163}
164
165impl RealtimeChannelBuilder {
166 pub fn new(client: SupabaseClient, topic: &str) -> Self {
167 Self {
168 client,
169 topic: topic.to_string(),
170 postgres_changes: Vec::new(),
171 }
172 }
173
174 pub fn on_postgres_changes<S1, S2, S3>(
176 mut self,
177 event: PostgresEvent,
178 schema: S1,
179 table: Option<S2>,
180 filter: Option<S3>,
181 ) -> Self
182 where
183 S1: Into<String>,
184 S2: Into<String>,
185 S3: Into<String>,
186 {
187 let mut config = json!({
188 "event": event.to_string(),
189 "schema": schema.into(),
190 });
191
192 if let Some(t) = table {
193 config
194 .as_object_mut()
195 .unwrap()
196 .insert("table".to_string(), json!(t.into()));
197 }
198 if let Some(f) = filter {
199 config
200 .as_object_mut()
201 .unwrap()
202 .insert("filter".to_string(), json!(f.into()));
203 }
204
205 self.postgres_changes.push(config);
206 self
207 }
208
209 pub async fn subscribe(self) -> Result<RealtimeChannel> {
212 let (tx, rx) = mpsc::unbounded_channel();
213 let (cmd_tx, mut cmd_rx) = mpsc::unbounded_channel();
214
215 let client = self.client.clone();
216 let topic = self.topic.clone();
217
218 let mut postgres_changes_config = Vec::new();
220 for cfg in &self.postgres_changes {
221 postgres_changes_config.push(json!({
222 "event": cfg["event"],
223 "schema": cfg["schema"],
224 "table": cfg.get("table"),
225 "filter": cfg.get("filter")
226 }));
227 }
228
229 let mut config = json!({});
230 if !postgres_changes_config.is_empty() {
231 config.as_object_mut().unwrap().insert(
232 "postgres_changes".to_string(),
233 json!(postgres_changes_config),
234 );
235 }
236
237 config.as_object_mut().unwrap().insert(
238 "broadcast".to_string(),
239 json!({ "ack": false, "self": false }),
240 );
241 config
242 .as_object_mut()
243 .unwrap()
244 .insert("presence".to_string(), json!({ "key": "" }));
245
246 let config_clone = config.clone();
247
248 tokio::spawn(async move {
249 let mut retry_count = 0;
250 let base_delay = client.inner.config.retry_base_delay_ms;
251
252 loop {
253 match connect_and_listen(&client, &topic, &config_clone, &tx, &mut cmd_rx).await {
256 Ok(_) => {
257 retry_count = 0;
258 }
259 Err(e) => {
260 let _ = tx.send(Err(SupaError::RealtimeError {
261 message: format!("Realtime disconnected: {}. Reconnecting...", e),
262 }));
263 }
264 }
265
266 retry_count += 1;
267 let delay = base_delay * 2u64.pow(retry_count.min(9) as u32);
268 sleep(Duration::from_millis(delay)).await;
269 }
270 });
271
272 Ok(RealtimeChannel {
273 topic: self.topic,
274 rx,
275 cmd_tx,
276 })
277 }
278}
279
280async fn connect_and_listen(
285 client: &SupabaseClient,
286 topic: &str,
287 config: &Value,
288 tx: &mpsc::UnboundedSender<Result<RealtimeMessage>>,
289 user_cmd_rx: &mut mpsc::UnboundedReceiver<ChannelCommand>,
290) -> Result<()> {
291 let url = client.inner.url.clone();
293 let scheme = match url.scheme() {
294 "https" => "wss",
295 "http" => "ws",
296 _ => "wss",
297 };
298 let host = url.host_str().unwrap_or_default();
299 let port = url.port_or_known_default().unwrap_or(443);
300
301 let ws_url = format!(
302 "{}://{}:{}/realtime/v1/websocket?apikey={}&vsn=1.0.0",
303 scheme, host, port, client.inner.key
304 );
305
306 let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url)
307 .await
308 .map_err(|e| SupaError::RealtimeError {
309 message: format!("Connection failed: {}", e),
310 })?;
311
312 let (mut write, mut read) = ws_stream.split();
313 let (internal_cmd_tx, mut internal_cmd_rx) = mpsc::channel::<Message>(10);
314
315 let writer_handle = tokio::spawn(async move {
317 while let Some(msg) = internal_cmd_rx.recv().await {
318 if let Err(_) = write.send(msg).await {
319 break;
320 }
321 }
322 });
323
324 let join_ref = format!("{}", rand::random::<u64>());
326 let access_token = {
327 let lock = client.inner.session.read().unwrap();
328 lock.as_ref()
333 .map(|s| s.access_token.clone())
334 .unwrap_or_else(|| client.inner.key.clone())
335 };
336
337 let join_msg = json!({
338 "topic": topic,
339 "event": "phx_join",
340 "payload": {
341 "config": config,
342 "access_token": access_token
343 },
344 "ref": join_ref
345 });
346
347 internal_cmd_tx
348 .send(Message::Text(join_msg.to_string()))
349 .await
350 .map_err(|e| SupaError::RealtimeError {
351 message: format!("Failed to send join: {}", e),
352 })?;
353
354 let hb_cmd_tx = internal_cmd_tx.clone();
356 let hb_handle = tokio::spawn(async move {
357 loop {
358 sleep(Duration::from_secs(30)).await;
359 let msg = json!({
360 "topic": "phoenix",
361 "event": "heartbeat",
362 "payload": {},
363 "ref": format!("{}", rand::random::<u64>())
364 });
365 if hb_cmd_tx
366 .send(Message::Text(msg.to_string()))
367 .await
368 .is_err()
369 {
370 break;
371 }
372 }
373 });
374
375 loop {
377 tokio::select! {
378 msg_res = read.next() => {
380 match msg_res {
381 Some(Ok(msg)) => {
382 match msg {
383 Message::Text(text) => {
384 if let Ok(parsed) = serde_json::from_str::<RealtimeMessage>(&text) {
385 if parsed.event == "phx_reply" {
386 continue;
388 }
389 if parsed.event == "phx_close" {
390 break;
392 }
393 if parsed.event == "phx_error" {
394 break;
396 }
397 if tx.send(Ok(parsed)).is_err() {
398 break;
399 }
400 }
401 }
402 Message::Close(_) => break,
403 _ => {}
404 }
405 }
406 Some(Err(_)) => break, None => break, }
409 }
410
411 cmd = user_cmd_rx.recv() => {
413 match cmd {
414 Some(ChannelCommand::Broadcast { event, payload }) => {
415 let msg = json!({
416 "topic": topic,
417 "event": "broadcast",
418 "payload": {
419 "event": event,
420 "payload": payload
421 },
422 "ref": format!("{}", rand::random::<u64>())
423 });
424 if internal_cmd_tx.send(Message::Text(msg.to_string())).await.is_err() {
425 break;
426 }
427 }
428 Some(ChannelCommand::Track { payload }) => {
429 let msg = json!({
430 "topic": topic,
431 "event": "presence",
432 "payload": {
433 "type": "track",
434 "event": "track",
435 "payload": payload
436 },
437 "ref": format!("{}", rand::random::<u64>())
438 });
439 if internal_cmd_tx.send(Message::Text(msg.to_string())).await.is_err() {
440 break;
441 }
442 }
443 Some(ChannelCommand::Untrack) => {
444 let msg = json!({
445 "topic": topic,
446 "event": "presence",
447 "payload": {
448 "type": "untrack",
449 "event": "untrack"
450 },
451 "ref": format!("{}", rand::random::<u64>())
452 });
453 if internal_cmd_tx.send(Message::Text(msg.to_string())).await.is_err() {
454 break;
455 }
456 }
457 Some(ChannelCommand::Close) => {
458 let leave_msg = json!({
460 "topic": topic,
461 "event": "phx_leave",
462 "payload": {},
463 "ref": format!("{}", rand::random::<u64>())
464 });
465 let _ = internal_cmd_tx.send(Message::Text(leave_msg.to_string())).await;
466 return Ok(());
468 }
469 None => break }
471 }
472 }
473 }
474
475 hb_handle.abort();
477 writer_handle.abort();
478
479 Err(SupaError::RealtimeError {
480 message: "Connection ended".into(),
481 })
482}
483
484#[derive(Debug, Serialize, Deserialize)]
489pub struct RealtimeMessage {
490 pub topic: String,
491 pub event: String,
492 pub payload: Value,
493 #[serde(rename = "ref")]
494 pub ref_: Option<String>,
495}
496
497impl RealtimeMessage {
498 pub fn is_postgres_change(&self) -> bool {
500 self.event == "postgres_changes"
501 || self.event == "INSERT"
502 || self.event == "UPDATE"
503 || self.event == "DELETE"
504 }
505
506 pub fn is_presence(&self) -> bool {
508 self.event == "presence_state" || self.event == "presence_diff"
509 }
510
511 pub fn is_broadcast(&self) -> bool {
513 self.event == "broadcast"
514 }
515
516 pub fn as_insert<T: DeserializeOwned>(&self) -> Result<T> {
518 self.extract_record("INSERT")
519 }
520
521 pub fn as_update<T: DeserializeOwned>(&self) -> Result<T> {
522 self.extract_record("UPDATE")
523 }
524
525 pub fn as_delete<T: DeserializeOwned>(&self) -> Result<T> {
526 self.extract_record("DELETE")
527 }
528
529 fn extract_record<T: DeserializeOwned>(&self, expected_type: &str) -> Result<T> {
531 let type_ = self
535 .payload
536 .get("type")
537 .and_then(|v| v.as_str())
538 .unwrap_or_default();
539
540 if !type_.is_empty() && type_ != expected_type {
543 return Err(SupaError::RealtimeError {
544 message: format!("Expected type {}, got {}", expected_type, type_),
545 });
546 }
547
548 let record_key = if expected_type == "DELETE" {
549 "old_record"
550 } else {
551 "record"
552 };
553 let record = self.payload.get(record_key);
554
555 match record {
556 Some(val) if !val.is_null() => {
557 serde_json::from_value(val.clone()).map_err(|e| SupaError::RealtimeError {
558 message: format!("Deserialization failed: {}", e),
559 })
560 }
561 _ => {
562 let fallback = self
564 .payload
565 .get("record")
566 .or_else(|| self.payload.get("old_record"));
567 if let Some(val) = fallback {
568 if !val.is_null() {
569 return serde_json::from_value(val.clone()).map_err(|e| {
570 SupaError::RealtimeError {
571 message: format!("Deserialization failed (fallback): {}", e),
572 }
573 });
574 }
575 }
576 Err(SupaError::RealtimeError {
577 message: format!("No {} found in payload", record_key),
578 })
579 }
580 }
581 }
582}