nimiq_jsonrpc_client/
websocket.rs

1use std::{
2    collections::HashMap,
3    fmt::Debug,
4    str::FromStr,
5    sync::{
6        atomic::{AtomicU64, Ordering},
7        Arc,
8    },
9};
10
11use async_trait::async_trait;
12use base64::Engine;
13use futures::{
14    sink::SinkExt,
15    stream::{BoxStream, SplitSink, StreamExt},
16};
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use thiserror::Error;
20use tokio::{
21    net::TcpStream,
22    sync::{mpsc, oneshot, RwLock},
23};
24use tokio_tungstenite::tungstenite::{
25    client::IntoClientRequest,
26    protocol::{frame::coding::CloseCode, CloseFrame},
27    Message,
28};
29use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
30use url::Url;
31
32use nimiq_jsonrpc_core::{
33    Request, RequestOrResponse, Response, SubscriptionId, SubscriptionMessage,
34};
35
36use crate::{Client, Credentials};
37
38/// Error type returned by websocket client.
39#[derive(Debug, Error)]
40pub enum Error {
41    /// HTTP error
42    #[error("HTTP protocol error: {0}")]
43    HTTP(#[from] http::Error),
44
45    /// Websocket error
46    #[error("Websocket protocol error: {0}")]
47    Websocket(#[from] tokio_tungstenite::tungstenite::Error),
48
49    /// JSON-RPC protocol error
50    #[error("JSON-RPC protocol error: {0}")]
51    JsonRpc(#[from] nimiq_jsonrpc_core::Error),
52
53    /// JSON error
54    #[error("JSON error: {0}")]
55    Json(#[from] serde_json::Error),
56
57    /// Error in the internal oneshot channel.
58    #[error("{0}")]
59    OneshotRecv(#[from] oneshot::error::RecvError),
60
61    /// Error in the internal MPSC channel.
62    #[error("{0}")]
63    MpscSend(#[from] mpsc::error::SendError<SubscriptionMessage<Value>>),
64}
65
66type StreamsMap = HashMap<SubscriptionId, mpsc::Sender<SubscriptionMessage<Value>>>;
67type RequestsMap = HashMap<u64, oneshot::Sender<Response>>;
68
69/// A websocket JSON-RPC client.
70///
71pub struct WebsocketClient {
72    streams: Arc<RwLock<StreamsMap>>,
73    requests: Arc<RwLock<RequestsMap>>,
74    sender: RwLock<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
75    next_id: AtomicU64,
76}
77
78impl WebsocketClient {
79    /// Creates a new JSON-RPC websocket client.
80    ///
81    /// # Arguments
82    ///
83    ///  - `url`: The URL of the websocket endpoint (.e.g `ws://localhost:8000/ws`)
84    ///  - `basic_auth`: Credentials for HTTP basic auth.
85    ///
86    pub async fn new(url: Url, basic_auth: Option<Credentials>) -> Result<Self, Error> {
87        let request = {
88            let uri: http::Uri = url.to_string().parse().unwrap();
89            let mut request = uri.into_client_request()?;
90
91            if let Some(basic_auth) = basic_auth {
92                let header_value = format!(
93                    "Basic {}",
94                    base64::prelude::BASE64_STANDARD
95                        .encode(format!("{}:{}", basic_auth.username, basic_auth.password.0))
96                );
97                request.headers_mut().append(
98                    "Authorization",
99                    header_value
100                        .parse()
101                        .map_err(|e| Error::HTTP(http::Error::from(e)))?,
102                );
103            }
104
105            request
106        };
107
108        log::debug!("HTTP request: {:?}", request);
109
110        let (ws_stream, _) = connect_async(request).await?;
111
112        let (ws_tx, mut ws_rx) = ws_stream.split();
113
114        let streams = Arc::new(RwLock::new(HashMap::new()));
115        let requests = Arc::new(RwLock::new(HashMap::new()));
116
117        {
118            let streams = Arc::clone(&streams);
119            let requests = Arc::clone(&requests);
120
121            tokio::spawn(async move {
122                while let Some(message_result) = ws_rx.next().await {
123                    match message_result {
124                        Ok(message) => {
125                            if let Err(e) =
126                                Self::handle_websocket_message(&streams, &requests, message).await
127                            {
128                                log::error!("{}", e);
129                            }
130                        }
131                        Err(e) => {
132                            log::error!("{}", e);
133                        }
134                    }
135                }
136            });
137        }
138
139        Ok(Self {
140            next_id: AtomicU64::new(1),
141            sender: RwLock::new(ws_tx),
142            streams,
143            requests,
144        })
145    }
146
147    /// Creates a new JSON-RPC websocket client.
148    ///
149    /// # Arguments
150    ///
151    ///  - `url`: The URL of the websocket endpoint (.e.g `ws://localhost:8000/ws`)
152    ///
153    pub async fn with_url(url: Url) -> Result<Self, Error> {
154        Self::new(url, None).await
155    }
156
157    async fn handle_websocket_message(
158        streams: &Arc<RwLock<StreamsMap>>,
159        requests: &Arc<RwLock<RequestsMap>>,
160        message: Message,
161    ) -> Result<(), Error> {
162        // FIXME: This will also accept pings
163        let data = message.into_text()?;
164
165        log::trace!("Received message: {:?}", data);
166
167        let message = RequestOrResponse::from_str(&data)?;
168
169        match message {
170            RequestOrResponse::Request(request) => {
171                if request.id.is_some() {
172                    log::error!("Received unexpected request, which is not a notification.");
173                } else if let Some(params) = request.params {
174                    let message: SubscriptionMessage<Value> = serde_json::from_value(params)
175                        .expect("Failed to deserialize request parameters");
176
177                    let mut streams = streams.write().await;
178                    if let Some(tx) = streams.get_mut(&message.subscription) {
179                        tx.send(message).await?;
180                    } else {
181                        log::error!(
182                            "Notification for unknown stream ID: {}",
183                            message.subscription
184                        );
185                    }
186                } else {
187                    log::error!("No 'params' field in notification.");
188                }
189            }
190            RequestOrResponse::Response(response) => {
191                let mut requests = requests.write().await;
192
193                if let Some(tx) = response.id.as_u64().and_then(|id| requests.remove(&id)) {
194                    drop(requests);
195                    tx.send(response).ok();
196                } else {
197                    log::error!("Response for unknown request ID: {}", response.id);
198                }
199            }
200        }
201
202        Ok(())
203    }
204}
205
206#[async_trait]
207impl Client for WebsocketClient {
208    type Error = Error;
209
210    async fn send_request<P, R>(&self, method: &str, params: &P) -> Result<R, Self::Error>
211    where
212        P: Serialize + Debug + Send + Sync,
213        R: for<'de> Deserialize<'de> + Debug + Send + Sync,
214    {
215        let request_id = self.next_id.fetch_add(1, Ordering::SeqCst);
216        let request = Request::build(method.to_owned(), Some(params), Some(&request_id))
217            .expect("Failed to serialize JSON-RPC request.");
218
219        log::debug!("Sending request: {:?}", request);
220
221        self.sender
222            .write()
223            .await
224            .send(Message::binary(serde_json::to_vec(&request)?))
225            .await?;
226
227        let (tx, rx) = oneshot::channel();
228
229        let mut requests = self.requests.write().await;
230        requests.insert(request_id, tx);
231        drop(requests);
232
233        let response = rx.await?;
234        log::debug!("Received response: {:?}", response);
235
236        Ok(response.into_result()?)
237    }
238
239    async fn connect_stream<T: Unpin + 'static>(&self, id: SubscriptionId) -> BoxStream<'static, T>
240    where
241        T: for<'de> Deserialize<'de> + Debug + Send + Sync,
242    {
243        let (tx, mut rx) = mpsc::channel(16);
244
245        self.streams.write().await.insert(id, tx);
246
247        let stream = async_stream::stream! {
248            while let Some(message) = rx.recv().await {
249                yield serde_json::from_value(message.result).unwrap();
250            }
251        };
252
253        stream.boxed()
254    }
255
256    async fn disconnect_stream(&self, id: SubscriptionId) -> Result<(), Self::Error> {
257        if let Some(tx) = self.streams.write().await.remove(&id) {
258            log::debug!("Closing stream of subscription ID: {}", id);
259            drop(tx);
260        } else {
261            log::error!("Unknown subscription ID: {}", id);
262        }
263
264        Ok(())
265    }
266
267    /// Close the websocket connection
268    async fn close(&self) {
269        // Try to send the close message
270        // We don't do anything if it fails
271        let _ = self
272            .sender
273            .write()
274            .await
275            .send(Message::Close(Some(CloseFrame {
276                code: CloseCode::Normal,
277                reason: "".into(),
278            })))
279            .await;
280    }
281}