#[cfg(feature = "std")]
use std::borrow::Cow;
#[cfg(feature = "std")]
use std::sync::Arc;
use core::slice::from_raw_parts;
use mbedtls_sys::types::raw_types::*;
use mbedtls_sys::types::size_t;
use mbedtls_sys::*;
use crate::alloc::List as MbedtlsList;
#[cfg(not(feature = "std"))]
use crate::alloc_prelude::*;
use crate::error::{IntoResult, Result};
use crate::pk::dhparam::Dhm;
use crate::pk::Pk;
use crate::private::UnsafeFrom;
use crate::rng::RngCallback;
use crate::ssl::context::HandshakeContext;
use crate::ssl::cookie::CookieCallback;
use crate::ssl::ticket::TicketCallback;
use crate::x509::{self, Certificate, Crl, Profile, VerifyCallback};
#[allow(non_camel_case_types)]
#[derive(Eq, PartialEq, PartialOrd, Ord, Debug, Copy, Clone)]
pub enum Version {
Ssl3,
Tls1_0,
Tls1_1,
Tls1_2,
#[doc(hidden)]
__NonExhaustive,
}
define!(
#[c_ty(c_int)]
enum Endpoint {
Client = SSL_IS_CLIENT,
Server = SSL_IS_SERVER,
}
);
define!(
#[c_ty(c_int)]
#[derive(PartialEq, Eq)]
enum Transport {
Stream = SSL_TRANSPORT_STREAM,
Datagram = SSL_TRANSPORT_DATAGRAM,
}
);
define!(
#[c_ty(c_int)]
enum Preset {
Default = SSL_PRESET_DEFAULT,
SuiteB = SSL_PRESET_SUITEB,
}
);
define!(
#[c_ty(c_int)]
enum AuthMode {
None = SSL_VERIFY_NONE,
Optional = SSL_VERIFY_OPTIONAL,
Required = SSL_VERIFY_REQUIRED,
}
);
define!(
#[c_ty(c_int)]
enum UseSessionTickets {
Enabled = SSL_SESSION_TICKETS_ENABLED,
Disabled = SSL_SESSION_TICKETS_DISABLED,
}
);
define!(
#[c_ty(c_int)]
enum Renegotiation {
Enabled = SSL_RENEGOTIATION_ENABLED,
Disabled = SSL_RENEGOTIATION_DISABLED,
}
);
#[cfg(feature = "std")]
callback!(DbgCallback: Fn(i32, Cow<'_, str>, i32, Cow<'_, str>) -> ());
callback!(SniCallback: Fn(&mut HandshakeContext, &[u8]) -> Result<()>);
callback!(CaCallback: Fn(&MbedtlsList<Certificate>) -> Result<MbedtlsList<Certificate>>);
#[repr(transparent)]
pub struct NullTerminatedStrList {
c: Vec<*mut c_char>,
}
unsafe impl Send for NullTerminatedStrList {}
unsafe impl Sync for NullTerminatedStrList {}
impl NullTerminatedStrList {
#[cfg(feature = "std")]
pub fn new(list: &[&str]) -> Result<Self> {
let mut ret = NullTerminatedStrList {
c: Vec::with_capacity(list.len() + 1),
};
for item in list {
ret.c.push(
::std::ffi::CString::new(*item)
.map_err(|_| crate::error::codes::SslBadInputData)?
.into_raw(),
);
}
ret.c.push(core::ptr::null_mut());
Ok(ret)
}
pub fn as_ptr(&self) -> *const *const c_char {
self.c.as_ptr() as *const _
}
}
#[cfg(feature = "std")]
impl Drop for NullTerminatedStrList {
fn drop(&mut self) {
for i in self.c.iter() {
unsafe {
if !(*i).is_null() {
let _ = ::std::ffi::CString::from_raw(*i);
}
}
}
}
}
define!(
#[c_ty(ssl_config)]
#[repr(C)]
struct Config {
own_cert: Vec<Arc<MbedtlsList<Certificate>>>,
own_pk: Vec<Arc<Pk>>,
ca_cert: Option<Arc<MbedtlsList<Certificate>>>,
crl: Option<Arc<Crl>>,
rng: Option<Arc<dyn RngCallback + 'static>>,
ciphersuites: Vec<Arc<Vec<c_int>>>,
curves: Option<Arc<Vec<ecp_group_id>>>,
protocols: Option<Arc<NullTerminatedStrList>>,
verify_callback: Option<Arc<dyn VerifyCallback + 'static>>,
#[cfg(feature = "std")]
dbg_callback: Option<Arc<dyn DbgCallback + 'static>>,
sni_callback: Option<Arc<dyn SniCallback + 'static>>,
ticket_callback: Option<Arc<dyn TicketCallback + 'static>>,
ca_callback: Option<Arc<dyn CaCallback + 'static>>,
dtls_cookies: Option<Arc<dyn CookieCallback + 'static>>,
};
const drop: fn(&mut Self) = ssl_config_free;
impl<'a> Into<ptr> {}
);
unsafe impl Sync for Config {}
impl Config {
pub fn new(e: Endpoint, t: Transport, p: Preset) -> Self {
let mut inner = ssl_config::default();
unsafe {
ssl_config_init(&mut inner);
ssl_config_defaults(&mut inner, e as c_int, t as c_int, p as c_int);
};
Config {
inner,
own_cert: vec![],
own_pk: vec![],
ca_cert: None,
crl: None,
rng: None,
ciphersuites: vec![],
curves: None,
protocols: None,
verify_callback: None,
#[cfg(feature = "std")]
dbg_callback: None,
sni_callback: None,
ticket_callback: None,
ca_callback: None,
dtls_cookies: None,
}
}
setter!(set_endpoint(e: Endpoint) = ssl_conf_endpoint);
setter!(set_transport(t: Transport) = ssl_conf_transport);
setter!(set_authmode(am: AuthMode) = ssl_conf_authmode);
getter!(read_timeout() -> u32 = .read_timeout);
setter!(set_read_timeout(t: u32) = ssl_conf_read_timeout);
fn check_c_list<T: Default + Eq>(list: &[T]) {
assert!(list.last() == Some(&T::default()));
}
pub fn set_ciphersuites(&mut self, list: Arc<Vec<c_int>>) {
Self::check_c_list(&list);
unsafe { ssl_conf_ciphersuites(self.into(), list.as_ptr()) }
self.ciphersuites.push(list);
}
pub fn set_alpn_protocols(&mut self, protocols: Arc<NullTerminatedStrList>) -> Result<()> {
unsafe {
ssl_conf_alpn_protocols(&mut self.inner, protocols.as_ptr() as *mut _)
.into_result()
.map(|_| ())?;
}
self.protocols = Some(protocols);
Ok(())
}
pub fn set_ciphersuites_for_version(&mut self, list: Arc<Vec<c_int>>, major: c_int, minor: c_int) {
Self::check_c_list(&list);
unsafe { ssl_conf_ciphersuites_for_version(self.into(), list.as_ptr(), major, minor) }
self.ciphersuites.push(list);
}
pub fn set_curves(&mut self, list: Arc<Vec<ecp_group_id>>) {
Self::check_c_list(&list);
unsafe { ssl_conf_curves(self.into(), list.as_ptr()) }
self.curves = Some(list);
}
pub fn set_rng<T: RngCallback + 'static>(&mut self, rng: Arc<T>) {
unsafe { ssl_conf_rng(self.into(), Some(T::call), rng.data_ptr()) };
self.rng = Some(rng);
}
pub fn set_min_version(&mut self, version: Version) -> Result<()> {
let minor = match version {
Version::Ssl3 => 0,
Version::Tls1_0 => 1,
Version::Tls1_1 => 2,
Version::Tls1_2 => 3,
_ => {
return Err(crate::error::codes::SslBadHsProtocolVersion.into());
}
};
unsafe { ssl_conf_min_version(self.into(), 3, minor) };
Ok(())
}
pub fn set_max_version(&mut self, version: Version) -> Result<()> {
let minor = match version {
Version::Ssl3 => 0,
Version::Tls1_0 => 1,
Version::Tls1_1 => 2,
Version::Tls1_2 => 3,
_ => {
return Err(crate::error::codes::SslBadHsProtocolVersion.into());
}
};
unsafe { ssl_conf_max_version(self.into(), 3, minor) };
Ok(())
}
setter!(set_cert_profile(p: &'static Profile) = ssl_conf_cert_profile);
pub fn set_dh_params(&mut self, dhm: &Dhm) -> Result<()> {
unsafe {
ssl_conf_dh_param_ctx(self.into(), dhm.inner_ffi_mut())
.into_result()
.map(|_| ())?;
}
Ok(())
}
pub fn set_ca_list(&mut self, ca_cert: Arc<MbedtlsList<Certificate>>, crl: Option<Arc<Crl>>) {
unsafe {
ssl_conf_ca_chain(
self.into(),
ca_cert.inner_ffi_mut(),
crl.as_ref().map(|crl| crl.inner_ffi_mut()).unwrap_or(::core::ptr::null_mut()),
);
}
self.ca_cert = Some(ca_cert);
self.crl = crl;
}
pub fn push_cert(&mut self, own_cert: Arc<MbedtlsList<Certificate>>, own_pk: Arc<Pk>) -> Result<()> {
if own_cert.is_empty() {
return Err(crate::error::codes::SslBadInputData.into());
}
self.own_cert.push(own_cert.clone());
self.own_pk.push(own_pk.clone());
unsafe {
ssl_conf_own_cert(self.into(), own_cert.inner_ffi_mut(), own_pk.inner_ffi_mut())
.into_result()
.map(|_| ())
}
}
pub fn set_session_tickets_callback<T: TicketCallback + 'static>(&mut self, cb: Arc<T>) {
unsafe { ssl_conf_session_tickets_cb(self.into(), Some(T::call_write), Some(T::call_parse), cb.data_ptr()) };
self.ticket_callback = Some(cb);
}
setter!(
set_session_tickets(u: UseSessionTickets) = ssl_conf_session_tickets
);
setter!(set_renegotiation(u: Renegotiation) = ssl_conf_renegotiation);
setter!(
set_ffdh_min_bitlen(bitlen: c_uint) = ssl_conf_dhm_min_bitlen
);
pub fn set_sni_callback<F>(&mut self, cb: F)
where
F: SniCallback + 'static,
{
unsafe extern "C" fn sni_callback<F>(
closure: *mut c_void,
ctx: *mut ssl_context,
name: *const c_uchar,
name_len: size_t,
) -> c_int
where
F: Fn(&mut HandshakeContext, &[u8]) -> Result<()> + 'static,
{
let cb = &mut *(closure as *mut F);
let ctx = UnsafeFrom::from(ctx).unwrap();
let name = from_raw_parts(name, name_len);
match cb(ctx, name) {
Ok(()) => 0,
Err(_) => -1,
}
}
self.sni_callback = Some(Arc::new(cb));
unsafe {
ssl_conf_sni(
self.into(),
Some(sni_callback::<F>),
&**self.sni_callback.as_mut().unwrap() as *const _ as *mut c_void,
)
}
}
pub fn set_verify_callback<F>(&mut self, cb: F)
where
F: VerifyCallback + 'static,
{
self.verify_callback = Some(Arc::new(cb));
unsafe {
ssl_conf_verify(
self.into(),
Some(x509::verify_callback::<F>),
&**self.verify_callback.as_ref().unwrap() as *const _ as *mut c_void,
)
}
}
pub fn set_ca_callback<F>(&mut self, cb: F)
where
F: CaCallback + 'static,
{
unsafe extern "C" fn ca_callback<F>(
closure: *mut c_void,
child: *const x509_crt,
candidate_cas: *mut *mut x509_crt,
) -> c_int
where
F: CaCallback + 'static,
{
if child.is_null() || closure.is_null() || candidate_cas.is_null() {
return ::mbedtls_sys::ERR_X509_BAD_INPUT_DATA;
}
let cb = &mut *(closure as *mut F);
let crt: &MbedtlsList<Certificate> = UnsafeFrom::from(&child as *const *const x509_crt).expect("valid certificate");
match cb(&crt) {
Ok(list) => {
*candidate_cas = list.into_raw();
0
}
Err(e) => e.to_int(),
}
}
self.ca_callback = Some(Arc::new(cb));
unsafe {
ssl_conf_ca_cb(
self.into(),
Some(ca_callback::<F>),
&**self.ca_callback.as_mut().unwrap() as *const _ as *mut c_void,
)
}
}
#[cfg(feature = "std")]
pub fn set_dbg_callback<F>(&mut self, cb: F)
where
F: DbgCallback + 'static,
{
#[allow(dead_code)]
unsafe extern "C" fn dbg_callback<F>(
closure: *mut c_void,
level: c_int,
file: *const c_char,
line: c_int,
message: *const c_char,
) -> ()
where
F: DbgCallback + 'static,
{
let cb = &mut *(closure as *mut F);
let file = match file.is_null() {
false => std::ffi::CStr::from_ptr(file).to_string_lossy(),
true => Cow::from(""),
};
let message = match message.is_null() {
false => std::ffi::CStr::from_ptr(message).to_string_lossy(),
true => Cow::from(""),
};
cb(level, file, line, message);
}
self.dbg_callback = Some(Arc::new(cb));
unsafe {
ssl_conf_dbg(
self.into(),
Some(dbg_callback::<F>),
&**self.dbg_callback.as_mut().unwrap() as *const _ as *mut c_void,
)
}
}
pub fn set_psk(&mut self, psk: &[u8], psk_identity: &str) -> Result<()> {
unsafe {
ssl_conf_psk(
self.into(),
psk.as_ptr(),
psk.len(),
psk_identity.as_ptr(),
psk_identity.len(),
)
.into_result()
.map(|_| ())
}
}
pub fn set_dtls_cookies<T: CookieCallback + 'static>(&mut self, dtls_cookies: Arc<T>) {
unsafe {
ssl_conf_dtls_cookies(
self.into(),
Some(T::cookie_write),
Some(T::cookie_check),
dtls_cookies.data_ptr(),
)
};
self.dtls_cookies = Some(dtls_cookies);
}
}