pub mod client;
pub mod digest;
pub mod server;
use crate::protocols::digest::TimingDigest;
use crate::protocols::{Ssl, UniqueID};
use crate::tls::{self, ssl, tokio_ssl::SslStream as InnerSsl};
use log::warn;
use pingora_error::{ErrorType::*, OrErr, Result};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::SystemTime;
use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};
pub use digest::SslDigest;
#[derive(Debug)]
pub struct SslStream<T> {
ssl: InnerSsl<T>,
digest: Option<Arc<SslDigest>>,
timing: TimingDigest,
}
impl<T> SslStream<T>
where
T: AsyncRead + AsyncWrite + std::marker::Unpin,
{
pub fn new(ssl: ssl::Ssl, stream: T) -> Result<Self> {
let ssl = InnerSsl::new(ssl, stream)
.explain_err(TLSHandshakeFailure, |e| format!("ssl stream error: {e}"))?;
Ok(SslStream {
ssl,
digest: None,
timing: Default::default(),
})
}
pub async fn connect(&mut self) -> Result<(), ssl::Error> {
Self::clear_error();
Pin::new(&mut self.ssl).connect().await?;
self.timing.established_ts = SystemTime::now();
self.digest = Some(Arc::new(SslDigest::from_ssl(self.ssl())));
Ok(())
}
pub async fn accept(&mut self) -> Result<(), ssl::Error> {
Self::clear_error();
Pin::new(&mut self.ssl).accept().await?;
self.timing.established_ts = SystemTime::now();
self.digest = Some(Arc::new(SslDigest::from_ssl(self.ssl())));
Ok(())
}
#[inline]
fn clear_error() {
let errs = tls::error::ErrorStack::get();
if !errs.errors().is_empty() {
warn!("Clearing dirty TLS error stack: {}", errs);
}
}
}
impl<T> SslStream<T> {
pub fn ssl_digest(&self) -> Option<Arc<SslDigest>> {
self.digest.clone()
}
}
use std::ops::{Deref, DerefMut};
impl<T> Deref for SslStream<T> {
type Target = InnerSsl<T>;
fn deref(&self) -> &Self::Target {
&self.ssl
}
}
impl<T> DerefMut for SslStream<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.ssl
}
}
impl<T> AsyncRead for SslStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Self::clear_error();
Pin::new(&mut self.ssl).poll_read(cx, buf)
}
}
impl<T> AsyncWrite for SslStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Self::clear_error();
Pin::new(&mut self.ssl).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Self::clear_error();
Pin::new(&mut self.ssl).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Self::clear_error();
Pin::new(&mut self.ssl).poll_shutdown(cx)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Self::clear_error();
Pin::new(&mut self.ssl).poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
true
}
}
impl<T> UniqueID for SslStream<T>
where
T: UniqueID,
{
fn id(&self) -> i32 {
self.ssl.get_ref().id()
}
}
impl<T> Ssl for SslStream<T> {
fn get_ssl(&self) -> Option<&ssl::SslRef> {
Some(self.ssl())
}
fn get_ssl_digest(&self) -> Option<Arc<SslDigest>> {
self.ssl_digest()
}
}
#[derive(Hash, Clone, Debug)]
pub enum ALPN {
H1,
H2,
H2H1,
}
impl std::fmt::Display for ALPN {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ALPN::H1 => write!(f, "H1"),
ALPN::H2 => write!(f, "H2"),
ALPN::H2H1 => write!(f, "H2H1"),
}
}
}
impl ALPN {
pub fn new(max: u8, min: u8) -> Self {
if max == 1 {
ALPN::H1
} else if min == 2 {
ALPN::H2
} else {
ALPN::H2H1
}
}
pub fn get_max_http_version(&self) -> u8 {
match self {
ALPN::H1 => 1,
_ => 2,
}
}
pub fn get_min_http_version(&self) -> u8 {
match self {
ALPN::H2 => 2,
_ => 1,
}
}
pub(crate) fn to_wire_preference(&self) -> &[u8] {
match self {
Self::H1 => b"\x08http/1.1",
Self::H2 => b"\x02h2",
Self::H2H1 => b"\x02h2\x08http/1.1",
}
}
pub(crate) fn from_wire_selected(raw: &[u8]) -> Option<Self> {
match raw {
b"http/1.1" => Some(Self::H1),
b"h2" => Some(Self::H2),
_ => None,
}
}
}