use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt;
use std::sync::atomic;
use serde;
use serde_json;
use serde_json::value::RawValue;
use super::{Request, Response};
use crate::error::Error;
use crate::util::HashableValue;
use async_trait::async_trait;
#[async_trait]
pub trait Transport: Send + Sync + 'static {
async fn send_request(&self, r: Request<'_>) -> Result<Response, Error>;
async fn send_batch(&self, rs: &[Request<'_>]) -> Result<Vec<Response>, Error>;
fn fmt_target(&self, f: &mut fmt::Formatter) -> fmt::Result;
}
pub struct Client {
pub(crate) transport: Box<dyn Transport>,
nonce: atomic::AtomicUsize,
}
impl Client {
pub fn with_transport<T: Transport>(transport: T) -> Client {
Client {
transport: Box::new(transport),
nonce: atomic::AtomicUsize::new(1),
}
}
pub fn build_request<'a>(&self, method: &'a str, params: &'a [Box<RawValue>]) -> Request<'a> {
let nonce = self.nonce.fetch_add(1, atomic::Ordering::Relaxed);
Request {
method: method,
params: params,
id: serde_json::Value::from(nonce),
jsonrpc: Some("2.0"),
}
}
pub async fn send_request(&self, request: Request<'_>) -> Result<Response, Error> {
self.transport.send_request(request).await
}
pub async fn send_batch(
&self,
requests: &[Request<'_>],
) -> Result<Vec<Option<Response>>, Error> {
if requests.is_empty() {
return Err(Error::EmptyBatch);
}
let responses = self.transport.send_batch(requests).await?;
if responses.len() > requests.len() {
return Err(Error::WrongBatchResponseSize);
}
let mut by_id = HashMap::with_capacity(requests.len());
for resp in responses.into_iter() {
let id = HashableValue(Cow::Owned(resp.id.clone()));
if let Some(dup) = by_id.insert(id, resp) {
return Err(Error::BatchDuplicateResponseId(dup.id));
}
}
let results = requests
.into_iter()
.map(|r| by_id.remove(&HashableValue(Cow::Borrowed(&r.id))))
.collect();
if let Some((id, _)) = by_id.into_iter().nth(0) {
return Err(Error::WrongBatchResponseId(id.0.into_owned()));
}
Ok(results)
}
pub async fn call<R: for<'a> serde::de::Deserialize<'a>>(
&self,
method: &str,
args: &[Box<RawValue>],
) -> Result<R, Error> {
let request = self.build_request(method, args);
let id = request.id.clone();
let response = self.send_request(request).await?;
if response.jsonrpc != None && response.jsonrpc != Some(From::from("2.0")) {
return Err(Error::VersionMismatch);
}
if response.id != id {
return Err(Error::NonceMismatch);
}
Ok(response.result()?)
}
}
impl fmt::Debug for Client {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "jsonrpc::Client(")?;
self.transport.fmt_target(f)?;
write!(f, ")")
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync;
struct DummyTransport;
#[async_trait]
impl Transport for DummyTransport {
async fn send_request(&self, _: Request<'_>) -> Result<Response, Error> {
Err(Error::NonceMismatch)
}
async fn send_batch(&self, _: &[Request<'_>]) -> Result<Vec<Response>, Error> {
Ok(vec![])
}
fn fmt_target(&self, _: &mut fmt::Formatter) -> fmt::Result {
Ok(())
}
}
#[test]
fn sanity() {
let client = Client::with_transport(DummyTransport);
assert_eq!(client.nonce.load(sync::atomic::Ordering::Relaxed), 1);
let req1 = client.build_request("test", &[]);
assert_eq!(client.nonce.load(sync::atomic::Ordering::Relaxed), 2);
let req2 = client.build_request("test", &[]);
assert_eq!(client.nonce.load(sync::atomic::Ordering::Relaxed), 3);
assert!(req1.id != req2.id);
}
}