alloy_transport_ws/
native.rs1use crate::WsBackend;
2use alloy_pubsub::PubSubConnect;
3use alloy_transport::{utils::Spawnable, Authorization, TransportErrorKind, TransportResult};
4use futures::{SinkExt, StreamExt};
5use serde_json::value::RawValue;
6use std::time::Duration;
7pub use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
8use tokio_tungstenite::{
9 tungstenite::{self, client::IntoClientRequest, Message},
10 MaybeTlsStream, WebSocketStream,
11};
12
13#[cfg(target_arch = "wasm32")]
14use wasmtimer::tokio::sleep;
15
16#[cfg(not(target_arch = "wasm32"))]
17use tokio::time::sleep;
18
19type TungsteniteStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
20
21const KEEPALIVE: u64 = 10;
22
23#[derive(Clone, Debug)]
25pub struct WsConnect {
26 pub url: String,
28 pub auth: Option<Authorization>,
30 pub config: Option<WebSocketConfig>,
32}
33
34impl WsConnect {
35 pub fn new<S: Into<String>>(url: S) -> Self {
37 Self { url: url.into(), auth: None, config: None }
38 }
39
40 pub fn with_auth(mut self, auth: Authorization) -> Self {
42 self.auth = Some(auth);
43 self
44 }
45
46 pub const fn with_config(mut self, config: WebSocketConfig) -> Self {
48 self.config = Some(config);
49 self
50 }
51}
52
53impl IntoClientRequest for WsConnect {
54 fn into_client_request(self) -> tungstenite::Result<tungstenite::handshake::client::Request> {
55 let mut request: http::Request<()> = self.url.into_client_request()?;
56 if let Some(auth) = self.auth {
57 let mut auth_value = http::HeaderValue::from_str(&auth.to_string())?;
58 auth_value.set_sensitive(true);
59
60 request.headers_mut().insert(http::header::AUTHORIZATION, auth_value);
61 }
62
63 request.into_client_request()
64 }
65}
66
67impl PubSubConnect for WsConnect {
68 fn is_local(&self) -> bool {
69 alloy_transport::utils::guess_local_url(&self.url)
70 }
71
72 async fn connect(&self) -> TransportResult<alloy_pubsub::ConnectionHandle> {
73 let request = self.clone().into_client_request();
74 let req = request.map_err(TransportErrorKind::custom)?;
75 let (socket, _) = tokio_tungstenite::connect_async_with_config(req, self.config, false)
76 .await
77 .map_err(TransportErrorKind::custom)?;
78
79 let (handle, interface) = alloy_pubsub::ConnectionHandle::new();
80 let backend = WsBackend { socket, interface };
81
82 backend.spawn();
83
84 Ok(handle)
85 }
86}
87
88impl WsBackend<TungsteniteStream> {
89 #[allow(clippy::result_unit_err)]
91 pub fn handle(&mut self, msg: Message) -> Result<(), ()> {
92 match msg {
93 Message::Text(text) => self.handle_text(&text),
94 Message::Close(frame) => {
95 if frame.is_some() {
96 error!(?frame, "Received close frame with data");
97 } else {
98 error!("WS server has gone away");
99 }
100 Err(())
101 }
102 Message::Binary(_) => {
103 error!("Received binary message, expected text");
104 Err(())
105 }
106 Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => Ok(()),
107 }
108 }
109
110 pub async fn send(&mut self, msg: Box<RawValue>) -> Result<(), tungstenite::Error> {
112 self.socket.send(Message::Text(msg.get().to_owned().into())).await
113 }
114
115 pub fn spawn(mut self) {
117 let fut = async move {
118 let mut errored = false;
119 let mut expecting_pong = false;
120 let keepalive = sleep(Duration::from_secs(KEEPALIVE));
121 tokio::pin!(keepalive);
122 loop {
123 tokio::select! {
133 biased;
134 inst = self.interface.recv_from_frontend() => {
138 match inst {
139 Some(msg) => {
140 keepalive.set(sleep(Duration::from_secs(KEEPALIVE)));
142 if let Err(err) = self.send(msg).await {
143 error!(%err, "WS connection error");
144 errored = true;
145 break
146 }
147 },
148 None => {
150 break
151 },
152 }
153 },
154 _ = &mut keepalive => {
157 if expecting_pong {
160 error!("WS server missed a pong");
161 errored = true;
162 break
163 }
164 keepalive.set(sleep(Duration::from_secs(KEEPALIVE)));
166 if let Err(err) = self.socket.send(Message::Ping(Default::default())).await {
167 error!(%err, "WS connection error");
168 errored = true;
169 break
170 }
171 expecting_pong = true;
174 }
175 resp = self.socket.next() => {
176 match resp {
177 Some(Ok(item)) => {
178 if item.is_pong() {
179 expecting_pong = false;
180 }
181 errored = self.handle(item).is_err();
182 if errored { break }
183 },
184 Some(Err(err)) => {
185 error!(%err, "WS connection error");
186 errored = true;
187 break
188 }
189 None => {
190 error!("WS server has gone away");
191 errored = true;
192 break
193 },
194 }
195 }
196 }
197 }
198 if errored {
199 self.interface.close_with_error();
200 }
201 };
202 fut.spawn_task()
203 }
204}