#[cfg(all(feature = "tokio", feature = "tls"))]
mod tokio_tests {
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::oneshot,
task
};
use protwrap::tokio::{
client::{Connector, TlsTcpConnInfo},
server::listener::{
async_trait, Acceptor, KillSwitch, Listener, SockAddr,
TlsTcpListenerInfo
},
ServerStream
};
struct MyServer {
tx_port: Option<oneshot::Sender<u16>>,
ks: KillSwitch
}
#[async_trait]
impl Acceptor for MyServer {
async fn bound(&mut self, _listener: &Listener, sa: SockAddr) {
let port = sa.unwrap_std().port();
let Some(tx) = self.tx_port.take() else {
panic!("Channel end-point missing");
};
tx.send(port).unwrap();
}
async fn unbound(&mut self, _listener: &Listener) {}
async fn connected(&mut self, _sa: SockAddr, strm: ServerStream) {
tokio::task::spawn(Self::handle_connection(strm, self.ks.clone()))
.await
.unwrap();
}
}
impl MyServer {
async fn handle_connection(mut strm: ServerStream, ks: KillSwitch) {
let mut buf = [0u8; 5];
let n = strm.read_exact(&mut buf[..]).await.unwrap();
assert_eq!(n, 5);
assert_eq!(buf, *b"hello");
strm.write_all(b"world").await.unwrap();
strm.flush().await.unwrap();
ks.trigger();
}
}
#[tokio::test]
async fn client_server() {
let (ca, srv) = task::spawn_blocking(init_pki).await.unwrap();
let (srv_key_pem, srv_cert_pem) = srv.to_pem().unwrap();
let tlsinfo = TlsTcpListenerInfo {
addr: "127.0.0.1:0".into(),
srv_key_pem,
srv_cert_pem
};
let listener = Listener::from(tlsinfo);
let ks = KillSwitch::new();
let (tx, rx) = oneshot::channel();
let acceptor = MyServer {
tx_port: Some(tx),
ks: ks.clone()
};
let killswitch = ks.clone();
let jh_server = tokio::task::spawn(async move {
listener.run(killswitch, acceptor).await.unwrap();
});
let (_, ca_cert_pem) = ca.to_pem().unwrap();
let jh_client = tokio::task::spawn(async move {
let port = rx.await.unwrap();
let addr = format!("127.0.0.1:{port}");
let conninfo = TlsTcpConnInfo {
addr,
host: "localhost".into(),
ca_cert_pem
};
let c = Connector::from(conninfo);
let mut strm = c.connect().await.unwrap();
let n = strm.write(b"hello").await.unwrap();
assert_eq!(n, 5);
strm.flush().await.unwrap();
let mut buf = [0u8; 5];
let n = strm.read_exact(&mut buf[..]).await.unwrap();
assert_eq!(n, 5);
assert_eq!(buf, *b"world");
});
ks.wait().await;
jh_client.await.unwrap();
jh_server.await.unwrap();
}
fn init_pki() -> (quickcert::PkiIdent, quickcert::PkiIdent) {
let ca = quickcert::mk_ca().unwrap();
let srv = quickcert::mk_server(&ca, "server", ["localhost"]).unwrap();
(ca, srv)
}
}