myko_rs/
websocket.rs

1use futures_signals::signal::Mutable;
2use futures_util::{future::select_all, SinkExt, StreamExt};
3use log::{debug, error, info, warn};
4use std::time::Duration;
5use tokio_tungstenite::connect_async;
6use tokio_util::sync::CancellationToken;
7use tungstenite::Message;
8use url::Url;
9
10#[derive(Clone, Debug)]
11pub enum SocketConnectionStatus {
12    Disconnected,
13    Connecting(String, CancellationToken),
14    Connected(String, CancellationToken),
15}
16
17pub struct AutoReconnectSocket {
18    pub status: Mutable<SocketConnectionStatus>,
19    pub incoming: tokio::sync::broadcast::Sender<Message>,
20    pub outgoing: tokio::sync::broadcast::Sender<Message>,
21}
22
23impl Default for AutoReconnectSocket {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl AutoReconnectSocket {
30    pub fn new() -> Self {
31        Self {
32            status: Mutable::new(SocketConnectionStatus::Disconnected),
33            incoming: tokio::sync::broadcast::channel(1000).0,
34            outgoing: tokio::sync::broadcast::channel(1000).0,
35        }
36    }
37
38    pub fn set_addr(&self, addr: Option<String>) {
39        let lock = self.status.lock_ref();
40        let s = lock.clone();
41        drop(lock);
42
43        match s {
44            SocketConnectionStatus::Connected(current_addr, teardown)
45            | SocketConnectionStatus::Connecting(current_addr, teardown) => {
46                if Some(current_addr) == addr {
47                    return;
48                }
49
50                teardown.cancel();
51
52                self.status.set(SocketConnectionStatus::Disconnected);
53                if let Some(addr) = addr {
54                    self.build(addr);
55                }
56            }
57
58            SocketConnectionStatus::Disconnected => {
59                if let Some(addr) = addr {
60                    self.build(addr);
61                }
62            }
63        }
64    }
65
66    pub fn build(&self, addr: String) {
67        info!("Building Connection to {}", addr);
68
69        let lock = self.status.lock_ref();
70        let s = lock.clone();
71        drop(lock);
72
73        match s {
74            SocketConnectionStatus::Connected(_, _token) => {
75                unreachable!("Should not be building when already connected");
76            }
77            SocketConnectionStatus::Connecting(_, _token) => {
78                unreachable!("Should not be building when already connected");
79            }
80            SocketConnectionStatus::Disconnected => (),
81        }
82
83        let teardown = CancellationToken::new();
84
85        let send = self.outgoing.clone();
86        let recv = self.incoming.clone();
87
88        let status = self.status.clone();
89
90        tokio::spawn(async move {
91            loop {
92                info!("Connecting to {}", addr);
93
94                let parsed = Url::parse(addr.as_str());
95
96                let mut parsed = match parsed {
97                    Ok(c) => c,
98                    Err(e) => {
99                        error!("Could not parse url: {:?}", e);
100
101                        let add_ws = format!("ws://{}", addr);
102
103                        match Url::parse(add_ws.as_str()) {
104                            Ok(c) => c,
105                            Err(_e) => {
106                                error!("Could not Parse Url: {_e} {add_ws}");
107                                return;
108                            }
109                        }
110                    }
111                };
112
113                if parsed.scheme() != "ws" {
114                    let _ = parsed.set_scheme("ws");
115                }
116
117                let ws_stream = match connect_async(&parsed.to_string()).await {
118                    Ok((ws_stream, _)) => ws_stream,
119                    Err(_) => {
120                        tokio::time::sleep(Duration::from_secs(1)).await;
121                        continue;
122                    }
123                };
124
125                info!("Connected to {}", parsed);
126
127                let (mut write, mut read) = ws_stream.split();
128                let interior_cancel = CancellationToken::new();
129
130                let int_send_cancel = interior_cancel.clone();
131                let rec_send_cancel = teardown.clone();
132
133                if teardown.is_cancelled() {
134                    break;
135                }
136
137                let mut local_send = send.subscribe();
138                let write_handle = tokio::spawn(async move {
139                    loop {
140                        if int_send_cancel.is_cancelled() || rec_send_cancel.is_cancelled() {
141                            debug!("Exiting Write Loop");
142                            break;
143                        }
144
145                        let msg = match local_send.recv().await {
146                            Ok(msg) => msg,
147                            Err(e) => {
148                                error!("Error receiving message to send: {:?}", e);
149                                continue;
150                            }
151                        };
152
153                        match write.send(msg).await {
154                            Ok(_) => {}
155                            Err(e) => {
156                                int_send_cancel.cancel();
157                                error!("Websocket write failed: {:?}", e);
158                            }
159                        }
160                    }
161                    debug!("Websocket Write Loop Exited");
162                });
163
164                let rec_read_cancel = teardown.clone();
165                let int_read_cancel = interior_cancel.clone();
166
167                let local_recv = recv.clone();
168
169                let read_handle = tokio::spawn(async move {
170                    while let (Some(Ok(msg)), false, false) = (
171                        read.next().await,
172                        rec_read_cancel.is_cancelled(),
173                        int_read_cancel.is_cancelled(),
174                    ) {
175                        match local_recv.send(msg) {
176                            Ok(_num) => {
177                                // debug!("Sent Message Downstream to {} Listeners", num);
178                            }
179                            Err(_e) => {
180                                debug!("No Downstream Listeners");
181                            }
182                        }
183                    }
184
185                    error!("Websocket Read Failed");
186                    int_read_cancel.cancel();
187                });
188
189                let s = SocketConnectionStatus::Connected(addr.clone(), teardown.clone());
190
191                status.set(s);
192
193                let _ = select_all(vec![write_handle, read_handle]).await;
194
195                warn!("Read and/or Write Exited - Reconnecting in 1s");
196
197                let s = SocketConnectionStatus::Disconnected;
198
199                status.set(s);
200
201                tokio::time::sleep(Duration::from_secs(1)).await;
202            }
203        });
204    }
205}