use std::{sync::mpsc::channel, thread};
use tentacle::{
async_trait,
builder::{MetaBuilder, ServiceBuilder},
context::{ProtocolContext, ProtocolContextMutRef, ServiceContext},
multiaddr::Multiaddr,
secio::SecioKeyPair,
service::{ProtocolHandle, ProtocolMeta, Service, ServiceEvent, TargetProtocol},
traits::{ServiceHandle, ServiceProtocol, SessionProtocol},
};
#[derive(Clone)]
struct PHandle;
#[async_trait]
impl SessionProtocol for PHandle {
async fn connected(&mut self, context: ProtocolContextMutRef<'_>, _version: &str) {
if context.session.ty.is_inbound() {
let _res = context.disconnect(context.session.id).await;
}
}
}
struct Dummy;
#[async_trait]
impl ServiceProtocol for Dummy {
async fn init(&mut self, _context: &mut ProtocolContext) {}
}
struct SHandle {
count: usize,
addr: Option<Multiaddr>,
}
#[async_trait]
impl ServiceHandle for SHandle {
async fn handle_event(&mut self, control: &mut ServiceContext, event: ServiceEvent) {
match event {
ServiceEvent::SessionOpen { session_context } => {
self.addr = Some(session_context.address.clone());
if session_context.ty.is_outbound() {
control
.open_protocol(session_context.id, 1.into())
.await
.unwrap();
}
}
_ => {
if let ServiceEvent::SessionClose { session_context } = event {
if session_context.ty.is_outbound() {
self.count += 1;
if self.count >= 10 {
control.shutdown().await.unwrap();
} else {
let _res = control
.dial(self.addr.clone().unwrap(), TargetProtocol::Single(0.into()))
.await;
}
}
}
}
}
}
}
pub fn create<F>(
secio: bool,
metas: impl Iterator<Item = ProtocolMeta>,
shandle: F,
) -> Service<F, SecioKeyPair>
where
F: ServiceHandle + Unpin + 'static,
{
let mut builder = ServiceBuilder::default().forever(true);
for meta in metas {
builder = builder.insert_protocol(meta);
}
if secio {
builder
.handshake_type(SecioKeyPair::secp256k1_generated().into())
.build(shandle)
} else {
builder.build(shandle)
}
}
fn test_session_handle_open(secio: bool) {
let p_handle_1 = PHandle;
let s_handle_1 = SHandle {
count: 0,
addr: None,
};
let p_handle_2 = PHandle;
let s_handle_2 = SHandle {
count: 0,
addr: None,
};
let meta_dummy_1 = MetaBuilder::new()
.id(0.into())
.service_handle(move || {
let handle = Box::new(Dummy);
ProtocolHandle::Callback(handle)
})
.build();
let meta_dummy_2 = MetaBuilder::new()
.id(0.into())
.service_handle(move || {
let handle = Box::new(Dummy);
ProtocolHandle::Callback(handle)
})
.build();
let meta_1 = MetaBuilder::new()
.id(1.into())
.session_handle(move || {
let handle = Box::new(p_handle_1.clone());
ProtocolHandle::Callback(handle)
})
.build();
let meta_2 = MetaBuilder::new()
.id(1.into())
.session_handle(move || {
let handle = Box::new(p_handle_2.clone());
ProtocolHandle::Callback(handle)
})
.build();
let mut service_1 = create(secio, vec![meta_dummy_1, meta_1].into_iter(), s_handle_1);
let mut service_2 = create(secio, vec![meta_dummy_2, meta_2].into_iter(), s_handle_2);
let (addr_sender, addr_receiver) = channel::<Multiaddr>();
thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async move {
let listen_addr = service_2
.listen("/ip4/127.0.0.1/tcp/0".parse().unwrap())
.await
.unwrap();
addr_sender.send(listen_addr).unwrap();
service_2.run().await
});
});
let listen_addr = addr_receiver.recv().unwrap();
let handle = thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async move {
service_1
.dial(listen_addr, TargetProtocol::Single(0.into()))
.await
.unwrap();
service_1.run().await
});
});
handle.join().unwrap();
}
#[test]
fn test_session_handle_with_secio() {
test_session_handle_open(true)
}
#[test]
fn test_session_handle_with_no_secio() {
test_session_handle_open(false)
}