use std::fs::File;
use std::io::{self, BufReader, Read, Write};
use std::sync::Arc;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::crypto::{ring as crypto, CryptoProvider, WebPkiSupportedAlgorithms};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{ClientConfig, ClientConnection, DigitallySignedStruct, SignatureScheme};
use super::common::ProtocolVersion;
use crate::error::{Error, Result};
pub use rustls::RootCertStore;
const SYSTEM_CA_PATHS: &[&str] = &[
"/etc/ssl/certs/ca-certificates.crt",
"/etc/pki/tls/certs/ca-bundle.crt",
"/etc/ssl/cert.pem",
"/etc/ssl/ca-bundle.pem",
"/etc/ca-certificates/extracted/tls-ca-bundle.pem",
];
#[derive(Default, Clone)]
pub struct TlsOpts {
pub alpn: Vec<Vec<u8>>,
pub verify: bool,
pub roots: Option<RootCertStore>,
}
impl TlsOpts {
pub fn verifying() -> Self {
TlsOpts {
alpn: Vec::new(),
verify: true,
roots: None,
}
}
}
pub fn load_system_roots() -> Result<RootCertStore> {
for path in SYSTEM_CA_PATHS {
let file = match File::open(path) {
Ok(f) => f,
Err(e) if e.kind() == io::ErrorKind::NotFound => continue,
Err(e) => return Err(Error::Io(e)),
};
return parse_roots(BufReader::new(file), path);
}
Err(Error::BadResponse(
"no system CA bundle found; tried common Unix paths".into(),
))
}
pub fn load_roots_from_file(path: &str) -> Result<RootCertStore> {
let file = File::open(path).map_err(Error::Io)?;
parse_roots(BufReader::new(file), path)
}
fn parse_roots<R: io::BufRead>(mut reader: R, path: &str) -> Result<RootCertStore> {
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::BadResponse(format!("PEM parse error in {path}: {e}")))?;
let mut roots = RootCertStore::empty();
let (added, _ignored) = roots.add_parsable_certificates(certs);
if added == 0 {
return Err(Error::BadResponse(format!(
"no usable CA certificates parsed from {path}"
)));
}
Ok(roots)
}
pub struct TlsStream<S: Read + Write> {
conn: ClientConnection,
sock: S,
version: Option<ProtocolVersion>,
alpn: Option<Vec<u8>>,
peer_certs_der: Vec<Vec<u8>>,
}
pub fn connect_over<S: Read + Write>(transport: S, sni: &str) -> Result<TlsStream<S>> {
connect_over_tls(transport, sni, TlsOpts::verifying())
}
pub fn connect_over_with_alpn<S: Read + Write>(
transport: S,
sni: &str,
alpn: &[&[u8]],
) -> Result<TlsStream<S>> {
let mut opts = TlsOpts::verifying();
opts.alpn = alpn.iter().map(|p| p.to_vec()).collect();
connect_over_tls(transport, sni, opts)
}
pub fn connect_over_tls<S: Read + Write>(
transport: S,
sni: &str,
opts: TlsOpts,
) -> Result<TlsStream<S>> {
let roots = match opts.roots {
Some(r) => r,
None => load_system_roots()?,
};
let mut config = if opts.verify {
ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth()
} else {
ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerify::new()))
.with_no_client_auth()
};
config.alpn_protocols = opts.alpn;
let server_name: ServerName<'static> = ServerName::try_from(sni.to_string())
.map_err(|e| Error::BadResponse(format!("invalid SNI {sni:?}: {e}")))?;
let conn = ClientConnection::new(Arc::new(config), server_name).map_err(rustls_err)?;
let mut s = TlsStream {
conn,
sock: transport,
version: None,
alpn: None,
peer_certs_der: Vec::new(),
};
s.run_handshake()?;
s.snapshot_post_handshake();
Ok(s)
}
impl<S: Read + Write> TlsStream<S> {
pub fn negotiated_version(&self) -> Option<ProtocolVersion> {
self.version
}
pub fn alpn_selected(&self) -> Option<&[u8]> {
self.alpn.as_deref()
}
pub fn peer_certificates(&self) -> &[Vec<u8>] {
&self.peer_certs_der
}
fn run_handshake(&mut self) -> Result<()> {
while self.conn.is_handshaking() {
let mut did_something = false;
if self.conn.wants_write() {
self.conn.write_tls(&mut self.sock).map_err(Error::Io)?;
did_something = true;
}
if self.conn.is_handshaking() && self.conn.wants_read() {
let n = self.conn.read_tls(&mut self.sock).map_err(Error::Io)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
self.conn.process_new_packets().map_err(rustls_err)?;
did_something = true;
}
if !did_something {
self.conn.process_new_packets().map_err(rustls_err)?;
}
}
while self.conn.wants_write() {
self.conn.write_tls(&mut self.sock).map_err(Error::Io)?;
}
Ok(())
}
fn snapshot_post_handshake(&mut self) {
self.version = self.conn.protocol_version().map(map_rustls_version);
self.alpn = self.conn.alpn_protocol().map(|p| p.to_vec());
self.peer_certs_der = self
.conn
.peer_certificates()
.map(|certs| certs.iter().map(|c| c.to_vec()).collect())
.unwrap_or_default();
}
}
impl<S: Read + Write> Write for TlsStream<S> {
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
let n = self.conn.writer().write(data)?;
while self.conn.wants_write() {
self.conn.write_tls(&mut self.sock)?;
}
Ok(n)
}
fn flush(&mut self) -> io::Result<()> {
while self.conn.wants_write() {
self.conn.write_tls(&mut self.sock)?;
}
self.sock.flush()
}
}
impl<S: Read + Write> Read for TlsStream<S> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if dst.is_empty() {
return Ok(0);
}
loop {
match self.conn.reader().read(dst) {
Ok(0) => return Ok(0), Ok(n) => return Ok(n),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(0),
Err(e) => return Err(e),
}
while self.conn.wants_write() {
self.conn.write_tls(&mut self.sock)?;
}
if !self.conn.wants_read() {
return Ok(0);
}
let n = self.conn.read_tls(&mut self.sock)?;
if n == 0 {
return match self.conn.reader().read(dst) {
Ok(n) => Ok(n),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(0),
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(0),
Err(e) => Err(e),
};
}
self.conn
.process_new_packets()
.map_err(|e| io::Error::other(format!("tls: {e}")))?;
}
}
}
#[derive(Debug)]
struct NoVerify {
sig_algs: WebPkiSupportedAlgorithms,
}
impl NoVerify {
fn new() -> Self {
let provider: CryptoProvider = crypto::default_provider();
Self {
sig_algs: provider.signature_verification_algorithms,
}
}
}
impl ServerCertVerifier for NoVerify {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> std::result::Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(message, cert, dss, &self.sig_algs)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(message, cert, dss, &self.sig_algs)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.sig_algs.supported_schemes()
}
}
fn map_rustls_version(v: rustls::ProtocolVersion) -> ProtocolVersion {
use rustls::ProtocolVersion as R;
match v {
R::TLSv1_2 => ProtocolVersion::TLSv1_2,
R::TLSv1_3 => ProtocolVersion::TLSv1_3,
other => ProtocolVersion::Other(u16::from(other)),
}
}
fn rustls_err(e: rustls::Error) -> Error {
Error::BadResponse(format!("tls: {e}"))
}