use std::sync::Arc;
use wolfcrypt_sys::*;
use crate::certificate::{Certificate, PrivateKey, RootCertStore};
use crate::ensure_init;
use crate::error::{expect_wolfssl_success, Result, TlsError};
use crate::protocol::{self, ProtocolVersion};
pub(crate) struct CtxInner {
pub(crate) ctx: *mut WOLFSSL_CTX,
}
unsafe impl Send for CtxInner {}
unsafe impl Sync for CtxInner {}
impl Drop for CtxInner {
fn drop(&mut self) {
unsafe {
wolfSSL_CTX_free(self.ctx);
}
}
}
#[derive(Clone)]
pub struct TlsClientConfig {
pub(crate) inner: Arc<CtxInner>,
}
pub struct TlsClientConfigBuilder {
protocol_versions: Option<Vec<ProtocolVersion>>,
root_store: Option<RootCertStore>,
client_cert: Option<Certificate>,
client_key: Option<PrivateKey>,
}
impl TlsClientConfig {
pub fn builder() -> TlsClientConfigBuilder {
TlsClientConfigBuilder {
protocol_versions: None,
root_store: None,
client_cert: None,
client_key: None,
}
}
pub unsafe fn as_raw_ctx(&self) -> *mut wolfcrypt_sys::WOLFSSL_CTX {
self.inner.ctx
}
pub unsafe fn new_ssl_with_io_callbacks(
&self,
server_name: &str,
recv_cb: wolfcrypt_sys::CallbackIORecv,
send_cb: wolfcrypt_sys::CallbackIOSend,
io_ctx: *mut core::ffi::c_void,
) -> crate::error::Result<*mut wolfcrypt_sys::WOLFSSL> {
use crate::error::TlsError;
use wolfcrypt_sys::*;
let ssl = wolfSSL_new(self.inner.ctx);
if ssl.is_null() {
return Err(TlsError::AllocFailed { func: "wolfSSL_new" });
}
let guard = crate::SslGuard(ssl);
wolfSSL_SSLSetIORecv(guard.as_ptr(), recv_cb);
wolfSSL_SSLSetIOSend(guard.as_ptr(), send_cb);
wolfSSL_SetIOReadCtx(guard.as_ptr(), io_ctx);
wolfSSL_SetIOWriteCtx(guard.as_ptr(), io_ctx);
if !server_name.is_empty() {
if server_name.len() > u16::MAX as usize {
return Err(TlsError::InvalidConfig("server name exceeds maximum SNI length"));
}
let ret = wolfSSL_UseSNI(
guard.as_ptr(),
WOLFSSL_SNI_HOST_NAME as core::ffi::c_uchar,
server_name.as_ptr() as *const core::ffi::c_void,
server_name.len() as u16,
);
if ret != WOLFSSL_SUCCESS as core::ffi::c_int {
return Err(TlsError::Ffi { code: ret, func: "wolfSSL_UseSNI" });
}
}
Ok(guard.into_raw())
}
}
impl TlsClientConfigBuilder {
pub fn with_protocol_versions(mut self, versions: &[ProtocolVersion]) -> Self {
self.protocol_versions = Some(versions.to_vec());
self
}
pub fn with_root_certificates(mut self, store: RootCertStore) -> Self {
self.root_store = Some(store);
self
}
pub fn with_no_client_auth(self) -> Self {
self
}
pub fn with_client_auth(mut self, cert: Certificate, key: PrivateKey) -> Self {
self.client_cert = Some(cert);
self.client_key = Some(key);
self
}
pub fn build(self) -> Result<TlsClientConfig> {
ensure_init();
let root_store = self
.root_store
.ok_or(TlsError::InvalidConfig("root certificates are required"))?;
let method = unsafe {
protocol::resolve_method(protocol::Side::Client, self.protocol_versions.as_deref())?
};
let ctx = unsafe { wolfSSL_CTX_new(method) };
if ctx.is_null() {
return Err(TlsError::AllocFailed {
func: "wolfSSL_CTX_new",
});
}
let inner = Arc::new(CtxInner { ctx });
for (cert_data, format) in root_store.iter() {
let ret = unsafe {
wolfSSL_CTX_load_verify_buffer(
inner.ctx,
cert_data.as_ptr(),
cert_data.len() as core::ffi::c_long,
format.as_c_int(),
)
};
expect_wolfssl_success(ret, "wolfSSL_CTX_load_verify_buffer")?;
}
unsafe {
wolfSSL_CTX_set_verify(inner.ctx, WOLFSSL_VERIFY_PEER as core::ffi::c_int, None);
}
if let (Some(cert), Some(key)) = (self.client_cert.as_ref(), self.client_key.as_ref()) {
let ret = unsafe {
wolfSSL_CTX_use_certificate_buffer(
inner.ctx,
cert.data().as_ptr(),
cert.data().len() as core::ffi::c_long,
cert.format().as_c_int(),
)
};
expect_wolfssl_success(ret, "wolfSSL_CTX_use_certificate_buffer")?;
let ret = unsafe {
wolfSSL_CTX_use_PrivateKey_buffer(
inner.ctx,
key.data().as_ptr(),
key.data().len() as core::ffi::c_long,
key.format().as_c_int(),
)
};
expect_wolfssl_success(ret, "wolfSSL_CTX_use_PrivateKey_buffer")?;
}
Ok(TlsClientConfig { inner })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_without_root_certs_fails() {
let result = TlsClientConfig::builder().with_no_client_auth().build();
assert!(result.is_err());
}
}