cs_mwc_web3/transports/
ws.rs

1//! WebSocket Transport
2
3use self::compat::{TcpStream, TlsStream};
4use crate::{api::SubscriptionId, error, helpers, rpc, BatchTransport, DuplexTransport, Error, RequestId, Transport};
5use futures::{
6    channel::{mpsc, oneshot},
7    task::{Context, Poll},
8    AsyncRead, AsyncWrite, Future, FutureExt, Stream, StreamExt,
9};
10use soketto::{
11    connection,
12    handshake::{Client, ServerResponse},
13};
14use std::{
15    collections::BTreeMap,
16    fmt,
17    marker::Unpin,
18    pin::Pin,
19    sync::{atomic, Arc},
20};
21use url::Url;
22
23impl From<soketto::handshake::Error> for Error {
24    fn from(err: soketto::handshake::Error) -> Self {
25        Error::Transport(format!("Handshake Error: {:?}", err))
26    }
27}
28
29impl From<connection::Error> for Error {
30    fn from(err: connection::Error) -> Self {
31        Error::Transport(format!("Connection Error: {:?}", err))
32    }
33}
34
35type SingleResult = error::Result<rpc::Value>;
36type BatchResult = error::Result<Vec<SingleResult>>;
37type Pending = oneshot::Sender<BatchResult>;
38type Subscription = mpsc::UnboundedSender<rpc::Value>;
39
40/// Stream, either plain TCP or TLS.
41enum MaybeTlsStream<P, T> {
42    /// Unencrypted socket stream.
43    Plain(P),
44    /// Encrypted socket stream.
45    #[allow(dead_code)]
46    Tls(T),
47}
48
49impl<P, T> AsyncRead for MaybeTlsStream<P, T>
50where
51    P: AsyncRead + AsyncWrite + Unpin,
52    T: AsyncRead + AsyncWrite + Unpin,
53{
54    fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<Result<usize, std::io::Error>> {
55        match self.get_mut() {
56            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
57            MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_read(cx, buf),
58        }
59    }
60}
61
62impl<P, T> AsyncWrite for MaybeTlsStream<P, T>
63where
64    P: AsyncRead + AsyncWrite + Unpin,
65    T: AsyncRead + AsyncWrite + Unpin,
66{
67    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
68        match self.get_mut() {
69            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
70            MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_write(cx, buf),
71        }
72    }
73
74    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), std::io::Error>> {
75        match self.get_mut() {
76            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
77            MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_flush(cx),
78        }
79    }
80
81    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), std::io::Error>> {
82        match self.get_mut() {
83            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_close(cx),
84            MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_close(cx),
85        }
86    }
87}
88
89struct WsServerTask {
90    pending: BTreeMap<RequestId, Pending>,
91    subscriptions: BTreeMap<SubscriptionId, Subscription>,
92    sender: connection::Sender<MaybeTlsStream<TcpStream, TlsStream>>,
93    receiver: connection::Receiver<MaybeTlsStream<TcpStream, TlsStream>>,
94}
95
96impl WsServerTask {
97    /// Create new WebSocket transport.
98    pub async fn new(url: &str) -> error::Result<Self> {
99        let url = Url::parse(url)?;
100
101        let scheme = match url.scheme() {
102            s if s == "ws" || s == "wss" => s,
103            s => return Err(error::Error::Transport(format!("Wrong scheme: {}", s))),
104        };
105        let host = match url.host_str() {
106            Some(s) => s,
107            None => return Err(error::Error::Transport("Wrong host name".to_string())),
108        };
109        let port = url.port().unwrap_or(if scheme == "ws" { 80 } else { 443 });
110        let addrs = format!("{}:{}", host, port);
111
112        let stream = compat::raw_tcp_stream(addrs).await?;
113        stream.set_nodelay(true)?;
114        let socket = if scheme == "wss" {
115            #[cfg(any(feature = "ws-tls-tokio", feature = "ws-tls-async-std"))]
116            {
117                let stream = async_native_tls::connect(host, stream).await?;
118                MaybeTlsStream::Tls(compat::compat(stream))
119            }
120            #[cfg(not(any(feature = "ws-tls-tokio", feature = "ws-tls-async-std")))]
121            panic!("The library was compiled without TLS support. Enable ws-tls-tokio or ws-tls-async-std feature.");
122        } else {
123            let stream = compat::compat(stream);
124            MaybeTlsStream::Plain(stream)
125        };
126
127        let mut client = Client::new(socket, host, url.path());
128        let handshake = client.handshake();
129        let (sender, receiver) = match handshake.await? {
130            ServerResponse::Accepted { .. } => client.into_builder().finish(),
131            ServerResponse::Redirect { status_code, location } => {
132                return Err(error::Error::Transport(format!(
133                    "(code: {}) Unable to follow redirects: {}",
134                    status_code, location
135                )))
136            }
137            ServerResponse::Rejected { status_code } => {
138                return Err(error::Error::Transport(format!(
139                    "(code: {}) Connection rejected.",
140                    status_code
141                )))
142            }
143        };
144
145        Ok(Self {
146            pending: Default::default(),
147            subscriptions: Default::default(),
148            sender,
149            receiver,
150        })
151    }
152
153    async fn into_task(self, requests: mpsc::UnboundedReceiver<TransportMessage>) {
154        let Self {
155            receiver,
156            mut sender,
157            mut pending,
158            mut subscriptions,
159        } = self;
160
161        let receiver = as_data_stream(receiver).fuse();
162        let requests = requests.fuse();
163        pin_mut!(receiver);
164        pin_mut!(requests);
165        loop {
166            select! {
167                msg = requests.next() => match msg {
168                    Some(TransportMessage::Request { id, request, sender: tx }) => {
169                        if pending.insert(id.clone(), tx).is_some() {
170                            log::warn!("Replacing a pending request with id {:?}", id);
171                        }
172                        let res = sender.send_text(request).await;
173                        let res2 = sender.flush().await;
174                        if let Err(e) = res.and(res2) {
175                            // TODO [ToDr] Re-connect.
176                            log::error!("WS connection error: {:?}", e);
177                            pending.remove(&id);
178                        }
179                    }
180                    Some(TransportMessage::Subscribe { id, sink }) => {
181                        if subscriptions.insert(id.clone(), sink).is_some() {
182                            log::warn!("Replacing already-registered subscription with id {:?}", id);
183                        }
184                    }
185                    Some(TransportMessage::Unsubscribe { id }) => {
186                        if subscriptions.remove(&id).is_none() {
187                            log::warn!("Unsubscribing from non-existent subscription with id {:?}", id);
188                        }
189                    }
190                    None => {}
191                },
192                res = receiver.next() => match res {
193                    Some(Ok(data)) => {
194                        handle_message(&data, &subscriptions, &mut pending);
195                    },
196                    Some(Err(e)) => {
197                        log::error!("WS connection error: {:?}", e);
198                        break;
199                    },
200                    None => break,
201                },
202                complete => break,
203            }
204        }
205    }
206}
207
208fn as_data_stream<T: Unpin + futures::AsyncRead + futures::AsyncWrite>(
209    receiver: soketto::connection::Receiver<T>,
210) -> impl Stream<Item = Result<Vec<u8>, soketto::connection::Error>> {
211    futures::stream::unfold(receiver, |mut receiver| async move {
212        let mut data = Vec::new();
213        Some(match receiver.receive_data(&mut data).await {
214            Ok(_) => (Ok(data), receiver),
215            Err(e) => (Err(e), receiver),
216        })
217    })
218}
219
220fn handle_message(
221    data: &[u8],
222    subscriptions: &BTreeMap<SubscriptionId, Subscription>,
223    pending: &mut BTreeMap<RequestId, Pending>,
224) {
225    log::trace!("Message received: {:?}", data);
226    if let Ok(notification) = helpers::to_notification_from_slice(data) {
227        if let rpc::Params::Map(params) = notification.params {
228            let id = params.get("subscription");
229            let result = params.get("result");
230
231            if let (Some(&rpc::Value::String(ref id)), Some(result)) = (id, result) {
232                let id: SubscriptionId = id.clone().into();
233                if let Some(stream) = subscriptions.get(&id) {
234                    if let Err(e) = stream.unbounded_send(result.clone()) {
235                        log::error!("Error sending notification: {:?} (id: {:?}", e, id);
236                    }
237                } else {
238                    log::warn!("Got notification for unknown subscription (id: {:?})", id);
239                }
240            } else {
241                log::error!("Got unsupported notification (id: {:?})", id);
242            }
243        }
244    } else {
245        let response = helpers::to_response_from_slice(data);
246        let outputs = match response {
247            Ok(rpc::Response::Single(output)) => vec![output],
248            Ok(rpc::Response::Batch(outputs)) => outputs,
249            _ => vec![],
250        };
251
252        let id = match outputs.get(0) {
253            Some(&rpc::Output::Success(ref success)) => success.id.clone(),
254            Some(&rpc::Output::Failure(ref failure)) => failure.id.clone(),
255            None => rpc::Id::Num(0),
256        };
257
258        if let rpc::Id::Num(num) = id {
259            if let Some(request) = pending.remove(&(num as usize)) {
260                log::trace!("Responding to (id: {:?}) with {:?}", num, outputs);
261                if let Err(err) = request.send(helpers::to_results_from_outputs(outputs)) {
262                    log::warn!("Sending a response to deallocated channel: {:?}", err);
263                }
264            } else {
265                log::warn!("Got response for unknown request (id: {:?})", num);
266            }
267        } else {
268            log::warn!("Got unsupported response (id: {:?})", id);
269        }
270    }
271}
272
273enum TransportMessage {
274    Request {
275        id: RequestId,
276        request: String,
277        sender: oneshot::Sender<BatchResult>,
278    },
279    Subscribe {
280        id: SubscriptionId,
281        sink: mpsc::UnboundedSender<rpc::Value>,
282    },
283    Unsubscribe {
284        id: SubscriptionId,
285    },
286}
287
288/// WebSocket transport
289#[derive(Clone)]
290pub struct WebSocket {
291    id: Arc<atomic::AtomicUsize>,
292    requests: mpsc::UnboundedSender<TransportMessage>,
293}
294
295impl fmt::Debug for WebSocket {
296    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
297        fmt.debug_struct("WebSocket").field("id", &self.id).finish()
298    }
299}
300
301impl WebSocket {
302    /// Create new WebSocket transport.
303    pub async fn new(url: &str) -> error::Result<Self> {
304        let id = Arc::new(atomic::AtomicUsize::new(1));
305        let task = WsServerTask::new(url).await?;
306        // TODO [ToDr] Not unbounded?
307        let (sink, stream) = mpsc::unbounded();
308        // Spawn background task for the transport.
309        #[cfg(feature = "ws-tokio")]
310        tokio::spawn(task.into_task(stream));
311        #[cfg(feature = "ws-async-std")]
312        async_std::task::spawn(task.into_task(stream));
313
314        Ok(Self { id, requests: sink })
315    }
316
317    fn send(&self, msg: TransportMessage) -> error::Result {
318        self.requests.unbounded_send(msg).map_err(dropped_err)
319    }
320
321    fn send_request(&self, id: RequestId, request: rpc::Request) -> error::Result<oneshot::Receiver<BatchResult>> {
322        let request = helpers::to_string(&request);
323        log::debug!("[{}] Calling: {}", id, request);
324        let (sender, receiver) = oneshot::channel();
325        self.send(TransportMessage::Request { id, request, sender })?;
326        Ok(receiver)
327    }
328}
329
330fn dropped_err<T>(_: T) -> error::Error {
331    Error::Transport("Cannot send request. Internal task finished.".into())
332}
333
334fn batch_to_single(response: BatchResult) -> SingleResult {
335    match response?.into_iter().next() {
336        Some(res) => res,
337        None => Err(Error::InvalidResponse("Expected single, got batch.".into())),
338    }
339}
340
341fn batch_to_batch(res: BatchResult) -> BatchResult {
342    res
343}
344
345enum ResponseState {
346    Receiver(Option<error::Result<oneshot::Receiver<BatchResult>>>),
347    Waiting(oneshot::Receiver<BatchResult>),
348}
349
350/// A WS resonse wrapper.
351pub struct Response<R, T> {
352    extract: T,
353    state: ResponseState,
354    _data: std::marker::PhantomData<R>,
355}
356
357impl<R, T> Response<R, T> {
358    fn new(response: error::Result<oneshot::Receiver<BatchResult>>, extract: T) -> Self {
359        Self {
360            extract,
361            state: ResponseState::Receiver(Some(response)),
362            _data: Default::default(),
363        }
364    }
365}
366
367impl<R, T> Future for Response<R, T>
368where
369    R: Unpin + 'static,
370    T: Fn(BatchResult) -> error::Result<R> + Unpin + 'static,
371{
372    type Output = error::Result<R>;
373    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
374        loop {
375            match self.state {
376                ResponseState::Receiver(ref mut res) => {
377                    let receiver = res.take().expect("Receiver state is active only once; qed")?;
378                    self.state = ResponseState::Waiting(receiver)
379                }
380                ResponseState::Waiting(ref mut future) => {
381                    let response = ready!(future.poll_unpin(cx)).map_err(dropped_err)?;
382                    return Poll::Ready((self.extract)(response));
383                }
384            }
385        }
386    }
387}
388
389impl Transport for WebSocket {
390    type Out = Response<rpc::Value, fn(BatchResult) -> SingleResult>;
391
392    fn prepare(&self, method: &str, params: Vec<rpc::Value>) -> (RequestId, rpc::Call) {
393        let id = self.id.fetch_add(1, atomic::Ordering::AcqRel);
394        let request = helpers::build_request(id, method, params);
395
396        (id, request)
397    }
398
399    fn send(&self, id: RequestId, request: rpc::Call) -> Self::Out {
400        let response = self.send_request(id, rpc::Request::Single(request));
401        Response::new(response, batch_to_single)
402    }
403}
404
405impl BatchTransport for WebSocket {
406    type Batch = Response<Vec<SingleResult>, fn(BatchResult) -> BatchResult>;
407
408    fn send_batch<T>(&self, requests: T) -> Self::Batch
409    where
410        T: IntoIterator<Item = (RequestId, rpc::Call)>,
411    {
412        let mut it = requests.into_iter();
413        let (id, first) = it.next().map(|x| (x.0, Some(x.1))).unwrap_or_else(|| (0, None));
414        let requests = first.into_iter().chain(it.map(|x| x.1)).collect();
415        let response = self.send_request(id, rpc::Request::Batch(requests));
416        Response::new(response, batch_to_batch)
417    }
418}
419
420impl DuplexTransport for WebSocket {
421    type NotificationStream = mpsc::UnboundedReceiver<rpc::Value>;
422
423    fn subscribe(&self, id: SubscriptionId) -> error::Result<Self::NotificationStream> {
424        // TODO [ToDr] Not unbounded?
425        let (sink, stream) = mpsc::unbounded();
426        self.send(TransportMessage::Subscribe { id, sink })?;
427        Ok(stream)
428    }
429
430    fn unsubscribe(&self, id: SubscriptionId) -> error::Result {
431        self.send(TransportMessage::Unsubscribe { id })
432    }
433}
434
435/// Compatibility layer between async-std and tokio
436#[cfg(feature = "ws-async-std")]
437#[doc(hidden)]
438pub mod compat {
439    pub use async_std::net::{TcpListener, TcpStream};
440    /// TLS stream type for async-std runtime.
441    #[cfg(feature = "ws-tls-async-std")]
442    pub type TlsStream = async_native_tls::TlsStream<TcpStream>;
443    /// Dummy TLS stream type.
444    #[cfg(not(feature = "ws-tls-async-std"))]
445    pub type TlsStream = TcpStream;
446
447    /// Create new TcpStream object.
448    pub async fn raw_tcp_stream(addrs: String) -> std::io::Result<TcpStream> {
449        TcpStream::connect(addrs).await
450    }
451
452    /// Wrap given argument into compatibility layer.
453    #[inline(always)]
454    pub fn compat<T>(t: T) -> T {
455        t
456    }
457}
458
459/// Compatibility layer between async-std and tokio
460#[cfg(feature = "ws-tokio")]
461pub mod compat {
462    /// async-std compatible TcpStream.
463    pub type TcpStream = Compat<tokio::net::TcpStream>;
464    /// async-std compatible TcpListener.
465    pub type TcpListener = tokio::net::TcpListener;
466    /// TLS stream type for tokio runtime.
467    #[cfg(feature = "ws-tls-tokio")]
468    pub type TlsStream = Compat<async_native_tls::TlsStream<tokio::net::TcpStream>>;
469    /// Dummy TLS stream type.
470    #[cfg(not(feature = "ws-tls-tokio"))]
471    pub type TlsStream = TcpStream;
472
473    use std::{
474        io,
475        pin::Pin,
476        task::{Context, Poll},
477    };
478
479    /// Create new TcpStream object.
480    pub async fn raw_tcp_stream(addrs: String) -> io::Result<tokio::net::TcpStream> {
481        Ok(tokio::net::TcpStream::connect(addrs).await?)
482    }
483
484    /// Wrap given argument into compatibility layer.
485    pub fn compat<T>(t: T) -> Compat<T> {
486        Compat(t)
487    }
488
489    /// Compatibility layer.
490    pub struct Compat<T>(T);
491    impl<T: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for Compat<T> {
492        fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
493            tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
494        }
495
496        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
497            tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
498        }
499
500        fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
501            tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
502        }
503    }
504
505    impl<T: tokio::io::AsyncWrite + Unpin> futures::AsyncWrite for Compat<T> {
506        fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
507            tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
508        }
509
510        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
511            tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
512        }
513
514        fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
515            tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
516        }
517    }
518
519    impl<T: tokio::io::AsyncRead + Unpin> futures::AsyncRead for Compat<T> {
520        fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
521            tokio::io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
522        }
523    }
524
525    impl<T: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for Compat<T> {
526        fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
527            tokio::io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
528        }
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535    use crate::{rpc, Transport};
536    use futures::{
537        io::{BufReader, BufWriter},
538        StreamExt,
539    };
540    use soketto::handshake;
541
542    #[test]
543    fn bounds_matching() {
544        fn async_rw<T: AsyncRead + AsyncWrite>() {}
545
546        async_rw::<TcpStream>();
547        async_rw::<MaybeTlsStream<TcpStream, TlsStream>>();
548    }
549
550    #[tokio::test]
551    async fn should_send_a_request() {
552        let _ = env_logger::try_init();
553        // given
554        let addr = "127.0.0.1:3000";
555        let listener = futures::executor::block_on(compat::TcpListener::bind(addr)).expect("Failed to bind");
556        println!("Starting the server.");
557        tokio::spawn(server(listener, addr));
558
559        let endpoint = "ws://127.0.0.1:3000";
560        let ws = WebSocket::new(endpoint).await.unwrap();
561
562        // when
563        let res = ws.execute("eth_accounts", vec![rpc::Value::String("1".into())]);
564
565        // then
566        assert_eq!(res.await, Ok(rpc::Value::String("x".into())));
567    }
568
569    async fn server(mut listener: compat::TcpListener, addr: &str) {
570        let mut incoming = listener.incoming();
571        println!("Listening on: {}", addr);
572        while let Some(Ok(socket)) = incoming.next().await {
573            let socket = compat::compat(socket);
574            let mut server = handshake::Server::new(BufReader::new(BufWriter::new(socket)));
575            let key = {
576                let req = server.receive_request().await.unwrap();
577                req.into_key()
578            };
579            let accept = handshake::server::Response::Accept {
580                key: &key,
581                protocol: None,
582            };
583            server.send_response(&accept).await.unwrap();
584            let (mut sender, mut receiver) = server.into_builder().finish();
585            loop {
586                let mut data = Vec::new();
587                match receiver.receive_data(&mut data).await {
588                    Ok(data_type) if data_type.is_text() => {
589                        assert_eq!(
590                            std::str::from_utf8(&data),
591                            Ok(r#"{"jsonrpc":"2.0","method":"eth_accounts","params":["1"],"id":1}"#)
592                        );
593                        sender
594                            .send_text(r#"{"jsonrpc":"2.0","id":1,"result":"x"}"#)
595                            .await
596                            .unwrap();
597                        sender.flush().await.unwrap();
598                    }
599                    Err(soketto::connection::Error::Closed) => break,
600                    e => panic!("Unexpected data: {:?}", e),
601                }
602            }
603        }
604    }
605}