use crate::prelude::*;
use alloc::{collections::BTreeMap, sync::Arc};
use async_mutex::Mutex;
use async_std::task;
use async_trait::async_trait;
use async_tungstenite::tungstenite::{Error as WsError, Message};
use futures_channel::oneshot;
use futures_util::{
sink::{Sink, SinkExt},
stream::SplitSink,
Stream, StreamExt,
};
use jsonrpc::{
error::{result_to_response, standard_error, StandardError},
serde_json,
};
use crate::{
rpc::{self, Rpc, RpcResult},
Error,
};
type Id = u8;
pub struct Backend<Tx> {
tx: Mutex<Tx>,
messages: Arc<Mutex<BTreeMap<Id, oneshot::Sender<rpc::Response>>>>,
}
#[async_trait]
impl<Tx> Rpc for Backend<Tx>
where
Tx: Sink<Message, Error = Error> + Unpin + Send,
{
async fn rpc(&self, method: &str, params: &[&str]) -> RpcResult {
let id = self.next_id().await;
log::info!("RPC `{}` (ID={})", method, id);
let (sender, recv) = oneshot::channel::<rpc::Response>();
let messages = self.messages.clone();
messages.lock().await.insert(id, sender);
let msg = serde_json::to_string(&rpc::Request {
id: id.into(),
jsonrpc: Some("2.0"),
method,
params: &Self::convert_params(params),
})
.expect("Request is serializable");
log::debug!("RPC Request {} ...", &msg[..50]);
let _ = self.tx.lock().await.send(Message::Text(msg)).await;
let res = recv
.await
.map_err(|_| standard_error(StandardError::InternalError, None))?
.result::<String>()?;
log::debug!("RPC Response: {}...", &res[..res.len().min(20)]);
let res = hex::decode(&res[2..])
.map_err(|_| standard_error(StandardError::InternalError, None))?;
Ok(res)
}
}
impl<Tx> Backend<Tx> {
async fn next_id(&self) -> Id {
self.messages.lock().await.keys().last().unwrap_or(&0) + 1
}
}
#[cfg(not(feature = "wss"))]
pub type WS2 = futures_util::sink::SinkErrInto<
SplitSink<async_tungstenite::WebSocketStream<async_std::net::TcpStream>, Message>,
Message,
Error,
>;
#[cfg(feature = "wss")]
pub type WS2 = futures_util::sink::SinkErrInto<
SplitSink<
async_tungstenite::WebSocketStream<
async_tungstenite::stream::Stream<
async_std::net::TcpStream,
async_tls::client::TlsStream<async_std::net::TcpStream>,
>,
>,
Message,
>,
Message,
Error,
>;
impl Backend<WS2> {
pub async fn new_ws2(url: &str) -> core::result::Result<Self, WsError> {
log::trace!("WS connecting to {}", url);
let (stream, _) = async_tungstenite::async_std::connect_async(url).await?;
let (tx, rx) = stream.split();
let backend = Backend {
tx: Mutex::new(tx.sink_err_into()),
messages: Arc::new(Mutex::new(BTreeMap::new())),
};
backend.process_incoming_messages(rx);
Ok(backend)
}
fn process_incoming_messages<Rx>(&self, mut rx: Rx)
where
Rx: Stream<Item = core::result::Result<Message, WsError>> + Unpin + Send + 'static,
{
let messages = self.messages.clone();
task::spawn(async move {
while let Some(msg) = rx.next().await {
match msg {
Ok(msg) => {
log::trace!("Got WS message {}", msg);
if let Ok(msg) = msg.to_text() {
let res: rpc::Response =
serde_json::from_str(msg).unwrap_or_else(|_| {
result_to_response(
Err(standard_error(StandardError::ParseError, None)),
().into(),
)
});
if res.id.is_u64() {
let id = res.id.as_u64().unwrap() as Id;
log::trace!("Answering request {}", id);
let mut messages = messages.lock().await;
if let Some(channel) = messages.remove(&id) {
channel.send(res).expect("receiver waiting");
log::debug!("Answered request id: {}", id);
}
}
}
}
Err(err) => {
log::warn!("WS Error: {}", err);
}
}
}
log::warn!("WS connection closed");
});
}
}