use crate::listeners::TlsAcceptCallbacks;
use crate::protocols::tls::rustls::TlsStream;
use crate::protocols::IO;
use crate::{listeners::tls::Acceptor, protocols::Shutdown};
use async_trait::async_trait;
use log::warn;
use pingora_error::{ErrorType::*, OrErr, Result};
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
impl<S: AsyncRead + AsyncWrite + Send + Unpin> TlsStream<S> {
async fn start_accept(mut self: Pin<&mut Self>) -> Result<bool> {
let res = self.accept().await;
match res {
Ok(()) => Ok(true),
Err(e) => {
if e.etype == TLSWantX509Lookup {
Ok(false)
} else {
Err(e)
}
}
}
}
async fn resume_accept(mut self: Pin<&mut Self>) -> Result<()> {
self.accept().await
}
}
async fn prepare_tls_stream<S: IO>(acceptor: &Acceptor, io: S) -> Result<TlsStream<S>> {
TlsStream::from_acceptor(acceptor, io)
.await
.explain_err(TLSHandshakeFailure, |e| format!("tls stream error: {e}"))
}
pub async fn handshake<S: IO>(acceptor: &Acceptor, io: S) -> Result<TlsStream<S>> {
let mut stream = prepare_tls_stream(acceptor, io).await?;
stream
.accept()
.await
.explain_err(TLSHandshakeFailure, |e| format!("TLS accept() failed: {e}"))?;
Ok(stream)
}
pub async fn handshake_with_callback<S: IO>(
acceptor: &Acceptor,
io: S,
_callbacks: &TlsAcceptCallbacks,
) -> Result<TlsStream<S>> {
let mut tls_stream = prepare_tls_stream(acceptor, io).await?;
let done = Pin::new(&mut tls_stream).start_accept().await?;
if !done {
warn!("Callacks are not supported with feature \"rustls\".");
Pin::new(&mut tls_stream)
.resume_accept()
.await
.explain_err(TLSHandshakeFailure, |e| format!("TLS accept() failed: {e}"))?;
}
Ok(tls_stream)
}
#[async_trait]
impl<S> Shutdown for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Sync + Unpin + Send,
{
async fn shutdown(&mut self) {
match <Self as AsyncWriteExt>::shutdown(self).await {
Ok(()) => {}
Err(e) => {
warn!("TLS shutdown failed, {e}");
}
}
}
}
#[ignore]
#[tokio::test]
async fn test_async_cert() {
todo!("callback support and test for Rustls")
}