mod certificate;
mod client;
mod config;
mod error;
mod protocol;
mod server;
pub use certificate::{Certificate, PrivateKey, RootCertStore};
pub use client::TlsClient;
pub use config::{TlsClientConfig, TlsClientConfigBuilder};
pub use error::{Result, TlsError};
pub use protocol::ProtocolVersion;
pub use server::{TlsAcceptor, TlsServer, TlsServerConfig, TlsServerConfigBuilder};
use std::sync::Once;
use wolfcrypt_sys::*;
pub trait TlsSocket {
fn tls_raw_fd(&self) -> core::ffi::c_int;
}
#[cfg(unix)]
impl<T: std::os::unix::io::AsRawFd> TlsSocket for T {
fn tls_raw_fd(&self) -> core::ffi::c_int {
self.as_raw_fd()
}
}
#[cfg(windows)]
impl<T: std::os::windows::io::AsRawSocket> TlsSocket for T {
fn tls_raw_fd(&self) -> core::ffi::c_int {
let sock = self.as_raw_socket();
assert!(
sock <= core::ffi::c_int::MAX as u64,
"socket handle {sock:#x} exceeds c_int range"
);
sock as core::ffi::c_int
}
}
macro_rules! impl_tls_io {
($ty:ident) => {
impl<S: std::io::Read + std::io::Write + $crate::TlsSocket> std::io::Read for $ty<S> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
let len =
std::cmp::min(buf.len(), core::ffi::c_int::MAX as usize) as core::ffi::c_int;
let ret = unsafe {
wolfcrypt_sys::wolfSSL_read(
self.ssl,
buf.as_mut_ptr() as *mut core::ffi::c_void,
len,
)
};
if ret > 0 {
Ok(ret as usize)
} else if ret == 0 {
Ok(0) } else {
let err = unsafe { wolfcrypt_sys::wolfSSL_get_error(self.ssl, ret) };
match err {
wolfcrypt_sys::wolfSSL_ErrorCodes_WOLFSSL_ERROR_WANT_READ_E
| wolfcrypt_sys::wolfSSL_ErrorCodes_WOLFSSL_ERROR_WANT_WRITE_E => {
Err(std::io::Error::from(std::io::ErrorKind::WouldBlock))
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!(
"wolfSSL_read: {} (error {err})",
$crate::error::error_string(err)
),
)),
}
}
}
}
impl<S: std::io::Read + std::io::Write + $crate::TlsSocket> std::io::Write for $ty<S> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
let len =
std::cmp::min(buf.len(), core::ffi::c_int::MAX as usize) as core::ffi::c_int;
let ret = unsafe {
wolfcrypt_sys::wolfSSL_write(
self.ssl,
buf.as_ptr() as *const core::ffi::c_void,
len,
)
};
if ret > 0 {
Ok(ret as usize)
} else {
let err = unsafe { wolfcrypt_sys::wolfSSL_get_error(self.ssl, ret) };
match err {
wolfcrypt_sys::wolfSSL_ErrorCodes_WOLFSSL_ERROR_WANT_READ_E
| wolfcrypt_sys::wolfSSL_ErrorCodes_WOLFSSL_ERROR_WANT_WRITE_E => {
Err(std::io::Error::from(std::io::ErrorKind::WouldBlock))
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!(
"wolfSSL_write: {} (error {err})",
$crate::error::error_string(err)
),
)),
}
}
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
impl<S> Drop for $ty<S> {
fn drop(&mut self) {
unsafe {
let _ = wolfcrypt_sys::wolfSSL_shutdown(self.ssl);
wolfcrypt_sys::wolfSSL_free(self.ssl);
}
}
}
};
}
pub(crate) use impl_tls_io;
pub(crate) struct SslGuard(pub(crate) *mut wolfcrypt_sys::WOLFSSL);
impl Drop for SslGuard {
fn drop(&mut self) {
unsafe {
wolfcrypt_sys::wolfSSL_free(self.0);
}
}
}
impl SslGuard {
pub(crate) fn as_ptr(&self) -> *mut wolfcrypt_sys::WOLFSSL {
self.0
}
pub(crate) fn into_raw(self) -> *mut wolfcrypt_sys::WOLFSSL {
let ptr = self.0;
std::mem::forget(self);
ptr
}
}
static INIT: Once = Once::new();
pub fn ensure_init() {
INIT.call_once(|| {
let ret = unsafe { wolfSSL_Init() };
assert!(ret >= 0, "wolfSSL_Init failed with code {ret}");
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wolfssl_init_and_cleanup() {
ensure_init();
}
#[test]
fn want_read_write_error_codes_are_negative() {
let want_read = wolfcrypt_sys::wolfSSL_ErrorCodes_WOLFSSL_ERROR_WANT_READ_E;
let want_write = wolfcrypt_sys::wolfSSL_ErrorCodes_WOLFSSL_ERROR_WANT_WRITE_E;
assert!(
want_read < 0,
"WANT_READ_E should be negative (got {want_read}); \
are we matching the OpenSSL compat constant by mistake?"
);
assert!(
want_write < 0,
"WANT_WRITE_E should be negative (got {want_write}); \
are we matching the OpenSSL compat constant by mistake?"
);
assert_eq!(want_read, -2, "WANT_READ_E should be -2");
assert_eq!(want_write, -3, "WANT_WRITE_E should be -3");
}
#[test]
fn tls_types_implement_debug() {
fn assert_debug<T: std::fmt::Debug>() {}
assert_debug::<TlsClient<std::net::TcpStream>>();
assert_debug::<server::TlsServer<std::net::TcpStream>>();
}
#[test]
fn resolve_method_returns_non_null_for_all_valid_inputs() {
use crate::protocol::{resolve_method, ProtocolVersion, Side};
ensure_init();
let version_sets: &[Option<&[ProtocolVersion]>] = &[
None,
Some(&[ProtocolVersion::Tls12]),
Some(&[ProtocolVersion::Tls13]),
Some(&[ProtocolVersion::Tls12, ProtocolVersion::Tls13]),
Some(&[ProtocolVersion::Tls13, ProtocolVersion::Tls12]),
];
for side in [Side::Client, Side::Server] {
for versions in version_sets {
let result = unsafe { resolve_method(side, *versions) };
assert!(
result.is_ok(),
"resolve_method({side:?}, {versions:?}) should succeed"
);
let ptr = result.unwrap();
assert!(
!ptr.is_null(),
"resolve_method({side:?}, {versions:?}) returned null"
);
}
}
}
#[test]
fn cint_max_is_positive_and_usable_as_len() {
let max = core::ffi::c_int::MAX as usize;
assert!(max >= 32767, "c_int::MAX too small: {max}");
let clamped = std::cmp::min(usize::MAX, max);
assert_eq!(clamped, max);
assert_eq!(clamped as core::ffi::c_int, core::ffi::c_int::MAX);
}
}