pub mod pool;
#[cfg(test)]
mod tests;
#[cfg(feature = "metrics")]
mod metrics;
use std::{
borrow::Cow,
convert::TryFrom,
fmt,
future::Future,
marker::PhantomData,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
time::{Duration, Instant},
};
use bytes::Bytes;
use futures::{
FutureExt,
SinkExt,
StreamExt,
future,
future::{BoxFuture, Either},
task::{Context, Poll},
};
use log::*;
use prost::Message;
use tari_shutdown::{Shutdown, ShutdownSignal, oneshot_trigger::OneshotSignal};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::{Mutex, mpsc, oneshot, watch},
time,
};
use tower::{Service, ServiceExt};
use tracing::{Instrument, Level, span};
use super::message::RpcMethod;
use crate::{
framing::CanonicalFraming,
message::MessageExt,
peer_manager::NodeId,
proto,
protocol::{
ProtocolId,
rpc,
rpc::{
Handshake,
NamedProtocolService,
Response,
RpcError,
RpcServerError,
RpcStatus,
body::ClientStreaming,
message::{BaseRequest, RpcMessageFlags},
},
},
stream_id,
stream_id::StreamId,
};
const LOG_TARGET: &str = "comms::rpc::client";
#[derive(Clone)]
pub struct RpcClient {
connector: ClientConnector,
}
impl RpcClient {
pub fn builder<T>() -> RpcClientBuilder<T>
where T: NamedProtocolService {
RpcClientBuilder::new().with_protocol_id(T::PROTOCOL_NAME.into())
}
pub async fn connect<TSubstream>(
config: RpcClientConfig,
node_id: NodeId,
framed: CanonicalFraming<TSubstream>,
protocol_name: ProtocolId,
terminate_signal: Option<OneshotSignal<NodeId>>,
session_state: Arc<AtomicBool>,
) -> Result<Self, RpcError>
where
TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId + 'static,
{
trace!(target: LOG_TARGET,"connect to {node_id:?} with {config:?}");
let (request_tx, request_rx) = mpsc::channel(1);
let shutdown = Shutdown::new();
let shutdown_signal = shutdown.to_signal();
let (last_request_latency_tx, last_request_latency_rx) = watch::channel(None);
let connector = ClientConnector::new(request_tx, last_request_latency_rx, shutdown);
let (ready_tx, ready_rx) = oneshot::channel();
let tracing_id = tracing::Span::current().id();
tokio::spawn({
let span = span!(Level::TRACE, "start_rpc_worker");
span.follows_from(tracing_id);
RpcClientWorker::new(
config,
node_id,
request_rx,
last_request_latency_tx,
framed,
ready_tx,
protocol_name,
shutdown_signal,
terminate_signal,
session_state,
)
.run()
.instrument(span)
});
ready_rx
.await
.expect("ready_rx oneshot is never dropped without a reply")?;
Ok(Self { connector })
}
pub async fn request_response<T, R, M>(&mut self, request: T, method: M) -> Result<R, RpcError>
where
T: prost::Message,
R: prost::Message + Default + std::fmt::Debug,
M: Into<RpcMethod>,
{
let req_bytes = request.to_encoded_bytes();
let request = BaseRequest::new(method.into(), req_bytes.into());
let mut resp = self.call_inner(request).await?;
let resp = resp.recv().await.ok_or(RpcError::ServerClosedRequest)??;
let resp = R::decode(resp.into_message())?;
Ok(resp)
}
pub async fn server_streaming<T, M, R>(&mut self, request: T, method: M) -> Result<ClientStreaming<R>, RpcError>
where
T: prost::Message,
R: prost::Message + Default,
M: Into<RpcMethod>,
{
let req_bytes = request.to_encoded_bytes();
let request = BaseRequest::new(method.into(), req_bytes.into());
let resp = self.call_inner(request).await?;
Ok(ClientStreaming::new(resp))
}
pub async fn close(&mut self) {
self.connector.close().await;
}
pub fn is_connected(&self) -> bool {
self.connector.is_connected()
}
pub fn get_last_request_latency(&mut self) -> Option<Duration> {
self.connector.get_last_request_latency()
}
pub fn ping(&mut self) -> impl Future<Output = Result<Duration, RpcError>> + '_ {
self.connector.send_ping()
}
async fn call_inner(
&mut self,
request: BaseRequest<Bytes>,
) -> Result<mpsc::Receiver<Result<Response<Bytes>, RpcStatus>>, RpcError> {
let svc = self.connector.ready().await?;
let resp = svc.call(request).await?;
Ok(resp)
}
}
impl fmt::Debug for RpcClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RpcClient {{ inner: ... }}")
}
}
#[derive(Debug, Clone)]
pub struct RpcClientBuilder<TClient> {
config: RpcClientConfig,
protocol_id: Option<ProtocolId>,
node_id: Option<NodeId>,
terminate_signal: Option<OneshotSignal<NodeId>>,
session_state: Option<Arc<AtomicBool>>,
_client: PhantomData<TClient>,
}
impl<TClient> Default for RpcClientBuilder<TClient> {
fn default() -> Self {
Self {
config: Default::default(),
protocol_id: None,
node_id: None,
terminate_signal: None,
session_state: None,
_client: PhantomData,
}
}
}
impl<TClient> RpcClientBuilder<TClient> {
pub fn new() -> Self {
Default::default()
}
pub fn with_deadline(mut self, timeout: Duration) -> Self {
self.config.deadline = Some(timeout);
self
}
pub fn with_deadline_grace_period(mut self, timeout: Duration) -> Self {
self.config.deadline_grace_period = timeout;
self
}
pub fn with_handshake_timeout(mut self, timeout: Duration) -> Self {
self.config.handshake_timeout = timeout;
self
}
pub fn with_protocol_id(mut self, protocol_id: ProtocolId) -> Self {
self.protocol_id = Some(protocol_id);
self
}
pub fn with_node_id(mut self, node_id: NodeId) -> Self {
self.node_id = Some(node_id);
self
}
pub fn with_terminate_signal(mut self, terminate_signal: OneshotSignal<NodeId>) -> Self {
self.terminate_signal = Some(terminate_signal);
self
}
pub fn with_session_state(mut self, session_state: Arc<AtomicBool>) -> Self {
self.session_state = Some(session_state);
self
}
}
impl<TClient> RpcClientBuilder<TClient>
where TClient: From<RpcClient> + NamedProtocolService
{
pub async fn connect<TSubstream>(self, framed: CanonicalFraming<TSubstream>) -> Result<TClient, RpcError>
where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId + 'static {
RpcClient::connect(
self.config,
self.node_id.unwrap_or_default(),
framed,
self.protocol_id
.as_ref()
.cloned()
.unwrap_or_else(|| ProtocolId::from_static(TClient::PROTOCOL_NAME)),
self.terminate_signal,
self.session_state.unwrap_or(Arc::new(AtomicBool::new(true))),
)
.await
.map(Into::into)
}
}
#[derive(Debug, Clone, Copy)]
pub struct RpcClientConfig {
pub deadline: Option<Duration>,
pub deadline_grace_period: Duration,
pub handshake_timeout: Duration,
}
impl RpcClientConfig {
pub fn timeout_with_grace_period(&self) -> Option<Duration> {
self.deadline.map(|d| d + self.deadline_grace_period)
}
pub fn handshake_timeout(&self) -> Duration {
self.handshake_timeout
}
}
impl Default for RpcClientConfig {
fn default() -> Self {
Self {
deadline: Some(Duration::from_secs(120)),
deadline_grace_period: Duration::from_secs(60),
handshake_timeout: Duration::from_secs(90),
}
}
}
#[derive(Clone)]
pub struct ClientConnector {
inner: mpsc::Sender<ClientRequest>,
last_request_latency_rx: watch::Receiver<Option<Duration>>,
shutdown: Arc<Mutex<Shutdown>>,
}
impl ClientConnector {
pub(self) fn new(
sender: mpsc::Sender<ClientRequest>,
last_request_latency_rx: watch::Receiver<Option<Duration>>,
shutdown: Shutdown,
) -> Self {
Self {
inner: sender,
last_request_latency_rx,
shutdown: Arc::new(Mutex::new(shutdown)),
}
}
pub async fn close(&mut self) {
let mut lock = self.shutdown.lock().await;
lock.trigger();
}
pub fn get_last_request_latency(&mut self) -> Option<Duration> {
*self.last_request_latency_rx.borrow()
}
pub async fn send_ping(&mut self) -> Result<Duration, RpcError> {
let (reply, reply_rx) = oneshot::channel();
self.inner
.send(ClientRequest::SendPing(reply))
.await
.map_err(|_| RpcError::ClientClosed)?;
let latency = reply_rx.await.map_err(|_| RpcError::RequestCancelled)??;
Ok(latency)
}
pub fn is_connected(&self) -> bool {
!self.inner.is_closed()
}
}
impl fmt::Debug for ClientConnector {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ClientConnector {{ inner: ... }}")
}
}
impl Service<BaseRequest<Bytes>> for ClientConnector {
type Error = RpcError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
type Response = mpsc::Receiver<Result<Response<Bytes>, RpcStatus>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, request: BaseRequest<Bytes>) -> Self::Future {
let (reply, reply_rx) = oneshot::channel();
let inner = self.inner.clone();
async move {
inner
.send(ClientRequest::SendRequest { request, reply })
.await
.map_err(|_| RpcError::ClientClosed)?;
reply_rx.await.map_err(|_| RpcError::RequestCancelled)
}
.boxed()
}
}
struct RpcClientWorker<TSubstream> {
config: RpcClientConfig,
node_id: NodeId,
request_rx: mpsc::Receiver<ClientRequest>,
last_request_latency_tx: watch::Sender<Option<Duration>>,
framed: CanonicalFraming<TSubstream>,
next_request_id: u16,
ready_tx: Option<oneshot::Sender<Result<(), RpcError>>>,
protocol_id: ProtocolId,
shutdown_signal: ShutdownSignal,
terminate_signal: Option<OneshotSignal<NodeId>>,
session_state: Arc<AtomicBool>,
}
impl<TSubstream> RpcClientWorker<TSubstream>
where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
{
pub(self) fn new(
config: RpcClientConfig,
node_id: NodeId,
request_rx: mpsc::Receiver<ClientRequest>,
last_request_latency_tx: watch::Sender<Option<Duration>>,
framed: CanonicalFraming<TSubstream>,
ready_tx: oneshot::Sender<Result<(), RpcError>>,
protocol_id: ProtocolId,
shutdown_signal: ShutdownSignal,
terminate_signal: Option<OneshotSignal<NodeId>>,
session_state: Arc<AtomicBool>,
) -> Self {
Self {
config,
node_id,
request_rx,
framed,
next_request_id: 0,
ready_tx: Some(ready_tx),
last_request_latency_tx,
protocol_id,
shutdown_signal,
terminate_signal,
session_state,
}
}
fn protocol_name(&self) -> Cow<'_, str> {
String::from_utf8_lossy(&self.protocol_id)
}
fn stream_id(&self) -> stream_id::Id {
self.framed.stream_id()
}
#[allow(clippy::too_many_lines)]
async fn run(mut self) {
debug!(
target: LOG_TARGET,
"(stream={}) Performing client handshake for '{}'",
self.stream_id(),
self.protocol_name()
);
let start = Instant::now();
let mut handshake = Handshake::new(&mut self.framed).with_timeout(self.config.handshake_timeout());
match handshake.perform_client_handshake().await {
Ok(_) => {
let latency = start.elapsed();
debug!(
target: LOG_TARGET,
"(stream={}) RPC Session ({}) negotiation completed. Latency: {:.0?}",
self.stream_id(),
self.protocol_name(),
latency
);
let _ = self.last_request_latency_tx.send(Some(latency));
if let Some(r) = self.ready_tx.take() {
let _result = r.send(Ok(()));
}
#[cfg(feature = "metrics")]
metrics::handshake_counter(&self.protocol_id).inc();
},
Err(err) => {
#[cfg(feature = "metrics")]
metrics::handshake_errors(&self.protocol_id).inc();
if let Some(r) = self.ready_tx.take() {
let _result = r.send(Err(err.into()));
}
return;
},
}
let mut terminate_signal = self
.terminate_signal
.take()
.map(|f| f.boxed())
.unwrap_or_else(|| future::pending::<Option<NodeId>>().boxed());
#[cfg(feature = "metrics")]
metrics::num_sessions(&self.protocol_id).inc();
loop {
tokio::select! {
biased;
_ = &mut self.shutdown_signal => {
break;
}
node_id = &mut terminate_signal => {
debug!(
target: LOG_TARGET, "(stream={}) Peer '{}' connection has dropped. Worker is terminating.",
self.stream_id(), node_id.unwrap_or_default()
);
break;
}
req = self.request_rx.recv() => {
match req {
Some(req) => {
if let Err(err) = self.handle_request(req).await {
#[cfg(feature = "metrics")]
metrics::client_errors(&self.protocol_id).inc();
info!(
target: LOG_TARGET,
"(stream={}) Unexpected error: {}. Worker is terminating.",
self.stream_id(), err
);
break;
}
}
None => {
debug!(
target: LOG_TARGET,
"(stream={}) Request channel closed. Worker is terminating.",
self.stream_id()
);
break
},
}
}
}
}
#[cfg(feature = "metrics")]
metrics::num_sessions(&self.protocol_id).dec();
let session_state = self.session_state.as_ref();
session_state.store(false, Ordering::Relaxed);
if let Err(err) = self.framed.close().await {
debug!(
target: LOG_TARGET,
"(stream: {}, peer: {}) IO Error when closing substream: {}",
self.stream_id(),
self.node_id,
err
);
}
debug!(
target: LOG_TARGET,
"(stream: {}, peer: {}) RpcClientWorker ({}) terminated.",
self.stream_id(),
self.node_id,
self.protocol_name()
);
}
async fn handle_request(&mut self, req: ClientRequest) -> Result<(), RpcError> {
use ClientRequest::{SendPing, SendRequest};
match req {
SendRequest { request, reply } => {
self.do_request_response(request, reply).await?;
},
SendPing(reply) => {
self.do_ping_pong(reply).await?;
},
}
Ok(())
}
async fn do_ping_pong(&mut self, reply: oneshot::Sender<Result<Duration, RpcStatus>>) -> Result<(), RpcError> {
let ack = proto::rpc::RpcRequest {
flags: u32::from(RpcMessageFlags::ACK.bits()),
deadline: self.config.deadline.map(|t| t.as_secs()).unwrap_or(0),
..Default::default()
};
let start = Instant::now();
self.framed.send(ack.to_encoded_bytes().into()).await?;
trace!(
target: LOG_TARGET,
"(stream={}) Ping (protocol {}) sent in {:.2?}",
self.stream_id(),
self.protocol_name(),
start.elapsed()
);
let mut reader = RpcResponseReader::new(&mut self.framed, self.config, 0);
let resp = match reader.read_ack().await {
Ok(resp) => resp,
Err(RpcError::ReplyTimeout) => {
debug!(
target: LOG_TARGET,
"(stream={}) Ping timed out after {:.0?}",
self.stream_id(),
start.elapsed()
);
#[cfg(feature = "metrics")]
metrics::client_timeouts(&self.protocol_id).inc();
let _result = reply.send(Err(RpcStatus::timed_out("Response timed out")));
return Ok(());
},
Err(err) => return Err(err),
};
let status = RpcStatus::from(&resp);
if !status.is_ok() {
let _result = reply.send(Err(status.clone()));
return Err(status.into());
}
let resp_flags =
RpcMessageFlags::from_bits(u8::try_from(resp.flags).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX))
})?)
.ok_or(RpcStatus::protocol_error(&format!(
"invalid message flag, does not match any flags ({})",
resp.flags
)))?;
if !resp_flags.contains(RpcMessageFlags::ACK) {
warn!(
target: LOG_TARGET,
"(stream={}) Invalid ping response {:?}",
self.stream_id(),
resp
);
let _result = reply.send(Err(RpcStatus::protocol_error(&format!(
"Received invalid ping response on protocol '{}'",
self.protocol_name()
))));
return Err(RpcError::InvalidPingResponse);
}
let _result = reply.send(Ok(start.elapsed()));
Ok(())
}
#[allow(clippy::too_many_lines)]
async fn do_request_response(
&mut self,
request: BaseRequest<Bytes>,
reply: oneshot::Sender<mpsc::Receiver<Result<Response<Bytes>, RpcStatus>>>,
) -> Result<(), RpcError> {
#[cfg(feature = "metrics")]
metrics::outbound_request_bytes(&self.protocol_id).observe(request.get_ref().len() as f64);
let request_id = self.next_request_id();
let method = request.method.into();
let req = proto::rpc::RpcRequest {
request_id: u32::from(request_id),
method,
deadline: self.config.deadline.map(|t| t.as_secs()).unwrap_or(0),
flags: 0,
payload: request.message.to_vec(),
};
trace!(target: LOG_TARGET, "Sending request: {req}");
if reply.is_closed() {
warn!(
target: LOG_TARGET,
"Client request was cancelled before request was sent"
);
}
let (response_tx, response_rx) = mpsc::channel(5);
if let Err(mut rx) = reply.send(response_rx) {
warn!(
target: LOG_TARGET,
"Client request was cancelled after request was sent. This means that you are making an RPC request \
and then immediately dropping the response! (protocol = {})",
self.protocol_name(),
);
rx.close();
return Ok(());
}
#[cfg(feature = "metrics")]
let latency = metrics::request_response_latency(&self.protocol_id);
#[cfg(feature = "metrics")]
let mut metrics_timer = Some(latency.start_timer());
let timer = Instant::now();
if let Err(err) = self.send_request(req).await {
warn!(target: LOG_TARGET, "{err}");
#[cfg(feature = "metrics")]
metrics::client_errors(&self.protocol_id).inc();
let _result = response_tx.send(Err(err.into())).await;
return Ok(());
}
let partial_latency = timer.elapsed();
loop {
if self.shutdown_signal.is_triggered() {
debug!(
target: LOG_TARGET,
"[peer: {}, protocol: {}, stream_id: {}, req_id: {}] Client connector closed. Quitting stream \
early",
self.node_id,
self.protocol_name(),
self.stream_id(),
request_id
);
break;
}
let resp_result = {
let resp_fut = self.read_response(request_id);
tokio::pin!(resp_fut);
let closed_fut = response_tx.closed();
tokio::pin!(closed_fut);
match future::select(resp_fut, closed_fut).await {
Either::Left((r, _)) => Some(r),
Either::Right(_) => None,
}
};
let resp_result = match resp_result {
Some(r) => r,
None => {
self.premature_close(request_id, method).await?;
break;
},
};
let resp = match resp_result {
Ok((resp, time_to_first_msg)) => {
if let Some(t) = time_to_first_msg {
let _ = self.last_request_latency_tx.send(Some(partial_latency + t));
}
trace!(
target: LOG_TARGET,
"Received response ({} byte(s)) from request #{} (protocol = {}, method={})",
resp.payload.len(),
request_id,
self.protocol_name(),
method,
);
#[cfg(feature = "metrics")]
if let Some(t) = metrics_timer.take() {
t.observe_duration();
}
resp
},
Err(RpcError::ReplyTimeout) => {
debug!(
target: LOG_TARGET,
"Request {request_id} (method={method}) timed out"
);
#[cfg(feature = "metrics")]
metrics::client_timeouts(&self.protocol_id).inc();
if response_tx.is_closed() {
self.premature_close(request_id, method).await?;
} else {
let _result = response_tx.send(Err(RpcStatus::timed_out("Response timed out"))).await;
}
break;
},
Err(RpcError::ClientClosed) => {
debug!(
target: LOG_TARGET,
"Request {request_id} (method={method}) was closed (read_reply)"
);
self.request_rx.close();
break;
},
Err(err) => {
return Err(err);
},
};
match Self::convert_to_result(resp) {
Ok(Ok(resp)) => {
let is_finished = resp.is_finished();
if response_tx.is_closed() {
self.premature_close(request_id, method).await?;
break;
} else {
let _result = response_tx.send(Ok(resp)).await;
}
if is_finished {
break;
}
},
Ok(Err(err)) => {
debug!(target: LOG_TARGET, "Remote service returned error: {err}");
if !response_tx.is_closed() {
let _result = response_tx.send(Err(err)).await;
}
break;
},
Err(err @ RpcError::ResponseIdDidNotMatchRequest { .. }) |
Err(err @ RpcError::UnexpectedAckResponse) => {
warn!(target: LOG_TARGET, "{err}");
continue;
},
Err(err) => return Err(err),
}
}
Ok(())
}
async fn premature_close(&mut self, request_id: u16, method: u32) -> Result<(), RpcError> {
info!(
target: LOG_TARGET,
"(stream={}) Response receiver was dropped before the response/stream could complete for protocol {}, \
interrupting the stream. ",
self.stream_id(),
self.protocol_name()
);
let req = proto::rpc::RpcRequest {
request_id: u32::from(request_id),
method,
flags: RpcMessageFlags::FIN.bits().into(),
deadline: self.config.deadline.map(|d| d.as_secs()).unwrap_or(0),
..Default::default()
};
if let Ok(res) = time::timeout(Duration::from_secs(2), self.send_request(req)).await {
res?;
}
Ok(())
}
async fn send_request(&mut self, req: proto::rpc::RpcRequest) -> Result<(), RpcError> {
let payload = req.to_encoded_bytes();
if payload.len() > rpc::max_request_size() {
return Err(RpcError::MaxRequestSizeExceeded {
got: payload.len(),
expected: rpc::max_request_size(),
});
}
self.framed.send(payload.into()).await?;
Ok(())
}
async fn read_response(
&mut self,
request_id: u16,
) -> Result<(proto::rpc::RpcResponse, Option<Duration>), RpcError> {
let stream_id = self.stream_id();
let protocol_name = self.protocol_name().to_string();
let mut reader = RpcResponseReader::new(&mut self.framed, self.config, request_id);
let mut num_ignored = 0;
let resp = loop {
match reader.read_response().await {
Ok(resp) => {
trace!(
target: LOG_TARGET,
"(stream: {}, {}) Received body len = {}",
stream_id,
protocol_name,
reader.bytes_read()
);
#[cfg(feature = "metrics")]
metrics::inbound_response_bytes(&self.protocol_id).observe(reader.bytes_read() as f64);
let time_to_first_msg = reader.time_to_first_msg();
break (resp, time_to_first_msg);
},
Err(RpcError::ResponseIdDidNotMatchRequest { actual, expected })
if actual.wrapping_add(1) == request_id =>
{
warn!(
target: LOG_TARGET,
"Possible delayed response received for previous request {actual}"
);
num_ignored += 1;
const MAX_ALLOWED_IGNORED: usize = 20;
if num_ignored > MAX_ALLOWED_IGNORED {
return Err(RpcError::ResponseIdDidNotMatchRequest { actual, expected });
}
continue;
},
Err(err) => return Err(err),
}
};
Ok(resp)
}
fn next_request_id(&mut self) -> u16 {
let mut next_id = self.next_request_id;
self.next_request_id = self.next_request_id.wrapping_add(1);
if next_id == 0 {
next_id += 1;
self.next_request_id += 1;
}
next_id
}
fn convert_to_result(resp: proto::rpc::RpcResponse) -> Result<Result<Response<Bytes>, RpcStatus>, RpcError> {
let status = RpcStatus::from(&resp);
if !status.is_ok() {
return Ok(Err(status));
}
let flags = match resp.flags() {
Ok(flags) => flags,
Err(e) => return Ok(Err(RpcError::ServerError(RpcServerError::ProtocolError(e)).into())),
};
let resp = Response {
flags,
payload: resp.payload.into(),
};
Ok(Ok(resp))
}
}
pub enum ClientRequest {
SendRequest {
request: BaseRequest<Bytes>,
reply: oneshot::Sender<mpsc::Receiver<Result<Response<Bytes>, RpcStatus>>>,
},
SendPing(oneshot::Sender<Result<Duration, RpcStatus>>),
}
struct RpcResponseReader<'a, TSubstream> {
framed: &'a mut CanonicalFraming<TSubstream>,
config: RpcClientConfig,
request_id: u16,
bytes_read: usize,
time_to_first_msg: Option<Duration>,
}
impl<'a, TSubstream> RpcResponseReader<'a, TSubstream>
where TSubstream: AsyncRead + AsyncWrite + Unpin
{
pub fn new(framed: &'a mut CanonicalFraming<TSubstream>, config: RpcClientConfig, request_id: u16) -> Self {
Self {
framed,
config,
request_id,
bytes_read: 0,
time_to_first_msg: None,
}
}
pub fn bytes_read(&self) -> usize {
self.bytes_read
}
pub fn time_to_first_msg(&self) -> Option<Duration> {
self.time_to_first_msg
}
pub async fn read_response(&mut self) -> Result<proto::rpc::RpcResponse, RpcError> {
let timer = Instant::now();
let resp = self.next().await?;
self.time_to_first_msg = Some(timer.elapsed());
self.check_response(&resp)?;
self.bytes_read = resp.payload.len();
trace!(
target: LOG_TARGET,
"Received {} bytes in {:.2?}",
resp.payload.len(),
self.time_to_first_msg.unwrap_or_default()
);
Ok(resp)
}
pub async fn read_ack(&mut self) -> Result<proto::rpc::RpcResponse, RpcError> {
let resp = self.next().await?;
Ok(resp)
}
fn check_response(&self, resp: &proto::rpc::RpcResponse) -> Result<(), RpcError> {
let resp_id = u16::try_from(resp.request_id)
.map_err(|_| RpcStatus::protocol_error(&format!("invalid request_id: must be less than {}", u16::MAX)))?;
let flags =
RpcMessageFlags::from_bits(u8::try_from(resp.flags).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX))
})?)
.ok_or(RpcStatus::protocol_error(&format!(
"invalid message flag, does not match any flags ({})",
resp.flags
)))?;
if flags.contains(RpcMessageFlags::ACK) {
return Err(RpcError::UnexpectedAckResponse);
}
if resp_id != self.request_id {
return Err(RpcError::ResponseIdDidNotMatchRequest {
expected: self.request_id,
actual: u16::try_from(resp.request_id).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid request_id: must be less than {}", u16::MAX))
})?,
});
}
Ok(())
}
async fn next(&mut self) -> Result<proto::rpc::RpcResponse, RpcError> {
let next_msg_fut = match self.config.timeout_with_grace_period() {
Some(timeout) => Either::Left(time::timeout(timeout, self.framed.next())),
None => Either::Right(self.framed.next().map(Ok)),
};
match next_msg_fut.await {
Ok(Some(Ok(resp))) => Ok(proto::rpc::RpcResponse::decode(resp)?),
Ok(Some(Err(err))) => Err(err.into()),
Ok(None) => Err(RpcError::ServerClosedRequest),
Err(_) => Err(RpcError::ReplyTimeout),
}
}
}