1#![allow(dead_code)]
2
3use std::error::Error;
4use std::fmt::{Debug, Display};
5use std::thread::sleep;
6use std::time::SystemTime;
7use std::{
8 collections::HashMap,
9 io,
10 net::{TcpStream, ToSocketAddrs},
11 time::Duration,
12};
13
14use bevy::ecs::system::SystemId;
15use bevy::log::{debug, info};
16use bevy::prelude::*;
17use bevy_crossbeam_event::CrossbeamEventSender;
18use crossbeam::channel::{unbounded, Receiver, SendError, Sender, TryRecvError};
19use native_tls::TlsConnector;
20use tungstenite::{client, Message};
21use tungstenite::{
22 client::{uri_mode, IntoClientRequest},
23 handshake::MidHandshake,
24 http::{HeaderMap, HeaderValue, Response as HttpResponse, StatusCode, Uri},
25 stream::{MaybeTlsStream, Mode, NoDelay},
26 ClientHandshake, Error as TungsteniteError, HandshakeError, WebSocket as WebSocketWrapper,
27};
28use uuid::Uuid;
29
30use super::channel::{ChannelState, RealtimeChannel};
31use crate::message::payload::Payload;
32use crate::message::realtime_message::RealtimeMessage;
33
34use super::channel::ChannelBuilder;
35
36pub type Response = HttpResponse<Option<Vec<u8>>>;
37pub type WebSocket = WebSocketWrapper<MaybeTlsStream<TcpStream>>;
38
39#[derive(PartialEq, Debug, Default, Clone, Copy, Event)]
41pub enum ConnectionState {
42 Reconnect,
44 Reconnecting,
46 Connecting,
47 Open,
48 Closing,
49 #[default]
50 Closed,
51}
52
53#[derive(PartialEq, Debug)]
56pub enum NextMessageError {
57 WouldBlock,
58 TryRecvError(TryRecvError),
59 NoChannel,
60 ChannelClosed,
61 ClientClosed,
62 SocketError(SocketError),
63 MonitorError(MonitorError),
64}
65
66impl Error for NextMessageError {}
67impl Display for NextMessageError {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.write_str(&format!("{:?}", self))
70 }
71}
72
73#[derive(PartialEq, Debug)]
75pub enum MonitorError {
76 ReconnectError,
77 MaxReconnects,
78 WouldBlock,
79 Disconnected,
80}
81
82#[derive(Debug, PartialEq)]
84pub enum SocketError {
85 NoSocket,
86 NoRead,
87 NoWrite,
88 Disconnected,
89 WouldBlock,
90 TooManyRetries,
91 HandshakeError,
92}
93
94#[derive(Debug, PartialEq, Clone, Copy)]
96pub enum ConnectError {
97 BadUri,
98 BadHost,
99 BadAddrs,
100 StreamError,
101 NoDelayError,
102 NonblockingError,
103 HandshakeError,
104 MaxRetries,
105 WrongProtocol,
106}
107
108pub(crate) struct MessageChannel((Sender<RealtimeMessage>, Receiver<RealtimeMessage>));
109
110impl Default for MessageChannel {
111 fn default() -> Self {
112 Self(crossbeam::channel::unbounded())
113 }
114}
115
116struct MonitorChannel((Sender<MonitorSignal>, Receiver<MonitorSignal>));
117
118impl Default for MonitorChannel {
119 fn default() -> Self {
120 Self(crossbeam::channel::unbounded())
121 }
122}
123
124#[derive(Debug, PartialEq)]
125enum MonitorSignal {
126 Reconnect,
127}
128
129#[derive(Clone)]
130pub struct ClientManager {
131 tx: Sender<ClientManagerMessage>,
132}
133
134pub enum ClientManagerMessage {
135 Channel {
136 callback: SystemId<In<ChannelBuilder>>,
137 },
138 AddChannel {
139 channel: RealtimeChannel,
140 },
141 SetAccessToken {
142 token: String,
143 },
144 ConnectionState {
145 sender: CrossbeamEventSender<ConnectionState>,
146 },
147 Connect {
148 callback: SystemId<In<Result<(), ConnectError>>>,
149 },
150}
151
152impl ClientManager {
153 pub fn new(client: &Client) -> Self {
154 Self {
155 tx: client.manager_tx.clone(),
156 }
157 }
158
159 pub fn connect(
160 &self,
161 callback: SystemId<In<Result<(), ConnectError>>>,
162 ) -> Result<(), SendError<ClientManagerMessage>> {
163 self.tx.send(ClientManagerMessage::Connect { callback })
164 }
165
166 pub fn channel(
167 &self,
168 callback: SystemId<In<ChannelBuilder>>,
169 ) -> Result<(), SendError<ClientManagerMessage>> {
170 self.tx.send(ClientManagerMessage::Channel { callback })
171 }
172
173 pub fn add_channel(
174 &self,
175 channel: RealtimeChannel,
176 ) -> Result<(), SendError<ClientManagerMessage>> {
177 self.tx.send(ClientManagerMessage::AddChannel { channel })
178 }
179
180 pub fn set_access_token(&self, token: String) -> Result<(), SendError<ClientManagerMessage>> {
181 self.tx.send(ClientManagerMessage::SetAccessToken { token })
182 }
183
184 pub fn connection_state(
185 &self,
186 sender: CrossbeamEventSender<ConnectionState>,
187 ) -> Result<(), SendError<ClientManagerMessage>> {
188 self.tx
189 .send(ClientManagerMessage::ConnectionState { sender })
190 }
191}
192
193pub struct Client {
195 pub(crate) access_token: String,
196 connection_state: ConnectionState,
197 socket: Option<WebSocket>,
198 channels: HashMap<Uuid, RealtimeChannel>,
199 messages_this_second: Vec<SystemTime>,
200 next_ref: Uuid,
201 pub(crate) outbound_channel: MessageChannel,
203 inbound_channel: MessageChannel,
204 monitor_channel: MonitorChannel,
205 middleware: HashMap<Uuid, Box<dyn Fn(RealtimeMessage) -> RealtimeMessage + Send + Sync>>,
206 reconnect_now: Option<SystemTime>,
208 reconnect_delay: Duration,
209 reconnect_attempts: usize,
210 heartbeat_now: Option<SystemTime>,
211 headers: HeaderMap,
213 params: Option<HashMap<String, String>>,
214 heartbeat_interval: Duration,
215 encode: Option<Box<dyn Fn(RealtimeMessage) -> RealtimeMessage + Send + Sync>>,
216 decode: Option<Box<dyn Fn(RealtimeMessage) -> RealtimeMessage + Send + Sync>>,
217 reconnect_interval: ReconnectFn,
218 reconnect_max_attempts: usize,
219 connection_timeout: Duration,
220 auth_url: Option<String>,
221 endpoint: String,
222 max_events_per_second: usize,
223 manager_rx: Receiver<ClientManagerMessage>,
225 manager_tx: Sender<ClientManagerMessage>,
226 channel_callback_event_sender: CrossbeamEventSender<ChannelCallbackEvent>,
227 connect_result_callback_event_sender: CrossbeamEventSender<ConnectResultCallbackEvent>,
228}
229
230#[derive(Event, Clone)]
231pub struct ChannelCallbackEvent(pub (SystemId<In<ChannelBuilder>>, ChannelBuilder));
232
233#[derive(Event, Clone)]
234pub struct ConnectResultCallbackEvent(
235 pub (
236 SystemId<In<Result<(), ConnectError>>>,
237 Result<(), ConnectError>,
238 ),
239);
240
241impl Client {
242 pub fn manager_recv(&mut self) -> Result<(), Box<dyn Error>> {
243 while let Ok(message) = self.manager_rx.try_recv() {
244 match message {
245 ClientManagerMessage::Channel { callback } => {
246 let c = self.channel();
247
248 debug!("got channel, sending to callback...");
249
250 self.channel_callback_event_sender
251 .send(ChannelCallbackEvent((callback, c)));
252 }
253 ClientManagerMessage::AddChannel { channel } => self.add_channel(channel),
254 ClientManagerMessage::SetAccessToken { token } => {
255 self.access_token = token;
256 }
257 ClientManagerMessage::ConnectionState { sender } => {
258 sender.send(self.connection_state);
259 }
260 ClientManagerMessage::Connect { callback } => {
261 let result = self.connect();
262 self.connect_result_callback_event_sender
263 .send(ConnectResultCallbackEvent((callback, result)))
264 }
265 }
266 }
267
268 Ok(())
269 }
270 pub fn builder(endpoint: impl Into<String>, access_token: impl Into<String>) -> ClientBuilder {
272 ClientBuilder::new(endpoint, access_token)
273 }
274
275 pub fn get_status(&self) -> ConnectionState {
277 self.connection_state
278 }
279
280 pub fn channel(&mut self) -> ChannelBuilder {
282 ChannelBuilder::new(self)
283 }
284
285 pub fn connect(&mut self) -> Result<(), ConnectError> {
287 info!("connecting...");
288 self.connection_state = ConnectionState::Connecting;
289
290 let _ = self.manager_recv();
291
292 let uri: Uri = match format!(
293 "{}/websocket?apikey={}&vsn=1.0.0",
294 self.endpoint, self.access_token
295 )
296 .parse()
297 {
298 Ok(uri) => uri,
299 Err(_e) => return Err(ConnectError::BadUri),
300 };
301
302 let ws_scheme = match uri.scheme_str() {
304 Some(scheme) => {
305 if scheme == "http" {
306 "ws"
307 } else {
308 "wss"
309 }
310 }
311 None => "ws",
312 };
313
314 let mut add_params = String::new();
315 if let Some(params) = &self.params {
316 for (field, value) in params {
317 add_params = format!("{add_params}&{field}={value}");
318 }
319 }
320
321 let mut p_q = uri.path_and_query().unwrap().to_string();
322
323 if !add_params.is_empty() {
324 p_q = format!("{p_q}{add_params}");
325 }
326
327 let uri = Uri::builder()
328 .scheme(ws_scheme)
329 .authority(uri.authority().unwrap().clone())
330 .path_and_query(p_q)
331 .build()
332 .unwrap();
333
334 let Ok(mut request) = uri.clone().into_client_request() else {
335 return Err(ConnectError::BadUri);
336 };
337
338 let headers = request.headers_mut();
339
340 let auth: HeaderValue = format!("Bearer {}", self.access_token)
341 .parse()
342 .expect("malformed access token?");
343 headers.insert("Authorization", auth);
344
345 let xci: HeaderValue = "realtime-rs/0.1.0".to_string().parse().unwrap();
347 headers.insert("X-Client-Info", xci);
348
349 debug!("Connecting... Req: {:?}\n", request);
350
351 let uri = request.uri();
352
353 let Ok(mode) = uri_mode(uri) else {
354 return Err(ConnectError::BadUri);
355 };
356
357 let Some(host) = uri.host() else {
358 return Err(ConnectError::BadHost);
359 };
360
361 let port = uri.port_u16().unwrap_or(match mode {
362 Mode::Plain => 80,
363 Mode::Tls => 443,
364 });
365
366 let Ok(mut addrs) = (host, port).to_socket_addrs() else {
367 return Err(ConnectError::BadAddrs);
368 };
369
370 let mut stream = match TcpStream::connect_timeout(
371 &addrs.next().expect("uhoh no addr"),
372 self.connection_timeout,
373 ) {
374 Ok(stream) => {
375 self.reconnect_attempts = 0;
376 stream
377 }
378 Err(_e) => {
379 return self.retry_connect();
381 }
382 };
383
384 let Ok(()) = NoDelay::set_nodelay(&mut stream, true) else {
385 return Err(ConnectError::NoDelayError);
386 };
387
388 let maybe_tls = match mode {
389 Mode::Tls => {
390 let connector = TlsConnector::new().expect("No TLS tings");
391
392 let connected_stream = connector
393 .connect(host, stream.try_clone().expect("noclone"))
394 .unwrap();
395
396 stream
397 .set_nonblocking(true)
398 .expect("blocking mode oh nooooo");
399
400 MaybeTlsStream::NativeTls(connected_stream)
401 }
402 Mode::Plain => {
403 stream
404 .set_nonblocking(true)
405 .expect("blocking mode oh nooooo");
406
407 MaybeTlsStream::Plain(stream)
408 }
409 };
410
411 let conn: Result<(WebSocket, Response), TungsteniteError> = match client(request, maybe_tls)
412 {
413 Ok(stream) => {
414 self.reconnect_attempts = 0;
415 Ok(stream)
416 }
417 Err(err) => match err {
418 HandshakeError::Failure(_err) => {
419 return self.retry_connect();
421 }
422 HandshakeError::Interrupted(mid_hs) => match self.retry_handshake(mid_hs) {
423 Ok(stream) => Ok(stream),
424 Err(_err) => {
425 return self.retry_connect();
427 }
428 },
429 },
430 };
431
432 let (socket, res) = conn.expect("Handshake fail");
433
434 if res.status() != StatusCode::SWITCHING_PROTOCOLS {
435 return Err(ConnectError::WrongProtocol);
436 }
437
438 self.socket = Some(socket);
439
440 self.connection_state = ConnectionState::Open;
441 info!("connected");
442
443 Ok(())
444 }
445
446 fn retry_connect(&mut self) -> Result<(), ConnectError> {
447 debug!(
448 "Retry count {}/{}",
449 self.reconnect_attempts, self.reconnect_max_attempts
450 );
451 debug!(
452 "Waiting {}s...",
453 self.reconnect_interval.0(self.reconnect_attempts).as_secs()
454 );
455
456 if self.reconnect_attempts < self.reconnect_max_attempts {
457 self.reconnect_attempts += 1;
458 let backoff = &self.reconnect_interval.0;
459 sleep(backoff(self.reconnect_attempts));
460 return self.connect();
461 }
462
463 Err(ConnectError::MaxRetries)
464 }
465
466 pub fn disconnect(&mut self) {
468 if self.connection_state == ConnectionState::Closed {
469 return;
470 }
471
472 self.remove_all_channels();
473
474 self.connection_state = ConnectionState::Closed;
475
476 let Some(ref mut socket) = self.socket else {
477 debug!("Already disconnected. {:?}", self.connection_state);
478 return;
479 };
480
481 let _ = socket.close(None);
482 debug!("Client disconnected. {:?}", self.connection_state);
483 }
484
485 pub fn send(&mut self, msg: RealtimeMessage) -> Result<(), SendError<RealtimeMessage>> {
487 self.outbound_channel.0 .0.send(msg)
488 }
489
490 pub fn get_channel_mut(&mut self, channel_id: Uuid) -> Option<&mut RealtimeChannel> {
493 self.channels.get_mut(&channel_id)
494 }
495
496 pub fn get_channel(&self, channel_id: Uuid) -> Option<&RealtimeChannel> {
499 self.channels.get(&channel_id)
500 }
501
502 pub fn get_channels(&self) -> &HashMap<Uuid, RealtimeChannel> {
504 &self.channels
505 }
506
507 pub fn remove_channel(&mut self, channel_id: Uuid) -> Option<RealtimeChannel> {
510 if let Some(mut channel) = self.channels.remove(&channel_id) {
511 let _ = channel.unsubscribe();
512
513 if self.channels.is_empty() {
514 self.disconnect();
515 }
516
517 return Some(channel);
518 }
519
520 None
521 }
522
523 pub fn block_until_subscribed(&mut self, channel_id: Uuid) -> Result<Uuid, ChannelState> {
525 let channel = self.channels.get_mut(&channel_id);
529
530 let channel = channel.unwrap();
531
532 if channel.connection_state == ChannelState::Joined {
533 return Ok(channel.id);
534 }
535
536 if channel.connection_state != ChannelState::Joining {
537 self.channels
538 .get_mut(&channel_id)
539 .unwrap()
540 .subscribe()
541 .unwrap();
542 }
543
544 loop {
545 match self.step() {
546 Ok(channel_ids) => {
547 debug!(
548 "[Blocking Subscribe] Message forwarded to {:?}",
549 channel_ids
550 )
551 }
552 Err(NextMessageError::WouldBlock) => {}
553 Err(_e) => {
554 }
556 }
557
558 let channel = self.channels.get_mut(&channel_id).unwrap();
559
560 match channel.connection_state {
561 ChannelState::Joined => {
562 break;
563 }
564 ChannelState::Closed => return Err(ChannelState::Closed),
565 _ => {}
566 }
567 }
568
569 Ok(channel_id)
570 }
571
572 pub fn set_auth(&mut self, access_token: String) {
574 self.access_token.clone_from(&access_token);
575
576 for channel in self.channels.values_mut() {
577 let _ = channel.set_auth(access_token.clone()); }
580 }
581
582 pub fn add_middleware(
585 &mut self,
586 middleware: Box<dyn Fn(RealtimeMessage) -> RealtimeMessage + Send + Sync>,
587 ) -> Uuid {
588 let uuid = Uuid::new_v4();
590 self.middleware.insert(uuid, middleware);
591 uuid
592 }
593
594 pub fn remove_middleware(&mut self, uuid: Uuid) -> &mut Client {
596 self.middleware.remove(&uuid);
597 self
598 }
599
600 pub fn step(&mut self) -> Result<Vec<Uuid>, NextMessageError> {
602 match self.manager_recv() {
604 Ok(()) => {}
605 Err(e) => debug!("client manager_recv error: {}", e),
606 }
607
608 for channel in self.channels.values_mut() {
609 match channel.manager_recv() {
610 Ok(()) => {}
611 Err(e) => debug!("channel manager_recv error: {}", e),
612 }
613 }
614
615 match self.connection_state {
616 ConnectionState::Closed => {
617 return Err(NextMessageError::ClientClosed);
618 }
619 ConnectionState::Reconnect => {
620 let _ = self.monitor_channel.0 .0.send(MonitorSignal::Reconnect);
621 return Err(NextMessageError::SocketError(SocketError::Disconnected));
622 }
623 ConnectionState::Reconnecting => {
624 return Err(NextMessageError::WouldBlock);
625 }
626 ConnectionState::Connecting => {}
627 ConnectionState::Closing => {}
628 ConnectionState::Open => {}
629 }
630
631 match self.run_monitor() {
632 Ok(_) => {}
633 Err(MonitorError::WouldBlock) => {}
634 Err(MonitorError::MaxReconnects) => {
635 self.disconnect();
636 return Err(NextMessageError::MonitorError(MonitorError::MaxReconnects));
637 }
638 Err(e) => {
639 return Err(NextMessageError::MonitorError(e));
640 }
641 }
642
643 self.run_heartbeat();
644
645 match self.write_socket() {
646 Ok(()) => {}
647 Err(SocketError::WouldBlock) => {}
648 Err(e) => {
649 self.reconnect();
650 return Err(NextMessageError::SocketError(e));
651 }
652 }
653
654 match self.read_socket() {
655 Ok(()) => {}
656 Err(e) => {
657 self.reconnect();
658 return Err(NextMessageError::SocketError(e));
659 }
660 }
661
662 match self.inbound_channel.0 .1.try_recv() {
663 Ok(mut message) => {
664 let mut ids = vec![];
665 message = self.run_middleware(message);
669
670 for (id, channel) in &mut self.channels {
672 if channel.topic == message.topic {
673 channel.recieve(message.clone());
674 ids.push(*id);
675 }
676 }
677
678 Ok(ids)
679 }
680 Err(TryRecvError::Empty) => Err(NextMessageError::WouldBlock),
681 Err(e) => Err(NextMessageError::TryRecvError(e)),
682 }
683 }
684
685 pub(crate) fn add_channel(&mut self, channel: RealtimeChannel) {
686 self.channels.insert(channel.id, channel);
687 }
688
689 pub(crate) fn get_channel_tx(&self) -> Sender<RealtimeMessage> {
690 self.outbound_channel.0 .0.clone()
691 }
692
693 fn run_middleware(&self, mut message: RealtimeMessage) -> RealtimeMessage {
694 for middleware in self.middleware.values() {
695 message = middleware(message)
696 }
697 message
698 }
699
700 fn remove_all_channels(&mut self) {
701 if self.connection_state == ConnectionState::Closing
702 || self.connection_state == ConnectionState::Closed
703 {
704 return;
705 }
706
707 self.connection_state = ConnectionState::Closing;
708
709 loop {
711 let recv = self.step();
712
713 if Err(NextMessageError::WouldBlock) == recv {
714 break;
715 }
716 }
717
718 loop {
719 let _ = self.step();
720
721 let mut all_channels_closed = true;
722
723 for channel in self.channels.values_mut() {
724 let channel_state = channel.unsubscribe();
725
726 match channel_state {
727 Ok(state) => {
728 if state != ChannelState::Closed {
729 all_channels_closed = false;
730 }
731 }
732 Err(e) => {
733 debug!("Unsubscribe error: {:?}", e);
735 }
736 }
737 }
738
739 if all_channels_closed {
740 debug!("All channels closed!");
741 break;
742 }
743 }
744
745 self.channels.clear();
746 }
747
748 fn retry_handshake(
749 &mut self,
750 mid_hs: MidHandshake<ClientHandshake<MaybeTlsStream<TcpStream>>>,
751 ) -> Result<(WebSocket, Response), SocketError> {
752 match mid_hs.handshake() {
753 Ok(stream) => Ok(stream),
754 Err(e) => match e {
755 HandshakeError::Interrupted(mid_hs) => {
756 if self.reconnect_attempts < self.reconnect_max_attempts {
758 self.reconnect_attempts += 1;
759 let backoff = &self.reconnect_interval.0;
760 sleep(backoff(self.reconnect_attempts));
761 return self.retry_handshake(mid_hs);
762 }
763
764 Err(SocketError::TooManyRetries)
765 }
766 HandshakeError::Failure(_err) => {
767 Err(SocketError::HandshakeError)
769 }
770 },
771 }
772 }
773
774 fn run_monitor(&mut self) -> Result<(), MonitorError> {
775 if self.reconnect_now.is_none() {
776 self.reconnect_now = Some(SystemTime::now());
777
778 self.reconnect_delay = self.reconnect_interval.0(self.reconnect_attempts);
779 }
780
781 match self.monitor_channel.0 .1.try_recv() {
782 Ok(signal) => match signal {
783 MonitorSignal::Reconnect => {
784 if self.connection_state == ConnectionState::Open
785 || self.connection_state == ConnectionState::Reconnecting
786 || SystemTime::now() < self.reconnect_now.unwrap() + self.reconnect_delay
787 {
788 return Err(MonitorError::WouldBlock);
789 }
790
791 if self.reconnect_attempts >= self.reconnect_max_attempts {
792 return Err(MonitorError::MaxReconnects);
793 }
794
795 self.connection_state = ConnectionState::Reconnecting;
796 self.reconnect_attempts += 1;
797 self.reconnect_now.take();
798
799 match self.connect() {
800 Ok(_) => {
801 for channel in self.channels.values_mut() {
802 channel.subscribe().unwrap();
803 }
804
805 Ok(())
806 }
807 Err(e) => {
808 debug!("reconnect error: {:?}", e);
809 self.connection_state = ConnectionState::Reconnect;
810 Err(MonitorError::ReconnectError)
811 }
812 }
813 }
814 },
815 Err(TryRecvError::Empty) => Err(MonitorError::WouldBlock),
816 Err(TryRecvError::Disconnected) => Err(MonitorError::Disconnected),
817 }
818 }
819
820 fn run_heartbeat(&mut self) {
821 if self.heartbeat_now.is_none() {
822 self.heartbeat_now = Some(SystemTime::now());
823 }
824
825 if self.heartbeat_now.unwrap() + self.heartbeat_interval > SystemTime::now() {
826 return;
827 }
828
829 self.heartbeat_now.take();
830
831 let _ = self.send(RealtimeMessage::heartbeat());
832 }
833
834 fn read_socket(&mut self) -> Result<(), SocketError> {
835 let Some(ref mut socket) = self.socket else {
836 return Err(SocketError::NoSocket);
837 };
838
839 if !socket.can_read() {
840 return Err(SocketError::NoRead);
841 }
842
843 match socket.read() {
844 Ok(raw_message) => match raw_message {
845 Message::Text(string_message) => {
846 let mut message: RealtimeMessage =
848 serde_json::from_str(&string_message).expect("Deserialization error: ");
849
850 debug!("[RECV] {:?}", message);
851
852 if let Some(decode) = &self.decode {
853 message = decode(message);
854 }
855
856 if let Payload::Empty {} = message.payload {
857 debug!("Possibly malformed payload: {:?}", string_message)
858 }
859
860 let _ = self.inbound_channel.0 .0.send(message);
861 Ok(())
862 }
863 Message::Close(_close_message) => {
864 self.disconnect();
865 Err(SocketError::Disconnected)
866 }
867 _ => {
868 Err(SocketError::WouldBlock)
870 }
871 },
872 Err(TungsteniteError::Io(err)) if err.kind() == io::ErrorKind::WouldBlock => {
873 Ok(())
875 }
876 Err(err) => {
877 debug!("Socket read error: {:?}", err);
878 self.connection_state = ConnectionState::Reconnect;
879 let _ = self.monitor_channel.0 .0.send(MonitorSignal::Reconnect);
880 Err(SocketError::WouldBlock)
881 }
882 }
883 }
884
885 fn write_socket(&mut self) -> Result<(), SocketError> {
886 let Some(ref mut socket) = self.socket else {
887 return Err(SocketError::NoSocket);
888 };
889
890 if !socket.can_write() {
891 return Err(SocketError::NoWrite);
892 }
893
894 let now = SystemTime::now();
896
897 self.messages_this_second = self
898 .messages_this_second
899 .clone() .into_iter()
901 .filter(|st| now.duration_since(*st).unwrap_or_default() < Duration::from_secs(1))
902 .collect();
903
904 if self.messages_this_second.len() >= self.max_events_per_second {
905 return Err(SocketError::WouldBlock);
906 }
907
908 let message = self.outbound_channel.0 .1.try_recv();
912
913 match message {
916 Ok(mut message) => {
917 if message.message_ref.is_none() {
918 message.message_ref = Some(self.next_ref.into());
919 self.next_ref = Uuid::new_v4();
920 }
921
922 if let Some(encode) = &self.encode {
923 message = encode(message);
924 }
925
926 let raw = serde_json::to_string(&message);
927 debug!("[SEND] {:?}", raw);
928
929 let _ = socket.send(message.into());
930 self.messages_this_second.push(now);
931 Ok(())
932 }
933 Err(TryRecvError::Empty) => {
934 Ok(())
936 }
937 Err(e) => {
938 debug!("outbound error: {:?}", e);
939 self.connection_state = ConnectionState::Reconnect;
940 let _ = self.monitor_channel.0 .0.send(MonitorSignal::Reconnect);
941 Err(SocketError::WouldBlock)
942 }
943 }
944 }
945
946 fn reconnect(&mut self) {
947 self.connection_state = ConnectionState::Reconnect;
948 let _ = self.monitor_channel.0 .0.send(MonitorSignal::Reconnect);
949 }
950}
951
952pub struct ReconnectFn(pub Box<dyn Fn(usize) -> Duration + Send + Sync>);
956
957impl ReconnectFn {
958 pub fn new(f: impl Fn(usize) -> Duration + 'static + Sync + Send) -> Self {
959 Self(Box::new(f))
960 }
961}
962
963impl Default for ReconnectFn {
964 fn default() -> Self {
965 Self(Box::new(backoff))
966 }
967}
968
969impl Debug for ReconnectFn {
970 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
971 f.write_str("TODO reconnect fn debug")
972 }
973}
974
975pub struct ClientBuilder {
977 headers: HeaderMap,
978 params: Option<HashMap<String, String>>,
979 heartbeat_interval: Duration,
980 encode: Option<Box<dyn Fn(RealtimeMessage) -> RealtimeMessage + Send + Sync>>,
981 decode: Option<Box<dyn Fn(RealtimeMessage) -> RealtimeMessage + Send + Sync>>,
982 reconnect_interval: ReconnectFn,
983 reconnect_max_attempts: usize,
984 connection_timeout: Duration,
985 auth_url: Option<String>,
986 endpoint: String,
987 access_token: String,
988 max_events_per_second: usize,
989}
990
991impl ClientBuilder {
992 pub fn new(endpoint: impl Into<String>, access_token: impl Into<String>) -> Self {
993 let mut headers = HeaderMap::new();
994 headers.insert("X-Client-Info", "realtime-rs/0.1.0".parse().unwrap());
995
996 Self {
997 headers,
998 params: Default::default(),
999 heartbeat_interval: Duration::from_secs(29),
1000 encode: Default::default(),
1001 decode: Default::default(),
1002 reconnect_interval: ReconnectFn(Box::new(backoff)),
1003 reconnect_max_attempts: usize::MAX,
1004 connection_timeout: Duration::from_secs(10),
1005 auth_url: Default::default(),
1006 endpoint: endpoint.into(),
1007 access_token: access_token.into(),
1008 max_events_per_second: 10,
1009 }
1010 }
1011
1012 pub fn set_headers(&mut self, set_headers: HeaderMap) -> &mut Self {
1014 let mut headers = HeaderMap::new();
1015 headers.insert("X-Client-Info", "realtime-rs/0.1.0".parse().unwrap());
1016 headers.extend(set_headers);
1017
1018 self.headers = headers;
1019
1020 self
1021 }
1022
1023 pub fn add_headers(&mut self, headers: HeaderMap) -> &mut Self {
1025 self.headers.extend(headers);
1026 self
1027 }
1028
1029 pub fn params(&mut self, params: HashMap<String, String>) -> &mut Self {
1031 self.params = Some(params);
1032 self
1033 }
1034
1035 pub fn heartbeat_interval(&mut self, heartbeat_interval: Duration) -> &mut Self {
1037 self.heartbeat_interval = heartbeat_interval;
1038 self
1039 }
1040
1041 pub fn reconnect_interval(&mut self, reconnect_interval: ReconnectFn) -> &mut Self {
1050 self.reconnect_interval = reconnect_interval;
1053 self
1054 }
1055
1056 pub fn reconnect_max_attempts(&mut self, max_attempts: usize) -> &mut Self {
1058 self.reconnect_max_attempts = max_attempts;
1059 self
1060 }
1061
1062 pub fn connection_timeout(&mut self, timeout: Duration) -> &mut Self {
1066 let timeout = if timeout < Duration::from_secs(1) {
1068 Duration::from_secs(1)
1069 } else {
1070 timeout
1071 };
1072
1073 self.connection_timeout = timeout;
1074 self
1075 }
1076
1077 pub fn auth_url(&mut self, auth_url: impl Into<String>) -> &mut Self {
1081 self.auth_url = Some(auth_url.into());
1082 self
1083 }
1084
1085 pub fn max_events_per_second(&mut self, count: usize) -> &mut Self {
1088 self.max_events_per_second = count;
1089 self
1090 }
1091
1092 pub fn encode(
1093 &mut self,
1094 encode: impl Fn(RealtimeMessage) -> RealtimeMessage + 'static + Send + Sync,
1095 ) -> &mut Self {
1096 self.encode = Some(Box::new(encode));
1097 self
1098 }
1099
1100 pub fn decode(
1101 &mut self,
1102 decode: impl Fn(RealtimeMessage) -> RealtimeMessage + 'static + Send + Sync,
1103 ) -> &mut Self {
1104 self.decode = Some(Box::new(decode));
1105 self
1106 }
1107
1108 pub fn build(
1110 self,
1111 channel_callback_event_sender: CrossbeamEventSender<ChannelCallbackEvent>,
1112 connect_result_callback_event_sender: CrossbeamEventSender<ConnectResultCallbackEvent>,
1113 ) -> Client {
1114 let (manager_tx, manager_rx) = unbounded();
1115 Client {
1116 headers: self.headers,
1117 params: self.params,
1118 heartbeat_interval: self.heartbeat_interval,
1119 encode: self.encode,
1120 decode: self.decode,
1121 reconnect_interval: self.reconnect_interval,
1122 reconnect_max_attempts: self.reconnect_max_attempts,
1123 connection_timeout: self.connection_timeout,
1124 auth_url: self.auth_url,
1125 endpoint: self.endpoint,
1126 access_token: self.access_token,
1127 max_events_per_second: self.max_events_per_second,
1128 next_ref: Uuid::new_v4(),
1129 connection_state: Default::default(),
1130 socket: Default::default(),
1131 channels: Default::default(),
1132 messages_this_second: Default::default(),
1133 outbound_channel: Default::default(),
1134 inbound_channel: Default::default(),
1135 monitor_channel: Default::default(),
1136 middleware: Default::default(),
1137 reconnect_now: Default::default(),
1138 reconnect_delay: Default::default(),
1139 reconnect_attempts: Default::default(),
1140 heartbeat_now: Default::default(),
1141 manager_rx,
1142 manager_tx,
1143 channel_callback_event_sender,
1144 connect_result_callback_event_sender,
1145 }
1146 }
1147}
1148
1149fn backoff(attempts: usize) -> Duration {
1150 let times: Vec<u64> = vec![0, 1, 2, 5, 10];
1151
1152 Duration::from_secs(times[attempts.min(times.len() - 1)])
1153}