use std::ffi::c_void;
use std::io::{Read, Write};
use wolfcrypt_sys::*;
use crate::callback::{io_recv_shim, io_send_shim, IOCallbacks};
use crate::config::TlsClientConfig;
use crate::error::{Result, TlsError};
use crate::SslGuard;
pub struct TlsClient<IOCB: IOCallbacks> {
ssl: *mut WOLFSSL,
#[allow(dead_code)]
io: Box<IOCB>,
#[allow(dead_code)]
config: TlsClientConfig,
}
impl<IOCB: IOCallbacks> std::fmt::Debug for TlsClient<IOCB> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TlsClient").field("ssl", &self.ssl).finish()
}
}
unsafe impl<IOCB: IOCallbacks + Send> Send for TlsClient<IOCB> {}
impl<IOCB: IOCallbacks> TlsClient<IOCB> {
pub fn new(config: TlsClientConfig, server_name: &str, io: IOCB) -> Result<Self> {
if server_name.len() > 253 {
return Err(TlsError::InvalidConfig(
"server name exceeds maximum DNS hostname length (253 bytes)",
));
}
let ssl = unsafe { wolfSSL_new(config.inner.ctx) };
if ssl.is_null() {
return Err(TlsError::AllocFailed { func: "wolfSSL_new" });
}
let guard = SslGuard(ssl);
if !server_name.is_empty() {
let ret = unsafe {
wolfSSL_UseSNI(
guard.as_ptr(),
WOLFSSL_SNI_HOST_NAME as core::ffi::c_uchar,
server_name.as_ptr() as *const core::ffi::c_void,
server_name.len() as u16,
)
};
if ret != WOLFSSL_SUCCESS as core::ffi::c_int {
return Err(TlsError::Ffi {
code: ret,
func: "wolfSSL_UseSNI",
});
}
}
let mut io = Box::new(io);
unsafe {
wolfSSL_SSLSetIORecv(guard.as_ptr(), Some(io_recv_shim::<IOCB>));
wolfSSL_SSLSetIOSend(guard.as_ptr(), Some(io_send_shim::<IOCB>));
let ctx = &mut *io as *mut IOCB as *mut c_void;
wolfSSL_SetIOReadCtx(guard.as_ptr(), ctx);
wolfSSL_SetIOWriteCtx(guard.as_ptr(), ctx);
}
let ret = unsafe { wolfSSL_connect(guard.as_ptr()) };
if ret != WOLFSSL_SUCCESS as core::ffi::c_int {
let (err, verify_result) = unsafe {
let e = wolfSSL_get_error(guard.as_ptr(), ret);
let v = wolfSSL_get_verify_result(guard.as_ptr());
(e, v)
};
drop(guard);
if verify_result != X509_V_OK as core::ffi::c_long {
let reason = crate::error::verify_error_string(verify_result);
return Err(TlsError::CertificateVerification(format!(
"{reason} (X509 error {verify_result})"
)));
}
return Err(TlsError::Ffi {
code: err,
func: "wolfSSL_connect",
});
}
Ok(TlsClient {
ssl: guard.into_raw(),
io,
config,
})
}
pub unsafe fn as_raw_ssl(&self) -> *mut WOLFSSL {
self.ssl
}
}
const WANT_READ: core::ffi::c_int = WOLFSSL_ERROR_WANT_READ as core::ffi::c_int;
const WANT_WRITE: core::ffi::c_int = WOLFSSL_ERROR_WANT_WRITE as core::ffi::c_int;
impl<IOCB: IOCallbacks> Read for TlsClient<IOCB> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
let len = buf.len().min(core::ffi::c_int::MAX as usize) as core::ffi::c_int;
let ret = unsafe {
wolfSSL_read(self.ssl, buf.as_mut_ptr() as *mut core::ffi::c_void, len)
};
if ret > 0 {
Ok(ret as usize)
} else if ret == 0 {
Ok(0)
} else {
let err = unsafe { wolfSSL_get_error(self.ssl, ret) };
match err {
WANT_READ | WANT_WRITE => {
Err(std::io::Error::from(std::io::ErrorKind::WouldBlock))
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!(
"wolfSSL_read: {} (error {err})",
crate::error::error_string(err)
),
)),
}
}
}
}
impl<IOCB: IOCallbacks> Write for TlsClient<IOCB> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
let len = buf.len().min(core::ffi::c_int::MAX as usize) as core::ffi::c_int;
let ret = unsafe {
wolfSSL_write(self.ssl, buf.as_ptr() as *const core::ffi::c_void, len)
};
if ret > 0 {
Ok(ret as usize)
} else if ret == 0 {
Err(std::io::Error::from(std::io::ErrorKind::WouldBlock))
} else {
let err = unsafe { wolfSSL_get_error(self.ssl, ret) };
match err {
WANT_READ | WANT_WRITE => {
Err(std::io::Error::from(std::io::ErrorKind::WouldBlock))
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!(
"wolfSSL_write: {} (error {err})",
crate::error::error_string(err)
),
)),
}
}
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
impl<IOCB: IOCallbacks> Drop for TlsClient<IOCB> {
fn drop(&mut self) {
unsafe {
let _ = wolfSSL_shutdown(self.ssl);
wolfSSL_free(self.ssl);
}
}
}