1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
use serde::de::DeserializeOwned;
use serde::Serialize;
use snafu::{ensure, Backtrace, ResultExt, Snafu};
use std::ops::Add;
pub use persia_rpc_macro::service;
#[derive(Snafu, Debug)]
#[snafu(visibility = "pub")]
pub enum PersiaRpcError {
#[snafu(display("serialization error"))]
SerializationFailure {
source: bincode::Error,
backtrace: Option<Backtrace>,
},
#[snafu(display("server addr parse error from {}: {}", server_addr, source))]
ServerAddrParseFailure {
server_addr: String,
source: url::ParseError,
backtrace: Option<Backtrace>,
},
#[snafu(display("transport error {}: {}", msg, source))]
TransportError {
msg: String,
source: hyper::Error,
backtrace: Option<Backtrace>,
},
#[snafu(display("transport server side error {}", msg))]
TransportServerSideError {
msg: String,
backtrace: Option<Backtrace>,
},
}
pub struct RpcClient {
client: hyper::Client<hyper::client::HttpConnector>,
server_addr: url::Url,
}
fn expect_uri(url: url::Url) -> hyper::Uri {
url.as_str()
.parse()
.expect("a parsed Url should always be a valid Uri")
}
impl RpcClient {
pub fn new(server_addr: &str) -> Result<Self, PersiaRpcError> {
let server_addr = url::Url::parse("http://".to_string().add(server_addr).as_str())
.context(ServerAddrParseFailure {
server_addr: server_addr.to_string(),
})?;
Ok(Self {
client: hyper::Client::builder().http2_only(true).build_http(),
server_addr,
})
}
pub async fn call_async<T: Serialize + Send + 'static, R: DeserializeOwned + Send + 'static>(
&self,
endpoint_name: &str,
input: T,
) -> Result<R, PersiaRpcError> {
let server_addr = self
.server_addr
.join(endpoint_name)
.context(ServerAddrParseFailure {
server_addr: endpoint_name.to_string(),
})?;
let data = smol::unblock(move || bincode::serialize(&input))
.await
.context(SerializationFailure {})?;
let req = hyper::Request::builder()
.method("POST")
.uri(expect_uri(server_addr))
.body(hyper::Body::from(data))
.expect("request builder");
let response = self.client.request(req).await.context(TransportError {
msg: format!("call {} error", endpoint_name),
})?;
ensure!(
response.status() == hyper::http::StatusCode::OK,
TransportServerSideError {
msg: format!(
"call {} server side error: {:?}",
endpoint_name,
response.into_body()
),
}
);
let resp_bytes =
hyper::body::to_bytes(response.into_body())
.await
.context(TransportError {
msg: format!("call {} recv bytes error", endpoint_name),
})?;
let resp: R = smol::unblock(move || bincode::deserialize(resp_bytes.as_ref()))
.await
.context(SerializationFailure {})?;
Ok(resp)
}
}