#![warn(missing_docs)]
#![doc(html_root_url = "https://docs.rs/hyper-openssl/0.8")]
use antidote::Mutex;
use bytes::{Buf, BufMut};
use hyper::client::connect::{Connect, Connected, Destination};
#[cfg(feature = "runtime")]
use hyper::client::HttpConnector;
use openssl::error::ErrorStack;
use openssl::ex_data::Index;
use openssl::ssl::{
ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslSessionCacheMode,
};
#[cfg(feature = "runtime")]
use openssl::ssl::SslMethod;
use std::error::Error;
use std::fmt::Debug;
use std::io;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_openssl::SslStream;
use cache::{SessionCache, SessionKey};
use std::future::Future;
use std::pin::Pin;
use std::task::{Poll, Context};
use once_cell::sync::OnceCell;
mod cache;
#[cfg(test)]
mod test;
fn key_index() -> Result<Index<Ssl, SessionKey>, ErrorStack> {
static IDX: OnceCell<Index<Ssl, SessionKey>> = OnceCell::new();
IDX.get_or_try_init(|| Ssl::new_ex_index()).map(|v| *v)
}
#[derive(Clone)]
struct Inner {
ssl: SslConnector,
cache: Arc<Mutex<SessionCache>>,
callback: Option<
Arc<dyn Fn(&mut ConnectConfiguration, &Destination) -> Result<(), ErrorStack> + Sync + Send>,
>,
}
impl Inner {
fn setup_ssl(&self, destination: &Destination) -> Result<ConnectConfiguration, ErrorStack> {
let mut conf = self.ssl.configure()?;
if let Some(ref callback) = self.callback {
callback(&mut conf, destination)?;
}
let key = SessionKey {
host: destination.host().to_string(),
port: destination.port().unwrap_or(443),
};
if let Some(session) = self.cache.lock().get(&key) {
unsafe {
conf.set_session(&session)?;
}
}
let idx = key_index()?;
conf.set_ex_data(idx, key);
Ok(conf)
}
}
#[derive(Clone)]
pub struct HttpsConnector<T> {
http: T,
inner: Inner,
}
#[cfg(feature = "runtime")]
impl HttpsConnector<HttpConnector> {
pub fn new() -> Result<HttpsConnector<HttpConnector>, ErrorStack> {
let mut http = HttpConnector::new();
http.enforce_http(false);
let mut ssl = SslConnector::builder(SslMethod::tls())?;
ssl = ssl;
#[cfg(ossl102)]
ssl.set_alpn_protos(b"\x02h2\x08http/1.1")?;
HttpsConnector::with_connector(http, ssl)
}
}
impl<T> HttpsConnector<T>
where
T: Connect,
T::Transport: Debug + Sync + Send,
{
pub fn with_connector(
http: T,
mut ssl: SslConnectorBuilder,
) -> Result<HttpsConnector<T>, ErrorStack> {
let cache = Arc::new(Mutex::new(SessionCache::new()));
ssl.set_session_cache_mode(SslSessionCacheMode::CLIENT);
ssl.set_new_session_callback({
let cache = cache.clone();
move |ssl, session| {
if let Some(key) = key_index().ok().and_then(|idx| ssl.ex_data(idx)) {
cache.lock().insert(key.clone(), session);
}
}
});
ssl.set_remove_session_callback({
let cache = cache.clone();
move |_, session| cache.lock().remove(session)
});
Ok(HttpsConnector {
http,
inner: Inner {
ssl: ssl.build(),
cache,
callback: None,
},
})
}
pub fn set_callback<F>(&mut self, callback: F)
where
F: Fn(&mut ConnectConfiguration, &Destination) -> Result<(), ErrorStack>
+ 'static
+ Sync
+ Send,
{
self.inner.callback = Some(Arc::new(callback));
}
}
impl<T> Connect for HttpsConnector<T>
where
T: Connect,
T::Transport: Debug + Sync,
T::Future: 'static,
{
type Transport = MaybeHttpsStream<T::Transport>;
type Error = Box<dyn Error + Sync + Send>;
type Future =
Pin<Box<dyn Future<Output = Result<(Self::Transport, Connected), Self::Error>> + Send>>;
fn connect(&self, destination: Destination) -> Self::Future {
let tls_setup = if destination.scheme() == "https" {
Some((self.inner.clone(), destination.clone()))
} else {
None
};
let connect = self.http.connect(destination);
let f = async {
let (conn, mut connected) = connect.await.map_err(Into::into)?;
let (inner, destination) = match tls_setup {
Some((inner, destination)) => (inner, destination),
None => return Ok((MaybeHttpsStream::Http(conn), connected)),
};
let config = inner.setup_ssl(&destination)?;
let stream = tokio_openssl::connect(config, destination.host(), conn).await?;
connected = connected;
#[cfg(ossl102)]
{
if let Some(b"h2") = stream.ssl().selected_alpn_protocol() {
connected = connected.negotiated_h2();
}
}
Ok((MaybeHttpsStream::Https(stream), connected))
};
Box::pin(f)
}
}
pub enum MaybeHttpsStream<T> {
Http(T),
Https(SslStream<T>),
}
impl<T> AsyncRead for MaybeHttpsStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
match &*self {
MaybeHttpsStream::Http(s) => s.prepare_uninitialized_buffer(buf),
MaybeHttpsStream::Https(s) => s.prepare_uninitialized_buffer(buf),
}
}
fn poll_read(mut self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
match &mut *self {
MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(ctx, buf),
MaybeHttpsStream::Https(s) => Pin::new(s).poll_read(ctx, buf),
}
}
fn poll_read_buf<B>(mut self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &mut B) -> Poll<io::Result<usize>>
where
B: BufMut,
{
match &mut *self {
MaybeHttpsStream::Http(s) => Pin::new(s).poll_read_buf(ctx, buf),
MaybeHttpsStream::Https(s) => Pin::new(s).poll_read_buf(ctx, buf),
}
}
}
impl<T> AsyncWrite for MaybeHttpsStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(mut self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
match &mut *self {
MaybeHttpsStream::Http(s) => Pin::new(s).poll_write(ctx, buf),
MaybeHttpsStream::Https(s) => Pin::new(s).poll_write(ctx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeHttpsStream::Http(s) => Pin::new(s).poll_flush(ctx),
MaybeHttpsStream::Https(s) => Pin::new(s).poll_flush(ctx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeHttpsStream::Http(s) => Pin::new(s).poll_shutdown(ctx),
MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(ctx),
}
}
fn poll_write_buf<B>(mut self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &mut B) -> Poll<io::Result<usize>>
where
B: Buf,
{
match &mut *self {
MaybeHttpsStream::Http(s) => Pin::new(s).poll_write_buf(ctx, buf),
MaybeHttpsStream::Https(s) => Pin::new(s).poll_write_buf(ctx, buf),
}
}
}