1use futures_signals::signal::Mutable;
2use futures_util::{future::select_all, SinkExt, StreamExt};
3use log::{debug, error, info, warn};
4use std::time::Duration;
5use tokio_tungstenite::connect_async;
6use tokio_util::sync::CancellationToken;
7use tungstenite::Message;
8use url::Url;
9
10#[derive(Clone, Debug)]
11pub enum SocketConnectionStatus {
12 Disconnected,
13 Connecting(String, CancellationToken),
14 Connected(String, CancellationToken),
15}
16
17pub struct AutoReconnectSocket {
18 pub status: Mutable<SocketConnectionStatus>,
19 pub incoming: tokio::sync::broadcast::Sender<Message>,
20 pub outgoing: tokio::sync::broadcast::Sender<Message>,
21}
22
23impl Default for AutoReconnectSocket {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl AutoReconnectSocket {
30 pub fn new() -> Self {
31 Self {
32 status: Mutable::new(SocketConnectionStatus::Disconnected),
33 incoming: tokio::sync::broadcast::channel(1000).0,
34 outgoing: tokio::sync::broadcast::channel(1000).0,
35 }
36 }
37
38 pub fn set_addr(&self, addr: Option<String>) {
39 let lock = self.status.lock_ref();
40 let s = lock.clone();
41 drop(lock);
42
43 match s {
44 SocketConnectionStatus::Connected(current_addr, teardown)
45 | SocketConnectionStatus::Connecting(current_addr, teardown) => {
46 if Some(current_addr) == addr {
47 return;
48 }
49
50 teardown.cancel();
51
52 self.status.set(SocketConnectionStatus::Disconnected);
53 if let Some(addr) = addr {
54 self.build(addr);
55 }
56 }
57
58 SocketConnectionStatus::Disconnected => {
59 if let Some(addr) = addr {
60 self.build(addr);
61 }
62 }
63 }
64 }
65
66 pub fn build(&self, addr: String) {
67 info!("Building Connection to {}", addr);
68
69 let lock = self.status.lock_ref();
70 let s = lock.clone();
71 drop(lock);
72
73 match s {
74 SocketConnectionStatus::Connected(_, _token) => {
75 unreachable!("Should not be building when already connected");
76 }
77 SocketConnectionStatus::Connecting(_, _token) => {
78 unreachable!("Should not be building when already connected");
79 }
80 SocketConnectionStatus::Disconnected => (),
81 }
82
83 let teardown = CancellationToken::new();
84
85 let send = self.outgoing.clone();
86 let recv = self.incoming.clone();
87
88 let status = self.status.clone();
89
90 tokio::spawn(async move {
91 loop {
92 info!("Connecting to {}", addr);
93
94 let parsed = Url::parse(addr.as_str());
95
96 let mut parsed = match parsed {
97 Ok(c) => c,
98 Err(e) => {
99 error!("Could not parse url: {:?}", e);
100
101 let add_ws = format!("ws://{}", addr);
102
103 match Url::parse(add_ws.as_str()) {
104 Ok(c) => c,
105 Err(_e) => {
106 error!("Could not Parse Url: {_e} {add_ws}");
107 return;
108 }
109 }
110 }
111 };
112
113 if parsed.scheme() != "ws" {
114 let _ = parsed.set_scheme("ws");
115 }
116
117 let ws_stream = match connect_async(&parsed.to_string()).await {
118 Ok((ws_stream, _)) => ws_stream,
119 Err(_) => {
120 tokio::time::sleep(Duration::from_secs(1)).await;
121 continue;
122 }
123 };
124
125 info!("Connected to {}", parsed);
126
127 let (mut write, mut read) = ws_stream.split();
128 let interior_cancel = CancellationToken::new();
129
130 let int_send_cancel = interior_cancel.clone();
131 let rec_send_cancel = teardown.clone();
132
133 if teardown.is_cancelled() {
134 break;
135 }
136
137 let mut local_send = send.subscribe();
138 let write_handle = tokio::spawn(async move {
139 loop {
140 if int_send_cancel.is_cancelled() || rec_send_cancel.is_cancelled() {
141 debug!("Exiting Write Loop");
142 break;
143 }
144
145 let msg = match local_send.recv().await {
146 Ok(msg) => msg,
147 Err(e) => {
148 error!("Error receiving message to send: {:?}", e);
149 continue;
150 }
151 };
152
153 match write.send(msg).await {
154 Ok(_) => {}
155 Err(e) => {
156 int_send_cancel.cancel();
157 error!("Websocket write failed: {:?}", e);
158 }
159 }
160 }
161 debug!("Websocket Write Loop Exited");
162 });
163
164 let rec_read_cancel = teardown.clone();
165 let int_read_cancel = interior_cancel.clone();
166
167 let local_recv = recv.clone();
168
169 let read_handle = tokio::spawn(async move {
170 while let (Some(Ok(msg)), false, false) = (
171 read.next().await,
172 rec_read_cancel.is_cancelled(),
173 int_read_cancel.is_cancelled(),
174 ) {
175 match local_recv.send(msg) {
176 Ok(_num) => {
177 }
179 Err(_e) => {
180 debug!("No Downstream Listeners");
181 }
182 }
183 }
184
185 error!("Websocket Read Failed");
186 int_read_cancel.cancel();
187 });
188
189 let s = SocketConnectionStatus::Connected(addr.clone(), teardown.clone());
190
191 status.set(s);
192
193 let _ = select_all(vec![write_handle, read_handle]).await;
194
195 warn!("Read and/or Write Exited - Reconnecting in 1s");
196
197 let s = SocketConnectionStatus::Disconnected;
198
199 status.set(s);
200
201 tokio::time::sleep(Duration::from_secs(1)).await;
202 }
203 });
204 }
205}