use crate::error::{Error, ErrorType, Fallible};
use s2n_tls_sys::*;
use std::{
any::Any,
ffi::c_void,
marker::PhantomData,
ptr::{self, NonNull},
sync::Arc,
};
#[derive(Debug)]
pub(crate) struct CertificateChainHandle<'a> {
pub(crate) cert: NonNull<s2n_cert_chain_and_key>,
is_owned: bool,
_lifetime: PhantomData<&'a s2n_cert_chain_and_key>,
}
unsafe impl Send for CertificateChainHandle<'_> {}
unsafe impl Sync for CertificateChainHandle<'_> {}
impl CertificateChainHandle<'_> {
pub(crate) fn allocate() -> Result<CertificateChainHandle<'static>, crate::error::Error> {
crate::init::init();
Ok(CertificateChainHandle {
cert: unsafe { s2n_cert_chain_and_key_new().into_result() }?,
is_owned: true,
_lifetime: PhantomData,
})
}
fn from_reference(cert: NonNull<s2n_cert_chain_and_key>) -> Self {
Self {
cert,
is_owned: false,
_lifetime: PhantomData,
}
}
fn context_mut(&mut self) -> Option<&mut Context> {
let context = unsafe { s2n_cert_chain_and_key_get_ctx(self.cert.as_ptr()) };
if context.is_null() {
None
} else {
Some(unsafe { &mut *(context as *mut Context) })
}
}
fn context(&self) -> Option<&Context> {
let context = unsafe { s2n_cert_chain_and_key_get_ctx(self.cert.as_ptr()) };
if context.is_null() {
None
} else {
Some(unsafe { &*(context as *const Context) })
}
}
}
impl Drop for CertificateChainHandle<'_> {
fn drop(&mut self) {
if self.is_owned {
if let Some(internal_context) = self.context_mut() {
drop(unsafe { Box::from_raw(internal_context) });
}
unsafe {
let _ = s2n_cert_chain_and_key_set_ctx(self.cert.as_ptr(), std::ptr::null_mut())
.into_result();
let _ = s2n_cert_chain_and_key_free(self.cert.as_ptr()).into_result();
}
}
}
}
struct Context {
application_context: Box<dyn Any + Send + Sync>,
}
#[derive(Debug)]
pub struct Builder {
cert_handle: CertificateChainHandle<'static>,
}
impl Builder {
pub fn new() -> Result<Self, Error> {
Ok(Self {
cert_handle: CertificateChainHandle::allocate()?,
})
}
pub fn load_pem(&mut self, chain: &[u8], key: &[u8]) -> Result<&mut Self, Error> {
unsafe {
s2n_cert_chain_and_key_load_pem_bytes(
self.cert_handle.cert.as_ptr(),
chain.as_ptr() as *mut _,
chain.len() as u32,
key.as_ptr() as *mut _,
key.len() as u32,
)
.into_result()
}?;
Ok(self)
}
pub fn load_public_pem(&mut self, chain: &[u8]) -> Result<&mut Self, Error> {
unsafe {
s2n_cert_chain_and_key_load_public_pem_bytes(
self.cert_handle.cert.as_ptr(),
chain.as_ptr() as *mut _,
chain.len() as u32,
)
.into_result()
}?;
Ok(self)
}
pub fn set_ocsp_data(&mut self, data: &[u8]) -> Result<&mut Self, Error> {
unsafe {
s2n_cert_chain_and_key_set_ocsp_data(
self.cert_handle.cert.as_ptr(),
data.as_ptr(),
data.len() as u32,
)
.into_result()
}?;
Ok(self)
}
pub fn set_application_context<T: Send + Sync + 'static>(
&mut self,
app_context: T,
) -> Result<&mut Self, Error> {
match self.cert_handle.context_mut() {
Some(_) => Err(Error::bindings(
ErrorType::UsageError,
"cert builder error",
"set_application_context can only be called once",
)),
None => {
let app_context = Box::new(app_context);
let internal_context = Box::new(Context {
application_context: app_context,
});
unsafe {
s2n_cert_chain_and_key_set_ctx(
self.cert_handle.cert.as_ptr(),
Box::into_raw(internal_context) as *mut c_void,
)
.into_result()
}?;
Ok(self)
}
}
}
pub fn build(self) -> Result<CertificateChain<'static>, Error> {
Ok(CertificateChain::from_allocated(self.cert_handle))
}
}
#[derive(Clone)]
pub struct CertificateChain<'a> {
cert_handle: Arc<CertificateChainHandle<'a>>,
}
impl CertificateChain<'_> {
pub(crate) fn from_allocated(
handle: CertificateChainHandle<'static>,
) -> CertificateChain<'static> {
CertificateChain {
cert_handle: Arc::new(handle),
}
}
pub(crate) unsafe fn from_ptr_reference<'a>(
ptr: NonNull<s2n_cert_chain_and_key>,
) -> CertificateChain<'a> {
let handle = Arc::new(CertificateChainHandle::from_reference(ptr));
CertificateChain {
cert_handle: handle,
}
}
pub fn iter(&self) -> CertificateChainIter<'_> {
CertificateChainIter {
idx: 0,
len: self.len(),
chain: self,
}
}
pub fn application_context<T: Send + Sync + 'static>(&self) -> Option<&T> {
if let Some(internal_context) = self.cert_handle.context() {
internal_context.application_context.downcast_ref()
} else {
None
}
}
pub fn len(&self) -> usize {
let mut length: u32 = 0;
let res = unsafe { s2n_cert_chain_get_length(self.as_ptr(), &mut length).into_result() };
if res.is_err() {
return 0;
}
length.try_into().unwrap()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub(crate) fn as_ptr(&self) -> *const s2n_cert_chain_and_key {
self.cert_handle.cert.as_ptr() as *const _
}
}
pub struct CertificateChainIter<'a> {
idx: u32,
len: usize,
chain: &'a CertificateChain<'a>,
}
impl<'a> Iterator for CertificateChainIter<'a> {
type Item = Result<Certificate<'a>, Error>;
fn next(&mut self) -> Option<Self::Item> {
let idx = self.idx;
if usize::try_from(idx).unwrap() >= self.len {
return None;
}
self.idx += 1;
let mut out = ptr::null_mut();
unsafe {
if let Err(e) =
s2n_cert_chain_get_cert(self.chain.as_ptr(), &mut out, idx).into_result()
{
return Some(Err(e));
}
}
let out = match NonNull::new(out) {
Some(out) => out,
None => return Some(Err(Error::INVALID_INPUT)),
};
Some(Ok(Certificate {
chain: PhantomData,
certificate: out,
}))
}
}
pub struct Certificate<'a> {
chain: PhantomData<&'a CertificateChain<'a>>,
certificate: NonNull<s2n_cert>,
}
impl Certificate<'_> {
pub fn der(&self) -> Result<&[u8], Error> {
unsafe {
let mut buffer = ptr::null();
let mut length = 0;
s2n_cert_get_der(self.certificate.as_ptr(), &mut buffer, &mut length).into_result()?;
let length = usize::try_from(length).map_err(|_| Error::INVALID_INPUT)?;
Ok(std::slice::from_raw_parts(buffer, length))
}
}
}
unsafe impl Send for Certificate<'_> {}
#[cfg(test)]
mod tests {
use crate::{
config,
error::{Error as S2NError, ErrorSource, ErrorType},
security::DEFAULT_TLS13,
testing::{
config_builder, CertKeyPair, InsecureAcceptAllCertificatesHandler, SniTestCerts,
TestPair,
},
};
use super::*;
fn sni_test_pair(
certs: Vec<CertificateChain<'static>>,
defaults: Option<Vec<CertificateChain<'static>>>,
types: &[SniTestCerts],
) -> Result<TestPair, crate::error::Error> {
let mut server_config = config::Builder::new();
server_config
.with_system_certs(false)?
.set_security_policy(&DEFAULT_TLS13)?;
for cert in certs.into_iter() {
server_config.load_chain(cert)?;
}
if let Some(defaults) = defaults {
server_config.set_default_chains(defaults)?;
}
let mut client_config = config::Builder::new();
client_config
.with_system_certs(false)?
.set_security_policy(&DEFAULT_TLS13)?
.set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?;
for t in types {
client_config.trust_pem(t.get().cert())?;
}
Ok(TestPair::from_configs(
&client_config.build()?,
&server_config.build()?,
))
}
fn cert_chains_are_equal(this: &CertificateChain<'_>, that: &CertificateChain<'_>) -> bool {
let this: Vec<Vec<u8>> = this
.iter()
.map(|cert| cert.unwrap().der().unwrap().to_owned())
.collect();
let that: Vec<Vec<u8>> = that
.iter()
.map(|cert| cert.unwrap().der().unwrap().to_owned())
.collect();
this == that
}
#[test]
fn reference_count_increment() -> Result<(), crate::error::Error> {
let cert = SniTestCerts::AlligatorRsa.get().into_certificate_chain();
assert_eq!(Arc::strong_count(&cert.cert_handle), 1);
{
let mut server = config::Builder::new();
server.load_chain(cert.clone())?;
assert_eq!(Arc::strong_count(&cert.cert_handle), 2);
}
assert_eq!(Arc::strong_count(&cert.cert_handle), 1);
Ok(())
}
#[test]
fn cert_is_dropped() {
let weak_ref = {
let cert = SniTestCerts::AlligatorEcdsa.get().into_certificate_chain();
assert_eq!(Arc::strong_count(&cert.cert_handle), 1);
Arc::downgrade(&cert.cert_handle)
};
assert_eq!(weak_ref.strong_count(), 0);
assert!(weak_ref.upgrade().is_none());
}
#[test]
fn shared_certs() -> Result<(), crate::error::Error> {
let test_key_pair = SniTestCerts::AlligatorRsa.get();
let cert = test_key_pair.into_certificate_chain();
let mut test_pair_1 =
sni_test_pair(vec![cert.clone()], None, &[SniTestCerts::AlligatorRsa])?;
let mut test_pair_2 =
sni_test_pair(vec![cert.clone()], None, &[SniTestCerts::AlligatorRsa])?;
assert_eq!(Arc::strong_count(&cert.cert_handle), 3);
assert!(test_pair_1.handshake().is_ok());
assert!(test_pair_2.handshake().is_ok());
assert_eq!(Arc::strong_count(&cert.cert_handle), 3);
drop(test_pair_1);
assert_eq!(Arc::strong_count(&cert.cert_handle), 2);
drop(test_pair_2);
assert_eq!(Arc::strong_count(&cert.cert_handle), 1);
Ok(())
}
#[test]
fn too_many_certs_in_default() -> Result<(), crate::error::Error> {
const FAILING_NUMBER: usize = 6;
let certs = vec![SniTestCerts::AlligatorRsa.get().into_certificate_chain(); FAILING_NUMBER];
assert_eq!(Arc::strong_count(&certs[0].cert_handle), FAILING_NUMBER);
let mut config = config::Builder::new();
let err = config.set_default_chains(certs.clone()).err().unwrap();
assert_eq!(err.kind(), ErrorType::UsageError);
assert_eq!(err.source(), ErrorSource::Bindings);
assert_eq!(Arc::strong_count(&certs[0].cert_handle), FAILING_NUMBER);
Ok(())
}
#[test]
fn default_selection() -> Result<(), crate::error::Error> {
let alligator_cert = SniTestCerts::AlligatorRsa.get().into_certificate_chain();
let beaver_cert = SniTestCerts::BeaverRsa.get().into_certificate_chain();
{
let mut test_pair = sni_test_pair(
vec![alligator_cert.clone(), beaver_cert.clone()],
None,
&[SniTestCerts::AlligatorRsa, SniTestCerts::BeaverRsa],
)?;
assert!(test_pair.handshake().is_ok());
assert!(cert_chains_are_equal(
&alligator_cert,
&test_pair.client.peer_cert_chain().unwrap()
));
assert_eq!(Arc::strong_count(&alligator_cert.cert_handle), 2);
assert_eq!(Arc::strong_count(&beaver_cert.cert_handle), 2);
}
{
let mut test_pair = sni_test_pair(
vec![alligator_cert.clone(), beaver_cert.clone()],
Some(vec![beaver_cert.clone()]),
&[SniTestCerts::AlligatorRsa, SniTestCerts::BeaverRsa],
)?;
assert!(test_pair.handshake().is_ok());
assert!(cert_chains_are_equal(
&beaver_cert,
&test_pair.client.peer_cert_chain().unwrap()
));
assert_eq!(Arc::strong_count(&alligator_cert.cert_handle), 2);
assert_eq!(Arc::strong_count(&beaver_cert.cert_handle), 3);
}
{
let mut test_pair = sni_test_pair(
vec![alligator_cert.clone()],
Some(vec![beaver_cert.clone()]),
&[SniTestCerts::AlligatorRsa, SniTestCerts::BeaverRsa],
)?;
assert!(test_pair.handshake().is_ok());
assert!(cert_chains_are_equal(
&beaver_cert,
&test_pair.client.peer_cert_chain().unwrap()
));
assert_eq!(Arc::strong_count(&alligator_cert.cert_handle), 2);
assert_eq!(Arc::strong_count(&beaver_cert.cert_handle), 2);
}
Ok(())
}
#[test]
fn cert_ownership_error() -> Result<(), crate::error::Error> {
let application_owned_cert = SniTestCerts::AlligatorRsa.get().into_certificate_chain();
let cert_for_lib = SniTestCerts::BeaverRsa.get();
let mut config = config::Builder::new();
config.load_chain(application_owned_cert)?;
let err = config
.load_pem(cert_for_lib.cert(), cert_for_lib.key())
.err()
.unwrap();
assert_eq!(err.kind(), ErrorType::UsageError);
assert_eq!(err.name(), "S2N_ERR_CERT_OWNERSHIP");
Ok(())
}
#[test]
fn certificate_send_sync_test() {
fn assert_send_sync<T: 'static + Send + Sync>() {}
assert_send_sync::<CertificateChain<'static>>();
}
#[test]
fn application_context_workflow() -> Result<(), S2NError> {
let context: Arc<u64> = Arc::new(0xC0FFEE);
let handle = Arc::clone(&context);
assert_eq!(Arc::strong_count(&handle), 2);
let default = CertKeyPair::default();
let mut chain = Builder::new()?;
chain.load_pem(default.cert(), default.key())?;
chain.set_application_context(context)?;
let chain = chain.build()?;
let invalid_type_get = chain.application_context::<u64>();
assert!(invalid_type_get.is_none());
let retrieved_context = chain.application_context::<Arc<u64>>().unwrap();
assert_eq!(*retrieved_context.as_ref(), 0xC0FFEE);
assert_eq!(Arc::strong_count(&handle), 2);
drop(chain);
assert_eq!(Arc::strong_count(&handle), 1);
Ok(())
}
#[test]
fn application_context_override() -> Result<(), S2NError> {
let initial: Arc<u64> = Arc::new(0xC0FFEE);
let overridden: Arc<[u8; 6]> = Arc::new(*b"coffee");
let mut builder = Builder::new()?;
builder.set_application_context(initial)?;
let err = builder.set_application_context(overridden).unwrap_err();
assert_eq!(err.kind(), ErrorType::UsageError);
Ok(())
}
#[test]
fn application_context_from_selected_cert() -> Result<(), S2NError> {
let default = CertKeyPair::default();
let mut chain = Builder::new()?;
chain.load_pem(default.cert(), default.key())?;
chain.set_application_context(0xC0FFEE_u64)?;
let mut server_config = config::Builder::new();
server_config.load_chain(chain.build()?)?;
let client_config = config_builder(&crate::security::DEFAULT).unwrap();
let mut test_pair =
TestPair::from_configs(&client_config.build()?, &server_config.build()?);
test_pair.handshake()?;
let selected_cert = test_pair.server.selected_cert().unwrap();
let context = selected_cert.application_context::<u64>();
assert_eq!(context, Some(&0xC0FFEE_u64));
Ok(())
}
}