use crate::{
callback::IOCallbacks,
error::{Error, Result},
ssl::{Session, SessionConfig},
CurveGroup, Method, NewSessionError, RootCertificate, Secret, SslVerifyMode,
};
use std::os::raw::c_int;
use std::ptr::NonNull;
use thiserror::Error;
#[derive(Debug)]
pub struct ContextBuilder {
ctx: NonNull<wolfssl_sys::WOLFSSL_CTX>,
method: Method,
}
#[derive(Error, Debug)]
pub enum NewContextBuilderError {
#[error("Failed to initialize WolfSSL: {0}")]
InitFailed(Error),
#[error("Failed to obtain WOLFSSL_METHOD")]
MethodFailed,
#[error("Failed to allocate WolfSSL Context")]
CreateFailed,
}
impl ContextBuilder {
pub fn new(method: Method) -> std::result::Result<Self, NewContextBuilderError> {
crate::wolf_init().map_err(NewContextBuilderError::InitFailed)?;
let method_fn = method
.into_method_ptr()
.ok_or(NewContextBuilderError::MethodFailed)?;
let ctx = unsafe { wolfssl_sys::wolfSSL_CTX_new(method_fn.as_ptr()) };
let ctx = NonNull::new(ctx).ok_or(NewContextBuilderError::CreateFailed)?;
Ok(Self { ctx, method })
}
pub fn try_when<F>(self, cond: bool, func: F) -> Result<Self>
where
F: FnOnce(Self) -> Result<Self>,
{
if cond {
func(self)
} else {
Ok(self)
}
}
pub fn try_when_some<F, T>(self, maybe: Option<T>, func: F) -> Result<Self>
where
F: FnOnce(Self, T) -> Result<Self>,
{
if let Some(t) = maybe {
func(self, t)
} else {
Ok(self)
}
}
pub fn with_root_certificate(self, root: RootCertificate) -> Result<Self> {
use wolfssl_sys::{
wolfSSL_CTX_load_verify_buffer, wolfSSL_CTX_load_verify_locations,
WOLFSSL_FILETYPE_ASN1, WOLFSSL_FILETYPE_PEM,
};
let result = match root {
RootCertificate::Asn1Buffer(buf) => unsafe {
wolfSSL_CTX_load_verify_buffer(
self.ctx.as_ptr(),
buf.as_ptr(),
buf.len() as std::os::raw::c_long,
WOLFSSL_FILETYPE_ASN1 as c_int,
)
},
RootCertificate::PemBuffer(buf) => unsafe {
wolfSSL_CTX_load_verify_buffer(
self.ctx.as_ptr(),
buf.as_ptr(),
buf.len() as std::os::raw::c_long,
WOLFSSL_FILETYPE_PEM as c_int,
)
},
RootCertificate::PemFileOrDirectory(path) => {
let is_dir = path.is_dir();
let path = path.to_str().ok_or_else(|| {
Error::fatal(wolfssl_sys::wolfSSL_ErrorCodes_WOLFSSL_BAD_PATH)
})?;
let path = std::ffi::CString::new(path)
.map_err(|_| Error::fatal(wolfssl_sys::wolfSSL_ErrorCodes_WOLFSSL_BAD_PATH))?;
if is_dir {
unsafe {
wolfSSL_CTX_load_verify_locations(
self.ctx.as_ptr(),
std::ptr::null(),
path.as_c_str().as_ptr(),
)
}
} else {
unsafe {
wolfSSL_CTX_load_verify_locations(
self.ctx.as_ptr(),
path.as_c_str().as_ptr(),
std::ptr::null(),
)
}
}
}
};
if result == wolfssl_sys::WOLFSSL_SUCCESS as c_int {
Ok(self)
} else {
Err(Error::fatal(result))
}
}
pub fn with_cipher_list(self, cipher_list: &str) -> Result<Self> {
let cipher_list = std::ffi::CString::new(cipher_list)
.map_err(|_| Error::fatal(wolfssl_sys::WOLFSSL_FAILURE as c_int))?;
let result = unsafe {
wolfssl_sys::wolfSSL_CTX_set_cipher_list(
self.ctx.as_ptr(),
cipher_list.as_c_str().as_ptr(),
)
};
if result == wolfssl_sys::WOLFSSL_SUCCESS as c_int {
Ok(self)
} else {
Err(Error::fatal(result))
}
}
pub fn with_groups(self, groups: &[CurveGroup]) -> Result<Self> {
let mut ffi_curves = groups.iter().map(|g| g.as_ffi() as i32).collect::<Vec<_>>();
let result = unsafe {
wolfssl_sys::wolfSSL_CTX_set_groups(
self.ctx.as_ptr(),
ffi_curves.as_mut_ptr(),
ffi_curves.len() as i32,
)
};
if result == wolfssl_sys::WOLFSSL_SUCCESS as c_int {
Ok(self)
} else {
Err(Error::fatal(result))
}
}
pub fn with_certificate(self, secret: Secret) -> Result<Self> {
use wolfssl_sys::{
wolfSSL_CTX_use_certificate_buffer, wolfSSL_CTX_use_certificate_file,
WOLFSSL_FILETYPE_ASN1, WOLFSSL_FILETYPE_PEM,
};
let result = match secret {
Secret::Asn1Buffer(buf) => unsafe {
wolfSSL_CTX_use_certificate_buffer(
self.ctx.as_ptr(),
buf.as_ptr(),
buf.len() as std::os::raw::c_long,
WOLFSSL_FILETYPE_ASN1 as c_int,
)
},
Secret::Asn1File(path) => {
let path = path.to_str().ok_or_else(|| {
Error::fatal(wolfssl_sys::wolfCrypt_ErrorCodes_BAD_PATH_ERROR)
})?;
let file = std::ffi::CString::new(path)
.map_err(|_| Error::fatal(wolfssl_sys::wolfCrypt_ErrorCodes_BAD_PATH_ERROR))?;
unsafe {
wolfSSL_CTX_use_certificate_file(
self.ctx.as_ptr(),
file.as_c_str().as_ptr(),
WOLFSSL_FILETYPE_ASN1 as c_int,
)
}
}
Secret::PemBuffer(buf) => unsafe {
wolfSSL_CTX_use_certificate_buffer(
self.ctx.as_ptr(),
buf.as_ptr(),
buf.len() as std::os::raw::c_long,
WOLFSSL_FILETYPE_PEM as c_int,
)
},
Secret::PemFile(path) => {
let path = path.to_str().ok_or_else(|| {
Error::fatal(wolfssl_sys::wolfCrypt_ErrorCodes_BAD_PATH_ERROR)
})?;
let file = std::ffi::CString::new(path)
.map_err(|_| Error::fatal(wolfssl_sys::wolfCrypt_ErrorCodes_BAD_PATH_ERROR))?;
unsafe {
wolfSSL_CTX_use_certificate_file(
self.ctx.as_ptr(),
file.as_c_str().as_ptr(),
WOLFSSL_FILETYPE_PEM as c_int,
)
}
}
};
if result == wolfssl_sys::WOLFSSL_SUCCESS as c_int {
Ok(self)
} else {
Err(Error::fatal(result))
}
}
pub fn with_private_key(self, secret: Secret) -> Result<Self> {
use wolfssl_sys::{
wolfSSL_CTX_use_PrivateKey_buffer, wolfSSL_CTX_use_PrivateKey_file,
WOLFSSL_FILETYPE_ASN1, WOLFSSL_FILETYPE_PEM,
};
let result = match secret {
Secret::Asn1Buffer(buf) => unsafe {
wolfSSL_CTX_use_PrivateKey_buffer(
self.ctx.as_ptr(),
buf.as_ptr(),
buf.len() as std::os::raw::c_long,
WOLFSSL_FILETYPE_ASN1 as c_int,
)
},
Secret::Asn1File(path) => {
let path = path.to_str().ok_or_else(|| {
Error::fatal(wolfssl_sys::wolfCrypt_ErrorCodes_BAD_PATH_ERROR)
})?;
let file = std::ffi::CString::new(path)
.map_err(|_| Error::fatal(wolfssl_sys::wolfCrypt_ErrorCodes_BAD_PATH_ERROR))?;
unsafe {
wolfSSL_CTX_use_PrivateKey_file(
self.ctx.as_ptr(),
file.as_c_str().as_ptr(),
WOLFSSL_FILETYPE_ASN1 as c_int,
)
}
}
Secret::PemBuffer(buf) => unsafe {
wolfSSL_CTX_use_PrivateKey_buffer(
self.ctx.as_ptr(),
buf.as_ptr(),
buf.len() as std::os::raw::c_long,
WOLFSSL_FILETYPE_PEM as c_int,
)
},
Secret::PemFile(path) => {
let path = path.to_str().ok_or_else(|| {
Error::fatal(wolfssl_sys::wolfCrypt_ErrorCodes_BAD_PATH_ERROR)
})?;
let file = std::ffi::CString::new(path)
.map_err(|_| Error::fatal(wolfssl_sys::wolfCrypt_ErrorCodes_BAD_PATH_ERROR))?;
unsafe {
wolfSSL_CTX_use_PrivateKey_file(
self.ctx.as_ptr(),
file.as_c_str().as_ptr(),
WOLFSSL_FILETYPE_PEM as c_int,
)
}
}
};
if result == wolfssl_sys::WOLFSSL_SUCCESS as c_int {
Ok(self)
} else {
Err(Error::fatal(result))
}
}
pub fn with_secure_renegotiation(self) -> Result<Self> {
let result = unsafe { wolfssl_sys::wolfSSL_CTX_UseSecureRenegotiation(self.ctx.as_ptr()) };
if result == wolfssl_sys::WOLFSSL_SUCCESS as c_int {
Ok(self)
} else {
Err(Error::fatal(result))
}
}
pub fn with_verify_method(self, mode: SslVerifyMode) -> Self {
unsafe { wolfssl_sys::wolfSSL_CTX_set_verify(self.ctx.as_ptr(), mode.into(), None) };
self
}
#[cfg(feature = "system_ca_certs")]
pub fn with_system_ca_certs(self) -> Self {
unsafe { wolfssl_sys::wolfSSL_CTX_load_system_CA_certs(self.ctx.as_ptr()) };
self
}
pub fn build(self) -> Context {
Context {
method: self.method,
ctx: ContextPointer(self.ctx),
}
}
}
pub(crate) struct ContextPointer(NonNull<wolfssl_sys::WOLFSSL_CTX>);
impl std::ops::Deref for ContextPointer {
type Target = NonNull<wolfssl_sys::WOLFSSL_CTX>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
unsafe impl Send for ContextPointer {}
unsafe impl Sync for ContextPointer {}
pub(crate) struct WolfsslPointer(NonNull<wolfssl_sys::WOLFSSL>);
impl WolfsslPointer {
pub(crate) fn as_ptr(&mut self) -> *mut wolfssl_sys::WOLFSSL {
self.0.as_ptr()
}
}
unsafe impl Send for WolfsslPointer {}
pub struct Context {
method: Method,
ctx: ContextPointer,
}
impl Context {
pub fn method(&self) -> Method {
self.method
}
pub fn new_session<IOCB: IOCallbacks>(
&self,
config: SessionConfig<IOCB>,
) -> std::result::Result<Session<IOCB>, NewSessionError> {
let ptr = unsafe { wolfssl_sys::wolfSSL_new(self.ctx.as_ptr()) };
let ssl = WolfsslPointer(NonNull::new(ptr).ok_or(NewSessionError::CreateFailed)?);
Session::new_from_wolfssl_pointer(ssl, config)
}
}
impl Drop for Context {
fn drop(&mut self) {
unsafe { wolfssl_sys::wolfSSL_CTX_free(self.ctx.as_ptr()) }
}
}
#[cfg(test)]
mod tests {
use super::*;
use test_case::test_case;
#[test_case(Method::DtlsClient)]
#[test_case(Method::DtlsClientV1_2)]
#[test_case(Method::DtlsServer)]
#[test_case(Method::DtlsServerV1_2)]
#[test_case(Method::TlsClient)]
#[test_case(Method::TlsClientV1_2)]
#[test_case(Method::TlsClientV1_3)]
#[test_case(Method::TlsServer)]
#[test_case(Method::TlsServerV1_2)]
#[test_case(Method::TlsServerV1_3)]
fn wolfssl_context_new(method: Method) {
crate::wolf_init().unwrap();
let _ = method.into_method_ptr().unwrap();
}
#[test]
fn new() {
ContextBuilder::new(Method::DtlsClient).unwrap();
}
#[test_case(true, true => true)]
#[test_case(true, false => panics "Fatal(Other { what:")]
#[test_case(false, false => false)]
#[test_case(false, true => false)]
fn try_when(whether: bool, ok: bool) -> bool {
let mut called = false;
let _ = ContextBuilder::new(Method::TlsClient)
.unwrap()
.try_when(whether, |b| {
called = true;
if ok {
Ok(b)
} else {
Err(Error::fatal(wolfssl_sys::WOLFSSL_FAILURE as c_int))
}
})
.unwrap();
called
}
#[test_case(Some(true) => true)]
#[test_case(Some(false) => panics "Fatal(Other { what:")]
#[test_case(None => false)]
fn try_some(whether: Option<bool>) -> bool {
let mut called = false;
let _ = ContextBuilder::new(Method::TlsClient)
.unwrap()
.try_when_some(whether, |b, ok| {
called = true;
if ok {
Ok(b)
} else {
Err(Error::fatal(wolfssl_sys::WOLFSSL_FAILURE as c_int))
}
})
.unwrap();
called
}
#[test]
fn root_certificate_buffer() {
const CA_CERT: &[u8] = &include!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/ca_cert_der_2048"
));
let cert = RootCertificate::Asn1Buffer(CA_CERT);
let _ = ContextBuilder::new(Method::TlsClient)
.unwrap()
.with_root_certificate(cert)
.unwrap();
}
#[test]
fn set_cipher_list() {
let _ = ContextBuilder::new(Method::DtlsClient)
.unwrap()
.with_cipher_list("TLS13-CHACHA20-POLY1305-SHA256")
.unwrap();
}
#[test]
fn set_certificate_buffer() {
const SERVER_CERT: &[u8] = &include!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/server_cert_der_2048"
));
let cert = Secret::Asn1Buffer(SERVER_CERT);
let _ = ContextBuilder::new(Method::TlsClient)
.unwrap()
.with_certificate(cert)
.unwrap();
}
#[test]
fn set_private_key_buffer() {
const SERVER_KEY: &[u8] = &include!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/server_key_der_2048"
));
let key = Secret::Asn1Buffer(SERVER_KEY);
let _ = ContextBuilder::new(Method::TlsClient)
.unwrap()
.with_private_key(key)
.unwrap();
}
#[test]
fn set_secure_renegotiation() {
let _ = ContextBuilder::new(Method::TlsClient)
.unwrap()
.with_secure_renegotiation()
.unwrap();
}
#[test_case(SslVerifyMode::SslVerifyNone)]
#[test_case(SslVerifyMode::SslVerifyPeer)]
#[test_case(SslVerifyMode::SslVerifyFailIfNoPeerCert)]
#[test_case(SslVerifyMode::SslVerifyFailExceptPsk)]
fn set_verify_method(mode: SslVerifyMode) {
ContextBuilder::new(Method::TlsClient)
.unwrap()
.with_verify_method(mode);
}
#[test]
fn register_io_callbacks() {
let _ = ContextBuilder::new(Method::TlsClient).unwrap().build();
}
}