use std::{sync::Arc, time::Duration};
use bytes::Bytes;
use futures::{SinkExt, StreamExt, stream::FuturesUnordered};
use rand::rngs::OsRng;
use tari_common_sqlite::connection::DbConnection;
use tari_shutdown::Shutdown;
use tari_test_utils::{collect_stream, unpack_enum};
use tokio::{
sync::{broadcast, mpsc, oneshot},
time,
};
use super::protocol::{MessagingEventReceiver, MessagingProtocol};
use crate::{
message::{InboundMessage, MessageTag, MessagingReplyRx, OutboundMessage},
multiplexing::Substream,
net_address::MultiaddressesWithStats,
peer_manager::{
NodeId,
NodeIdentity,
Peer,
PeerFeatures,
PeerFlags,
PeerManager,
create_test_peer,
database::{MIGRATIONS, PeerDatabaseSql},
},
protocol::{
ProtocolEvent,
ProtocolId,
ProtocolNotification,
messaging::{MessagingEvent, SendFailReason},
},
test_utils::{
mocks::{ConnectivityManagerMockState, create_connectivity_mock, create_peer_connection_mock_pair},
node_id,
node_identity::build_node_identity,
},
types::{CommsPublicKey, TransportProtocol},
};
static TEST_MSG1: Bytes = Bytes::from_static(b"TEST_MSG1");
static TEST_MSG2: Bytes = Bytes::from_static(b"TEST_MSG2");
static MESSAGING_PROTOCOL_ID: ProtocolId = ProtocolId::from_static(b"test/msg");
fn create_peer_manager() -> Arc<PeerManager> {
let db_connection = DbConnection::connect_temp_file_and_migrate(MIGRATIONS).unwrap();
let peers_db = PeerDatabaseSql::new(
db_connection,
&create_test_peer(false, PeerFeatures::COMMUNICATION_NODE),
)
.unwrap();
Arc::new(PeerManager::new(peers_db, TransportProtocol::get_all()).unwrap())
}
async fn spawn_messaging_protocol() -> (
Arc<PeerManager>,
Arc<NodeIdentity>,
ConnectivityManagerMockState,
mpsc::Sender<ProtocolNotification<Substream>>,
mpsc::UnboundedSender<OutboundMessage>,
mpsc::Receiver<InboundMessage>,
MessagingEventReceiver,
Shutdown,
) {
let shutdown = Shutdown::new();
let (requester, mock) = create_connectivity_mock();
let mock_state = mock.get_shared_state();
mock.spawn();
let peer_manager = create_peer_manager();
let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_CLIENT);
let (proto_tx, proto_rx) = mpsc::channel(10);
let (request_tx, request_rx) = mpsc::unbounded_channel();
let (inbound_msg_tx, inbound_msg_rx) = mpsc::channel(100);
let (events_tx, events_rx) = broadcast::channel(100);
let msg_proto = MessagingProtocol::new(
MESSAGING_PROTOCOL_ID.clone(),
requester,
proto_rx,
request_rx,
events_tx,
inbound_msg_tx,
shutdown.to_signal(),
)
.set_message_received_event_enabled(true);
tokio::spawn(msg_proto.run());
(
peer_manager,
node_identity,
mock_state,
proto_tx,
request_tx,
inbound_msg_rx,
events_rx,
shutdown,
)
}
#[tokio::test]
async fn new_inbound_substream_handling() {
let (peer_manager, _, conn_man_mock, proto_tx, outbound_msg_tx, mut inbound_msg_rx, mut events_rx, _shutdown) =
spawn_messaging_protocol().await;
let expected_node_id = node_id::random();
let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng);
let peer1 = Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
);
peer_manager.add_or_update_peer(peer1.clone()).await.unwrap();
let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng);
let peer2 = Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
);
let (_, conn1_state, conn2, _conn2_state) = create_peer_connection_mock_pair(peer1.clone(), peer2.clone()).await;
conn_man_mock.add_active_connection(conn2).await;
let (reply_tx, _reply_rx) = oneshot::channel();
let out_msg = OutboundMessage {
tag: MessageTag::new(),
reply: reply_tx.into(),
peer_node_id: peer1.node_id.clone(),
body: TEST_MSG1.clone(),
};
outbound_msg_tx.send(out_msg).unwrap();
let stream_theirs = conn1_state.next_incoming_substream().await.unwrap();
proto_tx
.send(ProtocolNotification::new(
MESSAGING_PROTOCOL_ID.clone(),
ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_theirs),
))
.await
.unwrap();
let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(in_msg.source_peer, expected_node_id);
assert_eq!(in_msg.body, TEST_MSG1);
let expected_tag = in_msg.tag;
let event = time::timeout(Duration::from_secs(5), events_rx.recv())
.await
.unwrap()
.unwrap();
unpack_enum!(MessagingEvent::MessageReceived(node_id, tag) = &event);
assert_eq!(tag, &expected_tag);
assert_eq!(*node_id, expected_node_id);
}
#[tokio::test]
async fn send_message_request() {
let (_, node_identity, conn_man_mock, _, request_tx, _, _, _shutdown) = spawn_messaging_protocol().await;
let peer_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let (conn1, peer_conn_mock1, _, peer_conn_mock2) =
create_peer_connection_mock_pair(node_identity.to_peer(), peer_node_identity.to_peer()).await;
conn_man_mock.add_active_connection(conn1).await;
let out_msg = OutboundMessage::new(peer_node_identity.node_id().clone(), TEST_MSG1.clone());
request_tx.send(out_msg).unwrap();
let stream = peer_conn_mock2.next_incoming_substream().await.unwrap();
let mut framed = MessagingProtocol::framed(stream);
let msg = framed.next().await.unwrap().unwrap();
assert_eq!(msg, TEST_MSG1);
assert_eq!(peer_conn_mock1.call_count(), 1);
}
#[tokio::test]
async fn send_message_dial_failed() {
let (_, _, conn_manager_mock, _, request_tx, _, mut event_tx, _shutdown) = spawn_messaging_protocol().await;
let node_id = node_id::random();
let (reply_tx, reply_rx) = oneshot::channel();
let out_msg = OutboundMessage::with_reply(node_id, TEST_MSG1.clone(), reply_tx.into());
request_tx.send(out_msg).unwrap();
let event = event_tx.recv().await.unwrap();
unpack_enum!(MessagingEvent::OutboundProtocolExited(_node_id) = &event);
let reply = reply_rx.await.unwrap().unwrap_err();
unpack_enum!(SendFailReason::PeerDialFailed = reply);
let calls = conn_manager_mock.take_calls().await;
assert_eq!(calls.len(), 2);
assert!(calls.iter().all(|evt| evt.starts_with("DialPeer")));
}
#[tokio::test]
async fn send_message_substream_bulk_failure() {
const NUM_MSGS: usize = 10;
let (_, node_identity, conn_manager_mock, _, mut request_tx, _, mut events_rx, _shutdown) =
spawn_messaging_protocol().await;
let peer_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let (conn1, _, _, peer_conn_mock2) =
create_peer_connection_mock_pair(node_identity.to_peer(), peer_node_identity.to_peer()).await;
let peer_node_id = peer_node_identity.node_id();
conn_manager_mock.add_active_connection(conn1).await;
async fn send_msg(
request_tx: &mut mpsc::UnboundedSender<OutboundMessage>,
node_id: NodeId,
) -> (MessageTag, MessagingReplyRx) {
let (reply_tx, reply_rx) = oneshot::channel();
let out_msg = OutboundMessage::with_reply(node_id, TEST_MSG1.clone(), reply_tx.into());
let msg_tag = out_msg.tag;
request_tx.send(out_msg).unwrap();
(msg_tag, reply_rx)
}
let mut expected_out_msg_tags = Vec::with_capacity(NUM_MSGS);
expected_out_msg_tags.push(send_msg(&mut request_tx, peer_node_id.clone()).await);
let _substream = peer_conn_mock2.next_incoming_substream().await.unwrap();
peer_conn_mock2.disconnect().await.unwrap();
drop(peer_conn_mock2);
for _ in 0..NUM_MSGS - 1 {
expected_out_msg_tags.push(send_msg(&mut request_tx, peer_node_id.clone()).await);
}
let mut num_sent = 0usize;
let mut num_failed = 0usize;
for (_, reply) in expected_out_msg_tags {
match reply.await.unwrap() {
Ok(_) => {
num_sent += 1;
},
Err(SendFailReason::PeerDialFailed) => {
num_failed += 1;
},
Err(err) => unreachable!("Unexpected error {}", err),
}
}
assert!(num_failed > 0);
assert_eq!(num_sent + num_failed, NUM_MSGS);
let event = time::timeout(Duration::from_secs(10), events_rx.recv())
.await
.unwrap()
.unwrap();
unpack_enum!(MessagingEvent::OutboundProtocolExited(node_id) = &event);
assert_eq!(node_id, peer_node_id);
}
#[tokio::test]
async fn many_concurrent_send_message_requests() {
const NUM_MSGS: usize = 100;
let (_, _, conn_man_mock, _, request_tx, _, _, _shutdown) = spawn_messaging_protocol().await;
let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let (conn1, peer_conn_mock1, _, peer_conn_mock2) =
create_peer_connection_mock_pair(node_identity1.to_peer(), node_identity2.to_peer()).await;
let node_id2 = node_identity2.node_id();
conn_man_mock.add_active_connection(conn1).await;
let mut msg_tags = Vec::with_capacity(NUM_MSGS);
let mut reply_rxs = Vec::with_capacity(NUM_MSGS);
for _ in 0..NUM_MSGS {
let (reply_tx, reply_rx) = oneshot::channel();
let out_msg = OutboundMessage {
tag: MessageTag::new(),
reply: reply_tx.into(),
peer_node_id: node_id2.clone(),
body: TEST_MSG1.clone(),
};
msg_tags.push(out_msg.tag);
reply_rxs.push(reply_rx);
request_tx.send(out_msg).unwrap();
}
let stream = peer_conn_mock2.next_incoming_substream().await.unwrap();
let mut framed = MessagingProtocol::framed(stream);
let messages = collect_stream!(framed, take = NUM_MSGS, timeout = Duration::from_secs(10));
assert_eq!(messages.len(), NUM_MSGS);
let unordered = reply_rxs.into_iter().collect::<FuturesUnordered<_>>();
let results = unordered.collect::<Vec<_>>().await;
assert_eq!(
results.into_iter().map(Result::unwrap).filter(Result::is_err).count(),
0
);
assert_eq!(peer_conn_mock1.call_count(), 1);
}
#[tokio::test]
async fn many_concurrent_send_message_requests_that_fail() {
const NUM_MSGS: usize = 100;
let (_, _, _, _, request_tx, _, _, _shutdown) = spawn_messaging_protocol().await;
let node_id2 = node_id::random();
let mut msg_tags = Vec::with_capacity(NUM_MSGS);
let mut reply_rxs = Vec::with_capacity(NUM_MSGS);
for _ in 0..NUM_MSGS {
let (reply_tx, reply_rx) = oneshot::channel();
let out_msg = OutboundMessage {
tag: MessageTag::new(),
reply: reply_tx.into(),
peer_node_id: node_id2.clone(),
body: TEST_MSG1.clone(),
};
msg_tags.push(out_msg.tag);
reply_rxs.push(reply_rx);
request_tx.send(out_msg).unwrap();
}
let unordered = reply_rxs.into_iter().collect::<FuturesUnordered<_>>();
let results = unordered.collect::<Vec<_>>().await;
assert!(results.into_iter().map(|r| r.unwrap()).all(|r| r.is_err()));
}
#[tokio::test]
async fn new_inbound_substream_only_single_session_permitted() {
let (peer_manager, node_identity_1, conn_man_mock, proto_tx, _, mut inbound_msg_rx, _, _shutdown) =
spawn_messaging_protocol().await;
let expected_node_id = node_id::random();
let peer1 = node_identity_1.to_peer();
let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng);
let peer2 = Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
);
peer_manager.add_or_update_peer(peer2.clone()).await.unwrap();
let (conn1, conn1_state, _, conn2_state) = create_peer_connection_mock_pair(peer1.clone(), peer2.clone()).await;
conn_man_mock.add_active_connection(conn1).await;
tokio::spawn({
let expected_node_id = expected_node_id.clone();
async move {
while let Some(stream_theirs) = conn2_state.next_incoming_substream().await {
proto_tx
.send(ProtocolNotification::new(
MESSAGING_PROTOCOL_ID.clone(),
ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_theirs),
))
.await
.unwrap();
}
}
});
let stream_ours = conn1_state.open_substream().await.unwrap();
let mut framed_ours = MessagingProtocol::framed(stream_ours);
framed_ours.send(TEST_MSG1.clone()).await.unwrap();
let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(in_msg.source_peer, expected_node_id);
assert_eq!(in_msg.body, TEST_MSG1);
let stream_ours2 = conn1_state.open_substream().await.unwrap();
let mut framed_ours2 = MessagingProtocol::framed(stream_ours2);
loop {
if let Err(e) = framed_ours2.send(TEST_MSG2.clone()).await {
assert_eq!(
e.to_string().split(':').nth(1).map(|s| s.trim()),
Some("connection is closed"),
"Expected connection to be closed but got '{e}'"
);
break;
}
}
framed_ours.send(TEST_MSG1.clone()).await.unwrap();
let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(in_msg.source_peer, expected_node_id);
assert_eq!(in_msg.body, TEST_MSG1);
framed_ours.close().await.unwrap();
let stream_ours = conn1_state.open_substream().await.unwrap();
let mut framed_ours = MessagingProtocol::framed(stream_ours);
framed_ours.send(TEST_MSG1.clone()).await.unwrap();
let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(in_msg.source_peer, expected_node_id);
assert_eq!(in_msg.body, TEST_MSG1);
}