use std::fmt;
use std::marker::Unpin;
use crate::handshake::handshake;
use crate::runtime::{AsyncRead, AsyncReadExt, AsyncWrite};
use crate::TlsStream;
#[derive(Clone)]
pub struct TlsAcceptor(native_tls::TlsAcceptor);
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("NativeTls({})", 0)]
NativeTls(#[from] native_tls::Error),
#[error("Io({})", 0)]
Io(#[from] std::io::Error),
}
impl TlsAcceptor {
pub async fn new<R, S>(mut file: R, password: S) -> Result<Self, Error>
where
R: AsyncRead + Unpin,
S: AsRef<str>,
{
let mut identity = vec![];
file.read_to_end(&mut identity).await?;
let identity = native_tls::Identity::from_pkcs12(&identity, password.as_ref())?;
Ok(TlsAcceptor(native_tls::TlsAcceptor::new(identity)?))
}
pub async fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, native_tls::Error>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let stream = handshake(move |s| self.0.accept(s), stream).await?;
Ok(stream)
}
}
impl fmt::Debug for TlsAcceptor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsAcceptor").finish()
}
}
impl From<native_tls::TlsAcceptor> for TlsAcceptor {
fn from(inner: native_tls::TlsAcceptor) -> TlsAcceptor {
TlsAcceptor(inner)
}
}
#[cfg(all(test, feature = "runtime-async-std"))]
mod tests {
use super::*;
use crate::runtime::AsyncWriteExt;
use crate::TlsConnector;
use async_std::fs::File;
use async_std::net::{TcpListener, TcpStream};
use async_std::stream::StreamExt;
#[async_std::test]
async fn test_acceptor() {
let key = File::open("tests/identity.pfx").await.unwrap();
let acceptor = TlsAcceptor::new(key, "hello").await.unwrap();
let listener = TcpListener::bind("127.0.0.1:8443").await.unwrap();
async_std::task::spawn(async move {
let mut incoming = listener.incoming();
while let Some(stream) = incoming.next().await {
let acceptor = acceptor.clone();
let stream = stream.unwrap();
async_std::task::spawn(async move {
let mut stream = acceptor.accept(stream).await.unwrap();
stream.write_all(b"hello").await.unwrap();
});
}
});
let stream = TcpStream::connect("127.0.01:8443").await.unwrap();
let connector = TlsConnector::new().danger_accept_invalid_certs(true);
let mut stream = connector.connect("127.0.0.1", stream).await.unwrap();
let mut res = Vec::new();
stream.read_to_end(&mut res).await.unwrap();
assert_eq!(res, b"hello");
}
}