use crate::{config::Context, connection::Connection};
use core::{mem::ManuallyDrop, ptr::NonNull, time::Duration};
use s2n_tls_sys::s2n_connection;
mod async_cb;
pub use async_cb::*;
mod client_hello;
pub use client_hello::*;
mod session_ticket;
pub use session_ticket::*;
mod pkey;
pub use pkey::*;
#[cfg(feature = "unstable-crl")]
mod cert_validation;
#[cfg(feature = "unstable-crl")]
pub use cert_validation::*;
pub(crate) unsafe fn with_context<F, T>(conn_ptr: *mut s2n_connection, action: F) -> T
where
F: FnOnce(&mut Connection, &Context) -> T,
{
let raw = NonNull::new(conn_ptr).expect("connection should not be null");
let mut conn = ManuallyDrop::new(Connection::from_raw(raw));
let config = conn.config().expect("config should not be null");
let context = config.context();
action(&mut conn, context)
}
pub trait VerifyHostNameCallback: 'static + Send + Sync {
fn verify_host_name(&self, host_name: &str) -> bool;
}
pub trait WallClock: 'static + Send + Sync {
fn get_time_since_epoch(&self) -> Duration;
}
pub trait MonotonicClock: 'static + Send + Sync {
fn get_time(&self) -> Duration;
}
pub(crate) unsafe fn verify_host(
host_name: *const ::libc::c_char,
host_name_len: usize,
handler: &mut Box<dyn VerifyHostNameCallback>,
) -> u8 {
let host_name = host_name as *const u8;
let host_name = core::slice::from_raw_parts(host_name, host_name_len);
match core::str::from_utf8(host_name) {
Ok(host_name_str) => handler.verify_host_name(host_name_str) as u8,
Err(_) => 0, }
}
#[cfg(test)]
mod tests {
use crate::{callbacks::with_context, config::Config, connection::Builder, enums::Mode};
#[test]
fn panic_does_not_free_connection() -> Result<(), crate::error::Error> {
let config = Config::new();
let mut connection = config.build_connection(Mode::Server)?;
assert_eq!(config.test_get_refcount()?, 2);
let conn_ptr = connection.as_ptr();
let unwind_result = std::panic::catch_unwind(|| {
unsafe {
with_context(conn_ptr, |_conn, _context| {
panic!("force unwind");
})
};
});
assert!(unwind_result.is_err());
assert_eq!(config.test_get_refcount()?, 2);
drop(connection);
assert_eq!(config.test_get_refcount()?, 1);
Ok(())
}
}