use super::SslStream;
use crate::protocols::raw_connect::ProxyDigest;
use crate::protocols::{
GetProxyDigest, GetSocketDigest, GetTimingDigest, SocketDigest, TimingDigest, IO,
};
use crate::tls::{ssl, ssl::ConnectConfiguration, ssl_sys::X509_V_ERR_INVALID_CALL};
use pingora_error::{Error, ErrorType::*, OrErr, Result};
use std::sync::Arc;
pub async fn handshake<S: IO>(
conn_config: ConnectConfiguration,
domain: &str,
io: S,
) -> Result<SslStream<S>> {
let ssl = conn_config
.into_ssl(domain)
.explain_err(TLSHandshakeFailure, |e| format!("ssl config error: {e}"))?;
let mut stream = SslStream::new(ssl, io)
.explain_err(TLSHandshakeFailure, |e| format!("ssl stream error: {e}"))?;
let handshake_result = stream.connect().await;
match handshake_result {
Ok(()) => Ok(stream),
Err(e) => {
let context = format!("TLS connect() failed: {e}, SNI: {domain}");
match e.code() {
ssl::ErrorCode::SSL => {
#[cfg(not(feature = "boringssl"))]
fn verify_result<S>(stream: SslStream<S>) -> Result<(), i32> {
match stream.ssl().verify_result().as_raw() {
crate::tls::ssl_sys::X509_V_OK => Ok(()),
e => Err(e),
}
}
#[cfg(feature = "boringssl")]
fn verify_result<S>(stream: SslStream<S>) -> Result<(), i32> {
stream.ssl().verify_result().map_err(|e| e.as_raw())
}
match verify_result(stream) {
Ok(()) => Error::e_explain(TLSHandshakeFailure, context),
Err(X509_V_ERR_INVALID_CALL) => {
Error::e_explain(TLSHandshakeFailure, context)
}
_ => Error::e_explain(InvalidCert, context),
}
}
_ => Error::e_explain(TLSHandshakeFailure, context),
}
}
}
}
impl<S> GetTimingDigest for SslStream<S>
where
S: GetTimingDigest,
{
fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
let mut ts_vec = self.get_ref().get_timing_digest();
ts_vec.push(Some(self.timing.clone()));
ts_vec
}
}
impl<S> GetProxyDigest for SslStream<S>
where
S: GetProxyDigest,
{
fn get_proxy_digest(&self) -> Option<Arc<ProxyDigest>> {
self.get_ref().get_proxy_digest()
}
}
impl<S> GetSocketDigest for SslStream<S>
where
S: GetSocketDigest,
{
fn get_socket_digest(&self) -> Option<Arc<SocketDigest>> {
self.get_ref().get_socket_digest()
}
fn set_socket_digest(&mut self, socket_digest: SocketDigest) {
self.get_mut().set_socket_digest(socket_digest)
}
}