use crossbeam_channel::Sender;
#[cfg(feature = "tracing")]
use crate::tracing::{Trace, TraceHolder};
use crate::{
ClientIdentifier, ReplyMessage, RequestMessage, Result, RpcOnWire,
ServerIdentifier,
};
pub struct Client {
pub(crate) client: ClientIdentifier,
pub(crate) server: ServerIdentifier,
pub(crate) request_bus: Sender<Option<RpcOnWire>>,
}
impl Client {
pub async fn call_rpc(
&self,
service_method: String,
request: RequestMessage,
) -> Result<ReplyMessage> {
#[cfg(feature = "tracing")]
{
let trace = TraceHolder::make();
self.trace_and_call_rpc(service_method, request, trace)
.await
}
#[cfg(not(feature = "tracing"))]
self.trace_and_call_rpc(service_method, request).await
}
async fn trace_and_call_rpc(
&self,
service_method: String,
request: RequestMessage,
#[cfg(feature = "tracing")] trace: TraceHolder,
) -> Result<ReplyMessage> {
#[cfg(feature = "tracing")]
let local_trace = trace.clone();
let (tx, rx) = futures::channel::oneshot::channel();
let rpc = RpcOnWire {
client: self.client.clone(),
server: self.server.clone(),
service_method,
request,
reply_channel: tx,
#[cfg(feature = "tracing")]
trace,
};
mark_trace!(local_trace, assemble);
self.request_bus.send(Some(rpc)).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::NotConnected,
format!(
"Cannot send rpc, client {} is disconnected. {}",
self.client, e
),
)
})?;
mark_trace!(local_trace, enqueue);
#[allow(clippy::let_and_return)]
let ret = rx.await.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
format!("Network request is dropped: {}", e),
)
})?;
mark_trace!(local_trace, response);
ret
}
#[cfg(feature = "tracing")]
pub async fn trace_rpc(
&self,
service_method: String,
request: RequestMessage,
) -> (Result<ReplyMessage>, Trace) {
let trace = TraceHolder::make();
let local_trace = trace.clone();
let response = self
.trace_and_call_rpc(service_method, request, trace)
.await;
(response, local_trace.extract())
}
}
#[cfg(test)]
mod tests {
use crossbeam_channel::{unbounded, Sender};
use super::*;
fn make_rpc_call(tx: Sender<Option<RpcOnWire>>) -> Result<ReplyMessage> {
let client = Client {
client: "C".into(),
server: "S".into(),
request_bus: tx,
};
let request = RequestMessage::from_static(&[0x17, 0x20]);
futures::executor::block_on(client.call_rpc("hello".into(), request))
}
fn make_rpc_call_and_reply(
reply: Result<ReplyMessage>,
) -> Result<ReplyMessage> {
let (tx, rx) = unbounded();
let handle = std::thread::spawn(move || make_rpc_call(tx));
let rpc = rx
.recv()
.expect("The request message should arrive")
.expect("The request message should not be null");
assert_eq!("C", &rpc.client);
assert_eq!("S", &rpc.server);
assert_eq!("hello", &rpc.service_method);
assert_eq!(&[0x17, 0x20], rpc.request.as_ref());
rpc.reply_channel
.send(reply)
.expect("The reply channel should not be closed");
handle.join().expect("Rpc sending thread should succeed")
}
#[test]
fn test_call_rpc() -> Result<()> {
let data = &[0x11, 0x99];
let reply =
make_rpc_call_and_reply(Ok(ReplyMessage::from_static(data)))?;
assert_eq!(data, reply.as_ref());
Ok(())
}
#[test]
fn test_call_rpc_remote_error() -> Result<()> {
let reply = make_rpc_call_and_reply(Err(std::io::Error::new(
std::io::ErrorKind::AddrInUse,
"",
)));
if let Err(e) = reply {
assert_eq!(std::io::ErrorKind::AddrInUse, e.kind());
} else {
panic!("Client should propagate remote error.")
}
Ok(())
}
#[test]
fn test_call_rpc_remote_dropped() -> Result<()> {
let (tx, rx) = unbounded();
let handle = std::thread::spawn(move || make_rpc_call(tx));
let rpc = rx
.recv()
.expect("The request message should arrive")
.expect("The request message should not be null");
drop(rpc.reply_channel);
let reply = handle.join().expect("Rpc sending thread should succeed");
if let Err(e) = reply {
assert_eq!(std::io::ErrorKind::ConnectionAborted, e.kind());
} else {
panic!(
"Client should return error. Reply channel has been dropped."
)
}
Ok(())
}
#[test]
fn test_call_rpc_not_connected() -> Result<()> {
let (tx, rx) = unbounded();
{
drop(rx);
}
let handle = std::thread::spawn(move || make_rpc_call(tx));
let reply = handle.join().expect("Rpc sending thread should succeed");
if let Err(e) = reply {
assert_eq!(std::io::ErrorKind::NotConnected, e.kind());
} else {
panic!("Client should return error. request_bus has been dropped.")
}
Ok(())
}
async fn make_rpc(client: Client) -> Result<ReplyMessage> {
let request = RequestMessage::from_static(&[0x17, 0x20]);
client.call_rpc("hello".into(), request).await
}
#[test]
fn test_call_across_threads() -> Result<()> {
let (tx, rx) = unbounded();
let rpc_future = {
let client = Client {
client: "C".into(),
server: "S".into(),
request_bus: tx,
};
make_rpc(client)
};
std::thread::spawn(move || {
let _ = futures::executor::block_on(rpc_future);
});
let rpc = rx
.recv()
.expect("The request message should arrive")
.expect("The request message should not be null");
rpc.reply_channel
.send(Ok(Default::default()))
.expect("The reply channel should not be closed");
Ok(())
}
}