use crate::{certificate, client, server};
use core::{
sync::atomic::{AtomicBool, AtomicU8, Ordering},
task::Poll,
};
use openssl::{ec::EcKey, ecdsa::EcdsaSig};
use pin_project::pin_project;
use s2n_quic_core::{
crypto::tls::{
self,
testing::certificates::{CERT_PEM, KEY_PEM, UNTRUSTED_CERT_PEM, UNTRUSTED_KEY_PEM},
Endpoint,
},
transport,
};
#[cfg(any(test, feature = "unstable_client_hello"))]
use s2n_tls::callbacks::ClientHelloCallback;
#[cfg(any(test, feature = "unstable_client_hello"))]
use s2n_tls::callbacks::{PrivateKeyCallback, PrivateKeyOperation};
use s2n_tls::{
callbacks::{ConnectionFuture, VerifyHostNameCallback},
connection::Connection,
error::Error,
};
use std::{sync::Arc, time::SystemTime};
pub struct MyCallbackHandler {
done: Arc<AtomicBool>,
wait_counter: Arc<AtomicU8>,
}
impl MyCallbackHandler {
fn new(wait_counter: u8) -> Self {
MyCallbackHandler {
done: Arc::new(AtomicBool::new(false)),
wait_counter: Arc::new(AtomicU8::new(wait_counter)),
}
}
}
#[cfg(any(test, feature = "unstable_client_hello"))]
impl ClientHelloCallback for MyCallbackHandler {
fn on_client_hello(
&self,
_connection: &mut Connection,
) -> Result<Option<std::pin::Pin<Box<dyn s2n_tls::callbacks::ConnectionFuture>>>, Error> {
let fut = MyConnectionFuture {
done: self.done.clone(),
wait_counter: self.wait_counter.clone(),
};
Ok(Some(Box::pin(fut)))
}
}
struct MyConnectionFuture {
done: Arc<AtomicBool>,
wait_counter: Arc<AtomicU8>,
}
impl ConnectionFuture for MyConnectionFuture {
fn poll(
self: std::pin::Pin<&mut Self>,
_connection: &mut Connection,
_ctx: &mut core::task::Context,
) -> Poll<Result<(), Error>> {
if self.wait_counter.fetch_sub(1, Ordering::SeqCst) == 0 {
self.done.store(true, Ordering::SeqCst);
return Poll::Ready(Ok(()));
}
Poll::Pending
}
}
pub static TICKET_KEY: [u8; 16] = [0; 16];
pub static TICKET_KEY_NAME: &[u8] = "keyname".as_bytes();
struct ResumptionConfig;
impl ResumptionConfig {
fn build() -> Result<s2n_tls::config::Config, s2n_tls::error::Error> {
let mut config_builder = s2n_tls::config::Builder::new();
config_builder
.enable_session_tickets(true)?
.add_session_ticket_key(TICKET_KEY_NAME, &TICKET_KEY, SystemTime::now())?
.load_pem(CERT_PEM.as_bytes(), KEY_PEM.as_bytes())?
.set_security_policy(&s2n_tls::security::DEFAULT_TLS13)?
.enable_quic()?
.set_application_protocol_preference([b"h3"])?;
config_builder.build()
}
}
impl crate::ConfigLoader for ResumptionConfig {
fn load(&mut self, _cx: crate::ConnectionContext) -> s2n_tls::config::Config {
Self::build().expect("Config builder failed")
}
}
#[cfg(any(test, feature = "unstable_private_key"))]
impl PrivateKeyCallback for MyCallbackHandler {
fn handle_operation(
&self,
_connection: &mut Connection,
op: PrivateKeyOperation,
) -> Result<Option<std::pin::Pin<Box<dyn s2n_tls::callbacks::ConnectionFuture>>>, Error> {
let future = MyConnectionFuture {
done: self.done.clone(),
wait_counter: self.wait_counter.clone(),
};
let op = Some(op);
let pkey_future = MyPrivateKeyFuture { future, op };
Ok(Some(Box::pin(pkey_future)))
}
}
#[pin_project]
struct MyPrivateKeyFuture {
#[pin]
future: MyConnectionFuture,
#[pin]
op: Option<PrivateKeyOperation>,
}
impl ConnectionFuture for MyPrivateKeyFuture {
fn poll(
self: std::pin::Pin<&mut Self>,
conn: &mut Connection,
ctx: &mut core::task::Context,
) -> Poll<Result<(), Error>> {
let mut this = self.project();
if this.future.poll(conn, ctx).is_pending() {
return Poll::Pending;
}
let op = this.op.take().expect("Missing pkey operation");
let in_buf_size = op.input_size()?;
let mut in_buf = vec![0; in_buf_size];
op.input(&mut in_buf)?;
let key = EcKey::private_key_from_pem(KEY_PEM.as_bytes())
.expect("Failed to create EcKey from pem");
let sig = EcdsaSig::sign(&in_buf, &key).expect("Failed to sign input");
let out = sig.to_der().expect("Failed to convert signature to der");
op.set_output(conn, &out)?;
Poll::Ready(Ok(()))
}
}
pub struct VerifyHostNameClientCertVerifier {
host_name: String,
}
impl VerifyHostNameCallback for VerifyHostNameClientCertVerifier {
fn verify_host_name(&self, host_name: &str) -> bool {
self.host_name == host_name
}
}
impl VerifyHostNameClientCertVerifier {
pub fn new(host_name: impl ToString) -> VerifyHostNameClientCertVerifier {
VerifyHostNameClientCertVerifier {
host_name: host_name.to_string(),
}
}
}
#[derive(Default)]
pub struct RejectAllClientCertificatesHandler {}
impl VerifyHostNameCallback for RejectAllClientCertificatesHandler {
fn verify_host_name(&self, _host_name: &str) -> bool {
false
}
}
fn s2n_client() -> client::Client {
client::Builder::default()
.with_certificate(CERT_PEM)
.unwrap()
.build()
.unwrap()
}
fn s2n_client_with_client_auth() -> Result<client::Client, Error> {
client::Builder::default()
.with_empty_trust_store()?
.with_certificate(CERT_PEM)?
.with_client_identity(CERT_PEM, KEY_PEM)?
.build()
}
fn s2n_client_with_untrusted_client_auth() -> Result<client::Client, Error> {
client::Builder::default()
.with_empty_trust_store()?
.with_certificate(CERT_PEM)?
.with_client_identity(UNTRUSTED_CERT_PEM, UNTRUSTED_KEY_PEM)?
.build()
}
fn s2n_client_with_fixed_hostname_auth(host_name: &str) -> Result<client::Client, Error> {
client::Builder::default()
.with_certificate(CERT_PEM)
.unwrap()
.with_verify_host_name_callback(VerifyHostNameClientCertVerifier::new(host_name))
.unwrap()
.build()
}
fn s2n_client_with_resumption() -> Result<client::Client, Error> {
let mut builder = client::Builder::default().with_certificate(CERT_PEM)?;
struct TicketCallback;
impl s2n_tls::callbacks::SessionTicketCallback for TicketCallback {
fn on_session_ticket(
&self,
_connection: &mut Connection,
_session_ticket: &s2n_tls::callbacks::SessionTicket,
) {
}
}
let cb = TicketCallback;
builder.config_mut().set_session_ticket_callback(cb)?;
builder.config_mut().enable_session_tickets(true)?;
builder.build()
}
fn s2n_server() -> server::Server {
server::Builder::default()
.with_certificate(CERT_PEM, KEY_PEM)
.unwrap()
.build()
.unwrap()
}
fn s2n_server_with_client_auth() -> Result<server::Server, Error> {
server::Builder::default()
.with_empty_trust_store()?
.with_client_authentication()?
.with_verify_host_name_callback(VerifyHostNameClientCertVerifier::new("qlaws.qlaws"))?
.with_certificate(CERT_PEM, KEY_PEM)?
.with_trusted_certificate(CERT_PEM)?
.build()
}
fn s2n_server_with_resumption() -> server::Server<ResumptionConfig> {
server::Server::from_loader(ResumptionConfig)
}
fn s2n_server_with_client_auth_verifier_rejects_client_certs() -> Result<server::Server, Error> {
server::Builder::default()
.with_empty_trust_store()?
.with_client_authentication()?
.with_verify_host_name_callback(RejectAllClientCertificatesHandler::default())?
.with_certificate(CERT_PEM, KEY_PEM)?
.with_trusted_certificate(CERT_PEM)?
.build()
}
#[cfg(any(test, feature = "unstable_client_hello"))]
fn s2n_server_with_client_hello_callback(wait_counter: u8) -> (server::Server, Arc<AtomicBool>) {
let handle = MyCallbackHandler::new(wait_counter);
let done = handle.done.clone();
let tls = server::Builder::default()
.with_certificate(CERT_PEM, KEY_PEM)
.unwrap()
.with_client_hello_handler(handle)
.unwrap()
.build()
.unwrap();
(tls, done)
}
#[cfg(any(test, feature = "unstable_private_key"))]
fn s2n_server_with_private_key_callback(wait_counter: u8) -> (server::Server, Arc<AtomicBool>) {
let handle = MyCallbackHandler::new(wait_counter);
let done = handle.done.clone();
let tls = server::Builder::default()
.with_certificate(CERT_PEM, certificate::OFFLOAD_PRIVATE_KEY)
.unwrap()
.with_private_key_handler(handle)
.unwrap()
.build()
.unwrap();
(tls, done)
}
fn rustls_server() -> s2n_quic_rustls::server::Server {
s2n_quic_rustls::server::Builder::default()
.with_certificate(CERT_PEM, KEY_PEM)
.unwrap()
.build()
.unwrap()
}
fn rustls_client() -> s2n_quic_rustls::client::Client {
s2n_quic_rustls::client::Builder::default()
.with_certificate(CERT_PEM)
.unwrap()
.build()
.unwrap()
}
#[test]
#[cfg_attr(miri, ignore)]
fn s2n_client_s2n_server_ch_callback_test() {
for wait_counter in 0..=10 {
let mut client_endpoint = s2n_client();
let (mut server_endpoint, done) = s2n_server_with_client_hello_callback(wait_counter);
run(&mut server_endpoint, &mut client_endpoint, Some(done));
}
}
#[test]
#[cfg_attr(miri, ignore)]
fn s2n_client_s2n_server_pkey_callback_test() {
for wait_counter in 0..=10 {
let mut client_endpoint = s2n_client();
let (mut server_endpoint, done) = s2n_server_with_private_key_callback(wait_counter);
run(&mut server_endpoint, &mut client_endpoint, Some(done));
}
}
#[test]
#[cfg_attr(miri, ignore)]
fn s2n_client_s2n_server_test() {
let mut client_endpoint = s2n_client();
let mut server_endpoint = s2n_server();
run(&mut server_endpoint, &mut client_endpoint, None);
}
#[test]
#[cfg_attr(miri, ignore)]
fn s2n_client_s2n_server_resumption_test() {
let mut client_endpoint = s2n_client_with_resumption().unwrap();
let mut server_endpoint = s2n_server_with_resumption();
let pair = run_result(&mut server_endpoint, &mut client_endpoint, None).unwrap();
assert!(
!pair.client.context.application.rx.is_empty(),
"expected session ticket message in RX"
);
}
#[test]
#[cfg_attr(miri, ignore)]
fn rustls_client_s2n_server_resumption_test() {
let mut client_endpoint = rustls_client();
let mut server_endpoint = s2n_server_with_resumption();
let pair = run_result(&mut server_endpoint, &mut client_endpoint, None).unwrap();
assert!(
!pair.client.context.application.rx.is_empty(),
"expected session ticket message in RX"
);
}
#[test]
#[cfg_attr(miri, ignore)]
fn rustls_client_s2n_server_test() {
let mut client_endpoint = rustls_client();
let mut server_endpoint = s2n_server();
run(&mut server_endpoint, &mut client_endpoint, None);
}
#[test]
#[cfg_attr(miri, ignore)]
fn s2n_client_rustls_server_test() {
let mut client_endpoint = s2n_client();
let mut server_endpoint = rustls_server();
run(&mut server_endpoint, &mut client_endpoint, None);
}
#[test]
#[cfg_attr(miri, ignore)]
fn s2n_client_s2n_server_client_auth_test() {
let mut client_endpoint = s2n_client_with_client_auth().unwrap();
let mut server_endpoint = s2n_server_with_client_auth().unwrap();
run(&mut server_endpoint, &mut client_endpoint, None);
}
#[test]
#[cfg_attr(miri, ignore)]
fn s2n_client_no_client_auth_s2n_server_requires_client_auth_test() {
let mut client_endpoint = s2n_client();
let mut server_endpoint = s2n_server_with_client_auth().unwrap();
let test_result = run_result(&mut server_endpoint, &mut client_endpoint, None);
assert!(test_result.is_err());
let e = test_result.unwrap_err();
assert_eq!(e.description().unwrap(), "CERTIFICATE_REQUIRED");
}
#[test]
#[cfg_attr(miri, ignore)]
fn s2n_client_with_client_auth_s2n_server_does_not_require_client_auth_test() {
let mut client_endpoint = s2n_client_with_client_auth().unwrap();
let mut server_endpoint = s2n_server();
let test_result = run_result(&mut server_endpoint, &mut client_endpoint, None);
assert!(test_result.is_err());
let e = test_result.unwrap_err();
assert_eq!(e.description().unwrap(), "UNEXPECTED_MESSAGE");
}
#[test]
#[cfg_attr(miri, ignore)]
fn s2n_client_with_client_auth_s2n_server_does_not_trust_client_certificate() {
let mut client_endpoint = s2n_client_with_client_auth().unwrap();
let mut server_endpoint = s2n_server_with_client_auth_verifier_rejects_client_certs().unwrap();
let test_result = run_result(&mut server_endpoint, &mut client_endpoint, None);
assert!(test_result.is_err());
let e = test_result.unwrap_err();
assert_eq!(e.description().unwrap(), "HANDSHAKE_FAILURE");
}
#[test]
#[cfg_attr(miri, ignore)]
fn s2n_client_with_client_auth_s2n_server_does_not_trust_issuer() {
let mut client_endpoint = s2n_client_with_untrusted_client_auth().unwrap();
let mut server_endpoint = s2n_server_with_client_auth().unwrap();
let test_result = run_result(&mut server_endpoint, &mut client_endpoint, None);
assert!(test_result.is_err());
let e = test_result.unwrap_err();
assert_eq!(e.description().unwrap(), "HANDSHAKE_FAILURE");
}
#[test]
#[cfg_attr(miri, ignore)]
fn s2n_client_with_custom_hostname_auth_rejects_server_name() {
let mut client_endpoint = s2n_client_with_fixed_hostname_auth("not-localhost").unwrap();
let mut server_endpoint = s2n_server();
let test_result = run_result(&mut server_endpoint, &mut client_endpoint, None);
assert!(test_result.is_err());
let e = test_result.unwrap_err();
assert_eq!(e.description().unwrap(), "HANDSHAKE_FAILURE");
}
#[test]
#[cfg_attr(miri, ignore)]
fn s2n_client_with_custom_hostname_auth_accepts_server_name() {
let mut client_endpoint = s2n_client_with_fixed_hostname_auth("localhost").unwrap();
let mut server_endpoint = s2n_server();
run_result(&mut server_endpoint, &mut client_endpoint, None).unwrap();
}
fn run_result<S: Endpoint, C: Endpoint>(
server: &mut S,
client: &mut C,
client_hello_cb_done: Option<Arc<AtomicBool>>,
) -> Result<tls::testing::Pair<S::Session, C::Session>, transport::Error> {
let mut pair = tls::testing::Pair::new(server, client, "localhost".into());
while pair.is_handshaking() {
pair.poll(client_hello_cb_done.as_ref())?;
}
pair.finish();
Ok(pair)
}
fn run<S: Endpoint, C: Endpoint>(
server: &mut S,
client: &mut C,
client_hello_cb_done: Option<Arc<AtomicBool>>,
) {
run_result(server, client, client_hello_cb_done).unwrap();
}
#[test]
fn config_loader() {
use crate::{ConfigLoader, Server};
let server = Server::default();
let server: Server<Server> = Server::from_loader(server);
let server: Box<dyn ConfigLoader> = Box::new(server);
let mut server: Server<Box<dyn ConfigLoader>> = Server::from_loader(server);
let _ = server.new_server_session(&1);
}