use std::fs::File;
use std::io::{self, BufReader, Read, Write};
use std::sync::Arc;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::crypto::{ring as crypto, CryptoProvider, WebPkiSupportedAlgorithms};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{ClientConfig, ClientConnection, DigitallySignedStruct, SignatureScheme};
use zeroize::Zeroize;
use super::client_auth;
use super::common::ProtocolVersion;
use crate::error::{Error, Result};
pub use rustls::RootCertStore;
const SYSTEM_CA_PATHS: &[&str] = &[
"/etc/ssl/certs/ca-certificates.crt",
"/etc/pki/tls/certs/ca-bundle.crt",
"/etc/ssl/cert.pem",
"/etc/ssl/ca-bundle.pem",
"/etc/ca-certificates/extracted/tls-ca-bundle.pem",
];
#[derive(Clone)]
pub struct TlsOpts {
pub alpn: Vec<Vec<u8>>,
pub verify: bool,
pub roots: Option<RootCertStore>,
pub min_version: Option<ProtocolVersion>,
pub max_version: Option<ProtocolVersion>,
pub client_cert: Option<Vec<u8>>,
pub client_key: Option<Vec<u8>>,
pub client_key_pass: Option<String>,
pub cert_is_der: bool,
pub key_is_der: bool,
pub pinned_spki_sha256: Vec<[u8; 32]>,
pub crl_pem: Option<Vec<u8>>,
pub cipher_suites: Vec<u16>,
}
impl TlsOpts {
pub fn verifying() -> Self {
TlsOpts {
alpn: Vec::new(),
verify: true,
roots: None,
min_version: None,
max_version: None,
client_cert: None,
client_key: None,
client_key_pass: None,
cert_is_der: false,
key_is_der: false,
pinned_spki_sha256: Vec::new(),
crl_pem: None,
cipher_suites: Vec::new(),
}
}
}
impl Default for TlsOpts {
fn default() -> Self {
TlsOpts::verifying()
}
}
impl Drop for TlsOpts {
fn drop(&mut self) {
self.client_key.zeroize();
self.client_key_pass.zeroize();
}
}
pub fn load_system_roots() -> Result<RootCertStore> {
for path in SYSTEM_CA_PATHS {
let file = match File::open(path) {
Ok(f) => f,
Err(e) if e.kind() == io::ErrorKind::NotFound => continue,
Err(e) => return Err(Error::Io(e)),
};
return parse_roots(BufReader::new(file), path);
}
Err(Error::BadResponse(
"no system CA bundle found; tried common Unix paths".into(),
))
}
pub fn load_roots_from_file(path: &str) -> Result<RootCertStore> {
let file = File::open(path).map_err(Error::Io)?;
parse_roots(BufReader::new(file), path)
}
pub fn load_roots_from_dir(base: Option<RootCertStore>, dir: &str) -> Result<RootCertStore> {
let mut roots = match base {
Some(r) => r,
None => load_system_roots()?,
};
let mut added = 0usize;
for entry in std::fs::read_dir(dir).map_err(Error::Io)? {
let entry = entry.map_err(Error::Io)?;
let path = entry.path();
if !path.is_file() {
continue;
}
let Ok(file) = File::open(&path) else {
continue;
};
let mut reader = BufReader::new(file);
let Ok(certs) =
rustls_pemfile::certs(&mut reader).collect::<std::result::Result<Vec<_>, _>>()
else {
continue; };
let (n, _ignored) = roots.add_parsable_certificates(certs);
added += n;
}
if added == 0 {
return Err(Error::BadResponse(format!(
"--capath {dir}: no usable CA certificates found"
)));
}
Ok(roots)
}
fn parse_roots<R: io::BufRead>(mut reader: R, path: &str) -> Result<RootCertStore> {
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::BadResponse(format!("PEM parse error in {path}: {e}")))?;
let mut roots = RootCertStore::empty();
let (added, _ignored) = roots.add_parsable_certificates(certs);
if added == 0 {
return Err(Error::BadResponse(format!(
"no usable CA certificates parsed from {path}"
)));
}
Ok(roots)
}
pub struct TlsStream<S: Read + Write> {
conn: ClientConnection,
sock: S,
version: Option<ProtocolVersion>,
alpn: Option<Vec<u8>>,
peer_certs_der: Vec<Vec<u8>>,
dirty_eof: bool,
}
pub fn connect_over<S: Read + Write>(transport: S, sni: &str) -> Result<TlsStream<S>> {
connect_over_tls(transport, sni, TlsOpts::verifying())
}
pub fn connect_over_with_alpn<S: Read + Write>(
transport: S,
sni: &str,
alpn: &[&[u8]],
) -> Result<TlsStream<S>> {
let mut opts = TlsOpts::verifying();
opts.alpn = alpn.iter().map(|p| p.to_vec()).collect();
connect_over_tls(transport, sni, opts)
}
pub fn connect_over_tls<S: Read + Write>(
transport: S,
sni: &str,
mut opts: TlsOpts,
) -> Result<TlsStream<S>> {
if opts.crl_pem.is_some() {
return Err(Error::BadResponse(
"--crlfile is not supported by the rustls-tls backend; \
build with the default purecrypto-tls backend for CRL checking"
.into(),
));
}
if !opts.cipher_suites.is_empty() {
return Err(Error::BadResponse(
"--ciphers/--tls13-ciphers is not supported by the rustls-tls backend; \
build with the default purecrypto-tls backend"
.into(),
));
}
let roots = match opts.roots.take() {
Some(r) => r,
None => load_system_roots()?,
};
let rank = |v: ProtocolVersion| match v {
ProtocolVersion::TLSv1_3 => 3u8,
_ => 2u8,
};
let min = opts.min_version.map(rank).unwrap_or(0);
let max = opts.max_version.map(rank).unwrap_or(u8::MAX);
let versions: Vec<&'static rustls::SupportedProtocolVersion> =
if opts.min_version.is_none() && opts.max_version.is_none() {
rustls::ALL_VERSIONS.to_vec()
} else {
[&rustls::version::TLS12, &rustls::version::TLS13]
.into_iter()
.filter(|v| {
let r = match v.version {
rustls::ProtocolVersion::TLSv1_3 => 3u8,
_ => 2u8,
};
r >= min && r <= max
})
.collect()
};
let builder = ClientConfig::builder_with_protocol_versions(&versions);
let identity = if let Some(cert_bytes) = &opts.client_cert {
Some(build_identity(
cert_bytes,
opts.client_key.as_deref(),
opts.client_key_pass.as_deref(),
opts.cert_is_der,
opts.key_is_der,
)?)
} else {
None
};
let verified = builder.with_root_certificates(roots);
let dangerous = ClientConfig::builder_with_protocol_versions(&versions)
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerify::new()));
let mut config = match (opts.verify, identity) {
(true, Some((chain, key))) => verified
.with_client_auth_cert(chain, key)
.map_err(rustls_err)?,
(true, None) => verified.with_no_client_auth(),
(false, Some((chain, key))) => dangerous
.with_client_auth_cert(chain, key)
.map_err(rustls_err)?,
(false, None) => dangerous.with_no_client_auth(),
};
config.alpn_protocols = std::mem::take(&mut opts.alpn);
let server_name: ServerName<'static> = ServerName::try_from(sni.to_string())
.map_err(|e| Error::BadResponse(format!("invalid SNI {sni:?}: {e}")))?;
let conn = ClientConnection::new(Arc::new(config), server_name).map_err(rustls_err)?;
let mut s = TlsStream {
conn,
sock: transport,
version: None,
alpn: None,
peer_certs_der: Vec::new(),
dirty_eof: false,
};
s.run_handshake()?;
s.snapshot_post_handshake();
if !opts.pinned_spki_sha256.is_empty() {
let leaf = s.peer_certificates().first().map(Vec::as_slice);
match leaf {
Some(der) if client_auth::spki_pin_matches(der, &opts.pinned_spki_sha256) => {}
_ => {
return Err(Error::BadResponse(
"pinned public key does not match server certificate".into(),
))
}
}
}
Ok(s)
}
fn build_identity(
cert_bytes: &[u8],
key_bytes: Option<&[u8]>,
pass: Option<&str>,
cert_is_der: bool,
key_is_der: bool,
) -> Result<(
Vec<CertificateDer<'static>>,
rustls::pki_types::PrivateKeyDer<'static>,
)> {
use rustls::pki_types::PrivateKeyDer;
let chain: Vec<CertificateDer<'static>> = if cert_is_der {
vec![CertificateDer::from(cert_bytes.to_vec())]
} else {
let mut reader = BufReader::new(cert_bytes);
let certs = rustls_pemfile::certs(&mut reader)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::BadResponse(format!("client cert: PEM parse error: {e}")))?;
if certs.is_empty() {
return Err(Error::BadResponse(
"client cert: file contains no CERTIFICATE blocks".into(),
));
}
certs
};
let key: PrivateKeyDer<'static> = if key_is_der {
let kb = key_bytes.ok_or_else(|| {
Error::BadResponse("client cert: a DER key needs --key (no embedded key)".into())
})?;
PrivateKeyDer::try_from(kb.to_vec())
.map_err(|e| Error::BadResponse(format!("client key (DER): {e}")))?
} else {
let src = key_bytes.unwrap_or(cert_bytes);
let mut reader = BufReader::new(src);
match rustls_pemfile::private_key(&mut reader) {
Ok(Some(k)) => k,
Ok(None) => {
return Err(Error::BadResponse(
"client key: no private key found in the PEM \
(encrypted keys are not supported by the rustls backend)"
.into(),
))
}
Err(e) => {
return Err(Error::BadResponse(format!(
"client key: PEM parse error: {e}"
)))
}
}
};
let _ = pass;
Ok((chain, key))
}
impl<S: Read + Write> TlsStream<S> {
pub fn negotiated_version(&self) -> Option<ProtocolVersion> {
self.version
}
pub fn alpn_selected(&self) -> Option<&[u8]> {
self.alpn.as_deref()
}
pub fn peer_certificates(&self) -> &[Vec<u8>] {
&self.peer_certs_der
}
pub fn was_truncated(&self) -> bool {
self.dirty_eof
}
fn run_handshake(&mut self) -> Result<()> {
while self.conn.is_handshaking() {
let mut did_something = false;
if self.conn.wants_write() {
self.conn.write_tls(&mut self.sock).map_err(Error::Io)?;
did_something = true;
}
if self.conn.is_handshaking() && self.conn.wants_read() {
let n = self.conn.read_tls(&mut self.sock).map_err(Error::Io)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
self.conn.process_new_packets().map_err(rustls_err)?;
did_something = true;
}
if !did_something {
self.conn.process_new_packets().map_err(rustls_err)?;
}
}
while self.conn.wants_write() {
self.conn.write_tls(&mut self.sock).map_err(Error::Io)?;
}
Ok(())
}
fn snapshot_post_handshake(&mut self) {
self.version = self.conn.protocol_version().map(map_rustls_version);
self.alpn = self.conn.alpn_protocol().map(|p| p.to_vec());
self.peer_certs_der = self
.conn
.peer_certificates()
.map(|certs| certs.iter().map(|c| c.to_vec()).collect())
.unwrap_or_default();
}
}
impl<S: Read + Write> Write for TlsStream<S> {
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
let n = self.conn.writer().write(data)?;
while self.conn.wants_write() {
self.conn.write_tls(&mut self.sock)?;
}
Ok(n)
}
fn flush(&mut self) -> io::Result<()> {
while self.conn.wants_write() {
self.conn.write_tls(&mut self.sock)?;
}
self.sock.flush()
}
}
impl<S: Read + Write> Read for TlsStream<S> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if dst.is_empty() {
return Ok(0);
}
loop {
match self.conn.reader().read(dst) {
Ok(0) => return Ok(0), Ok(n) => return Ok(n),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
self.dirty_eof = true;
return Ok(0);
}
Err(e) => return Err(e),
}
while self.conn.wants_write() {
self.conn.write_tls(&mut self.sock)?;
}
if !self.conn.wants_read() {
return Ok(0);
}
let n = self.conn.read_tls(&mut self.sock)?;
if n == 0 {
return match self.conn.reader().read(dst) {
Ok(n) => Ok(n),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(0),
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
self.dirty_eof = true;
Ok(0)
}
Err(e) => Err(e),
};
}
self.conn
.process_new_packets()
.map_err(|e| io::Error::other(format!("tls: {e}")))?;
}
}
}
#[derive(Debug)]
struct NoVerify {
sig_algs: WebPkiSupportedAlgorithms,
}
impl NoVerify {
fn new() -> Self {
let provider: CryptoProvider = crypto::default_provider();
Self {
sig_algs: provider.signature_verification_algorithms,
}
}
}
impl ServerCertVerifier for NoVerify {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> std::result::Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(message, cert, dss, &self.sig_algs)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(message, cert, dss, &self.sig_algs)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.sig_algs.supported_schemes()
}
}
fn map_rustls_version(v: rustls::ProtocolVersion) -> ProtocolVersion {
use rustls::ProtocolVersion as R;
match v {
R::TLSv1_2 => ProtocolVersion::TLSv1_2,
R::TLSv1_3 => ProtocolVersion::TLSv1_3,
other => ProtocolVersion::Other(u16::from(other)),
}
}
fn rustls_err(e: rustls::Error) -> Error {
Error::BadResponse(format!("tls: {e}"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_enables_verification() {
assert!(TlsOpts::default().verify);
assert!(TlsOpts::verifying().verify);
}
}