use crate::helpers::{
build_unsubscribe_message, process_batch_response, process_error_response, process_single_response,
process_subscription_response, stop_subscription,
};
use crate::traits::{Client, SubscriptionClient};
use crate::transport::{parse_url, Receiver as WsReceiver, Sender as WsSender, WsTransportClientBuilder};
use crate::v2::error::JsonRpcErrorAlloc;
use crate::v2::params::{Id, JsonRpcParams};
use crate::v2::request::{JsonRpcCallSer, JsonRpcNotificationSer};
use crate::v2::response::{JsonRpcNotifResponse, JsonRpcResponse};
use crate::TEN_MB_SIZE_BYTES;
use crate::{
manager::RequestManager, BatchMessage, Error, FrontToBack, RequestMessage, Subscription, SubscriptionMessage,
};
use async_std::sync::Mutex;
use async_trait::async_trait;
use futures::{
channel::{mpsc, oneshot},
future::Either,
prelude::*,
sink::SinkExt,
};
use serde::de::DeserializeOwned;
use std::{
borrow::Cow,
marker::PhantomData,
sync::atomic::{AtomicU64, AtomicUsize, Ordering},
time::Duration,
};
#[derive(Debug)]
enum ErrorFromBack {
Read(String),
Unread(oneshot::Receiver<Error>),
}
impl ErrorFromBack {
async fn read_error(self) -> (Self, Error) {
match self {
Self::Unread(rx) => {
let msg = match rx.await {
Ok(msg) => msg.to_string(),
Err(_) => "Error reason could not be found. This is a bug. Please open an issue.".to_string(),
};
let err = Error::RestartNeeded(msg.clone());
(Self::Read(msg), err)
}
Self::Read(msg) => (Self::Read(msg.clone()), Error::RestartNeeded(msg)),
}
}
}
#[derive(Debug)]
pub struct WsClient {
to_back: mpsc::Sender<FrontToBack>,
error: Mutex<ErrorFromBack>,
request_timeout: Option<Duration>,
id_guard: RequestIdGuard,
}
#[derive(Debug)]
struct RequestIdGuard {
current_pending: AtomicUsize,
max_concurrent_requests: usize,
current_id: AtomicU64,
}
impl RequestIdGuard {
fn new(limit: usize) -> Self {
Self { current_pending: AtomicUsize::new(0), max_concurrent_requests: limit, current_id: AtomicU64::new(0) }
}
fn get_slot(&self) -> Result<(), Error> {
self.current_pending
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| {
if val >= self.max_concurrent_requests {
None
} else {
Some(val + 1)
}
})
.map(|_| ())
.map_err(|_| Error::MaxSlotsExceeded)
}
fn next_request_id(&self) -> Result<u64, Error> {
self.get_slot()?;
let id = self.current_id.fetch_add(1, Ordering::SeqCst);
Ok(id)
}
fn next_request_ids(&self, len: usize) -> Result<Vec<u64>, Error> {
self.get_slot()?;
let mut batch = Vec::with_capacity(len);
for _ in 0..len {
batch.push(self.current_id.fetch_add(1, Ordering::SeqCst));
}
Ok(batch)
}
fn reclaim_request_id(&self) {
let _ = self.current_pending.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| {
if val > 0 {
Some(val - 1)
} else {
None
}
});
}
}
#[derive(Clone, Debug)]
pub struct WsClientBuilder<'a> {
max_request_body_size: u32,
request_timeout: Option<Duration>,
connection_timeout: Duration,
origin: Option<Cow<'a, str>>,
handshake_url: Cow<'a, str>,
max_concurrent_requests: usize,
max_notifs_per_subscription: usize,
}
impl<'a> Default for WsClientBuilder<'a> {
fn default() -> Self {
Self {
max_request_body_size: TEN_MB_SIZE_BYTES,
request_timeout: None,
connection_timeout: Duration::from_secs(10),
origin: None,
handshake_url: From::from("/"),
max_concurrent_requests: 256,
max_notifs_per_subscription: 4,
}
}
}
impl<'a> WsClientBuilder<'a> {
pub fn max_request_body_size(mut self, size: u32) -> Self {
self.max_request_body_size = size;
self
}
pub fn request_timeout(mut self, timeout: Option<Duration>) -> Self {
self.request_timeout = timeout;
self
}
pub fn connection_timeout(mut self, timeout: Duration) -> Self {
self.connection_timeout = timeout;
self
}
pub fn origin_header(mut self, origin: Option<Cow<'a, str>>) -> Self {
self.origin = origin;
self
}
pub fn handshake_url(mut self, url: Cow<'a, str>) -> Self {
self.handshake_url = url;
self
}
pub fn max_concurrent_requests(mut self, max: usize) -> Self {
self.max_concurrent_requests = max;
self
}
pub fn max_notifs_per_subscription(mut self, max: usize) -> Self {
self.max_notifs_per_subscription = max;
self
}
pub async fn build(self, url: &'a str) -> Result<WsClient, Error> {
let max_capacity_per_subscription = self.max_notifs_per_subscription;
let max_concurrent_requests = self.max_concurrent_requests;
let request_timeout = self.request_timeout;
let (to_back, from_front) = mpsc::channel(self.max_concurrent_requests);
let (err_tx, err_rx) = oneshot::channel();
let (sockaddrs, host, mode) = parse_url(url).map_err(|e| Error::TransportError(Box::new(e)))?;
let builder = WsTransportClientBuilder {
sockaddrs,
mode,
host,
handshake_url: self.handshake_url,
timeout: self.connection_timeout,
origin: None,
max_request_body_size: self.max_request_body_size,
};
let (sender, receiver) = builder.build().await.map_err(|e| Error::TransportError(Box::new(e)))?;
async_std::task::spawn(async move {
background_task(sender, receiver, from_front, err_tx, max_capacity_per_subscription).await;
});
Ok(WsClient {
to_back,
request_timeout,
error: Mutex::new(ErrorFromBack::Unread(err_rx)),
id_guard: RequestIdGuard::new(max_concurrent_requests),
})
}
}
impl WsClient {
pub fn is_connected(&self) -> bool {
!self.to_back.is_closed()
}
async fn read_error_from_backend(&self) -> Error {
let mut err_lock = self.error.lock().await;
let from_back = std::mem::replace(&mut *err_lock, ErrorFromBack::Read(String::new()));
let (next_state, err) = from_back.read_error().await;
*err_lock = next_state;
err
}
}
#[async_trait]
impl Client for WsClient {
async fn notification<'a>(&self, method: &'a str, params: JsonRpcParams<'a>) -> Result<(), Error> {
let _req_id = self.id_guard.next_request_id()?;
let notif = JsonRpcNotificationSer::new(method, params);
let raw = serde_json::to_string(¬if).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?;
log::trace!("[frontend]: send notification: {:?}", raw);
let res = self.to_back.clone().send(FrontToBack::Notification(raw)).await;
self.id_guard.reclaim_request_id();
match res {
Ok(()) => Ok(()),
Err(_) => Err(self.read_error_from_backend().await),
}
}
async fn request<'a, R>(&self, method: &'a str, params: JsonRpcParams<'a>) -> Result<R, Error>
where
R: DeserializeOwned,
{
let (send_back_tx, send_back_rx) = oneshot::channel();
let req_id = self.id_guard.next_request_id()?;
let raw = serde_json::to_string(&JsonRpcCallSer::new(Id::Number(req_id), method, params)).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?;
log::trace!("[frontend]: send request: {:?}", raw);
if self
.to_back
.clone()
.send(FrontToBack::Request(RequestMessage { raw, id: req_id, send_back: Some(send_back_tx) }))
.await
.is_err()
{
self.id_guard.reclaim_request_id();
return Err(self.read_error_from_backend().await);
}
let send_back_rx_out = if let Some(duration) = self.request_timeout {
let timeout = async_std::task::sleep(duration);
futures::pin_mut!(send_back_rx, timeout);
match future::select(send_back_rx, timeout).await {
future::Either::Left((send_back_rx_out, _)) => send_back_rx_out,
future::Either::Right((_, _)) => Ok(Err(Error::RequestTimeout)),
}
} else {
send_back_rx.await
};
self.id_guard.reclaim_request_id();
let json_value = match send_back_rx_out {
Ok(Ok(v)) => v,
Ok(Err(err)) => return Err(err),
Err(_) => return Err(self.read_error_from_backend().await),
};
serde_json::from_value(json_value).map_err(Error::ParseError)
}
async fn batch_request<'a, R>(&self, batch: Vec<(&'a str, JsonRpcParams<'a>)>) -> Result<Vec<R>, Error>
where
R: DeserializeOwned + Default + Clone,
{
let batch_ids = self.id_guard.next_request_ids(batch.len())?;
let mut batches = Vec::with_capacity(batch.len());
for (idx, (method, params)) in batch.into_iter().enumerate() {
batches.push(JsonRpcCallSer::new(Id::Number(batch_ids[idx]), method, params));
}
let (send_back_tx, send_back_rx) = oneshot::channel();
let raw = serde_json::to_string(&batches).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?;
log::trace!("[frontend]: send batch request: {:?}", raw);
if self
.to_back
.clone()
.send(FrontToBack::Batch(BatchMessage { raw, ids: batch_ids, send_back: send_back_tx }))
.await
.is_err()
{
self.id_guard.reclaim_request_id();
return Err(self.read_error_from_backend().await);
}
let res = send_back_rx.await;
self.id_guard.reclaim_request_id();
let json_values = match res {
Ok(Ok(v)) => v,
Ok(Err(err)) => return Err(err),
Err(_) => return Err(self.read_error_from_backend().await),
};
let values: Result<_, _> =
json_values.into_iter().map(|val| serde_json::from_value(val).map_err(Error::ParseError)).collect();
Ok(values?)
}
}
#[async_trait]
impl SubscriptionClient for WsClient {
async fn subscribe<'a, N>(
&self,
subscribe_method: &'a str,
params: JsonRpcParams<'a>,
unsubscribe_method: &'a str,
) -> Result<Subscription<N>, Error>
where
N: DeserializeOwned,
{
log::trace!("[frontend]: subscribe: {:?}, unsubscribe: {:?}", subscribe_method, unsubscribe_method);
if subscribe_method == unsubscribe_method {
return Err(Error::SubscriptionNameConflict(unsubscribe_method.to_owned()));
}
let ids = self.id_guard.next_request_ids(2)?;
let raw =
serde_json::to_string(&JsonRpcCallSer::new(Id::Number(ids[0]), subscribe_method, params)).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?;
let (send_back_tx, send_back_rx) = oneshot::channel();
if self
.to_back
.clone()
.send(FrontToBack::Subscribe(SubscriptionMessage {
raw,
subscribe_id: ids[0],
unsubscribe_id: ids[1],
unsubscribe_method: unsubscribe_method.to_owned(),
send_back: send_back_tx,
}))
.await
.is_err()
{
self.id_guard.reclaim_request_id();
return Err(self.read_error_from_backend().await);
}
let res = send_back_rx.await;
self.id_guard.reclaim_request_id();
let (notifs_rx, id) = match res {
Ok(Ok(val)) => val,
Ok(Err(err)) => return Err(err),
Err(_) => return Err(self.read_error_from_backend().await),
};
Ok(Subscription { to_back: self.to_back.clone(), notifs_rx, marker: PhantomData, id })
}
}
async fn background_task(
mut sender: WsSender,
receiver: WsReceiver,
mut frontend: mpsc::Receiver<FrontToBack>,
front_error: oneshot::Sender<Error>,
max_notifs_per_subscription: usize,
) {
let mut manager = RequestManager::new();
let backend_event = futures::stream::unfold(receiver, |mut receiver| async {
let res = receiver.next_response().await;
Some((res, receiver))
});
futures::pin_mut!(backend_event);
loop {
let next_frontend = frontend.next();
let next_backend = backend_event.next();
futures::pin_mut!(next_frontend, next_backend);
match future::select(next_frontend, next_backend).await {
Either::Left((None, _)) => {
log::trace!("[backend]: frontend dropped; terminate client");
return;
}
Either::Left((Some(FrontToBack::Batch(batch)), _)) => {
log::trace!("[backend]: client prepares to send batch request: {:?}", batch.raw);
if let Err(send_back) = manager.insert_pending_batch(batch.ids.clone(), batch.send_back) {
log::warn!("[backend]: batch request: {:?} already pending", batch.ids);
let _ = send_back.send(Err(Error::InvalidRequestId));
continue;
}
if let Err(e) = sender.send(batch.raw).await {
log::warn!("[backend]: client batch request failed: {:?}", e);
manager.complete_pending_batch(batch.ids);
}
}
Either::Left((Some(FrontToBack::Notification(notif)), _)) => {
log::trace!("[backend]: client prepares to send notification: {:?}", notif);
if let Err(e) = sender.send(notif).await {
log::warn!("[backend]: client notif failed: {:?}", e);
}
}
Either::Left((Some(FrontToBack::Request(request)), _)) => {
log::trace!("[backend]: client prepares to send request={:?}", request);
match sender.send(request.raw).await {
Ok(_) => manager
.insert_pending_call(request.id, request.send_back)
.expect("ID unused checked above; qed"),
Err(e) => {
log::warn!("[backend]: client request failed: {:?}", e);
let _ = request.send_back.map(|s| s.send(Err(Error::TransportError(Box::new(e)))));
}
}
}
Either::Left((Some(FrontToBack::Subscribe(sub)), _)) => match sender.send(sub.raw).await {
Ok(_) => manager
.insert_pending_subscription(
sub.subscribe_id,
sub.unsubscribe_id,
sub.send_back,
sub.unsubscribe_method,
)
.expect("Request ID unused checked above; qed"),
Err(e) => {
log::warn!("[backend]: client subscription failed: {:?}", e);
let _ = sub.send_back.send(Err(Error::TransportError(Box::new(e))));
}
},
Either::Left((Some(FrontToBack::SubscriptionClosed(sub_id)), _)) => {
log::trace!("Closing subscription: {:?}", sub_id);
if let Some(unsub) = manager
.get_request_id_by_subscription_id(&sub_id)
.and_then(|req_id| build_unsubscribe_message(&mut manager, req_id, sub_id))
{
stop_subscription(&mut sender, &mut manager, unsub).await;
}
}
Either::Right((Some(Ok(raw)), _)) => {
if let Ok(single) = serde_json::from_slice::<JsonRpcResponse<_>>(&raw) {
log::debug!("[backend]: recv method_call {:?}", single);
match process_single_response(&mut manager, single, max_notifs_per_subscription) {
Ok(Some(unsub)) => {
stop_subscription(&mut sender, &mut manager, unsub).await;
}
Ok(None) => (),
Err(err) => {
let _ = front_error.send(err);
return;
}
}
}
else if let Ok(notif) = serde_json::from_slice::<JsonRpcNotifResponse<_>>(&raw) {
log::debug!("[backend]: recv subscription {:?}", notif);
if let Err(Some(unsub)) = process_subscription_response(&mut manager, notif) {
let _ = stop_subscription(&mut sender, &mut manager, unsub).await;
}
}
else if let Ok(batch) = serde_json::from_slice::<Vec<JsonRpcResponse<_>>>(&raw) {
log::debug!("[backend]: recv batch {:?}", batch);
if let Err(e) = process_batch_response(&mut manager, batch) {
let _ = front_error.send(e);
break;
}
}
else if let Ok(err) = serde_json::from_slice::<JsonRpcErrorAlloc>(&raw) {
log::debug!("[backend]: recv error response {:?}", err);
if let Err(e) = process_error_response(&mut manager, err) {
let _ = front_error.send(e);
break;
}
}
else {
log::debug!(
"[backend]: recv unparseable message: {:?}",
serde_json::from_slice::<serde_json::Value>(&raw)
);
let _ = front_error.send(Error::Custom("Unparsable response".into()));
return;
}
}
Either::Right((Some(Err(e)), _)) => {
log::error!("Error: {:?} terminating client", e);
let _ = front_error.send(Error::TransportError(Box::new(e)));
return;
}
Either::Right((None, _)) => {
log::error!("[backend]: WebSocket receiver dropped; terminate client");
let _ = front_error.send(Error::Custom("WebSocket receiver dropped".into()));
return;
}
}
}
}