async_native_tls/
acceptor.rs1use std::fmt;
2use std::marker::Unpin;
3
4use crate::handshake::handshake;
5use crate::runtime::{AsyncRead, AsyncReadExt, AsyncWrite};
6use crate::TlsStream;
7
8#[derive(Clone)]
41pub struct TlsAcceptor(native_tls::TlsAcceptor);
42
43#[derive(thiserror::Error, Debug)]
45pub enum Error {
46 #[error("NativeTls({})", 0)]
48 NativeTls(#[from] native_tls::Error),
49 #[error("Io({})", 0)]
51 Io(#[from] std::io::Error),
52}
53
54impl TlsAcceptor {
55 pub async fn new<R, S>(mut file: R, password: S) -> Result<Self, Error>
57 where
58 R: AsyncRead + Unpin,
59 S: AsRef<str>,
60 {
61 let mut identity = vec![];
62 file.read_to_end(&mut identity).await?;
63
64 let identity = native_tls::Identity::from_pkcs12(&identity, password.as_ref())?;
65 Ok(TlsAcceptor(native_tls::TlsAcceptor::new(identity)?))
66 }
67
68 pub async fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, native_tls::Error>
79 where
80 S: AsyncRead + AsyncWrite + Unpin,
81 {
82 let stream = handshake(move |s| self.0.accept(s), stream).await?;
83 Ok(stream)
84 }
85}
86
87impl fmt::Debug for TlsAcceptor {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 f.debug_struct("TlsAcceptor").finish()
90 }
91}
92
93impl From<native_tls::TlsAcceptor> for TlsAcceptor {
94 fn from(inner: native_tls::TlsAcceptor) -> TlsAcceptor {
95 TlsAcceptor(inner)
96 }
97}
98
99#[cfg(all(test, feature = "runtime-smol"))]
100mod tests {
101 use super::*;
102 use crate::runtime::AsyncWriteExt;
103 use crate::TlsConnector;
104 use smol::fs::File;
105 use smol::net::{TcpListener, TcpStream};
106 use smol::stream::StreamExt;
107
108 #[test]
109 fn test_acceptor() {
110 smol::block_on(async {
111 let key = File::open("tests/identity.pfx").await.unwrap();
112 let acceptor = TlsAcceptor::new(key, "hello").await.unwrap();
113 let listener = TcpListener::bind("127.0.0.1:8443").await.unwrap();
114 smol::spawn(async move {
115 let mut incoming = listener.incoming();
116
117 while let Some(stream) = incoming.next().await {
118 let acceptor = acceptor.clone();
119 let stream = stream.unwrap();
120 smol::spawn(async move {
121 let mut stream = acceptor.accept(stream).await.unwrap();
122 stream.write_all(b"hello").await.unwrap();
123 })
124 .detach();
125 }
126 })
127 .detach();
128
129 let stream = TcpStream::connect("127.0.01:8443").await.unwrap();
130 let connector = TlsConnector::new().danger_accept_invalid_certs(true);
131
132 let mut stream = connector.connect("127.0.0.1", stream).await.unwrap();
133 let mut res = Vec::new();
134 stream.read_to_end(&mut res).await.unwrap();
135 assert_eq!(res, b"hello");
136 })
137 }
138}