use std::{
collections::HashSet,
task::Poll,
time::{Duration, Instant},
};
use futures::{
channel::{mpsc, oneshot},
sink::SinkMapErr,
FutureExt, SinkExt, StreamExt,
};
use tower::load_shed::error::Overloaded;
use tracing::Span;
use zebra_chain::serialization::SerializationError;
use zebra_test::mock_service::{MockService, PanicAssertion};
use crate::{
constants::{MAX_OVERLOAD_DROP_PROBABILITY, MIN_OVERLOAD_DROP_PROBABILITY, REQUEST_TIMEOUT},
peer::{
connection::{overload_drop_connection_probability, Connection, State},
ClientRequest, ErrorSlot,
},
protocol::external::Message,
types::Nonce,
PeerError, Request, Response,
};
#[tokio::test]
async fn connection_run_loop_ok() {
let _init_guard = zebra_test::init();
let (peer_tx, peer_rx) = mpsc::channel(1);
let (connection, client_tx, mut inbound_service, mut peer_outbound_messages, shared_error_slot) =
new_test_connection();
let connection = connection.run(peer_rx);
let connection = connection.shared();
let connection_guard = connection.clone();
let result = connection.now_or_never();
assert_eq!(result, None);
let error = shared_error_slot.try_get_error();
assert!(error.is_none(), "unexpected error: {error:?}");
assert!(!client_tx.is_closed());
assert!(!peer_tx.is_closed());
inbound_service.expect_no_requests().await;
std::mem::drop(connection_guard);
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, None);
}
#[tokio::test]
async fn connection_run_loop_spawn_ok() {
let _init_guard = zebra_test::init();
let (peer_tx, peer_rx) = mpsc::channel(1);
let (connection, client_tx, mut inbound_service, mut peer_outbound_messages, shared_error_slot) =
new_test_connection();
let mut connection_join_handle = tokio::spawn(connection.run(peer_rx));
let error = shared_error_slot.try_get_error();
assert!(error.is_none(), "unexpected error: {error:?}");
assert!(!client_tx.is_closed());
assert!(!peer_tx.is_closed());
inbound_service.expect_no_requests().await;
let connection_result = futures::poll!(&mut connection_join_handle);
assert!(
matches!(connection_result, Poll::Pending),
"unexpected run loop termination: {connection_result:?}",
);
connection_join_handle.abort();
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, None);
}
#[tokio::test]
async fn connection_run_loop_message_ok() {
let _init_guard = zebra_test::init();
tokio::time::pause();
let (mut peer_tx, peer_rx) = mpsc::channel(1);
let (
connection,
mut client_tx,
mut inbound_service,
mut peer_outbound_messages,
shared_error_slot,
) = new_test_connection();
let mut connection_join_handle = tokio::spawn(connection.run(peer_rx));
let (request_tx, mut request_rx) = oneshot::channel();
let request = ClientRequest {
request: Request::Peers,
tx: request_tx,
inv_collector: None,
transient_addr: None,
span: Span::current(),
};
client_tx
.try_send(request)
.expect("internal request channel is valid");
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, Some(Message::GetAddr));
peer_tx
.try_send(Ok(Message::Addr(Vec::new())))
.expect("peer inbound response channel is valid");
tokio::task::yield_now().await;
let peer_response = request_rx.try_recv();
assert_eq!(
peer_response
.expect("peer internal response channel is valid")
.expect("response is present")
.expect("response is a message (not an error)"),
Response::Peers(Vec::new()),
);
let error = shared_error_slot.try_get_error();
assert!(error.is_none(), "unexpected error: {error:?}");
assert!(!client_tx.is_closed());
assert!(!peer_tx.is_closed());
inbound_service.expect_no_requests().await;
let connection_result = futures::poll!(&mut connection_join_handle);
assert!(
matches!(connection_result, Poll::Pending),
"unexpected run loop termination: {connection_result:?}",
);
connection_join_handle.abort();
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, None);
}
#[tokio::test]
async fn connection_run_loop_future_drop() {
let _init_guard = zebra_test::init();
let (peer_tx, peer_rx) = mpsc::channel(1);
let (connection, client_tx, mut inbound_service, mut peer_outbound_messages, shared_error_slot) =
new_test_connection();
let connection = connection.run(peer_rx);
let result = connection.now_or_never();
assert_eq!(result, None);
let error = shared_error_slot.try_get_error();
assert_eq!(
error.expect("missing expected error").inner_debug(),
"ConnectionDropped",
);
assert!(client_tx.is_closed());
assert!(peer_tx.is_closed());
inbound_service.expect_no_requests().await;
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, None);
}
#[tokio::test]
async fn connection_run_loop_client_close() {
let _init_guard = zebra_test::init();
let (peer_tx, peer_rx) = mpsc::channel(1);
let (
connection,
mut client_tx,
mut inbound_service,
mut peer_outbound_messages,
shared_error_slot,
) = new_test_connection();
let connection = connection.run(peer_rx);
client_tx.close_channel();
let connection = connection.shared();
let connection_guard = connection.clone();
let result = connection.now_or_never();
assert_eq!(result, Some(()));
let error = shared_error_slot.try_get_error();
assert_eq!(
error.expect("missing expected error").inner_debug(),
"ClientDropped",
);
assert!(client_tx.is_closed());
assert!(peer_tx.is_closed());
inbound_service.expect_no_requests().await;
std::mem::drop(connection_guard);
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, None);
}
#[tokio::test]
async fn connection_run_loop_client_drop() {
let _init_guard = zebra_test::init();
let (peer_tx, peer_rx) = mpsc::channel(1);
let (connection, client_tx, mut inbound_service, mut peer_outbound_messages, shared_error_slot) =
new_test_connection();
let connection = connection.run(peer_rx);
std::mem::drop(client_tx);
let connection = connection.shared();
let connection_guard = connection.clone();
let result = connection.now_or_never();
assert_eq!(result, Some(()));
let error = shared_error_slot.try_get_error();
assert_eq!(
error.expect("missing expected error").inner_debug(),
"ClientDropped",
);
assert!(peer_tx.is_closed());
inbound_service.expect_no_requests().await;
std::mem::drop(connection_guard);
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, None);
}
#[tokio::test]
async fn connection_run_loop_inbound_close() {
let _init_guard = zebra_test::init();
let (mut peer_tx, peer_rx) = mpsc::channel(1);
let (connection, client_tx, mut inbound_service, mut peer_outbound_messages, shared_error_slot) =
new_test_connection();
let connection = connection.run(peer_rx);
peer_tx.close_channel();
let connection = connection.shared();
let connection_guard = connection.clone();
let result = connection.now_or_never();
assert_eq!(result, Some(()));
let error = shared_error_slot.try_get_error();
assert_eq!(
error.expect("missing expected error").inner_debug(),
"ConnectionClosed",
);
assert!(client_tx.is_closed());
assert!(peer_tx.is_closed());
inbound_service.expect_no_requests().await;
std::mem::drop(connection_guard);
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, None);
}
#[tokio::test]
async fn connection_run_loop_inbound_drop() {
let _init_guard = zebra_test::init();
let (peer_tx, peer_rx) = mpsc::channel(1);
let (connection, client_tx, mut inbound_service, mut peer_outbound_messages, shared_error_slot) =
new_test_connection();
let connection = connection.run(peer_rx);
std::mem::drop(peer_tx);
let connection = connection.shared();
let connection_guard = connection.clone();
let result = connection.now_or_never();
assert_eq!(result, Some(()));
let error = shared_error_slot.try_get_error();
assert_eq!(
error.expect("missing expected error").inner_debug(),
"ConnectionClosed",
);
assert!(client_tx.is_closed());
inbound_service.expect_no_requests().await;
std::mem::drop(connection_guard);
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, None);
}
#[tokio::test]
async fn connection_run_loop_failed() {
let _init_guard = zebra_test::init();
let (peer_tx, peer_rx) = mpsc::channel(1);
let (
mut connection,
client_tx,
mut inbound_service,
mut peer_outbound_messages,
shared_error_slot,
) = new_test_connection();
connection.state = State::Failed;
shared_error_slot
.try_update_error(PeerError::Overloaded.into())
.expect("unexpected previous error in tests");
let connection = connection.run(peer_rx);
let connection = connection.shared();
let connection_guard = connection.clone();
let result = connection.now_or_never();
assert_eq!(result, Some(()));
let error = shared_error_slot.try_get_error();
assert_eq!(
error.expect("missing expected error").inner_debug(),
"Overloaded",
);
assert!(client_tx.is_closed());
assert!(peer_tx.is_closed());
inbound_service.expect_no_requests().await;
std::mem::drop(connection_guard);
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, None);
}
#[tokio::test]
async fn connection_run_loop_send_timeout_nil_response() {
let _init_guard = zebra_test::init();
tokio::time::pause();
let (peer_tx, peer_rx) = mpsc::channel(1);
let (
connection,
mut client_tx,
mut inbound_service,
mut peer_outbound_messages,
shared_error_slot,
) = new_test_connection();
let mut connection_join_handle = tokio::spawn(connection.run(peer_rx));
let (request_tx, mut request_rx) = oneshot::channel();
let request = ClientRequest {
request: Request::AdvertiseTransactionIds(HashSet::new()),
tx: request_tx,
inv_collector: None,
transient_addr: None,
span: Span::current(),
};
client_tx.try_send(request).expect("channel is valid");
tokio::time::sleep(REQUEST_TIMEOUT + Duration::from_secs(1)).await;
let error = shared_error_slot.try_get_error();
assert_eq!(
error.expect("missing expected error").inner_debug(),
"ConnectionSendTimeout",
);
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, Some(Message::Inv(Vec::new())));
let peer_response = request_rx.try_recv();
assert_eq!(
peer_response
.expect("peer internal response channel is valid")
.expect("response is present")
.expect_err("response is an error (not a message)")
.inner_debug(),
"ConnectionSendTimeout",
);
assert!(client_tx.is_closed());
assert!(peer_tx.is_closed());
inbound_service.expect_no_requests().await;
let connection_result = futures::poll!(&mut connection_join_handle);
assert!(
matches!(connection_result, Poll::Ready(Ok(()))),
"expected run loop termination, but run loop continued: {connection_result:?}",
);
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, None);
}
#[tokio::test]
async fn connection_run_loop_send_timeout_expect_response() {
let _init_guard = zebra_test::init();
tokio::time::pause();
let (peer_tx, peer_rx) = mpsc::channel(1);
let (
connection,
mut client_tx,
mut inbound_service,
mut peer_outbound_messages,
shared_error_slot,
) = new_test_connection();
let mut connection_join_handle = tokio::spawn(connection.run(peer_rx));
let (request_tx, mut request_rx) = oneshot::channel();
let request = ClientRequest {
request: Request::Peers,
tx: request_tx,
inv_collector: None,
transient_addr: None,
span: Span::current(),
};
client_tx.try_send(request).expect("channel is valid");
tokio::time::sleep(REQUEST_TIMEOUT + Duration::from_secs(1)).await;
let error = shared_error_slot.try_get_error();
assert_eq!(
error.expect("missing expected error").inner_debug(),
"ConnectionSendTimeout",
);
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, Some(Message::GetAddr));
let peer_response = request_rx.try_recv();
assert_eq!(
peer_response
.expect("peer internal response channel is valid")
.expect("response is present")
.expect_err("response is an error (not a message)")
.inner_debug(),
"ConnectionSendTimeout",
);
assert!(client_tx.is_closed());
assert!(peer_tx.is_closed());
inbound_service.expect_no_requests().await;
let connection_result = futures::poll!(&mut connection_join_handle);
assert!(
matches!(connection_result, Poll::Ready(Ok(()))),
"expected run loop termination, but run loop continued: {connection_result:?}",
);
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, None);
}
#[tokio::test]
async fn connection_run_loop_receive_timeout() {
let _init_guard = zebra_test::init();
tokio::time::pause();
let (peer_tx, peer_rx) = mpsc::channel(1);
let (
connection,
mut client_tx,
mut inbound_service,
mut peer_outbound_messages,
shared_error_slot,
) = new_test_connection();
let mut connection_join_handle = tokio::spawn(connection.run(peer_rx));
let (request_tx, mut request_rx) = oneshot::channel();
let request = ClientRequest {
request: Request::Peers,
tx: request_tx,
inv_collector: None,
transient_addr: None,
span: Span::current(),
};
client_tx.try_send(request).expect("channel is valid");
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, Some(Message::GetAddr));
tokio::time::sleep(REQUEST_TIMEOUT + Duration::from_secs(1)).await;
let error = shared_error_slot.try_get_error();
assert!(error.is_none(), "unexpected error: {error:?}");
assert!(!client_tx.is_closed());
assert!(!peer_tx.is_closed());
let peer_response = request_rx.try_recv();
assert_eq!(
peer_response
.expect("peer internal response channel is valid")
.expect("response is present")
.expect_err("response is an error (not a message)")
.inner_debug(),
"ConnectionReceiveTimeout",
);
inbound_service.expect_no_requests().await;
let connection_result = futures::poll!(&mut connection_join_handle);
assert!(
matches!(connection_result, Poll::Pending),
"unexpected run loop termination: {connection_result:?}",
);
connection_join_handle.abort();
let outbound_message = peer_outbound_messages.next().await;
assert_eq!(outbound_message, None);
}
#[test]
fn overload_probability_reduces_over_time() {
let now = Instant::now();
let prev = now + Duration::from_secs(1);
assert_eq!(
overload_drop_connection_probability(now, Some(prev)),
MAX_OVERLOAD_DROP_PROBABILITY,
"if the overload time is in the future (OS bugs?), it should have maximum drop probability",
);
let prev = now;
assert_eq!(
overload_drop_connection_probability(now, Some(prev)),
MAX_OVERLOAD_DROP_PROBABILITY,
"if the overload times are the same, overloads should have maximum drop probability",
);
let prev = now - Duration::from_micros(1);
let drop_probability = overload_drop_connection_probability(now, Some(prev));
assert!(
drop_probability <= MAX_OVERLOAD_DROP_PROBABILITY,
"if the overloads are very close together, drops can optionally decrease: {drop_probability} <= {MAX_OVERLOAD_DROP_PROBABILITY}",
);
assert!(
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.001,
"if the overloads are very close together, drops can only decrease slightly: {drop_probability}",
);
let last_probability = drop_probability;
let prev = now - Duration::from_millis(1);
let drop_probability = overload_drop_connection_probability(now, Some(prev));
assert!(
drop_probability < last_probability,
"if the overloads decrease, drops should decrease: {drop_probability} < {last_probability}",
);
assert!(
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.001,
"if the overloads are very close together, drops can only decrease slightly: {drop_probability}",
);
let last_probability = drop_probability;
let prev = now - Duration::from_millis(10);
let drop_probability = overload_drop_connection_probability(now, Some(prev));
assert!(
drop_probability < last_probability,
"if the overloads decrease, drops should decrease: {drop_probability} < {last_probability}",
);
assert!(
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.001,
"if the overloads are very close together, drops can only decrease slightly: {drop_probability}",
);
let last_probability = drop_probability;
let prev = now - Duration::from_millis(100);
let drop_probability = overload_drop_connection_probability(now, Some(prev));
assert!(
drop_probability < last_probability,
"if the overloads decrease, drops should decrease: {drop_probability} < {last_probability}",
);
assert!(
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.01,
"if the overloads are very close together, drops can only decrease slightly: {drop_probability}",
);
let last_probability = drop_probability;
let prev = now - Duration::from_secs(1);
let drop_probability = overload_drop_connection_probability(now, Some(prev));
assert!(
drop_probability < last_probability,
"if the overloads decrease, drops should decrease: {drop_probability} < {last_probability}",
);
assert!(
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability > 0.4,
"if the overloads are distant, drops should decrease a lot: {drop_probability}",
);
let last_probability = drop_probability;
let prev = now - Duration::from_secs(5);
let drop_probability = overload_drop_connection_probability(now, Some(prev));
assert!(
drop_probability < last_probability,
"if the overloads decrease, drops should decrease: {drop_probability} < {last_probability}",
);
assert_eq!(
drop_probability, MIN_OVERLOAD_DROP_PROBABILITY,
"if overloads are far apart, drops should have minimum drop probability: {drop_probability}",
);
let _last_probability = drop_probability;
let prev = now - Duration::from_secs(10);
let drop_probability = overload_drop_connection_probability(now, Some(prev));
assert_eq!(
drop_probability, MIN_OVERLOAD_DROP_PROBABILITY,
"if overloads are far apart, drops should have minimum drop probability: {drop_probability}",
);
let drop_probability = overload_drop_connection_probability(now, None);
assert_eq!(
drop_probability, MIN_OVERLOAD_DROP_PROBABILITY,
"if there is no previous overload time, overloads should have minimum drop probability: {drop_probability}",
);
}
#[tokio::test(flavor = "multi_thread")]
async fn connection_is_randomly_disconnected_on_overload() {
let _init_guard = zebra_test::init();
const TEST_RUNS: usize = 220;
const TESTS_BEFORE_FAILURE: f32 = 50_000.0;
let test_runs = TEST_RUNS.try_into().expect("constant fits in i32");
assert!(
1.0 / MIN_OVERLOAD_DROP_PROBABILITY.powi(test_runs) > TESTS_BEFORE_FAILURE,
"not enough test runs: failures must be frequent enough to happen in almost all tests"
);
assert!(
1.0 / MAX_OVERLOAD_DROP_PROBABILITY.powi(test_runs) > TESTS_BEFORE_FAILURE,
"not enough test runs: successes must be frequent enough to happen in almost all tests"
);
let mut connection_continues = 0;
let mut connection_closes = 0;
for _ in 0..TEST_RUNS {
let (mut peer_tx, peer_rx) = mpsc::channel(1);
let (
connection,
_client_tx,
mut inbound_service,
mut peer_outbound_messages,
shared_error_slot,
) = new_test_connection();
let error = shared_error_slot.try_get_error();
assert!(
error.is_none(),
"unexpected error before starting the connection event loop: {error:?}",
);
let connection_handle = tokio::spawn(connection.run(peer_rx));
tokio::time::sleep(Duration::from_millis(1)).await;
let error = shared_error_slot.try_get_error();
assert!(
error.is_none(),
"unexpected error before sending messages to the connection event loop: {error:?}",
);
let inbound_req = Message::GetAddr;
peer_tx
.send(Ok(inbound_req))
.await
.expect("send to channel always succeeds");
tokio::time::sleep(Duration::from_millis(1)).await;
let error = shared_error_slot.try_get_error();
assert!(
error.is_none(),
"unexpected error before sending responses to the connection event loop: {error:?}",
);
inbound_service
.expect_request(Request::Peers)
.await
.respond_error(Overloaded::new().into());
tokio::time::sleep(Duration::from_millis(1)).await;
let outbound_result = peer_outbound_messages.try_recv();
assert!(
outbound_result.is_err(),
"unexpected outbound message after Overloaded error:\n\
{outbound_result:?}\n\
note: Err(TryRecvError::Empty) means no messages, Err(TryRecvError::Closed) means the channel is closed"
);
let error = shared_error_slot.try_get_error();
if error.is_some() {
connection_closes += 1;
} else {
connection_continues += 1;
}
connection_handle.abort();
}
assert!(
connection_closes > 0,
"some overloaded connections must be closed at random"
);
assert!(
connection_continues > 0,
"some overloaded errors must be ignored at random"
);
}
#[tokio::test]
async fn connection_ping_pong_round_trip() {
let _init_guard = zebra_test::init();
let (peer_tx, peer_rx) = mpsc::channel(1);
let (
connection,
mut client_tx,
_inbound_service,
mut peer_outbound_messages,
_shared_error_slot,
) = new_test_connection();
let connection = tokio::spawn(connection.run(peer_rx));
let (response_tx, response_rx) = oneshot::channel();
let nonce = Nonce::default();
client_tx
.send(ClientRequest {
request: Request::Ping(nonce),
tx: response_tx,
inv_collector: None,
transient_addr: None,
span: Span::none(),
})
.await
.expect("send to connection should succeed");
let outbound_msg = peer_outbound_messages
.next()
.await
.expect("expected outbound Ping message");
let ping_nonce = match outbound_msg {
Message::Ping(nonce) => nonce,
msg => panic!("expected Ping message, but got: {:?}", msg),
};
assert_eq!(
nonce, ping_nonce,
"Ping nonce in request must match message sent to peer"
);
let pong_rtt = Duration::from_millis(42);
tokio::time::sleep(pong_rtt).await;
peer_tx
.clone()
.send(Ok(Message::Pong(ping_nonce)))
.await
.expect("sending Pong to connection should succeed");
match response_rx.await.expect("response channel must succeed") {
Ok(Response::Pong(rtt)) => {
assert!(
rtt >= pong_rtt,
"measured RTT {rtt:?} must be >= simulated RTT {pong_rtt:?}"
);
}
Ok(resp) => panic!("unexpected response: {resp:?}"),
Err(err) => panic!("unexpected error: {err:?}"),
}
drop(peer_tx);
let _ = connection.await;
}
fn new_test_connection() -> (
Connection<
MockService<Request, Response, PanicAssertion>,
SinkMapErr<mpsc::Sender<Message>, fn(mpsc::SendError) -> SerializationError>,
>,
mpsc::Sender<ClientRequest>,
MockService<Request, Response, PanicAssertion>,
mpsc::Receiver<Message>,
ErrorSlot,
) {
super::new_test_connection()
}