use super::SslStream;
use crate::protocols::{Shutdown, IO};
use crate::tls::ext;
use crate::tls::ext::ssl_from_acceptor;
use crate::tls::ssl;
use crate::tls::ssl::{SslAcceptor, SslRef};
use async_trait::async_trait;
use log::warn;
use pingora_error::{ErrorType::*, OrErr, Result};
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
pub fn prepare_tls_stream<S: IO>(ssl_acceptor: &SslAcceptor, io: S) -> Result<SslStream<S>> {
let ssl = ssl_from_acceptor(ssl_acceptor)
.explain_err(TLSHandshakeFailure, |e| format!("ssl_acceptor error: {e}"))?;
SslStream::new(ssl, io).explain_err(TLSHandshakeFailure, |e| format!("ssl stream error: {e}"))
}
pub async fn handshake<S: IO>(ssl_acceptor: &SslAcceptor, io: S) -> Result<SslStream<S>> {
let mut stream = prepare_tls_stream(ssl_acceptor, io)?;
stream
.accept()
.await
.explain_err(TLSHandshakeFailure, |e| format!("TLS accept() failed: {e}"))?;
Ok(stream)
}
pub async fn handshake_with_callback<S: IO>(
ssl_acceptor: &SslAcceptor,
io: S,
callbacks: &TlsAcceptCallbacks,
) -> Result<SslStream<S>> {
let mut tls_stream = prepare_tls_stream(ssl_acceptor, io)?;
let done = Pin::new(&mut tls_stream)
.start_accept()
.await
.explain_err(TLSHandshakeFailure, |e| format!("TLS accept() failed: {e}"))?;
if !done {
let ssl_mut = unsafe { ext::ssl_mut(tls_stream.ssl()) };
callbacks.certificate_callback(ssl_mut).await;
Pin::new(&mut tls_stream)
.resume_accept()
.await
.explain_err(TLSHandshakeFailure, |e| format!("TLS accept() failed: {e}"))?;
Ok(tls_stream)
} else {
Ok(tls_stream)
}
}
#[async_trait]
pub trait TlsAccept {
async fn certificate_callback(&self, _ssl: &mut SslRef) -> () {
}
}
pub type TlsAcceptCallbacks = Box<dyn TlsAccept + Send + Sync>;
#[async_trait]
impl<S> Shutdown for SslStream<S>
where
S: AsyncRead + AsyncWrite + Sync + Unpin + Send,
{
async fn shutdown(&mut self) {
match <Self as AsyncWriteExt>::shutdown(self).await {
Ok(()) => {}
Err(e) => {
warn!("TLS shutdown failed, {e}");
}
}
}
}
#[async_trait]
pub trait ResumableAccept {
async fn start_accept(self: Pin<&mut Self>) -> Result<bool, ssl::Error>;
async fn resume_accept(self: Pin<&mut Self>) -> Result<(), ssl::Error>;
}
#[async_trait]
impl<S: AsyncRead + AsyncWrite + Send + Unpin> ResumableAccept for SslStream<S> {
async fn start_accept(mut self: Pin<&mut Self>) -> Result<bool, ssl::Error> {
let ssl_mut = unsafe { ext::ssl_mut(self.ssl()) };
ext::suspend_when_need_ssl_cert(ssl_mut);
let res = self.accept().await;
match res {
Ok(()) => Ok(true),
Err(e) => {
if ext::is_suspended_for_cert(&e) {
Ok(false)
} else {
Err(e)
}
}
}
}
async fn resume_accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
let ssl_mut = unsafe { ext::ssl_mut(self.ssl()) };
ext::unblock_ssl_cert(ssl_mut);
self.accept().await
}
}
#[tokio::test]
async fn test_async_cert() {
use tokio::io::AsyncReadExt;
let acceptor = ssl::SslAcceptor::mozilla_intermediate_v5(ssl::SslMethod::tls())
.unwrap()
.build();
struct Callback;
#[async_trait]
impl TlsAccept for Callback {
async fn certificate_callback(&self, ssl: &mut SslRef) -> () {
assert_eq!(
ssl.servername(ssl::NameType::HOST_NAME).unwrap(),
"pingora.org"
);
let cert = format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR"));
let key = format!("{}/tests/keys/key.pem", env!("CARGO_MANIFEST_DIR"));
let cert_bytes = std::fs::read(cert).unwrap();
let cert = crate::tls::x509::X509::from_pem(&cert_bytes).unwrap();
let key_bytes = std::fs::read(key).unwrap();
let key = crate::tls::pkey::PKey::private_key_from_pem(&key_bytes).unwrap();
ext::ssl_use_certificate(ssl, &cert).unwrap();
ext::ssl_use_private_key(ssl, &key).unwrap();
}
}
let cb: TlsAcceptCallbacks = Box::new(Callback);
let (client, server) = tokio::io::duplex(1024);
tokio::spawn(async move {
let ssl_context = ssl::SslContext::builder(ssl::SslMethod::tls())
.unwrap()
.build();
let mut ssl = ssl::Ssl::new(&ssl_context).unwrap();
ssl.set_hostname("pingora.org").unwrap();
ssl.set_verify(ssl::SslVerifyMode::NONE); let mut stream = SslStream::new(ssl, client).unwrap();
Pin::new(&mut stream).connect().await.unwrap();
let mut buf = [0; 1];
let _ = stream.read(&mut buf).await;
});
handshake_with_callback(&acceptor, server, &cb)
.await
.unwrap();
}