use crate::{
callbacks::{SessionTicketCallback, VerifyHostNameCallback},
cert_chain::{self, CertificateChain},
config::{self, *},
connection,
enums::{self, Blinding},
error, security,
};
use alloc::{collections::VecDeque, sync::Arc};
use core::{
sync::atomic::{AtomicUsize, Ordering},
task::Poll,
};
use libc::{c_int, c_void};
use std::{
cell::RefCell,
io::{Read, Write},
pin::Pin,
sync::Mutex,
};
pub mod client_hello;
pub mod resumption;
pub mod s2n_tls;
type Error = Box<dyn std::error::Error>;
type Result<T, E = Error> = core::result::Result<T, E>;
pub fn test_error(msg: &str) -> crate::error::Error {
crate::error::Error::application(msg.into())
}
pub fn assert_test_error(input: crate::error::Error, expected_message: &str) {
let error_msg = input
.application_error()
.expect("unexpected error type")
.to_string();
assert_eq!(expected_message, error_msg.to_string())
}
#[derive(Clone)]
pub struct Counter(Arc<AtomicUsize>);
impl Counter {
fn new() -> Self {
Counter(Arc::new(AtomicUsize::new(0)))
}
pub fn count(&self) -> usize {
self.0.load(Ordering::Relaxed)
}
pub fn increment(&self) {
self.0.fetch_add(1, Ordering::Relaxed);
}
}
impl Default for Counter {
fn default() -> Self {
Self::new()
}
}
#[allow(non_camel_case_types)]
pub enum SniTestCerts {
AlligatorRsa,
AlligatorEcdsa,
BeaverRsa,
WildcardInsectRsa,
}
impl SniTestCerts {
pub fn get(&self) -> CertKeyPair {
let prefix = match *self {
SniTestCerts::AlligatorRsa => "alligator_",
SniTestCerts::AlligatorEcdsa => "alligator_ecdsa_",
SniTestCerts::BeaverRsa => "beaver_",
SniTestCerts::WildcardInsectRsa => "wildcard_insect_rsa_",
};
CertKeyPair::from_path(&format!("sni/{prefix}"), "cert", "key", "cert")
}
}
#[derive(Clone)]
pub struct CertKeyPair {
cert_path: String,
key_path: String,
ca_path: String,
cert: Vec<u8>,
key: Vec<u8>,
ca_cert: Vec<u8>,
}
impl Default for CertKeyPair {
fn default() -> Self {
Self::from_path("rsa_4096_sha512_client_", "cert", "key", "cert")
}
}
impl CertKeyPair {
const TEST_PEMS_PATH: &'static str =
concat!(env!("CARGO_MANIFEST_DIR"), "/../../../../tests/pems/");
pub fn from_path(prefix: &str, chain: &str, key: &str, ca: &str) -> Self {
let cert_path = format!("{}{prefix}{chain}.pem", Self::TEST_PEMS_PATH);
let key_path = format!("{}{prefix}{key}.pem", Self::TEST_PEMS_PATH);
let ca_path = format!("{}{prefix}{ca}.pem", Self::TEST_PEMS_PATH);
let cert = std::fs::read(&cert_path)
.unwrap_or_else(|_| panic!("Failed to read cert at {cert_path}"));
let key =
std::fs::read(&key_path).unwrap_or_else(|_| panic!("Failed to read key at {key_path}"));
let ca_cert =
std::fs::read(&ca_path).unwrap_or_else(|_| panic!("Failed to read cert at {ca_path}"));
CertKeyPair {
cert_path,
key_path,
ca_path,
cert,
key,
ca_cert,
}
}
pub fn into_certificate_chain(&self) -> CertificateChain<'static> {
let mut chain = cert_chain::Builder::new().unwrap();
chain.load_pem(&self.cert, &self.key).unwrap();
chain.build().unwrap()
}
pub fn cert_path(&self) -> &str {
&self.cert_path
}
pub fn key_path(&self) -> &str {
&self.key_path
}
pub fn ca_path(&self) -> &str {
&self.ca_path
}
pub fn cert(&self) -> &[u8] {
&self.cert
}
pub fn key(&self) -> &[u8] {
&self.key
}
pub fn ca_cert(&self) -> &[u8] {
&self.ca_cert
}
}
pub struct InsecureAcceptAllCertificatesHandler {}
impl VerifyHostNameCallback for InsecureAcceptAllCertificatesHandler {
fn verify_host_name(&self, _host_name: &str) -> bool {
true
}
}
pub struct RejectAllCertificatesHandler {}
impl VerifyHostNameCallback for RejectAllCertificatesHandler {
fn verify_host_name(&self, _host_name: &str) -> bool {
false
}
}
pub fn build_config(cipher_prefs: &security::Policy) -> Result<crate::config::Config, Error> {
let builder = config_builder(cipher_prefs)?;
Ok(builder.build().expect("Unable to build server config"))
}
pub fn config_builder(cipher_prefs: &security::Policy) -> Result<crate::config::Builder, Error> {
let mut builder = Builder::new();
let keypair = CertKeyPair::default();
builder
.set_security_policy(cipher_prefs)
.expect("Unable to set config cipher preferences");
builder
.load_pem(keypair.cert(), keypair.key())
.expect("Unable to load cert/pem");
builder
.set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})
.expect("Unable to set a host verify callback.");
builder.with_system_certs(false).unwrap();
builder.trust_pem(keypair.cert()).expect("load cert pem");
Ok(builder)
}
type LocalDataBuffer = RefCell<VecDeque<u8>>;
#[derive(Debug)]
#[allow(dead_code)]
pub struct TestPairIO {
pub server_tx_stream: Pin<Box<LocalDataBuffer>>,
pub client_tx_stream: Pin<Box<LocalDataBuffer>>,
}
pub struct TestPair {
pub server: connection::Connection,
pub client: connection::Connection,
pub io: TestPairIO,
}
impl TestPair {
pub fn handshake_with_config(config: &config::Config) -> Result<(), error::Error> {
Self::from_configs(config, config).handshake()
}
pub fn from_config(config: &config::Config) -> Self {
Self::from_configs(config, config)
}
pub fn from_configs(client_config: &config::Config, server_config: &config::Config) -> Self {
use crate::connection::Builder;
let client = client_config.build_connection(enums::Mode::Client).unwrap();
let server = server_config.build_connection(enums::Mode::Server).unwrap();
Self::from_connections(client, server)
}
pub fn from_connections(
mut client: connection::Connection,
mut server: connection::Connection,
) -> Self {
let client_tx_stream = Box::pin(Default::default());
let server_tx_stream = Box::pin(Default::default());
Self::register_connection(&mut client, &client_tx_stream, &server_tx_stream).unwrap();
Self::register_connection(&mut server, &server_tx_stream, &client_tx_stream).unwrap();
let io = TestPairIO {
server_tx_stream,
client_tx_stream,
};
Self { server, client, io }
}
fn register_connection(
conn: &mut connection::Connection,
send_ctx: &Pin<Box<LocalDataBuffer>>,
recv_ctx: &Pin<Box<LocalDataBuffer>>,
) -> Result<(), error::Error> {
conn.set_blinding(Blinding::SelfService)?
.set_send_callback(Some(Self::send_cb))?
.set_receive_callback(Some(Self::recv_cb))?;
unsafe {
conn.set_send_context(
send_ctx as &LocalDataBuffer as *const LocalDataBuffer as *mut c_void,
)?
.set_receive_context(
recv_ctx as &LocalDataBuffer as *const LocalDataBuffer as *mut c_void,
)?;
}
Ok(())
}
pub fn handshake(&mut self) -> Result<(), error::Error> {
loop {
match (self.client.poll_negotiate(), self.server.poll_negotiate()) {
(Poll::Ready(Ok(_)), Poll::Ready(Ok(_))) => return Ok(()),
(_, Poll::Ready(Err(e))) => return Err(e),
(Poll::Ready(Err(e)), _) => return Err(e),
_ => {}
}
}
}
pub(crate) unsafe extern "C" fn send_cb(
context: *mut c_void,
data: *const u8,
len: u32,
) -> c_int {
let context = &*(context as *const LocalDataBuffer);
let data = core::slice::from_raw_parts(data, len as _);
let bytes_written = context.borrow_mut().write(data).unwrap();
bytes_written as c_int
}
pub(crate) unsafe extern "C" fn recv_cb(
context: *mut c_void,
data: *mut u8,
len: u32,
) -> c_int {
let context = &*(context as *const LocalDataBuffer);
let data = core::slice::from_raw_parts_mut(data, len as _);
match context.borrow_mut().read(data) {
Ok(len) => {
if len == 0 {
errno::set_errno(errno::Errno(libc::EWOULDBLOCK));
-1
} else {
len as c_int
}
}
Err(err) => {
panic!("{err:?}");
}
}
}
}
type SessionState = Vec<u8>;
#[derive(Debug, Clone, Default)]
pub struct LIFOSessionResumption {
pub ticket: Arc<Mutex<Vec<SessionState>>>,
}
impl SessionTicketCallback for LIFOSessionResumption {
fn on_session_ticket(
&self,
_connection: &mut connection::Connection,
session_ticket: &crate::callbacks::SessionTicket,
) {
println!("got a session ticket of {}", session_ticket.len().unwrap());
let mut ticket_buffer = vec![0; session_ticket.len().unwrap()];
session_ticket.data(&mut ticket_buffer).unwrap();
self.ticket.lock().unwrap().push(ticket_buffer);
}
}
impl ConnectionInitializer for LIFOSessionResumption {
fn initialize_connection(
&self,
connection: &mut crate::connection::Connection,
) -> crate::callbacks::ConnectionFutureResult {
let latest_ticket = self.ticket.lock().unwrap().pop();
if let Some(ticket) = latest_ticket {
connection.set_session_ticket(&ticket).unwrap();
}
Ok(None)
}
}