use std::{
str::FromStr,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use tokio::{
runtime::Handle,
sync::{Mutex, mpsc},
};
use tokio_stream::StreamExt;
use crate::{
connection_manager::{
ConnectionDirection,
NegotiatedSubstream,
PeerConnection,
PeerConnectionError,
PeerConnectionRequest,
},
multiaddr::Multiaddr,
multiplexing,
multiplexing::{IncomingSubstreams, Substream, Yamux, YamuxControlError},
peer_manager::{NodeId, Peer, PeerFeatures},
test_utils::{node_identity::build_node_identity, transport},
utils::atomic_ref_counter::AtomicRefCounter,
};
static ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
pub fn create_dummy_peer_connection(node_id: NodeId) -> (PeerConnection, mpsc::Receiver<PeerConnectionRequest>) {
let (tx, rx) = mpsc::channel(1);
let addr = Multiaddr::from_str("/ip4/23.23.23.23/tcp/80").unwrap();
(
PeerConnection::new(
1,
tx,
node_id,
PeerFeatures::COMMUNICATION_NODE,
addr,
ConnectionDirection::Inbound,
AtomicRefCounter::new(),
),
rx,
)
}
pub async fn create_peer_connection_mock_pair(
peer1: Peer,
peer2: Peer,
) -> (
PeerConnection,
PeerConnectionMockState,
PeerConnection,
PeerConnectionMockState,
) {
let rt_handle = Handle::current();
let (tx1, rx1) = mpsc::channel(1);
let (tx2, rx2) = mpsc::channel(1);
let (listen_addr, muxer_in, muxer_out) = transport::build_multiplexed_connections().await;
let mock = PeerConnectionMock::new(rx1, muxer_in);
let mock_state_in = mock.get_shared_state();
rt_handle.spawn(mock.run());
let mock = PeerConnectionMock::new(rx2, muxer_out);
let mock_state_out = mock.get_shared_state();
rt_handle.spawn(mock.run());
(
PeerConnection::new(
ID_COUNTER.fetch_add(1, Ordering::SeqCst),
tx1,
peer2.node_id,
peer2.features,
listen_addr.clone(),
ConnectionDirection::Inbound,
mock_state_in.substream_counter(),
),
mock_state_in,
PeerConnection::new(
ID_COUNTER.fetch_add(1, Ordering::SeqCst),
tx2,
peer1.node_id,
peer1.features,
listen_addr,
ConnectionDirection::Outbound,
mock_state_out.substream_counter(),
),
mock_state_out,
)
}
pub async fn new_peer_connection_mock_pair() -> (
PeerConnection,
PeerConnectionMockState,
PeerConnection,
PeerConnectionMockState,
) {
let peer1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE).to_peer();
let peer2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE).to_peer();
create_peer_connection_mock_pair(peer1, peer2).await
}
#[derive(Clone)]
pub struct PeerConnectionMockState {
call_count: Arc<AtomicUsize>,
mux_control: Arc<Mutex<multiplexing::Control>>,
mux_incoming: Arc<Mutex<IncomingSubstreams>>,
substream_counter: AtomicRefCounter,
}
impl PeerConnectionMockState {
pub fn new(muxer: Yamux) -> Self {
let control = muxer.get_yamux_control();
let substream_counter = muxer.substream_counter();
Self {
call_count: Arc::new(AtomicUsize::new(0)),
mux_control: Arc::new(Mutex::new(control)),
mux_incoming: Arc::new(Mutex::new(muxer.into_incoming())),
substream_counter,
}
}
pub fn inc_call_count(&self) {
self.call_count.fetch_add(1, Ordering::SeqCst);
}
pub fn call_count(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
pub async fn open_substream(&self) -> Result<Substream, PeerConnectionError> {
self.mux_control.lock().await.open_stream().await.map_err(Into::into)
}
pub fn substream_counter(&self) -> AtomicRefCounter {
self.substream_counter.clone()
}
pub fn num_open_substreams(&self) -> usize {
self.substream_counter.get()
}
pub async fn next_incoming_substream(&self) -> Option<Substream> {
self.mux_incoming.lock().await.next().await
}
pub async fn disconnect(&self) -> Result<(), PeerConnectionError> {
match self.mux_control.lock().await.close().await {
Err(YamuxControlError::ConnectionClosed) => Ok(()),
Err(_err) => Ok(()), Ok(_) => Ok(()),
}
}
}
pub struct PeerConnectionMock {
receiver: mpsc::Receiver<PeerConnectionRequest>,
state: PeerConnectionMockState,
}
impl PeerConnectionMock {
pub fn new(receiver: mpsc::Receiver<PeerConnectionRequest>, muxer: Yamux) -> Self {
Self {
receiver,
state: PeerConnectionMockState::new(muxer),
}
}
pub fn get_shared_state(&self) -> PeerConnectionMockState {
self.state.clone()
}
pub async fn run(mut self) {
while let Some(req) = self.receiver.recv().await {
self.handle_request(req).await;
}
}
async fn handle_request(&mut self, req: PeerConnectionRequest) {
use PeerConnectionRequest::{Disconnect, OpenSubstream};
self.state.inc_call_count();
match req {
OpenSubstream { protocol_id, reply_tx } => match self.state.open_substream().await {
Ok(stream) => {
let negotiated_substream = NegotiatedSubstream {
protocol: protocol_id,
stream,
};
reply_tx.send(Ok(negotiated_substream)).unwrap();
},
Err(err) => {
reply_tx.send(Err(err)).unwrap();
},
},
Disconnect(_, reply_tx, _minimized, _requester) => {
self.receiver.close();
reply_tx.send(self.state.disconnect().await).unwrap();
},
}
}
}