#![allow(clippy::indexing_slicing)]
use std::{fmt, io, sync::Arc};
use futures::{Sink, SinkExt, Stream, StreamExt};
use tari_common::configuration::Network;
use tari_comms::{
Bytes,
BytesMut,
connectivity::ConnectivityEvent,
framing,
memsocket::MemorySocket,
message::MessageExt,
peer_manager::PeerFeatures,
protocol::{ProtocolEvent, ProtocolNotification, ProtocolNotificationTx},
test_utils::{
mocks::{ConnectivityManagerMockState, create_connectivity_mock, create_peer_connection_mock_pair},
node_identity::build_node_identity,
},
};
use tari_transaction_components::{
key_manager::KeyManager,
tari_amount::uT,
test_helpers::create_tx,
transaction_components::Transaction,
};
use tari_utilities::ByteArray;
use tokio::{
sync::{broadcast, mpsc},
task,
};
use crate::{
consensus::BaseNodeConsensusManager,
mempool::{
Mempool,
proto,
sync_protocol::{MAX_FRAME_SIZE, MEMPOOL_SYNC_PROTOCOL, MempoolPeerProtocol, MempoolSyncProtocol},
},
validation::mocks::MockValidator,
};
pub fn create_transactions(n: usize) -> Vec<Transaction> {
let key_manager = KeyManager::new_random().unwrap();
let mut transactions = Vec::new();
for _i in 0..n {
let (transaction, _, _) = create_tx(5000 * uT, 3 * uT, 1, 2, 1, 3, Default::default(), &key_manager)
.expect("Failed to get transaction");
transactions.push(transaction);
}
transactions
}
async fn new_mempool_with_transactions(n: usize) -> (Mempool, Vec<Transaction>) {
let mempool = Mempool::new(
Default::default(),
BaseNodeConsensusManager::builder(Network::LocalNet).build().unwrap(),
Box::new(MockValidator::new(true)),
);
let transactions = create_transactions(n);
for txn in &transactions {
mempool.insert(Arc::new(txn.clone())).await.unwrap();
}
(mempool, transactions)
}
async fn setup(
num_txns: usize,
) -> (
ProtocolNotificationTx<MemorySocket>,
ConnectivityManagerMockState,
Mempool,
Vec<Transaction>,
) {
let (protocol_notif_tx, protocol_notif_rx) = mpsc::channel(1);
let (mempool, transactions) = new_mempool_with_transactions(num_txns).await;
let (connectivity, connectivity_manager_mock) = create_connectivity_mock();
let connectivity_manager_mock_state = connectivity_manager_mock.spawn();
let (block_event_sender, _) = broadcast::channel(1);
let block_receiver = block_event_sender.subscribe();
let protocol = MempoolSyncProtocol::new(
Default::default(),
protocol_notif_rx,
mempool.clone(),
connectivity,
block_receiver,
);
task::spawn(protocol.run());
connectivity_manager_mock_state.wait_until_event_receivers_ready().await;
(
protocol_notif_tx,
connectivity_manager_mock_state,
mempool,
transactions,
)
}
#[tokio::test]
async fn empty_set() {
let (_, connectivity_manager_state, mempool1, _) = setup(0).await;
let node1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let node2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let (_node1_conn, node1_mock, node2_conn, _) =
create_peer_connection_mock_pair(node1.to_peer(), node2.to_peer()).await;
connectivity_manager_state.publish_event(ConnectivityEvent::PeerConnected(node2_conn.into()));
let substream = node1_mock.next_incoming_substream().await.unwrap();
let framed = framing::canonical(substream, MAX_FRAME_SIZE);
let (mempool2, _) = new_mempool_with_transactions(0).await;
MempoolPeerProtocol::new(Default::default(), framed, node2.node_id().clone(), mempool2.clone())
.start_responder()
.await
.unwrap();
let transactions = mempool2.snapshot().await.unwrap();
assert_eq!(transactions.len(), 0);
let transactions = mempool1.snapshot().await.unwrap();
assert_eq!(transactions.len(), 0);
}
#[tokio::test]
async fn synchronise() {
let (_, connectivity_manager_state, mempool1, transactions1) = setup(5).await;
let node1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let node2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let (_node1_conn, node1_mock, node2_conn, _) =
create_peer_connection_mock_pair(node1.to_peer(), node2.to_peer()).await;
connectivity_manager_state.publish_event(ConnectivityEvent::PeerConnected(node2_conn.into()));
let substream = node1_mock.next_incoming_substream().await.unwrap();
let framed = framing::canonical(substream, MAX_FRAME_SIZE);
let (mempool2, transactions2) = new_mempool_with_transactions(3).await;
MempoolPeerProtocol::new(Default::default(), framed, node2.node_id().clone(), mempool2.clone())
.start_responder()
.await
.unwrap();
let transactions = get_snapshot(&mempool2).await;
assert_eq!(transactions.len(), 8);
assert!(transactions1.iter().all(|txn| transactions.contains(txn)));
assert!(transactions2.iter().all(|txn| transactions.contains(txn)));
let transactions = get_snapshot(&mempool1).await;
assert_eq!(transactions.len(), 8);
assert!(transactions1.iter().all(|txn| transactions.contains(txn)));
assert!(transactions2.iter().all(|txn| transactions.contains(txn)));
}
#[tokio::test]
async fn duplicate_set() {
let (_, connectivity_manager_state, mempool1, transactions1) = setup(2).await;
let node1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let node2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let (_node1_conn, node1_mock, node2_conn, _) =
create_peer_connection_mock_pair(node1.to_peer(), node2.to_peer()).await;
connectivity_manager_state.publish_event(ConnectivityEvent::PeerConnected(node2_conn.into()));
let substream = node1_mock.next_incoming_substream().await.unwrap();
let framed = framing::canonical(substream, MAX_FRAME_SIZE);
let (mempool2, transactions2) = new_mempool_with_transactions(1).await;
mempool2.insert(Arc::new(transactions1[0].clone())).await.unwrap();
MempoolPeerProtocol::new(Default::default(), framed, node2.node_id().clone(), mempool2.clone())
.start_responder()
.await
.unwrap();
let transactions = get_snapshot(&mempool2).await;
assert_eq!(transactions.len(), 3);
assert!(transactions1.iter().all(|txn| transactions.contains(txn)));
assert!(transactions2.iter().all(|txn| transactions.contains(txn)));
let transactions = get_snapshot(&mempool1).await;
assert_eq!(transactions.len(), 3);
assert!(transactions1.iter().all(|txn| transactions.contains(txn)));
assert!(transactions2.iter().all(|txn| transactions.contains(txn)));
}
#[tokio::test]
async fn responder() {
let (protocol_notif, _, _, transactions1) = setup(2).await;
let node1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let node2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let (sock_in, sock_out) = MemorySocket::new_pair();
protocol_notif
.send(ProtocolNotification::new(
MEMPOOL_SYNC_PROTOCOL.clone(),
ProtocolEvent::NewInboundSubstream(node1.node_id().clone(), sock_in),
))
.await
.unwrap();
let (mempool2, transactions2) = new_mempool_with_transactions(1).await;
mempool2.insert(Arc::new(transactions1[0].clone())).await.unwrap();
let framed = framing::canonical(sock_out, MAX_FRAME_SIZE);
MempoolPeerProtocol::new(Default::default(), framed, node2.node_id().clone(), mempool2.clone())
.start_initiator()
.await
.unwrap();
let transactions = get_snapshot(&mempool2).await;
assert_eq!(transactions.len(), 3);
assert!(transactions1.iter().all(|txn| transactions.contains(txn)));
assert!(transactions2.iter().all(|txn| transactions.contains(txn)));
}
#[tokio::test]
async fn initiator_messages() {
let (protocol_notif, _, _, transactions1) = setup(2).await;
let node1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let (sock_in, sock_out) = MemorySocket::new_pair();
protocol_notif
.send(ProtocolNotification::new(
MEMPOOL_SYNC_PROTOCOL.clone(),
ProtocolEvent::NewInboundSubstream(node1.node_id().clone(), sock_in),
))
.await
.unwrap();
let mut transactions = create_transactions(2);
transactions.push(transactions1[0].clone());
let mut framed = framing::canonical(sock_out, MAX_FRAME_SIZE);
let inventory = proto::TransactionInventory {
items: transactions
.iter()
.map(|tx| tx.first_kernel_excess_sig().unwrap().get_signature().to_vec())
.collect(),
};
write_message(&mut framed, inventory).await;
let transaction: proto::TransactionItem = read_message(&mut framed).await;
assert!(transaction.transaction.is_some());
let stop: proto::TransactionItem = read_message(&mut framed).await;
assert!(stop.transaction.is_none());
let indexes: proto::InventoryIndexes = read_message(&mut framed).await;
assert_eq!(indexes.indexes, [0, 1]);
}
#[tokio::test]
async fn responder_messages() {
let (_, connectivity_manager_state, _, transactions1) = setup(1).await;
let node1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let node2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let (_node1_conn, node1_mock, node2_conn, _) =
create_peer_connection_mock_pair(node1.to_peer(), node2.to_peer()).await;
connectivity_manager_state.publish_event(ConnectivityEvent::PeerConnected(node2_conn.into()));
let substream = node1_mock.next_incoming_substream().await.unwrap();
let mut framed = framing::canonical(substream, MAX_FRAME_SIZE);
let inventory: proto::TransactionInventory = read_message(&mut framed).await;
assert_eq!(inventory.items.len(), 1);
let nothing = proto::TransactionItem::empty();
write_message(&mut framed, nothing).await;
let indexes = proto::InventoryIndexes { indexes: vec![0] };
write_message(&mut framed, indexes).await;
let transaction: proto::TransactionItem = read_message(&mut framed).await;
assert_eq!(
transaction
.transaction
.unwrap()
.body
.unwrap()
.kernels
.remove(0)
.excess_sig
.unwrap()
.signature,
transactions1[0]
.first_kernel_excess_sig()
.unwrap()
.get_signature()
.to_vec()
);
let stop: proto::TransactionItem = read_message(&mut framed).await;
assert!(stop.transaction.is_none());
assert!(framed.next().await.is_none());
}
async fn get_snapshot(mempool: &Mempool) -> Vec<Transaction> {
mempool
.snapshot()
.await
.unwrap()
.iter()
.map(|t| &**t)
.cloned()
.collect()
}
async fn read_message<S, T>(reader: &mut S) -> T
where
S: Stream<Item = io::Result<BytesMut>> + Unpin,
T: prost::Message + Default,
{
let msg = reader.next().await.unwrap().unwrap();
T::decode(&mut msg.freeze()).unwrap()
}
async fn write_message<S, T>(writer: &mut S, message: T)
where
S: Sink<Bytes> + Unpin,
S::Error: fmt::Debug,
T: prost::Message,
{
writer.send(message.to_encoded_bytes().into()).await.unwrap();
}