1use self::compat::{TcpStream, TlsStream};
4use crate::{api::SubscriptionId, error, helpers, rpc, BatchTransport, DuplexTransport, Error, RequestId, Transport};
5use futures::{
6 channel::{mpsc, oneshot},
7 task::{Context, Poll},
8 AsyncRead, AsyncWrite, Future, FutureExt, Stream, StreamExt,
9};
10use soketto::{
11 connection,
12 handshake::{Client, ServerResponse},
13};
14use std::{
15 collections::BTreeMap,
16 fmt,
17 marker::Unpin,
18 pin::Pin,
19 sync::{atomic, Arc},
20};
21use url::Url;
22
23impl From<soketto::handshake::Error> for Error {
24 fn from(err: soketto::handshake::Error) -> Self {
25 Error::Transport(format!("Handshake Error: {:?}", err))
26 }
27}
28
29impl From<connection::Error> for Error {
30 fn from(err: connection::Error) -> Self {
31 Error::Transport(format!("Connection Error: {:?}", err))
32 }
33}
34
35type SingleResult = error::Result<rpc::Value>;
36type BatchResult = error::Result<Vec<SingleResult>>;
37type Pending = oneshot::Sender<BatchResult>;
38type Subscription = mpsc::UnboundedSender<rpc::Value>;
39
40enum MaybeTlsStream<P, T> {
42 Plain(P),
44 #[allow(dead_code)]
46 Tls(T),
47}
48
49impl<P, T> AsyncRead for MaybeTlsStream<P, T>
50where
51 P: AsyncRead + AsyncWrite + Unpin,
52 T: AsyncRead + AsyncWrite + Unpin,
53{
54 fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<Result<usize, std::io::Error>> {
55 match self.get_mut() {
56 MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
57 MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_read(cx, buf),
58 }
59 }
60}
61
62impl<P, T> AsyncWrite for MaybeTlsStream<P, T>
63where
64 P: AsyncRead + AsyncWrite + Unpin,
65 T: AsyncRead + AsyncWrite + Unpin,
66{
67 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
68 match self.get_mut() {
69 MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
70 MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_write(cx, buf),
71 }
72 }
73
74 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), std::io::Error>> {
75 match self.get_mut() {
76 MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
77 MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_flush(cx),
78 }
79 }
80
81 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), std::io::Error>> {
82 match self.get_mut() {
83 MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_close(cx),
84 MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_close(cx),
85 }
86 }
87}
88
89struct WsServerTask {
90 pending: BTreeMap<RequestId, Pending>,
91 subscriptions: BTreeMap<SubscriptionId, Subscription>,
92 sender: connection::Sender<MaybeTlsStream<TcpStream, TlsStream>>,
93 receiver: connection::Receiver<MaybeTlsStream<TcpStream, TlsStream>>,
94}
95
96impl WsServerTask {
97 pub async fn new(url: &str) -> error::Result<Self> {
99 let url = Url::parse(url)?;
100
101 let scheme = match url.scheme() {
102 s if s == "ws" || s == "wss" => s,
103 s => return Err(error::Error::Transport(format!("Wrong scheme: {}", s))),
104 };
105 let host = match url.host_str() {
106 Some(s) => s,
107 None => return Err(error::Error::Transport("Wrong host name".to_string())),
108 };
109 let port = url.port().unwrap_or(if scheme == "ws" { 80 } else { 443 });
110 let addrs = format!("{}:{}", host, port);
111
112 let stream = compat::raw_tcp_stream(addrs).await?;
113 stream.set_nodelay(true)?;
114 let socket = if scheme == "wss" {
115 #[cfg(any(feature = "ws-tls-tokio", feature = "ws-tls-async-std"))]
116 {
117 let stream = async_native_tls::connect(host, stream).await?;
118 MaybeTlsStream::Tls(compat::compat(stream))
119 }
120 #[cfg(not(any(feature = "ws-tls-tokio", feature = "ws-tls-async-std")))]
121 panic!("The library was compiled without TLS support. Enable ws-tls-tokio or ws-tls-async-std feature.");
122 } else {
123 let stream = compat::compat(stream);
124 MaybeTlsStream::Plain(stream)
125 };
126
127 let mut client = Client::new(socket, host, url.path());
128 let handshake = client.handshake();
129 let (sender, receiver) = match handshake.await? {
130 ServerResponse::Accepted { .. } => client.into_builder().finish(),
131 ServerResponse::Redirect { status_code, location } => {
132 return Err(error::Error::Transport(format!(
133 "(code: {}) Unable to follow redirects: {}",
134 status_code, location
135 )))
136 }
137 ServerResponse::Rejected { status_code } => {
138 return Err(error::Error::Transport(format!(
139 "(code: {}) Connection rejected.",
140 status_code
141 )))
142 }
143 };
144
145 Ok(Self {
146 pending: Default::default(),
147 subscriptions: Default::default(),
148 sender,
149 receiver,
150 })
151 }
152
153 async fn into_task(self, requests: mpsc::UnboundedReceiver<TransportMessage>) {
154 let Self {
155 receiver,
156 mut sender,
157 mut pending,
158 mut subscriptions,
159 } = self;
160
161 let receiver = as_data_stream(receiver).fuse();
162 let requests = requests.fuse();
163 pin_mut!(receiver);
164 pin_mut!(requests);
165 loop {
166 select! {
167 msg = requests.next() => match msg {
168 Some(TransportMessage::Request { id, request, sender: tx }) => {
169 if pending.insert(id.clone(), tx).is_some() {
170 log::warn!("Replacing a pending request with id {:?}", id);
171 }
172 let res = sender.send_text(request).await;
173 let res2 = sender.flush().await;
174 if let Err(e) = res.and(res2) {
175 log::error!("WS connection error: {:?}", e);
177 pending.remove(&id);
178 }
179 }
180 Some(TransportMessage::Subscribe { id, sink }) => {
181 if subscriptions.insert(id.clone(), sink).is_some() {
182 log::warn!("Replacing already-registered subscription with id {:?}", id);
183 }
184 }
185 Some(TransportMessage::Unsubscribe { id }) => {
186 if subscriptions.remove(&id).is_none() {
187 log::warn!("Unsubscribing from non-existent subscription with id {:?}", id);
188 }
189 }
190 None => {}
191 },
192 res = receiver.next() => match res {
193 Some(Ok(data)) => {
194 handle_message(&data, &subscriptions, &mut pending);
195 },
196 Some(Err(e)) => {
197 log::error!("WS connection error: {:?}", e);
198 break;
199 },
200 None => break,
201 },
202 complete => break,
203 }
204 }
205 }
206}
207
208fn as_data_stream<T: Unpin + futures::AsyncRead + futures::AsyncWrite>(
209 receiver: soketto::connection::Receiver<T>,
210) -> impl Stream<Item = Result<Vec<u8>, soketto::connection::Error>> {
211 futures::stream::unfold(receiver, |mut receiver| async move {
212 let mut data = Vec::new();
213 Some(match receiver.receive_data(&mut data).await {
214 Ok(_) => (Ok(data), receiver),
215 Err(e) => (Err(e), receiver),
216 })
217 })
218}
219
220fn handle_message(
221 data: &[u8],
222 subscriptions: &BTreeMap<SubscriptionId, Subscription>,
223 pending: &mut BTreeMap<RequestId, Pending>,
224) {
225 log::trace!("Message received: {:?}", data);
226 if let Ok(notification) = helpers::to_notification_from_slice(data) {
227 if let rpc::Params::Map(params) = notification.params {
228 let id = params.get("subscription");
229 let result = params.get("result");
230
231 if let (Some(&rpc::Value::String(ref id)), Some(result)) = (id, result) {
232 let id: SubscriptionId = id.clone().into();
233 if let Some(stream) = subscriptions.get(&id) {
234 if let Err(e) = stream.unbounded_send(result.clone()) {
235 log::error!("Error sending notification: {:?} (id: {:?}", e, id);
236 }
237 } else {
238 log::warn!("Got notification for unknown subscription (id: {:?})", id);
239 }
240 } else {
241 log::error!("Got unsupported notification (id: {:?})", id);
242 }
243 }
244 } else {
245 let response = helpers::to_response_from_slice(data);
246 let outputs = match response {
247 Ok(rpc::Response::Single(output)) => vec![output],
248 Ok(rpc::Response::Batch(outputs)) => outputs,
249 _ => vec![],
250 };
251
252 let id = match outputs.get(0) {
253 Some(&rpc::Output::Success(ref success)) => success.id.clone(),
254 Some(&rpc::Output::Failure(ref failure)) => failure.id.clone(),
255 None => rpc::Id::Num(0),
256 };
257
258 if let rpc::Id::Num(num) = id {
259 if let Some(request) = pending.remove(&(num as usize)) {
260 log::trace!("Responding to (id: {:?}) with {:?}", num, outputs);
261 if let Err(err) = request.send(helpers::to_results_from_outputs(outputs)) {
262 log::warn!("Sending a response to deallocated channel: {:?}", err);
263 }
264 } else {
265 log::warn!("Got response for unknown request (id: {:?})", num);
266 }
267 } else {
268 log::warn!("Got unsupported response (id: {:?})", id);
269 }
270 }
271}
272
273enum TransportMessage {
274 Request {
275 id: RequestId,
276 request: String,
277 sender: oneshot::Sender<BatchResult>,
278 },
279 Subscribe {
280 id: SubscriptionId,
281 sink: mpsc::UnboundedSender<rpc::Value>,
282 },
283 Unsubscribe {
284 id: SubscriptionId,
285 },
286}
287
288#[derive(Clone)]
290pub struct WebSocket {
291 id: Arc<atomic::AtomicUsize>,
292 requests: mpsc::UnboundedSender<TransportMessage>,
293}
294
295impl fmt::Debug for WebSocket {
296 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
297 fmt.debug_struct("WebSocket").field("id", &self.id).finish()
298 }
299}
300
301impl WebSocket {
302 pub async fn new(url: &str) -> error::Result<Self> {
304 let id = Arc::new(atomic::AtomicUsize::new(1));
305 let task = WsServerTask::new(url).await?;
306 let (sink, stream) = mpsc::unbounded();
308 #[cfg(feature = "ws-tokio")]
310 tokio::spawn(task.into_task(stream));
311 #[cfg(feature = "ws-async-std")]
312 async_std::task::spawn(task.into_task(stream));
313
314 Ok(Self { id, requests: sink })
315 }
316
317 fn send(&self, msg: TransportMessage) -> error::Result {
318 self.requests.unbounded_send(msg).map_err(dropped_err)
319 }
320
321 fn send_request(&self, id: RequestId, request: rpc::Request) -> error::Result<oneshot::Receiver<BatchResult>> {
322 let request = helpers::to_string(&request);
323 log::debug!("[{}] Calling: {}", id, request);
324 let (sender, receiver) = oneshot::channel();
325 self.send(TransportMessage::Request { id, request, sender })?;
326 Ok(receiver)
327 }
328}
329
330fn dropped_err<T>(_: T) -> error::Error {
331 Error::Transport("Cannot send request. Internal task finished.".into())
332}
333
334fn batch_to_single(response: BatchResult) -> SingleResult {
335 match response?.into_iter().next() {
336 Some(res) => res,
337 None => Err(Error::InvalidResponse("Expected single, got batch.".into())),
338 }
339}
340
341fn batch_to_batch(res: BatchResult) -> BatchResult {
342 res
343}
344
345enum ResponseState {
346 Receiver(Option<error::Result<oneshot::Receiver<BatchResult>>>),
347 Waiting(oneshot::Receiver<BatchResult>),
348}
349
350pub struct Response<R, T> {
352 extract: T,
353 state: ResponseState,
354 _data: std::marker::PhantomData<R>,
355}
356
357impl<R, T> Response<R, T> {
358 fn new(response: error::Result<oneshot::Receiver<BatchResult>>, extract: T) -> Self {
359 Self {
360 extract,
361 state: ResponseState::Receiver(Some(response)),
362 _data: Default::default(),
363 }
364 }
365}
366
367impl<R, T> Future for Response<R, T>
368where
369 R: Unpin + 'static,
370 T: Fn(BatchResult) -> error::Result<R> + Unpin + 'static,
371{
372 type Output = error::Result<R>;
373 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
374 loop {
375 match self.state {
376 ResponseState::Receiver(ref mut res) => {
377 let receiver = res.take().expect("Receiver state is active only once; qed")?;
378 self.state = ResponseState::Waiting(receiver)
379 }
380 ResponseState::Waiting(ref mut future) => {
381 let response = ready!(future.poll_unpin(cx)).map_err(dropped_err)?;
382 return Poll::Ready((self.extract)(response));
383 }
384 }
385 }
386 }
387}
388
389impl Transport for WebSocket {
390 type Out = Response<rpc::Value, fn(BatchResult) -> SingleResult>;
391
392 fn prepare(&self, method: &str, params: Vec<rpc::Value>) -> (RequestId, rpc::Call) {
393 let id = self.id.fetch_add(1, atomic::Ordering::AcqRel);
394 let request = helpers::build_request(id, method, params);
395
396 (id, request)
397 }
398
399 fn send(&self, id: RequestId, request: rpc::Call) -> Self::Out {
400 let response = self.send_request(id, rpc::Request::Single(request));
401 Response::new(response, batch_to_single)
402 }
403}
404
405impl BatchTransport for WebSocket {
406 type Batch = Response<Vec<SingleResult>, fn(BatchResult) -> BatchResult>;
407
408 fn send_batch<T>(&self, requests: T) -> Self::Batch
409 where
410 T: IntoIterator<Item = (RequestId, rpc::Call)>,
411 {
412 let mut it = requests.into_iter();
413 let (id, first) = it.next().map(|x| (x.0, Some(x.1))).unwrap_or_else(|| (0, None));
414 let requests = first.into_iter().chain(it.map(|x| x.1)).collect();
415 let response = self.send_request(id, rpc::Request::Batch(requests));
416 Response::new(response, batch_to_batch)
417 }
418}
419
420impl DuplexTransport for WebSocket {
421 type NotificationStream = mpsc::UnboundedReceiver<rpc::Value>;
422
423 fn subscribe(&self, id: SubscriptionId) -> error::Result<Self::NotificationStream> {
424 let (sink, stream) = mpsc::unbounded();
426 self.send(TransportMessage::Subscribe { id, sink })?;
427 Ok(stream)
428 }
429
430 fn unsubscribe(&self, id: SubscriptionId) -> error::Result {
431 self.send(TransportMessage::Unsubscribe { id })
432 }
433}
434
435#[cfg(feature = "ws-async-std")]
437#[doc(hidden)]
438pub mod compat {
439 pub use async_std::net::{TcpListener, TcpStream};
440 #[cfg(feature = "ws-tls-async-std")]
442 pub type TlsStream = async_native_tls::TlsStream<TcpStream>;
443 #[cfg(not(feature = "ws-tls-async-std"))]
445 pub type TlsStream = TcpStream;
446
447 pub async fn raw_tcp_stream(addrs: String) -> std::io::Result<TcpStream> {
449 TcpStream::connect(addrs).await
450 }
451
452 #[inline(always)]
454 pub fn compat<T>(t: T) -> T {
455 t
456 }
457}
458
459#[cfg(feature = "ws-tokio")]
461pub mod compat {
462 pub type TcpStream = Compat<tokio::net::TcpStream>;
464 pub type TcpListener = tokio::net::TcpListener;
466 #[cfg(feature = "ws-tls-tokio")]
468 pub type TlsStream = Compat<async_native_tls::TlsStream<tokio::net::TcpStream>>;
469 #[cfg(not(feature = "ws-tls-tokio"))]
471 pub type TlsStream = TcpStream;
472
473 use std::{
474 io,
475 pin::Pin,
476 task::{Context, Poll},
477 };
478
479 pub async fn raw_tcp_stream(addrs: String) -> io::Result<tokio::net::TcpStream> {
481 Ok(tokio::net::TcpStream::connect(addrs).await?)
482 }
483
484 pub fn compat<T>(t: T) -> Compat<T> {
486 Compat(t)
487 }
488
489 pub struct Compat<T>(T);
491 impl<T: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for Compat<T> {
492 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
493 tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
494 }
495
496 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
497 tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
498 }
499
500 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
501 tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
502 }
503 }
504
505 impl<T: tokio::io::AsyncWrite + Unpin> futures::AsyncWrite for Compat<T> {
506 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
507 tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
508 }
509
510 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
511 tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
512 }
513
514 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
515 tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
516 }
517 }
518
519 impl<T: tokio::io::AsyncRead + Unpin> futures::AsyncRead for Compat<T> {
520 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
521 tokio::io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
522 }
523 }
524
525 impl<T: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for Compat<T> {
526 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
527 tokio::io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
528 }
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535 use crate::{rpc, Transport};
536 use futures::{
537 io::{BufReader, BufWriter},
538 StreamExt,
539 };
540 use soketto::handshake;
541
542 #[test]
543 fn bounds_matching() {
544 fn async_rw<T: AsyncRead + AsyncWrite>() {}
545
546 async_rw::<TcpStream>();
547 async_rw::<MaybeTlsStream<TcpStream, TlsStream>>();
548 }
549
550 #[tokio::test]
551 async fn should_send_a_request() {
552 let _ = env_logger::try_init();
553 let addr = "127.0.0.1:3000";
555 let listener = futures::executor::block_on(compat::TcpListener::bind(addr)).expect("Failed to bind");
556 println!("Starting the server.");
557 tokio::spawn(server(listener, addr));
558
559 let endpoint = "ws://127.0.0.1:3000";
560 let ws = WebSocket::new(endpoint).await.unwrap();
561
562 let res = ws.execute("eth_accounts", vec![rpc::Value::String("1".into())]);
564
565 assert_eq!(res.await, Ok(rpc::Value::String("x".into())));
567 }
568
569 async fn server(mut listener: compat::TcpListener, addr: &str) {
570 let mut incoming = listener.incoming();
571 println!("Listening on: {}", addr);
572 while let Some(Ok(socket)) = incoming.next().await {
573 let socket = compat::compat(socket);
574 let mut server = handshake::Server::new(BufReader::new(BufWriter::new(socket)));
575 let key = {
576 let req = server.receive_request().await.unwrap();
577 req.into_key()
578 };
579 let accept = handshake::server::Response::Accept {
580 key: &key,
581 protocol: None,
582 };
583 server.send_response(&accept).await.unwrap();
584 let (mut sender, mut receiver) = server.into_builder().finish();
585 loop {
586 let mut data = Vec::new();
587 match receiver.receive_data(&mut data).await {
588 Ok(data_type) if data_type.is_text() => {
589 assert_eq!(
590 std::str::from_utf8(&data),
591 Ok(r#"{"jsonrpc":"2.0","method":"eth_accounts","params":["1"],"id":1}"#)
592 );
593 sender
594 .send_text(r#"{"jsonrpc":"2.0","id":1,"result":"x"}"#)
595 .await
596 .unwrap();
597 sender.flush().await.unwrap();
598 }
599 Err(soketto::connection::Error::Closed) => break,
600 e => panic!("Unexpected data: {:?}", e),
601 }
602 }
603 }
604 }
605}