use std::collections::HashMap;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::time::{Duration, Instant};
use crossbeam_channel::{Receiver, RecvError, Sender};
use parking_lot::Mutex;
use rand::{thread_rng, Rng};
use crate::{
Client, ClientIdentifier, Result, RpcOnWire, Server, ServerIdentifier,
};
pub struct Network {
reliable: bool,
long_delays: bool,
long_reordering: bool,
clients: HashMap<ClientIdentifier, (bool, ServerIdentifier)>,
servers: HashMap<ServerIdentifier, Arc<Server>>,
request_bus: Sender<Option<RpcOnWire>>,
server_pool: tokio::runtime::Runtime,
keep_running: bool,
stopped: AtomicBool,
rpc_count: std::cell::Cell<usize>,
}
impl Network {
pub fn set_reliable(&mut self, yes: bool) {
self.reliable = yes
}
pub fn set_long_reordering(&mut self, yes: bool) {
self.long_reordering = yes
}
pub fn set_long_delays(&mut self, yes: bool) {
self.long_delays = yes
}
pub fn stop(&mut self) {
self.keep_running = false;
self.request_bus
.send(None)
.expect("Sending RPCs should never fail.");
}
pub fn stopped(&self) -> bool {
self.stopped.load(Ordering::Acquire)
}
pub fn make_client<C: Into<ClientIdentifier>, S: Into<ServerIdentifier>>(
&mut self,
client: C,
server: S,
) -> Client {
let (client, server) = (client.into(), server.into());
self.clients.insert(client.clone(), (true, server.clone()));
Client {
client,
server,
request_bus: self.request_bus.clone(),
}
}
pub fn set_enable_client<C: AsRef<str>>(&mut self, client: C, yes: bool) {
if let Some(pair) = self.clients.get_mut(client.as_ref()) {
pair.0 = yes;
}
}
pub fn add_server<S: Into<ServerIdentifier>>(
&mut self,
server_name: S,
mut server: Server,
) {
server.use_pool(self.server_pool.handle().clone());
self.servers.insert(server_name.into(), Arc::new(server));
}
pub fn remove_server<S: AsRef<str>>(&mut self, server_name: S) {
let server = self.servers.remove(server_name.as_ref());
if let Some(server) = server {
server.interrupt();
}
}
pub fn get_rpc_count<S: AsRef<str>>(
&self,
server_name: S,
) -> Option<usize> {
self.servers
.get(server_name.as_ref())
.map(|s| s.rpc_count())
}
#[allow(clippy::ptr_arg)]
fn dispatch(&self, client: &ClientIdentifier) -> Result<Arc<Server>> {
let (enabled, server_name) =
self.clients.get(client).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
format!("Client {} is not connected.", client),
)
})?;
if !enabled {
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
format!("Client {} is disabled.", client),
));
}
let server = self.servers.get(server_name).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::NotFound,
format!(
"Cannot connect {} to server {}: server not found.",
client, server_name,
),
)
})?;
Ok(server.clone())
}
pub fn get_total_rpc_count(&self) -> usize {
self.rpc_count.get()
}
}
impl Network {
const MAX_MINOR_DELAY_MILLIS: u64 = 27;
const MAX_SHORT_DELAY_MILLIS: u64 = 100;
const MAX_LONG_DELAY_MILLIS: u64 = 7000;
const DROP_RATE: (u32, u32) = (100, 1000);
const LONG_REORDERING_RATE: (u32, u32) = (600u32, 900u32);
const LONG_REORDERING_BASE_DELAY_MILLIS: u64 = 200;
const LONG_REORDERING_RANDOM_DELAY_BOUND_MILLIS: u64 = 2000;
const SHUTDOWN_DELAY: Duration = Duration::from_micros(100);
async fn delay_for_millis(milli_seconds: u64) {
tokio::time::sleep(Duration::from_millis(milli_seconds)).await;
}
async fn serve_rpc(network: Arc<Mutex<Self>>, rpc: RpcOnWire) {
let (server_result, reliable, long_reordering, long_delays) = {
let network = network.lock();
network.increase_rpc_count();
(
network.dispatch(&rpc.client),
network.reliable,
network.long_reordering,
network.long_delays,
)
};
mark_trace!(rpc.trace, dispatched);
if !reliable {
let minor_delay =
thread_rng().gen_range(0..Self::MAX_MINOR_DELAY_MILLIS);
Self::delay_for_millis(minor_delay).await;
if thread_rng().gen_ratio(Self::DROP_RATE.0, Self::DROP_RATE.1) {
Self::delay_for_millis(Self::MAX_MINOR_DELAY_MILLIS).await;
let _ = rpc.reply_channel.send(Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"Remote server did not respond in time.",
)));
mark_trace!(rpc.trace, served);
return;
}
}
let reply = match server_result {
Ok(server) => {
let data = rpc.request.clone();
mark_trace!(rpc.trace, before_serving);
#[cfg(not(feature = "tracing"))]
{
server.dispatch(rpc.service_method, data).await
}
#[cfg(feature = "tracing")]
{
let res = server
.dispatch(rpc.service_method, data, rpc.trace.clone())
.await;
mark_trace!(rpc.trace, after_server_response);
res
}
}
Err(e) => {
let long_delay_upper = if long_delays {
Self::MAX_LONG_DELAY_MILLIS
} else {
Self::MAX_SHORT_DELAY_MILLIS
};
let long_delay =
rand::thread_rng().gen_range(0..long_delay_upper);
Self::delay_for_millis(long_delay).await;
Err(e)
}
};
mark_trace!(rpc.trace, after_serving);
let client = &rpc.client;
let reply = reply.and_then(|reply| {
network.lock().dispatch(client).map(|_| reply)
});
if reply.is_ok() {
if !reliable
&& thread_rng().gen_ratio(Self::DROP_RATE.0, Self::DROP_RATE.1)
{
let _ = rpc.reply_channel.send(Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"The network did not send respond in time.",
)));
mark_trace!(rpc.trace, served);
return;
} else if long_reordering {
let should_reorder = thread_rng().gen_ratio(
Self::LONG_REORDERING_RATE.0,
Self::LONG_REORDERING_RATE.1,
);
if should_reorder {
let long_delay_bound = thread_rng().gen_range(
0..Self::LONG_REORDERING_RANDOM_DELAY_BOUND_MILLIS,
);
let long_delay = Self::LONG_REORDERING_BASE_DELAY_MILLIS
+ thread_rng().gen_range(0..1 + long_delay_bound);
Self::delay_for_millis(long_delay).await;
}
}
}
if let Err(_e) = rpc.reply_channel.send(reply) {
}
mark_trace!(rpc.trace, served);
}
const WORKER_THREADS: usize = 16;
pub fn run_daemon() -> Arc<Mutex<Network>> {
let (network, rx) = Network::new();
let network = Arc::new(Mutex::new(network));
let thread_pool = tokio::runtime::Builder::new_multi_thread()
.worker_threads(Self::WORKER_THREADS)
.thread_name("network")
.enable_time()
.build()
.expect("Creating network thread pool should not fail");
let other = network.clone();
std::thread::spawn(move || {
let network = other;
let mut stop_timer = Instant::now();
loop {
if stop_timer.elapsed() >= Self::SHUTDOWN_DELAY {
if !network.lock().keep_running {
break;
}
stop_timer = Instant::now();
}
match rx.recv() {
Ok(Some(rpc)) => {
mark_trace!(rpc.trace, dequeue);
thread_pool
.spawn(Self::serve_rpc(network.clone(), rpc));
}
Ok(None) => break,
Err(RecvError) => break,
}
}
thread_pool.shutdown_timeout(Self::SHUTDOWN_DELAY);
drop(rx);
network.lock().stopped.store(true, Ordering::Release);
});
network
}
}
impl Network {
fn increase_rpc_count(&self) {
self.rpc_count.set(self.rpc_count.get() + 1);
}
fn new() -> (Self, Receiver<Option<RpcOnWire>>) {
let server_pool = tokio::runtime::Builder::new_multi_thread()
.enable_time()
.thread_name("server-pool")
.build()
.expect("Creating server thread pool should not fail");
let (tx, rx) = crossbeam_channel::unbounded();
let network = Network {
reliable: true,
long_delays: false,
long_reordering: false,
clients: Default::default(),
servers: Default::default(),
request_bus: tx,
server_pool,
keep_running: true,
stopped: Default::default(),
rpc_count: std::cell::Cell::new(0),
};
(network, rx)
}
}
#[cfg(test)]
mod tests {
use std::sync::Barrier;
use parking_lot::MutexGuard;
use crate::test_utils::{
junk_server::{
make_test_server, JunkRpcs, NON_CLIENT, NON_SERVER, TEST_CLIENT,
TEST_SERVER,
},
make_aborting_rpc, make_echo_rpc,
};
use crate::{ReplyMessage, RequestMessage, Result};
use super::*;
#[test]
fn test_rpc_count_works() {
let (network, _) = Network::new();
assert_eq!(0, network.get_total_rpc_count());
network.increase_rpc_count();
assert_eq!(1, network.get_total_rpc_count());
}
fn unlock<T>(network: &Arc<Mutex<T>>) -> MutexGuard<T> {
network.lock()
}
#[test]
fn test_network_shutdown() {
let network = Network::run_daemon();
let sender = {
let mut network = unlock(&network);
network.stop();
network.request_bus.clone()
};
while !unlock(&network).stopped() {
std::thread::sleep(Network::SHUTDOWN_DELAY)
}
let (rpc, _) = make_echo_rpc("client", "server", &[]);
let result = sender.send(Some(rpc));
assert!(
result.is_err(),
"Network is shutdown, requests should not be processed."
);
}
fn send_rpc<C: Into<String>, S: Into<String>>(
rpc: RpcOnWire,
rx: futures::channel::oneshot::Receiver<Result<ReplyMessage>>,
client: C,
server: S,
enabled: bool,
) -> Result<ReplyMessage> {
let network = Network::run_daemon();
let sender = {
let mut network = unlock(&network);
network
.clients
.insert(client.into(), (enabled, server.into()));
network
.servers
.insert(TEST_SERVER.into(), Arc::new(make_test_server()));
network.request_bus.clone()
};
let result = sender.send(Some(rpc));
assert!(
result.is_ok(),
"Network is running, requests should be processed."
);
let reply = match futures::executor::block_on(rx) {
Ok(reply) => reply,
Err(e) => panic!("Future execution should not fail: {}", e),
};
reply
}
#[test]
fn test_proxy_rpc() -> Result<()> {
let (rpc, rx) =
make_echo_rpc(TEST_CLIENT, TEST_SERVER, &[0x09u8, 0x00u8]);
let reply = send_rpc(rpc, rx, TEST_CLIENT, TEST_SERVER, true);
match reply {
Ok(reply) => assert_eq!(reply.as_ref(), &[0x00u8, 0x09u8]),
Err(e) => panic!("Expecting echo message, got {}", e),
}
Ok(())
}
#[test]
fn test_proxy_rpc_server_error() -> Result<()> {
let (rpc, rx) = make_aborting_rpc(TEST_CLIENT, TEST_SERVER);
let reply = send_rpc(rpc, rx, TEST_CLIENT, TEST_SERVER, true);
let err = reply.expect_err("Network should proxy server errors");
assert_eq!(std::io::ErrorKind::ConnectionReset, err.kind());
Ok(())
}
#[test]
fn test_proxy_rpc_server_not_found() -> Result<()> {
let (rpc, rx) = make_aborting_rpc(TEST_CLIENT, NON_SERVER);
let reply = send_rpc(rpc, rx, TEST_CLIENT, NON_SERVER, true);
let err = reply.expect_err("Network should check server in memory");
assert_eq!(std::io::ErrorKind::NotFound, err.kind());
Ok(())
}
#[test]
fn test_proxy_rpc_client_disabled() -> Result<()> {
let (rpc, rx) = make_aborting_rpc(TEST_CLIENT, TEST_SERVER);
let reply = send_rpc(rpc, rx, TEST_CLIENT, TEST_SERVER, false);
let err =
reply.expect_err("Network should check if client is disabled");
assert_eq!(std::io::ErrorKind::BrokenPipe, err.kind());
Ok(())
}
#[test]
fn test_proxy_rpc_no_such_client() -> Result<()> {
let (rpc, rx) = make_aborting_rpc(NON_CLIENT, TEST_SERVER);
let reply = send_rpc(rpc, rx, TEST_CLIENT, TEST_SERVER, true);
let err = reply.expect_err("Network should check client names");
assert_eq!(std::io::ErrorKind::PermissionDenied, err.kind());
Ok(())
}
#[test]
fn test_server_killed() -> Result<()> {
let start_barrier = Arc::new(Barrier::new(2));
let end_barrier = Arc::new(Barrier::new(2));
let network = Network::run_daemon();
let mut server = Server::make_server(TEST_SERVER);
let start_barrier_clone = start_barrier.clone();
let end_barrier_clone = end_barrier.clone();
server.register_rpc_handler(
"blocking".to_owned(),
Box::new(move |args| {
start_barrier_clone.wait();
end_barrier_clone.wait();
args
}),
)?;
unlock(&network).add_server(TEST_SERVER, server);
let client = unlock(&network).make_client(TEST_CLIENT, TEST_SERVER);
std::thread::spawn(move || {
start_barrier.wait();
unlock(&network).remove_server(TEST_SERVER);
end_barrier.wait();
});
let reply = futures::executor::block_on(
client.call_rpc("blocking".to_owned(), Default::default()),
);
let err = reply
.expect_err("Client should receive error after server is killed");
assert!(
std::io::ErrorKind::Interrupted == err.kind()
|| std::io::ErrorKind::NotFound == err.kind()
);
Ok(())
}
fn make_network_and_client() -> (Arc<Mutex<Network>>, Client) {
let network = Network::run_daemon();
let server = make_test_server();
unlock(&network).add_server(TEST_SERVER, server);
let client = unlock(&network).make_client(TEST_CLIENT, TEST_SERVER);
(network, client)
}
#[test]
fn test_basic_functions() -> Result<()> {
let (network, client) = make_network_and_client();
assert_eq!(0, unlock(&network).get_total_rpc_count());
let request = RequestMessage::from_static(&[0x17, 0x20]);
let reply_data = &[0x20, 0x17];
let reply = futures::executor::block_on(
client.call_rpc(JunkRpcs::Echo.name(), request.clone()),
)?;
assert_eq!(reply_data, reply.as_ref());
assert_eq!(1, unlock(&network).get_total_rpc_count());
unlock(&network).set_enable_client(TEST_CLIENT, false);
let reply = futures::executor::block_on(
client.call_rpc(JunkRpcs::Echo.name(), request.clone()),
);
reply.expect_err("Client is blocked");
assert_eq!(2, unlock(&network).get_total_rpc_count());
assert_eq!(Some(1), unlock(&network).get_rpc_count(TEST_SERVER));
assert_eq!(None, unlock(&network).get_rpc_count(NON_SERVER));
unlock(&network).set_enable_client(TEST_CLIENT, true);
unlock(&network).remove_server(&TEST_SERVER);
let reply = futures::executor::block_on(
client.call_rpc(JunkRpcs::Echo.name(), request.clone()),
);
reply.expect_err("Client is blocked");
assert_eq!(3, unlock(&network).get_total_rpc_count());
unlock(&network).stop();
while !unlock(&network).stopped() {
std::thread::sleep(Duration::from_millis(10));
}
let reply = futures::executor::block_on(
client.call_rpc(JunkRpcs::Echo.name(), request),
);
reply.expect_err("Network is shutdown");
assert_eq!(3, unlock(&network).get_total_rpc_count());
Ok(())
}
#[test]
fn test_unreliable() -> Result<()> {
let (network, _) = make_network_and_client();
network.lock().set_reliable(false);
const RPC_COUNT: usize = 1000;
let mut handles = vec![];
for i in 0..RPC_COUNT {
let client = network
.lock()
.make_client(format!("{}-{}", TEST_CLIENT, i), TEST_SERVER);
let handle = std::thread::spawn(move || {
let reply = client.call_rpc(
JunkRpcs::Echo.name(),
RequestMessage::from_static(&[0x20, 0x17]),
);
futures::executor::block_on(reply)
});
handles.push(handle);
}
let mut success = 0;
for handle in handles {
let result = handle.join().expect("Thread join should not fail");
if result.is_ok() {
success += 1;
}
}
let success = success as f64;
assert!(
success > RPC_COUNT as f64 * 0.75,
"More than 15% RPC failed"
);
assert!(success < RPC_COUNT as f64 * 0.85, "Less than 5% RPC failed");
Ok(())
}
#[test]
#[ignore = "Large tests with many threads"]
fn test_many_requests() {
let now = Instant::now();
let (network, _) = make_network_and_client();
let barrier = Arc::new(Barrier::new(THREAD_COUNT + 1));
const THREAD_COUNT: usize = 200;
const RPC_COUNT: usize = 10000;
let mut handles = vec![];
for i in 0..THREAD_COUNT {
let network_ref = network.clone();
let barrier_ref = barrier.clone();
let handle = std::thread::spawn(move || {
let client = unlock(&network_ref)
.make_client(format!("{}-{}", TEST_CLIENT, i), TEST_SERVER);
barrier_ref.wait();
let mut results = vec![];
for _ in 0..RPC_COUNT {
let reply = client.call_rpc(
JunkRpcs::Echo.name(),
RequestMessage::from_static(&[0x20, 0x17]),
);
results.push(reply);
}
for result in results {
futures::executor::block_on(result)
.expect("All futures should succeed");
}
});
handles.push(handle);
}
barrier.wait();
for handle in handles {
handle.join().expect("All threads should succeed");
}
eprintln!("Many requests test took {:?}", now.elapsed());
eprintln!(
"The average RPC latency is {:?}",
now.elapsed() / (RPC_COUNT as u32)
);
}
#[cfg(feature = "tracing")]
#[test]
fn test_tracing() -> Result<()> {
let (_, client) = make_network_and_client();
let (response, trace) = futures::executor::block_on(
client.trace_rpc(JunkRpcs::Echo.name(), RequestMessage::new()),
);
assert!(response.is_ok());
assert!(trace.assemble > Duration::from_secs(0));
assert!(trace.enqueue > trace.assemble);
assert!(trace.dispatched > trace.dequeue);
assert!(trace.before_serving > trace.dispatched);
assert!(trace.after_serving > trace.before_serving);
assert!(trace.served >= trace.after_serving, "{:?}", trace);
Ok(())
}
#[cfg(feature = "tracing")]
#[test]
#[ignore = "Large tests that runs for seconds"]
fn test_dispatch_delay() -> Result<()> {
let (_, client) = make_network_and_client();
const ROUND: u32 = 100_000;
let start = Instant::now();
let mut dispatch_delay: Duration = Default::default();
for _ in 0..ROUND {
let (response, trace) = futures::executor::block_on(
client.trace_rpc(JunkRpcs::Echo.name(), RequestMessage::new()),
);
assert!(response.is_ok());
dispatch_delay += trace.response - trace.assemble;
}
eprintln!("The average RPC delay is {:?} ", dispatch_delay / ROUND);
eprintln!(
"The average client latency is {:?}",
start.elapsed() / ROUND
);
Ok(())
}
}