1#[cfg(test)]
2mod tests;
3
4use std::time::Duration;
5
6use futures_util::sink::SinkExt;
7use futures_util::stream::{SplitSink, SplitStream, StreamExt};
8use kanal::{AsyncReceiver, AsyncSender};
9
10use tokio_tungstenite::WebSocketStream;
11
12use tokio_tungstenite::tungstenite::protocol::Message;
13use url::Url;
14
15#[derive(Clone)]
16pub struct Wsconfig {
17 pub insecure: Option<bool>,
18 pub private_chain_bytes: Option<Vec<u8>>,
19}
20
21#[derive(Clone, Debug)]
22#[cfg(feature = "tls")]
23struct NoVerifier;
24#[cfg(feature = "tls")]
25
26impl rustls::client::danger::ServerCertVerifier for NoVerifier {
27 fn verify_server_cert(
28 &self,
29 _end_entity: &rustls::pki_types::CertificateDer<'_>,
30 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
31 _server_name: &rustls::pki_types::ServerName,
32 _ocsp_response: &[u8],
33 _now: rustls::pki_types::UnixTime,
34 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
35 Ok(rustls::client::danger::ServerCertVerified::assertion())
36 }
37
38 fn verify_tls12_signature(
39 &self,
40 _message: &[u8],
41 _cert: &rustls::pki_types::CertificateDer<'_>,
42 _dss: &rustls::DigitallySignedStruct,
43 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
44 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
45 }
46
47 fn verify_tls13_signature(
48 &self,
49 _message: &[u8],
50 _cert: &rustls::pki_types::CertificateDer<'_>,
51 _dss: &rustls::DigitallySignedStruct,
52 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
53 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
54 }
55
56 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
57 vec![
58 rustls::SignatureScheme::RSA_PKCS1_SHA256,
59 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
60 rustls::SignatureScheme::RSA_PKCS1_SHA384,
61 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
62 rustls::SignatureScheme::RSA_PKCS1_SHA512,
63 rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
64 rustls::SignatureScheme::RSA_PSS_SHA256,
65 rustls::SignatureScheme::RSA_PSS_SHA384,
66 rustls::SignatureScheme::RSA_PSS_SHA512,
67 rustls::SignatureScheme::ED25519,
68 rustls::SignatureScheme::ED448,
69 ]
70 }
71}
72pub async fn initialize_default_websocket_connection(
73 url: Url,
74) -> anyhow::Result<(
75 SplitSink<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>,
76 SplitStream<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>>,
77)> {
78 println!(
79 "Connecting to the WebSocket server at {}...",
80 &url.to_string()
81 );
82
83 let (ws_stream, _) = tokio_tungstenite::connect_async(&url.to_string()).await?;
84 println!("Successfully connected to the WebSocket server.");
85
86 Ok(ws_stream.split())
87}
88
89#[cfg(feature = "tls")]
90pub async fn initialize_insecure_tls(
91 url: Url,
92) -> anyhow::Result<(
93 SplitSink<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>,
94 SplitStream<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>>,
95)> {
96 println!(
97 "Connecting to the WebSocket server at {}...",
98 &url.to_string()
99 );
100
101 let root_cert_store = rustls::RootCertStore::empty();
102
103 let mut config = rustls::ClientConfig::builder()
104 .with_root_certificates(root_cert_store)
105 .with_no_client_auth();
106 config
107 .dangerous()
108 .set_certificate_verifier(std::sync::Arc::new(NoVerifier));
109
110 let connector = tokio_tungstenite::Connector::Rustls(std::sync::Arc::new(config));
111
112 let (ws_stream, _) =
113 tokio_tungstenite::connect_async_tls_with_config(url, None, true, Some(connector)).await?;
114
115 println!("Successfully connected to the WebSocket server.");
116
117 Ok(ws_stream.split())
118}
119
120#[cfg(feature = "tls")]
121pub async fn initialize_private_tls(
122 url: Url,
123 private_chain_bytes: &[u8],
124) -> anyhow::Result<(
125 SplitSink<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>,
126 SplitStream<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>>,
127)> {
128 println!(
129 "Connecting to the WebSocket server at {}...",
130 &url.to_string()
131 );
132
133 let mut cert_cursor = std::io::Cursor::new(private_chain_bytes);
134 let cert_chain: Result<Vec<rustls::pki_types::CertificateDer<'_>>, anyhow::Error> =
135 rustls_pemfile::certs(&mut cert_cursor)
136 .collect::<Result<Vec<_>, _>>()
137 .map_err(|e| anyhow::anyhow!("Error parsing certificate: {:?}", e));
138
139 let mut root_cert_store = rustls::RootCertStore::empty();
140
141 root_cert_store.add_parsable_certificates(cert_chain?);
142
143 let config = rustls::ClientConfig::builder()
144 .with_root_certificates(root_cert_store)
145 .with_no_client_auth();
146
147 let connector = tokio_tungstenite::Connector::Rustls(std::sync::Arc::new(config));
148
149 let url = Url::parse(url.as_str())?;
151
152 let (ws_stream, _) =
153 tokio_tungstenite::connect_async_tls_with_config(url, None, true, Some(connector)).await?;
154
155 println!("Successfully connected to the WebSocket server.");
156
157 Ok(ws_stream.split())
158}
159
160#[cfg(feature = "tls")]
161pub async fn initialize(
162 uri: Url,
163 ws_config: Option<Wsconfig>,
164) -> anyhow::Result<(
165 SplitSink<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>,
166 SplitStream<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>>,
167)> {
168 let url = Url::parse(uri.as_str())?;
169 if ws_config.clone().is_some() {
170 let ws_cfg = ws_config.clone().unwrap();
171 if ws_cfg.insecure.is_some() {
172 initialize_insecure_tls(url).await
173 } else if ws_cfg.private_chain_bytes.is_some() {
174 initialize_private_tls(url, &ws_cfg.private_chain_bytes.unwrap()).await
175 } else {
176 initialize_default_websocket_connection(url).await
177 }
178 } else {
179 if url.scheme() == "ws" {
180 println!(
181 "Connecting to the OPEN WebSocket server at {}...",
182 &url.to_string()
183 );
184 }
185
186 initialize_default_websocket_connection(url).await
187 }
188}
189
190#[cfg(not(feature = "tls"))]
191pub async fn websocket_handler(
192 uri: Url,
193 ws_channel_receiver: AsyncReceiver<String>,
194 events_channel_sender: AsyncSender<String>,
195) -> anyhow::Result<()> {
196 let (mut ws_sink, mut ws_stream) = initialize_default_websocket_connection(uri).await?;
197
198 let tx_loop = tokio::spawn(async move {
199 while let Ok(msg) = ws_channel_receiver.recv().await {
200 ws_sink.send(Message::Text(msg)).await?;
201 }
202 Ok::<(), anyhow::Error>(())
203 });
204
205 let rx_loop = tokio::spawn(async move {
206 while let Some(msg) = ws_stream.next().await {
207 match msg {
208 Ok(Message::Text(text)) => {
209 events_channel_sender.send(text).await?;
210 }
211 Ok(_) => {}
212 Err(e) => {
213 return Err(anyhow::anyhow!("Error receiving message: {}", e));
214 }
215 }
216 }
217 Ok::<(), anyhow::Error>(())
218 });
219
220 _ = tokio::try_join!(tx_loop, rx_loop)?;
221 Err(anyhow::anyhow!("WebSocket handler exited!"))
222}
223
224#[cfg(feature = "tls")]
225pub async fn websocket_handler(
226 uri: Url,
227 ws_channel_receiver: AsyncReceiver<String>,
228 events_channel_sender: AsyncSender<String>,
229 ws_config: Option<Wsconfig>,
230) -> anyhow::Result<()> {
231 let (mut ws_sink, mut ws_stream) = initialize(uri, ws_config).await?;
232
233 let tx_loop = tokio::spawn(async move {
234 while let Ok(msg) = ws_channel_receiver.recv().await {
235 ws_sink.send(Message::Text(msg)).await?;
236 }
237 Ok::<(), anyhow::Error>(())
238 });
239
240 let rx_loop = tokio::spawn(async move {
241 while let Some(msg) = ws_stream.next().await {
242 match msg {
243 Ok(Message::Text(text)) => {
244 events_channel_sender.send(text).await?;
245 }
246 Ok(_) => {}
247 Err(e) => {
248 return Err(anyhow::anyhow!("Error receiving message: {}", e));
249 }
250 }
251 }
252 Ok::<(), anyhow::Error>(())
253 });
254
255 _ = tokio::try_join!(tx_loop, rx_loop)?;
256 Err(anyhow::anyhow!("WebSocket handler exited!"))
257}
258
259pub fn create_channel() -> (AsyncSender<String>, AsyncReceiver<String>) {
260 let (ws_channel_sender, ws_channel_receiver) = kanal::unbounded_async();
261 (ws_channel_sender, ws_channel_receiver)
262}
263
264#[cfg(feature = "tls")]
265pub async fn start_websocket(
266 uri: Url,
267 ws_channel_receiver: AsyncReceiver<String>,
268 events_channel_sender: AsyncSender<String>,
269 ws_config: Option<Wsconfig>,
270) -> anyhow::Result<()> {
271 let timeout_in_seconds = 60;
272 println!("start websocket routine");
273
274 loop {
275 let t = websocket_handler(
276 uri.clone(),
277 ws_channel_receiver.clone(),
278 events_channel_sender.clone(),
279 ws_config.clone(),
280 )
281 .await;
282
283 if t.is_err() {
284 let msg = format!("websocket error {:?}", t.unwrap_err());
285 eprintln!("{}", msg);
286 }
287
288 println!(
289 "restarting websocket routine in {} seconds",
290 timeout_in_seconds
291 );
292 tokio::time::sleep(Duration::from_secs(timeout_in_seconds)).await;
293 }
294}
295
296#[cfg(not(feature = "tls"))]
297pub async fn start_websocket(
298 uri: Url,
299 ws_channel_receiver: AsyncReceiver<String>,
300 events_channel_sender: AsyncSender<String>,
301) -> anyhow::Result<()> {
302 let timeout_in_seconds = 60;
303 println!("start websocket routine");
304
305 loop {
306 let t = websocket_handler(
307 uri.clone(),
308 ws_channel_receiver.clone(),
309 events_channel_sender.clone(),
310 )
311 .await;
312
313 if t.is_err() {
314 let msg = format!("websocket error {:?}", t.unwrap_err());
315 eprintln!("{}", msg);
316 }
317
318 println!(
319 "restarting websocket routine in {} seconds",
320 timeout_in_seconds
321 );
322 tokio::time::sleep(Duration::from_secs(timeout_in_seconds)).await;
323 }
324}