use s2n_tls_sys::*;
use crate::{
callbacks::with_context,
config,
connection::Connection,
enums::CallbackResult,
error::{Error, Fallible},
};
use std::{marker::PhantomData, ptr::NonNull};
pub struct CertificateRequest<'a> {
request: NonNull<s2n_certificate_request>,
conn: &'a mut Connection,
}
pub struct CertificateAuthorities<'a>(
Option<NonNull<s2n_certificate_authority_list>>,
PhantomData<&'a mut ()>,
);
#[derive(Default)]
pub(crate) struct CertRequestState {
chain: Option<crate::cert_chain::CertificateChain<'static>>,
}
impl CertificateRequest<'_> {
pub fn certificate_authorities(&mut self) -> CertificateAuthorities<'_> {
unsafe {
let list = NonNull::new(s2n_certificate_request_get_ca_list(self.request.as_ptr()));
CertificateAuthorities(list, PhantomData)
}
}
pub fn connection(&mut self) -> &mut Connection {
self.conn
}
pub fn set_certificate(
&mut self,
cert_chain: crate::cert_chain::CertificateChain<'static>,
) -> Result<(), Error> {
let ptr = cert_chain.as_ptr();
self.conn.cert_request_state().chain = Some(cert_chain);
unsafe {
s2n_certificate_request_set_certificate(self.request.as_ptr(), ptr as *mut _)
.into_result()?;
}
Ok(())
}
}
impl<'a> CertificateAuthorities<'a> {
pub fn reset(&mut self) -> Result<(), Error> {
if let Some(this) = self.0.as_ref().map(|v| v.as_ptr()) {
unsafe {
s2n_certificate_authority_list_reread(this).into_result()?;
}
}
Ok(())
}
}
pub struct CertificateAuthority<'a>(&'a [u8]);
impl<'a> CertificateAuthority<'a> {
pub fn der(&self) -> &[u8] {
self.0
}
}
impl<'a> Iterator for CertificateAuthorities<'a> {
type Item = Result<CertificateAuthority<'a>, Error>;
fn next(&mut self) -> Option<Self::Item> {
let mut ptr = std::ptr::null_mut::<u8>();
let mut length: u16 = 0;
if let Some(this) = self.0.as_ref().map(|v| v.as_ptr()) {
unsafe {
if !s2n_certificate_authority_list_has_next(this) {
return None;
}
}
unsafe {
if let Err(e) =
s2n_certificate_authority_list_next(this, &mut ptr, &mut length).into_result()
{
return Some(Err(e));
}
}
}
if ptr.is_null() {
return Some(Err(crate::error::Error::INVALID_INPUT));
}
if length == 0 {
return Some(Ok(CertificateAuthority(&[])));
}
Some(Ok(CertificateAuthority(unsafe {
std::slice::from_raw_parts(ptr, usize::from(length))
})))
}
}
pub trait CertificateRequestCallback: 'static + Send + Sync {
fn on_certificate_request(&self, request: &mut CertificateRequest) -> Result<(), Error>;
}
impl config::Builder {
pub fn set_certificate_request_callback<T: 'static + CertificateRequestCallback>(
&mut self,
handler: T,
) -> Result<&mut Self, Error> {
unsafe extern "C" fn cert_request_callback(
conn_ptr: *mut s2n_connection,
_context: *mut libc::c_void,
request: *mut s2n_certificate_request,
) -> libc::c_int {
let request = match NonNull::new(request) {
Some(r) => r,
None => return CallbackResult::Failure.into(),
};
with_context(conn_ptr, |conn, context| {
let callback = context.cert_authorities.as_ref();
if let Some(callback) = callback {
let mut req = CertificateRequest { request, conn };
match callback.on_certificate_request(&mut req) {
Ok(()) => return CallbackResult::Success.into(),
Err(_err) => return CallbackResult::Failure.into(),
}
}
CallbackResult::Success.into()
})
}
let handler = Box::new(handler);
let context = unsafe {
self.config.context_mut()
};
context.cert_authorities = Some(handler);
unsafe {
s2n_config_set_cert_request_callback(
self.as_mut_ptr(),
Some(cert_request_callback),
std::ptr::null_mut(),
)
.into_result()?;
}
Ok(self)
}
}
#[cfg(test)]
mod tests {
use crate::{
cert_chain::CertificateChain,
enums::ClientAuthType,
security,
testing::{config_builder, CertKeyPair, TestPair},
};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
struct SetCallback {
certificate: CertificateChain<'static>,
}
impl super::CertificateRequestCallback for SetCallback {
fn on_certificate_request(
&self,
request: &mut super::CertificateRequest,
) -> Result<(), super::Error> {
request.set_certificate(self.certificate.clone())?;
Ok(())
}
}
struct ExtractCallback {
certificate: CertificateChain<'static>,
cas: Arc<Mutex<Vec<Vec<u8>>>>,
}
impl super::CertificateRequestCallback for ExtractCallback {
fn on_certificate_request(
&self,
request: &mut super::CertificateRequest,
) -> Result<(), super::Error> {
let mut cas = self.cas.lock().unwrap();
cas.clear();
for ca in request.certificate_authorities() {
cas.push(ca?.der().to_owned());
}
request.certificate_authorities().reset()?;
for (idx, ca) in request.certificate_authorities().enumerate() {
assert_eq!(cas[idx], ca?.der());
}
request.set_certificate(self.certificate.clone())?;
Ok(())
}
}
#[test]
fn basic() -> Result<(), Box<dyn std::error::Error>> {
let keypair = CertKeyPair::default();
let config = {
let mut config = config_builder(&security::DEFAULT_TLS13)?;
config.set_client_auth_type(ClientAuthType::Required)?;
config.trust_pem(keypair.cert())?;
config.set_certificate_request_callback(SetCallback {
certificate: keypair.into_certificate_chain(),
})?;
config.build()?
};
let mut pair = TestPair::from_config(&config);
assert!(pair.server.selected_cert().is_none());
assert!(pair.client.selected_cert().is_none());
pair.handshake()?;
for conn in [&pair.server, &pair.client] {
let chain = conn.selected_cert().unwrap();
assert_eq!(chain.len(), 1);
for cert in chain.iter() {
let cert = cert?;
let cert = cert.der()?;
assert!(!cert.is_empty());
}
}
assert_eq!(
pair.server
.selected_cert()
.unwrap()
.iter()
.next()
.unwrap()?
.der()?,
pair.client
.selected_cert()
.unwrap()
.iter()
.next()
.unwrap()?
.der()?
);
Ok(())
}
#[test]
fn change_cert() -> Result<(), Box<dyn std::error::Error>> {
let keypair = CertKeyPair::from_path("rsa_4096_sha384_client_", "cert", "key", "cert");
let config = {
let mut config = config_builder(&security::DEFAULT_TLS13)?;
config.set_client_auth_type(ClientAuthType::Required)?;
config.trust_pem(keypair.cert())?;
config.set_certificate_request_callback(SetCallback {
certificate: keypair.into_certificate_chain(),
})?;
config.build()?
};
let mut pair = TestPair::from_config(&config);
assert!(pair.server.selected_cert().is_none());
assert!(pair.client.selected_cert().is_none());
pair.handshake()?;
for conn in [&pair.server, &pair.client] {
let chain = conn.selected_cert().unwrap();
assert_eq!(chain.len(), 1);
for cert in chain.iter() {
let cert = cert?;
let cert = cert.der()?;
assert!(!cert.is_empty());
}
}
assert_eq!(
pair.server
.selected_cert()
.unwrap()
.iter()
.next()
.unwrap()?
.der()?,
CertKeyPair::default()
.into_certificate_chain()
.iter()
.next()
.unwrap()?
.der()?
);
assert_eq!(
pair.client
.selected_cert()
.unwrap()
.iter()
.next()
.unwrap()?
.der()?,
keypair
.into_certificate_chain()
.iter()
.next()
.unwrap()?
.der()?
);
Ok(())
}
#[test]
fn ca_list_empty() -> Result<(), Box<dyn std::error::Error>> {
let keypair = CertKeyPair::default();
let cas = Arc::new(Mutex::new(vec![]));
let config = {
let mut config = config_builder(&security::DEFAULT_TLS13)?;
config.set_client_auth_type(ClientAuthType::Required)?;
config.set_certificate_request_callback(ExtractCallback {
certificate: keypair.into_certificate_chain(),
cas: cas.clone(),
})?;
config.build()?
};
let mut pair = TestPair::from_config(&config);
assert!(pair.server.selected_cert().is_none());
assert!(pair.client.selected_cert().is_none());
pair.handshake()?;
let cas = cas.lock().unwrap();
assert_eq!(cas.len(), 0);
Ok(())
}
#[test]
fn ca_list() -> Result<(), Box<dyn std::error::Error>> {
let keypair = CertKeyPair::default();
let cas = Arc::new(Mutex::new(vec![]));
let config = {
let mut config = config_builder(&security::DEFAULT_TLS13)?;
config.set_client_auth_type(ClientAuthType::Required)?;
config.set_certificate_authorities_from_trust_store()?;
config.set_certificate_request_callback(ExtractCallback {
certificate: keypair.into_certificate_chain(),
cas: cas.clone(),
})?;
config.with_system_certs(false)?;
config.build()?
};
let mut pair = TestPair::from_config(&config);
assert!(pair.server.selected_cert().is_none());
assert!(pair.client.selected_cert().is_none());
pair.handshake()?;
let cas = cas.lock().unwrap();
assert_eq!(cas.len(), 1);
let ca = cas.iter().next().unwrap();
let decoded = openssl::x509::X509Name::from_der(ca)?;
let expected =
openssl::x509::X509Name::load_client_ca_file(CertKeyPair::default().cert_path())?;
assert_eq!(
decoded.try_cmp(expected.get(0).unwrap())?,
std::cmp::Ordering::Equal
);
Ok(())
}
#[derive(Clone)]
struct Pick(Arc<Mutex<HashMap<&'static str, crate::config::Config>>>);
impl crate::callbacks::ClientHelloCallback for Pick {
fn on_client_hello(
&self,
connection: &mut crate::connection::Connection,
) -> crate::callbacks::ConnectionFutureResult {
let name = connection.server_name().unwrap();
let this = self.0.lock().unwrap();
let config = this.get(name).expect(name).clone();
connection.set_config(config)?;
connection.server_name_extension_used();
Ok(None)
}
}
struct DynamicSelect;
impl super::CertificateRequestCallback for DynamicSelect {
fn on_certificate_request(
&self,
request: &mut super::CertificateRequest,
) -> Result<(), super::Error> {
let a = CertKeyPair::from_path("rsa_4096_sha256_client_", "cert", "key", "cert");
let b = CertKeyPair::from_path("rsa_4096_sha384_client_", "cert", "key", "cert");
let reply = match request.certificate_authorities().count() {
1 => a,
2 => b,
_ => unreachable!(),
};
request.set_certificate(reply.into_certificate_chain())?;
Ok(())
}
}
#[test]
fn dynamic_pick() -> Result<(), Box<dyn std::error::Error>> {
let a = CertKeyPair::from_path("rsa_4096_sha256_client_", "cert", "key", "cert");
let b = CertKeyPair::from_path("rsa_4096_sha384_client_", "cert", "key", "cert");
let c = CertKeyPair::from_path("rsa_4096_sha512_client_", "cert", "key", "cert");
let mut map = HashMap::new();
map.insert("a.example.com", {
let mut config = config_builder(&security::DEFAULT_TLS13)?;
config.set_client_auth_type(ClientAuthType::Required)?;
config.wipe_trust_store()?;
config.trust_pem(a.cert())?;
config.set_certificate_authorities_from_trust_store()?;
config.build()?
});
map.insert("b.example.com", {
let mut config = config_builder(&security::DEFAULT_TLS13)?;
config.set_client_auth_type(ClientAuthType::Required)?;
config.wipe_trust_store()?;
config.trust_pem(b.cert())?;
config.trust_pem(c.cert())?;
config.set_certificate_authorities_from_trust_store()?;
config.build()?
});
let pick = Pick(Arc::new(Mutex::new(map)));
let server_config = {
let mut config = config_builder(&security::DEFAULT_TLS13)?;
config.set_client_auth_type(ClientAuthType::Required)?;
config.set_client_hello_callback(pick.clone())?;
config.build()?
};
let client_config = {
let mut config = config_builder(&security::DEFAULT_TLS13)?;
config.set_client_auth_type(ClientAuthType::Required)?;
config.set_certificate_request_callback(DynamicSelect)?;
config.build()?
};
let mut pair = TestPair::from_configs(&client_config, &server_config);
pair.client.set_server_name("a.example.com")?;
pair.server
.set_waker(Some(futures_test::task::noop_waker_ref()))?;
pair.handshake()?;
assert_eq!(
pair.client
.selected_cert()
.unwrap()
.iter()
.next()
.unwrap()?
.der()?,
a.into_certificate_chain().iter().next().unwrap()?.der()?
);
let mut pair = TestPair::from_configs(&client_config, &server_config);
pair.client.set_server_name("b.example.com")?;
pair.server
.set_waker(Some(futures_test::task::noop_waker_ref()))?;
pair.handshake()?;
assert_eq!(
pair.client
.selected_cert()
.unwrap()
.iter()
.next()
.unwrap()?
.der()?,
b.into_certificate_chain().iter().next().unwrap()?.der()?
);
Ok(())
}
}