use std::ffi::CString;
use std::os::raw::{c_int, c_long, c_uint, c_void};
use std::path::Path;
use std::ptr;
use std::sync::OnceLock;
use crate::error::{last_error, pem_eof_or_err, Error, Result};
use crate::ffi::SslCtx;
use super::cipher_suite::{self, CipherSuite};
use super::named_group::{self, NamedGroup};
use super::{encode_alpn_wire, iter_alpn_wire, ProtocolVersion};
#[derive(Debug)]
pub struct ServerConfig {
pub(crate) ctx: SslCtx,
pub(crate) ktls_disabled: bool,
}
unsafe impl Send for ServerConfig {}
unsafe impl Sync for ServerConfig {}
impl ServerConfig {
#[must_use]
pub fn builder() -> ServerConfigBuilder {
ServerConfigBuilder::default()
}
pub(crate) fn ctx_ptr(&self) -> *mut aws_lc_sys::SSL_CTX {
self.ctx.as_ptr()
}
}
#[derive(Debug, Default)]
pub struct ServerConfigBuilder {
alpn_protocols: Vec<Vec<u8>>,
min_version: Option<ProtocolVersion>,
max_version: Option<ProtocolVersion>,
ktls_aead_only: bool,
ktls_disabled: bool,
cipher_suites: Option<Vec<&'static CipherSuite>>,
named_groups: Option<Vec<NamedGroup>>,
client_auth: ClientAuthMode,
client_auth_roots_pem: Option<Vec<u8>>,
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum ClientAuthMode {
#[default]
None,
Optional,
Required,
}
impl ServerConfigBuilder {
#[must_use]
pub fn alpn_protocols(mut self, protos: &[&[u8]]) -> Self {
self.alpn_protocols = protos.iter().map(|p| p.to_vec()).collect();
self
}
#[must_use]
pub fn min_protocol_version(mut self, v: ProtocolVersion) -> Self {
self.min_version = Some(v);
self
}
#[must_use]
pub fn max_protocol_version(mut self, v: ProtocolVersion) -> Self {
self.max_version = Some(v);
self
}
#[must_use]
pub fn ktls_aead_only(mut self, on: bool) -> Self {
self.ktls_aead_only = on;
self
}
#[must_use]
pub fn disable_ktls(mut self) -> Self {
self.ktls_disabled = true;
self
}
#[must_use]
pub fn cipher_suites(mut self, suites: &[&'static CipherSuite]) -> Self {
self.cipher_suites = if suites.is_empty() {
None
} else {
Some(suites.to_vec())
};
self
}
#[must_use]
pub fn named_groups(mut self, groups: &[NamedGroup]) -> Self {
self.named_groups = if groups.is_empty() {
None
} else {
Some(groups.to_vec())
};
self
}
#[must_use]
pub fn client_auth(mut self, mode: ClientAuthMode, roots_pem: &[u8]) -> Self {
self.client_auth = mode;
self.client_auth_roots_pem = Some(roots_pem.to_vec());
self
}
pub fn with_pem_files(self, cert: &Path, key: &Path) -> Result<ServerConfig> {
let ctx = new_server_ctx()?;
let cert_c = path_to_cstring(cert)?;
let key_c = path_to_cstring(key)?;
let ok = unsafe {
aws_lc_sys::SSL_CTX_use_certificate_chain_file(ctx.as_ptr(), cert_c.as_ptr())
};
if ok != 1 {
return Err(Error::Init(format!(
"loading certificate chain from {}: {}",
cert.display(),
last_error()
)));
}
let ok = unsafe {
aws_lc_sys::SSL_CTX_use_PrivateKey_file(
ctx.as_ptr(),
key_c.as_ptr(),
aws_lc_sys::SSL_FILETYPE_PEM as c_int,
)
};
if ok != 1 {
return Err(Error::Init(format!(
"loading private key from {}: {}",
key.display(),
last_error()
)));
}
self.finish(ctx)
}
pub fn with_pem_bytes(self, cert: &[u8], key: &[u8]) -> Result<ServerConfig> {
let ctx = new_server_ctx()?;
load_cert_chain_pem(&ctx, cert)?;
load_private_key_pem(&ctx, key)?;
self.finish(ctx)
}
pub fn with_der_bytes(self, cert_chain: &[&[u8]], key: &[u8]) -> Result<ServerConfig> {
let ctx = new_server_ctx()?;
load_cert_chain_der(&ctx, cert_chain)?;
load_private_key_der(&ctx, key)?;
self.finish(ctx)
}
fn finish(self, ctx: SslCtx) -> Result<ServerConfig> {
let ok = unsafe { aws_lc_sys::SSL_CTX_check_private_key(ctx.as_ptr()) };
if ok != 1 {
return Err(Error::Init(format!(
"certificate and private key do not match: {}",
last_error()
)));
}
let min_v = self.min_version.unwrap_or(ProtocolVersion::Tls12).raw();
let max_v = self.max_version.unwrap_or(ProtocolVersion::Tls13).raw();
unsafe {
if aws_lc_sys::SSL_CTX_set_min_proto_version(ctx.as_ptr(), min_v) != 1 {
return Err(Error::Init(format!(
"SSL_CTX_set_min_proto_version: {}",
last_error()
)));
}
if aws_lc_sys::SSL_CTX_set_max_proto_version(ctx.as_ptr(), max_v) != 1 {
return Err(Error::Init(format!(
"SSL_CTX_set_max_proto_version: {}",
last_error()
)));
}
}
#[allow(clippy::cast_sign_loss)]
unsafe {
aws_lc_sys::SSL_CTX_set_options(ctx.as_ptr(), aws_lc_sys::SSL_OP_NO_TICKET as u32);
aws_lc_sys::SSL_CTX_set_num_tickets(ctx.as_ptr(), 0);
}
if self.ktls_aead_only {
let list = c"ECDHE+AESGCM:ECDHE+CHACHA20";
let ok = unsafe { aws_lc_sys::SSL_CTX_set_cipher_list(ctx.as_ptr(), list.as_ptr()) };
if ok != 1 {
return Err(Error::Init(format!(
"SSL_CTX_set_cipher_list (AEAD-only): {}",
last_error()
)));
}
}
if let Some(suites) = &self.cipher_suites {
cipher_suite::apply_to_ctx(&ctx, suites)?;
}
if let Some(groups) = &self.named_groups {
named_group::apply_to_ctx(&ctx, groups)?;
}
match self.client_auth {
ClientAuthMode::None => {}
mode => {
let roots_pem = self.client_auth_roots_pem.as_deref().ok_or_else(|| {
Error::Init("client_auth requires a non-empty CA bundle (roots_pem)".into())
})?;
load_client_ca_roots_pem(&ctx, roots_pem)?;
#[allow(clippy::cast_sign_loss)]
let flags = match mode {
ClientAuthMode::None => unreachable!(),
ClientAuthMode::Optional => aws_lc_sys::SSL_VERIFY_PEER as c_int,
ClientAuthMode::Required => {
(aws_lc_sys::SSL_VERIFY_PEER | aws_lc_sys::SSL_VERIFY_FAIL_IF_NO_PEER_CERT)
as c_int
}
};
unsafe {
aws_lc_sys::SSL_CTX_set_verify(ctx.as_ptr(), flags, None);
}
}
}
if !self.alpn_protocols.is_empty() {
let refs: Vec<&[u8]> = self.alpn_protocols.iter().map(Vec::as_slice).collect();
let wire = encode_alpn_wire(&refs)
.map_err(|e| Error::Init(format!("encoding ALPN protocol list: {e}")))?;
install_alpn_wire(&ctx, wire)?;
unsafe {
aws_lc_sys::SSL_CTX_set_alpn_select_cb(
ctx.as_ptr(),
Some(alpn_select_cb),
ptr::null_mut(),
);
}
}
Ok(ServerConfig {
ctx,
ktls_disabled: self.ktls_disabled,
})
}
}
fn new_server_ctx() -> Result<SslCtx> {
let raw = unsafe { aws_lc_sys::SSL_CTX_new(aws_lc_sys::TLS_server_method()) };
unsafe { SslCtx::from_raw(raw) }
.ok_or_else(|| Error::Init(format!("SSL_CTX_new: {}", last_error())))
}
fn path_to_cstring(p: &Path) -> Result<CString> {
let s = p.to_str().ok_or_else(|| {
Error::Init(format!(
"path {} is not valid UTF-8; AWS-LC PEM loaders require a NUL-terminated path",
p.display()
))
})?;
CString::new(s).map_err(|_| {
Error::Init(format!(
"path {} contains an embedded NUL byte",
p.display()
))
})
}
fn load_cert_chain_pem(ctx: &SslCtx, pem: &[u8]) -> Result<()> {
#[allow(clippy::cast_possible_wrap)]
let bio = unsafe { aws_lc_sys::BIO_new_mem_buf(pem.as_ptr().cast(), pem.len() as isize) };
if bio.is_null() {
return Err(Error::Init(format!(
"BIO_new_mem_buf for cert chain: {}",
last_error()
)));
}
let bio = BioGuard(bio);
let leaf =
unsafe { aws_lc_sys::PEM_read_bio_X509_AUX(bio.0, ptr::null_mut(), None, ptr::null_mut()) };
if leaf.is_null() {
return Err(Error::Init(format!(
"PEM_read_bio_X509_AUX (leaf): {}",
last_error()
)));
}
let ok = unsafe { aws_lc_sys::SSL_CTX_use_certificate(ctx.as_ptr(), leaf) };
unsafe { aws_lc_sys::X509_free(leaf) };
if ok != 1 {
return Err(Error::Init(format!(
"SSL_CTX_use_certificate: {}",
last_error()
)));
}
unsafe {
aws_lc_sys::SSL_CTX_clear_chain_certs(ctx.as_ptr());
}
loop {
let extra =
unsafe { aws_lc_sys::PEM_read_bio_X509(bio.0, ptr::null_mut(), None, ptr::null_mut()) };
if extra.is_null() {
pem_eof_or_err("PEM_read_bio_X509 (server chain)")?;
break;
}
let ok = unsafe { aws_lc_sys::SSL_CTX_add0_chain_cert(ctx.as_ptr(), extra) };
if ok != 1 {
unsafe { aws_lc_sys::X509_free(extra) };
return Err(Error::Init(format!(
"SSL_CTX_add0_chain_cert: {}",
last_error()
)));
}
}
Ok(())
}
fn load_private_key_pem(ctx: &SslCtx, pem: &[u8]) -> Result<()> {
#[allow(clippy::cast_possible_wrap)]
let bio = unsafe { aws_lc_sys::BIO_new_mem_buf(pem.as_ptr().cast(), pem.len() as isize) };
if bio.is_null() {
return Err(Error::Init(format!(
"BIO_new_mem_buf for private key: {}",
last_error()
)));
}
let bio = BioGuard(bio);
let key = unsafe {
aws_lc_sys::PEM_read_bio_PrivateKey(bio.0, ptr::null_mut(), None, ptr::null_mut())
};
if key.is_null() {
return Err(Error::Init(format!(
"PEM_read_bio_PrivateKey: {}",
last_error()
)));
}
let ok = unsafe { aws_lc_sys::SSL_CTX_use_PrivateKey(ctx.as_ptr(), key) };
unsafe { aws_lc_sys::EVP_PKEY_free(key) };
if ok != 1 {
return Err(Error::Init(format!(
"SSL_CTX_use_PrivateKey: {}",
last_error()
)));
}
Ok(())
}
fn load_cert_chain_der(ctx: &SslCtx, certs: &[&[u8]]) -> Result<()> {
let (leaf_der, rest) = certs.split_first().ok_or_else(|| {
Error::Init("DER certificate chain must contain at least the leaf certificate".into())
})?;
let leaf = super::der::parse_x509(leaf_der)?;
let ok = unsafe { aws_lc_sys::SSL_CTX_use_certificate(ctx.as_ptr(), leaf) };
unsafe { aws_lc_sys::X509_free(leaf) };
if ok != 1 {
return Err(Error::Init(format!(
"SSL_CTX_use_certificate (DER leaf): {}",
last_error()
)));
}
unsafe {
aws_lc_sys::SSL_CTX_clear_chain_certs(ctx.as_ptr());
}
for extra_der in rest {
let extra = super::der::parse_x509(extra_der)?;
let ok = unsafe { aws_lc_sys::SSL_CTX_add0_chain_cert(ctx.as_ptr(), extra) };
if ok != 1 {
unsafe { aws_lc_sys::X509_free(extra) };
return Err(Error::Init(format!(
"SSL_CTX_add0_chain_cert (DER): {}",
last_error()
)));
}
}
Ok(())
}
fn load_private_key_der(ctx: &SslCtx, der: &[u8]) -> Result<()> {
let key = super::der::parse_private_key(der)?;
let ok = unsafe { aws_lc_sys::SSL_CTX_use_PrivateKey(ctx.as_ptr(), key) };
unsafe { aws_lc_sys::EVP_PKEY_free(key) };
if ok != 1 {
return Err(Error::Init(format!(
"SSL_CTX_use_PrivateKey (DER): {}",
last_error()
)));
}
Ok(())
}
struct BioGuard(*mut aws_lc_sys::BIO);
impl Drop for BioGuard {
fn drop(&mut self) {
unsafe {
aws_lc_sys::BIO_free(self.0);
}
}
}
fn load_client_ca_roots_pem(ctx: &SslCtx, pem: &[u8]) -> Result<()> {
#[allow(clippy::cast_possible_wrap)]
let bio = unsafe { aws_lc_sys::BIO_new_mem_buf(pem.as_ptr().cast(), pem.len() as isize) };
if bio.is_null() {
return Err(Error::Init(format!(
"BIO_new_mem_buf for client CA roots: {}",
last_error()
)));
}
let bio = BioGuard(bio);
let store = unsafe { aws_lc_sys::SSL_CTX_get_cert_store(ctx.as_ptr()) };
if store.is_null() {
return Err(Error::Init("SSL_CTX_get_cert_store returned null".into()));
}
let mut added = 0usize;
loop {
let cert =
unsafe { aws_lc_sys::PEM_read_bio_X509(bio.0, ptr::null_mut(), None, ptr::null_mut()) };
if cert.is_null() {
pem_eof_or_err("PEM_read_bio_X509 (client CA roots)")?;
break;
}
let ok = unsafe { aws_lc_sys::X509_STORE_add_cert(store, cert) };
unsafe { aws_lc_sys::X509_free(cert) };
if ok != 1 {
return Err(Error::Init(format!(
"X509_STORE_add_cert (client CA): {}",
last_error()
)));
}
added += 1;
}
if added == 0 {
return Err(Error::Init(
"no certificates found in client CA roots PEM".into(),
));
}
Ok(())
}
fn alpn_ex_index() -> c_int {
static IDX: OnceLock<c_int> = OnceLock::new();
*IDX.get_or_init(|| {
let idx = unsafe {
aws_lc_sys::SSL_CTX_get_ex_new_index(
0,
ptr::null_mut(),
ptr::null_mut(),
None,
Some(alpn_ex_free),
)
};
assert!(idx >= 0, "SSL_CTX_get_ex_new_index failed");
idx
})
}
fn install_alpn_wire(ctx: &SslCtx, wire: Vec<u8>) -> Result<()> {
let raw = Box::into_raw(Box::new(wire)).cast::<c_void>();
let ok = unsafe { aws_lc_sys::SSL_CTX_set_ex_data(ctx.as_ptr(), alpn_ex_index(), raw) };
if ok != 1 {
drop(unsafe { Box::from_raw(raw.cast::<Vec<u8>>()) });
return Err(Error::Init(format!(
"SSL_CTX_set_ex_data (ALPN): {}",
last_error()
)));
}
Ok(())
}
unsafe extern "C" fn alpn_ex_free(
_parent: *mut c_void,
ptr: *mut c_void,
_ad: *mut aws_lc_sys::CRYPTO_EX_DATA,
_index: c_int,
_argl: c_long,
_argp: *mut c_void,
) {
if ptr.is_null() {
return;
}
drop(unsafe { Box::from_raw(ptr.cast::<Vec<u8>>()) });
}
unsafe extern "C" fn alpn_select_cb(
ssl: *mut aws_lc_sys::SSL,
out: *mut *const u8,
out_len: *mut u8,
in_: *const u8,
in_len: c_uint,
_arg: *mut c_void,
) -> c_int {
let ctx = unsafe { aws_lc_sys::SSL_get_SSL_CTX(ssl) };
if ctx.is_null() {
return aws_lc_sys::SSL_TLSEXT_ERR_ALERT_FATAL;
}
let raw = unsafe { aws_lc_sys::SSL_CTX_get_ex_data(ctx, alpn_ex_index()) };
if raw.is_null() {
return aws_lc_sys::SSL_TLSEXT_ERR_ALERT_FATAL;
}
let server_wire: &Vec<u8> = unsafe { &*raw.cast::<Vec<u8>>() };
let client_wire: &[u8] = unsafe { std::slice::from_raw_parts(in_, in_len as usize) };
for server_proto in iter_alpn_wire(server_wire) {
for client_proto in iter_alpn_wire(client_wire) {
if server_proto == client_proto {
unsafe {
*out = server_proto.as_ptr();
#[allow(clippy::cast_possible_truncation)]
{
*out_len = server_proto.len() as u8;
}
}
return aws_lc_sys::SSL_TLSEXT_ERR_OK;
}
}
}
aws_lc_sys::SSL_TLSEXT_ERR_ALERT_FATAL
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_CERT_PEM: &[u8] = include_bytes!("../../tests/data/cert.pem");
const TEST_KEY_PEM: &[u8] = include_bytes!("../../tests/data/key.pem");
const TEST_CERT_DER: &[u8] = include_bytes!("../../tests/data/cert.der");
const TEST_KEY_DER: &[u8] = include_bytes!("../../tests/data/key.der");
#[test]
fn builds_from_valid_pem_bytes() {
let cfg = ServerConfig::builder()
.with_pem_bytes(TEST_CERT_PEM, TEST_KEY_PEM)
.expect("config should build");
assert!(!cfg.ctx_ptr().is_null());
}
#[test]
fn builds_from_valid_der_bytes() {
let cfg = ServerConfig::builder()
.with_der_bytes(&[TEST_CERT_DER], TEST_KEY_DER)
.expect("config should build from DER cert + key");
assert!(!cfg.ctx_ptr().is_null());
}
#[test]
fn der_empty_chain_rejected() {
let err = ServerConfig::builder()
.with_der_bytes(&[], TEST_KEY_DER)
.expect_err("empty chain should fail");
assert!(matches!(err, Error::Init(_)), "got: {err:?}");
}
#[test]
fn der_garbage_cert_rejected() {
let err = ServerConfig::builder()
.with_der_bytes(&[b"not a real DER cert"], TEST_KEY_DER)
.expect_err("garbage DER should fail");
assert!(matches!(err, Error::Init(_)), "got: {err:?}");
}
#[test]
fn der_garbage_key_rejected() {
let err = ServerConfig::builder()
.with_der_bytes(&[TEST_CERT_DER], b"not a real DER key")
.expect_err("garbage DER key should fail");
assert!(matches!(err, Error::Init(_)), "got: {err:?}");
}
#[test]
fn mismatched_cert_and_key_rejected() {
let err = ServerConfig::builder()
.with_pem_bytes(TEST_CERT_PEM, b"not a real key")
.expect_err("should fail on garbage key");
assert!(matches!(err, Error::Init(_)), "got: {err:?}");
}
#[test]
fn builder_chains_setters() {
let cfg = ServerConfig::builder()
.alpn_protocols(&[b"h2", b"http/1.1"])
.min_protocol_version(ProtocolVersion::Tls12)
.max_protocol_version(ProtocolVersion::Tls13)
.ktls_aead_only(true)
.with_pem_bytes(TEST_CERT_PEM, TEST_KEY_PEM)
.expect("config should build with full setter chain");
assert!(!cfg.ctx_ptr().is_null());
}
#[test]
fn with_pem_files_loads_from_disk() {
use std::fs;
use std::path::PathBuf;
struct TmpFile(PathBuf);
impl Drop for TmpFile {
fn drop(&mut self) {
let _ = fs::remove_file(&self.0);
}
}
let dir = std::env::temp_dir();
let pid = std::process::id();
let cert = TmpFile(dir.join(format!("tokio-aws-lc-{pid}-cert.pem")));
let key = TmpFile(dir.join(format!("tokio-aws-lc-{pid}-key.pem")));
fs::write(&cert.0, TEST_CERT_PEM).expect("write cert");
fs::write(&key.0, TEST_KEY_PEM).expect("write key");
let cfg = ServerConfig::builder()
.with_pem_files(&cert.0, &key.0)
.expect("config should build from on-disk PEM files");
assert!(!cfg.ctx_ptr().is_null());
}
#[test]
fn with_pem_files_missing_path_errors() {
use std::path::PathBuf;
let missing = PathBuf::from("/definitely/does/not/exist/cert.pem");
let err = ServerConfig::builder()
.with_pem_files(&missing, &missing)
.expect_err("missing files should fail");
let msg = err.to_string();
assert!(
msg.contains("loading certificate chain"),
"unexpected error message: {msg}"
);
}
#[test]
fn path_to_cstring_rejects_embedded_nul() {
use std::path::PathBuf;
#[cfg(unix)]
{
use std::os::unix::ffi::OsStringExt;
let pb = PathBuf::from(std::ffi::OsString::from_vec(b"foo\0bar.pem".to_vec()));
let err = path_to_cstring(&pb).expect_err("embedded NUL must be rejected");
let msg = err.to_string();
assert!(
msg.contains("embedded NUL byte"),
"unexpected error message: {msg}"
);
}
let _ = path_to_cstring(&PathBuf::from("normal.pem")).expect("plain path");
}
#[test]
fn pem_chain_with_trailing_garbage_is_rejected() {
let mut polluted: Vec<u8> = TEST_CERT_PEM.to_vec();
polluted.extend_from_slice(
b"\n-----BEGIN CERTIFICATE-----\nnot-base64-at-all\n-----END CERTIFICATE-----\n",
);
let err = ServerConfig::builder()
.with_pem_bytes(&polluted, TEST_KEY_PEM)
.expect_err("trailing garbage in cert PEM must be rejected");
assert!(matches!(err, Error::Init(_)), "got: {err:?}");
}
}