use crate::{
callbacks::*,
enums::*,
error::{Error, Fallible},
security,
};
use core::{convert::TryInto, ptr::NonNull};
use s2n_tls_sys::*;
use std::{
ffi::{c_void, CString},
path::Path,
sync::atomic::{AtomicUsize, Ordering},
};
#[derive(Debug, PartialEq)]
pub struct Config(NonNull<s2n_config>);
unsafe impl Send for Config {}
unsafe impl Sync for Config {}
impl Config {
pub fn new() -> Self {
Self::default()
}
pub fn builder() -> Builder {
Builder::default()
}
pub(crate) unsafe fn from_raw(config: NonNull<s2n_config>) -> Self {
let config = Self(config);
config.context();
config
}
pub(crate) fn as_mut_ptr(&mut self) -> *mut s2n_config {
self.0.as_ptr()
}
pub(crate) fn context(&self) -> &Context {
let mut ctx = core::ptr::null_mut();
unsafe {
s2n_config_get_ctx(self.0.as_ptr(), &mut ctx)
.into_result()
.unwrap();
&*(ctx as *const Context)
}
}
pub(crate) fn context_mut(&mut self) -> &mut Context {
let mut ctx = core::ptr::null_mut();
unsafe {
s2n_config_get_ctx(self.as_mut_ptr(), &mut ctx)
.into_result()
.unwrap();
&mut *(ctx as *mut Context)
}
}
#[cfg(test)]
pub fn test_get_refcount(&self) -> Result<usize, Error> {
let context = self.context();
Ok(context.refcount.load(Ordering::SeqCst))
}
}
impl Default for Config {
fn default() -> Self {
Builder::new().build().unwrap()
}
}
impl Clone for Config {
fn clone(&self) -> Self {
let context = self.context();
let _count = context.refcount.fetch_add(1, Ordering::Relaxed);
Self(self.0)
}
}
impl Drop for Config {
fn drop(&mut self) {
let context = self.context_mut();
let count = context.refcount.fetch_sub(1, Ordering::Release);
debug_assert!(count > 0, "refcount should not drop below 1 instance");
if count != 1 {
return;
}
std::sync::atomic::fence(Ordering::Acquire);
unsafe {
let context = Box::from_raw(context);
drop(context);
let _ = s2n_config_free(self.0.as_ptr()).into_result();
}
}
}
#[derive(Default)]
pub struct Builder(Config);
impl Builder {
pub fn new() -> Self {
crate::init::init();
let config = unsafe { s2n_config_new().into_result() }.unwrap();
let context = Box::<Context>::default();
let context = Box::into_raw(context) as *mut c_void;
unsafe {
s2n_config_set_ctx(config.as_ptr(), context)
.into_result()
.unwrap();
s2n_config_set_client_hello_cb_mode(
config.as_ptr(),
s2n_client_hello_cb_mode::NONBLOCKING,
)
.into_result()
.unwrap();
}
Self(Config(config))
}
pub fn set_alert_behavior(&mut self, value: AlertBehavior) -> Result<&mut Self, Error> {
unsafe { s2n_config_set_alert_behavior(self.as_mut_ptr(), value.into()).into_result() }?;
Ok(self)
}
pub fn set_security_policy(&mut self, policy: &security::Policy) -> Result<&mut Self, Error> {
unsafe {
s2n_config_set_cipher_preferences(self.as_mut_ptr(), policy.as_cstr().as_ptr())
.into_result()
}?;
Ok(self)
}
pub fn set_application_protocol_preference<P: IntoIterator<Item = I>, I: AsRef<[u8]>>(
&mut self,
protocols: P,
) -> Result<&mut Self, Error> {
unsafe {
s2n_config_set_protocol_preferences(self.as_mut_ptr(), core::ptr::null(), 0)
.into_result()
}?;
for protocol in protocols {
self.append_application_protocol_preference(protocol.as_ref())?;
}
Ok(self)
}
pub fn append_application_protocol_preference(
&mut self,
protocol: &[u8],
) -> Result<&mut Self, Error> {
unsafe {
s2n_config_append_protocol_preference(
self.as_mut_ptr(),
protocol.as_ptr(),
protocol
.len()
.try_into()
.map_err(|_| Error::INVALID_INPUT)?,
)
.into_result()
}?;
Ok(self)
}
pub unsafe fn disable_x509_verification(&mut self) -> Result<&mut Self, Error> {
s2n_config_disable_x509_verification(self.as_mut_ptr()).into_result()?;
Ok(self)
}
pub fn add_dhparams(&mut self, pem: &[u8]) -> Result<&mut Self, Error> {
let cstring = CString::new(pem).map_err(|_| Error::INVALID_INPUT)?;
unsafe { s2n_config_add_dhparams(self.as_mut_ptr(), cstring.as_ptr()).into_result() }?;
Ok(self)
}
pub fn load_pem(&mut self, certificate: &[u8], private_key: &[u8]) -> Result<&mut Self, Error> {
let certificate = CString::new(certificate).map_err(|_| Error::INVALID_INPUT)?;
let private_key = CString::new(private_key).map_err(|_| Error::INVALID_INPUT)?;
unsafe {
s2n_config_add_cert_chain_and_key(
self.as_mut_ptr(),
certificate.as_ptr(),
private_key.as_ptr(),
)
.into_result()
}?;
Ok(self)
}
pub fn trust_pem(&mut self, certificate: &[u8]) -> Result<&mut Self, Error> {
let certificate = CString::new(certificate).map_err(|_| Error::INVALID_INPUT)?;
unsafe {
s2n_config_add_pem_to_trust_store(self.as_mut_ptr(), certificate.as_ptr()).into_result()
}?;
Ok(self)
}
pub fn trust_location(
&mut self,
file: Option<&Path>,
dir: Option<&Path>,
) -> Result<&mut Self, Error> {
fn to_cstr(input: Option<&Path>) -> Result<Option<CString>, Error> {
Ok(match input {
Some(input) => {
let string = input.to_str().ok_or(Error::INVALID_INPUT)?;
let cstring = CString::new(string).map_err(|_| Error::INVALID_INPUT)?;
Some(cstring)
}
None => None,
})
}
let file_cstr = to_cstr(file)?;
let file_ptr = file_cstr
.as_ref()
.map(|f| f.as_ptr())
.unwrap_or(core::ptr::null());
let dir_cstr = to_cstr(dir)?;
let dir_ptr = dir_cstr
.as_ref()
.map(|f| f.as_ptr())
.unwrap_or(core::ptr::null());
unsafe {
s2n_config_set_verification_ca_location(self.as_mut_ptr(), file_ptr, dir_ptr)
.into_result()
}?;
Ok(self)
}
pub fn wipe_trust_store(&mut self) -> Result<&mut Self, Error> {
unsafe { s2n_config_wipe_trust_store(self.as_mut_ptr()).into_result()? };
Ok(self)
}
pub fn set_client_auth_type(&mut self, auth_type: ClientAuthType) -> Result<&mut Self, Error> {
unsafe {
s2n_config_set_client_auth_type(self.as_mut_ptr(), auth_type.into()).into_result()
}?;
Ok(self)
}
pub fn enable_ocsp(&mut self) -> Result<&mut Self, Error> {
unsafe {
s2n_config_set_status_request_type(self.as_mut_ptr(), s2n_status_request_type::OCSP)
.into_result()
}?;
Ok(self)
}
pub fn set_ocsp_data(&mut self, data: &[u8]) -> Result<&mut Self, Error> {
let size: u32 = data.len().try_into().map_err(|_| Error::INVALID_INPUT)?;
unsafe {
s2n_config_set_extension_data(
self.as_mut_ptr(),
s2n_tls_extension_type::OCSP_STAPLING,
data.as_ptr(),
size,
)
.into_result()
}?;
self.enable_ocsp()
}
pub fn set_verify_host_callback<T: 'static + VerifyHostNameCallback>(
&mut self,
handler: T,
) -> Result<&mut Self, Error> {
unsafe extern "C" fn verify_host_cb(
host_name: *const ::libc::c_char,
host_name_len: usize,
context: *mut ::libc::c_void,
) -> u8 {
let host_name = host_name as *const u8;
let host_name = core::slice::from_raw_parts(host_name, host_name_len);
if let Ok(host_name_str) = core::str::from_utf8(host_name) {
let context = &mut *(context as *mut Context);
let handler = context.verify_host_callback.as_mut().unwrap();
return handler.verify_host_name(host_name_str) as u8;
}
0 }
let handler = Box::new(handler);
let context = self.0.context_mut();
context.verify_host_callback = Some(handler);
unsafe {
s2n_config_set_verify_host_callback(
self.as_mut_ptr(),
Some(verify_host_cb),
self.0.context_mut() as *mut _ as *mut c_void,
)
.into_result()?;
}
Ok(self)
}
pub unsafe fn set_key_log_callback(
&mut self,
callback: s2n_key_log_fn,
context: *mut core::ffi::c_void,
) -> Result<&mut Self, Error> {
s2n_config_set_key_log_cb(self.as_mut_ptr(), callback, context).into_result()?;
Ok(self)
}
pub fn set_max_cert_chain_depth(&mut self, depth: u16) -> Result<&mut Self, Error> {
unsafe { s2n_config_set_max_cert_chain_depth(self.as_mut_ptr(), depth).into_result() }?;
Ok(self)
}
pub fn set_send_buffer_size(&mut self, size: u32) -> Result<&mut Self, Error> {
unsafe { s2n_config_set_send_buffer_size(self.as_mut_ptr(), size).into_result() }?;
Ok(self)
}
pub fn set_client_hello_callback<T: 'static + ClientHelloCallback>(
&mut self,
handler: T,
) -> Result<&mut Self, Error> {
unsafe extern "C" fn client_hello_cb(
connection_ptr: *mut s2n_connection,
_context: *mut core::ffi::c_void,
) -> libc::c_int {
with_connection(connection_ptr, |conn| {
trigger_async_client_hello_callback(conn).into()
})
}
let handler = Box::new(handler);
let context = self.0.context_mut();
context.client_hello_callback = Some(handler);
unsafe {
s2n_config_set_client_hello_cb(
self.as_mut_ptr(),
Some(client_hello_cb),
core::ptr::null_mut(),
)
.into_result()?;
}
Ok(self)
}
pub fn set_wall_clock<T: 'static + WallClock>(
&mut self,
handler: T,
) -> Result<&mut Self, Error> {
unsafe extern "C" fn clock_cb(
context: *mut ::libc::c_void,
time_in_nanos: *mut u64,
) -> libc::c_int {
let context = &mut *(context as *mut Context);
if let Some(handler) = context.wall_clock.as_mut() {
if let Ok(nanos) = handler.get_time_since_epoch().as_nanos().try_into() {
*time_in_nanos = nanos;
return CallbackResult::Success.into();
}
}
CallbackResult::Failure.into()
}
let handler = Box::new(handler);
let context = self.0.context_mut();
context.wall_clock = Some(handler);
unsafe {
s2n_config_set_wall_clock(
self.as_mut_ptr(),
Some(clock_cb),
self.0.context_mut() as *mut _ as *mut c_void,
)
.into_result()?;
}
Ok(self)
}
pub fn set_monotonic_clock<T: 'static + MonotonicClock>(
&mut self,
handler: T,
) -> Result<&mut Self, Error> {
unsafe extern "C" fn clock_cb(
context: *mut ::libc::c_void,
time_in_nanos: *mut u64,
) -> libc::c_int {
let context = &mut *(context as *mut Context);
if let Some(handler) = context.monotonic_clock.as_mut() {
if let Ok(nanos) = handler.get_time().as_nanos().try_into() {
*time_in_nanos = nanos;
return CallbackResult::Success.into();
}
}
CallbackResult::Failure.into()
}
let handler = Box::new(handler);
let context = self.0.context_mut();
context.monotonic_clock = Some(handler);
unsafe {
s2n_config_set_monotonic_clock(
self.as_mut_ptr(),
Some(clock_cb),
self.0.context_mut() as *mut _ as *mut c_void,
)
.into_result()?;
}
Ok(self)
}
pub fn build(self) -> Result<Config, Error> {
Ok(self.0)
}
fn as_mut_ptr(&mut self) -> *mut s2n_config {
self.0.as_mut_ptr()
}
}
#[cfg(feature = "quic")]
impl Builder {
pub fn enable_quic(&mut self) -> Result<&mut Self, Error> {
unsafe { s2n_tls_sys::s2n_config_enable_quic(self.as_mut_ptr()).into_result() }?;
Ok(self)
}
}
pub(crate) struct Context {
refcount: AtomicUsize,
pub(crate) client_hello_callback: Option<Box<dyn ClientHelloCallback>>,
pub(crate) verify_host_callback: Option<Box<dyn VerifyHostNameCallback>>,
pub(crate) wall_clock: Option<Box<dyn WallClock>>,
pub(crate) monotonic_clock: Option<Box<dyn MonotonicClock>>,
}
impl Default for Context {
fn default() -> Self {
let refcount = AtomicUsize::new(1);
Self {
refcount,
client_hello_callback: None,
verify_host_callback: None,
wall_clock: None,
monotonic_clock: None,
}
}
}