bevy_realtime/
client.rs

1#![allow(dead_code)]
2
3use std::error::Error;
4use std::fmt::{Debug, Display};
5use std::thread::sleep;
6use std::time::SystemTime;
7use std::{
8    collections::HashMap,
9    io,
10    net::{TcpStream, ToSocketAddrs},
11    time::Duration,
12};
13
14use bevy::ecs::system::SystemId;
15use bevy::log::{debug, info};
16use bevy::prelude::*;
17use bevy_crossbeam_event::CrossbeamEventSender;
18use crossbeam::channel::{unbounded, Receiver, SendError, Sender, TryRecvError};
19use native_tls::TlsConnector;
20use tungstenite::{client, Message};
21use tungstenite::{
22    client::{uri_mode, IntoClientRequest},
23    handshake::MidHandshake,
24    http::{HeaderMap, HeaderValue, Response as HttpResponse, StatusCode, Uri},
25    stream::{MaybeTlsStream, Mode, NoDelay},
26    ClientHandshake, Error as TungsteniteError, HandshakeError, WebSocket as WebSocketWrapper,
27};
28use uuid::Uuid;
29
30use super::channel::{ChannelState, RealtimeChannel};
31use crate::message::payload::Payload;
32use crate::message::realtime_message::RealtimeMessage;
33
34use super::channel::ChannelBuilder;
35
36pub type Response = HttpResponse<Option<Vec<u8>>>;
37pub type WebSocket = WebSocketWrapper<MaybeTlsStream<TcpStream>>;
38
39/// Connection state of [RealtimeClient]
40#[derive(PartialEq, Debug, Default, Clone, Copy, Event)]
41pub enum ConnectionState {
42    /// Client wants to reconnect
43    Reconnect,
44    /// Client is mid-reconnect
45    Reconnecting,
46    Connecting,
47    Open,
48    Closing,
49    #[default]
50    Closed,
51}
52
53/// Error returned by [RealtimeClient::next_message()].
54/// Can be WouldBlock
55#[derive(PartialEq, Debug)]
56pub enum NextMessageError {
57    WouldBlock,
58    TryRecvError(TryRecvError),
59    NoChannel,
60    ChannelClosed,
61    ClientClosed,
62    SocketError(SocketError),
63    MonitorError(MonitorError),
64}
65
66impl Error for NextMessageError {}
67impl Display for NextMessageError {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        f.write_str(&format!("{:?}", self))
70    }
71}
72
73/// Error returned by the internal monitor of [RealtimeClient]
74#[derive(PartialEq, Debug)]
75pub enum MonitorError {
76    ReconnectError,
77    MaxReconnects,
78    WouldBlock,
79    Disconnected,
80}
81
82/// Error type for internal socket related errors in [RealtimeClient]
83#[derive(Debug, PartialEq)]
84pub enum SocketError {
85    NoSocket,
86    NoRead,
87    NoWrite,
88    Disconnected,
89    WouldBlock,
90    TooManyRetries,
91    HandshakeError,
92}
93
94/// Error returned by [RealtimeClient::connect()]
95#[derive(Debug, PartialEq, Clone, Copy)]
96pub enum ConnectError {
97    BadUri,
98    BadHost,
99    BadAddrs,
100    StreamError,
101    NoDelayError,
102    NonblockingError,
103    HandshakeError,
104    MaxRetries,
105    WrongProtocol,
106}
107
108pub(crate) struct MessageChannel((Sender<RealtimeMessage>, Receiver<RealtimeMessage>));
109
110impl Default for MessageChannel {
111    fn default() -> Self {
112        Self(crossbeam::channel::unbounded())
113    }
114}
115
116struct MonitorChannel((Sender<MonitorSignal>, Receiver<MonitorSignal>));
117
118impl Default for MonitorChannel {
119    fn default() -> Self {
120        Self(crossbeam::channel::unbounded())
121    }
122}
123
124#[derive(Debug, PartialEq)]
125enum MonitorSignal {
126    Reconnect,
127}
128
129#[derive(Clone)]
130pub struct ClientManager {
131    tx: Sender<ClientManagerMessage>,
132}
133
134pub enum ClientManagerMessage {
135    Channel {
136        callback: SystemId<In<ChannelBuilder>>,
137    },
138    AddChannel {
139        channel: RealtimeChannel,
140    },
141    SetAccessToken {
142        token: String,
143    },
144    ConnectionState {
145        sender: CrossbeamEventSender<ConnectionState>,
146    },
147    Connect {
148        callback: SystemId<In<Result<(), ConnectError>>>,
149    },
150}
151
152impl ClientManager {
153    pub fn new(client: &Client) -> Self {
154        Self {
155            tx: client.manager_tx.clone(),
156        }
157    }
158
159    pub fn connect(
160        &self,
161        callback: SystemId<In<Result<(), ConnectError>>>,
162    ) -> Result<(), SendError<ClientManagerMessage>> {
163        self.tx.send(ClientManagerMessage::Connect { callback })
164    }
165
166    pub fn channel(
167        &self,
168        callback: SystemId<In<ChannelBuilder>>,
169    ) -> Result<(), SendError<ClientManagerMessage>> {
170        self.tx.send(ClientManagerMessage::Channel { callback })
171    }
172
173    pub fn add_channel(
174        &self,
175        channel: RealtimeChannel,
176    ) -> Result<(), SendError<ClientManagerMessage>> {
177        self.tx.send(ClientManagerMessage::AddChannel { channel })
178    }
179
180    pub fn set_access_token(&self, token: String) -> Result<(), SendError<ClientManagerMessage>> {
181        self.tx.send(ClientManagerMessage::SetAccessToken { token })
182    }
183
184    pub fn connection_state(
185        &self,
186        sender: CrossbeamEventSender<ConnectionState>,
187    ) -> Result<(), SendError<ClientManagerMessage>> {
188        self.tx
189            .send(ClientManagerMessage::ConnectionState { sender })
190    }
191}
192
193/// Synchronous websocket client that interfaces with Supabase Realtime
194pub struct Client {
195    pub(crate) access_token: String,
196    connection_state: ConnectionState,
197    socket: Option<WebSocket>,
198    channels: HashMap<Uuid, RealtimeChannel>,
199    messages_this_second: Vec<SystemTime>,
200    next_ref: Uuid,
201    // mpsc
202    pub(crate) outbound_channel: MessageChannel,
203    inbound_channel: MessageChannel,
204    monitor_channel: MonitorChannel,
205    middleware: HashMap<Uuid, Box<dyn Fn(RealtimeMessage) -> RealtimeMessage + Send + Sync>>,
206    // timers
207    reconnect_now: Option<SystemTime>,
208    reconnect_delay: Duration,
209    reconnect_attempts: usize,
210    heartbeat_now: Option<SystemTime>,
211    // builder options
212    headers: HeaderMap,
213    params: Option<HashMap<String, String>>,
214    heartbeat_interval: Duration,
215    encode: Option<Box<dyn Fn(RealtimeMessage) -> RealtimeMessage + Send + Sync>>,
216    decode: Option<Box<dyn Fn(RealtimeMessage) -> RealtimeMessage + Send + Sync>>,
217    reconnect_interval: ReconnectFn,
218    reconnect_max_attempts: usize,
219    connection_timeout: Duration,
220    auth_url: Option<String>,
221    endpoint: String,
222    max_events_per_second: usize,
223    // sync bridge
224    manager_rx: Receiver<ClientManagerMessage>,
225    manager_tx: Sender<ClientManagerMessage>,
226    channel_callback_event_sender: CrossbeamEventSender<ChannelCallbackEvent>,
227    connect_result_callback_event_sender: CrossbeamEventSender<ConnectResultCallbackEvent>,
228}
229
230#[derive(Event, Clone)]
231pub struct ChannelCallbackEvent(pub (SystemId<In<ChannelBuilder>>, ChannelBuilder));
232
233#[derive(Event, Clone)]
234pub struct ConnectResultCallbackEvent(
235    pub  (
236        SystemId<In<Result<(), ConnectError>>>,
237        Result<(), ConnectError>,
238    ),
239);
240
241impl Client {
242    pub fn manager_recv(&mut self) -> Result<(), Box<dyn Error>> {
243        while let Ok(message) = self.manager_rx.try_recv() {
244            match message {
245                ClientManagerMessage::Channel { callback } => {
246                    let c = self.channel();
247
248                    debug!("got channel, sending to callback...");
249
250                    self.channel_callback_event_sender
251                        .send(ChannelCallbackEvent((callback, c)));
252                }
253                ClientManagerMessage::AddChannel { channel } => self.add_channel(channel),
254                ClientManagerMessage::SetAccessToken { token } => {
255                    self.access_token = token;
256                }
257                ClientManagerMessage::ConnectionState { sender } => {
258                    sender.send(self.connection_state);
259                }
260                ClientManagerMessage::Connect { callback } => {
261                    let result = self.connect();
262                    self.connect_result_callback_event_sender
263                        .send(ConnectResultCallbackEvent((callback, result)))
264                }
265            }
266        }
267
268        Ok(())
269    }
270    /// Returns a new [RealtimeClientBuilder] with provided `endpoint` and `access_token`
271    pub fn builder(endpoint: impl Into<String>, access_token: impl Into<String>) -> ClientBuilder {
272        ClientBuilder::new(endpoint, access_token)
273    }
274
275    /// Returns this client's [ConnectionState]
276    pub fn get_status(&self) -> ConnectionState {
277        self.connection_state
278    }
279
280    /// Returns a new [RealtimeChannelBuilder] instantiated with the provided `topic`
281    pub fn channel(&mut self) -> ChannelBuilder {
282        ChannelBuilder::new(self)
283    }
284
285    /// Attempt to create a websocket connection with the server
286    pub fn connect(&mut self) -> Result<(), ConnectError> {
287        info!("connecting...");
288        self.connection_state = ConnectionState::Connecting;
289
290        let _ = self.manager_recv();
291
292        let uri: Uri = match format!(
293            "{}/websocket?apikey={}&vsn=1.0.0",
294            self.endpoint, self.access_token
295        )
296        .parse()
297        {
298            Ok(uri) => uri,
299            Err(_e) => return Err(ConnectError::BadUri),
300        };
301
302        // TODO REFAC tidy
303        let ws_scheme = match uri.scheme_str() {
304            Some(scheme) => {
305                if scheme == "http" {
306                    "ws"
307                } else {
308                    "wss"
309                }
310            }
311            None => "ws",
312        };
313
314        let mut add_params = String::new();
315        if let Some(params) = &self.params {
316            for (field, value) in params {
317                add_params = format!("{add_params}&{field}={value}");
318            }
319        }
320
321        let mut p_q = uri.path_and_query().unwrap().to_string();
322
323        if !add_params.is_empty() {
324            p_q = format!("{p_q}{add_params}");
325        }
326
327        let uri = Uri::builder()
328            .scheme(ws_scheme)
329            .authority(uri.authority().unwrap().clone())
330            .path_and_query(p_q)
331            .build()
332            .unwrap();
333
334        let Ok(mut request) = uri.clone().into_client_request() else {
335            return Err(ConnectError::BadUri);
336        };
337
338        let headers = request.headers_mut();
339
340        let auth: HeaderValue = format!("Bearer {}", self.access_token)
341            .parse()
342            .expect("malformed access token?");
343        headers.insert("Authorization", auth);
344
345        // unwrap: shouldn't fail
346        let xci: HeaderValue = "realtime-rs/0.1.0".to_string().parse().unwrap();
347        headers.insert("X-Client-Info", xci);
348
349        debug!("Connecting... Req: {:?}\n", request);
350
351        let uri = request.uri();
352
353        let Ok(mode) = uri_mode(uri) else {
354            return Err(ConnectError::BadUri);
355        };
356
357        let Some(host) = uri.host() else {
358            return Err(ConnectError::BadHost);
359        };
360
361        let port = uri.port_u16().unwrap_or(match mode {
362            Mode::Plain => 80,
363            Mode::Tls => 443,
364        });
365
366        let Ok(mut addrs) = (host, port).to_socket_addrs() else {
367            return Err(ConnectError::BadAddrs);
368        };
369
370        let mut stream = match TcpStream::connect_timeout(
371            &addrs.next().expect("uhoh no addr"),
372            self.connection_timeout,
373        ) {
374            Ok(stream) => {
375                self.reconnect_attempts = 0;
376                stream
377            }
378            Err(_e) => {
379                // TODO err data
380                return self.retry_connect();
381            }
382        };
383
384        let Ok(()) = NoDelay::set_nodelay(&mut stream, true) else {
385            return Err(ConnectError::NoDelayError);
386        };
387
388        let maybe_tls = match mode {
389            Mode::Tls => {
390                let connector = TlsConnector::new().expect("No TLS tings");
391
392                let connected_stream = connector
393                    .connect(host, stream.try_clone().expect("noclone"))
394                    .unwrap();
395
396                stream
397                    .set_nonblocking(true)
398                    .expect("blocking mode oh nooooo");
399
400                MaybeTlsStream::NativeTls(connected_stream)
401            }
402            Mode::Plain => {
403                stream
404                    .set_nonblocking(true)
405                    .expect("blocking mode oh nooooo");
406
407                MaybeTlsStream::Plain(stream)
408            }
409        };
410
411        let conn: Result<(WebSocket, Response), TungsteniteError> = match client(request, maybe_tls)
412        {
413            Ok(stream) => {
414                self.reconnect_attempts = 0;
415                Ok(stream)
416            }
417            Err(err) => match err {
418                HandshakeError::Failure(_err) => {
419                    // TODO err data
420                    return self.retry_connect();
421                }
422                HandshakeError::Interrupted(mid_hs) => match self.retry_handshake(mid_hs) {
423                    Ok(stream) => Ok(stream),
424                    Err(_err) => {
425                        // TODO err data
426                        return self.retry_connect();
427                    }
428                },
429            },
430        };
431
432        let (socket, res) = conn.expect("Handshake fail");
433
434        if res.status() != StatusCode::SWITCHING_PROTOCOLS {
435            return Err(ConnectError::WrongProtocol);
436        }
437
438        self.socket = Some(socket);
439
440        self.connection_state = ConnectionState::Open;
441        info!("connected");
442
443        Ok(())
444    }
445
446    fn retry_connect(&mut self) -> Result<(), ConnectError> {
447        debug!(
448            "Retry count {}/{}",
449            self.reconnect_attempts, self.reconnect_max_attempts
450        );
451        debug!(
452            "Waiting {}s...",
453            self.reconnect_interval.0(self.reconnect_attempts).as_secs()
454        );
455
456        if self.reconnect_attempts < self.reconnect_max_attempts {
457            self.reconnect_attempts += 1;
458            let backoff = &self.reconnect_interval.0;
459            sleep(backoff(self.reconnect_attempts));
460            return self.connect();
461        }
462
463        Err(ConnectError::MaxRetries)
464    }
465
466    /// Disconnect the client
467    pub fn disconnect(&mut self) {
468        if self.connection_state == ConnectionState::Closed {
469            return;
470        }
471
472        self.remove_all_channels();
473
474        self.connection_state = ConnectionState::Closed;
475
476        let Some(ref mut socket) = self.socket else {
477            debug!("Already disconnected. {:?}", self.connection_state);
478            return;
479        };
480
481        let _ = socket.close(None);
482        debug!("Client disconnected. {:?}", self.connection_state);
483    }
484
485    /// Queues a [RealtimeMessage] for sending to the server
486    pub fn send(&mut self, msg: RealtimeMessage) -> Result<(), SendError<RealtimeMessage>> {
487        self.outbound_channel.0 .0.send(msg)
488    }
489
490    /// Returns an optional mutable reference to the [RealtimeChannel] with the provided [Uuid].
491    /// If `channel_id` is not found returns [None]
492    pub fn get_channel_mut(&mut self, channel_id: Uuid) -> Option<&mut RealtimeChannel> {
493        self.channels.get_mut(&channel_id)
494    }
495
496    /// Returns an optional reference to the [RealtimeChannel] with the provided [Uuid].
497    /// If `channel_id` is not found returns [None]
498    pub fn get_channel(&self, channel_id: Uuid) -> Option<&RealtimeChannel> {
499        self.channels.get(&channel_id)
500    }
501
502    /// Returns a reference to this client's HashMap of channels
503    pub fn get_channels(&self) -> &HashMap<Uuid, RealtimeChannel> {
504        &self.channels
505    }
506
507    /// Returns [Some(RealtimeChannel)] if channel was successfully removed, [None] if the channel
508    /// was not found.
509    pub fn remove_channel(&mut self, channel_id: Uuid) -> Option<RealtimeChannel> {
510        if let Some(mut channel) = self.channels.remove(&channel_id) {
511            let _ = channel.unsubscribe();
512
513            if self.channels.is_empty() {
514                self.disconnect();
515            }
516
517            return Some(channel);
518        }
519
520        None
521    }
522
523    /// Blocks the current thread until the channel with the provided `channel_id` has subscribed.
524    pub fn block_until_subscribed(&mut self, channel_id: Uuid) -> Result<Uuid, ChannelState> {
525        // TODO ergonomically this would fit better as a function on RealtimeChannel but borrow
526        // checker
527
528        let channel = self.channels.get_mut(&channel_id);
529
530        let channel = channel.unwrap();
531
532        if channel.connection_state == ChannelState::Joined {
533            return Ok(channel.id);
534        }
535
536        if channel.connection_state != ChannelState::Joining {
537            self.channels
538                .get_mut(&channel_id)
539                .unwrap()
540                .subscribe()
541                .unwrap();
542        }
543
544        loop {
545            match self.step() {
546                Ok(channel_ids) => {
547                    debug!(
548                        "[Blocking Subscribe] Message forwarded to {:?}",
549                        channel_ids
550                    )
551                }
552                Err(NextMessageError::WouldBlock) => {}
553                Err(_e) => {
554                    //println!("NextMessageError: {:?}", e)
555                }
556            }
557
558            let channel = self.channels.get_mut(&channel_id).unwrap();
559
560            match channel.connection_state {
561                ChannelState::Joined => {
562                    break;
563                }
564                ChannelState::Closed => return Err(ChannelState::Closed),
565                _ => {}
566            }
567        }
568
569        Ok(channel_id)
570    }
571
572    /// Use provided JWT to authorize future requests from this client and all channels
573    pub fn set_auth(&mut self, access_token: String) {
574        self.access_token.clone_from(&access_token);
575
576        for channel in self.channels.values_mut() {
577            // TODO single source of data for access token
578            let _ = channel.set_auth(access_token.clone()); // TODO error handling
579        }
580    }
581
582    /// Add a callback to run mutably on recieved [RealtimeMessage]s before any other registered
583    /// callbacks.
584    pub fn add_middleware(
585        &mut self,
586        middleware: Box<dyn Fn(RealtimeMessage) -> RealtimeMessage + Send + Sync>,
587    ) -> Uuid {
588        // TODO user defined middleware ordering
589        let uuid = Uuid::new_v4();
590        self.middleware.insert(uuid, middleware);
591        uuid
592    }
593
594    /// Remove middleware by it's [Uuid]
595    pub fn remove_middleware(&mut self, uuid: Uuid) -> &mut Client {
596        self.middleware.remove(&uuid);
597        self
598    }
599
600    /// The main step function for driving the [RealtimeClient]
601    pub fn step(&mut self) -> Result<Vec<Uuid>, NextMessageError> {
602        // TODO run manager_recv fns until channels drained
603        match self.manager_recv() {
604            Ok(()) => {}
605            Err(e) => debug!("client manager_recv error: {}", e),
606        }
607
608        for channel in self.channels.values_mut() {
609            match channel.manager_recv() {
610                Ok(()) => {}
611                Err(e) => debug!("channel manager_recv error: {}", e),
612            }
613        }
614
615        match self.connection_state {
616            ConnectionState::Closed => {
617                return Err(NextMessageError::ClientClosed);
618            }
619            ConnectionState::Reconnect => {
620                let _ = self.monitor_channel.0 .0.send(MonitorSignal::Reconnect);
621                return Err(NextMessageError::SocketError(SocketError::Disconnected));
622            }
623            ConnectionState::Reconnecting => {
624                return Err(NextMessageError::WouldBlock);
625            }
626            ConnectionState::Connecting => {}
627            ConnectionState::Closing => {}
628            ConnectionState::Open => {}
629        }
630
631        match self.run_monitor() {
632            Ok(_) => {}
633            Err(MonitorError::WouldBlock) => {}
634            Err(MonitorError::MaxReconnects) => {
635                self.disconnect();
636                return Err(NextMessageError::MonitorError(MonitorError::MaxReconnects));
637            }
638            Err(e) => {
639                return Err(NextMessageError::MonitorError(e));
640            }
641        }
642
643        self.run_heartbeat();
644
645        match self.write_socket() {
646            Ok(()) => {}
647            Err(SocketError::WouldBlock) => {}
648            Err(e) => {
649                self.reconnect();
650                return Err(NextMessageError::SocketError(e));
651            }
652        }
653
654        match self.read_socket() {
655            Ok(()) => {}
656            Err(e) => {
657                self.reconnect();
658                return Err(NextMessageError::SocketError(e));
659            }
660        }
661
662        match self.inbound_channel.0 .1.try_recv() {
663            Ok(mut message) => {
664                let mut ids = vec![];
665                // TODO filter & route system messages and the like
666
667                // Run middleware
668                message = self.run_middleware(message);
669
670                // Send message to channel
671                for (id, channel) in &mut self.channels {
672                    if channel.topic == message.topic {
673                        channel.recieve(message.clone());
674                        ids.push(*id);
675                    }
676                }
677
678                Ok(ids)
679            }
680            Err(TryRecvError::Empty) => Err(NextMessageError::WouldBlock),
681            Err(e) => Err(NextMessageError::TryRecvError(e)),
682        }
683    }
684
685    pub(crate) fn add_channel(&mut self, channel: RealtimeChannel) {
686        self.channels.insert(channel.id, channel);
687    }
688
689    pub(crate) fn get_channel_tx(&self) -> Sender<RealtimeMessage> {
690        self.outbound_channel.0 .0.clone()
691    }
692
693    fn run_middleware(&self, mut message: RealtimeMessage) -> RealtimeMessage {
694        for middleware in self.middleware.values() {
695            message = middleware(message)
696        }
697        message
698    }
699
700    fn remove_all_channels(&mut self) {
701        if self.connection_state == ConnectionState::Closing
702            || self.connection_state == ConnectionState::Closed
703        {
704            return;
705        }
706
707        self.connection_state = ConnectionState::Closing;
708
709        // wait until inbound_rx is drained
710        loop {
711            let recv = self.step();
712
713            if Err(NextMessageError::WouldBlock) == recv {
714                break;
715            }
716        }
717
718        loop {
719            let _ = self.step();
720
721            let mut all_channels_closed = true;
722
723            for channel in self.channels.values_mut() {
724                let channel_state = channel.unsubscribe();
725
726                match channel_state {
727                    Ok(state) => {
728                        if state != ChannelState::Closed {
729                            all_channels_closed = false;
730                        }
731                    }
732                    Err(e) => {
733                        // TODO error handling
734                        debug!("Unsubscribe error: {:?}", e);
735                    }
736                }
737            }
738
739            if all_channels_closed {
740                debug!("All channels closed!");
741                break;
742            }
743        }
744
745        self.channels.clear();
746    }
747
748    fn retry_handshake(
749        &mut self,
750        mid_hs: MidHandshake<ClientHandshake<MaybeTlsStream<TcpStream>>>,
751    ) -> Result<(WebSocket, Response), SocketError> {
752        match mid_hs.handshake() {
753            Ok(stream) => Ok(stream),
754            Err(e) => match e {
755                HandshakeError::Interrupted(mid_hs) => {
756                    // TODO sleeping main thread bad
757                    if self.reconnect_attempts < self.reconnect_max_attempts {
758                        self.reconnect_attempts += 1;
759                        let backoff = &self.reconnect_interval.0;
760                        sleep(backoff(self.reconnect_attempts));
761                        return self.retry_handshake(mid_hs);
762                    }
763
764                    Err(SocketError::TooManyRetries)
765                }
766                HandshakeError::Failure(_err) => {
767                    // TODO pass error data
768                    Err(SocketError::HandshakeError)
769                }
770            },
771        }
772    }
773
774    fn run_monitor(&mut self) -> Result<(), MonitorError> {
775        if self.reconnect_now.is_none() {
776            self.reconnect_now = Some(SystemTime::now());
777
778            self.reconnect_delay = self.reconnect_interval.0(self.reconnect_attempts);
779        }
780
781        match self.monitor_channel.0 .1.try_recv() {
782            Ok(signal) => match signal {
783                MonitorSignal::Reconnect => {
784                    if self.connection_state == ConnectionState::Open
785                        || self.connection_state == ConnectionState::Reconnecting
786                        || SystemTime::now() < self.reconnect_now.unwrap() + self.reconnect_delay
787                    {
788                        return Err(MonitorError::WouldBlock);
789                    }
790
791                    if self.reconnect_attempts >= self.reconnect_max_attempts {
792                        return Err(MonitorError::MaxReconnects);
793                    }
794
795                    self.connection_state = ConnectionState::Reconnecting;
796                    self.reconnect_attempts += 1;
797                    self.reconnect_now.take();
798
799                    match self.connect() {
800                        Ok(_) => {
801                            for channel in self.channels.values_mut() {
802                                channel.subscribe().unwrap();
803                            }
804
805                            Ok(())
806                        }
807                        Err(e) => {
808                            debug!("reconnect error: {:?}", e);
809                            self.connection_state = ConnectionState::Reconnect;
810                            Err(MonitorError::ReconnectError)
811                        }
812                    }
813                }
814            },
815            Err(TryRecvError::Empty) => Err(MonitorError::WouldBlock),
816            Err(TryRecvError::Disconnected) => Err(MonitorError::Disconnected),
817        }
818    }
819
820    fn run_heartbeat(&mut self) {
821        if self.heartbeat_now.is_none() {
822            self.heartbeat_now = Some(SystemTime::now());
823        }
824
825        if self.heartbeat_now.unwrap() + self.heartbeat_interval > SystemTime::now() {
826            return;
827        }
828
829        self.heartbeat_now.take();
830
831        let _ = self.send(RealtimeMessage::heartbeat());
832    }
833
834    fn read_socket(&mut self) -> Result<(), SocketError> {
835        let Some(ref mut socket) = self.socket else {
836            return Err(SocketError::NoSocket);
837        };
838
839        if !socket.can_read() {
840            return Err(SocketError::NoRead);
841        }
842
843        match socket.read() {
844            Ok(raw_message) => match raw_message {
845                Message::Text(string_message) => {
846                    // TODO recoverable error
847                    let mut message: RealtimeMessage =
848                        serde_json::from_str(&string_message).expect("Deserialization error: ");
849
850                    debug!("[RECV] {:?}", message);
851
852                    if let Some(decode) = &self.decode {
853                        message = decode(message);
854                    }
855
856                    if let Payload::Empty {} = message.payload {
857                        debug!("Possibly malformed payload: {:?}", string_message)
858                    }
859
860                    let _ = self.inbound_channel.0 .0.send(message);
861                    Ok(())
862                }
863                Message::Close(_close_message) => {
864                    self.disconnect();
865                    Err(SocketError::Disconnected)
866                }
867                _ => {
868                    // do nothing on ping, pong, binary messages
869                    Err(SocketError::WouldBlock)
870                }
871            },
872            Err(TungsteniteError::Io(err)) if err.kind() == io::ErrorKind::WouldBlock => {
873                // do nothing here :)
874                Ok(())
875            }
876            Err(err) => {
877                debug!("Socket read error: {:?}", err);
878                self.connection_state = ConnectionState::Reconnect;
879                let _ = self.monitor_channel.0 .0.send(MonitorSignal::Reconnect);
880                Err(SocketError::WouldBlock)
881            }
882        }
883    }
884
885    fn write_socket(&mut self) -> Result<(), SocketError> {
886        let Some(ref mut socket) = self.socket else {
887            return Err(SocketError::NoSocket);
888        };
889
890        if !socket.can_write() {
891            return Err(SocketError::NoWrite);
892        }
893
894        // Throttling
895        let now = SystemTime::now();
896
897        self.messages_this_second = self
898            .messages_this_second
899            .clone() // TODO do i need this clone? can i mutate in-place?
900            .into_iter()
901            .filter(|st| now.duration_since(*st).unwrap_or_default() < Duration::from_secs(1))
902            .collect();
903
904        if self.messages_this_second.len() >= self.max_events_per_second {
905            return Err(SocketError::WouldBlock);
906        }
907
908        // Send to server
909        // TODO should drain outbound_channel
910        // TODO drain should respect throttling
911        let message = self.outbound_channel.0 .1.try_recv();
912
913        // So should be a while try recv loop
914
915        match message {
916            Ok(mut message) => {
917                if message.message_ref.is_none() {
918                    message.message_ref = Some(self.next_ref.into());
919                    self.next_ref = Uuid::new_v4();
920                }
921
922                if let Some(encode) = &self.encode {
923                    message = encode(message);
924                }
925
926                let raw = serde_json::to_string(&message);
927                debug!("[SEND] {:?}", raw);
928
929                let _ = socket.send(message.into());
930                self.messages_this_second.push(now);
931                Ok(())
932            }
933            Err(TryRecvError::Empty) => {
934                // do nothing
935                Ok(())
936            }
937            Err(e) => {
938                debug!("outbound error: {:?}", e);
939                self.connection_state = ConnectionState::Reconnect;
940                let _ = self.monitor_channel.0 .0.send(MonitorSignal::Reconnect);
941                Err(SocketError::WouldBlock)
942            }
943        }
944    }
945
946    fn reconnect(&mut self) {
947        self.connection_state = ConnectionState::Reconnect;
948        let _ = self.monitor_channel.0 .0.send(MonitorSignal::Reconnect);
949    }
950}
951
952/// Takes a `Box<dyn Fn(usize) -> Duration>`
953/// The provided function should take a count of reconnect attempts and return a [Duration] to wait
954/// until the next attempt is made.
955pub struct ReconnectFn(pub Box<dyn Fn(usize) -> Duration + Send + Sync>);
956
957impl ReconnectFn {
958    pub fn new(f: impl Fn(usize) -> Duration + 'static + Sync + Send) -> Self {
959        Self(Box::new(f))
960    }
961}
962
963impl Default for ReconnectFn {
964    fn default() -> Self {
965        Self(Box::new(backoff))
966    }
967}
968
969impl Debug for ReconnectFn {
970    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
971        f.write_str("TODO reconnect fn debug")
972    }
973}
974
975/// Builder struct for [RealtimeClient]
976pub struct ClientBuilder {
977    headers: HeaderMap,
978    params: Option<HashMap<String, String>>,
979    heartbeat_interval: Duration,
980    encode: Option<Box<dyn Fn(RealtimeMessage) -> RealtimeMessage + Send + Sync>>,
981    decode: Option<Box<dyn Fn(RealtimeMessage) -> RealtimeMessage + Send + Sync>>,
982    reconnect_interval: ReconnectFn,
983    reconnect_max_attempts: usize,
984    connection_timeout: Duration,
985    auth_url: Option<String>,
986    endpoint: String,
987    access_token: String,
988    max_events_per_second: usize,
989}
990
991impl ClientBuilder {
992    pub fn new(endpoint: impl Into<String>, access_token: impl Into<String>) -> Self {
993        let mut headers = HeaderMap::new();
994        headers.insert("X-Client-Info", "realtime-rs/0.1.0".parse().unwrap());
995
996        Self {
997            headers,
998            params: Default::default(),
999            heartbeat_interval: Duration::from_secs(29),
1000            encode: Default::default(),
1001            decode: Default::default(),
1002            reconnect_interval: ReconnectFn(Box::new(backoff)),
1003            reconnect_max_attempts: usize::MAX,
1004            connection_timeout: Duration::from_secs(10),
1005            auth_url: Default::default(),
1006            endpoint: endpoint.into(),
1007            access_token: access_token.into(),
1008            max_events_per_second: 10,
1009        }
1010    }
1011
1012    /// Sets the client headers. Headers always contain "X-Client-Info: realtime-rs/{version}".
1013    pub fn set_headers(&mut self, set_headers: HeaderMap) -> &mut Self {
1014        let mut headers = HeaderMap::new();
1015        headers.insert("X-Client-Info", "realtime-rs/0.1.0".parse().unwrap());
1016        headers.extend(set_headers);
1017
1018        self.headers = headers;
1019
1020        self
1021    }
1022
1023    /// Merges provided [HeaderMap] with currently held headers
1024    pub fn add_headers(&mut self, headers: HeaderMap) -> &mut Self {
1025        self.headers.extend(headers);
1026        self
1027    }
1028
1029    /// Set endpoint URL params
1030    pub fn params(&mut self, params: HashMap<String, String>) -> &mut Self {
1031        self.params = Some(params);
1032        self
1033    }
1034
1035    /// Set [Duration] between heartbeat packets. Default 29 seconds.
1036    pub fn heartbeat_interval(&mut self, heartbeat_interval: Duration) -> &mut Self {
1037        self.heartbeat_interval = heartbeat_interval;
1038        self
1039    }
1040
1041    /// Set the function to provide time between reconnection attempts
1042    /// The provided function should take a count of reconnect attempts and return a [Duration] to wait
1043    /// until the next attempt is made.
1044    ///
1045    /// Don't implement an untested timing function here in prod or you might make a few too many
1046    /// requests.
1047    ///
1048    /// Defaults to stepped backoff
1049    pub fn reconnect_interval(&mut self, reconnect_interval: ReconnectFn) -> &mut Self {
1050        // TODO minimum interval to prevent 10000000 requests in seconds
1051        // then again it takes a bit of work to make that mistake?
1052        self.reconnect_interval = reconnect_interval;
1053        self
1054    }
1055
1056    /// Configure the number of recconect attempts to be made before erroring
1057    pub fn reconnect_max_attempts(&mut self, max_attempts: usize) -> &mut Self {
1058        self.reconnect_max_attempts = max_attempts;
1059        self
1060    }
1061
1062    /// Configure the duration to wait for a connection to succeed.
1063    /// Default: 10 seconds
1064    /// Minimum: 1 second
1065    pub fn connection_timeout(&mut self, timeout: Duration) -> &mut Self {
1066        // 1 sec min timeout
1067        let timeout = if timeout < Duration::from_secs(1) {
1068            Duration::from_secs(1)
1069        } else {
1070            timeout
1071        };
1072
1073        self.connection_timeout = timeout;
1074        self
1075    }
1076
1077    /// Set the base URL for the auth server
1078    /// In live supabase deployments this is the same as the endpoint URL, and defaults as such.
1079    /// In local deployments this may need to be set manually
1080    pub fn auth_url(&mut self, auth_url: impl Into<String>) -> &mut Self {
1081        self.auth_url = Some(auth_url.into());
1082        self
1083    }
1084
1085    /// Sets the max messages we can send in a second.
1086    /// Default: 10
1087    pub fn max_events_per_second(&mut self, count: usize) -> &mut Self {
1088        self.max_events_per_second = count;
1089        self
1090    }
1091
1092    pub fn encode(
1093        &mut self,
1094        encode: impl Fn(RealtimeMessage) -> RealtimeMessage + 'static + Send + Sync,
1095    ) -> &mut Self {
1096        self.encode = Some(Box::new(encode));
1097        self
1098    }
1099
1100    pub fn decode(
1101        &mut self,
1102        decode: impl Fn(RealtimeMessage) -> RealtimeMessage + 'static + Send + Sync,
1103    ) -> &mut Self {
1104        self.decode = Some(Box::new(decode));
1105        self
1106    }
1107
1108    /// Consume the [Self] and return a configured [RealtimeClient]
1109    pub fn build(
1110        self,
1111        channel_callback_event_sender: CrossbeamEventSender<ChannelCallbackEvent>,
1112        connect_result_callback_event_sender: CrossbeamEventSender<ConnectResultCallbackEvent>,
1113    ) -> Client {
1114        let (manager_tx, manager_rx) = unbounded();
1115        Client {
1116            headers: self.headers,
1117            params: self.params,
1118            heartbeat_interval: self.heartbeat_interval,
1119            encode: self.encode,
1120            decode: self.decode,
1121            reconnect_interval: self.reconnect_interval,
1122            reconnect_max_attempts: self.reconnect_max_attempts,
1123            connection_timeout: self.connection_timeout,
1124            auth_url: self.auth_url,
1125            endpoint: self.endpoint,
1126            access_token: self.access_token,
1127            max_events_per_second: self.max_events_per_second,
1128            next_ref: Uuid::new_v4(),
1129            connection_state: Default::default(),
1130            socket: Default::default(),
1131            channels: Default::default(),
1132            messages_this_second: Default::default(),
1133            outbound_channel: Default::default(),
1134            inbound_channel: Default::default(),
1135            monitor_channel: Default::default(),
1136            middleware: Default::default(),
1137            reconnect_now: Default::default(),
1138            reconnect_delay: Default::default(),
1139            reconnect_attempts: Default::default(),
1140            heartbeat_now: Default::default(),
1141            manager_rx,
1142            manager_tx,
1143            channel_callback_event_sender,
1144            connect_result_callback_event_sender,
1145        }
1146    }
1147}
1148
1149fn backoff(attempts: usize) -> Duration {
1150    let times: Vec<u64> = vec![0, 1, 2, 5, 10];
1151
1152    Duration::from_secs(times[attempts.min(times.len() - 1)])
1153}