#[path = "../openssl/ssl_data_31.rs"]
mod ssl_data;
use crate::socket::{SelectKind, timeout_error_msg};
use crate::vm::VirtualMachine;
use alloc::sync::Arc;
use parking_lot::RwLock as ParkingRwLock;
use rustls::RootCertStore;
use rustls::client::ClientConfig;
use rustls::client::ClientConnection;
use rustls::crypto::SupportedKxGroup;
use rustls::pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer};
use rustls::server::ResolvesServerCert;
use rustls::server::ServerConfig;
use rustls::server::ServerConnection;
use rustls::sign::CertifiedKey;
use rustpython_vm::builtins::{PyBaseException, PyBaseExceptionRef};
use rustpython_vm::convert::IntoPyException;
use rustpython_vm::function::ArgBytesLike;
use rustpython_vm::{AsObject, Py, PyObjectRef, PyPayload, PyResult, TryFromObject};
use std::io::Read;
use std::sync::Once;
use super::_ssl::PySSLSocket;
use super::error::{
PySSLCertVerificationError, PySSLError, create_ssl_eof_error, create_ssl_syscall_error,
create_ssl_want_read_error, create_ssl_want_write_error, create_ssl_zero_return_error,
};
pub const VERIFY_X509_STRICT: i32 = 0x20;
pub const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000;
static INIT_PROVIDER: Once = Once::new();
fn ensure_default_provider() {
INIT_PROVIDER.call_once(|| {
let _ = rustls::crypto::CryptoProvider::install_default(
rustls::crypto::aws_lc_rs::default_provider(),
);
});
}
const SSL3_RT_MAX_PLAIN_LENGTH: usize = 16384;
const ERR_LIB_SSL: i32 = 20;
const SSL_R_NO_SHARED_CIPHER: i32 = 193;
const X509_V_FLAG_CRL_CHECK: i32 = 4;
pub use x509::{
X509_V_ERR_CERT_HAS_EXPIRED, X509_V_ERR_CERT_NOT_YET_VALID, X509_V_ERR_CERT_REVOKED,
X509_V_ERR_HOSTNAME_MISMATCH, X509_V_ERR_INVALID_PURPOSE, X509_V_ERR_IP_ADDRESS_MISMATCH,
X509_V_ERR_UNABLE_TO_GET_CRL, X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY,
X509_V_ERR_UNSPECIFIED,
};
#[allow(dead_code)]
mod x509 {
pub const X509_V_OK: i32 = 0;
pub const X509_V_ERR_UNSPECIFIED: i32 = 1;
pub const X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT: i32 = 2;
pub const X509_V_ERR_UNABLE_TO_GET_CRL: i32 = 3;
pub const X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE: i32 = 4;
pub const X509_V_ERR_UNABLE_TO_DECRYPT_CRL_SIGNATURE: i32 = 5;
pub const X509_V_ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY: i32 = 6;
pub const X509_V_ERR_CERT_SIGNATURE_FAILURE: i32 = 7;
pub const X509_V_ERR_CRL_SIGNATURE_FAILURE: i32 = 8;
pub const X509_V_ERR_CERT_NOT_YET_VALID: i32 = 9;
pub const X509_V_ERR_CERT_HAS_EXPIRED: i32 = 10;
pub const X509_V_ERR_CRL_NOT_YET_VALID: i32 = 11;
pub const X509_V_ERR_CRL_HAS_EXPIRED: i32 = 12;
pub const X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD: i32 = 13;
pub const X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD: i32 = 14;
pub const X509_V_ERR_ERROR_IN_CRL_LAST_UPDATE_FIELD: i32 = 15;
pub const X509_V_ERR_ERROR_IN_CRL_NEXT_UPDATE_FIELD: i32 = 16;
pub const X509_V_ERR_OUT_OF_MEM: i32 = 17;
pub const X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT: i32 = 18;
pub const X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN: i32 = 19;
pub const X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY: i32 = 20;
pub const X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE: i32 = 21;
pub const X509_V_ERR_CERT_CHAIN_TOO_LONG: i32 = 22;
pub const X509_V_ERR_CERT_REVOKED: i32 = 23;
pub const X509_V_ERR_INVALID_CA: i32 = 24;
pub const X509_V_ERR_PATH_LENGTH_EXCEEDED: i32 = 25;
pub const X509_V_ERR_INVALID_PURPOSE: i32 = 26;
pub const X509_V_ERR_CERT_UNTRUSTED: i32 = 27;
pub const X509_V_ERR_CERT_REJECTED: i32 = 28;
pub const X509_V_ERR_SUBJECT_ISSUER_MISMATCH: i32 = 29;
pub const X509_V_ERR_AKID_SKID_MISMATCH: i32 = 30;
pub const X509_V_ERR_AKID_ISSUER_SERIAL_MISMATCH: i32 = 31;
pub const X509_V_ERR_KEYUSAGE_NO_CERTSIGN: i32 = 32;
pub const X509_V_ERR_UNABLE_TO_GET_CRL_ISSUER: i32 = 33;
pub const X509_V_ERR_UNHANDLED_CRITICAL_EXTENSION: i32 = 34;
pub const X509_V_ERR_KEYUSAGE_NO_CRL_SIGN: i32 = 35;
pub const X509_V_ERR_UNHANDLED_CRITICAL_CRL_EXTENSION: i32 = 36;
pub const X509_V_ERR_INVALID_NON_CA: i32 = 37;
pub const X509_V_ERR_PROXY_PATH_LENGTH_EXCEEDED: i32 = 38;
pub const X509_V_ERR_KEYUSAGE_NO_DIGITAL_SIGNATURE: i32 = 39;
pub const X509_V_ERR_PROXY_CERTIFICATES_NOT_ALLOWED: i32 = 40;
pub const X509_V_ERR_INVALID_EXTENSION: i32 = 41;
pub const X509_V_ERR_INVALID_POLICY_EXTENSION: i32 = 42;
pub const X509_V_ERR_NO_EXPLICIT_POLICY: i32 = 43;
pub const X509_V_ERR_DIFFERENT_CRL_SCOPE: i32 = 44;
pub const X509_V_ERR_UNSUPPORTED_EXTENSION_FEATURE: i32 = 45;
pub const X509_V_ERR_UNNESTED_RESOURCE: i32 = 46;
pub const X509_V_ERR_PERMITTED_VIOLATION: i32 = 47;
pub const X509_V_ERR_EXCLUDED_VIOLATION: i32 = 48;
pub const X509_V_ERR_SUBTREE_MINMAX: i32 = 49;
pub const X509_V_ERR_APPLICATION_VERIFICATION: i32 = 50;
pub const X509_V_ERR_UNSUPPORTED_CONSTRAINT_TYPE: i32 = 51;
pub const X509_V_ERR_UNSUPPORTED_CONSTRAINT_SYNTAX: i32 = 52;
pub const X509_V_ERR_UNSUPPORTED_NAME_SYNTAX: i32 = 53;
pub const X509_V_ERR_CRL_PATH_VALIDATION_ERROR: i32 = 54;
pub const X509_V_ERR_HOSTNAME_MISMATCH: i32 = 62;
pub const X509_V_ERR_EMAIL_MISMATCH: i32 = 63;
pub const X509_V_ERR_IP_ADDRESS_MISMATCH: i32 = 64;
}
fn rustls_cert_error_to_verify_info(cert_err: &rustls::CertificateError) -> (i32, &'static str) {
use rustls::CertificateError;
match cert_err {
CertificateError::UnknownIssuer => (
X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY,
"unable to get local issuer certificate",
),
CertificateError::Expired => (X509_V_ERR_CERT_HAS_EXPIRED, "certificate has expired"),
CertificateError::NotValidYet => (
X509_V_ERR_CERT_NOT_YET_VALID,
"certificate is not yet valid",
),
CertificateError::Revoked => (X509_V_ERR_CERT_REVOKED, "certificate revoked"),
CertificateError::UnknownRevocationStatus => (
X509_V_ERR_UNABLE_TO_GET_CRL,
"unable to get certificate CRL",
),
CertificateError::InvalidPurpose => (
X509_V_ERR_INVALID_PURPOSE,
"unsupported certificate purpose",
),
CertificateError::Other(other_err) => {
let err_msg = format!("{other_err:?}");
if err_msg.contains("Hostname mismatch") || err_msg.contains("not valid for") {
(
X509_V_ERR_HOSTNAME_MISMATCH,
"Hostname mismatch, certificate is not valid for",
)
} else if err_msg.contains("IP address mismatch") {
(
X509_V_ERR_IP_ADDRESS_MISMATCH,
"IP address mismatch, certificate is not valid for",
)
} else {
(X509_V_ERR_UNSPECIFIED, "certificate verification failed")
}
}
_ => (X509_V_ERR_UNSPECIFIED, "certificate verification failed"),
}
}
pub(super) fn create_ssl_cert_verification_error(
vm: &VirtualMachine,
cert_err: &rustls::CertificateError,
) -> PyResult<PyBaseExceptionRef> {
let (verify_code, verify_message) = rustls_cert_error_to_verify_info(cert_err);
let msg =
format!("[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: {verify_message}",);
let exc = vm.new_os_subtype_error(
PySSLCertVerificationError::class(&vm.ctx).to_owned(),
None,
msg,
);
exc.as_object().set_attr(
"verify_code",
vm.ctx.new_int(verify_code).as_object().to_owned(),
vm,
)?;
exc.as_object().set_attr(
"verify_message",
vm.ctx.new_str(verify_message).as_object().to_owned(),
vm,
)?;
exc.as_object()
.set_attr("library", vm.ctx.new_str("SSL").as_object().to_owned(), vm)?;
exc.as_object().set_attr(
"reason",
vm.ctx
.new_str("CERTIFICATE_VERIFY_FAILED")
.as_object()
.to_owned(),
vm,
)?;
Ok(exc.upcast())
}
#[derive(Debug)]
pub(super) enum TlsConnection {
Client(ClientConnection),
Server(ServerConnection),
}
impl TlsConnection {
pub fn is_handshaking(&self) -> bool {
match self {
TlsConnection::Client(conn) => conn.is_handshaking(),
TlsConnection::Server(conn) => conn.is_handshaking(),
}
}
pub fn wants_read(&self) -> bool {
match self {
TlsConnection::Client(conn) => conn.wants_read(),
TlsConnection::Server(conn) => conn.wants_read(),
}
}
pub fn wants_write(&self) -> bool {
match self {
TlsConnection::Client(conn) => conn.wants_write(),
TlsConnection::Server(conn) => conn.wants_write(),
}
}
pub fn read_tls(&mut self, reader: &mut dyn std::io::Read) -> std::io::Result<usize> {
match self {
TlsConnection::Client(conn) => conn.read_tls(reader),
TlsConnection::Server(conn) => conn.read_tls(reader),
}
}
pub fn write_tls(&mut self, writer: &mut dyn std::io::Write) -> std::io::Result<usize> {
match self {
TlsConnection::Client(conn) => conn.write_tls(writer),
TlsConnection::Server(conn) => conn.write_tls(writer),
}
}
pub fn process_new_packets(&mut self) -> Result<rustls::IoState, rustls::Error> {
match self {
TlsConnection::Client(conn) => conn.process_new_packets(),
TlsConnection::Server(conn) => conn.process_new_packets(),
}
}
pub fn reader(&mut self) -> rustls::Reader<'_> {
match self {
TlsConnection::Client(conn) => conn.reader(),
TlsConnection::Server(conn) => conn.reader(),
}
}
pub fn writer(&mut self) -> rustls::Writer<'_> {
match self {
TlsConnection::Client(conn) => conn.writer(),
TlsConnection::Server(conn) => conn.writer(),
}
}
pub fn is_session_resumed(&self) -> bool {
use rustls::HandshakeKind;
match self {
TlsConnection::Client(conn) => {
matches!(conn.handshake_kind(), Some(HandshakeKind::Resumed))
}
TlsConnection::Server(conn) => {
matches!(conn.handshake_kind(), Some(HandshakeKind::Resumed))
}
}
}
pub fn send_close_notify(&mut self) {
match self {
TlsConnection::Client(conn) => conn.send_close_notify(),
TlsConnection::Server(conn) => conn.send_close_notify(),
}
}
pub fn alpn_protocol(&self) -> Option<&[u8]> {
match self {
TlsConnection::Client(conn) => conn.alpn_protocol(),
TlsConnection::Server(conn) => conn.alpn_protocol(),
}
}
pub fn negotiated_cipher_suite(&self) -> Option<rustls::SupportedCipherSuite> {
match self {
TlsConnection::Client(conn) => conn.negotiated_cipher_suite(),
TlsConnection::Server(conn) => conn.negotiated_cipher_suite(),
}
}
pub fn peer_certificates(&self) -> Option<&[rustls::pki_types::CertificateDer<'static>]> {
match self {
TlsConnection::Client(conn) => conn.peer_certificates(),
TlsConnection::Server(conn) => conn.peer_certificates(),
}
}
}
#[derive(Debug)]
pub(super) enum SslError {
WantRead,
WantWrite,
Syscall(String),
Ssl(String),
ZeroReturn,
Eof,
PreauthData,
CertVerification(rustls::CertificateError),
Io(std::io::Error),
Timeout(String),
SniCallbackRestart,
Py(PyBaseExceptionRef),
AlertReceived { lib: i32, reason: i32 },
NoCipherSuites,
}
impl SslError {
fn alert_to_openssl_reason(alert: rustls::AlertDescription) -> i32 {
1000 + (u8::from(alert) as i32)
}
pub fn from_rustls(err: rustls::Error) -> Self {
match err {
rustls::Error::InvalidCertificate(cert_err) => SslError::CertVerification(cert_err),
rustls::Error::AlertReceived(alert_desc) => {
match alert_desc {
rustls::AlertDescription::CloseNotify => {
SslError::ZeroReturn
}
_ => {
SslError::AlertReceived {
lib: ERR_LIB_SSL,
reason: Self::alert_to_openssl_reason(alert_desc),
}
}
}
}
rustls::Error::InvalidMessage(_) => {
SslError::Eof
}
rustls::Error::PeerIncompatible(peer_err) => {
use rustls::PeerIncompatible;
match peer_err {
PeerIncompatible::NoCipherSuitesInCommon => {
SslError::NoCipherSuites
}
_ => {
SslError::Eof
}
}
}
_ => SslError::Ssl(format!("{err}")),
}
}
pub(super) fn create_ssl_error_with_reason(
vm: &VirtualMachine,
library: Option<&str>,
reason: &str,
message: impl Into<String>,
) -> PyBaseExceptionRef {
let msg = message.into();
let exc = vm.new_os_subtype_error(PySSLError::class(&vm.ctx).to_owned(), Some(1), msg);
let library_obj = match library {
Some(lib) => vm.ctx.new_str(lib).as_object().to_owned(),
None => vm.ctx.none(),
};
let _ = exc.as_object().set_attr("library", library_obj, vm);
let _ =
exc.as_object()
.set_attr("reason", vm.ctx.new_str(reason).as_object().to_owned(), vm);
exc.upcast()
}
fn create_ssl_error_from_codes(
vm: &VirtualMachine,
lib: i32,
reason: i32,
) -> PyBaseExceptionRef {
let key = ssl_data::encode_error_key(lib, reason);
let reason_str = ssl_data::ERROR_CODES
.get(&key)
.copied()
.unwrap_or("unknown error");
let lib_str = ssl_data::LIBRARY_CODES
.get(&(lib as u32))
.copied()
.unwrap_or("UNKNOWN");
Self::create_ssl_error_with_reason(
vm,
Some(lib_str),
reason_str,
format!("[SSL] {reason_str}"),
)
}
pub fn into_py_err(self, vm: &VirtualMachine) -> PyBaseExceptionRef {
match self {
SslError::WantRead => create_ssl_want_read_error(vm).upcast(),
SslError::WantWrite => create_ssl_want_write_error(vm).upcast(),
SslError::Timeout(msg) => timeout_error_msg(vm, msg).upcast(),
SslError::Syscall(msg) => {
create_ssl_syscall_error(vm, msg).upcast()
}
SslError::Ssl(msg) => vm
.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
format!("SSL error: {msg}"),
)
.upcast(),
SslError::ZeroReturn => create_ssl_zero_return_error(vm).upcast(),
SslError::Eof => create_ssl_eof_error(vm).upcast(),
SslError::PreauthData => {
Self::create_ssl_error_with_reason(
vm,
None,
"before TLS handshake with data",
"before TLS handshake with data",
)
}
SslError::CertVerification(cert_err) => {
create_ssl_cert_verification_error(vm, &cert_err).expect("unlikely to happen")
}
SslError::Io(err) => err.into_pyexception(vm),
SslError::SniCallbackRestart => {
unreachable!("SniCallbackRestart should not reach Python layer")
}
SslError::Py(exc) => exc,
SslError::AlertReceived { lib, reason } => {
Self::create_ssl_error_from_codes(vm, lib, reason)
}
SslError::NoCipherSuites => {
Self::create_ssl_error_from_codes(vm, ERR_LIB_SSL, SSL_R_NO_SHARED_CIPHER)
}
}
}
}
pub type SslResult<T> = Result<T, SslError>;
#[derive(Debug)]
pub struct ProtocolSettings {
pub versions: &'static [&'static rustls::SupportedProtocolVersion],
pub kx_groups: Option<Vec<&'static dyn rustls::crypto::SupportedKxGroup>>,
pub cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
pub alpn_protocols: Vec<Vec<u8>>,
}
#[derive(Debug)]
pub struct ServerConfigOptions {
pub protocol_settings: ProtocolSettings,
pub cert_chain: Vec<CertificateDer<'static>>,
pub private_key: PrivateKeyDer<'static>,
pub root_store: Option<RootCertStore>,
pub request_client_cert: bool,
pub use_deferred_validation: bool,
pub cert_resolver: Option<Arc<dyn ResolvesServerCert>>,
pub deferred_cert_error: Option<Arc<ParkingRwLock<Option<String>>>>,
pub session_storage: Option<Arc<rustls::server::ServerSessionMemoryCache>>,
pub ticketer: Option<Arc<dyn rustls::server::ProducesTickets>>,
}
#[derive(Debug)]
pub struct ClientConfigOptions {
pub protocol_settings: ProtocolSettings,
pub root_store: Option<RootCertStore>,
pub ca_certs_der: Vec<Vec<u8>>,
pub cert_chain: Option<Vec<CertificateDer<'static>>>,
pub private_key: Option<PrivateKeyDer<'static>>,
pub verify_server_cert: bool,
pub check_hostname: bool,
pub verify_flags: i32,
pub session_store: Option<Arc<dyn rustls::client::ClientSessionStore>>,
pub crls: Vec<CertificateRevocationListDer<'static>>,
}
fn create_custom_crypto_provider(
cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
kx_groups: Option<Vec<&'static dyn rustls::crypto::SupportedKxGroup>>,
) -> Arc<rustls::crypto::CryptoProvider> {
use rustls::crypto::aws_lc_rs::{ALL_CIPHER_SUITES, ALL_KX_GROUPS};
let default_provider = rustls::crypto::aws_lc_rs::default_provider();
Arc::new(rustls::crypto::CryptoProvider {
cipher_suites: cipher_suites.unwrap_or_else(|| ALL_CIPHER_SUITES.to_vec()),
kx_groups: kx_groups.unwrap_or_else(|| ALL_KX_GROUPS.to_vec()),
signature_verification_algorithms: default_provider.signature_verification_algorithms,
secure_random: default_provider.secure_random,
key_provider: default_provider.key_provider,
})
}
pub(super) fn create_server_config(options: ServerConfigOptions) -> Result<ServerConfig, String> {
use rustls::server::WebPkiClientVerifier;
ensure_default_provider();
let custom_provider = create_custom_crypto_provider(
options.protocol_settings.cipher_suites.clone(),
options.protocol_settings.kx_groups.clone(),
);
let client_cert_verifier: Option<Arc<dyn rustls::server::danger::ClientCertVerifier>> =
if let Some(root_store) = options.root_store {
if options.request_client_cert {
let base_verifier = WebPkiClientVerifier::builder(Arc::new(root_store))
.build()
.map_err(|e| format!("Failed to create client verifier: {e}"))?;
if options.use_deferred_validation {
if let Some(deferred_error) = options.deferred_cert_error {
use crate::ssl::cert::DeferredClientCertVerifier;
let deferred_verifier =
DeferredClientCertVerifier::new(base_verifier, deferred_error);
Some(Arc::new(deferred_verifier))
} else {
Some(base_verifier)
}
} else {
Some(base_verifier)
}
} else {
None
}
} else {
None
};
let builder = ServerConfig::builder_with_provider(custom_provider.clone())
.with_protocol_versions(options.protocol_settings.versions)
.map_err(|e| format!("Failed to create server config builder: {e}"))?;
let builder = if let Some(verifier) = client_cert_verifier {
builder.with_client_cert_verifier(verifier)
} else {
builder.with_no_client_auth()
};
let mut config = if let Some(resolver) = options.cert_resolver {
builder.with_cert_resolver(resolver)
} else {
builder
.with_single_cert(options.cert_chain, options.private_key)
.map_err(|e| format!("Failed to set server certificate: {e}"))?
};
apply_alpn_with_fallback(
&mut config.alpn_protocols,
&options.protocol_settings.alpn_protocols,
);
if let Some(session_storage) = options.session_storage {
config.session_storage = session_storage;
}
if let Some(ticketer) = options.ticketer {
config.ticketer = ticketer.clone();
}
Ok(config)
}
fn build_webpki_verifier_with_crls(
root_store: Arc<RootCertStore>,
crls: Vec<CertificateRevocationListDer<'static>>,
verify_flags: i32,
) -> Result<Arc<dyn rustls::client::danger::ServerCertVerifier>, String> {
use rustls::client::WebPkiServerVerifier;
let mut verifier_builder = WebPkiServerVerifier::builder(root_store);
let crl_check_requested = verify_flags & X509_V_FLAG_CRL_CHECK != 0;
let has_crls = !crls.is_empty();
if has_crls || crl_check_requested {
verifier_builder = verifier_builder.with_crls(crls);
if verify_flags & X509_V_FLAG_CRL_CHECK != 0 {
verifier_builder = verifier_builder.only_check_end_entity_revocation();
}
}
let webpki_verifier = verifier_builder
.build()
.map_err(|e| format!("Failed to build WebPkiServerVerifier: {e}"))?;
Ok(webpki_verifier as Arc<dyn rustls::client::danger::ServerCertVerifier>)
}
fn apply_verifier_wrappers(
verifier: Arc<dyn rustls::client::danger::ServerCertVerifier>,
verify_flags: i32,
has_crls: bool,
ca_certs_der: Vec<Vec<u8>>,
) -> Arc<dyn rustls::client::danger::ServerCertVerifier> {
let crl_check_requested = verify_flags & X509_V_FLAG_CRL_CHECK != 0;
let verifier = if crl_check_requested {
use crate::ssl::cert::CRLCheckVerifier;
Arc::new(CRLCheckVerifier::new(
verifier,
has_crls,
crl_check_requested,
))
} else {
verifier
};
let verifier = if !ca_certs_der.is_empty() {
use crate::ssl::cert::PartialChainVerifier;
Arc::new(PartialChainVerifier::new(
verifier,
ca_certs_der,
verify_flags,
))
} else {
verifier
};
if verify_flags & VERIFY_X509_STRICT != 0 {
Arc::new(super::cert::StrictCertVerifier::new(verifier, verify_flags))
} else {
verifier
}
}
fn apply_alpn_with_fallback(config_alpn: &mut Vec<Vec<u8>>, alpn_protocols: &[Vec<u8>]) {
if !alpn_protocols.is_empty() {
*config_alpn = alpn_protocols.to_vec();
config_alpn.push(vec![0u8]); }
}
pub(super) fn create_client_config(options: ClientConfigOptions) -> Result<ClientConfig, String> {
ensure_default_provider();
let custom_provider = create_custom_crypto_provider(
options.protocol_settings.cipher_suites.clone(),
options.protocol_settings.kx_groups.clone(),
);
let verifier: Arc<dyn rustls::client::danger::ServerCertVerifier> = if options
.verify_server_cert
{
let root_store = options
.root_store
.ok_or("Root store required for server verification")?;
let root_store_arc = Arc::new(root_store);
if root_store_arc.is_empty() {
use crate::ssl::cert::EmptyRootStoreVerifier;
Arc::new(EmptyRootStoreVerifier)
} else {
let has_crls = !options.crls.is_empty();
if options.check_hostname {
let base_verifier = build_webpki_verifier_with_crls(
root_store_arc.clone(),
options.crls,
options.verify_flags,
)?;
apply_verifier_wrappers(
base_verifier,
options.verify_flags,
has_crls,
options.ca_certs_der.clone(),
)
} else {
use crate::ssl::cert::HostnameIgnoringVerifier;
let webpki_verifier = build_webpki_verifier_with_crls(
root_store_arc.clone(),
options.crls,
options.verify_flags,
)?;
let crl_check_requested = options.verify_flags & X509_V_FLAG_CRL_CHECK != 0;
let verifier = if crl_check_requested {
use crate::ssl::cert::CRLCheckVerifier;
Arc::new(CRLCheckVerifier::new(
webpki_verifier,
has_crls,
crl_check_requested,
)) as Arc<dyn rustls::client::danger::ServerCertVerifier>
} else {
webpki_verifier
};
const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000;
let verifier = if options.verify_flags & VERIFY_X509_PARTIAL_CHAIN != 0 {
use crate::ssl::cert::PartialChainVerifier;
Arc::new(PartialChainVerifier::new(
verifier,
options.ca_certs_der.clone(),
options.verify_flags,
)) as Arc<dyn rustls::client::danger::ServerCertVerifier>
} else {
verifier
};
let hostname_ignoring_verifier: Arc<
dyn rustls::client::danger::ServerCertVerifier,
> = Arc::new(HostnameIgnoringVerifier::new_with_verifier(verifier));
if options.verify_flags & VERIFY_X509_STRICT != 0 {
Arc::new(crate::ssl::cert::StrictCertVerifier::new(
hostname_ignoring_verifier,
options.verify_flags,
))
} else {
hostname_ignoring_verifier
}
}
}
} else {
use crate::ssl::cert::NoVerifier;
Arc::new(NoVerifier)
};
let builder = ClientConfig::builder_with_provider(custom_provider.clone())
.with_protocol_versions(options.protocol_settings.versions)
.map_err(|e| format!("Failed to create client config builder: {e}"))?
.dangerous()
.with_custom_certificate_verifier(verifier);
let mut config =
if let (Some(cert_chain), Some(private_key)) = (options.cert_chain, options.private_key) {
builder
.with_client_auth_cert(cert_chain, private_key)
.map_err(|e| format!("Failed to set client certificate: {e}"))?
} else {
builder.with_no_client_auth()
};
apply_alpn_with_fallback(
&mut config.alpn_protocols,
&options.protocol_settings.alpn_protocols,
);
if let Some(session_store) = options.session_store {
use rustls::client::Resumption;
config.resumption = Resumption::store(session_store);
}
Ok(config)
}
pub(super) fn is_blocking_io_error(err: &Py<PyBaseException>, vm: &VirtualMachine) -> bool {
err.fast_isinstance(vm.ctx.exceptions.blocking_io_error)
}
fn send_all_bytes(
socket: &PySSLSocket,
buf: Vec<u8>,
vm: &VirtualMachine,
deadline: Option<std::time::Instant>,
) -> SslResult<()> {
socket
.flush_pending_tls_output(vm, deadline)
.map_err(SslError::Py)?;
if buf.is_empty() {
return Ok(());
}
let mut sent_total = 0;
while sent_total < buf.len() {
if let Some(dl) = deadline
&& std::time::Instant::now() >= dl
{
socket
.pending_tls_output
.lock()
.extend_from_slice(&buf[sent_total..]);
return Err(SslError::Timeout("The operation timed out".to_string()));
}
let timed_out = if let Some(dl) = deadline {
let now = std::time::Instant::now();
if now >= dl {
socket
.pending_tls_output
.lock()
.extend_from_slice(&buf[sent_total..]);
return Err(SslError::Timeout(
"The write operation timed out".to_string(),
));
}
socket
.sock_wait_for_io_with_timeout(SelectKind::Write, Some(dl - now), vm)
.map_err(SslError::Py)?
} else {
socket
.sock_wait_for_io_impl(SelectKind::Write, vm)
.map_err(SslError::Py)?
};
if timed_out {
socket
.pending_tls_output
.lock()
.extend_from_slice(&buf[sent_total..]);
return Err(SslError::Timeout(
"The write operation timed out".to_string(),
));
}
match socket.sock_send(&buf[sent_total..], vm) {
Ok(result) => {
let sent: usize = result
.try_to_value::<isize>(vm)
.map_err(SslError::Py)?
.try_into()
.map_err(|_| SslError::Syscall("Invalid send return value".to_string()))?;
if sent == 0 {
socket
.pending_tls_output
.lock()
.extend_from_slice(&buf[sent_total..]);
return Err(SslError::WantWrite);
}
sent_total += sent;
}
Err(e) => {
if is_blocking_io_error(&e, vm) {
socket
.pending_tls_output
.lock()
.extend_from_slice(&buf[sent_total..]);
return Err(SslError::WantWrite);
}
socket
.pending_tls_output
.lock()
.extend_from_slice(&buf[sent_total..]);
return Err(SslError::Py(e));
}
}
}
Ok(())
}
fn handshake_write_loop(
conn: &mut TlsConnection,
socket: &PySSLSocket,
force_initial_write: bool,
vm: &VirtualMachine,
) -> SslResult<bool> {
let mut made_progress = false;
socket
.flush_pending_tls_output(vm, None)
.map_err(SslError::Py)?;
while conn.wants_write() || force_initial_write {
if force_initial_write && !conn.wants_write() {
break;
}
let mut buf = Vec::new();
let written = conn
.write_tls(&mut buf as &mut dyn std::io::Write)
.map_err(SslError::Io)?;
if written > 0 && !buf.is_empty() {
send_all_bytes(socket, buf, vm, None)?;
made_progress = true;
} else if written == 0 {
break;
}
if !conn.wants_write() {
break;
}
}
Ok(made_progress)
}
const TLS_RECORD_HEADER_SIZE: usize = 5;
fn recv_one_tls_record(socket: &PySSLSocket, vm: &VirtualMachine) -> SslResult<PyObjectRef> {
let peeked_obj = match socket.sock_peek(SSL3_RT_MAX_PLAIN_LENGTH, vm) {
Ok(d) => d,
Err(e) => {
if is_blocking_io_error(&e, vm) {
return Err(SslError::WantRead);
}
return Err(SslError::Py(e));
}
};
let peeked = ArgBytesLike::try_from_object(vm, peeked_obj)
.map_err(|_| SslError::Syscall("Expected bytes-like object from peek".to_string()))?;
let peeked_bytes = peeked.borrow_buf();
if peeked_bytes.is_empty() {
return Err(SslError::Eof);
}
if peeked_bytes.len() < TLS_RECORD_HEADER_SIZE {
return socket.sock_recv(peeked_bytes.len(), vm).map_err(|e| {
if is_blocking_io_error(&e, vm) {
SslError::WantRead
} else {
SslError::Py(e)
}
});
}
let record_body_len = u16::from_be_bytes([peeked_bytes[3], peeked_bytes[4]]) as usize;
let total_record_size = TLS_RECORD_HEADER_SIZE + record_body_len;
let recv_size = if peeked_bytes.len() >= total_record_size {
total_record_size
} else {
peeked_bytes.len()
};
drop(peeked_bytes);
drop(peeked);
socket.sock_recv(recv_size, vm).map_err(|e| {
if is_blocking_io_error(&e, vm) {
SslError::WantRead
} else {
SslError::Py(e)
}
})
}
fn recv_one_tls_record_for_data(
conn: &mut TlsConnection,
socket: &PySSLSocket,
vm: &VirtualMachine,
) -> SslResult<PyObjectRef> {
match recv_one_tls_record(socket, vm) {
Ok(data) => Ok(data),
Err(SslError::Eof) => {
if let Err(rustls_err) = conn.process_new_packets() {
return Err(SslError::from_rustls(rustls_err));
}
Ok(vm.ctx.new_bytes(vec![]).into())
}
Err(SslError::Py(e)) => {
if let Err(rustls_err) = conn.process_new_packets() {
return Err(SslError::from_rustls(rustls_err));
}
if is_connection_closed_error(&e, vm) {
return Err(SslError::Eof);
}
Err(SslError::Py(e))
}
Err(e) => Err(e),
}
}
fn handshake_read_data(
conn: &mut TlsConnection,
socket: &PySSLSocket,
is_bio: bool,
is_server: bool,
vm: &VirtualMachine,
) -> SslResult<(bool, bool)> {
if !conn.wants_read() {
return Ok((false, false));
}
let is_first_sni_read = is_server && socket.is_first_sni_read();
if !is_bio {
let timed_out = socket
.sock_wait_for_io_impl(SelectKind::Read, vm)
.map_err(SslError::Py)?;
if timed_out {
return Err(SslError::Timeout("timed out".to_string()));
}
}
let data_obj = if !is_bio {
recv_one_tls_record(socket, vm)?
} else {
match socket.sock_recv(SSL3_RT_MAX_PLAIN_LENGTH, vm) {
Ok(d) => d,
Err(e) => {
if is_blocking_io_error(&e, vm) {
return Err(SslError::WantRead);
}
if !conn.wants_write() && e.fast_isinstance(vm.ctx.exceptions.timeout_error) {
return Ok((false, false));
}
return Err(SslError::Py(e));
}
}
};
if is_first_sni_read {
use rustpython_vm::builtins::PyBytes;
if let Some(bytes_obj) = data_obj.downcast_ref::<PyBytes>() {
socket.save_client_hello_from_bytes(bytes_obj.as_bytes());
}
}
ssl_read_tls_records(conn, data_obj, is_bio, vm)?;
Ok((true, is_first_sni_read))
}
fn handle_handshake_complete(
conn: &mut TlsConnection,
socket: &PySSLSocket,
_is_server: bool,
vm: &VirtualMachine,
) -> SslResult<bool> {
if conn.is_handshaking() {
return Ok(false); }
if socket.is_bio_mode() {
if conn.wants_write() {
let tls_data = ssl_write_tls_records(conn)?;
if !tls_data.is_empty() {
send_all_bytes(socket, tls_data, vm, None)?;
}
}
} else if conn.wants_write() {
while conn.wants_write() {
let tls_data = ssl_write_tls_records(conn)?;
if tls_data.is_empty() {
break;
}
match send_all_bytes(socket, tls_data, vm, None) {
Ok(()) => {}
Err(SslError::WantWrite) => {
socket
.blocking_flush_all_pending(vm)
.map_err(SslError::Py)?;
}
Err(e) => return Err(e),
}
}
}
if !socket.is_bio_mode() {
socket
.blocking_flush_all_pending(vm)
.map_err(SslError::Py)?;
}
Ok(true)
}
fn try_read_plaintext(conn: &mut TlsConnection, buf: &mut [u8]) -> SslResult<Option<usize>> {
let mut reader = conn.reader();
match reader.read(buf) {
Ok(0) => {
Ok(Some(0))
}
Ok(n) => {
Ok(Some(n))
}
Err(e) if e.kind() != std::io::ErrorKind::WouldBlock => {
Err(SslError::Io(e))
}
Err(_) => {
Ok(None)
}
}
}
pub(super) fn ssl_do_handshake(
conn: &mut TlsConnection,
socket: &PySSLSocket,
vm: &VirtualMachine,
) -> SslResult<()> {
if !conn.is_handshaking() {
return Ok(());
}
let is_bio = socket.is_bio_mode();
let is_server = matches!(conn, TlsConnection::Server(_));
let mut first_iteration = true; let mut iteration_count = 0;
loop {
iteration_count += 1;
let mut made_progress = false;
let force_initial_write = is_bio && first_iteration;
let write_progress = handshake_write_loop(conn, socket, force_initial_write, vm)?;
made_progress |= write_progress;
let (read_progress, is_first_sni_read) =
handshake_read_data(conn, socket, is_bio, is_server, vm)?;
made_progress |= read_progress;
if let Err(e) = conn.process_new_packets() {
if !is_bio {
conn.send_close_notify();
let _ = socket.flush_pending_tls_output(vm, None);
if let Ok(alert_data) = ssl_write_tls_records(conn)
&& !alert_data.is_empty()
{
let _ = send_all_bytes(socket, alert_data, vm, None);
}
}
if matches!(e, rustls::Error::InvalidMessage(_)) {
return Err(SslError::PreauthData);
}
return Err(SslError::from_rustls(e));
}
if is_server && is_first_sni_read && socket.has_sni_callback() {
return Err(SslError::SniCallbackRestart);
}
if handle_handshake_complete(conn, socket, is_server, vm)? {
return Ok(());
}
if is_bio {
if conn.wants_write() {
loop {
let mut buf = vec![0u8; SSL3_RT_MAX_PLAIN_LENGTH];
let n = match conn.write_tls(&mut buf.as_mut_slice()) {
Ok(n) => n,
Err(_) => break,
};
if n == 0 {
break;
}
send_all_bytes(socket, buf[..n].to_vec(), vm, None)?;
if !conn.wants_write() {
break;
}
}
if conn.wants_write() {
return Err(SslError::WantWrite);
}
}
if conn.wants_read() {
return Err(SslError::WantRead);
}
break;
}
first_iteration = false;
let should_continue = conn.wants_read() || conn.wants_write() || made_progress;
if !should_continue {
break;
}
if iteration_count > 1000 {
break;
}
}
if conn.is_handshaking() {
if conn.wants_write() {
return Err(SslError::WantWrite);
}
if conn.wants_read() {
return Err(SslError::WantRead);
}
Err(SslError::Syscall(format!(
"SSL handshake failed: incomplete after {iteration_count} iterations",
)))
} else {
Ok(())
}
}
pub(super) fn ssl_read(
conn: &mut TlsConnection,
buf: &mut [u8],
socket: &PySSLSocket,
vm: &VirtualMachine,
) -> SslResult<usize> {
let is_bio = socket.is_bio_mode();
let deadline = if !is_bio {
match socket.get_socket_timeout(vm).map_err(SslError::Py)? {
Some(timeout) if !timeout.is_zero() => Some(std::time::Instant::now() + timeout),
_ => None, }
} else {
None };
if !is_bio {
socket
.flush_pending_tls_output(vm, deadline)
.map_err(SslError::Py)?;
}
loop {
if let Some(deadline) = deadline
&& std::time::Instant::now() >= deadline
{
return Err(SslError::Timeout(
"The read operation timed out".to_string(),
));
}
let needs_more_tls = conn.wants_read();
if let Some(n) = try_read_plaintext(conn, buf)? {
if n == 0 {
return Err(SslError::ZeroReturn);
}
return Ok(n);
}
if !needs_more_tls {
if conn.wants_write() && !is_bio {
if let Some(deadline) = deadline
&& std::time::Instant::now() >= deadline
{
return Err(SslError::Timeout(
"The read operation timed out".to_string(),
));
}
let tls_data = ssl_write_tls_records(conn)?;
if !tls_data.is_empty() {
match send_all_bytes(socket, tls_data, vm, deadline) {
Ok(()) => {}
Err(SslError::WantWrite) => {
}
Err(SslError::Timeout(_)) => {
}
Err(e) => return Err(e),
}
}
if let Some(deadline) = deadline
&& std::time::Instant::now() >= deadline
{
return Err(SslError::Timeout(
"The read operation timed out".to_string(),
));
}
continue;
}
if is_bio && let Some(bio_obj) = socket.incoming_bio() {
let is_eof = bio_obj
.get_attr("eof", vm)
.and_then(|v| v.try_into_value::<bool>(vm))
.unwrap_or(false);
if is_eof {
return Err(SslError::Eof);
}
}
if !is_bio {
let timeout = socket.get_socket_timeout(vm).map_err(SslError::Py)?;
if let Some(t) = timeout
&& t.is_zero()
{
let io_state = conn.process_new_packets().map_err(SslError::from_rustls)?;
if io_state.peer_has_closed() {
return Err(SslError::ZeroReturn);
}
return Err(SslError::WantRead);
}
let data = recv_one_tls_record_for_data(conn, socket, vm)?;
let bytes_read = data
.clone()
.try_into_value::<rustpython_vm::builtins::PyBytes>(vm)
.map(|b| b.as_bytes().len())
.unwrap_or(0);
if bytes_read == 0 {
let io_state = conn.process_new_packets().map_err(SslError::from_rustls)?;
if io_state.peer_has_closed() {
return Err(SslError::ZeroReturn);
}
return Err(SslError::Eof);
}
ssl_read_tls_records(conn, data, false, vm)?;
conn.process_new_packets().map_err(SslError::from_rustls)?;
continue;
}
return Err(SslError::WantRead);
}
match ssl_ensure_data_available(conn, socket, vm) {
Ok(_bytes_read) => {
}
Err(SslError::Io(ref io_err)) if io_err.to_string().contains("message buffer full") => {
continue;
}
Err(e) => {
match try_read_plaintext(conn, buf)? {
Some(n) if n > 0 => {
return Ok(n);
}
_ => {
return Err(e);
}
}
}
}
}
}
pub(super) fn ssl_write(
conn: &mut TlsConnection,
data: &[u8],
socket: &PySSLSocket,
vm: &VirtualMachine,
) -> SslResult<usize> {
if data.is_empty() {
return Ok(0);
}
let is_bio = socket.is_bio_mode();
let deadline = if !is_bio {
match socket.get_socket_timeout(vm).map_err(SslError::Py)? {
Some(timeout) if !timeout.is_zero() => Some(std::time::Instant::now() + timeout),
_ => None,
}
} else {
None
};
if !is_bio {
socket
.flush_pending_tls_output(vm, deadline)
.map_err(SslError::Py)?;
}
let already_buffered = *socket.write_buffered_len.lock();
let mut bytes_written_to_rustls = 0usize;
if already_buffered == 0 {
bytes_written_to_rustls = {
let mut writer = conn.writer();
use std::io::Write;
match writer.write(data) {
Ok(0) if !data.is_empty() => {
if is_bio {
return Err(SslError::WantWrite);
}
return Err(SslError::Syscall("Write failed: buffer full".to_string()));
}
Ok(n) => n,
Err(e) => {
if is_bio {
return Err(SslError::WantWrite);
}
return Err(SslError::Syscall(format!("Write failed: {e}")));
}
}
};
*socket.write_buffered_len.lock() = bytes_written_to_rustls;
} else if already_buffered != data.len() {
*socket.write_buffered_len.lock() = 0;
return Err(SslError::Ssl("bad write retry".to_string()));
}
loop {
if let Some(dl) = deadline
&& std::time::Instant::now() >= dl
{
return Err(SslError::Timeout(
"The write operation timed out".to_string(),
));
}
if !conn.wants_write() {
break;
}
let tls_data = ssl_write_tls_records(conn)?;
if tls_data.is_empty() {
break;
}
match send_all_bytes(socket, tls_data, vm, deadline) {
Ok(()) => {
}
Err(SslError::WantWrite) => {
if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() {
*socket.write_buffered_len.lock() = 0;
return Ok(bytes_written_to_rustls);
}
return Err(SslError::WantWrite);
}
Err(SslError::WantRead) => {
if is_bio {
if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() {
*socket.write_buffered_len.lock() = 0;
return Ok(bytes_written_to_rustls);
}
return Err(SslError::WantRead);
}
let recv_result = socket.sock_recv(4096, vm).map_err(SslError::Py)?;
ssl_read_tls_records(conn, recv_result, false, vm)?;
conn.process_new_packets().map_err(SslError::from_rustls)?;
}
Err(e @ SslError::Timeout(_)) => {
if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() {
*socket.write_buffered_len.lock() = 0;
return Ok(bytes_written_to_rustls);
}
return Err(e);
}
Err(e) => {
*socket.write_buffered_len.lock() = 0;
return Err(e);
}
}
}
if !is_bio {
socket
.flush_pending_tls_output(vm, deadline)
.map_err(SslError::Py)?;
}
let actual_written = if bytes_written_to_rustls > 0 {
bytes_written_to_rustls
} else if already_buffered > 0 {
already_buffered
} else {
data.len()
};
*socket.write_buffered_len.lock() = 0;
Ok(actual_written)
}
fn ssl_write_tls_records(conn: &mut TlsConnection) -> SslResult<Vec<u8>> {
let mut buf = Vec::new();
let n = conn
.write_tls(&mut buf as &mut dyn std::io::Write)
.map_err(SslError::Io)?;
if n > 0 { Ok(buf) } else { Ok(Vec::new()) }
}
fn ssl_read_tls_records(
conn: &mut TlsConnection,
data: PyObjectRef,
is_bio: bool,
vm: &VirtualMachine,
) -> SslResult<()> {
let bytes = ArgBytesLike::try_from_object(vm, data)
.map_err(|_| SslError::Syscall("Expected bytes-like object".to_string()))?;
let bytes_data = bytes.borrow_buf();
if bytes_data.is_empty() {
if is_bio {
return Err(SslError::WantRead);
} else {
match conn.process_new_packets() {
Ok(io_state) => {
if io_state.peer_has_closed() {
return Err(SslError::ZeroReturn);
} else {
return Err(SslError::Eof);
}
}
Err(e) => return Err(SslError::from_rustls(e)),
}
}
}
let mut offset = 0;
while offset < bytes_data.len() {
let remaining = &bytes_data[offset..];
let mut cursor = std::io::Cursor::new(remaining);
match conn.read_tls(&mut cursor) {
Ok(read_bytes) => {
if read_bytes == 0 {
conn.process_new_packets().map_err(SslError::from_rustls)?;
let mut retry_cursor = std::io::Cursor::new(remaining);
match conn.read_tls(&mut retry_cursor) {
Ok(0) => {
break;
}
Ok(n) => {
offset += n;
}
Err(e) => {
return Err(SslError::Io(e));
}
}
} else {
offset += read_bytes;
}
}
Err(e) => {
if e.to_string().contains("buffer full") {
conn.process_new_packets().map_err(SslError::from_rustls)?;
continue;
}
return Err(SslError::Io(e));
}
}
}
Ok(())
}
fn is_connection_closed_error(exc: &Py<PyBaseException>, vm: &VirtualMachine) -> bool {
use rustpython_vm::stdlib::errno::errors;
if exc.fast_isinstance(vm.ctx.exceptions.connection_aborted_error)
|| exc.fast_isinstance(vm.ctx.exceptions.connection_reset_error)
{
return true;
}
if exc.fast_isinstance(vm.ctx.exceptions.os_error)
&& let Ok(errno) = exc.as_object().get_attr("errno", vm)
&& let Ok(errno_int) = errno.try_int(vm)
&& let Ok(errno_val) = errno_int.try_to_primitive::<i32>(vm)
{
return errno_val == errors::ECONNABORTED || errno_val == errors::ECONNRESET;
}
false
}
fn ssl_ensure_data_available(
conn: &mut TlsConnection,
socket: &PySSLSocket,
vm: &VirtualMachine,
) -> SslResult<usize> {
if conn.wants_read() {
let is_bio = socket.is_bio_mode();
if !is_bio {
let timeout = socket.get_socket_timeout(vm).map_err(SslError::Py)?;
if let Some(t) = timeout
&& !t.is_zero()
{
let timed_out = socket
.sock_wait_for_io_impl(SelectKind::Read, vm)
.map_err(SslError::Py)?;
if timed_out {
return Err(SslError::Timeout(
"The read operation timed out".to_string(),
));
}
}
}
let data = if !is_bio {
recv_one_tls_record_for_data(conn, socket, vm)?
} else {
match socket.sock_recv(2048, vm) {
Ok(data) => data,
Err(e) => {
if is_blocking_io_error(&e, vm) {
return Err(SslError::WantRead);
}
if let Err(rustls_err) = conn.process_new_packets() {
return Err(SslError::from_rustls(rustls_err));
}
if is_connection_closed_error(&e, vm) {
return Err(SslError::Eof);
}
return Err(SslError::Py(e));
}
}
};
let bytes_read = data
.clone()
.try_into_value::<rustpython_vm::builtins::PyBytes>(vm)
.map(|b| b.as_bytes().len())
.unwrap_or(0);
let is_eof = if is_bio {
if let Some(bio_obj) = socket.incoming_bio() {
bio_obj
.get_attr("eof", vm)
.and_then(|v| v.try_into_value::<bool>(vm))
.unwrap_or(false)
} else {
false
}
} else {
false
};
if is_eof && bytes_read == 0 {
return Err(SslError::Eof);
}
ssl_read_tls_records(conn, data, is_bio, vm)?;
conn.process_new_packets().map_err(SslError::from_rustls)?;
Ok(bytes_read)
} else {
Ok(0)
}
}
#[derive(Debug)]
pub(super) struct MultiCertResolver {
cert_keys: Vec<Arc<CertifiedKey>>,
}
impl MultiCertResolver {
pub fn new(cert_keys: Vec<Arc<CertifiedKey>>) -> Self {
Self { cert_keys }
}
}
impl ResolvesServerCert for MultiCertResolver {
fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let client_schemes = client_hello.signature_schemes();
for cert_key in &self.cert_keys {
if let Some(_scheme) = cert_key.key.choose_scheme(client_schemes) {
return Some(cert_key.clone());
}
}
self.cert_keys.first().cloned()
}
}
pub(super) fn normalize_cipher_name(rustls_name: &str) -> String {
rustls_name
.strip_prefix("TLS_")
.unwrap_or(rustls_name)
.replace("_WITH_", "_")
.replace('_', "-")
.replace("AES-256", "AES256")
.replace("AES-128", "AES128")
}
pub(super) fn get_cipher_key_bits(cipher_name: &str) -> i32 {
if cipher_name.contains("256") || cipher_name.contains("CHACHA20") {
256
} else if cipher_name.contains("128") {
128
} else {
0
}
}
pub(super) fn get_cipher_encryption_desc(cipher_name: &str) -> &'static str {
if cipher_name.contains("AES256") {
"AESGCM(256)"
} else if cipher_name.contains("AES128") {
"AESGCM(128)"
} else if cipher_name.contains("CHACHA20") {
"CHACHA20-POLY1305(256)"
} else {
"Unknown"
}
}
pub(super) fn normalize_rustls_cipher_name(rustls_name: &str) -> String {
if rustls_name.starts_with("TLS13_") {
rustls_name.replace("TLS13_", "TLS_")
} else {
rustls_name.to_string()
}
}
pub(super) fn get_protocol_version_str(version: &rustls::SupportedProtocolVersion) -> &'static str {
match version.version {
rustls::ProtocolVersion::TLSv1_2 => "TLSv1.2",
rustls::ProtocolVersion::TLSv1_3 => "TLSv1.3",
_ => "Unknown",
}
}
pub(super) struct CipherInfo {
pub name: String,
pub protocol: &'static str,
pub bits: i32,
}
pub(super) fn extract_cipher_info(suite: &rustls::SupportedCipherSuite) -> CipherInfo {
let rustls_name = format!("{:?}", suite.suite());
let name = normalize_rustls_cipher_name(&rustls_name);
let protocol = get_protocol_version_str(suite.version());
let bits = get_cipher_key_bits(&name);
CipherInfo {
name,
protocol,
bits,
}
}
pub(super) fn curve_name_to_kx_group(
curve: &str,
) -> Result<Vec<&'static dyn SupportedKxGroup>, String> {
let provider = rustls::crypto::aws_lc_rs::default_provider();
let all_groups = &provider.kx_groups;
match curve {
"prime256v1" | "secp256r1" => {
all_groups
.iter()
.find(|g| g.name() == rustls::NamedGroup::secp256r1)
.map(|g| vec![*g])
.ok_or_else(|| "secp256r1 not supported by crypto provider".to_owned())
}
"secp384r1" | "prime384v1" => all_groups
.iter()
.find(|g| g.name() == rustls::NamedGroup::secp384r1)
.map(|g| vec![*g])
.ok_or_else(|| "secp384r1 not supported by crypto provider".to_owned()),
"X25519" | "x25519" => all_groups
.iter()
.find(|g| g.name() == rustls::NamedGroup::X25519)
.map(|g| vec![*g])
.ok_or_else(|| "X25519 not supported by crypto provider".to_owned()),
"prime521v1" | "secp521r1" => all_groups
.iter()
.find(|g| g.name() == rustls::NamedGroup::secp521r1)
.map(|g| vec![*g])
.ok_or_else(|| "secp521r1 not supported by crypto provider".to_owned()),
"X448" | "x448" => all_groups
.iter()
.find(|g| g.name() == rustls::NamedGroup::X448)
.map(|g| vec![*g])
.ok_or_else(|| "X448 not supported by crypto provider".to_owned()),
_ => Err(format!("unknown curve name '{curve}'")),
}
}