use super::error::TlsError;
use super::stream::TlsStream;
use super::types::{CertificateChain, PrivateKey, RootCertStore};
use crate::io::{AsyncRead, AsyncWrite};
#[cfg(feature = "tls")]
use rustls::ServerConfig;
#[cfg(feature = "tls")]
use rustls::ServerConnection;
#[cfg(feature = "tls")]
use std::future::poll_fn;
use std::path::Path;
use std::sync::Arc;
#[derive(Clone)]
pub struct TlsAcceptor {
#[cfg(feature = "tls")]
config: Arc<ServerConfig>,
handshake_timeout: Option<std::time::Duration>,
alpn_required: bool,
#[cfg(not(feature = "tls"))]
_marker: std::marker::PhantomData<()>,
}
impl TlsAcceptor {
#[cfg(feature = "tls")]
pub fn new(config: ServerConfig) -> Self {
Self {
config: Arc::new(config),
handshake_timeout: None,
alpn_required: false,
}
}
pub fn builder(chain: CertificateChain, key: PrivateKey) -> TlsAcceptorBuilder {
TlsAcceptorBuilder::new(chain, key)
}
pub fn builder_from_pem(
cert_path: impl AsRef<Path>,
key_path: impl AsRef<Path>,
) -> Result<TlsAcceptorBuilder, TlsError> {
TlsAcceptorBuilder::from_pem_files(cert_path, key_path)
}
#[cfg(feature = "tls")]
pub fn config(&self) -> &Arc<ServerConfig> {
&self.config
}
#[must_use]
pub fn handshake_timeout(&self) -> Option<std::time::Duration> {
self.handshake_timeout
}
#[cfg(feature = "tls")]
pub async fn accept<IO>(&self, io: IO) -> Result<TlsStream<IO>, TlsError>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
let conn = ServerConnection::new(Arc::clone(&self.config))
.map_err(|e| TlsError::Configuration(e.to_string()))?;
let mut stream = TlsStream::new_server(io, conn);
if let Some(timeout) = self.handshake_timeout {
match crate::time::timeout(
super::timeout_now(),
timeout,
poll_fn(|cx| stream.poll_handshake(cx)),
)
.await
{
Ok(result) => result?,
Err(_) => return Err(TlsError::Timeout(timeout)),
}
} else {
poll_fn(|cx| stream.poll_handshake(cx)).await?;
}
if self.alpn_required {
let expected = self.config.alpn_protocols.clone();
let negotiated = stream.alpn_protocol().map(<[u8]>::to_vec);
let ok = negotiated
.as_deref()
.is_some_and(|p| expected.iter().any(|e| e.as_slice() == p));
if !ok {
return Err(TlsError::AlpnNegotiationFailed {
expected,
negotiated,
});
}
}
Ok(stream)
}
#[cfg(not(feature = "tls"))]
pub async fn accept<IO>(&self, _io: IO) -> Result<TlsStream<IO>, TlsError>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
Err(TlsError::Configuration("tls feature not enabled".into()))
}
}
impl std::fmt::Debug for TlsAcceptor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TlsAcceptor").finish_non_exhaustive()
}
}
#[derive(Debug, Clone, Default)]
pub enum ClientAuth {
#[default]
None,
Optional(RootCertStore),
Required(RootCertStore),
}
#[derive(Debug)]
pub struct TlsAcceptorBuilder {
cert_chain: CertificateChain,
key: PrivateKey,
client_auth: ClientAuth,
alpn_protocols: Vec<Vec<u8>>,
alpn_required: bool,
max_fragment_size: Option<usize>,
handshake_timeout: Option<std::time::Duration>,
}
impl TlsAcceptorBuilder {
pub fn new(chain: CertificateChain, key: PrivateKey) -> Self {
Self {
cert_chain: chain,
key,
client_auth: ClientAuth::None,
alpn_protocols: Vec::new(),
alpn_required: false,
max_fragment_size: None,
handshake_timeout: None,
}
}
pub fn from_pem_files(
cert_path: impl AsRef<Path>,
key_path: impl AsRef<Path>,
) -> Result<Self, TlsError> {
let chain = CertificateChain::from_pem_file(cert_path)?;
let key = PrivateKey::from_pem_file(key_path)?;
Ok(Self::new(chain, key))
}
pub fn client_auth(mut self, auth: ClientAuth) -> Self {
self.client_auth = auth;
self
}
pub fn require_client_auth(self, root_certs: RootCertStore) -> Self {
self.client_auth(ClientAuth::Required(root_certs))
}
pub fn optional_client_auth(self, root_certs: RootCertStore) -> Self {
self.client_auth(ClientAuth::Optional(root_certs))
}
pub fn alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
self.alpn_protocols = protocols;
self
}
pub fn require_alpn(mut self) -> Self {
self.alpn_required = true;
self
}
pub fn alpn_protocols_required(self, protocols: Vec<Vec<u8>>) -> Self {
self.alpn_protocols(protocols).require_alpn()
}
pub fn alpn_h2(self) -> Self {
self.alpn_protocols_required(vec![b"h2".to_vec()])
}
pub fn alpn_grpc(self) -> Self {
self.alpn_h2()
}
pub fn alpn_http(self) -> Self {
self.alpn_protocols(vec![b"h2".to_vec(), b"http/1.1".to_vec()])
}
pub fn max_fragment_size(mut self, size: usize) -> Self {
self.max_fragment_size = Some(size);
self
}
pub fn handshake_timeout(mut self, timeout: std::time::Duration) -> Self {
self.handshake_timeout = Some(timeout);
self
}
#[cfg(feature = "tls")]
pub fn build(self) -> Result<TlsAcceptor, TlsError> {
use rustls::crypto::ring::default_provider;
use rustls::server::WebPkiClientVerifier;
if self.alpn_required && self.alpn_protocols.is_empty() {
return Err(TlsError::Configuration(
"require_alpn set but no ALPN protocols configured".into(),
));
}
let builder = ServerConfig::builder_with_provider(Arc::new(default_provider()))
.with_safe_default_protocol_versions()
.map_err(|e| TlsError::Configuration(e.to_string()))?;
let builder = match self.client_auth {
ClientAuth::None => builder.with_no_client_auth(),
ClientAuth::Optional(roots) => {
let verifier = WebPkiClientVerifier::builder(Arc::new(roots.into_inner()))
.allow_unauthenticated()
.build()
.map_err(|e| TlsError::Configuration(e.to_string()))?;
builder.with_client_cert_verifier(verifier)
}
ClientAuth::Required(roots) => {
let verifier = WebPkiClientVerifier::builder(Arc::new(roots.into_inner()))
.build()
.map_err(|e| TlsError::Configuration(e.to_string()))?;
builder.with_client_cert_verifier(verifier)
}
};
let mut config = builder
.with_single_cert(self.cert_chain.into_inner(), self.key.clone_inner())
.map_err(|e| TlsError::Configuration(e.to_string()))?;
if !self.alpn_protocols.is_empty() {
config.alpn_protocols = self.alpn_protocols;
}
if let Some(size) = self.max_fragment_size {
config.max_fragment_size = Some(size);
}
#[cfg(feature = "tracing-integration")]
tracing::debug!(
alpn = ?config.alpn_protocols,
"TlsAcceptor built"
);
Ok(TlsAcceptor {
config: Arc::new(config),
handshake_timeout: self.handshake_timeout,
alpn_required: self.alpn_required,
})
}
#[cfg(not(feature = "tls"))]
pub fn build(self) -> Result<TlsAcceptor, TlsError> {
Err(TlsError::Configuration("tls feature not enabled".into()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tls::Certificate;
const TEST_CERT_PEM: &[u8] = br"-----BEGIN CERTIFICATE-----
MIIDGjCCAgKgAwIBAgIUEOa/xZnL2Xclme2QSueCrHSMLnEwDQYJKoZIhvcNAQEL
BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDIyNjIyMjk1MloXDTM2MDIy
NDIyMjk1MlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
AAOCAQ8AMIIBCgKCAQEAx1JqCHpDIHPR4H1LDrb3gHVCzoKujANyHdOKw7CTLKdz
JbDybwJYqZ8vZpq0xwhYKpHdGO4yv7yLT7a2kThq3MrxohfXp9tv1Dop7siTQiWT
7uGYJzh1bOhw7ElLJc8bW/mBf7ksMyqkX8/8mRXRWqqDv3dKe5CrSt2Pqti9tYH0
DcT2fftUGT14VvL/Fq1kWPM16ebTRCFp/4ki/Th7SzFvTN99L45MAilHZFefRSzc
9xN1qQZNm7lT6oo0zD3wmOy70iiasqpLrmG51TRdbnBnGH6CIHvUIl3rCDteUuj1
pB9lh67qt5kipCn4+8zceXmUaO/nmRawC7Vz+6AsTwIDAQABo2QwYjALBgNVHQ8E
BAMCBLAwEwYDVR0lBAwwCgYIKwYBBQUHAwEwFAYDVR0RBA0wC4IJbG9jYWxob3N0
MAkGA1UdEwQCMAAwHQYDVR0OBBYEFEGZkeJqxBWpc24NHkE8k5PM8gTyMA0GCSqG
SIb3DQEBCwUAA4IBAQAzfQ4na2v1VhK/dyhC89rMHPN/8OX7CGWwrpWlEOYtpMds
OyQKTZjdz8aFSFl9rvnyGRHrdo4J1RoMGNR5wt1XQ7+k3l/iEWRlSRw+JU6+jqsx
xfjik55Dji36pN7ARGW4ADBpc3yTOHFhaH41GpSZ6s/2KdGG2gifo7UGNdkdgL60
nxRt1tfapaNtzpi90TfDx2w6MQmkNMKVOowbYX/zUY7kklJLP8KWTwXO7eovtIpr
FPAy+SbPl3+sqPbes5IqAQO9jhjb0w0/5RlSTPtiKetb6gAA7Yqw+yZWkBN0WDye
Lru15URJw9pE1Uae8IuzyzHiF1fnn45swnvW3Szb
-----END CERTIFICATE-----";
const TEST_KEY_PEM: &[u8] = br"-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDHUmoIekMgc9Hg
fUsOtveAdULOgq6MA3Id04rDsJMsp3MlsPJvAlipny9mmrTHCFgqkd0Y7jK/vItP
traROGrcyvGiF9en22/UOinuyJNCJZPu4ZgnOHVs6HDsSUslzxtb+YF/uSwzKqRf
z/yZFdFaqoO/d0p7kKtK3Y+q2L21gfQNxPZ9+1QZPXhW8v8WrWRY8zXp5tNEIWn/
iSL9OHtLMW9M330vjkwCKUdkV59FLNz3E3WpBk2buVPqijTMPfCY7LvSKJqyqkuu
YbnVNF1ucGcYfoIge9QiXesIO15S6PWkH2WHruq3mSKkKfj7zNx5eZRo7+eZFrAL
tXP7oCxPAgMBAAECggEAOwgH+jnHfql+m4dP/uwmUgeogQPIERSGLBo2Ky208NEo
8507t6/QtW+9OJyR9K5eekEX46XMJuf+tF2PJWQ5lemO9awtBPwi2w5c0+jYYAtE
DEgI6Xi5okcXBovQc0KqvisfdMXRNtgmtW+iRm5lQf5lJYP9baoTaQlEXttxF/t+
g7RLjaPaJNvE/Yq+4FJUuL1fWSTXfH99If6rR8Zy+FXtFRpCVbNdpruUaOmIgjuT
TlRaXf/VfnIocRNVsEWTlfCJq8Ra4qLAFM4KYuEBoPaRxpOH9of4nZftzOHwiJ0m
8+GwXqNhySVKO3SPw194LCVSoje1+PEaA/tPlE1RZQKBgQDoJpCQ0SmKOCG/c0lD
QebhqSruFoqQqeEV6poZCO+HZMvszhIiUkvk3/uoZnFQmb3w4YwbRH05YQd6iXFk
048lbqPzfGQGepMpLAY9DWhnbDy+mbuOZp+04gZ/QUen+qKBOc3mNUGhCZNyAUl3
YXeGgPNtknRQ6ebNgO1PFLaoewKBgQDbzHjknGMAFcZXr4/MPOc03I8mQiLECfxa
5PJYhjq85ygCMePiH08xJC4RT6ld3EC4GxliPFubzLMXJhqGBgboSzXGcDZbAOdw
YqleUF/jBChl2oyawzf280FepJqFG6d5qFwISi4hnCZKC7PdIbaKjjRGU7flDBej
AfGjIuzlPQKBgETAjxXkbAn8P7pkWTErBkaUhBtI37aiKQAFn6eEZvPRHTe/e81g
VAuvbedcl3iIX6FEGutEaFWi78URiVyT7xPl5XZJw5HLoWOTHzHbk6z1eDP2cX5l
1CyMt+HeImuUJaZhySHBafNYU6tyyCAr5GsYK3+q3PnNm8YGxcEi4EmbAoGAYbvA
wb58Euybvh+1bBZkpE+yY0ujE9Jw4KXO0OgWtCqA0sEGWGSdnPc+eLoYUEEAkhyS
o+i8v0E9HPz3bEK/zYirx6nbsYlsX7+vGd3ZVSNjJy8PuD035Fnz5jaA8tECHglr
qs/5RT6ek+wyNRCpj2B+BAtzyKgg1n2lyWldNu0CgYEA4Ux9QV5s99W39vJlzGHD
ilKqHWetmrehbe0nIeCe2bJWqb08oSrQD8Q7om/MGAKjhFqNyYqqoJXcmbAvLygu
kMtbiQcfyyxjefyCA0OvdWEXrvnRZYNEBosyX/ko7Bl2IRBFP6ahQhj7jHqm2+/J
SrXuVI5uunTgPWuOtJOP+KM=
-----END PRIVATE KEY-----";
#[test]
fn test_builder_new() {
let chain = CertificateChain::from_pem(TEST_CERT_PEM).unwrap();
let key = PrivateKey::from_pem(TEST_KEY_PEM).unwrap();
let builder = TlsAcceptorBuilder::new(chain, key);
assert!(builder.alpn_protocols.is_empty());
}
#[test]
fn test_builder_alpn_http() {
let chain = CertificateChain::from_pem(TEST_CERT_PEM).unwrap();
let key = PrivateKey::from_pem(TEST_KEY_PEM).unwrap();
let builder = TlsAcceptorBuilder::new(chain, key).alpn_http();
assert_eq!(
builder.alpn_protocols,
vec![b"h2".to_vec(), b"http/1.1".to_vec()]
);
}
#[test]
fn test_builder_alpn_h2() {
let chain = CertificateChain::from_pem(TEST_CERT_PEM).unwrap();
let key = PrivateKey::from_pem(TEST_KEY_PEM).unwrap();
let builder = TlsAcceptorBuilder::new(chain, key).alpn_h2();
assert_eq!(builder.alpn_protocols, vec![b"h2".to_vec()]);
assert!(builder.alpn_required);
}
#[test]
fn test_builder_alpn_grpc() {
let chain = CertificateChain::from_pem(TEST_CERT_PEM).unwrap();
let key = PrivateKey::from_pem(TEST_KEY_PEM).unwrap();
let builder = TlsAcceptorBuilder::new(chain, key).alpn_grpc();
assert_eq!(builder.alpn_protocols, vec![b"h2".to_vec()]);
assert!(builder.alpn_required);
}
#[test]
fn test_client_auth_default() {
let chain = CertificateChain::from_pem(TEST_CERT_PEM).unwrap();
let key = PrivateKey::from_pem(TEST_KEY_PEM).unwrap();
let builder = TlsAcceptorBuilder::new(chain, key);
assert!(matches!(builder.client_auth, ClientAuth::None));
}
#[test]
fn test_certificate_from_pem() {
let certs = Certificate::from_pem(TEST_CERT_PEM).unwrap();
assert_eq!(certs.len(), 1);
}
#[test]
fn test_private_key_from_pem() {
let _key = PrivateKey::from_pem(TEST_KEY_PEM).unwrap();
}
#[cfg(feature = "tls")]
#[test]
fn test_build_acceptor() {
let chain = CertificateChain::from_pem(TEST_CERT_PEM).unwrap();
let key = PrivateKey::from_pem(TEST_KEY_PEM).unwrap();
let acceptor = TlsAcceptorBuilder::new(chain, key)
.alpn_http()
.build()
.unwrap();
assert_eq!(
acceptor.config().alpn_protocols,
vec![b"h2".to_vec(), b"http/1.1".to_vec()]
);
}
#[cfg(feature = "tls")]
#[test]
fn test_acceptor_clone_is_cheap() {
let chain = CertificateChain::from_pem(TEST_CERT_PEM).unwrap();
let key = PrivateKey::from_pem(TEST_KEY_PEM).unwrap();
let acceptor = TlsAcceptorBuilder::new(chain, key).build().unwrap();
let start = std::time::Instant::now();
for _ in 0..10000 {
let _clone = acceptor.clone();
}
let elapsed = start.elapsed();
assert!(elapsed.as_millis() < 100);
}
#[cfg(feature = "tls")]
#[test]
fn test_connect_accept_handshake() {
use crate::conformance::{ConformanceTarget, LabRuntimeTarget, TestConfig};
use crate::cx::Cx;
use crate::net::tcp::VirtualTcpStream;
use crate::test_utils::init_test_logging;
use futures_lite::future::zip;
init_test_logging();
let config = TestConfig::new()
.with_seed(0x715A_CCE7)
.with_tracing(true)
.with_max_steps(20_000);
let mut runtime = LabRuntimeTarget::create_runtime(config);
let (
client_ready,
server_ready,
client_protocol,
server_protocol,
client_alpn,
server_alpn,
checkpoints,
) = LabRuntimeTarget::block_on(&mut runtime, async move {
let _cx = Cx::current().expect("lab runtime should install a current Cx");
let chain = CertificateChain::from_pem(TEST_CERT_PEM).unwrap();
let key = PrivateKey::from_pem(TEST_KEY_PEM).unwrap();
let acceptor = TlsAcceptorBuilder::new(chain, key)
.alpn_http()
.build()
.unwrap();
let certs = Certificate::from_pem(TEST_CERT_PEM).unwrap();
let connector = crate::tls::TlsConnectorBuilder::new()
.add_root_certificates(certs)
.alpn_http()
.build()
.unwrap();
let (client_io, server_io) = VirtualTcpStream::pair(
"127.0.0.1:5000".parse().unwrap(),
"127.0.0.1:5001".parse().unwrap(),
);
let checkpoints = vec![serde_json::json!({
"phase": "virtual_stream_pair_created",
"client_addr": "127.0.0.1:5000",
"server_addr": "127.0.0.1:5001",
})];
for checkpoint in &checkpoints {
tracing::info!(event = %checkpoint, "tls_lab_checkpoint");
}
let (client_res, server_res) = zip(
connector.connect("localhost", client_io),
acceptor.accept(server_io),
)
.await;
let client = client_res.expect("client handshake should succeed");
let server = server_res.expect("server handshake should succeed");
let client_ready = client.is_ready();
let server_ready = server.is_ready();
let client_protocol = client.protocol_version().is_some();
let server_protocol = server.protocol_version().is_some();
let client_alpn = client.alpn_protocol().map(|protocol| protocol.to_vec());
let server_alpn = server.alpn_protocol().map(|protocol| protocol.to_vec());
let mut checkpoints = checkpoints;
checkpoints.push(serde_json::json!({
"phase": "handshake_completed",
"client_ready": client_ready,
"server_ready": server_ready,
"client_protocol_present": client_protocol,
"server_protocol_present": server_protocol,
"client_alpn": client_alpn.as_ref().map(|protocol| String::from_utf8_lossy(protocol).to_string()),
"server_alpn": server_alpn.as_ref().map(|protocol| String::from_utf8_lossy(protocol).to_string()),
}));
for checkpoint in checkpoints.iter().skip(1) {
tracing::info!(event = %checkpoint, "tls_lab_checkpoint");
}
(
client_ready,
server_ready,
client_protocol,
server_protocol,
client_alpn,
server_alpn,
checkpoints,
)
});
assert!(client_ready);
assert!(server_ready);
assert!(client_protocol);
assert!(server_protocol);
assert_eq!(client_alpn.as_deref(), Some(b"h2".as_slice()));
assert_eq!(server_alpn.as_deref(), Some(b"h2".as_slice()));
assert_eq!(checkpoints.len(), 2);
assert!(runtime.is_quiescent());
}
#[cfg(feature = "tls")]
#[test]
fn test_alpn_server_preference_ordering() {
use crate::net::tcp::VirtualTcpStream;
use crate::test_utils::run_test_with_cx;
use futures_lite::future::zip;
run_test_with_cx(|_cx| async move {
let chain = CertificateChain::from_pem(TEST_CERT_PEM).unwrap();
let key = PrivateKey::from_pem(TEST_KEY_PEM).unwrap();
let acceptor = TlsAcceptorBuilder::new(chain, key)
.alpn_protocols(vec![b"http/1.1".to_vec(), b"h2".to_vec()])
.build()
.unwrap();
let certs = Certificate::from_pem(TEST_CERT_PEM).unwrap();
let connector = crate::tls::TlsConnectorBuilder::new()
.add_root_certificates(certs)
.alpn_http()
.build()
.unwrap();
let (client_io, server_io) = VirtualTcpStream::pair(
"127.0.0.1:5100".parse().unwrap(),
"127.0.0.1:5101".parse().unwrap(),
);
let (client_res, server_res) = zip(
connector.connect("localhost", client_io),
acceptor.accept(server_io),
)
.await;
let client = client_res.unwrap();
let server = server_res.unwrap();
assert_eq!(client.alpn_protocol(), Some(b"http/1.1".as_slice()));
assert_eq!(server.alpn_protocol(), Some(b"http/1.1".as_slice()));
});
}
#[cfg(feature = "tls")]
#[test]
fn test_alpn_fallback_to_http11_when_server_h2_not_supported() {
use crate::net::tcp::VirtualTcpStream;
use crate::test_utils::run_test_with_cx;
use futures_lite::future::zip;
run_test_with_cx(|_cx| async move {
let chain = CertificateChain::from_pem(TEST_CERT_PEM).unwrap();
let key = PrivateKey::from_pem(TEST_KEY_PEM).unwrap();
let acceptor = TlsAcceptorBuilder::new(chain, key)
.alpn_protocols(vec![b"http/1.1".to_vec()])
.build()
.unwrap();
let certs = Certificate::from_pem(TEST_CERT_PEM).unwrap();
let connector = crate::tls::TlsConnectorBuilder::new()
.add_root_certificates(certs)
.alpn_http()
.build()
.unwrap();
let (client_io, server_io) = VirtualTcpStream::pair(
"127.0.0.1:5110".parse().unwrap(),
"127.0.0.1:5111".parse().unwrap(),
);
let (client_res, server_res) = zip(
connector.connect("localhost", client_io),
acceptor.accept(server_io),
)
.await;
let client = client_res.unwrap();
let server = server_res.unwrap();
assert_eq!(client.alpn_protocol(), Some(b"http/1.1".as_slice()));
assert_eq!(server.alpn_protocol(), Some(b"http/1.1".as_slice()));
});
}
#[cfg(feature = "tls")]
#[test]
fn test_alpn_none_when_server_has_no_alpn() {
use crate::net::tcp::VirtualTcpStream;
use crate::test_utils::run_test_with_cx;
use futures_lite::future::zip;
run_test_with_cx(|_cx| async move {
let chain = CertificateChain::from_pem(TEST_CERT_PEM).unwrap();
let key = PrivateKey::from_pem(TEST_KEY_PEM).unwrap();
let acceptor = TlsAcceptorBuilder::new(chain, key).build().unwrap();
let certs = Certificate::from_pem(TEST_CERT_PEM).unwrap();
let connector = crate::tls::TlsConnectorBuilder::new()
.add_root_certificates(certs)
.alpn_http()
.build()
.unwrap();
let (client_io, server_io) = VirtualTcpStream::pair(
"127.0.0.1:5120".parse().unwrap(),
"127.0.0.1:5121".parse().unwrap(),
);
let (client_res, server_res) = zip(
connector.connect("localhost", client_io),
acceptor.accept(server_io),
)
.await;
let client = client_res.unwrap();
let server = server_res.unwrap();
assert!(client.alpn_protocol().is_none());
assert!(server.alpn_protocol().is_none());
});
}
#[cfg(feature = "tls")]
#[test]
fn test_alpn_required_client_errors_on_no_overlap() {
use crate::net::tcp::VirtualTcpStream;
use crate::test_utils::run_test_with_cx;
use futures_lite::future::zip;
run_test_with_cx(|_cx| async move {
let chain = CertificateChain::from_pem(TEST_CERT_PEM).unwrap();
let key = PrivateKey::from_pem(TEST_KEY_PEM).unwrap();
let acceptor = TlsAcceptorBuilder::new(chain, key)
.alpn_protocols(vec![b"http/1.1".to_vec()])
.build()
.unwrap();
let certs = Certificate::from_pem(TEST_CERT_PEM).unwrap();
let connector = crate::tls::TlsConnectorBuilder::new()
.add_root_certificates(certs)
.alpn_h2()
.build()
.unwrap();
let (client_io, server_io) = VirtualTcpStream::pair(
"127.0.0.1:5130".parse().unwrap(),
"127.0.0.1:5131".parse().unwrap(),
);
let (client_res, server_res) = zip(
connector.connect("localhost", client_io),
acceptor.accept(server_io),
)
.await;
let client_err = client_res.unwrap_err();
assert!(matches!(client_err, TlsError::Handshake(_)));
let server_err = server_res.unwrap_err();
assert!(matches!(server_err, TlsError::Handshake(_)));
});
}
#[cfg(feature = "tls")]
#[test]
fn test_alpn_required_server_errors_when_client_offers_none() {
use crate::net::tcp::VirtualTcpStream;
use crate::test_utils::run_test_with_cx;
use futures_lite::future::zip;
run_test_with_cx(|_cx| async move {
let chain = CertificateChain::from_pem(TEST_CERT_PEM).unwrap();
let key = PrivateKey::from_pem(TEST_KEY_PEM).unwrap();
let acceptor = TlsAcceptorBuilder::new(chain, key)
.alpn_h2()
.build()
.unwrap();
let certs = Certificate::from_pem(TEST_CERT_PEM).unwrap();
let connector = crate::tls::TlsConnectorBuilder::new()
.add_root_certificates(certs)
.build()
.unwrap();
let (client_io, server_io) = VirtualTcpStream::pair(
"127.0.0.1:5140".parse().unwrap(),
"127.0.0.1:5141".parse().unwrap(),
);
let (client_res, server_res) = zip(
connector.connect("localhost", client_io),
acceptor.accept(server_io),
)
.await;
let client = client_res.unwrap();
assert!(client.alpn_protocol().is_none());
let server_err = server_res.unwrap_err();
assert!(matches!(server_err, TlsError::AlpnNegotiationFailed { .. }));
});
}
#[cfg(feature = "tls")]
#[test]
fn test_connect_timeout() {
use crate::net::tcp::VirtualTcpStream;
use crate::test_utils::run_test_with_cx;
run_test_with_cx(|_cx| async move {
let certs = Certificate::from_pem(TEST_CERT_PEM).unwrap();
let connector = crate::tls::TlsConnectorBuilder::new()
.add_root_certificates(certs)
.handshake_timeout(std::time::Duration::from_millis(5))
.build()
.unwrap();
let (client_io, _server_io) = VirtualTcpStream::pair(
"127.0.0.1:5002".parse().unwrap(),
"127.0.0.1:5003".parse().unwrap(),
);
let err = connector.connect("localhost", client_io).await.unwrap_err();
assert!(matches!(err, TlsError::Timeout(_)));
});
}
#[cfg(not(feature = "tls"))]
#[test]
fn test_build_without_tls_feature() {
let chain = CertificateChain::new();
let key = PrivateKey::from_pkcs8_der(vec![]);
let result = TlsAcceptorBuilder::new(chain, key).build();
assert!(result.is_err());
}
}