#[cfg(feature = "openssl-probe")]
use std::sync::Once;
use openssl::error::ErrorStack;
use openssl::pkey::PKey;
use openssl::ssl::{
self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
SslVerifyMode,
};
use openssl::x509::{store::X509StoreBuilder, X509VerifyResult, X509};
use std::error;
use std::fmt;
use std::io;
use super::{Protocol, TlsConnectorBuilder};
use openssl::pkey::Private;
#[cfg(have_min_max_version)]
fn supported_protocols(
min: Option<Protocol>,
max: Option<Protocol>,
ctx: &mut SslContextBuilder,
) -> Result<(), ErrorStack> {
use openssl::ssl::SslVersion;
fn cvt(p: Protocol) -> SslVersion {
match p {
Protocol::Sslv3 => SslVersion::SSL3,
Protocol::Tlsv10 => SslVersion::TLS1,
Protocol::Tlsv11 => SslVersion::TLS1_1,
Protocol::Tlsv12 => SslVersion::TLS1_2,
Protocol::__NonExhaustive => unreachable!(),
}
}
ctx.set_min_proto_version(min.map(cvt))?;
ctx.set_max_proto_version(max.map(cvt))?;
Ok(())
}
#[cfg(not(have_min_max_version))]
fn supported_protocols(
min: Option<Protocol>,
max: Option<Protocol>,
ctx: &mut SslContextBuilder,
) -> Result<(), ErrorStack> {
use openssl::ssl::SslOptions;
let no_ssl_mask = SslOptions::NO_SSLV2
| SslOptions::NO_SSLV3
| SslOptions::NO_TLSV1
| SslOptions::NO_TLSV1_1
| SslOptions::NO_TLSV1_2;
ctx.clear_options(no_ssl_mask);
let mut options = SslOptions::empty();
options |= match min {
None => SslOptions::empty(),
Some(Protocol::Sslv3) => SslOptions::NO_SSLV2,
Some(Protocol::Tlsv10) => SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3,
Some(Protocol::Tlsv11) => {
SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1
}
Some(Protocol::Tlsv12) => {
SslOptions::NO_SSLV2
| SslOptions::NO_SSLV3
| SslOptions::NO_TLSV1
| SslOptions::NO_TLSV1_1
}
Some(Protocol::__NonExhaustive) => unreachable!(),
};
options |= match max {
None | Some(Protocol::Tlsv12) => SslOptions::empty(),
Some(Protocol::Tlsv11) => SslOptions::NO_TLSV1_2,
Some(Protocol::Tlsv10) => SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2,
Some(Protocol::Sslv3) => {
SslOptions::NO_TLSV1 | SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2
}
Some(Protocol::__NonExhaustive) => unreachable!(),
};
ctx.set_options(options);
Ok(())
}
#[cfg(feature = "openssl-probe")]
fn init_trust() {
static ONCE: Once = Once::new();
ONCE.call_once(openssl_probe::init_ssl_cert_env_vars);
}
#[cfg(target_os = "android")]
fn load_android_root_certs(connector: &mut SslContextBuilder) -> Result<(), Error> {
use std::fs;
if let Ok(dir) = fs::read_dir("/system/etc/security/cacerts") {
let certs = dir
.filter_map(|r| r.ok())
.filter_map(|e| fs::read(e.path()).ok())
.filter_map(|b| X509::from_pem(&b).ok());
for cert in certs {
if let Err(err) = connector.cert_store_mut().add_cert(cert) {
debug!("load_android_root_certs error: {:?}", err);
}
}
}
Ok(())
}
#[derive(Debug)]
pub enum Error {
Normal(ErrorStack),
Ssl(ssl::Error, X509VerifyResult),
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match *self {
Error::Normal(ref e) => error::Error::source(e),
Error::Ssl(ref e, _) => error::Error::source(e),
}
}
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match *self {
Error::Normal(ref e) => fmt::Display::fmt(e, fmt),
Error::Ssl(ref e, X509VerifyResult::OK) => fmt::Display::fmt(e, fmt),
Error::Ssl(ref e, v) => write!(fmt, "{} ({})", e, v),
}
}
}
impl From<ErrorStack> for Error {
fn from(err: ErrorStack) -> Error {
Error::Normal(err)
}
}
#[derive(Clone)]
pub struct Identity {
pkey: PKey<Private>,
cert: X509,
chain: Vec<X509>,
}
#[derive(Clone)]
pub struct Certificate(X509);
pub struct MidHandshakeTlsStream<S>(MidHandshakeSslStream<S>);
impl<S> fmt::Debug for MidHandshakeTlsStream<S>
where
S: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
pub enum HandshakeError<S> {
Failure(Error),
WouldBlock(MidHandshakeTlsStream<S>),
}
impl<S> From<ssl::HandshakeError<S>> for HandshakeError<S> {
fn from(e: ssl::HandshakeError<S>) -> HandshakeError<S> {
match e {
ssl::HandshakeError::SetupFailure(e) => HandshakeError::Failure(e.into()),
ssl::HandshakeError::Failure(e) => {
let v = e.ssl().verify_result();
HandshakeError::Failure(Error::Ssl(e.into_error(), v))
}
ssl::HandshakeError::WouldBlock(s) => {
HandshakeError::WouldBlock(MidHandshakeTlsStream(s))
}
}
}
}
impl<S> From<ErrorStack> for HandshakeError<S> {
fn from(e: ErrorStack) -> HandshakeError<S> {
HandshakeError::Failure(e.into())
}
}
#[derive(Clone)]
pub struct TlsConnector {
connector: SslConnector,
use_sni: bool,
accept_invalid_hostnames: bool,
accept_invalid_certs: bool,
}
impl TlsConnector {
pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
#[cfg(feature = "openssl-probe")]
init_trust();
let mut connector = SslConnector::builder(SslMethod::tls())?;
if let Some(ref identity) = builder.identity {
connector.set_certificate(&identity.0.cert)?;
connector.set_private_key(&identity.0.pkey)?;
for cert in identity.0.chain.iter().rev() {
connector.add_extra_chain_cert(cert.to_owned())?;
}
}
supported_protocols(builder.min_protocol, builder.max_protocol, &mut connector)?;
if builder.disable_built_in_roots {
connector.set_cert_store(X509StoreBuilder::new()?.build());
}
for cert in &builder.root_certificates {
if let Err(err) = connector.cert_store_mut().add_cert((cert.0).0.clone()) {
debug!("add_cert error: {:?}", err);
}
}
#[cfg(target_os = "android")]
load_android_root_certs(&mut connector)?;
Ok(TlsConnector {
connector: connector.build(),
use_sni: builder.use_sni,
accept_invalid_hostnames: builder.accept_invalid_hostnames,
accept_invalid_certs: builder.accept_invalid_certs,
})
}
pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
{
let mut ssl = self
.connector
.configure()?
.use_server_name_indication(self.use_sni)
.verify_hostname(!self.accept_invalid_hostnames);
if self.accept_invalid_certs {
ssl.set_verify(SslVerifyMode::NONE);
}
let s = ssl.connect(domain, stream)?;
Ok(TlsStream(s))
}
}
impl fmt::Debug for TlsConnector {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("TlsConnector")
.field("use_sni", &self.use_sni)
.field("accept_invalid_hostnames", &self.accept_invalid_hostnames)
.field("accept_invalid_certs", &self.accept_invalid_certs)
.finish()
}
}
#[derive(Clone)]
pub struct TlsAcceptor(SslAcceptor);
pub struct TlsStream<S>(ssl::SslStream<S>);
impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl<S> TlsStream<S> {
pub fn get_ref(&self) -> &S {
self.0.get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.0.get_mut()
}
}
impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}