#![cfg(feature = "unstable")]
use futures::StreamExt;
use std::{
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
mpsc::channel,
},
thread,
time::Duration,
};
use tentacle::{
SubstreamReadPart,
builder::{MetaBuilder, ServiceBuilder},
context::SessionContext,
multiaddr::Multiaddr,
secio::SecioKeyPair,
service::{ProtocolMeta, Service, ServiceAsyncControl, TargetProtocol},
traits::{ProtocolSpawn, ServiceHandle},
};
fn create_service<F>(secio: bool, meta: ProtocolMeta, shandle: F) -> Service<F, SecioKeyPair>
where
F: ServiceHandle + Unpin + 'static,
{
let builder = ServiceBuilder::default().insert_protocol(meta);
if secio {
builder
.handshake_type(SecioKeyPair::secp256k1_generated().into())
.build(shandle)
} else {
builder.build(shandle)
}
}
fn run_pair(
secio: bool,
listener_meta: ProtocolMeta,
dialer_meta: ProtocolMeta,
timeout: Duration,
) {
let (addr_sender, addr_receiver) = channel::<Multiaddr>();
let listener_thread = thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let mut service = create_service(secio, listener_meta, ());
rt.block_on(async move {
let listen_addr = service
.listen("/ip4/127.0.0.1/tcp/0".parse().unwrap())
.await
.unwrap();
addr_sender.send(listen_addr).unwrap();
let _ignore = tokio::time::timeout(timeout, service.run()).await;
});
});
let dialer_thread = thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let mut service = create_service(secio, dialer_meta, ());
rt.block_on(async move {
let listen_addr = addr_receiver.recv().unwrap();
service
.dial(listen_addr, TargetProtocol::Single(1.into()))
.await
.unwrap();
let _ignore = tokio::time::timeout(timeout, service.run()).await;
});
});
listener_thread.join().unwrap();
dialer_thread.join().unwrap();
}
#[derive(Clone)]
struct FirstMessageProbe {
received: Arc<AtomicUsize>,
}
impl ProtocolSpawn for FirstMessageProbe {
fn spawn(
&self,
context: Arc<SessionContext>,
control: &ServiceAsyncControl,
mut read_part: SubstreamReadPart,
) {
let session_id = context.id;
let proto_id = read_part.protocol_id();
let send_control = control.clone();
tokio::spawn(async move {
let _ignore = send_control
.send_message_to(session_id, proto_id, b"init".to_vec().into())
.await;
});
let received = self.received.clone();
let shutdown_control = control.clone();
tokio::spawn(async move {
let _ignore = tokio::time::timeout(Duration::from_secs(10), async {
while let Some(Ok(data)) = read_part.next().await {
if data.as_ref() == b"init" {
received.fetch_add(1, Ordering::SeqCst);
break;
}
}
})
.await;
tokio::time::sleep(Duration::from_secs(1)).await;
let _ignore = shutdown_control.shutdown().await;
});
}
}
fn run_first_message_test(secio: bool, iterations: usize) {
for _i in 0..iterations {
let listener_received = Arc::new(AtomicUsize::new(0));
let dialer_received = Arc::new(AtomicUsize::new(0));
let meta_listener = MetaBuilder::new()
.id(1.into())
.protocol_spawn(FirstMessageProbe {
received: listener_received.clone(),
})
.build();
let meta_dialer = MetaBuilder::new()
.id(1.into())
.protocol_spawn(FirstMessageProbe {
received: dialer_received.clone(),
})
.build();
run_pair(secio, meta_listener, meta_dialer, Duration::from_secs(15));
let lr = listener_received.load(Ordering::SeqCst);
let dr = dialer_received.load(Ordering::SeqCst);
assert_eq!(lr, 1, "listener did not receive first message");
assert_eq!(dr, 1, "dialer did not receive first message");
}
}
#[test]
fn test_spawn_first_message_with_no_secio() {
run_first_message_test(false, 3);
}
#[test]
fn test_spawn_first_message_with_secio() {
run_first_message_test(true, 3);
}
const MULTI_MSG_COUNT: usize = 100;
#[derive(Clone)]
struct MultiMessageProbe {
received: Arc<AtomicUsize>,
total_done: Arc<AtomicUsize>,
}
impl ProtocolSpawn for MultiMessageProbe {
fn spawn(
&self,
context: Arc<SessionContext>,
control: &ServiceAsyncControl,
mut read_part: SubstreamReadPart,
) {
let session_id = context.id;
let proto_id = read_part.protocol_id();
let send_control = control.clone();
tokio::spawn(async move {
for i in 0..MULTI_MSG_COUNT {
let msg = format!("msg-{i}");
if let Err(_e) = send_control
.send_message_to(session_id, proto_id, msg.into_bytes().into())
.await
{
break;
}
}
});
let received = self.received.clone();
let total_done = self.total_done.clone();
let shutdown_control = control.clone();
tokio::spawn(async move {
let _ignore = tokio::time::timeout(Duration::from_secs(30), async {
let mut count = 0usize;
while let Some(Ok(_data)) = read_part.next().await {
count += 1;
received.fetch_add(1, Ordering::SeqCst);
if count >= MULTI_MSG_COUNT {
break;
}
}
})
.await;
total_done.fetch_add(1, Ordering::SeqCst);
for _ in 0..200 {
if total_done.load(Ordering::SeqCst) >= 2 {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
let _ignore = shutdown_control.shutdown().await;
});
}
}
fn run_multi_message_test(secio: bool, iterations: usize) {
for _i in 0..iterations {
let listener_received = Arc::new(AtomicUsize::new(0));
let dialer_received = Arc::new(AtomicUsize::new(0));
let total_done = Arc::new(AtomicUsize::new(0));
let meta_listener = MetaBuilder::new()
.id(1.into())
.protocol_spawn(MultiMessageProbe {
received: listener_received.clone(),
total_done: total_done.clone(),
})
.build();
let meta_dialer = MetaBuilder::new()
.id(1.into())
.protocol_spawn(MultiMessageProbe {
received: dialer_received.clone(),
total_done: total_done.clone(),
})
.build();
run_pair(secio, meta_listener, meta_dialer, Duration::from_secs(60));
let lr = listener_received.load(Ordering::SeqCst);
let dr = dialer_received.load(Ordering::SeqCst);
assert_eq!(
lr, MULTI_MSG_COUNT,
"listener received {lr}/{MULTI_MSG_COUNT} messages",
);
assert_eq!(
dr, MULTI_MSG_COUNT,
"dialer received {dr}/{MULTI_MSG_COUNT} messages",
);
}
}
#[test]
fn test_spawn_multi_message_with_no_secio() {
run_multi_message_test(false, 3);
}
#[test]
fn test_spawn_multi_message_with_secio() {
run_multi_message_test(true, 3);
}