#![allow(clippy::indexing_slicing)]
use std::{collections::HashSet, convert::identity, hash::Hash, time::Duration};
use bytes::Bytes;
use futures::stream::FuturesUnordered;
use tari_common_sqlite::connection::DbConnection;
use tari_shutdown::{Shutdown, ShutdownSignal};
use tari_test_utils::{collect_recv, collect_stream, unpack_enum};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::{broadcast, mpsc, oneshot},
task,
};
use crate::{
CommsNode,
backoff::ConstantBackoff,
builder::CommsBuilder,
connection_manager::ConnectionManagerEvent,
memsocket,
message::{InboundMessage, OutboundMessage},
multiaddr::{Multiaddr, Protocol},
multiplexing::Substream,
net_address::{MultiaddressesWithStats, PeerAddressSource},
peer_manager::{
Peer,
PeerFeatures,
database::{MIGRATIONS, PeerDatabaseSql},
},
pipeline,
pipeline::SinkService,
protocol::{
ProtocolEvent,
ProtocolId,
Protocols,
messaging::{MessagingEvent, MessagingEventSender, MessagingProtocolExtension},
},
test_utils::node_identity::build_node_identity,
transports::MemoryTransport,
};
async fn spawn_node(
protocols: Protocols<Substream>,
shutdown_sig: ShutdownSignal,
) -> (
CommsNode,
mpsc::Receiver<InboundMessage>,
mpsc::UnboundedSender<OutboundMessage>,
MessagingEventSender,
) {
let addr = format!("/memory/{}", memsocket::acquire_next_memsocket_port())
.parse::<Multiaddr>()
.unwrap();
let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
node_identity.add_public_address(addr.clone());
let (inbound_tx, inbound_rx) = mpsc::channel(10);
let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
let db_connection = DbConnection::connect_temp_file_and_migrate(MIGRATIONS).unwrap();
let peers_db = PeerDatabaseSql::new(db_connection, &node_identity.to_peer()).unwrap();
let comms_node = CommsBuilder::new()
.with_dial_backoff(ConstantBackoff::new(Duration::from_millis(500)))
.with_shutdown_signal(shutdown_sig)
.with_listener_address(addr)
.with_peer_storage(peers_db)
.with_node_identity(node_identity)
.build()
.unwrap();
let (messaging_events_sender, _) = broadcast::channel(100);
let mut comms_node = comms_node
.add_protocol_extensions(protocols.into())
.add_protocol_extension(
MessagingProtocolExtension::new(
ProtocolId::from_static(b"test/msg"),
messaging_events_sender.clone(),
pipeline::Builder::new()
.with_outbound_pipeline(outbound_rx, identity)
.max_concurrent_inbound_tasks(1)
.with_inbound_pipeline(SinkService::new(inbound_tx))
.build(),
)
.enable_message_received_event(),
)
.spawn_with_transport(MemoryTransport)
.await
.unwrap();
let address = comms_node
.connection_manager_requester()
.wait_until_listening()
.await
.unwrap();
unpack_enum!(Protocol::Memory(_port) = address.bind_address().iter().next().unwrap());
(comms_node, inbound_rx, outbound_tx, messaging_events_sender)
}
#[tokio::test]
async fn peer_to_peer_custom_protocols() {
static TEST_PROTOCOL: Bytes = Bytes::from_static(b"/tari/test");
static ANOTHER_TEST_PROTOCOL: Bytes = Bytes::from_static(b"/tari/test-again");
const TEST_MSG: &[u8] = b"Hello Tari";
const ANOTHER_TEST_MSG: &[u8] = b"Comms is running smoothly";
let (test_sender, _test_protocol_rx1) = mpsc::channel(10);
let (another_test_sender, mut another_test_protocol_rx1) = mpsc::channel(10);
let mut protocols1 = Protocols::new();
protocols1
.add([TEST_PROTOCOL.clone()], &test_sender)
.add([ANOTHER_TEST_PROTOCOL.clone()], &another_test_sender);
let (test_sender, mut test_protocol_rx2) = mpsc::channel(10);
let (another_test_sender, _another_test_protocol_rx2) = mpsc::channel(10);
let mut protocols2 = Protocols::new();
protocols2
.add([TEST_PROTOCOL.clone()], &test_sender)
.add([ANOTHER_TEST_PROTOCOL.clone()], &another_test_sender);
let mut shutdown = Shutdown::new();
let (comms_node1, _, _, _) = spawn_node(protocols1, shutdown.to_signal()).await;
let (comms_node2, _, _, _) = spawn_node(protocols2, shutdown.to_signal()).await;
let node_identity1 = comms_node1.node_identity();
let node_identity2 = comms_node2.node_identity();
comms_node1
.peer_manager()
.add_or_update_peer(Peer::new(
node_identity2.public_key().clone(),
node_identity2.node_id().clone(),
MultiaddressesWithStats::from_addresses_with_source(
node_identity2.public_addresses().clone(),
&PeerAddressSource::Config,
),
Default::default(),
Default::default(),
vec![TEST_PROTOCOL.clone(), ANOTHER_TEST_PROTOCOL.clone()],
Default::default(),
))
.await
.unwrap();
let mut conn_man_events1 = comms_node1.subscribe_connection_manager_events();
let conn_man_requester1 = comms_node1.connectivity();
let mut conn_man_events2 = comms_node2.subscribe_connection_manager_events();
let mut conn1 = conn_man_requester1
.dial_peer(node_identity2.node_id().clone())
.await
.unwrap();
let next_event = conn_man_events2.recv().await.unwrap();
unpack_enum!(ConnectionManagerEvent::PeerConnected(conn2) = &*next_event);
let next_event = conn_man_events1.recv().await.unwrap();
unpack_enum!(ConnectionManagerEvent::PeerConnected(_conn) = &*next_event);
let mut negotiated_substream1 = conn1.open_substream(&TEST_PROTOCOL).await.unwrap();
assert_eq!(negotiated_substream1.protocol, TEST_PROTOCOL);
negotiated_substream1.stream.write_all(TEST_MSG).await.unwrap();
let mut negotiated_substream2 = conn2.clone().open_substream(&ANOTHER_TEST_PROTOCOL).await.unwrap();
assert_eq!(negotiated_substream2.protocol, ANOTHER_TEST_PROTOCOL);
negotiated_substream2.stream.write_all(ANOTHER_TEST_MSG).await.unwrap();
let negotiation = test_protocol_rx2.recv().await.unwrap();
assert_eq!(negotiation.protocol, TEST_PROTOCOL);
unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream) = negotiation.event);
assert_eq!(&node_id, node_identity1.node_id());
let mut buf = [0u8; TEST_MSG.len()];
substream.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, TEST_MSG);
let negotiation = another_test_protocol_rx1.recv().await.unwrap();
assert_eq!(negotiation.protocol, ANOTHER_TEST_PROTOCOL);
unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream) = negotiation.event);
assert_eq!(&node_id, node_identity2.node_id());
let mut buf = [0u8; ANOTHER_TEST_MSG.len()];
substream.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, ANOTHER_TEST_MSG);
shutdown.trigger();
comms_node1.wait_until_shutdown().await;
comms_node2.wait_until_shutdown().await;
}
#[tokio::test]
async fn peer_to_peer_messaging() {
const NUM_MSGS: usize = 100;
let shutdown = Shutdown::new();
let (comms_node1, mut inbound_rx1, outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await;
let (comms_node2, mut inbound_rx2, outbound_tx2, messaging_events2) =
spawn_node(Protocols::new(), shutdown.to_signal()).await;
let mut messaging_events2 = messaging_events2.subscribe();
let node_identity1 = comms_node1.node_identity();
let node_identity2 = comms_node2.node_identity();
let mut peer = Peer::new(
node_identity2.public_key().clone(),
node_identity2.node_id().clone(),
MultiaddressesWithStats::from_addresses_with_source(
node_identity2.public_addresses(),
&PeerAddressSource::Config,
),
Default::default(),
PeerFeatures::COMMUNICATION_NODE,
Default::default(),
Default::default(),
);
let addresses: Vec<_> = peer.addresses.address_iter().cloned().collect();
for addr in &addresses {
peer.addresses.mark_last_seen_now(addr);
}
comms_node1.peer_manager().add_or_update_peer(peer).await.unwrap();
let mut replies = FuturesUnordered::new();
for i in 0..NUM_MSGS {
let (reply_tx, reply_rx) = oneshot::channel();
replies.push(reply_rx);
let outbound_msg = OutboundMessage::with_reply(
node_identity2.node_id().clone(),
format!("#{i:0>3} - comms messaging is so hot right now!").into(),
reply_tx.into(),
);
outbound_tx1.send(outbound_msg).unwrap();
}
let messages1_to_2 = collect_recv!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10));
let send_results = collect_stream!(replies, take = NUM_MSGS, timeout = Duration::from_secs(10));
send_results.into_iter().for_each(|r| {
r.unwrap().unwrap();
});
let events = collect_recv!(messaging_events2, take = NUM_MSGS, timeout = Duration::from_secs(10));
events.into_iter().for_each(|m| {
unpack_enum!(MessagingEvent::MessageReceived(_n, _t) = &m);
});
for i in 0..NUM_MSGS {
let outbound_msg = OutboundMessage::new(
node_identity1.node_id().clone(),
format!("#{i:0>3} - comms messaging is so hot right now!").into(),
);
outbound_tx2.send(outbound_msg).unwrap();
}
let messages2_to_1 = collect_recv!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10));
let check_messages = |msgs: Vec<InboundMessage>| {
for (i, msg) in msgs.iter().enumerate() {
let expected_msg_prefix = format!("#{i:0>3}");
assert_eq!(&msg.body[0..4], expected_msg_prefix.as_bytes());
}
};
assert_eq!(messages1_to_2.len(), NUM_MSGS);
check_messages(messages1_to_2);
assert_eq!(messages2_to_1.len(), NUM_MSGS);
check_messages(messages2_to_1);
drop(shutdown);
comms_node1.wait_until_shutdown().await;
comms_node2.wait_until_shutdown().await;
}
#[tokio::test]
async fn peer_to_peer_messaging_simultaneous() {
const NUM_MSGS: usize = 100;
let shutdown = Shutdown::new();
let (comms_node1, mut inbound_rx1, outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await;
let (comms_node2, mut inbound_rx2, outbound_tx2, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await;
log::info!(
"Peer1 = `{}`, Peer2 = `{}`",
comms_node1.node_identity().node_id().short_str(),
comms_node2.node_identity().node_id().short_str()
);
let o1 = outbound_tx1.clone();
let o2 = outbound_tx2.clone();
let node_identity1 = comms_node1.node_identity().clone();
let node_identity2 = comms_node2.node_identity().clone();
comms_node1
.peer_manager()
.add_or_update_peer(Peer::new(
node_identity2.public_key().clone(),
node_identity2.node_id().clone(),
MultiaddressesWithStats::from_addresses_with_source(
node_identity2.public_addresses(),
&PeerAddressSource::Config,
),
Default::default(),
Default::default(),
Default::default(),
Default::default(),
))
.await
.unwrap();
comms_node2
.peer_manager()
.add_or_update_peer(Peer::new(
node_identity1.public_key().clone(),
node_identity1.node_id().clone(),
MultiaddressesWithStats::from_addresses_with_source(
node_identity1.public_addresses(),
&PeerAddressSource::Config,
),
Default::default(),
Default::default(),
Default::default(),
Default::default(),
))
.await
.unwrap();
comms_node1
.connectivity()
.dial_peer(comms_node2.node_identity().node_id().clone())
.await
.unwrap();
let handle1 = task::spawn(async move {
for i in 0..NUM_MSGS {
let outbound_msg = OutboundMessage::new(
node_identity2.node_id().clone(),
format!("#{i:0>3} - comms messaging is so hot right now!").into(),
);
outbound_tx1.send(outbound_msg).unwrap();
}
});
let handle2 = task::spawn(async move {
for i in 0..NUM_MSGS {
let outbound_msg = OutboundMessage::new(
node_identity1.node_id().clone(),
format!("#{i:0>3} - comms messaging is so hot right now!").into(),
);
outbound_tx2.send(outbound_msg).unwrap();
}
});
handle1.await.unwrap();
handle2.await.unwrap();
let messages1_to_2 = collect_recv!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10));
let messages2_to_1 = collect_recv!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10));
assert!(has_unique_elements(messages1_to_2.into_iter().map(|m| m.body)));
assert!(has_unique_elements(messages2_to_1.into_iter().map(|m| m.body)));
drop(o1);
drop(o2);
drop(shutdown);
comms_node1.wait_until_shutdown().await;
comms_node2.wait_until_shutdown().await;
}
fn has_unique_elements<T>(iter: T) -> bool
where
T: IntoIterator,
T::Item: Eq + Hash,
{
let mut uniq = HashSet::new();
iter.into_iter().all(move |x| uniq.insert(x))
}