use super::*;
use crate::config::GlobalExecutor;
use futures::stream::StreamExt;
use std::future::Future;
use tokio::sync::mpsc;
use tokio::time::{Duration, sleep, timeout};
struct MockHandshakeStream;
impl Stream for MockHandshakeStream {
type Item = crate::node::network_bridge::handshake::Event;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Pending
}
}
struct ClosedHandshakeStream;
impl Stream for ClosedHandshakeStream {
type Item = crate::node::network_bridge::handshake::Event;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(None)
}
}
fn create_mock_handshake_stream() -> MockHandshakeStream {
MockHandshakeStream
}
struct MockClientStream;
impl Stream for MockClientStream {
type Item = (ClientId, WaitingTransaction);
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Pending
}
}
struct MockExecutorStream;
impl Stream for MockExecutorStream {
type Item = Transaction;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Pending
}
}
struct MockClientReceiverStream {
rx: mpsc::Receiver<(ClientId, WaitingTransaction)>,
}
impl Stream for MockClientReceiverStream {
type Item = (ClientId, WaitingTransaction);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.rx).poll_recv(cx)
}
}
struct MockExecutorReceiverStream {
rx: mpsc::Receiver<Transaction>,
}
impl Stream for MockExecutorReceiverStream {
type Item = Transaction;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.rx).poll_recv(cx)
}
}
#[tokio::test]
#[test_log::test]
async fn test_priority_select_future_wakeup() {
let (notif_tx, notif_rx) = mpsc::channel(10);
let (_op_tx, op_rx) = mpsc::channel::<OpExecutionPayload>(10);
let (_conn_event_tx, conn_event_rx) = mpsc::channel(10);
let (_bridge_tx, bridge_rx) = mpsc::channel(10);
let (_node_tx, node_rx) = mpsc::channel(10);
let notif_tx_clone = notif_tx.clone();
GlobalExecutor::spawn(async move {
sleep(Duration::from_millis(50)).await;
let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
notif_tx_clone.send(Either::Left(test_msg)).await.unwrap();
});
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientStream,
MockExecutorStream,
conn_event_rx,
);
tokio::pin!(stream);
let result = timeout(Duration::from_millis(200), stream.next()).await;
assert!(
result.is_ok(),
"Select stream should wake up when notification arrives"
);
let select_result = result.unwrap().expect("Stream should yield value");
match select_result {
SelectResult::Notification(Some(_)) => {}
SelectResult::Notification(None) => panic!("Got Notification(None)"),
SelectResult::OpExecution(_) => panic!("Got OpExecution"),
SelectResult::PeerConnection(_) => panic!("Got PeerConnection"),
SelectResult::ConnBridge(_) => panic!("Got ConnBridge"),
SelectResult::Handshake(_) => panic!("Got Handshake"),
SelectResult::NodeController(_) => panic!("Got NodeController"),
SelectResult::ClientTransaction(_) => panic!("Got ClientTransaction"),
SelectResult::ExecutorTransaction(_) => panic!("Got ExecutorTransaction"),
}
}
#[tokio::test]
#[test_log::test]
async fn test_priority_select_future_priority_ordering() {
let (notif_tx, notif_rx) = mpsc::channel(10);
let (op_tx, op_rx) = mpsc::channel::<OpExecutionPayload>(10);
let (_conn_event_tx, conn_event_rx) = mpsc::channel(10);
let (bridge_tx, bridge_rx) = mpsc::channel(10);
let (_, node_rx) = mpsc::channel(10);
let (callback_tx, _) = mpsc::channel(1);
let dummy_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
op_tx
.send((callback_tx, dummy_msg.clone(), None))
.await
.unwrap();
bridge_tx
.send(P2pBridgeEvent::NodeAction(NodeEvent::Disconnect {
cause: None,
}))
.await
.unwrap();
let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
notif_tx.send(Either::Left(test_msg)).await.unwrap();
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientStream,
MockExecutorStream,
conn_event_rx,
);
tokio::pin!(stream);
let result = timeout(Duration::from_millis(100), stream.next()).await;
assert!(result.is_ok());
match result.unwrap().expect("Stream should yield value") {
SelectResult::Notification(_) => {}
SelectResult::OpExecution(_)
| SelectResult::PeerConnection(_)
| SelectResult::ConnBridge(_)
| SelectResult::Handshake(_)
| SelectResult::NodeController(_)
| SelectResult::ClientTransaction(_)
| SelectResult::ExecutorTransaction(_) => {
panic!("Notification should be received first due to priority")
}
}
}
#[tokio::test]
#[test_log::test]
async fn test_priority_select_future_concurrent_messages() {
let (notif_tx, notif_rx) = mpsc::channel(100);
let (_conn_event_tx, conn_event_rx) = mpsc::channel(10);
for _ in 0..15 {
let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
notif_tx.send(Either::Left(test_msg)).await.unwrap();
}
let (_, op_rx) = mpsc::channel(10);
let (_, bridge_rx) = mpsc::channel(10);
let (_, node_rx) = mpsc::channel(10);
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientStream,
MockExecutorStream,
conn_event_rx,
);
tokio::pin!(stream);
let result = timeout(Duration::from_millis(100), stream.next()).await;
assert!(result.is_ok(), "Should receive first message");
match result.unwrap().expect("Stream should yield value") {
SelectResult::Notification(Some(_)) => {}
SelectResult::Notification(_)
| SelectResult::OpExecution(_)
| SelectResult::PeerConnection(_)
| SelectResult::ConnBridge(_)
| SelectResult::Handshake(_)
| SelectResult::NodeController(_)
| SelectResult::ClientTransaction(_)
| SelectResult::ExecutorTransaction(_) => panic!("Expected notification"),
}
}
#[tokio::test]
#[test_log::test]
async fn test_priority_select_future_buffered_messages() {
let (notif_tx, notif_rx) = mpsc::channel(10);
let (_conn_event_tx, conn_event_rx) = mpsc::channel(10);
let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
notif_tx.send(Either::Left(test_msg)).await.unwrap();
let (_, op_rx) = mpsc::channel(10);
let (_, bridge_rx) = mpsc::channel(10);
let (_, node_rx) = mpsc::channel(10);
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientStream,
MockExecutorStream,
conn_event_rx,
);
tokio::pin!(stream);
let result = timeout(Duration::from_millis(100), stream.next()).await;
assert!(
result.is_ok(),
"Should receive buffered message immediately"
);
match result.unwrap().expect("Stream should yield value") {
SelectResult::Notification(Some(_)) => {}
SelectResult::Notification(_)
| SelectResult::OpExecution(_)
| SelectResult::PeerConnection(_)
| SelectResult::ConnBridge(_)
| SelectResult::Handshake(_)
| SelectResult::NodeController(_)
| SelectResult::ClientTransaction(_)
| SelectResult::ExecutorTransaction(_) => panic!("Expected notification"),
}
}
#[tokio::test]
#[test_log::test]
async fn test_priority_select_future_rapid_cancellations() {
use futures::StreamExt;
let (notif_tx, notif_rx) = mpsc::channel(100);
let (_conn_event_tx, conn_event_rx) = mpsc::channel(10);
for _ in 0..10 {
let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
notif_tx.send(Either::Left(test_msg)).await.unwrap();
}
let (_, op_rx) = mpsc::channel(10);
let (_, bridge_rx) = mpsc::channel(10);
let (_, node_rx) = mpsc::channel(10);
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientStream,
MockExecutorStream,
conn_event_rx,
);
tokio::pin!(stream);
let mut received = 0;
for _ in 0..30 {
if let Ok(Some(SelectResult::Notification(Some(_)))) =
timeout(Duration::from_millis(5), stream.as_mut().next()).await
{
received += 1;
}
if received >= 10 {
break;
}
}
assert_eq!(
received, 10,
"Should receive all messages despite rapid cancellations"
);
}
#[tokio::test]
#[test_log::test]
async fn test_priority_select_event_loop_simulation() {
use futures::StreamExt;
let (notif_tx, notif_rx) = mpsc::channel::<Either<NetMessage, NodeEvent>>(10);
let (op_tx, op_rx) = mpsc::channel::<OpExecutionPayload>(10);
let (_conn_event_tx, conn_event_rx) = mpsc::channel(10);
let (bridge_tx, bridge_rx) = mpsc::channel::<P2pBridgeEvent>(10);
let (node_tx, node_rx) = mpsc::channel::<NodeEvent>(10);
let notif_tx_clone = notif_tx.clone();
let op_tx_clone = op_tx.clone();
let bridge_tx_clone = bridge_tx.clone();
let node_tx_clone = node_tx.clone();
GlobalExecutor::spawn(async move {
sleep(Duration::from_millis(10)).await;
for i in 0..3 {
let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
tracing::info!("Sending notification {}", i);
notif_tx_clone.send(Either::Left(test_msg)).await.unwrap();
}
for i in 0..2 {
let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
let (callback_tx, _) = mpsc::channel(1);
tracing::info!("Sending op_execution {}", i);
op_tx_clone
.send((callback_tx, test_msg, None))
.await
.unwrap();
}
for i in 0..2 {
tracing::info!("Sending bridge event {}", i);
bridge_tx_clone
.send(P2pBridgeEvent::NodeAction(NodeEvent::Disconnect {
cause: None,
}))
.await
.unwrap();
}
tracing::info!("Sending node controller event");
node_tx_clone
.send(NodeEvent::Disconnect { cause: None })
.await
.unwrap();
});
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientStream,
MockExecutorStream,
conn_event_rx,
);
tokio::pin!(stream);
let mut received_events = Vec::new();
let expected_count = 8;
for iteration in 0..expected_count {
tracing::info!("Event loop iteration {}", iteration);
let result = timeout(Duration::from_millis(50), stream.as_mut().next()).await;
assert!(result.is_ok(), "Iteration {} should complete", iteration);
let event = result.unwrap().expect("Stream should yield value");
match &event {
SelectResult::Notification(_) => received_events.push("notification"),
SelectResult::OpExecution(_) => received_events.push("op_execution"),
SelectResult::ConnBridge(_) => received_events.push("conn_bridge"),
SelectResult::Handshake(_) => received_events.push("handshake"),
SelectResult::NodeController(_) => received_events.push("node_controller"),
SelectResult::PeerConnection(_)
| SelectResult::ClientTransaction(_)
| SelectResult::ExecutorTransaction(_) => received_events.push("other"),
}
tracing::info!(
"Event loop iteration {} received: {:?}",
iteration,
received_events.last()
);
}
assert_eq!(
received_events.len(),
expected_count,
"Should receive all {} messages",
expected_count
);
let notif_count = received_events
.iter()
.filter(|&e| *e == "notification")
.count();
let op_count = received_events
.iter()
.filter(|&e| *e == "op_execution")
.count();
let bridge_count = received_events
.iter()
.filter(|&e| *e == "conn_bridge")
.count();
let node_count = received_events
.iter()
.filter(|&e| *e == "node_controller")
.count();
tracing::info!(
"Received counts - notifications: {}, op_execution: {}, bridge: {}, node_controller: {}",
notif_count,
op_count,
bridge_count,
node_count
);
assert_eq!(notif_count, 3, "Should receive 3 notifications");
assert_eq!(op_count, 2, "Should receive 2 op_execution messages");
assert_eq!(bridge_count, 2, "Should receive 2 bridge messages");
assert_eq!(node_count, 1, "Should receive 1 node_controller message");
let first_notif_idx = received_events.iter().position(|e| *e == "notification");
let last_notif_idx = received_events.iter().rposition(|e| *e == "notification");
let first_op_idx = received_events.iter().position(|e| *e == "op_execution");
let last_op_idx = received_events.iter().rposition(|e| *e == "op_execution");
let first_bridge_idx = received_events.iter().position(|e| *e == "conn_bridge");
assert_eq!(
first_notif_idx,
Some(0),
"First notification should be at index 0"
);
assert_eq!(
last_notif_idx,
Some(2),
"Last notification should be at index 2"
);
assert!(
first_op_idx.unwrap() > last_notif_idx.unwrap(),
"Op execution should come after all notifications"
);
assert_eq!(
first_op_idx,
Some(3),
"First op_execution should be at index 3"
);
assert_eq!(
last_op_idx,
Some(4),
"Last op_execution should be at index 4"
);
assert!(
first_bridge_idx.unwrap() > last_op_idx.unwrap(),
"Bridge events should come after all op_executions"
);
tracing::info!(
"✓ All {} messages received in correct priority order: {:?}",
expected_count,
received_events
);
drop(notif_tx);
drop(op_tx);
drop(bridge_tx);
drop(node_tx);
}
#[tokio::test]
#[test_log::test]
async fn test_priority_select_concurrent_random_stress() {
test_with_seed(42).await;
test_with_seed(123).await;
test_with_seed(999).await;
test_with_seed(7777).await;
test_with_seed(31415).await;
}
async fn test_with_seed(seed: u64) {
use rand::Rng;
use rand::SeedableRng;
use rand::rngs::StdRng;
tracing::info!("=== Stress test with seed {} ===", seed);
const NOTIF_COUNT: usize = 500;
const OP_COUNT: usize = 400;
const BRIDGE_COUNT: usize = 300;
const NODE_COUNT: usize = 200;
const CLIENT_COUNT: usize = 200;
const EXECUTOR_COUNT: usize = 100;
const TOTAL_MESSAGES: usize =
NOTIF_COUNT + OP_COUNT + BRIDGE_COUNT + NODE_COUNT + CLIENT_COUNT + EXECUTOR_COUNT;
let mut rng = StdRng::seed_from_u64(seed);
let make_delays = |count: usize, rng: &mut StdRng| -> Vec<u64> {
(0..count)
.map(|_| {
if rng.random_range(0..10) == 0 {
rng.random_range(1000..5000) } else {
rng.random_range(50..500) }
})
.collect()
};
let notif_delays = make_delays(NOTIF_COUNT, &mut rng);
let op_delays = make_delays(OP_COUNT, &mut rng);
let bridge_delays = make_delays(BRIDGE_COUNT, &mut rng);
let node_delays = make_delays(NODE_COUNT, &mut rng);
let client_delays = make_delays(CLIENT_COUNT, &mut rng);
let executor_delays = make_delays(EXECUTOR_COUNT, &mut rng);
let (notif_tx, notif_rx) = mpsc::channel::<Either<NetMessage, NodeEvent>>(100);
let (op_tx, op_rx) = mpsc::channel::<OpExecutionPayload>(100);
let (_conn_event_tx, conn_event_rx) = mpsc::channel(100);
let (bridge_tx, bridge_rx) = mpsc::channel::<P2pBridgeEvent>(100);
let (node_tx, node_rx) = mpsc::channel::<NodeEvent>(100);
let (client_tx, client_rx) = mpsc::channel::<(
crate::client_events::ClientId,
crate::contract::WaitingTransaction,
)>(100);
let (executor_tx, executor_rx) = mpsc::channel::<Transaction>(100);
tracing::info!(
"Starting stress test with {} total messages from 6 concurrent tasks",
TOTAL_MESSAGES
);
let notif_handle = GlobalExecutor::spawn(async move {
for (i, &delay_us) in notif_delays.iter().enumerate() {
sleep(Duration::from_micros(delay_us)).await;
let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
tracing::debug!(
"Notification task sending message {} after {}us delay",
i,
delay_us
);
notif_tx.send(Either::Left(test_msg)).await.unwrap();
}
tracing::info!("Notification task sent all {} messages", NOTIF_COUNT);
NOTIF_COUNT
});
let op_handle = GlobalExecutor::spawn(async move {
for (i, &delay_us) in op_delays.iter().enumerate() {
sleep(Duration::from_micros(delay_us)).await;
let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
let (callback_tx, _) = mpsc::channel(1);
tracing::debug!(
"OpExecution task sending message {} after {}us delay",
i,
delay_us
);
op_tx.send((callback_tx, test_msg, None)).await.unwrap();
}
tracing::info!("OpExecution task sent all {} messages", OP_COUNT);
OP_COUNT
});
let bridge_handle = GlobalExecutor::spawn(async move {
for (i, &delay_us) in bridge_delays.iter().enumerate() {
sleep(Duration::from_micros(delay_us)).await;
tracing::debug!(
"Bridge task sending message {} after {}us delay",
i,
delay_us
);
bridge_tx
.send(P2pBridgeEvent::NodeAction(NodeEvent::Disconnect {
cause: None,
}))
.await
.unwrap();
}
tracing::info!("Bridge task sent all {} messages", BRIDGE_COUNT);
BRIDGE_COUNT
});
let node_handle = GlobalExecutor::spawn(async move {
for (i, &delay_us) in node_delays.iter().enumerate() {
sleep(Duration::from_micros(delay_us)).await;
tracing::debug!(
"NodeController task sending message {} after {}us delay",
i,
delay_us
);
node_tx
.send(NodeEvent::Disconnect { cause: None })
.await
.unwrap();
}
tracing::info!("NodeController task sent all {} messages", NODE_COUNT);
NODE_COUNT
});
let client_handle = GlobalExecutor::spawn(async move {
for (i, &delay_us) in client_delays.iter().enumerate() {
sleep(Duration::from_micros(delay_us)).await;
let client_id = crate::client_events::ClientId::next();
let waiting_tx = crate::contract::WaitingTransaction::Transaction(Transaction::new::<
crate::operations::put::PutMsg,
>());
tracing::debug!(
"Client task sending message {} after {}us delay",
i,
delay_us
);
client_tx.send((client_id, waiting_tx)).await.unwrap();
}
tracing::info!("Client task sent all {} messages", CLIENT_COUNT);
CLIENT_COUNT
});
let executor_handle = GlobalExecutor::spawn(async move {
for (i, &delay_us) in executor_delays.iter().enumerate() {
sleep(Duration::from_micros(delay_us)).await;
tracing::debug!(
"Executor task sending message {} after {}us delay",
i,
delay_us
);
executor_tx
.send(Transaction::new::<crate::operations::put::PutMsg>())
.await
.unwrap();
}
tracing::info!("Executor task sent all {} messages", EXECUTOR_COUNT);
EXECUTOR_COUNT
});
sleep(Duration::from_micros(100)).await;
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientReceiverStream { rx: client_rx },
MockExecutorReceiverStream { rx: executor_rx },
conn_event_rx,
);
tokio::pin!(stream);
let mut received_events = Vec::new();
let mut iteration = 0;
use futures::StreamExt;
while received_events.len() < TOTAL_MESSAGES {
let result = timeout(Duration::from_millis(100), stream.as_mut().next()).await;
assert!(result.is_ok(), "Iteration {} timed out", iteration);
let Some(event) = result.unwrap() else {
tracing::debug!("Stream ended (all channels closed)");
break;
};
let (event_name, is_real_message) = match &event {
SelectResult::Notification(msg) => {
if msg.is_some() {
tracing::debug!("Received Notification message");
("notification", true)
} else {
tracing::debug!("Notification channel closed");
("notification", false)
}
}
SelectResult::OpExecution(msg) => {
if msg.is_some() {
tracing::debug!("Received OpExecution message");
("op_execution", true)
} else {
tracing::debug!("OpExecution channel closed");
("op_execution", false)
}
}
SelectResult::PeerConnection(msg) => ("peer_connection", msg.is_some()),
SelectResult::ConnBridge(msg) => {
if msg.is_some() {
tracing::debug!("Received ConnBridge message");
("conn_bridge", true)
} else {
tracing::debug!("ConnBridge channel closed");
("conn_bridge", false)
}
}
SelectResult::Handshake(_) => {
("handshake", false) }
SelectResult::NodeController(msg) => {
if msg.is_some() {
tracing::debug!("Received NodeController message");
("node_controller", true)
} else {
tracing::debug!("NodeController channel closed");
("node_controller", false)
}
}
SelectResult::ClientTransaction(result) => {
if result.is_ok() {
tracing::debug!("Received ClientTransaction message");
("client_transaction", true)
} else {
tracing::debug!("ClientTransaction channel closed or error");
("client_transaction", false)
}
}
SelectResult::ExecutorTransaction(result) => {
if result.is_ok() {
tracing::debug!("Received ExecutorTransaction message");
("executor_transaction", true)
} else {
tracing::debug!("ExecutorTransaction channel closed or error");
("executor_transaction", false)
}
}
};
if is_real_message {
received_events.push(event_name);
if received_events.len() % 100 == 0 {
tracing::info!(
"Received {} of {} real messages",
received_events.len(),
TOTAL_MESSAGES
);
}
} else {
tracing::debug!(
"Iteration {}: Received channel close from {}",
iteration,
event_name
);
}
iteration += 1;
if iteration > TOTAL_MESSAGES * 3 {
tracing::error!(
"Receiver loop exceeded maximum iterations. Received {} of {} messages after {} iterations",
received_events.len(),
TOTAL_MESSAGES,
iteration
);
panic!("Receiver loop exceeded maximum iterations - possible deadlock");
}
}
let sent_notif_count = notif_handle.await.unwrap();
let sent_op_count = op_handle.await.unwrap();
let sent_bridge_count = bridge_handle.await.unwrap();
let sent_node_count = node_handle.await.unwrap();
let sent_client_count = client_handle.await.unwrap();
let sent_executor_count = executor_handle.await.unwrap();
let total_sent = sent_notif_count
+ sent_op_count
+ sent_bridge_count
+ sent_node_count
+ sent_client_count
+ sent_executor_count;
tracing::info!("All sender tasks completed. Total sent: {}", total_sent);
tracing::info!(
"Receiver completed. Total received: {}",
received_events.len()
);
assert_eq!(
received_events.len(),
total_sent,
"Should receive all {} sent messages",
total_sent
);
assert_eq!(
received_events.len(),
TOTAL_MESSAGES,
"Total received should match expected total"
);
let recv_notif_count = received_events
.iter()
.filter(|&e| *e == "notification")
.count();
let recv_op_count = received_events
.iter()
.filter(|&e| *e == "op_execution")
.count();
let recv_bridge_count = received_events
.iter()
.filter(|&e| *e == "conn_bridge")
.count();
let recv_node_count = received_events
.iter()
.filter(|&e| *e == "node_controller")
.count();
let recv_client_count = received_events
.iter()
.filter(|&e| *e == "client_transaction")
.count();
let recv_executor_count = received_events
.iter()
.filter(|&e| *e == "executor_transaction")
.count();
tracing::info!("Sent vs Received:");
tracing::info!(
" notifications: sent={}, received={}",
sent_notif_count,
recv_notif_count
);
tracing::info!(
" op_execution: sent={}, received={}",
sent_op_count,
recv_op_count
);
tracing::info!(
" bridge: sent={}, received={}",
sent_bridge_count,
recv_bridge_count
);
tracing::info!(
" node_controller: sent={}, received={}",
sent_node_count,
recv_node_count
);
tracing::info!(
" client: sent={}, received={}",
sent_client_count,
recv_client_count
);
tracing::info!(
" executor: sent={}, received={}",
sent_executor_count,
recv_executor_count
);
assert_eq!(
recv_notif_count, sent_notif_count,
"Notification count mismatch"
);
assert_eq!(recv_op_count, sent_op_count, "OpExecution count mismatch");
assert_eq!(
recv_bridge_count, sent_bridge_count,
"Bridge count mismatch"
);
assert_eq!(
recv_node_count, sent_node_count,
"NodeController count mismatch"
);
assert_eq!(
recv_client_count, sent_client_count,
"Client count mismatch"
);
assert_eq!(
recv_executor_count, sent_executor_count,
"Executor count mismatch"
);
tracing::info!("✓ STRESS TEST PASSED for seed {}!", seed);
tracing::info!(
" All {} messages received correctly from 6 concurrent senders with random delays",
TOTAL_MESSAGES
);
tracing::info!(" Received events: {:?}", received_events);
tracing::info!(
" Priority ordering respected: when multiple messages buffered, highest priority selected first"
);
}
#[tokio::test]
#[test_log::test]
async fn test_priority_select_all_pending_waker_registration() {
use futures::StreamExt;
let (notif_tx, notif_rx) = mpsc::channel::<Either<NetMessage, NodeEvent>>(10);
let (op_tx, op_rx) = mpsc::channel::<OpExecutionPayload>(10);
let (_conn_event_tx, conn_event_rx) = mpsc::channel(10);
let (bridge_tx, bridge_rx) = mpsc::channel::<P2pBridgeEvent>(10);
let (node_tx, node_rx) = mpsc::channel::<NodeEvent>(10);
let (client_tx, client_rx) = mpsc::channel::<(ClientId, WaitingTransaction)>(10);
let (executor_tx, executor_rx) = mpsc::channel::<Transaction>(10);
tracing::info!("Creating PrioritySelectStream with all channels empty");
GlobalExecutor::spawn(async move {
sleep(Duration::from_millis(10)).await;
tracing::info!("All wakers should now be registered, sending messages");
tracing::info!("Sending to executor channel (lowest priority)");
executor_tx
.send(Transaction::new::<crate::operations::put::PutMsg>())
.await
.unwrap();
tracing::info!("Sending to client channel");
let client_id = crate::client_events::ClientId::next();
let waiting_tx = crate::contract::WaitingTransaction::Transaction(Transaction::new::<
crate::operations::put::PutMsg,
>());
client_tx.send((client_id, waiting_tx)).await.unwrap();
tracing::info!("Sending to node controller channel");
node_tx
.send(NodeEvent::Disconnect { cause: None })
.await
.unwrap();
tracing::info!("Sending to bridge channel");
bridge_tx
.send(P2pBridgeEvent::NodeAction(NodeEvent::Disconnect {
cause: None,
}))
.await
.unwrap();
tracing::info!("Sending to op execution channel (second priority)");
let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
let (callback_tx, _) = mpsc::channel(1);
op_tx
.send((callback_tx, test_msg.clone(), None))
.await
.unwrap();
tracing::info!("Sending to notification channel (highest priority)");
notif_tx.send(Either::Left(test_msg)).await.unwrap();
});
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientReceiverStream { rx: client_rx },
MockExecutorReceiverStream { rx: executor_rx },
conn_event_rx,
);
tokio::pin!(stream);
tracing::info!("PrioritySelectStream started, should poll all channels and go Pending");
let result = timeout(Duration::from_millis(100), stream.next()).await;
assert!(
result.is_ok(),
"Select should wake up when any message arrives"
);
let select_result = result.unwrap().expect("Stream should yield value");
match select_result {
SelectResult::Notification(_) => {
tracing::info!(
"✓ Correctly received Notification despite 5 other channels having messages"
);
}
SelectResult::OpExecution(_) => {
panic!("Should prioritize Notification over OpExecution")
}
SelectResult::ConnBridge(_) => panic!("Should prioritize Notification over ConnBridge"),
SelectResult::NodeController(_) => {
panic!("Should prioritize Notification over NodeController")
}
SelectResult::ClientTransaction(_) => {
panic!("Should prioritize Notification over ClientTransaction")
}
SelectResult::ExecutorTransaction(_) => {
panic!("Should prioritize Notification over ExecutorTransaction")
}
SelectResult::PeerConnection(_) | SelectResult::Handshake(_) => panic!("Unexpected result"),
}
}
#[tokio::test]
#[test_log::test]
async fn test_sparse_messages_reproduce_race() {
tracing::info!(
"=== Testing sparse messages with PrioritySelectStream (verifying fix for #1932) ==="
);
let (notif_tx, notif_rx) = mpsc::channel::<Either<NetMessage, NodeEvent>>(10);
let (_, op_rx) = mpsc::channel(1);
let (_conn_event_tx, conn_event_rx) = mpsc::channel(10);
let (_, bridge_rx) = mpsc::channel(1);
let (_, node_rx) = mpsc::channel(1);
let sender = GlobalExecutor::spawn(async move {
for i in 0..5 {
sleep(Duration::from_millis(200)).await;
tracing::info!(
"Sender: Sending message {} at {:?}",
i,
tokio::time::Instant::now()
);
let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
match notif_tx.send(Either::Left(test_msg)).await {
Ok(_) => tracing::info!("Sender: Message {} sent successfully", i),
Err(e) => {
tracing::error!("Sender: Failed to send message {}: {:?}", i, e);
break;
}
}
}
tracing::info!("Sender: Finished sending all messages");
});
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientStream,
MockExecutorStream,
conn_event_rx,
);
tokio::pin!(stream);
let mut received = 0;
let mut iteration = 0;
while received < 5 && iteration < 20 {
iteration += 1;
tracing::info!(
"Iteration {}: Polling PrioritySelectStream (reusing same stream)",
iteration
);
match timeout(Duration::from_millis(300), stream.as_mut().next()).await {
Ok(Some(SelectResult::Notification(Some(_)))) => {
received += 1;
tracing::info!(
"✅ Iteration {}: Received message {} of 5",
iteration,
received
);
}
Ok(Some(_)) => {
tracing::debug!("Iteration {}: Got other event", iteration);
}
Ok(None) => {
tracing::error!("Stream ended unexpectedly");
break;
}
Err(_) => {
tracing::warn!("Iteration {}: Timeout waiting for message", iteration);
}
}
}
sender.await.unwrap();
tracing::info!("Sender task completed, received {} messages", received);
assert_eq!(
received, 5,
"❌ FAIL: PrioritySelectStream still lost messages! Expected 5 but received {} in {} iterations.\n\
The fix should prevent lost wakeups by keeping the stream alive.",
received, iteration
);
tracing::info!("✅ PASS: All 5 messages received without loss using PrioritySelectStream!");
}
#[tokio::test]
#[test_log::test]
async fn test_stream_no_lost_messages_sparse_arrivals() {
use tokio_stream::wrappers::ReceiverStream;
tracing::info!("=== Testing stream approach doesn't lose messages (sparse arrivals) ===");
let (tx, rx) = mpsc::channel::<String>(10);
let stream = ReceiverStream::new(rx);
struct MessageStream<S> {
inner: S,
}
impl<S: Stream + Unpin> Stream for MessageStream<S> {
type Item = S::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}
let mut message_stream = MessageStream { inner: stream };
let sender = GlobalExecutor::spawn(async move {
for i in 0..5 {
sleep(Duration::from_millis(200)).await;
tracing::info!(
"Sender: Sending message {} at {:?}",
i,
tokio::time::Instant::now()
);
tx.send(format!("msg{}", i)).await.unwrap();
tracing::info!("Sender: Message {} sent successfully", i);
}
});
let mut received = 0;
for iteration in 1..=20 {
tracing::info!("Iteration {}: Calling stream.next().await", iteration);
let msg = timeout(Duration::from_millis(300), message_stream.next()).await;
match msg {
Ok(Some(msg)) => {
received += 1;
tracing::info!("✓ Received: {} (total: {})", msg, received);
}
Ok(None) => {
tracing::info!("Stream ended");
break;
}
Err(_) => {
tracing::info!(
"Timeout on iteration {} (received {} so far)",
iteration,
received
);
if received >= 5 {
break; }
}
}
}
sender.await.unwrap();
tracing::info!("Sender task completed, received {} messages", received);
assert_eq!(
received, 5,
"Stream approach should receive ALL messages! Expected 5 but got {}.\n\
The stream maintains waker registration across .next().await calls.",
received
);
tracing::info!(
"✓ SUCCESS: Stream-based approach received all 5 messages with sparse arrivals!"
);
tracing::info!("✓ Waker registration was maintained across stream.next().await iterations!");
}
#[tokio::test]
#[test_log::test]
async fn test_recreating_futures_maintains_waker() {
tracing::info!("=== Testing that recreating futures on each poll maintains waker ===");
struct MockSpecial {
counter: std::sync::Arc<std::sync::Mutex<usize>>,
rx: tokio::sync::mpsc::Receiver<String>,
}
impl MockSpecial {
async fn wait_for_event(&mut self) -> Option<String> {
tracing::info!("MockSpecial::wait_for_event called");
let msg = self.rx.recv().await?;
let mut counter = self.counter.lock().unwrap();
*counter += 1;
tracing::info!("MockSpecial: received '{}', counter now {}", msg, *counter);
Some(msg)
}
}
struct TestStream {
special: MockSpecial,
}
impl Stream for TestStream {
type Item = String;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let fut = self.special.wait_for_event();
tokio::pin!(fut);
match fut.poll(cx) {
Poll::Ready(Some(msg)) => Poll::Ready(Some(msg)),
Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending,
}
}
}
let counter = std::sync::Arc::new(std::sync::Mutex::new(0));
let (tx, rx) = mpsc::channel::<String>(10);
let mut test_stream = TestStream {
special: MockSpecial {
counter: counter.clone(),
rx,
},
};
let sender = GlobalExecutor::spawn(async move {
for i in 0..5 {
sleep(Duration::from_millis(200)).await;
tracing::info!("Sender: Sending message {}", i);
tx.send(format!("msg{}", i)).await.unwrap();
}
});
let mut received = 0;
for iteration in 1..=20 {
tracing::info!("Iteration {}: Calling stream.next().await", iteration);
let msg = timeout(Duration::from_millis(300), test_stream.next()).await;
match msg {
Ok(Some(msg)) => {
received += 1;
tracing::info!("✓ Received: {} (total: {})", msg, received);
}
Ok(None) => {
tracing::info!("Stream ended");
break;
}
Err(_) => {
tracing::info!(
"Timeout on iteration {} (received {} so far)",
iteration,
received
);
if received >= 5 {
break;
}
}
}
}
sender.await.unwrap();
assert_eq!(
received, 5,
"Recreating futures on each poll should STILL receive all messages! Got {}",
received
);
let final_counter = *counter.lock().unwrap();
assert_eq!(final_counter, 5, "Counter should be 5");
tracing::info!("✓ SUCCESS: Recreating futures on each poll MAINTAINS waker registration!");
tracing::info!(
"✓ The stream struct staying alive is what matters, not the individual futures!"
);
}
#[tokio::test]
#[test_log::test]
async fn test_recreating_futures_with_nested_select() {
use futures::StreamExt;
tracing::info!("=== Testing stream with NESTED select (like HandshakeHandler) ===");
struct MockWithNestedSelect {
rx1: tokio::sync::mpsc::Receiver<String>,
rx2: tokio::sync::mpsc::Receiver<String>,
counter: std::sync::Arc<std::sync::Mutex<usize>>,
rx1_closed: bool,
rx2_closed: bool,
}
impl MockWithNestedSelect {
async fn wait_for_event(&mut self) -> String {
loop {
tokio::select! {
msg1 = self.rx1.recv(), if !self.rx1_closed => {
match msg1 {
Some(msg) => {
let mut counter = self.counter.lock().unwrap();
*counter += 1;
tracing::info!("Nested select: rx1 received '{}', counter {}", msg, *counter);
return format!("rx1:{}", msg);
}
None => {
self.rx1_closed = true;
if self.rx2_closed {
return "rx1:closed".to_string();
}
continue;
}
}
}
msg2 = self.rx2.recv(), if !self.rx2_closed => {
match msg2 {
Some(msg) => {
let mut counter = self.counter.lock().unwrap();
*counter += 1;
tracing::info!("Nested select: rx2 received '{}', counter {}", msg, *counter);
return format!("rx2:{}", msg);
}
None => {
self.rx2_closed = true;
if self.rx1_closed {
return "rx2:closed".to_string();
}
continue;
}
}
}
else => {
return "both_closed".to_string();
}
}
}
}
}
struct TestStream {
special: MockWithNestedSelect,
}
impl Stream for TestStream {
type Item = String;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let fut = self.special.wait_for_event();
tokio::pin!(fut);
match fut.poll(cx) {
Poll::Ready(msg) => Poll::Ready(Some(msg)),
Poll::Pending => Poll::Pending,
}
}
}
let counter = std::sync::Arc::new(std::sync::Mutex::new(0));
let (tx1, rx1) = mpsc::channel::<String>(10);
let (tx2, rx2) = mpsc::channel::<String>(10);
for i in 0..3 {
if i % 2 == 0 {
tracing::info!("Sending to rx1: msg{}", i);
tx1.send(format!("msg{}", i)).await.unwrap();
} else {
tracing::info!("Sending to rx2: msg{}", i);
tx2.send(format!("msg{}", i)).await.unwrap();
}
}
tracing::info!("All 3 messages sent, now dropping senders");
drop(tx1);
drop(tx2);
let test_stream = TestStream {
special: MockWithNestedSelect {
rx1,
rx2,
counter: counter.clone(),
rx1_closed: false,
rx2_closed: false,
},
};
tokio::pin!(test_stream);
let mut received = Vec::new();
for iteration in 1..=10 {
tracing::info!("Iteration {}: Calling stream.next().await", iteration);
let msg = timeout(Duration::from_millis(100), test_stream.as_mut().next()).await;
match msg {
Ok(Some(msg)) => {
if msg.contains("closed") {
tracing::info!("Channel closed: {}", msg);
continue;
}
received.push(msg.clone());
tracing::info!("✓ Received: {} (total: {})", msg, received.len());
if received.len() >= 3 {
break;
}
}
Ok(None) => {
tracing::info!("Stream ended");
break;
}
Err(_) => {
tracing::info!(
"Timeout on iteration {} (received {} so far)",
iteration,
received.len()
);
break;
}
}
}
assert_eq!(
received.len(),
3,
"Stream with NESTED select should receive all messages! Got {} messages: {:?}",
received.len(),
received
);
let final_counter = *counter.lock().unwrap();
assert_eq!(final_counter, 3, "Counter should be 3");
tracing::info!(
"✅ SUCCESS: Stream with NESTED select (like HandshakeHandler) maintains waker registration!"
);
tracing::info!("✅ Received all messages: {:?}", received);
}
#[tokio::test]
#[test_log::test]
async fn test_nested_select_concurrent_arrivals() {
use futures::StreamExt;
tracing::info!("=== Testing nested select with rapid concurrent arrivals ===");
struct MockWithNestedSelect {
rx1: tokio::sync::mpsc::Receiver<String>,
rx2: tokio::sync::mpsc::Receiver<String>,
rx1_closed: bool,
rx2_closed: bool,
}
impl MockWithNestedSelect {
async fn wait_for_event(&mut self) -> String {
loop {
tokio::select! {
msg1 = self.rx1.recv(), if !self.rx1_closed => {
match msg1 {
Some(msg) => {
tracing::info!("Nested select: rx1 received '{}'", msg);
return format!("rx1:{}", msg);
}
None => {
self.rx1_closed = true;
if self.rx2_closed {
return "rx1:closed".to_string();
}
continue;
}
}
}
msg2 = self.rx2.recv(), if !self.rx2_closed => {
match msg2 {
Some(msg) => {
tracing::info!("Nested select: rx2 received '{}'", msg);
return format!("rx2:{}", msg);
}
None => {
self.rx2_closed = true;
if self.rx1_closed {
return "rx2:closed".to_string();
}
continue;
}
}
}
else => {
return "both_closed".to_string();
}
}
}
}
}
struct TestStream {
special: MockWithNestedSelect,
}
impl Stream for TestStream {
type Item = String;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let fut = self.special.wait_for_event();
tokio::pin!(fut);
match fut.poll(cx) {
Poll::Ready(msg) => Poll::Ready(Some(msg)),
Poll::Pending => Poll::Pending,
}
}
}
let (tx1, rx1) = mpsc::channel::<String>(10);
let (tx2, rx2) = mpsc::channel::<String>(10);
let test_stream = TestStream {
special: MockWithNestedSelect {
rx1,
rx2,
rx1_closed: false,
rx2_closed: false,
},
};
tokio::pin!(test_stream);
const MESSAGE_COUNT: usize = 1000;
GlobalExecutor::spawn(async move {
for i in 0..MESSAGE_COUNT {
if i % 2 == 0 {
if i % 100 == 0 {
tracing::info!("Sending msg{} to rx1 ({} sent)", i, i);
}
tx1.send(format!("msg{}", i)).await.unwrap();
} else {
if i % 100 == 0 {
tracing::info!("Sending msg{} to rx2 ({} sent)", i, i);
}
tx2.send(format!("msg{}", i)).await.unwrap();
}
sleep(Duration::from_micros(10)).await;
}
tracing::info!("Sender finished: sent all {} messages", MESSAGE_COUNT);
});
let mut received = Vec::new();
for iteration in 0..(MESSAGE_COUNT + 100) {
match timeout(Duration::from_millis(100), test_stream.as_mut().next()).await {
Ok(Some(msg)) => {
if !msg.contains("closed") {
received.push(msg);
if received.len() % 100 == 0 {
tracing::info!("Received {} of {} messages", received.len(), MESSAGE_COUNT);
}
}
if received.len() >= MESSAGE_COUNT {
break;
}
}
Ok(None) => break,
Err(_) => {
tracing::info!(
"Timeout on iteration {} after receiving {} messages",
iteration,
received.len()
);
break;
}
}
}
assert_eq!(
received.len(),
MESSAGE_COUNT,
"Should receive all {} messages even with rapid arrivals! Got {}. First 10: {:?}, Last 10: {:?}",
MESSAGE_COUNT,
received.len(),
&received[..received.len().min(10)],
&received[received.len().saturating_sub(10)..]
);
tracing::info!("✅ SUCCESS: All {} rapid messages received!", MESSAGE_COUNT);
tracing::info!(
"✅ Nested select with stream maintains waker registration under high concurrent load!"
);
}
#[tokio::test]
#[test_log::test]
async fn test_waker_registration_after_pending_poll() {
let (tx, rx) = mpsc::channel::<Either<NetMessage, NodeEvent>>(10);
let (_conn_event_tx, conn_event_rx) = mpsc::channel(10);
let (op_tx, op_rx) = mpsc::channel::<OpExecutionPayload>(10);
let (bridge_tx, bridge_rx) = mpsc::channel(10);
let (node_tx, node_rx) = mpsc::channel(10);
let stream = PrioritySelectStream::new(
rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientStream,
MockExecutorStream,
conn_event_rx,
);
tokio::pin!(stream);
let tx_clone = tx.clone();
let sender = GlobalExecutor::spawn(async move {
sleep(Duration::from_millis(50)).await;
let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
tx_clone.send(Either::Left(test_msg)).await.unwrap();
tracing::info!("Message sent after delay");
});
let result = timeout(Duration::from_millis(200), stream.next()).await;
assert!(
result.is_ok(),
"Stream should wake up when message arrives after initial Pending poll. \
If this times out, waker registration is being lost (futures recreated in poll)"
);
match result.unwrap() {
Some(SelectResult::Notification(Some(_))) => {
tracing::info!("✅ Waker registration maintained - message received correctly");
}
other => panic!(
"Expected Notification(Some(_)), got {:?}. \
Waker registration may be broken.",
other
),
}
sender.await.unwrap();
drop(tx);
drop(op_tx);
drop(bridge_tx);
drop(node_tx);
}
#[tokio::test]
#[test_log::test]
async fn test_waker_survives_multiple_pending_polls() {
let (tx, rx) = mpsc::channel::<Either<NetMessage, NodeEvent>>(10);
let (_conn_event_tx, conn_event_rx) = mpsc::channel(10);
let (op_tx, op_rx) = mpsc::channel::<OpExecutionPayload>(10);
let (bridge_tx, bridge_rx) = mpsc::channel(10);
let (node_tx, node_rx) = mpsc::channel(10);
let stream = PrioritySelectStream::new(
rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientStream,
MockExecutorStream,
conn_event_rx,
);
tokio::pin!(stream);
for i in 0..5 {
let poll_result = timeout(Duration::from_millis(5), stream.as_mut().next()).await;
assert!(
poll_result.is_err(),
"Poll {} should timeout (no messages yet)",
i
);
}
let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
tx.send(Either::Left(test_msg)).await.unwrap();
let result = timeout(Duration::from_millis(100), stream.as_mut().next()).await;
assert!(
result.is_ok(),
"Stream should receive message even after multiple Pending polls. \
If this fails, waker registration is being lost across poll cycles."
);
match result.unwrap() {
Some(SelectResult::Notification(Some(_))) => {
tracing::info!(
"✅ Waker registration maintained across {} Pending polls",
5
);
}
other => panic!("Expected Notification(Some(_)), got {:?}", other),
}
drop(op_tx);
drop(bridge_tx);
drop(node_tx);
}
#[tokio::test]
#[test_log::test]
async fn test_closed_handshake_stream_no_spin() {
let (notif_tx, notif_rx) = mpsc::channel(10);
let (_op_tx, op_rx) = mpsc::channel::<OpExecutionPayload>(10);
let (_conn_event_tx, conn_event_rx) = mpsc::channel(10);
let (_bridge_tx, bridge_rx) = mpsc::channel(10);
let (_node_tx, node_rx) = mpsc::channel(10);
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
ClosedHandshakeStream,
node_rx,
MockClientStream,
MockExecutorStream,
conn_event_rx,
);
tokio::pin!(stream);
let result = timeout(Duration::from_millis(100), stream.as_mut().next()).await;
assert!(
result.is_ok(),
"Should get closure notification immediately"
);
match result.unwrap() {
Some(SelectResult::Handshake(None)) => {}
other => panic!("Expected Handshake(None), got {:?}", other),
}
let result = timeout(Duration::from_millis(50), stream.as_mut().next()).await;
assert!(
result.is_err(),
"Should not get another event — stream should be pending, not spinning"
);
let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
notif_tx.send(Either::Left(test_msg)).await.unwrap();
let result = timeout(Duration::from_millis(100), stream.as_mut().next()).await;
assert!(result.is_ok(), "Notification should still arrive");
match result.unwrap() {
Some(SelectResult::Notification(Some(_))) => {}
other => panic!("Expected Notification(Some(_)), got {:?}", other),
}
}
fn dummy_notif_msg() -> Either<NetMessage, NodeEvent> {
Either::Left(NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
)))
}
fn dummy_client_tx() -> (ClientId, WaitingTransaction) {
let client_id = crate::client_events::ClientId::next();
let waiting_tx = crate::contract::WaitingTransaction::Transaction(Transaction::new::<
crate::operations::put::PutMsg,
>());
(client_id, waiting_tx)
}
async fn drain_stream<H, C, E>(
stream: &mut Pin<&mut PrioritySelectStream<H, C, E>>,
max_items: usize,
) -> Vec<&'static str>
where
H: Stream<Item = crate::node::network_bridge::handshake::Event> + Unpin,
C: Stream<Item = (ClientId, WaitingTransaction)> + Unpin,
E: Stream<Item = Transaction> + Unpin,
{
use futures::StreamExt;
let mut events = Vec::new();
for _ in 0..max_items {
match timeout(Duration::from_millis(50), stream.as_mut().next()).await {
Ok(Some(ref r)) => {
let name = match r {
SelectResult::Notification(Some(_)) => "notification",
SelectResult::OpExecution(Some(_)) => "op_execution",
SelectResult::PeerConnection(Some(_)) => "peer_connection",
SelectResult::ConnBridge(Some(_)) => "conn_bridge",
SelectResult::Handshake(Some(_)) => "handshake",
SelectResult::NodeController(Some(_)) => "node_controller",
SelectResult::ClientTransaction(Ok(_)) => "client_transaction",
SelectResult::ExecutorTransaction(Ok(_)) => "executor_transaction",
SelectResult::Notification(_)
| SelectResult::OpExecution(_)
| SelectResult::PeerConnection(_)
| SelectResult::ConnBridge(_)
| SelectResult::Handshake(_)
| SelectResult::NodeController(_)
| SelectResult::ClientTransaction(_)
| SelectResult::ExecutorTransaction(_) => continue,
};
events.push(name);
}
_ => break,
}
}
events
}
#[tokio::test]
#[test_log::test]
async fn test_anti_starvation_p7_serviced_under_load() {
const P1_COUNT: usize = 100;
const P7_COUNT: usize = 10;
let burst = PrioritySelectStream::<MockHandshakeStream, MockClientStream, MockExecutorStream>::MAX_HIGH_PRIORITY_BURST as usize;
let (notif_tx, notif_rx) = mpsc::channel(P1_COUNT + 10);
let (_, op_rx) = mpsc::channel(1);
let (_, conn_event_rx) = mpsc::channel(1);
let (_, bridge_rx) = mpsc::channel(1);
let (_, node_rx) = mpsc::channel(1);
let (client_tx, client_rx) = mpsc::channel(P7_COUNT + 10);
let (_, executor_rx) = mpsc::channel::<Transaction>(1);
for _ in 0..P1_COUNT {
notif_tx.send(dummy_notif_msg()).await.unwrap();
}
for _ in 0..P7_COUNT {
client_tx.send(dummy_client_tx()).await.unwrap();
}
drop(notif_tx);
drop(client_tx);
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientReceiverStream { rx: client_rx },
MockExecutorReceiverStream { rx: executor_rx },
conn_event_rx,
);
tokio::pin!(stream);
let events = drain_stream(&mut stream, P1_COUNT + P7_COUNT + 20).await;
let total_notif = events.iter().filter(|&&e| e == "notification").count();
let total_client = events
.iter()
.filter(|&&e| e == "client_transaction")
.count();
assert_eq!(total_notif, P1_COUNT, "All P1 messages received");
assert_eq!(total_client, P7_COUNT, "All P7 messages received");
let first_p7_idx = events
.iter()
.position(|&e| e == "client_transaction")
.expect("P7 messages must exist");
assert!(
first_p7_idx <= burst,
"First P7 message at index {} but should be at most {} (MAX_HIGH_PRIORITY_BURST)",
first_p7_idx,
burst
);
let last_p7_idx = events
.iter()
.rposition(|&e| e == "client_transaction")
.unwrap();
let last_notif_idx = events.iter().rposition(|&e| e == "notification").unwrap();
assert!(
first_p7_idx < last_notif_idx,
"P7 messages should be interleaved with P1, not all at the end"
);
tracing::info!(
"Anti-starvation: first P7 at index {}, last P7 at {}, last P1 at {}",
first_p7_idx,
last_p7_idx,
last_notif_idx
);
}
#[tokio::test]
#[test_log::test]
async fn test_anti_starvation_counter_reset_on_tier2() {
let burst = PrioritySelectStream::<MockHandshakeStream, MockClientStream, MockExecutorStream>::MAX_HIGH_PRIORITY_BURST as usize;
let p1_count = burst + 1;
let (notif_tx, notif_rx) = mpsc::channel(p1_count + 10);
let (_, op_rx) = mpsc::channel(1);
let (_, conn_event_rx) = mpsc::channel(1);
let (_, bridge_rx) = mpsc::channel(1);
let (_, node_rx) = mpsc::channel(1);
let (client_tx, client_rx) = mpsc::channel(10);
let (_, executor_rx) = mpsc::channel::<Transaction>(1);
for _ in 0..p1_count {
notif_tx.send(dummy_notif_msg()).await.unwrap();
}
client_tx.send(dummy_client_tx()).await.unwrap();
drop(notif_tx);
drop(client_tx);
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientReceiverStream { rx: client_rx },
MockExecutorReceiverStream { rx: executor_rx },
conn_event_rx,
);
tokio::pin!(stream);
let events = drain_stream(&mut stream, p1_count + 10).await;
assert_eq!(
events.get(burst),
Some(&"client_transaction"),
"P7 should appear at index {} (right after {} consecutive P1 items), got events: {:?}",
burst,
burst,
&events[..std::cmp::min(events.len(), burst + 3)]
);
}
#[tokio::test]
#[test_log::test]
async fn test_anti_starvation_preserves_priority_under_burst_limit() {
let burst = PrioritySelectStream::<MockHandshakeStream, MockClientStream, MockExecutorStream>::MAX_HIGH_PRIORITY_BURST as usize;
let p1_count = 20; let p7_count = 5;
assert!(p1_count < burst, "Test requires P1 count < burst limit");
let (notif_tx, notif_rx) = mpsc::channel(p1_count + 10);
let (_, op_rx) = mpsc::channel(1);
let (_, conn_event_rx) = mpsc::channel(1);
let (_, bridge_rx) = mpsc::channel(1);
let (_, node_rx) = mpsc::channel(1);
let (client_tx, client_rx) = mpsc::channel(p7_count + 10);
let (_, executor_rx) = mpsc::channel::<Transaction>(1);
for _ in 0..p1_count {
notif_tx.send(dummy_notif_msg()).await.unwrap();
}
for _ in 0..p7_count {
client_tx.send(dummy_client_tx()).await.unwrap();
}
drop(notif_tx);
drop(client_tx);
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientReceiverStream { rx: client_rx },
MockExecutorReceiverStream { rx: executor_rx },
conn_event_rx,
);
tokio::pin!(stream);
let events = drain_stream(&mut stream, p1_count + p7_count + 10).await;
let total_notif = events.iter().filter(|&&e| e == "notification").count();
let total_client = events
.iter()
.filter(|&&e| e == "client_transaction")
.count();
assert_eq!(total_notif, p1_count);
assert_eq!(total_client, p7_count);
let last_notif = events.iter().rposition(|&e| e == "notification").unwrap();
let first_client = events
.iter()
.position(|&e| e == "client_transaction")
.unwrap();
assert!(
last_notif < first_client,
"Under burst limit, all P1 (last at {}) should precede all P7 (first at {})",
last_notif,
first_client
);
}
#[tokio::test]
#[test_log::test]
async fn test_anti_starvation_force_poll_tier2_pending_falls_through() {
let burst = PrioritySelectStream::<MockHandshakeStream, MockClientStream, MockExecutorStream>::MAX_HIGH_PRIORITY_BURST as usize;
let p1_count = burst * 2;
let (notif_tx, notif_rx) = mpsc::channel(p1_count + 10);
let (_, op_rx) = mpsc::channel(1);
let (_, conn_event_rx) = mpsc::channel(1);
let (_, bridge_rx) = mpsc::channel(1);
let (_, node_rx) = mpsc::channel(1);
for _ in 0..p1_count {
notif_tx.send(dummy_notif_msg()).await.unwrap();
}
drop(notif_tx);
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientStream, MockExecutorStream, conn_event_rx,
);
tokio::pin!(stream);
let events = timeout(Duration::from_secs(5), async {
drain_stream(&mut stream, p1_count + 10).await
})
.await
.expect("drain_stream should complete within 5s — force-poll must not block on Pending Tier-2");
let total_notif = events.iter().filter(|&&e| e == "notification").count();
assert_eq!(
total_notif, p1_count,
"All {} P1 messages should be received even when P7/P8 are always Pending",
p1_count
);
}
#[tokio::test]
#[test_log::test]
async fn test_anti_starvation_mixed_p7_closed_p8_pending() {
let burst = PrioritySelectStream::<MockHandshakeStream, MockClientStream, MockExecutorStream>::MAX_HIGH_PRIORITY_BURST as usize;
let p1_count = burst * 2;
let (notif_tx, notif_rx) = mpsc::channel(p1_count + 10);
let (_, op_rx) = mpsc::channel(1);
let (_, conn_event_rx) = mpsc::channel(1);
let (_, bridge_rx) = mpsc::channel(1);
let (_, node_rx) = mpsc::channel(1);
let (client_tx, client_rx) = mpsc::channel::<(ClientId, WaitingTransaction)>(1);
drop(client_tx);
for _ in 0..p1_count {
notif_tx.send(dummy_notif_msg()).await.unwrap();
}
drop(notif_tx);
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientReceiverStream { rx: client_rx }, MockExecutorStream, conn_event_rx,
);
tokio::pin!(stream);
let events = timeout(Duration::from_secs(5), async {
drain_stream(&mut stream, p1_count + 10).await
})
.await
.expect("drain_stream should complete — mixed closed/pending Tier-2 must not block");
let total_notif = events.iter().filter(|&&e| e == "notification").count();
assert_eq!(
total_notif, p1_count,
"All {} P1 messages should be received with P7 closed and P8 Pending",
p1_count
);
}
#[tokio::test]
#[test_log::test]
async fn test_anti_starvation_high_load_fairness() {
test_high_load_fairness_config(500, 200, 150, 100, 200, 50).await;
test_high_load_fairness_config(100, 100, 100, 100, 150, 30).await;
test_high_load_fairness_config(800, 0, 0, 0, 300, 10).await;
test_high_load_fairness_config(200, 200, 100, 100, 5, 2).await;
}
async fn test_high_load_fairness_config(
n_notif: usize,
n_op: usize,
n_bridge: usize,
n_node: usize,
n_client: usize,
n_executor: usize,
) {
let burst = PrioritySelectStream::<MockHandshakeStream, MockClientStream, MockExecutorStream>::MAX_HIGH_PRIORITY_BURST as usize;
let total_tier1 = n_notif + n_op + n_bridge + n_node;
let total_tier2 = n_client + n_executor;
let total = total_tier1 + total_tier2;
let tiers = run_prefilled_stream(n_notif, n_op, n_bridge, n_node, n_client, n_executor).await;
assert_eq!(
tiers.len(),
total,
"config=({n_notif},{n_op},{n_bridge},{n_node},{n_client},{n_executor}): expected {total} items, got {}",
tiers.len()
);
let recv_tier1 = tiers.iter().filter(|&&t| !t).count();
let recv_tier2 = tiers.iter().filter(|&&t| t).count();
assert_eq!(recv_tier1, total_tier1, "tier-1 count mismatch");
assert_eq!(recv_tier2, total_tier2, "tier-2 count mismatch");
if total_tier2 == 0 || total_tier1 == 0 {
return;
}
let first_t2 = tiers.iter().position(|&t| t).unwrap();
assert!(
first_t2 <= burst,
"config=({n_notif},{n_op},{n_bridge},{n_node},{n_client},{n_executor}): first tier-2 at index {first_t2}, limit {burst}"
);
let mut max_run = 0usize;
let mut current_run = 0usize;
let mut tier2_remaining = total_tier2;
for &is_t2 in &tiers {
if is_t2 {
tier2_remaining -= 1;
current_run = 0;
} else if tier2_remaining > 0 {
current_run += 1;
max_run = max_run.max(current_run);
} else {
current_run = 0;
}
}
assert!(
max_run <= burst,
"config=({n_notif},{n_op},{n_bridge},{n_node},{n_client},{n_executor}): max tier-1 run {max_run} exceeds burst {burst}"
);
tracing::info!(
"config=({},{},{},{},{},{}): fairness OK — {} events, max tier-1 run={}",
n_notif,
n_op,
n_bridge,
n_node,
n_client,
n_executor,
total,
max_run
);
}
struct MockPeerConnection {
addr: std::net::SocketAddr,
}
impl crate::transport::PeerConnectionApi for MockPeerConnection {
fn remote_addr(&self) -> std::net::SocketAddr {
self.addr
}
fn send_message(
&mut self,
_msg: crate::message::NetMessage,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<(), crate::transport::TransportError>>
+ Send
+ '_,
>,
> {
Box::pin(async { Ok(()) })
}
fn recv(
&mut self,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<Vec<u8>, crate::transport::TransportError>>
+ Send
+ '_,
>,
> {
Box::pin(async { Ok(vec![]) })
}
fn set_orphan_stream_registry(
&mut self,
_registry: std::sync::Arc<crate::operations::orphan_streams::OrphanStreamRegistry>,
) {
}
fn send_stream_data(
&mut self,
_stream_id: crate::transport::peer_connection::StreamId,
_data: bytes::Bytes,
_metadata: Option<bytes::Bytes>,
_completion_tx: Option<tokio::sync::oneshot::Sender<()>>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<(), crate::transport::TransportError>>
+ Send
+ '_,
>,
> {
if let Some(tx) = _completion_tx {
let _ignored = tx.send(());
}
Box::pin(async { Ok(()) })
}
fn pipe_stream_data(
&mut self,
_outbound_stream_id: crate::transport::peer_connection::StreamId,
_inbound_handle: crate::transport::peer_connection::streaming::StreamHandle,
_metadata: Option<bytes::Bytes>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<(), crate::transport::TransportError>>
+ Send
+ '_,
>,
> {
Box::pin(async { Ok(()) })
}
}
struct MockHandshakeReceiverStream {
rx: mpsc::Receiver<crate::node::network_bridge::handshake::Event>,
}
impl Stream for MockHandshakeReceiverStream {
type Item = crate::node::network_bridge::handshake::Event;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.rx).poll_recv(cx)
}
}
fn dummy_handshake_event() -> crate::node::network_bridge::handshake::Event {
crate::node::network_bridge::handshake::Event::InboundConnection {
transaction: None,
peer: None,
connection: Box::new(MockPeerConnection {
addr: "127.0.0.1:9999".parse().unwrap(),
}),
transient: false,
}
}
#[tokio::test]
#[test_log::test]
async fn test_handshake_not_starved_under_notification_load() {
const P1_COUNT: usize = 100;
const HS_COUNT: usize = 1;
let burst = PrioritySelectStream::<MockHandshakeStream, MockClientStream, MockExecutorStream>::MAX_HIGH_PRIORITY_BURST as usize;
let (notif_tx, notif_rx) = mpsc::channel(P1_COUNT + 10);
let (_, op_rx) = mpsc::channel(1);
let (_, conn_event_rx) = mpsc::channel(1);
let (_, bridge_rx) = mpsc::channel(1);
let (_, node_rx) = mpsc::channel(1);
let (hs_tx, hs_rx) = mpsc::channel(HS_COUNT + 1);
for _ in 0..P1_COUNT {
notif_tx.send(dummy_notif_msg()).await.unwrap();
}
hs_tx.send(dummy_handshake_event()).await.unwrap();
drop(notif_tx);
drop(hs_tx);
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
MockHandshakeReceiverStream { rx: hs_rx },
node_rx,
MockClientStream,
MockExecutorStream,
conn_event_rx,
);
tokio::pin!(stream);
let events = drain_stream(&mut stream, P1_COUNT + HS_COUNT + 20).await;
let total_notif = events.iter().filter(|&&e| e == "notification").count();
let total_hs = events.iter().filter(|&&e| e == "handshake").count();
assert_eq!(total_notif, P1_COUNT, "All P1 messages received");
assert_eq!(total_hs, HS_COUNT, "All handshake events received");
let first_hs_idx = events
.iter()
.position(|&e| e == "handshake")
.expect("Handshake event must exist");
assert!(
first_hs_idx <= burst,
"Handshake event at index {} but should appear within {} polls (was starved!)",
first_hs_idx,
burst
);
assert_eq!(
first_hs_idx, burst,
"Handshake should appear at exactly index {} (anti-starvation trigger), got {}",
burst, first_hs_idx
);
tracing::info!(
"Handshake starvation test: handshake at index {}, {} notifications total",
first_hs_idx,
total_notif
);
}
#[tokio::test]
#[test_log::test]
async fn test_handshake_p2_before_op_execution_p3() {
let (_, notif_rx) = mpsc::channel(1);
let (op_tx, op_rx) = mpsc::channel::<OpExecutionPayload>(1);
let (_, conn_event_rx) = mpsc::channel(1);
let (_, bridge_rx) = mpsc::channel(1);
let (_, node_rx) = mpsc::channel(1);
let (hs_tx, hs_rx) = mpsc::channel(1);
hs_tx.send(dummy_handshake_event()).await.unwrap();
let dummy_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
let (callback_tx, _) = mpsc::channel(1);
op_tx.send((callback_tx, dummy_msg, None)).await.unwrap();
drop(hs_tx);
drop(op_tx);
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
MockHandshakeReceiverStream { rx: hs_rx },
node_rx,
MockClientStream,
MockExecutorStream,
conn_event_rx,
);
tokio::pin!(stream);
let events = drain_stream(&mut stream, 5).await;
assert_eq!(events.len(), 2, "Should receive both events");
assert_eq!(
events[0], "handshake",
"Handshake (P2) should be returned before op_execution (P3)"
);
assert_eq!(events[1], "op_execution");
}
#[tokio::test]
#[test_log::test]
async fn test_anti_starvation_includes_handshake() {
let burst = PrioritySelectStream::<MockHandshakeStream, MockClientStream, MockExecutorStream>::MAX_HIGH_PRIORITY_BURST as usize;
let p1_count = burst * 3;
let (notif_tx, notif_rx) = mpsc::channel(p1_count + 10);
let (_, op_rx) = mpsc::channel(1);
let (_, conn_event_rx) = mpsc::channel(1);
let (_, bridge_rx) = mpsc::channel(1);
let (_, node_rx) = mpsc::channel(1);
let (hs_tx, hs_rx) = mpsc::channel(5);
let (client_tx, client_rx) = mpsc::channel(5);
for _ in 0..p1_count {
notif_tx.send(dummy_notif_msg()).await.unwrap();
}
hs_tx.send(dummy_handshake_event()).await.unwrap();
client_tx.send(dummy_client_tx()).await.unwrap();
drop(notif_tx);
drop(hs_tx);
drop(client_tx);
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
MockHandshakeReceiverStream { rx: hs_rx },
node_rx,
MockClientReceiverStream { rx: client_rx },
MockExecutorStream,
conn_event_rx,
);
tokio::pin!(stream);
let events = drain_stream(&mut stream, p1_count + 10).await;
let total_hs = events.iter().filter(|&&e| e == "handshake").count();
let total_client = events
.iter()
.filter(|&&e| e == "client_transaction")
.count();
assert_eq!(total_hs, 1, "Handshake event must be delivered");
assert_eq!(total_client, 1, "Client transaction must be delivered");
let hs_idx = events.iter().position(|&e| e == "handshake").unwrap();
let client_idx = events
.iter()
.position(|&e| e == "client_transaction")
.unwrap();
assert!(
hs_idx < client_idx,
"Handshake (idx={}) should appear before client tx (idx={}) in force-poll order",
hs_idx,
client_idx
);
}
async fn run_prefilled_stream(
n_notif: usize,
n_op: usize,
n_bridge: usize,
n_node: usize,
n_client: usize,
n_executor: usize,
) -> Vec<bool> {
let total = n_notif + n_op + n_bridge + n_node + n_client + n_executor;
if total == 0 {
return vec![];
}
let (notif_tx, notif_rx) = mpsc::channel(n_notif.max(1));
let (op_tx, op_rx) = mpsc::channel::<OpExecutionPayload>(n_op.max(1));
let (_, conn_event_rx) = mpsc::channel(1);
let (bridge_tx, bridge_rx) = mpsc::channel(n_bridge.max(1));
let (node_tx, node_rx) = mpsc::channel(n_node.max(1));
let (client_tx, client_rx) = mpsc::channel(n_client.max(1));
let (executor_tx, executor_rx) = mpsc::channel(n_executor.max(1));
for _ in 0..n_notif {
notif_tx.send(dummy_notif_msg()).await.unwrap();
}
for _ in 0..n_op {
let msg = NetMessage::V1(crate::message::NetMessageV1::Aborted(
crate::message::Transaction::new::<crate::operations::put::PutMsg>(),
));
let (cb, _) = mpsc::channel(1);
op_tx.send((cb, msg, None)).await.unwrap();
}
for _ in 0..n_bridge {
bridge_tx
.send(P2pBridgeEvent::NodeAction(NodeEvent::Disconnect {
cause: None,
}))
.await
.unwrap();
}
for _ in 0..n_node {
node_tx
.send(NodeEvent::Disconnect { cause: None })
.await
.unwrap();
}
for _ in 0..n_client {
client_tx.send(dummy_client_tx()).await.unwrap();
}
for _ in 0..n_executor {
executor_tx
.send(Transaction::new::<crate::operations::put::PutMsg>())
.await
.unwrap();
}
drop(notif_tx);
drop(op_tx);
drop(bridge_tx);
drop(node_tx);
drop(client_tx);
drop(executor_tx);
let stream = PrioritySelectStream::new(
notif_rx,
op_rx,
bridge_rx,
create_mock_handshake_stream(),
node_rx,
MockClientReceiverStream { rx: client_rx },
MockExecutorReceiverStream { rx: executor_rx },
conn_event_rx,
);
tokio::pin!(stream);
let mut tiers = Vec::with_capacity(total);
use futures::StreamExt;
for _ in 0..(total * 3) {
match timeout(Duration::from_millis(50), stream.as_mut().next()).await {
Ok(Some(ref r)) => {
let tier = match r {
SelectResult::Notification(Some(_)) => Some(false),
SelectResult::OpExecution(Some(_)) => Some(false),
SelectResult::PeerConnection(Some(_)) => Some(false),
SelectResult::ConnBridge(Some(_)) => Some(false),
SelectResult::Handshake(Some(_)) => Some(false),
SelectResult::NodeController(Some(_)) => Some(false),
SelectResult::ClientTransaction(Ok(_)) => Some(true),
SelectResult::ExecutorTransaction(Ok(_)) => Some(true),
SelectResult::Notification(_)
| SelectResult::OpExecution(_)
| SelectResult::PeerConnection(_)
| SelectResult::ConnBridge(_)
| SelectResult::Handshake(_)
| SelectResult::NodeController(_)
| SelectResult::ClientTransaction(_)
| SelectResult::ExecutorTransaction(_) => None, };
if let Some(t) = tier {
tiers.push(t);
}
}
_ => break,
}
}
tiers
}
mod prop_tests {
use super::*;
use proptest::prelude::*;
fn max_tier1_run_while_tier2_pending(tiers: &[bool], total_tier2: usize) -> usize {
if total_tier2 == 0 {
return 0;
}
let mut max_run = 0;
let mut current_run = 0;
let mut tier2_remaining = total_tier2;
for &is_t2 in tiers {
if is_t2 {
tier2_remaining -= 1;
current_run = 0;
} else if tier2_remaining > 0 {
current_run += 1;
max_run = max_run.max(current_run);
} else {
current_run = 0;
}
}
max_run
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn proptest_no_starvation_under_any_load(
n_notif in 0usize..200,
n_op in 0usize..200,
n_bridge in 0usize..200,
n_node in 0usize..200,
n_client in 0usize..100,
n_executor in 0usize..100,
) {
let burst = PrioritySelectStream::<MockHandshakeStream, MockClientStream, MockExecutorStream>::MAX_HIGH_PRIORITY_BURST as usize;
let total_tier1 = n_notif + n_op + n_bridge + n_node;
let total_tier2 = n_client + n_executor;
let total = total_tier1 + total_tier2;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let tiers = rt.block_on(run_prefilled_stream(
n_notif, n_op, n_bridge, n_node, n_client, n_executor,
));
prop_assert_eq!(tiers.len(), total,
"Expected {} total messages, got {}", total, tiers.len());
let recv_tier1 = tiers.iter().filter(|&&t| !t).count();
let recv_tier2 = tiers.iter().filter(|&&t| t).count();
prop_assert_eq!(recv_tier1, total_tier1);
prop_assert_eq!(recv_tier2, total_tier2);
if total_tier2 > 0 {
let first_t2 = tiers.iter().position(|&t| t).unwrap();
prop_assert!(first_t2 <= burst,
"First tier-2 at index {} exceeds burst limit {}", first_t2, burst);
let max_run = max_tier1_run_while_tier2_pending(&tiers, total_tier2);
prop_assert!(max_run <= burst,
"Max tier-1 run {} exceeds burst limit {}", max_run, burst);
}
}
#[test]
fn proptest_priority_preserved_under_burst_limit(
n_notif in 0usize..8,
n_op in 0usize..8,
n_bridge in 0usize..8,
n_node in 0usize..8,
n_client in 0usize..50,
n_executor in 0usize..50,
) {
let burst = PrioritySelectStream::<MockHandshakeStream, MockClientStream, MockExecutorStream>::MAX_HIGH_PRIORITY_BURST as usize;
let total_tier1 = n_notif + n_op + n_bridge + n_node;
let total_tier2 = n_client + n_executor;
prop_assume!(total_tier1 < burst && total_tier1 > 0 && total_tier2 > 0);
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let tiers = rt.block_on(run_prefilled_stream(
n_notif, n_op, n_bridge, n_node, n_client, n_executor,
));
prop_assert_eq!(tiers.len(), total_tier1 + total_tier2);
let last_tier1 = tiers.iter().rposition(|&t| !t);
let first_tier2 = tiers.iter().position(|&t| t);
if let (Some(last_t1), Some(first_t2)) = (last_tier1, first_tier2) {
prop_assert!(last_t1 < first_t2,
"Under burst limit: last tier-1 at {} should precede first tier-2 at {}",
last_t1, first_t2);
}
}
}
}