simplews/
lib.rs

1#[cfg(test)]
2mod tests;
3
4use std::time::Duration;
5
6use futures_util::sink::SinkExt;
7use futures_util::stream::{SplitSink, SplitStream, StreamExt};
8use kanal::{AsyncReceiver, AsyncSender};
9
10use tokio_tungstenite::WebSocketStream;
11
12use tokio_tungstenite::tungstenite::protocol::Message;
13use url::Url;
14
15#[derive(Clone)]
16pub struct Wsconfig {
17    pub insecure: Option<bool>,
18    pub private_chain_bytes: Option<Vec<u8>>,
19}
20
21#[derive(Clone, Debug)]
22#[cfg(feature = "tls")]
23struct NoVerifier;
24#[cfg(feature = "tls")]
25
26impl rustls::client::danger::ServerCertVerifier for NoVerifier {
27    fn verify_server_cert(
28        &self,
29        _end_entity: &rustls::pki_types::CertificateDer<'_>,
30        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
31        _server_name: &rustls::pki_types::ServerName,
32        _ocsp_response: &[u8],
33        _now: rustls::pki_types::UnixTime,
34    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
35        Ok(rustls::client::danger::ServerCertVerified::assertion())
36    }
37
38    fn verify_tls12_signature(
39        &self,
40        _message: &[u8],
41        _cert: &rustls::pki_types::CertificateDer<'_>,
42        _dss: &rustls::DigitallySignedStruct,
43    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
44        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
45    }
46
47    fn verify_tls13_signature(
48        &self,
49        _message: &[u8],
50        _cert: &rustls::pki_types::CertificateDer<'_>,
51        _dss: &rustls::DigitallySignedStruct,
52    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
53        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
54    }
55
56    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
57        vec![
58            rustls::SignatureScheme::RSA_PKCS1_SHA256,
59            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
60            rustls::SignatureScheme::RSA_PKCS1_SHA384,
61            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
62            rustls::SignatureScheme::RSA_PKCS1_SHA512,
63            rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
64            rustls::SignatureScheme::RSA_PSS_SHA256,
65            rustls::SignatureScheme::RSA_PSS_SHA384,
66            rustls::SignatureScheme::RSA_PSS_SHA512,
67            rustls::SignatureScheme::ED25519,
68            rustls::SignatureScheme::ED448,
69        ]
70    }
71}
72pub async fn initialize_default_websocket_connection(
73    url: Url,
74) -> anyhow::Result<(
75    SplitSink<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>,
76    SplitStream<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>>,
77)> {
78    println!(
79        "Connecting to the WebSocket server at {}...",
80        &url.to_string()
81    );
82
83    let (ws_stream, _) = tokio_tungstenite::connect_async(&url.to_string()).await?;
84    println!("Successfully connected to the WebSocket server.");
85
86    Ok(ws_stream.split())
87}
88
89#[cfg(feature = "tls")]
90pub async fn initialize_insecure_tls(
91    url: Url,
92) -> anyhow::Result<(
93    SplitSink<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>,
94    SplitStream<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>>,
95)> {
96    println!(
97        "Connecting to the WebSocket server at {}...",
98        &url.to_string()
99    );
100
101    let root_cert_store = rustls::RootCertStore::empty();
102
103    let mut config = rustls::ClientConfig::builder()
104        .with_root_certificates(root_cert_store)
105        .with_no_client_auth();
106    config
107        .dangerous()
108        .set_certificate_verifier(std::sync::Arc::new(NoVerifier));
109
110    let connector = tokio_tungstenite::Connector::Rustls(std::sync::Arc::new(config));
111
112    let (ws_stream, _) =
113        tokio_tungstenite::connect_async_tls_with_config(url, None, true, Some(connector)).await?;
114
115    println!("Successfully connected to the WebSocket server.");
116
117    Ok(ws_stream.split())
118}
119
120#[cfg(feature = "tls")]
121pub async fn initialize_private_tls(
122    url: Url,
123    private_chain_bytes: &[u8],
124) -> anyhow::Result<(
125    SplitSink<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>,
126    SplitStream<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>>,
127)> {
128    println!(
129        "Connecting to the WebSocket server at {}...",
130        &url.to_string()
131    );
132
133    let mut cert_cursor = std::io::Cursor::new(private_chain_bytes);
134    let cert_chain: Result<Vec<rustls::pki_types::CertificateDer<'_>>, anyhow::Error> =
135        rustls_pemfile::certs(&mut cert_cursor)
136            .collect::<Result<Vec<_>, _>>()
137            .map_err(|e| anyhow::anyhow!("Error parsing certificate: {:?}", e));
138
139    let mut root_cert_store = rustls::RootCertStore::empty();
140
141    root_cert_store.add_parsable_certificates(cert_chain?);
142
143    let config = rustls::ClientConfig::builder()
144        .with_root_certificates(root_cert_store)
145        .with_no_client_auth();
146
147    let connector = tokio_tungstenite::Connector::Rustls(std::sync::Arc::new(config));
148
149    // Connect to the web socket
150    let url = Url::parse(url.as_str())?;
151
152    let (ws_stream, _) =
153        tokio_tungstenite::connect_async_tls_with_config(url, None, true, Some(connector)).await?;
154
155    println!("Successfully connected to the WebSocket server.");
156
157    Ok(ws_stream.split())
158}
159
160#[cfg(feature = "tls")]
161pub async fn initialize(
162    uri: Url,
163    ws_config: Option<Wsconfig>,
164) -> anyhow::Result<(
165    SplitSink<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>,
166    SplitStream<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>>,
167)> {
168    let url = Url::parse(uri.as_str())?;
169    if ws_config.clone().is_some() {
170        let ws_cfg = ws_config.clone().unwrap();
171        if ws_cfg.insecure.is_some() {
172            initialize_insecure_tls(url).await
173        } else if ws_cfg.private_chain_bytes.is_some() {
174            initialize_private_tls(url, &ws_cfg.private_chain_bytes.unwrap()).await
175        } else {
176            initialize_default_websocket_connection(url).await
177        }
178    } else {
179        if url.scheme() == "ws" {
180            println!(
181                "Connecting to the OPEN WebSocket server at {}...",
182                &url.to_string()
183            );
184        }
185
186        initialize_default_websocket_connection(url).await
187    }
188}
189
190#[cfg(not(feature = "tls"))]
191pub async fn websocket_handler(
192    uri: Url,
193    ws_channel_receiver: AsyncReceiver<String>,
194    events_channel_sender: AsyncSender<String>,
195) -> anyhow::Result<()> {
196    let (mut ws_sink, mut ws_stream) = initialize_default_websocket_connection(uri).await?;
197
198    let tx_loop = tokio::spawn(async move {
199        while let Ok(msg) = ws_channel_receiver.recv().await {
200            ws_sink.send(Message::Text(msg)).await?;
201        }
202        Ok::<(), anyhow::Error>(())
203    });
204
205    let rx_loop = tokio::spawn(async move {
206        while let Some(msg) = ws_stream.next().await {
207            match msg {
208                Ok(Message::Text(text)) => {
209                    events_channel_sender.send(text).await?;
210                }
211                Ok(_) => {}
212                Err(e) => {
213                    return Err(anyhow::anyhow!("Error receiving message: {}", e));
214                }
215            }
216        }
217        Ok::<(), anyhow::Error>(())
218    });
219
220    _ = tokio::try_join!(tx_loop, rx_loop)?;
221    Err(anyhow::anyhow!("WebSocket handler exited!"))
222}
223
224#[cfg(feature = "tls")]
225pub async fn websocket_handler(
226    uri: Url,
227    ws_channel_receiver: AsyncReceiver<String>,
228    events_channel_sender: AsyncSender<String>,
229    ws_config: Option<Wsconfig>,
230) -> anyhow::Result<()> {
231    let (mut ws_sink, mut ws_stream) = initialize(uri, ws_config).await?;
232
233    let tx_loop = tokio::spawn(async move {
234        while let Ok(msg) = ws_channel_receiver.recv().await {
235            ws_sink.send(Message::Text(msg)).await?;
236        }
237        Ok::<(), anyhow::Error>(())
238    });
239
240    let rx_loop = tokio::spawn(async move {
241        while let Some(msg) = ws_stream.next().await {
242            match msg {
243                Ok(Message::Text(text)) => {
244                    events_channel_sender.send(text).await?;
245                }
246                Ok(_) => {}
247                Err(e) => {
248                    return Err(anyhow::anyhow!("Error receiving message: {}", e));
249                }
250            }
251        }
252        Ok::<(), anyhow::Error>(())
253    });
254
255    _ = tokio::try_join!(tx_loop, rx_loop)?;
256    Err(anyhow::anyhow!("WebSocket handler exited!"))
257}
258
259pub fn create_channel() -> (AsyncSender<String>, AsyncReceiver<String>) {
260    let (ws_channel_sender, ws_channel_receiver) = kanal::unbounded_async();
261    (ws_channel_sender, ws_channel_receiver)
262}
263
264#[cfg(feature = "tls")]
265pub async fn start_websocket(
266    uri: Url,
267    ws_channel_receiver: AsyncReceiver<String>,
268    events_channel_sender: AsyncSender<String>,
269    ws_config: Option<Wsconfig>,
270) -> anyhow::Result<()> {
271    let timeout_in_seconds = 60;
272    println!("start websocket routine");
273
274    loop {
275        let t = websocket_handler(
276            uri.clone(),
277            ws_channel_receiver.clone(),
278            events_channel_sender.clone(),
279            ws_config.clone(),
280        )
281        .await;
282
283        if t.is_err() {
284            let msg = format!("websocket error {:?}", t.unwrap_err());
285            eprintln!("{}", msg);
286        }
287
288        println!(
289            "restarting websocket routine in {} seconds",
290            timeout_in_seconds
291        );
292        tokio::time::sleep(Duration::from_secs(timeout_in_seconds)).await;
293    }
294}
295
296#[cfg(not(feature = "tls"))]
297pub async fn start_websocket(
298    uri: Url,
299    ws_channel_receiver: AsyncReceiver<String>,
300    events_channel_sender: AsyncSender<String>,
301) -> anyhow::Result<()> {
302    let timeout_in_seconds = 60;
303    println!("start websocket routine");
304
305    loop {
306        let t = websocket_handler(
307            uri.clone(),
308            ws_channel_receiver.clone(),
309            events_channel_sender.clone(),
310        )
311        .await;
312
313        if t.is_err() {
314            let msg = format!("websocket error {:?}", t.unwrap_err());
315            eprintln!("{}", msg);
316        }
317
318        println!(
319            "restarting websocket routine in {} seconds",
320            timeout_in_seconds
321        );
322        tokio::time::sleep(Duration::from_secs(timeout_in_seconds)).await;
323    }
324}