#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
use std::any::Any;
use std::{error, fmt, io, result};
#[cfg_attr(target_vendor = "apple", path = "imp/security_framework.rs")]
#[cfg_attr(target_os = "windows", path = "imp/schannel.rs")]
#[cfg_attr(
not(any(target_vendor = "apple", target_os = "windows")),
path = "imp/openssl.rs"
)]
mod imp;
#[cfg(test)]
mod test;
pub type Result<T> = result::Result<T, Error>;
pub struct Error(imp::Error);
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
error::Error::source(&self.0)
}
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0, fmt)
}
}
impl fmt::Debug for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl From<imp::Error> for Error {
fn from(err: imp::Error) -> Error {
Error(err)
}
}
#[derive(Clone)]
pub struct Identity(imp::Identity);
impl Identity {
pub fn from_pkcs12(der: &[u8], password: &str) -> Result<Identity> {
let identity = imp::Identity::from_pkcs12(der, password)?;
Ok(Identity(identity))
}
#[cfg_attr(all(target_vendor = "apple", not(target_os = "macos")), deprecated(note = "Not available on iOS"))]
pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Result<Identity> {
let identity = imp::Identity::from_pkcs8(pem, key)?;
Ok(Identity(identity))
}
}
#[derive(Clone)]
pub struct Certificate(imp::Certificate);
impl Certificate {
pub fn from_der(der: &[u8]) -> Result<Certificate> {
let cert = imp::Certificate::from_der(der)?;
Ok(Certificate(cert))
}
#[cfg_attr(all(target_vendor = "apple", not(target_os = "macos")), deprecated(note = "Not available on iOS"))]
pub fn from_pem(pem: &[u8]) -> Result<Certificate> {
let cert = imp::Certificate::from_pem(pem)?;
Ok(Certificate(cert))
}
#[cfg_attr(all(target_vendor = "apple", not(target_os = "macos")), deprecated(note = "Not available on iOS"))]
pub fn stack_from_pem(buf: &[u8]) -> Result<Vec<Certificate>> {
let certs = imp::Certificate::stack_from_pem(buf)?;
Ok(certs.into_iter().map(Certificate).collect())
}
pub fn to_der(&self) -> Result<Vec<u8>> {
let der = self.0.to_der()?;
Ok(der)
}
}
pub struct MidHandshakeTlsStream<S>(imp::MidHandshakeTlsStream<S>);
impl<S> fmt::Debug for MidHandshakeTlsStream<S>
where
S: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl<S> MidHandshakeTlsStream<S> {
#[must_use]
pub fn get_ref(&self) -> &S {
self.0.get_ref()
}
#[must_use]
pub fn get_mut(&mut self) -> &mut S {
self.0.get_mut()
}
}
impl<S> MidHandshakeTlsStream<S>
where
S: io::Read + io::Write,
{
pub fn handshake(self) -> result::Result<TlsStream<S>, HandshakeError<S>> {
match self.0.handshake() {
Ok(s) => Ok(TlsStream(s)),
Err(e) => Err(e.into()),
}
}
}
#[derive(Debug)]
pub enum HandshakeError<S> {
Failure(Error),
WouldBlock(MidHandshakeTlsStream<S>),
}
impl<S> error::Error for HandshakeError<S>
where
S: Any + fmt::Debug,
{
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match *self {
HandshakeError::Failure(ref e) => Some(e),
HandshakeError::WouldBlock(_) => None,
}
}
}
impl<S> fmt::Display for HandshakeError<S>
where
S: Any + fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match *self {
HandshakeError::Failure(ref e) => fmt::Display::fmt(e, fmt),
HandshakeError::WouldBlock(_) => fmt.write_str("the handshake process was interrupted"),
}
}
}
impl<S> From<imp::HandshakeError<S>> for HandshakeError<S> {
fn from(e: imp::HandshakeError<S>) -> HandshakeError<S> {
match e {
imp::HandshakeError::Failure(e) => Self::Failure(Error(e)),
imp::HandshakeError::WouldBlock(s) => Self::WouldBlock(MidHandshakeTlsStream(s)),
}
}
}
#[derive(Debug, Copy, Clone)]
#[non_exhaustive]
pub enum Protocol {
Sslv3,
Tlsv10,
Tlsv11,
Tlsv12,
Tlsv13,
}
#[allow(clippy::struct_excessive_bools)]
pub struct TlsConnectorBuilder {
identity: Option<Identity>,
min_protocol: Option<Protocol>,
max_protocol: Option<Protocol>,
root_certificates: Vec<Certificate>,
accept_invalid_certs: bool,
accept_invalid_hostnames: bool,
use_sni: bool,
disable_built_in_roots: bool,
#[cfg(feature = "alpn")]
alpn: Vec<Box<str>>,
}
impl TlsConnectorBuilder {
pub fn identity(&mut self, identity: Identity) -> &mut TlsConnectorBuilder {
self.identity = Some(identity);
self
}
pub fn min_protocol_version(&mut self, protocol: Option<Protocol>) -> &mut TlsConnectorBuilder {
self.min_protocol = protocol;
self
}
pub fn max_protocol_version(&mut self, protocol: Option<Protocol>) -> &mut TlsConnectorBuilder {
self.max_protocol = protocol;
self
}
pub fn add_root_certificate(&mut self, cert: Certificate) -> &mut TlsConnectorBuilder {
self.root_certificates.push(cert);
self
}
pub fn disable_built_in_roots(&mut self, disable: bool) -> &mut TlsConnectorBuilder {
self.disable_built_in_roots = disable;
self
}
#[cfg(feature = "alpn")]
#[cfg_attr(docsrs, doc(cfg(feature = "alpn")))]
pub fn request_alpns(&mut self, protocols: &[&str]) -> &mut TlsConnectorBuilder {
self.alpn = protocols.iter().copied().map(Box::from).collect();
self
}
pub fn danger_accept_invalid_certs(&mut self, accept_invalid_certs: bool) -> &mut Self {
self.accept_invalid_certs = accept_invalid_certs;
self
}
pub fn use_sni(&mut self, use_sni: bool) -> &mut TlsConnectorBuilder {
self.use_sni = use_sni;
self
}
pub fn danger_accept_invalid_hostnames(&mut self, accept_invalid_hostnames: bool) -> &mut Self {
self.accept_invalid_hostnames = accept_invalid_hostnames;
self
}
pub fn build(&self) -> Result<TlsConnector> {
let connector = imp::TlsConnector::new(self)?;
Ok(TlsConnector(connector))
}
}
#[derive(Clone, Debug)]
pub struct TlsConnector(imp::TlsConnector);
impl TlsConnector {
pub fn new() -> Result<TlsConnector> {
TlsConnector::builder().build()
}
#[must_use]
pub fn builder() -> TlsConnectorBuilder {
TlsConnectorBuilder {
identity: None,
min_protocol: Some(Protocol::Tlsv12),
max_protocol: None,
root_certificates: vec![],
use_sni: true,
accept_invalid_certs: false,
accept_invalid_hostnames: false,
disable_built_in_roots: false,
#[cfg(feature = "alpn")]
alpn: vec![],
}
}
pub fn connect<S>(
&self,
domain: &str,
stream: S,
) -> result::Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
{
let s = self.0.connect(domain, stream)?;
Ok(TlsStream(s))
}
}
pub struct TlsAcceptorBuilder {
identity: Identity,
min_protocol: Option<Protocol>,
max_protocol: Option<Protocol>,
#[cfg(feature = "alpn-accept")]
accept_alpn: Vec<Box<str>>,
}
impl TlsAcceptorBuilder {
pub fn min_protocol_version(&mut self, protocol: Option<Protocol>) -> &mut TlsAcceptorBuilder {
self.min_protocol = protocol;
self
}
pub fn max_protocol_version(&mut self, protocol: Option<Protocol>) -> &mut TlsAcceptorBuilder {
self.max_protocol = protocol;
self
}
#[cfg(feature = "alpn-accept")]
#[cfg_attr(docsrs, doc(cfg(feature = "alpn-accept")))]
pub fn accept_alpn(&mut self, protocols: &[impl AsRef<str>]) -> &mut TlsAcceptorBuilder {
self.accept_alpn = protocols.iter().map(|s| Box::from(s.as_ref())).collect();
self
}
pub fn build(&self) -> Result<TlsAcceptor> {
let acceptor = imp::TlsAcceptor::new(self)?;
Ok(TlsAcceptor(acceptor))
}
}
#[derive(Clone)]
pub struct TlsAcceptor(imp::TlsAcceptor);
impl TlsAcceptor {
pub fn new(identity: Identity) -> Result<TlsAcceptor> {
TlsAcceptor::builder(identity).build()
}
#[must_use]
pub fn builder(identity: Identity) -> TlsAcceptorBuilder {
TlsAcceptorBuilder {
identity,
min_protocol: Some(Protocol::Tlsv12),
max_protocol: None,
#[cfg(feature = "alpn-accept")]
accept_alpn: vec![],
}
}
pub fn accept<S>(&self, stream: S) -> result::Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
{
match self.0.accept(stream) {
Ok(s) => Ok(TlsStream(s)),
Err(e) => Err(e.into()),
}
}
}
pub struct TlsStream<S>(imp::TlsStream<S>);
impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl<S> TlsStream<S> {
#[must_use]
pub fn get_ref(&self) -> &S {
self.0.get_ref()
}
#[must_use]
pub fn get_mut(&mut self) -> &mut S {
self.0.get_mut()
}
}
impl<S: io::Read + io::Write> TlsStream<S> {
pub fn buffered_read_size(&self) -> Result<usize> {
Ok(self.0.buffered_read_size()?)
}
pub fn peer_certificate(&self) -> Result<Option<Certificate>> {
Ok(self.0.peer_certificate()?.map(Certificate))
}
pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>> {
Ok(self.0.tls_server_end_point()?)
}
#[cfg(feature = "alpn")]
#[cfg_attr(docsrs, doc(cfg(feature = "alpn")))]
pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>> {
Ok(self.0.negotiated_alpn()?)
}
pub fn shutdown(&mut self) -> io::Result<()> {
self.0.shutdown()?;
Ok(())
}
}
impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
fn _check_kinds() {
use std::net::TcpStream;
fn is_sync<T: Sync>() {}
fn is_send<T: Send>() {}
is_sync::<Error>();
is_send::<Error>();
is_sync::<TlsConnectorBuilder>();
is_send::<TlsConnectorBuilder>();
is_sync::<TlsConnector>();
is_send::<TlsConnector>();
is_sync::<TlsAcceptorBuilder>();
is_send::<TlsAcceptorBuilder>();
is_sync::<TlsAcceptor>();
is_send::<TlsAcceptor>();
is_sync::<TlsStream<TcpStream>>();
is_send::<TlsStream<TcpStream>>();
is_sync::<MidHandshakeTlsStream<TcpStream>>();
is_send::<MidHandshakeTlsStream<TcpStream>>();
}