use std::fmt;
use std::marker::Unpin;
use crate::{
async_io::{
handshake::handshake,
runtime::{AsyncRead, AsyncReadExt, AsyncWrite},
TlsStream,
},
sync_io, Identity,
};
#[derive(Clone)]
pub struct TlsAcceptor(sync_io::TlsAcceptor);
impl TlsAcceptor {
pub async fn new<R, S>(mut file: R, password: S) -> crate::Result<Self>
where
R: AsyncRead + Unpin,
S: AsRef<str>,
{
let mut identity = vec![];
file.read_to_end(&mut identity).await?;
let identity = Identity::from_pkcs12(&identity, password.as_ref())?;
Ok(TlsAcceptor(sync_io::TlsAcceptor::new(identity)?))
}
pub async fn accept<S>(&self, stream: S) -> crate::Result<TlsStream<S>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let stream = handshake(move |s| self.0.accept(s), stream).await?;
Ok(stream)
}
}
impl fmt::Debug for TlsAcceptor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsAcceptor").finish()
}
}
impl From<sync_io::TlsAcceptor> for TlsAcceptor {
fn from(inner: sync_io::TlsAcceptor) -> TlsAcceptor {
TlsAcceptor(inner)
}
}
#[cfg(all(test, feature = "io-async-std"))]
mod tests {
use super::*;
use crate::async_io::{runtime::AsyncWriteExt, TlsConnector};
use async_std::{
fs::File,
net::{TcpListener, TcpStream},
stream::StreamExt,
};
#[async_std::test]
async fn test_acceptor() {
let key = File::open("tests/identity.pfx").await.unwrap();
let acceptor = TlsAcceptor::new(key, "hello").await.unwrap();
let listener = TcpListener::bind("127.0.0.1:8443").await.unwrap();
async_std::task::spawn(async move {
let mut incoming = listener.incoming();
while let Some(stream) = incoming.next().await {
let acceptor = acceptor.clone();
let stream = stream.unwrap();
async_std::task::spawn(async move {
let mut stream = acceptor.accept(stream).await.unwrap();
stream.write_all(b"hello").await.unwrap();
});
}
});
let stream = TcpStream::connect("127.0.01:8443").await.unwrap();
let connector = TlsConnector::new().danger_accept_invalid_certs(true);
let mut stream = connector.connect("127.0.0.1", stream).await.unwrap();
let mut res = Vec::new();
stream.read_to_end(&mut res).await.unwrap();
assert_eq!(res, b"hello");
}
}