1use std::{
2 collections::HashMap,
3 fmt::Debug,
4 str::FromStr,
5 sync::{
6 atomic::{AtomicU64, Ordering},
7 Arc,
8 },
9};
10
11use async_trait::async_trait;
12use base64::Engine;
13use futures::{
14 sink::SinkExt,
15 stream::{BoxStream, SplitSink, StreamExt},
16};
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use thiserror::Error;
20use tokio::{
21 net::TcpStream,
22 sync::{mpsc, oneshot, RwLock},
23};
24use tokio_tungstenite::tungstenite::{
25 client::IntoClientRequest,
26 protocol::{frame::coding::CloseCode, CloseFrame},
27 Message,
28};
29use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
30use url::Url;
31
32use nimiq_jsonrpc_core::{
33 Request, RequestOrResponse, Response, SubscriptionId, SubscriptionMessage,
34};
35
36use crate::{Client, Credentials};
37
38#[derive(Debug, Error)]
40pub enum Error {
41 #[error("HTTP protocol error: {0}")]
43 HTTP(#[from] http::Error),
44
45 #[error("Websocket protocol error: {0}")]
47 Websocket(#[from] tokio_tungstenite::tungstenite::Error),
48
49 #[error("JSON-RPC protocol error: {0}")]
51 JsonRpc(#[from] nimiq_jsonrpc_core::Error),
52
53 #[error("JSON error: {0}")]
55 Json(#[from] serde_json::Error),
56
57 #[error("{0}")]
59 OneshotRecv(#[from] oneshot::error::RecvError),
60
61 #[error("{0}")]
63 MpscSend(#[from] mpsc::error::SendError<SubscriptionMessage<Value>>),
64}
65
66type StreamsMap = HashMap<SubscriptionId, mpsc::Sender<SubscriptionMessage<Value>>>;
67type RequestsMap = HashMap<u64, oneshot::Sender<Response>>;
68
69pub struct WebsocketClient {
72 streams: Arc<RwLock<StreamsMap>>,
73 requests: Arc<RwLock<RequestsMap>>,
74 sender: RwLock<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
75 next_id: AtomicU64,
76}
77
78impl WebsocketClient {
79 pub async fn new(url: Url, basic_auth: Option<Credentials>) -> Result<Self, Error> {
87 let request = {
88 let uri: http::Uri = url.to_string().parse().unwrap();
89 let mut request = uri.into_client_request()?;
90
91 if let Some(basic_auth) = basic_auth {
92 let header_value = format!(
93 "Basic {}",
94 base64::prelude::BASE64_STANDARD
95 .encode(format!("{}:{}", basic_auth.username, basic_auth.password.0))
96 );
97 request.headers_mut().append(
98 "Authorization",
99 header_value
100 .parse()
101 .map_err(|e| Error::HTTP(http::Error::from(e)))?,
102 );
103 }
104
105 request
106 };
107
108 log::debug!("HTTP request: {:?}", request);
109
110 let (ws_stream, _) = connect_async(request).await?;
111
112 let (ws_tx, mut ws_rx) = ws_stream.split();
113
114 let streams = Arc::new(RwLock::new(HashMap::new()));
115 let requests = Arc::new(RwLock::new(HashMap::new()));
116
117 {
118 let streams = Arc::clone(&streams);
119 let requests = Arc::clone(&requests);
120
121 tokio::spawn(async move {
122 while let Some(message_result) = ws_rx.next().await {
123 match message_result {
124 Ok(message) => {
125 if let Err(e) =
126 Self::handle_websocket_message(&streams, &requests, message).await
127 {
128 log::error!("{}", e);
129 }
130 }
131 Err(e) => {
132 log::error!("{}", e);
133 }
134 }
135 }
136 });
137 }
138
139 Ok(Self {
140 next_id: AtomicU64::new(1),
141 sender: RwLock::new(ws_tx),
142 streams,
143 requests,
144 })
145 }
146
147 pub async fn with_url(url: Url) -> Result<Self, Error> {
154 Self::new(url, None).await
155 }
156
157 async fn handle_websocket_message(
158 streams: &Arc<RwLock<StreamsMap>>,
159 requests: &Arc<RwLock<RequestsMap>>,
160 message: Message,
161 ) -> Result<(), Error> {
162 let data = message.into_text()?;
164
165 log::trace!("Received message: {:?}", data);
166
167 let message = RequestOrResponse::from_str(&data)?;
168
169 match message {
170 RequestOrResponse::Request(request) => {
171 if request.id.is_some() {
172 log::error!("Received unexpected request, which is not a notification.");
173 } else if let Some(params) = request.params {
174 let message: SubscriptionMessage<Value> = serde_json::from_value(params)
175 .expect("Failed to deserialize request parameters");
176
177 let mut streams = streams.write().await;
178 if let Some(tx) = streams.get_mut(&message.subscription) {
179 tx.send(message).await?;
180 } else {
181 log::error!(
182 "Notification for unknown stream ID: {}",
183 message.subscription
184 );
185 }
186 } else {
187 log::error!("No 'params' field in notification.");
188 }
189 }
190 RequestOrResponse::Response(response) => {
191 let mut requests = requests.write().await;
192
193 if let Some(tx) = response.id.as_u64().and_then(|id| requests.remove(&id)) {
194 drop(requests);
195 tx.send(response).ok();
196 } else {
197 log::error!("Response for unknown request ID: {}", response.id);
198 }
199 }
200 }
201
202 Ok(())
203 }
204}
205
206#[async_trait]
207impl Client for WebsocketClient {
208 type Error = Error;
209
210 async fn send_request<P, R>(&self, method: &str, params: &P) -> Result<R, Self::Error>
211 where
212 P: Serialize + Debug + Send + Sync,
213 R: for<'de> Deserialize<'de> + Debug + Send + Sync,
214 {
215 let request_id = self.next_id.fetch_add(1, Ordering::SeqCst);
216 let request = Request::build(method.to_owned(), Some(params), Some(&request_id))
217 .expect("Failed to serialize JSON-RPC request.");
218
219 log::debug!("Sending request: {:?}", request);
220
221 self.sender
222 .write()
223 .await
224 .send(Message::binary(serde_json::to_vec(&request)?))
225 .await?;
226
227 let (tx, rx) = oneshot::channel();
228
229 let mut requests = self.requests.write().await;
230 requests.insert(request_id, tx);
231 drop(requests);
232
233 let response = rx.await?;
234 log::debug!("Received response: {:?}", response);
235
236 Ok(response.into_result()?)
237 }
238
239 async fn connect_stream<T: Unpin + 'static>(&self, id: SubscriptionId) -> BoxStream<'static, T>
240 where
241 T: for<'de> Deserialize<'de> + Debug + Send + Sync,
242 {
243 let (tx, mut rx) = mpsc::channel(16);
244
245 self.streams.write().await.insert(id, tx);
246
247 let stream = async_stream::stream! {
248 while let Some(message) = rx.recv().await {
249 yield serde_json::from_value(message.result).unwrap();
250 }
251 };
252
253 stream.boxed()
254 }
255
256 async fn disconnect_stream(&self, id: SubscriptionId) -> Result<(), Self::Error> {
257 if let Some(tx) = self.streams.write().await.remove(&id) {
258 log::debug!("Closing stream of subscription ID: {}", id);
259 drop(tx);
260 } else {
261 log::error!("Unknown subscription ID: {}", id);
262 }
263
264 Ok(())
265 }
266
267 async fn close(&self) {
269 let _ = self
272 .sender
273 .write()
274 .await
275 .send(Message::Close(Some(CloseFrame {
276 code: CloseCode::Normal,
277 reason: "".into(),
278 })))
279 .await;
280 }
281}