use crate::{
config::Config,
error::{LastError, Result},
*,
};
#[cfg(libressl_3_1_0)]
use std::convert::TryFrom;
use std::{
convert::TryInto,
ffi::{CStr, CString},
io,
net::ToSocketAddrs,
os::raw::c_void,
os::unix::io::{AsRawFd, RawFd},
slice,
time::SystemTime,
};
#[macro_export]
macro_rules! try_tls {
($self: expr, $call: expr) => {
match $call {
Err(err) => Err(io::Error::new(io::ErrorKind::Other, err)),
Ok(size) => {
if size == TLS_WANT_POLLIN as isize || size == TLS_WANT_POLLOUT as isize {
Err(io::Error::new(
io::ErrorKind::WouldBlock,
io::Error::last_os_error(),
))
} else {
Ok(size as usize)
}
}
}
};
}
#[derive(Debug)]
pub struct Tls(*mut libtls_sys::tls, RawFd);
impl Tls {
fn new(f: unsafe extern "C" fn() -> *mut libtls_sys::tls) -> io::Result<Self> {
let tls = unsafe { f() };
if tls.is_null() {
Err(io::Error::last_os_error())
} else {
Ok(Tls(tls, -1))
}
}
pub fn client() -> io::Result<Self> {
Self::new(libtls_sys::tls_client)
}
pub fn server() -> io::Result<Self> {
Self::new(libtls_sys::tls_server)
}
pub fn configure(&mut self, config: &Config) -> Result<()> {
cvt(self, unsafe { libtls_sys::tls_configure(self.0, config.0) })
}
pub unsafe fn from_sys(tls: *mut libtls_sys::tls) -> Self {
if tls.is_null() {
panic!(io::Error::last_os_error())
}
Tls(tls, -1)
}
pub fn reset(&mut self) {
unsafe { libtls_sys::tls_reset(self.0) };
}
pub fn accept_fds(&mut self, fd_read: RawFd, fd_write: RawFd) -> Result<Tls> {
let mut tls = Self::client()?;
unsafe {
cvt(
self,
libtls_sys::tls_accept_fds(self.0, &mut tls.0, fd_read, fd_write),
)?;
}
Ok(tls)
}
pub fn accept_socket(&mut self, socket: RawFd) -> Result<Tls> {
let mut tls = Self::client()?;
unsafe {
cvt(
self,
libtls_sys::tls_accept_socket(self.0, &mut tls.0, socket),
)?;
self.1 = socket;
}
Ok(tls)
}
pub fn accept_raw_fd<T>(&mut self, raw_fd: &T) -> Result<Tls>
where
T: AsRawFd,
{
self.accept_socket(raw_fd.as_raw_fd())
}
pub unsafe fn accept_cbs(
&mut self,
read_cb: TlsReadCb,
write_cb: TlsWriteCb,
cb_arg: Option<*mut c_void>,
) -> Result<Tls> {
let mut tls = Self::client()?;
let cb_arg = cb_arg.unwrap_or(std::ptr::null_mut());
cvt(
self,
libtls_sys::tls_accept_cbs(self.0, &mut tls.0, read_cb, write_cb, cb_arg),
)?;
Ok(tls)
}
pub fn connect(&mut self, host: &str, port: Option<&str>) -> Result<()> {
unsafe {
let c_host = CString::new(host)?;
let res = match port {
Some(val) => {
let c_port = CString::new(val)?;
libtls_sys::tls_connect(self.0, c_host.as_ptr(), c_port.as_ptr())
}
None => libtls_sys::tls_connect(self.0, c_host.as_ptr(), std::ptr::null()),
};
cvt(self, res)
}
}
pub fn connect_fds(&mut self, fd_read: RawFd, fd_write: RawFd, servername: &str) -> Result<()> {
unsafe {
let c_servername = CString::new(servername)?;
cvt(
self,
libtls_sys::tls_connect_fds(self.0, fd_read, fd_write, c_servername.as_ptr()),
)
}
}
pub fn connect_servername<A: ToSocketAddrs>(
&mut self,
host: A,
servername: &str,
) -> Result<()> {
let mut last_error = Self::to_error("no address to connect to".to_owned());
let mut connect = |addr: &str, servername: &str| -> Result<()> {
unsafe {
let c_host = CString::new(addr.to_string())?;
let c_servername = CString::new(servername)?;
cvt(
self,
libtls_sys::tls_connect_servername(
self.0,
c_host.as_ptr(),
std::ptr::null(),
c_servername.as_ptr(),
),
)
}
};
for addr in host.to_socket_addrs()? {
match connect(&addr.to_string(), servername) {
Ok(_) => return Ok(()),
Err(err) => last_error = Err(err),
}
}
last_error
}
pub fn connect_socket(&mut self, socket: RawFd, servername: &str) -> Result<()> {
unsafe {
let c_servername = CString::new(servername)?;
cvt(
self,
libtls_sys::tls_connect_socket(self.0, socket, c_servername.as_ptr()),
)?;
self.1 = socket;
Ok(())
}
}
pub fn connect_raw_fd<T>(&mut self, raw_fd: &T, servername: &str) -> Result<()>
where
T: AsRawFd,
{
self.connect_socket(raw_fd.as_raw_fd(), servername)
}
pub unsafe fn connect_cbs(
&mut self,
read_cb: TlsReadCb,
write_cb: TlsWriteCb,
cb_arg: Option<*mut c_void>,
servername: &str,
) -> Result<()> {
let c_servername = CString::new(servername)?;
let cb_arg = cb_arg.unwrap_or(std::ptr::null_mut());
cvt(
self,
libtls_sys::tls_connect_cbs(self.0, read_cb, write_cb, cb_arg, c_servername.as_ptr()),
)
}
pub fn tls_handshake(&mut self) -> error::Result<isize> {
cvt_err(self, unsafe { libtls_sys::tls_handshake(self.0) as isize })
}
pub fn tls_read(&mut self, buf: &mut [u8]) -> error::Result<isize> {
cvt_err(self, unsafe {
libtls_sys::tls_read(
self.0,
buf.as_mut_ptr() as *mut c_void,
buf.len().try_into()?,
)
.try_into()?
})
}
pub fn tls_write(&mut self, buf: &[u8]) -> error::Result<isize> {
cvt_err(self, unsafe {
libtls_sys::tls_write(self.0, buf.as_ptr() as *const c_void, buf.len().try_into()?)
.try_into()?
})
}
pub fn tls_close(&mut self) -> error::Result<isize> {
cvt_err(self, unsafe { libtls_sys::tls_close(self.0) as isize })
}
pub fn close(&mut self) -> io::Result<()> {
try_tls!(self, self.tls_close()).map(|_| ())
}
pub fn peer_cert_provided(&mut self) -> bool {
unsafe { libtls_sys::tls_peer_cert_provided(self.0) != 0 }
}
pub fn peer_cert_contains_name(&mut self, name: &str) -> Result<bool> {
unsafe {
let c_name = CString::new(name)?;
Ok(libtls_sys::tls_peer_cert_contains_name(self.0, c_name.as_ptr()) != 0)
}
}
pub fn peer_cert_hash(&mut self) -> error::Result<String> {
unsafe { cvt_string(self, libtls_sys::tls_peer_cert_hash(self.0)) }
}
pub fn peer_cert_issuer(&mut self) -> error::Result<String> {
unsafe { cvt_string(self, libtls_sys::tls_peer_cert_issuer(self.0)) }
}
pub fn peer_cert_subject(&mut self) -> error::Result<String> {
unsafe { cvt_string(self, libtls_sys::tls_peer_cert_subject(self.0)) }
}
pub fn peer_cert_notbefore(&mut self) -> error::Result<SystemTime> {
cvt_time(self, unsafe { libtls_sys::tls_peer_cert_notbefore(self.0) })
}
pub fn peer_cert_notafter(&mut self) -> error::Result<SystemTime> {
cvt_time(self, unsafe { libtls_sys::tls_peer_cert_notafter(self.0) })
}
pub fn peer_cert_chain_pem(&mut self) -> error::Result<Vec<u8>> {
unsafe {
let mut size = 0;
let ptr = libtls_sys::tls_peer_cert_chain_pem(self.0, &mut size);
if ptr.is_null() {
let errstr = self.last_error().unwrap_or_else(|_| "no error".to_string());
Self::to_error(errstr)
} else {
let len = size.try_into()?;
let data = slice::from_raw_parts(ptr, len);
Ok(data.to_vec())
}
}
}
pub fn conn_alpn_selected(&mut self) -> Option<String> {
unsafe {
let ptr = libtls_sys::tls_conn_alpn_selected(self.0);
if ptr.is_null() {
None
} else {
let c_str = CStr::from_ptr(ptr);
let string = c_str.to_owned().to_string_lossy().to_string();
Some(string)
}
}
}
pub fn conn_cipher(&mut self) -> error::Result<String> {
unsafe { cvt_string(self, libtls_sys::tls_conn_cipher(self.0)) }
}
#[cfg(libressl_3_1_0)]
pub fn conn_cipher_strength(&mut self) -> error::Result<usize> {
cvt_err(self, unsafe {
libtls_sys::tls_conn_cipher_strength(self.0) as isize
})
.and_then(|retval| usize::try_from(retval).map_err(Into::into))
}
pub fn conn_servername(&mut self) -> error::Result<String> {
unsafe { cvt_string(self, libtls_sys::tls_conn_servername(self.0)) }
}
pub fn conn_session_resumed(&mut self) -> bool {
unsafe { libtls_sys::tls_conn_session_resumed(self.0) != 0 }
}
pub fn conn_version(&mut self) -> error::Result<String> {
unsafe { cvt_string(self, libtls_sys::tls_conn_version(self.0)) }
}
pub fn ocsp_process_response(&mut self, response: &[u8]) -> error::Result<()> {
cvt(self, unsafe {
libtls_sys::tls_ocsp_process_response(
self.0,
response.as_ptr(),
response.len().try_into()?,
)
})
}
pub fn peer_ocsp_cert_status(&mut self) -> error::Result<isize> {
cvt_err(self, unsafe {
libtls_sys::tls_peer_ocsp_cert_status(self.0) as isize
})
}
pub fn peer_ocsp_crl_reason(&mut self) -> error::Result<isize> {
cvt_err(self, unsafe {
libtls_sys::tls_peer_ocsp_crl_reason(self.0) as isize
})
}
pub fn peer_ocsp_next_update(&mut self) -> error::Result<SystemTime> {
cvt_time(self, unsafe {
libtls_sys::tls_peer_ocsp_next_update(self.0)
})
}
pub fn peer_ocsp_response_status(&mut self) -> error::Result<isize> {
cvt_err(self, unsafe {
libtls_sys::tls_peer_ocsp_response_status(self.0) as isize
})
}
pub fn peer_ocsp_result(&mut self) -> error::Result<String> {
unsafe { cvt_string(self, libtls_sys::tls_peer_ocsp_result(self.0)) }
}
pub fn peer_ocsp_revocation_time(&mut self) -> error::Result<SystemTime> {
cvt_time(self, unsafe {
libtls_sys::tls_peer_ocsp_revocation_time(self.0)
})
}
pub fn peer_ocsp_this_update(&mut self) -> error::Result<SystemTime> {
cvt_time(self, unsafe {
libtls_sys::tls_peer_ocsp_this_update(self.0)
})
}
pub fn peer_ocsp_url(&mut self) -> error::Result<String> {
unsafe { cvt_string(self, libtls_sys::tls_peer_ocsp_url(self.0)) }
}
}
impl LastError for Tls {
fn last_error(&self) -> error::Result<String> {
unsafe { cvt_no_error(libtls_sys::tls_error(self.0)) }
}
fn to_error<T>(errstr: String) -> error::Result<T> {
Err(error::Error::CtxError(errstr))
}
}
impl AsRawFd for Tls {
fn as_raw_fd(&self) -> RawFd {
self.1
}
}
impl Drop for Tls {
fn drop(&mut self) {
unsafe {
loop {
let ret = libtls_sys::tls_close(self.0);
if !(ret == TLS_WANT_POLLIN || ret == TLS_WANT_POLLOUT) {
break;
}
}
libtls_sys::tls_free(self.0);
};
}
}
impl io::Read for Tls {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
try_tls!(self, self.tls_read(buf))
}
}
impl io::Write for Tls {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
try_tls!(self, self.tls_write(buf))
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
unsafe impl Send for Tls {}
unsafe impl Sync for Tls {}
pub type TlsReadCb = libtls_sys::tls_read_cb;
pub type TlsWriteCb = libtls_sys::tls_write_cb;