use std::path::PathBuf;
#[cfg(any(feature = "emulation", feature = "btls-backend"))]
use crate::browser_emulation::BoringTlsFingerprint;
use crate::error::{Error, ErrorKind, Result};
use crate::request::ProtocolPolicy;
use sha2::{Digest, Sha256};
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum RootStore {
System,
WebPki,
PemFile(PathBuf),
Pem(String),
}
#[derive(Clone, Debug, Default)]
pub struct TlsConfig {
pub(crate) accept_invalid_certs: bool,
pub(crate) backend: Option<TlsBackend>,
pub(crate) alpn_protocols: Option<Vec<String>>,
pub(crate) cipher_suites: Option<Vec<String>>,
pub(crate) min_tls_version: Option<String>,
pub(crate) max_tls_version: Option<String>,
#[cfg(any(feature = "emulation", feature = "btls-backend"))]
pub(crate) boring_tls_fingerprint: Option<BoringTlsFingerprint>,
pub(crate) root_store: Option<RootStore>,
pub(crate) pinned_certificates: Vec<PinnedCertificate>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PinnedCertificate {
pub(crate) domain: String,
pub(crate) fingerprint: Vec<u8>,
}
impl TlsConfig {
pub fn danger_accept_invalid_certs(mut self, enabled: bool) -> Self {
self.accept_invalid_certs = enabled;
self
}
pub fn backend(mut self, backend: TlsBackend) -> Self {
self.backend = Some(backend);
self
}
pub fn alpn_protocols<I, S>(mut self, protocols: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.alpn_protocols = Some(
protocols
.into_iter()
.map(|protocol| protocol.as_ref().to_owned())
.collect(),
);
self
}
pub fn cipher_suites<I, S>(mut self, suites: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.cipher_suites = Some(
suites
.into_iter()
.map(|suite| suite.as_ref().to_owned())
.collect(),
);
self
}
pub fn min_tls_version(mut self, version: impl AsRef<str>) -> Self {
self.min_tls_version = Some(version.as_ref().to_owned());
self
}
pub fn max_tls_version(mut self, version: impl AsRef<str>) -> Self {
self.max_tls_version = Some(version.as_ref().to_owned());
self
}
#[cfg(feature = "emulation")]
pub(crate) fn with_boring_tls_fingerprint(mut self, fingerprint: BoringTlsFingerprint) -> Self {
self.boring_tls_fingerprint = Some(fingerprint);
self
}
#[cfg(any(feature = "emulation", feature = "btls-backend"))]
pub(crate) fn boring_tls_fingerprint(&self) -> Option<&BoringTlsFingerprint> {
self.boring_tls_fingerprint.as_ref()
}
pub fn disable_alpn(mut self) -> Self {
self.alpn_protocols = Some(Vec::new());
self
}
pub fn root_store(mut self, root_store: RootStore) -> Self {
self.root_store = Some(root_store);
self
}
pub fn pin_certificate(
mut self,
domain: impl AsRef<str>,
fingerprint: impl AsRef<str>,
) -> Result<Self> {
self.pinned_certificates
.push(PinnedCertificate::new(domain, fingerprint)?);
Ok(self)
}
#[cfg(feature = "emulation")]
pub(crate) fn ensure_emulation_backend(mut self) -> Result<Self> {
#[cfg(feature = "btls-backend")]
{
match self.backend {
Some(TlsBackend::Boring) => Ok(self),
Some(_) => Err(Error::new(
ErrorKind::Transport,
"browser emulation requires the BoringSSL backend",
)),
None => {
self.backend = Some(TlsBackend::Boring);
Ok(self)
}
}
}
#[cfg(not(feature = "btls-backend"))]
{
let _ = &mut self;
Err(Error::new(
ErrorKind::Transport,
"browser emulation requires the btls-backend feature",
))
}
}
pub(crate) fn effective_alpn_protocols(&self, protocol_policy: ProtocolPolicy) -> Vec<String> {
if let Some(protocols) = &self.alpn_protocols {
return protocols.clone();
}
match protocol_policy {
ProtocolPolicy::Http3Only => vec!["h3".to_owned()],
_ => vec!["http/1.1".to_owned()],
}
}
pub(crate) fn validate_http1_alpn(
&self,
protocol_policy: ProtocolPolicy,
) -> std::result::Result<Vec<String>, &'static str> {
let protocols = self.effective_alpn_protocols(protocol_policy);
if protocols.is_empty() || protocols.iter().all(|protocol| protocol == "http/1.1") {
return Ok(protocols);
}
Err("custom ALPN protocols are not supported by the current HTTP/1 TLS transport")
}
#[cfg(feature = "h2")]
pub(crate) fn validate_h2_alpn(&self) -> std::result::Result<Vec<String>, &'static str> {
let protocols = self
.alpn_protocols
.clone()
.unwrap_or_else(|| vec!["h2".to_owned()]);
if protocols.iter().any(|protocol| protocol == "h2") {
return Ok(protocols);
}
Err("HTTP/2 TLS transport requires ALPN to include h2")
}
#[cfg(feature = "h3")]
pub(crate) fn validate_h3_alpn(&self) -> std::result::Result<Vec<String>, &'static str> {
let protocols = self
.alpn_protocols
.clone()
.unwrap_or_else(|| vec!["h3".to_owned()]);
if protocols.iter().any(|protocol| protocol == "h3") {
return Ok(protocols);
}
Err("HTTP/3 transport requires ALPN to include h3")
}
}
#[cfg(feature = "rustls")]
fn cipher_from_str(name: &str) -> Option<rustls::SupportedCipherSuite> {
let name = name.trim();
let direct = rustls::ALL_CIPHER_SUITES
.iter()
.find(|suite| suite.suite().as_str().map_or(false, |n| n == name))
.cloned();
if direct.is_some() {
return direct;
}
if name.starts_with("TLS_") && !name.starts_with("TLS13_") && !name.contains("_WITH_") {
let alt = format!("TLS13_{}", &name["TLS_".len()..]);
return rustls::ALL_CIPHER_SUITES
.iter()
.find(|suite| suite.suite().as_str().map_or(false, |n| n == alt))
.cloned();
}
None
}
#[cfg(feature = "rustls")]
fn version_gte(ver: &rustls::SupportedProtocolVersion, target: &str) -> bool {
let target = target.trim();
match (ver.version, target) {
(rustls::ProtocolVersion::TLSv1_3, _) => true,
(rustls::ProtocolVersion::TLSv1_2, "1.2") => true,
(rustls::ProtocolVersion::TLSv1_2, "1.1") => true,
(rustls::ProtocolVersion::TLSv1_2, "1.0") => true,
_ => false,
}
}
#[cfg(feature = "rustls")]
fn version_lte(ver: &rustls::SupportedProtocolVersion, target: &str) -> bool {
let target = target.trim();
match (ver.version, target) {
(rustls::ProtocolVersion::TLSv1_2, "1.2") => true,
(rustls::ProtocolVersion::TLSv1_3, "1.2") => false,
(rustls::ProtocolVersion::TLSv1_3, "1.3") => true,
(rustls::ProtocolVersion::TLSv1_2, "1.3") => true,
_ => false,
}
}
impl RootStore {
pub fn system() -> Self {
Self::System
}
pub fn webpki() -> Self {
Self::WebPki
}
pub fn pem_file(path: impl Into<PathBuf>) -> Self {
Self::PemFile(path.into())
}
pub fn pem(pem: impl Into<String>) -> Self {
Self::Pem(pem.into())
}
pub(crate) fn pem_bytes(&self) -> Result<Option<Vec<u8>>> {
match self {
Self::PemFile(path) => std::fs::read(path).map(Some).map_err(|err| {
Error::with_source(
ErrorKind::Transport,
format!("failed to read PEM root store from {}", path.display()),
err,
)
}),
Self::Pem(pem) => Ok(Some(pem.as_bytes().to_vec())),
Self::System | Self::WebPki => Ok(None),
}
}
}
impl PinnedCertificate {
pub fn new(domain: impl AsRef<str>, fingerprint: impl AsRef<str>) -> Result<Self> {
let domain = domain.as_ref().trim().to_ascii_lowercase();
if domain.is_empty() {
return Err(Error::new(
ErrorKind::InvalidUrl,
"pinned certificate domain cannot be empty",
));
}
let fingerprint = parse_fingerprint(fingerprint.as_ref())?;
Ok(Self {
domain,
fingerprint,
})
}
}
pub(crate) fn verify_pinned_certificate(
pins: &[PinnedCertificate],
host: &str,
cert_der: &[u8],
) -> Result<()> {
let host = host.to_ascii_lowercase();
let Some(pin) = pins.iter().find(|pin| pin.domain == host) else {
return Ok(());
};
let actual = Sha256::digest(cert_der);
if actual.as_slice() == pin.fingerprint.as_slice() {
return Ok(());
}
Err(Error::new(
ErrorKind::Transport,
format!("pinned certificate fingerprint mismatch for {host}"),
))
}
fn parse_fingerprint(input: &str) -> Result<Vec<u8>> {
let normalized = input
.trim()
.strip_prefix("sha256/")
.unwrap_or(input.trim())
.chars()
.filter(|ch| !matches!(ch, ':' | ' ' | '-'))
.collect::<String>();
if normalized.len() != 64 || !normalized.chars().all(|ch| ch.is_ascii_hexdigit()) {
return Err(Error::new(
ErrorKind::Transport,
"certificate fingerprint must be 32-byte SHA-256 hex",
));
}
let mut bytes = Vec::with_capacity(32);
let chars = normalized.as_bytes();
for i in (0..chars.len()).step_by(2) {
let hi = decode_hex(chars[i])?;
let lo = decode_hex(chars[i + 1])?;
bytes.push((hi << 4) | lo);
}
Ok(bytes)
}
fn decode_hex(byte: u8) -> Result<u8> {
match byte {
b'0'..=b'9' => Ok(byte - b'0'),
b'a'..=b'f' => Ok(byte - b'a' + 10),
b'A'..=b'F' => Ok(byte - b'A' + 10),
_ => Err(Error::new(
ErrorKind::Transport,
"certificate fingerprint must be valid hex",
)),
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum TlsBackend {
#[cfg(feature = "rustls")]
Rustls,
#[cfg(feature = "native-tls")]
Native,
#[cfg(feature = "btls-backend")]
Boring,
}
#[cfg(feature = "rustls")]
mod rustls_backend {
use std::io::BufReader;
use std::sync::Arc;
use std::time::SystemTime;
use rustls::client::{
HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier, WebPkiVerifier,
};
use rustls::{
Certificate, ClientConfig, DigitallySignedStruct, Error, RootCertStore, ServerName,
SignatureScheme,
};
use super::{RootStore, TlsConfig, verify_pinned_certificate};
impl TlsConfig {
pub(crate) fn build_client_config(
&self,
protocol_policy: crate::request::ProtocolPolicy,
) -> std::result::Result<Arc<ClientConfig>, &'static str> {
self.build_client_config_with_alpn(self.validate_http1_alpn(protocol_policy)?)
}
#[cfg(feature = "h2")]
pub(crate) fn build_h2_client_config(
&self,
) -> std::result::Result<Arc<ClientConfig>, &'static str> {
self.build_client_config_with_alpn(self.validate_h2_alpn()?)
}
fn build_client_config_with_alpn(
&self,
alpn_protocols: Vec<String>,
) -> std::result::Result<Arc<ClientConfig>, &'static str> {
let root_store = build_root_store(self)?;
let mut versions: Vec<&'static rustls::SupportedProtocolVersion> =
rustls::ALL_VERSIONS.iter().copied().collect();
if let Some(min) = &self.min_tls_version {
versions.retain(|v| super::version_gte(v, min));
}
if let Some(max) = &self.max_tls_version {
versions.retain(|v| super::version_lte(v, max));
}
if versions.is_empty() {
return Err("no TLS versions remain after filtering");
}
let mut cipher_suites: Vec<rustls::SupportedCipherSuite> =
if let Some(suites) = &self.cipher_suites {
suites
.iter()
.filter_map(|n| super::cipher_from_str(n))
.collect()
} else {
rustls::ALL_CIPHER_SUITES.iter().cloned().collect()
};
cipher_suites.retain(|s| versions.contains(&s.version()));
if cipher_suites.is_empty() {
return Err("no TLS cipher suites remain after filtering");
}
let mut config = rustls::ClientConfig::builder()
.with_cipher_suites(&cipher_suites)
.with_kx_groups(&rustls::ALL_KX_GROUPS)
.with_protocol_versions(&versions)
.map_err(|_| "invalid tls versions")?
.with_root_certificates(root_store.clone())
.with_no_client_auth();
if self.accept_invalid_certs || !self.pinned_certificates.is_empty() {
let verifier = build_certificate_verifier(self, root_store)?;
config.dangerous().set_certificate_verifier(verifier);
}
config.alpn_protocols = alpn_protocols
.into_iter()
.map(|protocol| protocol.into_bytes())
.collect();
Ok(Arc::new(config))
}
}
fn build_root_store(config: &TlsConfig) -> std::result::Result<RootCertStore, &'static str> {
match config.root_store.as_ref().unwrap_or(&RootStore::WebPki) {
RootStore::WebPki => {
let mut root_store = RootCertStore::empty();
root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
Ok(root_store)
}
RootStore::System => {
let mut root_store = RootCertStore::empty();
let certs = rustls_native_certs::load_native_certs();
for err in certs.errors {
let _ = err;
}
let (added, _) = root_store.add_parsable_certificates(&certs.certs);
if added == 0 {
return Err("failed to load any system root certificates");
}
Ok(root_store)
}
RootStore::PemFile(_) | RootStore::Pem(_) => {
let pem = config
.root_store
.as_ref()
.and_then(|store| store.pem_bytes().ok())
.flatten()
.ok_or("failed to load PEM root store")?;
let mut root_store = RootCertStore::empty();
let mut reader = BufReader::new(pem.as_slice());
let certs = rustls_pemfile::certs(&mut reader)
.map_err(|_| "failed to parse PEM root certificates")?;
let (added, _) = root_store.add_parsable_certificates(&certs);
if added == 0 {
return Err("failed to add any PEM root certificates");
}
Ok(root_store)
}
}
}
fn build_certificate_verifier(
config: &TlsConfig,
root_store: RootCertStore,
) -> std::result::Result<Arc<dyn ServerCertVerifier>, &'static str> {
let inner = if config.accept_invalid_certs {
None
} else {
Some(WebPkiVerifier::new(root_store, None))
};
Ok(Arc::new(PinnedVerifier {
inner,
pins: config.pinned_certificates.clone(),
}))
}
struct PinnedVerifier {
inner: Option<WebPkiVerifier>,
pins: Vec<super::PinnedCertificate>,
}
impl std::fmt::Debug for PinnedVerifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PinnedVerifier")
.field("has_inner", &self.inner.is_some())
.field("pins", &self.pins)
.finish()
}
}
impl ServerCertVerifier for PinnedVerifier {
fn verify_server_cert(
&self,
end_entity: &Certificate,
intermediates: &[Certificate],
server_name: &ServerName,
scts: &mut dyn Iterator<Item = &[u8]>,
ocsp_response: &[u8],
now: SystemTime,
) -> std::result::Result<ServerCertVerified, Error> {
if let Some(inner) = &self.inner {
inner.verify_server_cert(
end_entity,
intermediates,
server_name,
scts,
ocsp_response,
now,
)?;
}
if let ServerName::DnsName(name) = server_name {
verify_pinned_certificate(&self.pins, name.as_ref(), end_entity.0.as_ref())
.map_err(|err| Error::General(err.to_string()))?;
}
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &Certificate,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, Error> {
if let Some(inner) = &self.inner {
return inner.verify_tls12_signature(message, cert, dss);
}
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &Certificate,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, Error> {
if let Some(inner) = &self.inner {
return inner.verify_tls13_signature(message, cert, dss);
}
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
if let Some(inner) = &self.inner {
return inner.supported_verify_schemes();
}
vec![
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::ED25519,
]
}
}
}
#[cfg(test)]
mod tests {
use super::TlsConfig;
use crate::ProtocolPolicy;
#[test]
fn tls_config_defaults_to_http1_alpn() {
let protocols = TlsConfig::default().effective_alpn_protocols(ProtocolPolicy::Auto);
assert_eq!(protocols, vec!["http/1.1".to_owned()]);
}
#[test]
fn tls_config_can_disable_alpn() {
let protocols = TlsConfig::default()
.disable_alpn()
.effective_alpn_protocols(ProtocolPolicy::Auto);
assert!(protocols.is_empty());
}
#[test]
fn tls_config_rejects_unsupported_http1_alpn_override() {
let err = TlsConfig::default()
.alpn_protocols(["h2"])
.validate_http1_alpn(ProtocolPolicy::Auto)
.unwrap_err();
assert_eq!(
err,
"custom ALPN protocols are not supported by the current HTTP/1 TLS transport"
);
}
#[cfg(feature = "h2")]
#[test]
fn tls_config_allows_h2_alpn_for_h2_transport() {
let protocols = TlsConfig::default()
.alpn_protocols(["h2", "http/1.1"])
.validate_h2_alpn()
.unwrap();
assert_eq!(protocols, vec!["h2".to_owned(), "http/1.1".to_owned()]);
}
#[cfg(feature = "rustls")]
#[test]
fn tls_config_accepts_iana_tls13_cipher_suite_names() {
let config = TlsConfig::default()
.cipher_suites(["TLS_AES_128_GCM_SHA256"])
.min_tls_version("1.3")
.max_tls_version("1.3");
let client_config = config
.build_client_config(ProtocolPolicy::Http1Only)
.expect("iana cipher suite name should be accepted");
assert_eq!(client_config.alpn_protocols, vec![b"http/1.1".to_vec()]);
}
#[cfg(feature = "rustls")]
#[test]
fn tls_config_rejects_unknown_cipher_suite_names() {
let config = TlsConfig::default()
.cipher_suites(["TLS_NOT_A_REAL_CIPHER_SUITE"])
.min_tls_version("1.3")
.max_tls_version("1.3");
let err = config
.build_client_config(ProtocolPolicy::Http1Only)
.unwrap_err();
assert!(err.contains("cipher suites"));
}
#[cfg(feature = "rustls")]
#[test]
fn tls_config_rejects_invalid_tls_version_range() {
let config = TlsConfig::default()
.min_tls_version("1.3")
.max_tls_version("1.2");
let err = config
.build_client_config(ProtocolPolicy::Http1Only)
.unwrap_err();
assert!(err.contains("TLS versions"));
}
}