use std::ffi::c_void;
use std::fmt::{Debug, Formatter};
use std::slice;
use std::sync::Arc;
use libc::size_t;
use rustls::crypto::CryptoProvider;
use rustls::server::danger::ClientCertVerifier;
use rustls::server::{
ClientHello, ResolvesServerCert, ServerConfig, ServerConnection, StoresServerSessions,
WebPkiClientVerifier,
};
use rustls::sign::CertifiedKey;
use rustls::{KeyLog, KeyLogFile, ProtocolVersion, SignatureScheme, SupportedProtocolVersion};
use crate::certificate::rustls_certified_key;
use crate::connection::{Connection, rustls_connection};
use crate::crypto_provider::{self, rustls_crypto_provider};
use crate::error::{map_error, rustls_result};
use crate::ffi::{
Castable, OwnershipArc, OwnershipBox, OwnershipRef, free_arc, free_box, set_arc_mut_ptr,
set_boxed_mut_ptr, to_boxed_mut_ptr, try_box_from_ptr, try_clone_arc, try_mut_from_ptr,
try_mut_from_ptr_ptr, try_ref_from_ptr, try_ref_from_ptr_ptr, try_slice,
};
use crate::keylog::{CallbackKeyLog, rustls_keylog_log_callback, rustls_keylog_will_log_callback};
use crate::panic::ffi_panic_boundary;
use crate::rslice::{rustls_slice_bytes, rustls_slice_slice_bytes, rustls_slice_u16, rustls_str};
use crate::session::{
SessionStoreBroker, rustls_session_store_get_callback, rustls_session_store_put_callback,
};
use crate::userdata::userdata_get;
use crate::verifier::rustls_client_cert_verifier;
pub struct rustls_server_config_builder {
_private: [u8; 0],
}
impl Castable for rustls_server_config_builder {
type Ownership = OwnershipBox;
type RustType = ServerConfigBuilder;
}
pub(crate) struct ServerConfigBuilder {
provider: Option<Arc<CryptoProvider>>,
versions: Vec<&'static SupportedProtocolVersion>,
verifier: Arc<dyn ClientCertVerifier>,
cert_resolver: Option<Arc<dyn ResolvesServerCert>>,
session_storage: Option<Arc<dyn StoresServerSessions + Send + Sync>>,
alpn_protocols: Vec<Vec<u8>>,
ignore_client_order: Option<bool>,
key_log: Option<Arc<dyn KeyLog>>,
}
pub struct rustls_server_config {
_private: [u8; 0],
}
impl Castable for rustls_server_config {
type Ownership = OwnershipArc;
type RustType = ServerConfig;
}
impl rustls_server_config_builder {
#[no_mangle]
pub extern "C" fn rustls_server_config_builder_new() -> *mut rustls_server_config_builder {
ffi_panic_boundary! {
let builder = ServerConfigBuilder {
provider: crypto_provider::get_default_or_install_from_crate_features(),
versions: rustls::DEFAULT_VERSIONS.to_vec(),
verifier: WebPkiClientVerifier::no_client_auth(),
cert_resolver: None,
session_storage: None,
alpn_protocols: vec![],
ignore_client_order: None,
key_log: None,
};
to_boxed_mut_ptr(builder)
}
}
#[no_mangle]
pub extern "C" fn rustls_server_config_builder_new_custom(
provider: *const rustls_crypto_provider,
tls_versions: *const u16,
tls_versions_len: size_t,
builder_out: *mut *mut rustls_server_config_builder,
) -> rustls_result {
ffi_panic_boundary! {
let provider = try_clone_arc!(provider);
let tls_versions = try_slice!(tls_versions, tls_versions_len);
let mut versions = vec![];
for version_number in tls_versions {
let proto = ProtocolVersion::from(*version_number);
if proto == rustls::version::TLS12.version {
versions.push(&rustls::version::TLS12);
} else if proto == rustls::version::TLS13.version {
versions.push(&rustls::version::TLS13);
}
}
let builder_out = try_mut_from_ptr_ptr!(builder_out);
let builder = ServerConfigBuilder {
provider: Some(provider),
versions,
verifier: WebPkiClientVerifier::no_client_auth(),
cert_resolver: None,
session_storage: None,
alpn_protocols: vec![],
ignore_client_order: None,
key_log: None,
};
set_boxed_mut_ptr(builder_out, builder);
rustls_result::Ok
}
}
#[no_mangle]
pub extern "C" fn rustls_server_config_builder_set_client_verifier(
builder: *mut rustls_server_config_builder,
verifier: *const rustls_client_cert_verifier,
) {
ffi_panic_boundary! {
let builder = try_mut_from_ptr!(builder);
let verifier = try_ref_from_ptr!(verifier);
builder.verifier = verifier.clone();
}
}
#[no_mangle]
pub extern "C" fn rustls_server_config_builder_set_key_log_file(
builder: *mut rustls_server_config_builder,
) -> rustls_result {
ffi_panic_boundary! {
let builder = try_mut_from_ptr!(builder);
builder.key_log = Some(Arc::new(KeyLogFile::new()));
rustls_result::Ok
}
}
#[no_mangle]
pub extern "C" fn rustls_server_config_builder_set_key_log(
builder: *mut rustls_server_config_builder,
log_cb: rustls_keylog_log_callback,
will_log_cb: rustls_keylog_will_log_callback,
) -> rustls_result {
ffi_panic_boundary! {
let builder = try_mut_from_ptr!(builder);
let log_cb = match log_cb {
Some(cb) => cb,
None => return rustls_result::NullParameter,
};
builder.key_log = Some(Arc::new(CallbackKeyLog {
log_cb,
will_log_cb,
}));
rustls_result::Ok
}
}
#[no_mangle]
pub extern "C" fn rustls_server_config_builder_free(config: *mut rustls_server_config_builder) {
ffi_panic_boundary! {
free_box(config);
}
}
#[no_mangle]
pub extern "C" fn rustls_server_config_builder_set_ignore_client_order(
builder: *mut rustls_server_config_builder,
ignore: bool,
) -> rustls_result {
ffi_panic_boundary! {
let config = try_mut_from_ptr!(builder);
config.ignore_client_order = Some(ignore);
rustls_result::Ok
}
}
#[no_mangle]
pub extern "C" fn rustls_server_config_builder_set_alpn_protocols(
builder: *mut rustls_server_config_builder,
protocols: *const rustls_slice_bytes,
len: size_t,
) -> rustls_result {
ffi_panic_boundary! {
let config = try_mut_from_ptr!(builder);
let protocols = try_slice!(protocols, len);
let mut vv = Vec::new();
for p in protocols {
let v = try_slice!(p.data, p.len);
vv.push(v.to_vec());
}
config.alpn_protocols = vv;
rustls_result::Ok
}
}
#[no_mangle]
pub extern "C" fn rustls_server_config_builder_set_certified_keys(
builder: *mut rustls_server_config_builder,
certified_keys: *const *const rustls_certified_key,
certified_keys_len: size_t,
) -> rustls_result {
ffi_panic_boundary! {
let builder = try_mut_from_ptr!(builder);
let keys_ptrs = try_slice!(certified_keys, certified_keys_len);
let mut keys = Vec::new();
for &key_ptr in keys_ptrs {
let certified_key = try_clone_arc!(key_ptr);
keys.push(certified_key);
}
builder.cert_resolver = Some(Arc::new(ResolvesServerCertFromChoices::new(&keys)));
rustls_result::Ok
}
}
#[no_mangle]
pub extern "C" fn rustls_server_config_builder_build(
builder: *mut rustls_server_config_builder,
config_out: *mut *const rustls_server_config,
) -> rustls_result {
ffi_panic_boundary! {
let builder = try_box_from_ptr!(builder);
let config_out = try_ref_from_ptr_ptr!(config_out);
let provider = match builder.provider {
Some(provider) => provider,
None => return rustls_result::NoDefaultCryptoProvider,
};
let base = match ServerConfig::builder_with_provider(provider)
.with_protocol_versions(&builder.versions)
{
Ok(base) => base,
Err(err) => return map_error(err),
}
.with_client_cert_verifier(builder.verifier);
let mut config = if let Some(r) = builder.cert_resolver {
base.with_cert_resolver(r)
} else {
return rustls_result::NoCertResolver;
};
if let Some(ss) = builder.session_storage {
config.session_storage = ss;
}
config.alpn_protocols = builder.alpn_protocols;
if let Some(ignore_client_order) = builder.ignore_client_order {
config.ignore_client_order = ignore_client_order;
}
if let Some(key_log) = builder.key_log {
config.key_log = key_log;
}
set_arc_mut_ptr(config_out, config);
rustls_result::Ok
}
}
}
impl rustls_server_config {
#[no_mangle]
pub extern "C" fn rustls_server_config_fips(config: *const rustls_server_config) -> bool {
ffi_panic_boundary! {
try_ref_from_ptr!(config).fips()
}
}
#[no_mangle]
pub extern "C" fn rustls_server_config_free(config: *const rustls_server_config) {
ffi_panic_boundary! {
free_arc(config);
}
}
#[no_mangle]
pub extern "C" fn rustls_server_connection_new(
config: *const rustls_server_config,
conn_out: *mut *mut rustls_connection,
) -> rustls_result {
ffi_panic_boundary! {
if conn_out.is_null() {
return rustls_result::NullParameter;
}
let config = try_clone_arc!(config);
let conn_out = try_mut_from_ptr_ptr!(conn_out);
let server_connection = match ServerConnection::new(config) {
Ok(sc) => sc,
Err(e) => return map_error(e),
};
let c = Connection::from_server(server_connection);
set_boxed_mut_ptr(conn_out, c);
rustls_result::Ok
}
}
}
#[no_mangle]
pub extern "C" fn rustls_server_connection_get_server_name(
conn: *const rustls_connection,
) -> rustls_str<'static> {
ffi_panic_boundary! {
let Some(server_connection) = try_ref_from_ptr!(conn).as_server() else {
return rustls_str::default();
};
let Some(sni_hostname) = server_connection.server_name() else {
return rustls_str::default();
};
let res = rustls_str::try_from(sni_hostname).unwrap_or_default();
unsafe { res.into_static() }
}
}
#[derive(Debug)]
struct ResolvesServerCertFromChoices {
choices: Vec<Arc<CertifiedKey>>,
}
impl ResolvesServerCertFromChoices {
pub fn new(choices: &[Arc<CertifiedKey>]) -> Self {
ResolvesServerCertFromChoices {
choices: Vec::from(choices),
}
}
}
impl ResolvesServerCert for ResolvesServerCertFromChoices {
fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
for key in self.choices.iter() {
if key
.key
.choose_scheme(client_hello.signature_schemes())
.is_some()
{
return Some(key.clone());
}
}
None
}
}
#[repr(C)]
pub struct rustls_client_hello<'a> {
server_name: rustls_str<'a>,
signature_schemes: rustls_slice_u16<'a>,
named_groups: rustls_slice_u16<'a>,
alpn: *const rustls_slice_slice_bytes<'a>,
}
impl<'a> Castable for rustls_client_hello<'a> {
type Ownership = OwnershipRef;
type RustType = rustls_client_hello<'a>;
}
pub type rustls_client_hello_userdata = *mut c_void;
pub type rustls_client_hello_callback = Option<
unsafe extern "C" fn(
userdata: rustls_client_hello_userdata,
hello: *const rustls_client_hello,
) -> *const rustls_certified_key,
>;
type ClientHelloCallback = unsafe extern "C" fn(
userdata: rustls_client_hello_userdata,
hello: *const rustls_client_hello,
) -> *const rustls_certified_key;
struct ClientHelloResolver {
pub callback: ClientHelloCallback,
}
impl ClientHelloResolver {
pub fn new(callback: ClientHelloCallback) -> ClientHelloResolver {
ClientHelloResolver { callback }
}
}
impl ResolvesServerCert for ClientHelloResolver {
fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
let server_name = client_hello.server_name().unwrap_or_default();
let server_name = match server_name.try_into() {
Ok(r) => r,
Err(_) => return None,
};
let mapped_sigs: Vec<u16> = client_hello
.signature_schemes()
.iter()
.map(|s| u16::from(*s))
.collect();
let mapped_groups = match client_hello.named_groups() {
Some(groups) => groups.iter().map(|g| u16::from(*g)).collect(),
None => Vec::new(),
};
let alpn = match client_hello.alpn() {
Some(iter) => iter.collect(),
None => vec![],
};
let alpn = rustls_slice_slice_bytes { inner: &alpn };
let signature_schemes = (&*mapped_sigs).into();
let named_groups = (&*mapped_groups).into();
let hello = rustls_client_hello {
server_name,
signature_schemes,
named_groups,
alpn: &alpn,
};
let cb = self.callback;
let userdata = match userdata_get() {
Ok(u) => u,
Err(_) => return None,
};
let key_ptr = unsafe { cb(userdata, &hello) };
let certified_key = try_ref_from_ptr!(key_ptr);
Some(Arc::new(certified_key.clone()))
}
}
unsafe impl Sync for ClientHelloResolver {}
unsafe impl Send for ClientHelloResolver {}
impl Debug for ClientHelloResolver {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClientHelloResolver").finish()
}
}
impl rustls_server_config_builder {
#[no_mangle]
pub extern "C" fn rustls_server_config_builder_set_hello_callback(
builder: *mut rustls_server_config_builder,
callback: rustls_client_hello_callback,
) -> rustls_result {
ffi_panic_boundary! {
let callback = match callback {
Some(cb) => cb,
None => return rustls_result::NullParameter,
};
let builder = try_mut_from_ptr!(builder);
builder.cert_resolver = Some(Arc::new(ClientHelloResolver::new(callback)));
rustls_result::Ok
}
}
}
fn sigschemes(input: &[u16]) -> Vec<SignatureScheme> {
input.iter().copied().map(Into::into).collect()
}
#[no_mangle]
pub extern "C" fn rustls_client_hello_select_certified_key(
hello: *const rustls_client_hello,
certified_keys: *const *const rustls_certified_key,
certified_keys_len: size_t,
out_key: *mut *const rustls_certified_key,
) -> rustls_result {
ffi_panic_boundary! {
let hello = try_ref_from_ptr!(hello);
let schemes = sigschemes(try_slice!(
hello.signature_schemes.data,
hello.signature_schemes.len
));
if out_key.is_null() {
return rustls_result::NullParameter;
}
let keys_ptrs = try_slice!(certified_keys, certified_keys_len);
for &key_ptr in keys_ptrs {
let key_ref = try_ref_from_ptr!(key_ptr);
if key_ref.key.choose_scheme(&schemes).is_some() {
unsafe {
*out_key = key_ptr;
}
return rustls_result::Ok;
}
}
rustls_result::NotFound
}
}
impl rustls_server_config_builder {
#[no_mangle]
pub extern "C" fn rustls_server_config_builder_set_persistence(
builder: *mut rustls_server_config_builder,
get_cb: rustls_session_store_get_callback,
put_cb: rustls_session_store_put_callback,
) {
ffi_panic_boundary! {
let Some(get_cb) = get_cb else {
return;
};
let Some(put_cb) = put_cb else {
return;
};
try_mut_from_ptr!(builder).session_storage =
Some(Arc::new(SessionStoreBroker::new(get_cb, put_cb)));
}
}
}
#[cfg(all(test, any(feature = "ring", feature = "aws-lc-rs")))]
mod tests {
use std::ptr::{null, null_mut};
use super::*;
#[test]
#[cfg_attr(miri, ignore)]
fn test_config_builder() {
let builder = rustls_server_config_builder::rustls_server_config_builder_new();
let h1 = "http/1.1".as_bytes();
let h2 = "h2".as_bytes();
let alpn = [h1.into(), h2.into()];
rustls_server_config_builder::rustls_server_config_builder_set_alpn_protocols(
builder,
alpn.as_ptr(),
alpn.len(),
);
let cert_pem = include_str!("../testdata/localhost/cert.pem").as_bytes();
let key_pem = include_str!("../testdata/localhost/key.pem").as_bytes();
let mut certified_key = null();
let result = rustls_certified_key::rustls_certified_key_build(
cert_pem.as_ptr(),
cert_pem.len(),
key_pem.as_ptr(),
key_pem.len(),
&mut certified_key,
);
if !matches!(result, rustls_result::Ok) {
panic!("expected RUSTLS_RESULT_OK from rustls_certified_key_build, got {result:?}");
}
rustls_server_config_builder::rustls_server_config_builder_set_certified_keys(
builder,
&certified_key,
1,
);
let mut config = null();
let result =
rustls_server_config_builder::rustls_server_config_builder_build(builder, &mut config);
assert_eq!(result, rustls_result::Ok);
assert!(!config.is_null());
{
let config2 = try_ref_from_ptr!(config);
assert_eq!(config2.alpn_protocols, vec![h1, h2]);
}
rustls_server_config::rustls_server_config_free(config);
}
#[test]
fn test_server_config_builder_new_empty() {
let builder = rustls_server_config_builder::rustls_server_config_builder_new();
let mut config = null();
let result =
rustls_server_config_builder::rustls_server_config_builder_build(builder, &mut config);
assert_eq!(result, rustls_result::NoCertResolver);
assert!(config.is_null());
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_server_connection_new() {
let builder = rustls_server_config_builder::rustls_server_config_builder_new();
let cert_pem = include_str!("../testdata/localhost/cert.pem").as_bytes();
let key_pem = include_str!("../testdata/localhost/key.pem").as_bytes();
let mut certified_key = null();
let result = rustls_certified_key::rustls_certified_key_build(
cert_pem.as_ptr(),
cert_pem.len(),
key_pem.as_ptr(),
key_pem.len(),
&mut certified_key,
);
if !matches!(result, rustls_result::Ok) {
panic!("expected RUSTLS_RESULT_OK from rustls_certified_key_build, got {result:?}");
}
rustls_server_config_builder::rustls_server_config_builder_set_certified_keys(
builder,
&certified_key,
1,
);
let mut config = null();
let result =
rustls_server_config_builder::rustls_server_config_builder_build(builder, &mut config);
assert_eq!(result, rustls_result::Ok);
assert!(!config.is_null());
let mut conn = null_mut();
let result = rustls_server_config::rustls_server_connection_new(config, &mut conn);
if !matches!(result, rustls_result::Ok) {
panic!("expected RUSTLS_RESULT_OK, got {result:?}");
}
assert!(rustls_connection::rustls_connection_wants_read(conn));
assert!(!rustls_connection::rustls_connection_wants_write(conn));
assert!(rustls_connection::rustls_connection_is_handshaking(conn));
let some_byte = 42u8;
let mut alpn_protocol: *const u8 = &some_byte;
let mut alpn_protocol_len = 1;
rustls_connection::rustls_connection_get_alpn_protocol(
conn,
&mut alpn_protocol,
&mut alpn_protocol_len,
);
assert_eq!(alpn_protocol, null());
assert_eq!(alpn_protocol_len, 0);
assert_eq!(
rustls_connection::rustls_connection_get_negotiated_ciphersuite(conn),
0
);
let cs_name = rustls_connection::rustls_connection_get_negotiated_ciphersuite_name(conn);
assert_eq!(unsafe { cs_name.to_str() }, "");
assert_eq!(
rustls_connection::rustls_connection_get_peer_certificate(conn, 0),
null()
);
assert_eq!(
rustls_connection::rustls_connection_get_protocol_version(conn),
0
);
rustls_connection::rustls_connection_free(conn);
}
}