use crate::bio::Bio;
use crate::error::ErrorStack;
use crate::pkey::{HasPrivate, Pkey};
use crate::x509::X509;
use native_ossl_sys as sys;
use std::ffi::CStr;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TlsVersion {
Tls12 = 0x0303,
Tls13 = 0x0304,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SslVerifyMode(i32);
impl SslVerifyMode {
pub const NONE: Self = SslVerifyMode(0x00);
pub const PEER: Self = SslVerifyMode(0x01);
pub const FAIL_IF_NO_PEER_CERT: Self = SslVerifyMode(0x02);
#[must_use]
pub fn or(self, other: Self) -> Self {
SslVerifyMode(self.0 | other.0)
}
}
#[derive(Debug)]
pub enum SslIoError {
WantRead,
WantWrite,
ZeroReturn,
Syscall(ErrorStack),
Ssl(ErrorStack),
Other(i32),
}
impl std::fmt::Display for SslIoError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::WantRead => write!(f, "SSL want read"),
Self::WantWrite => write!(f, "SSL want write"),
Self::ZeroReturn => write!(f, "SSL zero return (peer closed)"),
Self::Syscall(e) => write!(f, "SSL syscall error: {e}"),
Self::Ssl(e) => write!(f, "SSL error: {e}"),
Self::Other(code) => write!(f, "SSL error code {code}"),
}
}
}
impl std::error::Error for SslIoError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShutdownResult {
Sent,
Complete,
}
pub struct SslSession {
ptr: *mut sys::SSL_SESSION,
}
unsafe impl Send for SslSession {}
unsafe impl Sync for SslSession {}
impl Clone for SslSession {
fn clone(&self) -> Self {
unsafe { sys::SSL_SESSION_up_ref(self.ptr) };
SslSession { ptr: self.ptr }
}
}
impl Drop for SslSession {
fn drop(&mut self) {
unsafe { sys::SSL_SESSION_free(self.ptr) };
}
}
pub struct SslCtx {
ptr: *mut sys::SSL_CTX,
}
unsafe impl Send for SslCtx {}
unsafe impl Sync for SslCtx {}
impl Clone for SslCtx {
fn clone(&self) -> Self {
unsafe { sys::SSL_CTX_up_ref(self.ptr) };
SslCtx { ptr: self.ptr }
}
}
impl Drop for SslCtx {
fn drop(&mut self) {
unsafe { sys::SSL_CTX_free(self.ptr) };
}
}
impl SslCtx {
pub fn new() -> Result<Self, ErrorStack> {
let method = unsafe { sys::TLS_method() };
let ptr = unsafe { sys::SSL_CTX_new(method) };
if ptr.is_null() {
return Err(ErrorStack::drain());
}
Ok(SslCtx { ptr })
}
pub fn new_client() -> Result<Self, ErrorStack> {
let method = unsafe { sys::TLS_client_method() };
let ptr = unsafe { sys::SSL_CTX_new(method) };
if ptr.is_null() {
return Err(ErrorStack::drain());
}
Ok(SslCtx { ptr })
}
pub fn new_server() -> Result<Self, ErrorStack> {
let method = unsafe { sys::TLS_server_method() };
let ptr = unsafe { sys::SSL_CTX_new(method) };
if ptr.is_null() {
return Err(ErrorStack::drain());
}
Ok(SslCtx { ptr })
}
pub fn set_min_proto_version(&self, ver: TlsVersion) -> Result<(), ErrorStack> {
let rc = unsafe { sys::SSL_CTX_ctrl(self.ptr, 123, ver as i64, std::ptr::null_mut()) };
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(())
}
pub fn set_max_proto_version(&self, ver: TlsVersion) -> Result<(), ErrorStack> {
let rc = unsafe { sys::SSL_CTX_ctrl(self.ptr, 124, ver as i64, std::ptr::null_mut()) };
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(())
}
pub fn set_verify(&self, mode: SslVerifyMode) {
unsafe { sys::SSL_CTX_set_verify(self.ptr, mode.0, None) };
}
pub fn set_cipher_list(&self, list: &CStr) -> Result<(), ErrorStack> {
let rc = unsafe { sys::SSL_CTX_set_cipher_list(self.ptr, list.as_ptr()) };
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(())
}
pub fn set_ciphersuites(&self, list: &CStr) -> Result<(), ErrorStack> {
let rc = unsafe { sys::SSL_CTX_set_ciphersuites(self.ptr, list.as_ptr()) };
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(())
}
pub fn use_certificate(&self, cert: &X509) -> Result<(), ErrorStack> {
let rc = unsafe { sys::SSL_CTX_use_certificate(self.ptr, cert.as_ptr()) };
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(())
}
pub fn use_private_key<T: HasPrivate>(&self, key: &Pkey<T>) -> Result<(), ErrorStack> {
let rc = unsafe { sys::SSL_CTX_use_PrivateKey(self.ptr, key.as_ptr()) };
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(())
}
pub fn check_private_key(&self) -> Result<(), ErrorStack> {
let rc = unsafe { sys::SSL_CTX_check_private_key(self.ptr) };
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(())
}
pub fn set_default_verify_paths(&self) -> Result<(), ErrorStack> {
let rc = unsafe { sys::SSL_CTX_set_default_verify_paths(self.ptr) };
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(())
}
pub fn disable_session_cache(&self) {
unsafe { sys::SSL_CTX_ctrl(self.ptr, 44, 0, std::ptr::null_mut()) };
}
pub fn new_ssl(&self) -> Result<Ssl, ErrorStack> {
let ptr = unsafe { sys::SSL_new(self.ptr) };
if ptr.is_null() {
return Err(ErrorStack::drain());
}
Ok(Ssl { ptr })
}
}
pub struct Ssl {
ptr: *mut sys::SSL,
}
unsafe impl Send for Ssl {}
impl Drop for Ssl {
fn drop(&mut self) {
unsafe { sys::SSL_free(self.ptr) };
}
}
impl Ssl {
pub fn set_bio_duplex(&mut self, bio: Bio) {
let ptr = bio.as_ptr();
std::mem::forget(bio);
unsafe { sys::SSL_set_bio(self.ptr, ptr, ptr) };
}
pub fn set_bio(&mut self, rbio: Bio, wbio: Bio) {
let rbio_ptr = rbio.as_ptr();
let wbio_ptr = wbio.as_ptr();
std::mem::forget(rbio);
std::mem::forget(wbio);
unsafe { sys::SSL_set_bio(self.ptr, rbio_ptr, wbio_ptr) };
}
pub fn set_hostname(&mut self, hostname: &CStr) -> Result<(), ErrorStack> {
let rc = unsafe {
sys::SSL_ctrl(
self.ptr,
55, 0, hostname.as_ptr() as *mut std::os::raw::c_void,
)
};
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(())
}
pub fn set_connect_state(&mut self) {
unsafe { sys::SSL_set_connect_state(self.ptr) };
}
pub fn set_accept_state(&mut self) {
unsafe { sys::SSL_set_accept_state(self.ptr) };
}
pub fn connect(&mut self) -> Result<(), SslIoError> {
let rc = unsafe { sys::SSL_connect(self.ptr) };
if rc == 1 {
return Ok(());
}
Err(self.ssl_io_error(rc))
}
pub fn accept(&mut self) -> Result<(), SslIoError> {
let rc = unsafe { sys::SSL_accept(self.ptr) };
if rc == 1 {
return Ok(());
}
Err(self.ssl_io_error(rc))
}
pub fn do_handshake(&mut self) -> Result<(), SslIoError> {
let rc = unsafe { sys::SSL_do_handshake(self.ptr) };
if rc == 1 {
return Ok(());
}
Err(self.ssl_io_error(rc))
}
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, SslIoError> {
let mut readbytes: usize = 0;
let rc = unsafe {
sys::SSL_read_ex(
self.ptr,
buf.as_mut_ptr().cast(),
buf.len(),
std::ptr::addr_of_mut!(readbytes),
)
};
if rc == 1 {
return Ok(readbytes);
}
Err(self.ssl_io_error(rc))
}
pub fn write(&mut self, buf: &[u8]) -> Result<usize, SslIoError> {
let mut written: usize = 0;
let rc = unsafe {
sys::SSL_write_ex(
self.ptr,
buf.as_ptr().cast(),
buf.len(),
std::ptr::addr_of_mut!(written),
)
};
if rc == 1 {
return Ok(written);
}
Err(self.ssl_io_error(rc))
}
pub fn shutdown(&mut self) -> Result<ShutdownResult, ErrorStack> {
let rc = unsafe { sys::SSL_shutdown(self.ptr) };
match rc {
1 => Ok(ShutdownResult::Complete),
0 => Ok(ShutdownResult::Sent),
_ => Err(ErrorStack::drain()),
}
}
#[must_use]
pub fn peer_certificate(&self) -> Option<X509> {
let ptr = unsafe { sys::SSL_get0_peer_certificate(self.ptr) };
if ptr.is_null() {
return None;
}
unsafe { sys::X509_up_ref(ptr) };
Some(unsafe { X509::from_ptr(ptr) })
}
#[must_use]
pub fn get1_session(&self) -> Option<SslSession> {
let ptr = unsafe { sys::SSL_get1_session(self.ptr) };
if ptr.is_null() {
None
} else {
Some(SslSession { ptr })
}
}
pub fn set_session(&mut self, session: &SslSession) -> Result<(), ErrorStack> {
let rc = unsafe { sys::SSL_set_session(self.ptr, session.ptr) };
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(())
}
fn ssl_io_error(&self, ret: i32) -> SslIoError {
let err = unsafe { sys::SSL_get_error(self.ptr, ret) };
match err {
2 => SslIoError::WantRead,
3 => SslIoError::WantWrite,
5 => SslIoError::Syscall(ErrorStack::drain()),
6 => SslIoError::ZeroReturn,
_ => {
let stack = ErrorStack::drain();
if stack.errors().next().is_none() {
SslIoError::Other(err)
} else {
SslIoError::Ssl(stack)
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pkey::{KeygenCtx, Pkey, Private, Public};
use crate::x509::{X509Builder, X509NameOwned};
fn make_ed25519_key() -> (Pkey<Private>, Pkey<Public>) {
let mut kgen = KeygenCtx::new(c"ED25519").unwrap();
let priv_key = kgen.generate().unwrap();
let pub_key = Pkey::<Public>::from(priv_key.clone());
(priv_key, pub_key)
}
fn make_self_signed_cert(priv_key: &Pkey<Private>, pub_key: &Pkey<Public>) -> X509 {
let mut name = X509NameOwned::new().unwrap();
name.add_entry_by_txt(c"CN", b"test").unwrap();
X509Builder::new()
.unwrap()
.set_version(2)
.unwrap()
.set_serial_number(1)
.unwrap()
.set_not_before_offset(0)
.unwrap()
.set_not_after_offset(86400)
.unwrap()
.set_subject_name(&name)
.unwrap()
.set_issuer_name(&name)
.unwrap()
.set_public_key(pub_key)
.unwrap()
.sign(priv_key, None)
.unwrap()
.build()
}
fn do_handshake_pair(mut client: Ssl, mut server: Ssl) -> Result<(Ssl, Ssl), SslIoError> {
let mut client_bio: *mut sys::BIO = std::ptr::null_mut();
let mut server_bio: *mut sys::BIO = std::ptr::null_mut();
let rc = unsafe {
sys::BIO_new_bio_pair(
std::ptr::addr_of_mut!(client_bio),
0,
std::ptr::addr_of_mut!(server_bio),
0,
)
};
assert_eq!(rc, 1, "BIO_new_bio_pair failed");
let client_bio_obj = unsafe { Bio::from_ptr_owned(client_bio) };
let server_bio_obj = unsafe { Bio::from_ptr_owned(server_bio) };
client.set_bio_duplex(client_bio_obj);
server.set_bio_duplex(server_bio_obj);
let mut client_done = false;
let mut server_done = false;
for _ in 0..20 {
if !client_done {
match client.connect() {
Ok(()) => client_done = true,
Err(SslIoError::WantRead | SslIoError::WantWrite) => {}
Err(e) => return Err(e),
}
}
if !server_done {
match server.accept() {
Ok(()) => server_done = true,
Err(SslIoError::WantRead | SslIoError::WantWrite) => {}
Err(e) => return Err(e),
}
}
if client_done && server_done {
return Ok((client, server));
}
}
Err(SslIoError::Other(-1))
}
#[test]
fn ctx_new_variants() {
SslCtx::new().unwrap();
SslCtx::new_client().unwrap();
SslCtx::new_server().unwrap();
}
#[test]
fn ctx_clone() {
let ctx = SslCtx::new().unwrap();
let _clone = ctx.clone();
}
#[test]
fn ctx_proto_version() {
let ctx = SslCtx::new().unwrap();
ctx.set_min_proto_version(TlsVersion::Tls12).unwrap();
ctx.set_max_proto_version(TlsVersion::Tls13).unwrap();
}
#[test]
fn ctx_verify_mode() {
let ctx = SslCtx::new().unwrap();
ctx.set_verify(SslVerifyMode::NONE);
ctx.set_verify(SslVerifyMode::PEER);
ctx.set_verify(SslVerifyMode::PEER.or(SslVerifyMode::FAIL_IF_NO_PEER_CERT));
}
#[test]
fn ctx_cipher_list() {
let ctx = SslCtx::new().unwrap();
ctx.set_cipher_list(c"HIGH:!aNULL").unwrap();
}
#[test]
fn ctx_load_cert_and_key() {
let (priv_key, pub_key) = make_ed25519_key();
let cert = make_self_signed_cert(&priv_key, &pub_key);
let ctx = SslCtx::new_server().unwrap();
ctx.use_certificate(&cert).unwrap();
ctx.use_private_key(&priv_key).unwrap();
ctx.check_private_key().unwrap();
}
#[test]
fn tls13_handshake_ed25519() {
let (priv_key, pub_key) = make_ed25519_key();
let cert = make_self_signed_cert(&priv_key, &pub_key);
let server_ctx = SslCtx::new_server().unwrap();
server_ctx.set_min_proto_version(TlsVersion::Tls13).unwrap();
server_ctx.set_max_proto_version(TlsVersion::Tls13).unwrap();
server_ctx.use_certificate(&cert).unwrap();
server_ctx.use_private_key(&priv_key).unwrap();
server_ctx.check_private_key().unwrap();
server_ctx.disable_session_cache();
let client_ctx = SslCtx::new_client().unwrap();
client_ctx.set_min_proto_version(TlsVersion::Tls13).unwrap();
client_ctx.set_max_proto_version(TlsVersion::Tls13).unwrap();
client_ctx.set_verify(SslVerifyMode::NONE);
client_ctx.disable_session_cache();
let client_ssl = client_ctx.new_ssl().unwrap();
let server_ssl = server_ctx.new_ssl().unwrap();
let (mut client, mut server) =
do_handshake_pair(client_ssl, server_ssl).expect("TLS 1.3 handshake failed");
let msg = b"hello TLS 1.3";
let n = client.write(msg).unwrap();
assert_eq!(n, msg.len());
let mut buf = [0u8; 64];
let n = server.read(&mut buf).unwrap();
assert_eq!(&buf[..n], msg);
let reply = b"world";
server.write(reply).unwrap();
let n = client.read(&mut buf).unwrap();
assert_eq!(&buf[..n], reply);
}
#[test]
fn tls12_handshake() {
let (priv_key, pub_key) = make_ed25519_key();
let cert = make_self_signed_cert(&priv_key, &pub_key);
let server_ctx = SslCtx::new_server().unwrap();
server_ctx.set_min_proto_version(TlsVersion::Tls12).unwrap();
server_ctx.set_max_proto_version(TlsVersion::Tls12).unwrap();
server_ctx.use_certificate(&cert).unwrap();
server_ctx.use_private_key(&priv_key).unwrap();
server_ctx.check_private_key().unwrap();
server_ctx.disable_session_cache();
let client_ctx = SslCtx::new_client().unwrap();
client_ctx.set_min_proto_version(TlsVersion::Tls12).unwrap();
client_ctx.set_max_proto_version(TlsVersion::Tls12).unwrap();
client_ctx.set_verify(SslVerifyMode::NONE);
client_ctx.disable_session_cache();
let client_ssl = client_ctx.new_ssl().unwrap();
let server_ssl = server_ctx.new_ssl().unwrap();
let (mut client, mut server) =
do_handshake_pair(client_ssl, server_ssl).expect("TLS 1.2 handshake failed");
client.write(b"tls12").unwrap();
let mut buf = [0u8; 16];
let n = server.read(&mut buf).unwrap();
assert_eq!(&buf[..n], b"tls12");
}
#[test]
fn peer_certificate_after_handshake() {
let (priv_key, pub_key) = make_ed25519_key();
let cert = make_self_signed_cert(&priv_key, &pub_key);
let cert_der = cert.to_der().unwrap();
let server_ctx = SslCtx::new_server().unwrap();
server_ctx.use_certificate(&cert).unwrap();
server_ctx.use_private_key(&priv_key).unwrap();
server_ctx.disable_session_cache();
let client_ctx = SslCtx::new_client().unwrap();
client_ctx.set_verify(SslVerifyMode::NONE);
client_ctx.disable_session_cache();
let (client, _server) =
do_handshake_pair(client_ctx.new_ssl().unwrap(), server_ctx.new_ssl().unwrap())
.unwrap();
let peer_cert = client.peer_certificate().expect("no peer certificate");
let peer_der = peer_cert.to_der().unwrap();
assert_eq!(peer_der, cert_der, "peer cert DER mismatch");
}
#[test]
fn shutdown_sequence() {
let (priv_key, pub_key) = make_ed25519_key();
let cert = make_self_signed_cert(&priv_key, &pub_key);
let server_ctx = SslCtx::new_server().unwrap();
server_ctx.use_certificate(&cert).unwrap();
server_ctx.use_private_key(&priv_key).unwrap();
server_ctx.disable_session_cache();
let client_ctx = SslCtx::new_client().unwrap();
client_ctx.set_verify(SslVerifyMode::NONE);
client_ctx.disable_session_cache();
let (mut client, mut server) =
do_handshake_pair(client_ctx.new_ssl().unwrap(), server_ctx.new_ssl().unwrap())
.unwrap();
let r = client.shutdown().unwrap();
assert_eq!(r, ShutdownResult::Sent);
let r = server.shutdown().unwrap();
assert_eq!(r, ShutdownResult::Sent);
let r = client.shutdown().unwrap();
assert_eq!(r, ShutdownResult::Complete);
}
#[test]
fn verify_mode_bits() {
let both = SslVerifyMode::PEER.or(SslVerifyMode::FAIL_IF_NO_PEER_CERT);
assert_eq!(both.0, 0x03);
}
}