use super::dtls::{
encode_dtls12_handshake_fragments, encode_dtls_record_packet, open_dtls13_aes128gcm_record,
parse_dtls_record_packet, reassemble_dtls12_handshake_fragments, seal_dtls13_aes128gcm_record,
DtlsEpochReplayTracker, DtlsFlightRetransmitTracker, DtlsRecordHeader, DtlsReplayWindow,
DtlsReplayWindowSnapshot,
};
use super::handshake::{encode_handshake_message, parse_handshake_message};
use super::kdf::{
finished_hmac_for_hash, hash_bytes_for_algorithm, hkdf_expand_for_hash, hkdf_extract_for_hash,
hkdf_extract_with_salt_for_hash, tls13_expand_label_for_hash, HashAlgorithm,
};
use super::keyshare::{
derive_deterministic_mlkem768_keypair, derive_deterministic_p256_private,
derive_deterministic_x25519_private, derive_tls13_mlkem768_shared_secret,
derive_tls13_p256_shared_secret, derive_tls13_x25519_shared_secret,
tls13_client_hello_offers_supported_key_exchange,
};
use super::psk::{ticket_age_matches_policy, ResumptionTicket, TicketStore, TicketUsagePolicy};
use super::record::{
build_record_nonce, decode_tls12_ciphertext_record, decode_tls13_ciphertext_record,
decode_tls13_inner_plaintext, encode_tls12_ciphertext_record, encode_tls13_ciphertext_record,
encode_tls13_inner_plaintext,
};
use super::state::{
AlertDescription, AlertLevel, CipherSuite, HandshakeState, RecordContentType, TlsVersion,
};
#[cfg(not(feature = "std"))]
use crate::internal_alloc::ToOwned;
use crate::internal_alloc::{String, Vec};
use noxtls_core::{Error, Result};
use noxtls_crypto::{
aes_gcm_decrypt, aes_gcm_encrypt, chacha20_poly1305_decrypt, chacha20_poly1305_encrypt,
ed25519_public_key_from_subject_public_key_info, ed25519_verify, hkdf_extract_sha256,
mldsa_verify, p256_ecdsa_verify_sha256, rsassa_pss_sha256_verify, rsassa_pss_sha384_verify,
sha256, tls12_prf_sha256, tls12_prf_sha384, AesCipher, HmacDrbgSha256, MlDsaPublicKey,
MlKemPrivateKey, P256PrivateKey, P256PublicKey, RsaPublicKey, TlsTranscriptSha256,
TlsTranscriptSha384, X25519PrivateKey, MLKEM_CIPHERTEXT_LEN,
};
use noxtls_x509::{
certificate_matches_hostname, parse_certificate, parse_der_node, parse_ecdsa_signature_der,
validate_certificate_chain,
};
#[derive(Debug, Clone)]
pub struct Connection {
pub version: TlsVersion,
pub state: HandshakeState,
selected_cipher_suite: Option<CipherSuite>,
transcript: Vec<u8>,
transcript_hash: TranscriptHashState,
handshake_secret: Option<Vec<u8>>,
tls13_master_secret: Option<Vec<u8>>,
tls13_client_handshake_traffic_secret: Option<Vec<u8>>,
tls13_server_handshake_traffic_secret: Option<Vec<u8>>,
tls13_finished_key: Option<Vec<u8>>,
tls13_client_application_traffic_secret: Option<Vec<u8>>,
tls13_server_application_traffic_secret: Option<Vec<u8>>,
tls13_exporter_master_secret: Option<Vec<u8>>,
tls13_resumption_master_secret: Option<Vec<u8>>,
tls13_client_x25519_private: Option<X25519PrivateKey>,
tls13_client_p256_private: Option<P256PrivateKey>,
tls13_client_mlkem768_private: Option<MlKemPrivateKey>,
tls13_shared_secret: Option<[u8; 32]>,
tls13_hrr_requested_group: Option<u16>,
tls13_hrr_seen: bool,
client_write_key: Option<[u8; 32]>,
server_write_key: Option<[u8; 32]>,
client_write_iv: Option<[u8; 12]>,
server_write_iv: Option<[u8; 12]>,
client_sequence: u64,
server_sequence: u64,
tls13_peer_close_notify_received: bool,
tls13_local_close_notify_sent: bool,
tls13_require_certificate_auth: bool,
tls13_server_trust_anchors_der: Vec<Vec<u8>>,
tls13_server_intermediates_der: Vec<Vec<u8>>,
tls13_server_validation_time: Option<String>,
tls13_server_expected_hostname: Option<String>,
tls13_client_server_name: Option<String>,
tls13_request_ocsp_stapling: bool,
tls13_require_ocsp_staple: bool,
tls13_ocsp_staple_verifier: Option<Tls13OcspStapleVerifier>,
tls13_server_ocsp_staple: Option<Vec<u8>>,
tls13_server_ocsp_staple_verified: bool,
tls13_require_server_name_ack: bool,
tls13_server_name_acknowledged: bool,
tls13_client_alpn_protocols: Vec<Vec<u8>>,
tls13_selected_alpn_protocol: Option<Vec<u8>>,
tls13_server_leaf_public_key_der: Option<Vec<u8>>,
tls13_server_certificate_chain_validated: bool,
tls13_early_data_require_acceptance: bool,
tls13_early_data_accepted_psk: Option<Vec<u8>>,
tls13_early_data_max_bytes: Option<u32>,
tls13_early_data_opened_bytes: u64,
tls13_early_data_offered_in_client_hello: bool,
tls13_early_data_accepted_in_encrypted_extensions: bool,
tls13_early_data_anti_replay_enabled: bool,
tls13_early_data_replay_window: DtlsReplayWindow,
tls13_early_data_telemetry: Tls13EarlyDataTelemetry,
tls12_change_cipher_spec_seen: bool,
tls12_session_id: Option<Vec<u8>>,
tls12_allow_legacy_record_versions: bool,
dtls13_client_write_key: Option<[u8; 16]>,
dtls13_client_write_iv: Option<[u8; 12]>,
dtls13_server_write_key: Option<[u8; 16]>,
dtls13_server_write_iv: Option<[u8; 12]>,
dtls13_outbound_epoch: u16,
dtls13_outbound_sequence: u64,
dtls13_inbound_replay_tracker: DtlsEpochReplayTracker,
dtls13_client_inbound_replay_tracker: DtlsEpochReplayTracker,
dtls13_active_flight: Vec<(u16, u64)>,
dtls13_active_flight_started_at_ms: Option<u64>,
dtls13_active_flight_timeout_ms: u64,
dtls13_active_flight_failed: bool,
dtls_retransmit_tracker: DtlsFlightRetransmitTracker,
dtls_retransmit_initial_timeout_ms: u64,
dtls_max_retransmit_attempts: u8,
dtls12_handshake_phase: Dtls12HandshakePhase,
dtls12_expected_cookie: Option<Vec<u8>>,
dtls12_anti_amplification_enforced: bool,
dtls12_inbound_bytes: u64,
dtls12_outbound_bytes: u64,
max_record_plaintext_len: usize,
}
#[derive(Debug, Clone)]
pub struct ProtectedRecord {
pub sequence: u64,
pub ciphertext: Vec<u8>,
pub tag: [u8; 16],
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct DtlsOperationalPolicy {
pub retransmit_initial_timeout_ms: u64,
pub max_retransmit_attempts: u8,
pub active_flight_timeout_ms: u64,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum DtlsOperationalProfile {
Conservative,
LanLowLatency,
LossyNetwork,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Tls13EarlyDataOperationalPolicy {
pub require_acceptance: bool,
pub anti_replay_enabled: bool,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Tls13EarlyDataOperationalProfile {
Compatibility,
Strict,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)]
pub struct Tls13EarlyDataTelemetry {
pub accepted_records: u64,
pub rejected_missing_acceptance: u64,
pub rejected_psk_mismatch: u64,
pub rejected_replay_or_too_old: u64,
pub rejected_invalid_input: u64,
pub rejected_decrypt_or_policy: u64,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)]
pub struct Tls13EarlyDataReplayState {
pub latest_sequence: u64,
pub bitmap: u64,
pub initialized: bool,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Tls13QuicInitialSecrets {
pub initial_secret: Vec<u8>,
pub client_initial_secret: Vec<u8>,
pub server_initial_secret: Vec<u8>,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Tls13QuicPacketProtectionKeys {
pub key: Vec<u8>,
pub iv: Vec<u8>,
pub header_protection_key: Vec<u8>,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Tls13QuicTrafficSecretSnapshot {
pub client_handshake_secret: Vec<u8>,
pub server_handshake_secret: Vec<u8>,
pub client_application_secret: Vec<u8>,
pub server_application_secret: Vec<u8>,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Tls13QuicNextTrafficSecrets {
pub client_next_application_secret: Vec<u8>,
pub server_next_application_secret: Vec<u8>,
}
pub const TLS13_QUIC_EXPORTER_LABEL_CLIENT_1RTT: &[u8] = b"EXPORTER-QUIC client 1rtt";
pub const TLS13_QUIC_EXPORTER_LABEL_SERVER_1RTT: &[u8] = b"EXPORTER-QUIC server 1rtt";
const TLS13_QUIC_V1_INITIAL_SALT: [u8; 20] = [
0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad,
0xcc, 0xbb, 0x7f, 0x0a,
];
const HANDSHAKE_CLIENT_HELLO: u8 = 0x01;
const HANDSHAKE_SERVER_HELLO: u8 = 0x02;
const HANDSHAKE_HELLO_VERIFY_REQUEST: u8 = 0x03;
const HANDSHAKE_NEW_SESSION_TICKET: u8 = 0x04;
const HANDSHAKE_ENCRYPTED_EXTENSIONS: u8 = 0x08;
const HANDSHAKE_CERTIFICATE: u8 = 0x0B;
const HANDSHAKE_SERVER_KEY_EXCHANGE: u8 = 0x0C;
const HANDSHAKE_CERTIFICATE_REQUEST: u8 = 0x0D;
const HANDSHAKE_SERVER_HELLO_DONE: u8 = 0x0E;
const HANDSHAKE_CLIENT_KEY_EXCHANGE: u8 = 0x10;
const HANDSHAKE_CERTIFICATE_VERIFY: u8 = 0x0F;
const HANDSHAKE_FINISHED: u8 = 0x14;
const HANDSHAKE_KEY_UPDATE: u8 = 0x18;
const EXT_SERVER_NAME: u16 = 0x0000;
const EXT_STATUS_REQUEST: u16 = 0x0005;
const EXT_ALPN: u16 = 0x0010;
const EXT_SUPPORTED_VERSIONS: u16 = 0x002B;
const EXT_SIGNATURE_ALGORITHMS: u16 = 0x000D;
const EXT_KEY_SHARE: u16 = 0x0033;
const EXT_PSK_KEY_EXCHANGE_MODES: u16 = 0x002D;
const EXT_PRE_SHARED_KEY: u16 = 0x0029;
const EXT_EARLY_DATA: u16 = 0x002A;
const TLS13_KEY_SHARE_GROUP_SECP256R1: u16 = 0x0017;
const TLS13_KEY_SHARE_GROUP_X25519: u16 = 0x001D;
const TLS13_KEY_SHARE_GROUP_MLKEM768: u16 = 0x0201;
const TLS13_KEY_SHARE_GROUP_X25519_MLKEM768_HYBRID: u16 = 0x11EC;
const TLS13_PSK_KEY_EXCHANGE_MODE_PSK_DHE_KE: u8 = 0x01;
const TLS13_SIGALG_ECDSA_SECP256R1_SHA256: u16 = 0x0403;
const TLS13_SIGALG_RSA_PSS_RSAE_SHA256: u16 = 0x0804;
const TLS13_SIGALG_RSA_PSS_RSAE_SHA384: u16 = 0x0805;
const TLS13_SIGALG_ED25519: u16 = 0x0807;
const TLS13_SIGALG_MLDSA65: u16 = 0x0905;
const TLS13_MAX_EXTENSION_VALUE_BYTES: usize = 16_384;
const TLS_MAX_RECORD_PLAINTEXT_LEN: usize = 16_384;
const DTLS_RETRANSMIT_TRACKER_MAX_RECORDS: usize = 256;
const DTLS_RETRANSMIT_INITIAL_TIMEOUT_MS: u64 = 1_000;
const DTLS_MAX_RETRANSMIT_ATTEMPTS: u8 = 4;
const DTLS13_ACTIVE_FLIGHT_TIMEOUT_MS: u64 = 10_000;
const DTLS13_MAX_SEQUENCE: u64 = (1_u64 << 48) - 1;
const DTLS12_MAX_COOKIE_LEN: usize = 255;
const DTLS12_ANTI_AMPLIFICATION_FACTOR: u64 = 3;
const TLS13_HRR_RANDOM: [u8; 32] = [
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
];
#[derive(Debug, Clone, Eq, PartialEq, Default)]
pub struct ClientHelloExtensions {
pub supported_versions: Vec<u16>,
pub signature_algorithms: Vec<u16>,
pub key_share_groups: Vec<u16>,
pub sni_server_name: Option<String>,
pub alpn_protocols: Vec<Vec<u8>>,
pub status_request_ocsp: bool,
pub psk_key_exchange_modes: Vec<u8>,
pub psk_identity_count: usize,
pub psk_identities: Vec<Vec<u8>>,
pub psk_obfuscated_ticket_ages: Vec<u32>,
pub psk_binders: Vec<Vec<u8>>,
pub early_data_offered: bool,
}
#[derive(Debug, Clone, Eq, PartialEq, Default)]
pub struct ClientHelloInfo {
pub offered_cipher_suites: Vec<CipherSuite>,
pub extensions: ClientHelloExtensions,
}
struct PskIdentityOffer<'a> {
identity: &'a [u8],
obfuscated_ticket_age: u32,
}
struct PskClientOffer<'a> {
identities: Vec<PskIdentityOffer<'a>>,
binders: Vec<&'a [u8]>,
}
#[derive(Debug, Clone, Eq, PartialEq, Default)]
struct Tls13ClientPublicKeyShares {
x25519: Option<[u8; 32]>,
secp256r1_uncompressed: Option<[u8; 65]>,
mlkem768: Option<Vec<u8>>,
x25519_mlkem768_hybrid: Option<Vec<u8>>,
}
#[derive(Debug, Clone, Eq, PartialEq)]
enum Tls13ServerKeyShareParsed {
X25519([u8; 32]),
Secp256r1([u8; 65]),
MlKem768(Vec<u8>),
X25519MlKem768Hybrid { x25519: [u8; 32], mlkem768: Vec<u8> },
}
struct ParsedServerHello {
suite: CipherSuite,
key_share: Option<Tls13ServerKeyShareParsed>,
hello_retry_request: bool,
requested_group: Option<u16>,
}
struct ParsedEncryptedExtensions {
selected_alpn_protocol: Option<Vec<u8>>,
server_name_acknowledged: bool,
early_data_accepted: bool,
}
struct ParsedTls13CertificateBody {
certificates: Vec<Vec<u8>>,
leaf_ocsp_staple: Option<Vec<u8>>,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Tls13OcspStapleVerification {
Good,
Expired,
Revoked,
}
pub type Tls13OcspStapleVerifier = fn(&[u8]) -> Result<Tls13OcspStapleVerification>;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum Dtls12HandshakePhase {
AwaitingClientHello,
AwaitingClientHelloWithCookie,
AwaitingClientKeyExchange,
AwaitingFinished,
Connected,
}
#[derive(Debug, Clone)]
enum TranscriptHashState {
Sha256(TlsTranscriptSha256),
Sha384(TlsTranscriptSha384),
}
impl TranscriptHashState {
fn for_version(version: TlsVersion) -> Self {
match version {
TlsVersion::Tls13 | TlsVersion::Dtls13 => Self::Sha384(TlsTranscriptSha384::new()),
TlsVersion::Tls10 | TlsVersion::Tls11 | TlsVersion::Tls12 | TlsVersion::Dtls12 => {
Self::Sha256(TlsTranscriptSha256::new())
}
}
}
fn update(&mut self, message: &[u8]) {
match self {
Self::Sha256(hasher) => hasher.update(message),
Self::Sha384(hasher) => hasher.update(message),
}
}
fn snapshot_hash(&self) -> Vec<u8> {
match self {
Self::Sha256(hasher) => hasher.snapshot_hash().to_vec(),
Self::Sha384(hasher) => hasher.snapshot_hash().to_vec(),
}
}
fn algorithm(&self) -> HashAlgorithm {
match self {
Self::Sha256(_) => HashAlgorithm::Sha256,
Self::Sha384(_) => HashAlgorithm::Sha384,
}
}
}
impl CipherSuite {
fn from_u16(codepoint: u16) -> Option<Self> {
match codepoint {
0x1301 => Some(Self::TlsAes128GcmSha256),
0x1302 => Some(Self::TlsAes256GcmSha384),
0x1303 => Some(Self::TlsChacha20Poly1305Sha256),
0xC02F => Some(Self::TlsEcdheRsaWithAes128GcmSha256),
0xC030 => Some(Self::TlsEcdheRsaWithAes256GcmSha384),
_ => None,
}
}
fn transcript_hash_state(self) -> TranscriptHashState {
match self {
Self::TlsAes128GcmSha256
| Self::TlsChacha20Poly1305Sha256
| Self::TlsEcdheRsaWithAes128GcmSha256 => {
TranscriptHashState::Sha256(TlsTranscriptSha256::new())
}
Self::TlsAes256GcmSha384 | Self::TlsEcdheRsaWithAes256GcmSha384 => {
TranscriptHashState::Sha384(TlsTranscriptSha384::new())
}
}
}
fn hash_algorithm(self) -> HashAlgorithm {
match self {
Self::TlsAes128GcmSha256
| Self::TlsChacha20Poly1305Sha256
| Self::TlsEcdheRsaWithAes128GcmSha256 => HashAlgorithm::Sha256,
Self::TlsAes256GcmSha384 | Self::TlsEcdheRsaWithAes256GcmSha384 => {
HashAlgorithm::Sha384
}
}
}
fn tls13_traffic_key_len(self) -> Option<usize> {
match self {
CipherSuite::TlsAes128GcmSha256 => Some(16),
CipherSuite::TlsAes256GcmSha384 | CipherSuite::TlsChacha20Poly1305Sha256 => Some(32),
CipherSuite::TlsEcdheRsaWithAes128GcmSha256
| CipherSuite::TlsEcdheRsaWithAes256GcmSha384 => None,
}
}
fn to_u16(self) -> u16 {
match self {
Self::TlsAes128GcmSha256 => 0x1301,
Self::TlsAes256GcmSha384 => 0x1302,
Self::TlsChacha20Poly1305Sha256 => 0x1303,
Self::TlsEcdheRsaWithAes128GcmSha256 => 0xC02F,
Self::TlsEcdheRsaWithAes256GcmSha384 => 0xC030,
}
}
}
impl Connection {
pub fn new(version: TlsVersion) -> Self {
Self {
version,
state: HandshakeState::Idle,
selected_cipher_suite: None,
transcript: Vec::new(),
transcript_hash: TranscriptHashState::for_version(version),
handshake_secret: None,
tls13_master_secret: None,
tls13_client_handshake_traffic_secret: None,
tls13_server_handshake_traffic_secret: None,
tls13_finished_key: None,
tls13_client_application_traffic_secret: None,
tls13_server_application_traffic_secret: None,
tls13_exporter_master_secret: None,
tls13_resumption_master_secret: None,
tls13_client_x25519_private: None,
tls13_client_p256_private: None,
tls13_client_mlkem768_private: None,
tls13_shared_secret: None,
tls13_hrr_requested_group: None,
tls13_hrr_seen: false,
client_write_key: None,
server_write_key: None,
client_write_iv: None,
server_write_iv: None,
client_sequence: 0,
server_sequence: 0,
tls13_peer_close_notify_received: false,
tls13_local_close_notify_sent: false,
tls13_require_certificate_auth: false,
tls13_server_trust_anchors_der: Vec::new(),
tls13_server_intermediates_der: Vec::new(),
tls13_server_validation_time: None,
tls13_server_expected_hostname: None,
tls13_client_server_name: None,
tls13_request_ocsp_stapling: false,
tls13_require_ocsp_staple: false,
tls13_ocsp_staple_verifier: None,
tls13_server_ocsp_staple: None,
tls13_server_ocsp_staple_verified: false,
tls13_require_server_name_ack: false,
tls13_server_name_acknowledged: false,
tls13_client_alpn_protocols: Vec::new(),
tls13_selected_alpn_protocol: None,
tls13_server_leaf_public_key_der: None,
tls13_server_certificate_chain_validated: false,
tls13_early_data_require_acceptance: false,
tls13_early_data_accepted_psk: None,
tls13_early_data_max_bytes: None,
tls13_early_data_opened_bytes: 0,
tls13_early_data_offered_in_client_hello: false,
tls13_early_data_accepted_in_encrypted_extensions: false,
tls13_early_data_anti_replay_enabled: true,
tls13_early_data_replay_window: DtlsReplayWindow::new(),
tls13_early_data_telemetry: Tls13EarlyDataTelemetry::default(),
tls12_change_cipher_spec_seen: false,
tls12_session_id: None,
tls12_allow_legacy_record_versions: false,
dtls13_client_write_key: None,
dtls13_client_write_iv: None,
dtls13_server_write_key: None,
dtls13_server_write_iv: None,
dtls13_outbound_epoch: 0,
dtls13_outbound_sequence: 0,
dtls13_inbound_replay_tracker: DtlsEpochReplayTracker::new(),
dtls13_client_inbound_replay_tracker: DtlsEpochReplayTracker::new(),
dtls13_active_flight: Vec::new(),
dtls13_active_flight_started_at_ms: None,
dtls13_active_flight_timeout_ms: DTLS13_ACTIVE_FLIGHT_TIMEOUT_MS,
dtls13_active_flight_failed: false,
dtls_retransmit_tracker: DtlsFlightRetransmitTracker::new(
DTLS_RETRANSMIT_TRACKER_MAX_RECORDS,
),
dtls_retransmit_initial_timeout_ms: DTLS_RETRANSMIT_INITIAL_TIMEOUT_MS,
dtls_max_retransmit_attempts: DTLS_MAX_RETRANSMIT_ATTEMPTS,
dtls12_handshake_phase: Dtls12HandshakePhase::AwaitingClientHello,
dtls12_expected_cookie: None,
dtls12_anti_amplification_enforced: true,
dtls12_inbound_bytes: 0,
dtls12_outbound_bytes: 0,
max_record_plaintext_len: TLS_MAX_RECORD_PLAINTEXT_LEN,
}
}
#[must_use]
pub fn dtls_operational_policy(&self) -> Option<DtlsOperationalPolicy> {
if !self.version.is_dtls() {
return None;
}
Some(DtlsOperationalPolicy {
retransmit_initial_timeout_ms: self.dtls_retransmit_initial_timeout_ms,
max_retransmit_attempts: self.dtls_max_retransmit_attempts,
active_flight_timeout_ms: self.dtls13_active_flight_timeout_ms,
})
}
pub fn set_dtls_operational_policy(
&mut self,
policy: DtlsOperationalPolicy,
) -> Result<DtlsOperationalPolicy> {
self.ensure_dtls12_mode()?;
let effective = DtlsOperationalPolicy {
retransmit_initial_timeout_ms: policy.retransmit_initial_timeout_ms.max(1),
max_retransmit_attempts: policy.max_retransmit_attempts.max(1),
active_flight_timeout_ms: policy.active_flight_timeout_ms.max(1),
};
self.dtls_retransmit_initial_timeout_ms = effective.retransmit_initial_timeout_ms;
self.dtls_max_retransmit_attempts = effective.max_retransmit_attempts;
self.dtls13_active_flight_timeout_ms = effective.active_flight_timeout_ms;
Ok(effective)
}
pub fn apply_dtls_operational_profile(
&mut self,
profile: DtlsOperationalProfile,
) -> Result<DtlsOperationalPolicy> {
let policy = match profile {
DtlsOperationalProfile::Conservative => DtlsOperationalPolicy {
retransmit_initial_timeout_ms: DTLS_RETRANSMIT_INITIAL_TIMEOUT_MS,
max_retransmit_attempts: DTLS_MAX_RETRANSMIT_ATTEMPTS,
active_flight_timeout_ms: DTLS13_ACTIVE_FLIGHT_TIMEOUT_MS,
},
DtlsOperationalProfile::LanLowLatency => DtlsOperationalPolicy {
retransmit_initial_timeout_ms: 250,
max_retransmit_attempts: 3,
active_flight_timeout_ms: 3_000,
},
DtlsOperationalProfile::LossyNetwork => DtlsOperationalPolicy {
retransmit_initial_timeout_ms: 1_500,
max_retransmit_attempts: 6,
active_flight_timeout_ms: 20_000,
},
};
self.set_dtls_operational_policy(policy)
}
pub fn set_tls13_require_certificate_auth(&mut self, required: bool) {
self.tls13_require_certificate_auth = required;
}
pub fn configure_tls13_server_auth(
&mut self,
trust_anchors_der: &[Vec<u8>],
intermediates_der: &[Vec<u8>],
validation_time: &str,
) -> Result<()> {
if trust_anchors_der.is_empty() {
return Err(Error::InvalidLength(
"tls13 trust anchor list must not be empty",
));
}
if validation_time.is_empty() {
return Err(Error::InvalidLength(
"tls13 validation time must not be empty",
));
}
self.tls13_server_trust_anchors_der = trust_anchors_der.to_vec();
self.tls13_server_intermediates_der = intermediates_der.to_vec();
self.tls13_server_validation_time = Some(validation_time.to_owned());
Ok(())
}
pub fn set_tls13_server_expected_hostname(&mut self, hostname: Option<&str>) -> Result<()> {
match hostname {
Some(value) if value.is_empty() => Err(Error::InvalidLength(
"tls13 expected hostname must not be empty",
)),
Some(value) => {
self.tls13_server_expected_hostname = Some(value.to_owned());
Ok(())
}
None => {
self.tls13_server_expected_hostname = None;
Ok(())
}
}
}
pub fn set_tls12_session_id(&mut self, session_id: Option<&[u8]>) -> Result<()> {
match session_id {
Some(value) if value.is_empty() => Err(Error::InvalidLength(
"tls12 session id must not be empty when present",
)),
Some(value) if value.len() > 32 => Err(Error::InvalidLength(
"tls12 session id must not exceed 32 bytes",
)),
Some(value) => {
self.tls12_session_id = Some(value.to_vec());
Ok(())
}
None => {
self.tls12_session_id = None;
Ok(())
}
}
}
#[must_use]
pub fn tls12_session_id(&self) -> Option<&[u8]> {
self.tls12_session_id.as_deref()
}
pub fn set_tls12_allow_legacy_record_versions(&mut self, allow: bool) {
self.tls12_allow_legacy_record_versions = allow;
}
pub fn set_tls13_server_name(&mut self, server_name: Option<&str>) -> Result<()> {
match server_name {
Some(name) if name.is_empty() => {
Err(Error::InvalidLength("sni server_name must not be empty"))
}
Some(name) if name.len() > u16::MAX as usize => Err(Error::InvalidLength(
"sni server_name length must not exceed 65535 bytes",
)),
Some(name) => {
if !is_valid_sni_dns_name(name) {
return Err(Error::ParseFailure("invalid sni server_name"));
}
self.tls13_client_server_name = Some(name.to_owned());
self.tls13_server_name_acknowledged = false;
Ok(())
}
None => {
self.tls13_client_server_name = None;
self.tls13_server_name_acknowledged = false;
Ok(())
}
}
}
pub fn set_tls13_request_ocsp_stapling(&mut self, enabled: bool) {
self.tls13_request_ocsp_stapling = enabled;
}
pub fn set_tls13_require_ocsp_staple(&mut self, required: bool) {
self.tls13_require_ocsp_staple = required;
}
pub fn set_tls13_ocsp_staple_verifier(&mut self, verifier: Option<Tls13OcspStapleVerifier>) {
self.tls13_ocsp_staple_verifier = verifier;
}
#[must_use]
pub fn tls13_server_ocsp_staple(&self) -> Option<&[u8]> {
self.tls13_server_ocsp_staple.as_deref()
}
#[must_use]
pub fn tls13_server_ocsp_staple_verified(&self) -> bool {
self.tls13_server_ocsp_staple_verified
}
pub fn set_tls13_require_server_name_ack(&mut self, required: bool) {
self.tls13_require_server_name_ack = required;
}
#[must_use]
pub fn tls13_server_name_acknowledged(&self) -> bool {
self.tls13_server_name_acknowledged
}
pub fn set_tls13_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
let mut parsed_protocols = Vec::with_capacity(protocols.len());
for protocol in protocols {
if protocol.is_empty() {
return Err(Error::InvalidLength("alpn protocol must not be empty"));
}
if protocol.len() > u8::MAX as usize {
return Err(Error::InvalidLength(
"alpn protocol length must not exceed 255 bytes",
));
}
let encoded = protocol.as_bytes().to_vec();
if parsed_protocols.contains(&encoded) {
return Err(Error::ParseFailure("duplicate alpn protocol"));
}
parsed_protocols.push(encoded);
}
self.tls13_client_alpn_protocols = parsed_protocols;
self.tls13_selected_alpn_protocol = None;
Ok(())
}
#[must_use]
pub fn tls13_selected_alpn_protocol(&self) -> Option<&[u8]> {
self.tls13_selected_alpn_protocol.as_deref()
}
pub fn set_max_record_plaintext_len(&mut self, max_len: usize) -> Result<()> {
if max_len == 0 || max_len > TLS_MAX_RECORD_PLAINTEXT_LEN {
return Err(Error::InvalidLength(
"record plaintext limit must be between 1 and 16384 bytes",
));
}
self.max_record_plaintext_len = max_len;
Ok(())
}
pub fn set_tls13_early_data_anti_replay_enabled(&mut self, enabled: bool) {
self.tls13_early_data_anti_replay_enabled = enabled;
if enabled {
self.tls13_early_data_replay_window = DtlsReplayWindow::new();
}
}
pub fn set_tls13_require_early_data_acceptance(&mut self, required: bool) {
self.tls13_early_data_require_acceptance = required;
self.tls13_early_data_accepted_psk = None;
self.tls13_early_data_max_bytes = None;
self.tls13_early_data_opened_bytes = 0;
self.tls13_early_data_accepted_in_encrypted_extensions = false;
}
pub fn set_tls13_early_data_operational_profile(
&mut self,
profile: Tls13EarlyDataOperationalProfile,
) {
let policy = match profile {
Tls13EarlyDataOperationalProfile::Compatibility => Tls13EarlyDataOperationalPolicy {
require_acceptance: false,
anti_replay_enabled: false,
},
Tls13EarlyDataOperationalProfile::Strict => Tls13EarlyDataOperationalPolicy {
require_acceptance: true,
anti_replay_enabled: true,
},
};
self.set_tls13_early_data_operational_policy(policy);
}
pub fn set_tls13_early_data_operational_policy(
&mut self,
policy: Tls13EarlyDataOperationalPolicy,
) {
self.set_tls13_require_early_data_acceptance(policy.require_acceptance);
self.set_tls13_early_data_anti_replay_enabled(policy.anti_replay_enabled);
}
#[must_use]
pub fn tls13_early_data_operational_policy(&self) -> Tls13EarlyDataOperationalPolicy {
Tls13EarlyDataOperationalPolicy {
require_acceptance: self.tls13_early_data_require_acceptance,
anti_replay_enabled: self.tls13_early_data_anti_replay_enabled,
}
}
#[must_use]
pub fn tls13_early_data_telemetry(&self) -> Tls13EarlyDataTelemetry {
self.tls13_early_data_telemetry
}
pub fn reset_tls13_early_data_telemetry(&mut self) {
self.tls13_early_data_telemetry = Tls13EarlyDataTelemetry::default();
}
#[must_use]
pub fn export_tls13_early_data_replay_state(&self) -> Tls13EarlyDataReplayState {
let snapshot = self.tls13_early_data_replay_window.snapshot();
Tls13EarlyDataReplayState {
latest_sequence: snapshot.latest_sequence,
bitmap: snapshot.bitmap,
initialized: snapshot.initialized,
}
}
pub fn import_tls13_early_data_replay_state(
&mut self,
state: Tls13EarlyDataReplayState,
) -> Result<()> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 early-data replay state requires TLS 1.3 connection",
));
}
self.tls13_early_data_replay_window
.restore_from_snapshot(DtlsReplayWindowSnapshot {
latest_sequence: state.latest_sequence,
bitmap: state.bitmap,
initialized: state.initialized,
});
Ok(())
}
pub fn send_client_hello(&mut self, random: &[u8]) -> Result<Vec<u8>> {
if self.state != HandshakeState::Idle {
return Err(Error::StateError("client hello can only be sent from idle"));
}
if random.len() != 32 {
return Err(Error::InvalidLength("client hello random must be 32 bytes"));
}
self.reset_transcript_for_new_handshake();
self.validate_tls13_hrr_retry_group_support()?;
self.reset_tls13_certificate_auth_state();
let key_shares = self.prepare_client_key_share(random)?;
let client_hello_body = encode_client_hello_body(
self.version,
random,
&default_client_cipher_suites(self.version),
&key_shares,
self.tls13_client_server_name.as_deref(),
&self.tls13_client_alpn_protocols,
self.tls13_request_ocsp_stapling,
false,
None,
self.tls12_session_id.as_deref(),
)?;
let msg = encode_handshake_message(HANDSHAKE_CLIENT_HELLO, &client_hello_body);
self.append_transcript(&msg);
self.state = HandshakeState::ClientHelloSent;
self.tls13_early_data_offered_in_client_hello = false;
self.tls13_early_data_accepted_in_encrypted_extensions = false;
Ok(msg)
}
pub fn send_client_hello_with_psk(
&mut self,
random: &[u8],
identity: &[u8],
obfuscated_ticket_age: u32,
psk: &[u8],
) -> Result<Vec<u8>> {
self.send_client_hello_with_psk_internal(
random,
identity,
obfuscated_ticket_age,
psk,
false,
)
}
fn send_client_hello_with_psk_internal(
&mut self,
random: &[u8],
identity: &[u8],
obfuscated_ticket_age: u32,
psk: &[u8],
offer_early_data: bool,
) -> Result<Vec<u8>> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"psk client hello is currently only modeled for TLS 1.3",
));
}
if self.state != HandshakeState::Idle {
return Err(Error::StateError("client hello can only be sent from idle"));
}
if random.len() != 32 {
return Err(Error::InvalidLength("client hello random must be 32 bytes"));
}
if identity.is_empty() {
return Err(Error::InvalidLength("psk identity must not be empty"));
}
if psk.is_empty() {
return Err(Error::InvalidLength("psk must not be empty"));
}
self.reset_transcript_for_new_handshake();
self.validate_tls13_hrr_retry_group_support()?;
self.reset_tls13_certificate_auth_state();
let binder_len = self.negotiated_hash_algorithm().output_len();
let placeholder = vec![0_u8; binder_len];
let placeholder_offer = PskClientOffer {
identities: vec![PskIdentityOffer {
identity,
obfuscated_ticket_age,
}],
binders: vec![placeholder.as_slice()],
};
let key_shares = self.prepare_client_key_share(random)?;
let placeholder_body = encode_client_hello_body(
self.version,
random,
&default_client_cipher_suites(self.version),
&key_shares,
self.tls13_client_server_name.as_deref(),
&self.tls13_client_alpn_protocols,
self.tls13_request_ocsp_stapling,
offer_early_data,
Some(&placeholder_offer),
self.tls12_session_id.as_deref(),
)?;
let placeholder_msg = encode_handshake_message(HANDSHAKE_CLIENT_HELLO, &placeholder_body);
let binder = self.compute_tls13_psk_binder(psk, &placeholder_msg)?;
let final_offer = PskClientOffer {
identities: vec![PskIdentityOffer {
identity,
obfuscated_ticket_age,
}],
binders: vec![binder.as_slice()],
};
let final_body = encode_client_hello_body(
self.version,
random,
&default_client_cipher_suites(self.version),
&key_shares,
self.tls13_client_server_name.as_deref(),
&self.tls13_client_alpn_protocols,
self.tls13_request_ocsp_stapling,
offer_early_data,
Some(&final_offer),
self.tls12_session_id.as_deref(),
)?;
let msg = encode_handshake_message(HANDSHAKE_CLIENT_HELLO, &final_body);
self.append_transcript(&msg);
self.state = HandshakeState::ClientHelloSent;
self.tls13_early_data_offered_in_client_hello = offer_early_data;
self.tls13_early_data_accepted_in_encrypted_extensions = false;
Ok(msg)
}
pub fn send_client_hello_with_resumption_tickets(
&mut self,
random: &[u8],
tickets: &[ResumptionTicket],
) -> Result<Vec<u8>> {
let mut obfuscated_ages = Vec::with_capacity(tickets.len());
for ticket in tickets {
obfuscated_ages.push(ticket.obfuscated_ticket_age);
}
self.send_client_hello_with_resumption_tickets_with_ages(random, tickets, &obfuscated_ages)
}
pub fn send_client_hello_with_resumption_tickets_at(
&mut self,
random: &[u8],
tickets: &[ResumptionTicket],
current_time_ms: u64,
) -> Result<Vec<u8>> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"psk client hello is currently only modeled for TLS 1.3",
));
}
if self.state != HandshakeState::Idle {
return Err(Error::StateError("client hello can only be sent from idle"));
}
if random.len() != 32 {
return Err(Error::InvalidLength("client hello random must be 32 bytes"));
}
if tickets.is_empty() {
return Err(Error::InvalidLength("ticket list must not be empty"));
}
let mut obfuscated_ages = Vec::with_capacity(tickets.len());
for ticket in tickets {
let elapsed_ms = current_time_ms.saturating_sub(ticket.issued_at_ms);
let elapsed_u32 = elapsed_ms.min(u64::from(u32::MAX)) as u32;
obfuscated_ages.push(ticket.age_add.wrapping_add(elapsed_u32));
}
self.send_client_hello_with_resumption_tickets_with_ages(random, tickets, &obfuscated_ages)
}
fn send_client_hello_with_resumption_tickets_with_ages(
&mut self,
random: &[u8],
tickets: &[ResumptionTicket],
obfuscated_ages: &[u32],
) -> Result<Vec<u8>> {
self.reset_transcript_for_new_handshake();
self.validate_tls13_hrr_retry_group_support()?;
self.reset_tls13_certificate_auth_state();
let hash_len = self.negotiated_hash_algorithm().output_len();
let mut psk_identities = Vec::with_capacity(tickets.len());
let mut psks = Vec::with_capacity(tickets.len());
for (ticket, obfuscated_age) in tickets.iter().zip(obfuscated_ages.iter().copied()) {
psk_identities.push(PskIdentityOffer {
identity: ticket.identity.as_slice(),
obfuscated_ticket_age: obfuscated_age,
});
psks.push(self.derive_tls13_resumption_psk(&ticket.ticket_nonce)?);
}
let zero_binders: Vec<Vec<u8>> = (0..tickets.len()).map(|_| vec![0_u8; hash_len]).collect();
let zero_binder_refs: Vec<&[u8]> = zero_binders.iter().map(Vec::as_slice).collect();
let placeholder_offer = PskClientOffer {
identities: psk_identities,
binders: zero_binder_refs,
};
let key_shares = self.prepare_client_key_share(random)?;
let placeholder_body = encode_client_hello_body(
self.version,
random,
&default_client_cipher_suites(self.version),
&key_shares,
self.tls13_client_server_name.as_deref(),
&self.tls13_client_alpn_protocols,
self.tls13_request_ocsp_stapling,
tickets.iter().any(|ticket| ticket.max_early_data_size > 0),
Some(&placeholder_offer),
self.tls12_session_id.as_deref(),
)?;
let placeholder_msg = encode_handshake_message(HANDSHAKE_CLIENT_HELLO, &placeholder_body);
let mut binders = Vec::with_capacity(psks.len());
for psk in &psks {
binders.push(self.compute_tls13_psk_binder(psk, &placeholder_msg)?);
}
let binder_refs: Vec<&[u8]> = binders.iter().map(Vec::as_slice).collect();
let final_offer = PskClientOffer {
identities: placeholder_offer.identities,
binders: binder_refs,
};
let final_body = encode_client_hello_body(
self.version,
random,
&default_client_cipher_suites(self.version),
&key_shares,
self.tls13_client_server_name.as_deref(),
&self.tls13_client_alpn_protocols,
self.tls13_request_ocsp_stapling,
tickets.iter().any(|ticket| ticket.max_early_data_size > 0),
Some(&final_offer),
self.tls12_session_id.as_deref(),
)?;
let msg = encode_handshake_message(HANDSHAKE_CLIENT_HELLO, &final_body);
self.append_transcript(&msg);
self.state = HandshakeState::ClientHelloSent;
self.tls13_early_data_offered_in_client_hello =
tickets.iter().any(|ticket| ticket.max_early_data_size > 0);
self.tls13_early_data_accepted_in_encrypted_extensions = false;
Ok(msg)
}
pub fn send_client_hello_with_resumption_ticket(
&mut self,
random: &[u8],
ticket: &ResumptionTicket,
) -> Result<Vec<u8>> {
let psk = self.derive_tls13_resumption_psk(&ticket.ticket_nonce)?;
self.send_client_hello_with_psk_internal(
random,
&ticket.identity,
ticket.obfuscated_ticket_age,
&psk,
ticket.max_early_data_size > 0,
)
}
pub fn send_client_hello_with_resumption_ticket_at(
&mut self,
random: &[u8],
ticket: &ResumptionTicket,
current_time_ms: u64,
) -> Result<Vec<u8>> {
let psk = self.derive_tls13_resumption_psk(&ticket.ticket_nonce)?;
let elapsed_ms = current_time_ms.saturating_sub(ticket.issued_at_ms);
let elapsed_u32 = elapsed_ms.min(u64::from(u32::MAX)) as u32;
let obfuscated_age = ticket.age_add.wrapping_add(elapsed_u32);
self.send_client_hello_with_psk_internal(
random,
&ticket.identity,
obfuscated_age,
&psk,
ticket.max_early_data_size > 0,
)
}
pub fn send_client_hello_auto(&mut self, drbg: &mut HmacDrbgSha256) -> Result<Vec<u8>> {
let random = drbg.generate(32, b"client_hello_random")?;
self.send_client_hello(&random)
}
pub fn send_client_hello_with_psk_auto(
&mut self,
drbg: &mut HmacDrbgSha256,
identity: &[u8],
obfuscated_ticket_age: u32,
psk: &[u8],
) -> Result<Vec<u8>> {
let random = drbg.generate(32, b"client_hello_random")?;
self.send_client_hello_with_psk(&random, identity, obfuscated_ticket_age, psk)
}
pub fn send_client_hello_with_resumption_ticket_auto(
&mut self,
drbg: &mut HmacDrbgSha256,
ticket: &ResumptionTicket,
) -> Result<Vec<u8>> {
let random = drbg.generate(32, b"client_hello_random")?;
self.send_client_hello_with_resumption_ticket(&random, ticket)
}
pub fn send_client_hello_with_resumption_ticket_at_auto(
&mut self,
drbg: &mut HmacDrbgSha256,
ticket: &ResumptionTicket,
current_time_ms: u64,
) -> Result<Vec<u8>> {
let random = drbg.generate(32, b"client_hello_random")?;
self.send_client_hello_with_resumption_ticket_at(&random, ticket, current_time_ms)
}
pub fn send_client_hello_with_resumption_tickets_auto(
&mut self,
drbg: &mut HmacDrbgSha256,
tickets: &[ResumptionTicket],
) -> Result<Vec<u8>> {
let random = drbg.generate(32, b"client_hello_random")?;
self.send_client_hello_with_resumption_tickets(&random, tickets)
}
pub fn send_client_hello_with_resumption_tickets_at_auto(
&mut self,
drbg: &mut HmacDrbgSha256,
tickets: &[ResumptionTicket],
current_time_ms: u64,
) -> Result<Vec<u8>> {
let random = drbg.generate(32, b"client_hello_random")?;
self.send_client_hello_with_resumption_tickets_at(&random, tickets, current_time_ms)
}
pub fn recv_server_hello(&mut self, msg: &[u8]) -> Result<()> {
if self.state != HandshakeState::ClientHelloSent {
return Err(Error::StateError(
"server hello can only be processed after client hello",
));
}
let parsed = parse_server_hello(msg)?;
if parsed.hello_retry_request {
if self.tls13_hrr_seen {
return Err(Error::ParseFailure("duplicate hello retry request"));
}
self.tls13_hrr_seen = true;
self.tls13_hrr_requested_group = parsed.requested_group;
self.reset_transcript_for_hrr();
self.append_transcript(msg);
self.state = HandshakeState::Idle;
return Ok(());
}
let selected_suite = parsed.suite;
self.tls13_hrr_seen = false;
self.tls13_hrr_requested_group = None;
let server_key_share = parsed.key_share;
if let Some(share) = server_key_share {
self.tls13_shared_secret = Some(match share {
Tls13ServerKeyShareParsed::X25519(peer_key_share) => {
let private = self
.tls13_client_x25519_private
.clone()
.ok_or(Error::StateError(
"client x25519 key share must be available before server x25519 key share",
))?;
derive_tls13_x25519_shared_secret(private, &peer_key_share)?
}
Tls13ServerKeyShareParsed::Secp256r1(peer_uncompressed) => {
let private = self.tls13_client_p256_private.as_ref().ok_or(
Error::StateError(
"client secp256r1 key share must be available before server secp256r1 key share",
),
)?;
derive_tls13_p256_shared_secret(private, &peer_uncompressed)?
}
Tls13ServerKeyShareParsed::MlKem768(peer_key_share) => {
let private =
self.tls13_client_mlkem768_private
.as_ref()
.ok_or(Error::StateError(
"client mlkem768 key share must be available before server mlkem768 key share",
))?;
derive_tls13_mlkem768_shared_secret(private, &peer_key_share)?
}
Tls13ServerKeyShareParsed::X25519MlKem768Hybrid { x25519, mlkem768 } => {
let x25519_private = self
.tls13_client_x25519_private
.clone()
.ok_or(Error::StateError(
"client x25519 key share must be available before server hybrid key share",
))?;
let x25519_shared = derive_tls13_x25519_shared_secret(x25519_private, &x25519)?;
let mlkem_private =
self.tls13_client_mlkem768_private
.as_ref()
.ok_or(Error::StateError(
"client mlkem768 key share must be available before server hybrid key share",
))?;
let mlkem_shared =
derive_tls13_mlkem768_shared_secret(mlkem_private, &mlkem768)?;
combine_tls13_hybrid_shared_secret(&x25519_shared, &mlkem_shared)
}
});
}
self.append_transcript(msg);
self.selected_cipher_suite = Some(selected_suite);
self.rebuild_transcript_hash_from_selected_suite();
self.state = HandshakeState::ServerHelloReceived;
Ok(())
}
pub fn build_hello_retry_request(suite: CipherSuite, requested_group: u16) -> Result<Vec<u8>> {
let mut body = Vec::new();
body.extend_from_slice(&legacy_wire_version(TlsVersion::Tls13));
body.extend_from_slice(&TLS13_HRR_RANDOM);
body.push(0x00); body.extend_from_slice(&suite.to_u16().to_be_bytes());
body.push(0x00); let mut extensions = Vec::new();
push_extension(
&mut extensions,
EXT_KEY_SHARE,
&requested_group.to_be_bytes(),
);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
Ok(encode_handshake_message(HANDSHAKE_SERVER_HELLO, &body))
}
pub fn recv_encrypted_extensions(&mut self, msg: &[u8]) -> Result<()> {
if self.state != HandshakeState::ServerHelloReceived {
return Err(Error::StateError(
"encrypted extensions can only be processed after server hello",
));
}
let (handshake_type, body) = parse_handshake_message(msg)?;
if handshake_type != HANDSHAKE_ENCRYPTED_EXTENSIONS {
return Err(Error::ParseFailure("invalid encrypted extensions type"));
}
let encrypted_extensions = parse_encrypted_extensions_body(body)?;
if encrypted_extensions.server_name_acknowledged && self.tls13_client_server_name.is_none()
{
return Err(Error::ParseFailure(
"encrypted extensions contains unsolicited server_name acknowledgement",
));
}
if self.tls13_require_server_name_ack
&& self.tls13_client_server_name.is_some()
&& !encrypted_extensions.server_name_acknowledged
{
return Err(Error::ParseFailure(
"encrypted extensions missing required server_name acknowledgement",
));
}
if encrypted_extensions.early_data_accepted
&& !self.tls13_early_data_offered_in_client_hello
{
return Err(Error::ParseFailure(
"encrypted extensions contains unsolicited early_data acceptance",
));
}
self.tls13_early_data_accepted_in_encrypted_extensions =
encrypted_extensions.early_data_accepted;
if self.tls13_early_data_offered_in_client_hello
&& !encrypted_extensions.early_data_accepted
{
self.tls13_early_data_accepted_psk = None;
self.tls13_early_data_max_bytes = None;
self.tls13_early_data_opened_bytes = 0;
self.tls13_early_data_replay_window = DtlsReplayWindow::new();
}
self.tls13_server_name_acknowledged = encrypted_extensions.server_name_acknowledged;
if let Some(selected_protocol) = encrypted_extensions.selected_alpn_protocol {
if !self.tls13_client_alpn_protocols.is_empty()
&& !self
.tls13_client_alpn_protocols
.contains(&selected_protocol)
{
return Err(Error::ParseFailure(
"encrypted extensions selected unsupported alpn protocol",
));
}
self.tls13_selected_alpn_protocol = Some(selected_protocol);
} else {
self.tls13_selected_alpn_protocol = None;
}
self.append_transcript(msg);
self.state = HandshakeState::ServerEncryptedExtensionsReceived;
Ok(())
}
pub fn build_certificate_request_message() -> Vec<u8> {
let mut extensions = Vec::new();
let mut sigalgs = Vec::new();
let requested_sigalgs = [
TLS13_SIGALG_ECDSA_SECP256R1_SHA256,
TLS13_SIGALG_RSA_PSS_RSAE_SHA256,
TLS13_SIGALG_RSA_PSS_RSAE_SHA384,
TLS13_SIGALG_ED25519,
TLS13_SIGALG_MLDSA65,
];
sigalgs.extend_from_slice(&((requested_sigalgs.len() * 2) as u16).to_be_bytes());
for sigalg in requested_sigalgs {
sigalgs.extend_from_slice(&sigalg.to_be_bytes());
}
push_extension(&mut extensions, EXT_SIGNATURE_ALGORITHMS, &sigalgs);
let mut body = Vec::new();
body.push(0x00); body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
encode_handshake_message(HANDSHAKE_CERTIFICATE_REQUEST, &body)
}
pub fn recv_certificate_request(&mut self, msg: &[u8]) -> Result<()> {
if self.state != HandshakeState::ServerEncryptedExtensionsReceived {
return Err(Error::StateError(
"certificate request can only be processed after encrypted extensions",
));
}
let (handshake_type, body) = parse_handshake_message(msg)?;
if handshake_type != HANDSHAKE_CERTIFICATE_REQUEST {
return Err(Error::ParseFailure("invalid certificate request type"));
}
parse_certificate_request_body(body)?;
self.append_transcript(msg);
self.state = HandshakeState::ServerCertificateRequestReceived;
Ok(())
}
pub fn build_encrypted_extensions() -> Vec<u8> {
Self::build_encrypted_extensions_with_policy(None, false, false)
.expect("empty encrypted extensions must always encode")
}
pub fn build_encrypted_extensions_with_alpn(selected_alpn: Option<&[u8]>) -> Result<Vec<u8>> {
Self::build_encrypted_extensions_with_policy(selected_alpn, false, false)
}
pub fn build_encrypted_extensions_with_alpn_and_early_data(
selected_alpn: Option<&[u8]>,
accept_early_data: bool,
) -> Result<Vec<u8>> {
Self::build_encrypted_extensions_with_policy(selected_alpn, false, accept_early_data)
}
pub fn build_encrypted_extensions_with_policy(
selected_alpn: Option<&[u8]>,
acknowledge_server_name: bool,
accept_early_data: bool,
) -> Result<Vec<u8>> {
let mut body = Vec::new();
let mut extensions = Vec::new();
if let Some(protocol) = selected_alpn {
if protocol.is_empty() {
return Err(Error::InvalidLength("alpn protocol must not be empty"));
}
if protocol.len() > u8::MAX as usize {
return Err(Error::InvalidLength(
"alpn protocol length must not exceed 255 bytes",
));
}
let protocols = vec![protocol.to_vec()];
let extension_data = encode_alpn_extension_data(&protocols)?;
push_extension(&mut extensions, EXT_ALPN, &extension_data);
}
if acknowledge_server_name {
push_extension(&mut extensions, EXT_SERVER_NAME, &[]);
}
if accept_early_data {
push_extension(&mut extensions, EXT_EARLY_DATA, &[]);
}
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
Ok(encode_handshake_message(
HANDSHAKE_ENCRYPTED_EXTENSIONS,
&body,
))
}
pub fn recv_certificate(&mut self, msg: &[u8]) -> Result<()> {
if self.state != HandshakeState::ServerEncryptedExtensionsReceived
&& self.state != HandshakeState::ServerCertificateRequestReceived
{
return Err(Error::StateError(
"certificate can only be processed after encrypted extensions/certificate request",
));
}
let (handshake_type, body) = parse_handshake_message(msg)?;
if handshake_type != HANDSHAKE_CERTIFICATE {
return Err(Error::ParseFailure("invalid certificate type"));
}
let parsed = parse_certificate_body(body)?;
self.tls13_server_ocsp_staple = parsed.leaf_ocsp_staple.clone();
self.tls13_server_ocsp_staple_verified = false;
if self.tls13_require_ocsp_staple && parsed.leaf_ocsp_staple.is_none() {
return Err(Error::ParseFailure(
"certificate message missing required ocsp staple",
));
}
if let Some(staple) = parsed.leaf_ocsp_staple.as_deref() {
if let Some(verifier) = self.tls13_ocsp_staple_verifier {
match verifier(staple)? {
Tls13OcspStapleVerification::Good => {
self.tls13_server_ocsp_staple_verified = true;
}
Tls13OcspStapleVerification::Expired => {
return Err(Error::ParseFailure("ocsp staple expired"));
}
Tls13OcspStapleVerification::Revoked => {
return Err(Error::ParseFailure("ocsp staple revoked"));
}
}
} else {
self.tls13_server_ocsp_staple_verified = true;
}
}
if self.tls13_require_certificate_auth {
self.validate_tls13_server_certificate_chain(&parsed.certificates)?;
}
self.append_transcript(msg);
self.state = HandshakeState::ServerCertificateReceived;
Ok(())
}
pub fn process_server_handshake_flight(&mut self, messages: &[Vec<u8>]) -> Result<()> {
if messages.len() < 5 {
return Err(Error::ParseFailure("server handshake flight is too short"));
}
let mut index = 0_usize;
self.recv_server_hello(&messages[index])?;
index += 1;
self.recv_encrypted_extensions(&messages[index])?;
index += 1;
let (next_type, _) = parse_handshake_message(&messages[index])?;
if next_type == HANDSHAKE_CERTIFICATE_REQUEST {
self.recv_certificate_request(&messages[index])?;
index += 1;
}
self.recv_certificate(&messages[index])?;
index += 1;
self.recv_certificate_verify(&messages[index])?;
index += 1;
self.derive_handshake_secret()?;
self.recv_finished_message(&messages[index])?;
index += 1;
if index != messages.len() {
return Err(Error::ParseFailure(
"unexpected trailing server handshake messages",
));
}
Ok(())
}
pub fn process_tls12_server_handshake_flight(&mut self, messages: &[Vec<u8>]) -> Result<()> {
if self.version != TlsVersion::Tls12 {
return Err(Error::StateError(
"tls12 server flight processing requires tls1.2 connection version",
));
}
if self.state != HandshakeState::ClientHelloSent {
return Err(Error::StateError(
"tls12 server flight can only be processed after client hello",
));
}
if messages.len() < 3 {
return Err(Error::ParseFailure(
"tls12 server handshake flight is too short",
));
}
let mut index = 0_usize;
self.recv_server_hello(&messages[index])?;
index += 1;
let (next_type, _body) = parse_handshake_message(&messages[index])?;
if next_type != HANDSHAKE_CERTIFICATE {
return Err(Error::ParseFailure(
"tls12 server handshake flight expected certificate after server hello",
));
}
self.recv_tls12_server_certificate(&messages[index])?;
index += 1;
while index < messages.len() {
let (message_type, _body) = parse_handshake_message(&messages[index])?;
if message_type == HANDSHAKE_SERVER_KEY_EXCHANGE {
self.recv_tls12_server_key_exchange(&messages[index])?;
index += 1;
continue;
}
if message_type == HANDSHAKE_CERTIFICATE_REQUEST {
self.recv_tls12_server_certificate_request(&messages[index])?;
index += 1;
continue;
}
break;
}
if index >= messages.len() {
return Err(Error::ParseFailure(
"tls12 server handshake flight missing server hello done",
));
}
self.recv_tls12_server_hello_done(&messages[index])?;
index += 1;
if index != messages.len() {
return Err(Error::ParseFailure(
"unexpected trailing tls12 server handshake messages",
));
}
self.state = HandshakeState::ServerCertificateVerified;
Ok(())
}
pub fn recv_tls12_change_cipher_spec(&mut self) -> Result<()> {
if self.version != TlsVersion::Tls12 {
return Err(Error::StateError(
"tls12 change cipher spec requires tls1.2 connection version",
));
}
if self.state != HandshakeState::ServerCertificateVerified {
return Err(Error::StateError(
"tls12 change cipher spec can only be processed after server handshake flight",
));
}
self.tls12_change_cipher_spec_seen = true;
Ok(())
}
pub fn process_tls12_client_handshake_flight(&mut self, messages: &[Vec<u8>]) -> Result<()> {
if self.version != TlsVersion::Tls12 {
return Err(Error::StateError(
"tls12 client flight processing requires tls1.2 connection version",
));
}
if self.state != HandshakeState::ServerCertificateVerified {
return Err(Error::StateError(
"tls12 client flight can only be processed after server handshake flight",
));
}
if messages.len() < 2 {
return Err(Error::ParseFailure(
"tls12 client handshake flight is too short",
));
}
let mut index = 0_usize;
let (next_type, _body) = parse_handshake_message(&messages[index])?;
if next_type != HANDSHAKE_CLIENT_KEY_EXCHANGE {
return Err(Error::ParseFailure(
"tls12 client handshake flight expected client key exchange first",
));
}
self.recv_tls12_client_key_exchange(&messages[index])?;
index += 1;
if index < messages.len() {
let (message_type, _body) = parse_handshake_message(&messages[index])?;
if message_type == HANDSHAKE_CERTIFICATE_VERIFY {
self.recv_tls12_client_certificate_verify(&messages[index])?;
index += 1;
}
}
if !self.tls12_change_cipher_spec_seen {
return Err(Error::ParseFailure(
"tls12 expected change cipher spec before finished",
));
}
if index >= messages.len() {
return Err(Error::ParseFailure(
"tls12 client handshake flight missing finished message",
));
}
self.recv_tls12_client_finished(&messages[index])?;
index += 1;
if index != messages.len() {
return Err(Error::ParseFailure(
"unexpected trailing tls12 client handshake messages",
));
}
self.tls12_change_cipher_spec_seen = false;
self.state = HandshakeState::Finished;
Ok(())
}
pub fn process_tls12_server_handshake_flight_with_alert(
&mut self,
messages: &[Vec<u8>],
) -> core::result::Result<(), (Error, Option<Vec<u8>>)> {
match self.process_tls12_server_handshake_flight(messages) {
Ok(()) => Ok(()),
Err(error) => {
let alert_packet = self.send_tls12_alert_for_handshake_error(&error).ok();
Err((error, alert_packet))
}
}
}
pub fn process_tls12_client_handshake_flight_with_alert(
&mut self,
messages: &[Vec<u8>],
) -> core::result::Result<(), (Error, Option<Vec<u8>>)> {
match self.process_tls12_client_handshake_flight(messages) {
Ok(()) => Ok(()),
Err(error) => {
let alert_packet = self.send_tls12_alert_for_handshake_error(&error).ok();
Err((error, alert_packet))
}
}
}
#[must_use]
pub fn tls12_alert_for_handshake_error(error: &Error) -> (AlertLevel, AlertDescription) {
let description = match error {
Error::StateError(message) => {
if message.contains("can only be processed")
|| message.contains("expected")
|| message.contains("missing")
{
AlertDescription::UnexpectedMessage
} else {
AlertDescription::InternalError
}
}
Error::ParseFailure(message) | Error::InvalidLength(message) => {
if message.contains("expected")
|| message.contains("missing")
|| message.contains("unexpected trailing")
|| message.contains("invalid")
|| message.contains("malformed")
|| message.contains("must be empty")
|| message.contains("must not be empty")
{
AlertDescription::UnexpectedMessage
} else {
AlertDescription::IllegalParameter
}
}
Error::InvalidEncoding(_message) => AlertDescription::IllegalParameter,
Error::UnsupportedFeature(_message) | Error::CryptoFailure(_message) => {
AlertDescription::HandshakeFailure
}
};
(AlertLevel::Fatal, description)
}
fn recv_tls12_server_certificate(&mut self, msg: &[u8]) -> Result<()> {
if self.state != HandshakeState::ServerHelloReceived {
return Err(Error::StateError(
"tls12 certificate can only be processed after server hello",
));
}
let (message_type, body) = parse_handshake_message(msg)?;
if message_type != HANDSHAKE_CERTIFICATE {
return Err(Error::ParseFailure(
"invalid tls12 certificate message type",
));
}
let certificates = parse_tls12_certificate_list(body)?;
if self.tls13_require_certificate_auth {
self.validate_tls13_server_certificate_chain(&certificates)?;
}
self.append_transcript(msg);
self.state = HandshakeState::ServerCertificateReceived;
Ok(())
}
fn recv_tls12_server_key_exchange(&mut self, msg: &[u8]) -> Result<()> {
if self.state != HandshakeState::ServerCertificateReceived {
return Err(Error::StateError(
"tls12 server key exchange can only be processed after certificate",
));
}
let (message_type, body) = parse_handshake_message(msg)?;
if message_type != HANDSHAKE_SERVER_KEY_EXCHANGE {
return Err(Error::ParseFailure(
"invalid tls12 server key exchange message type",
));
}
parse_tls12_server_key_exchange_body(body)?;
self.append_transcript(msg);
Ok(())
}
fn recv_tls12_server_certificate_request(&mut self, msg: &[u8]) -> Result<()> {
if self.state != HandshakeState::ServerCertificateReceived {
return Err(Error::StateError(
"tls12 certificate request can only be processed after certificate",
));
}
let (message_type, body) = parse_handshake_message(msg)?;
if message_type != HANDSHAKE_CERTIFICATE_REQUEST {
return Err(Error::ParseFailure(
"invalid tls12 certificate request message type",
));
}
if body.is_empty() {
return Err(Error::ParseFailure(
"tls12 certificate request body must not be empty",
));
}
self.append_transcript(msg);
Ok(())
}
fn recv_tls12_server_hello_done(&mut self, msg: &[u8]) -> Result<()> {
if self.state != HandshakeState::ServerCertificateReceived {
return Err(Error::StateError(
"tls12 server hello done can only be processed after certificate flight",
));
}
let (message_type, body) = parse_handshake_message(msg)?;
if message_type != HANDSHAKE_SERVER_HELLO_DONE {
return Err(Error::ParseFailure(
"invalid tls12 server hello done message type",
));
}
if !body.is_empty() {
return Err(Error::ParseFailure(
"tls12 server hello done body must be empty",
));
}
self.append_transcript(msg);
Ok(())
}
fn recv_tls12_client_key_exchange(&mut self, msg: &[u8]) -> Result<()> {
let (message_type, body) = parse_handshake_message(msg)?;
if message_type != HANDSHAKE_CLIENT_KEY_EXCHANGE {
return Err(Error::ParseFailure(
"invalid tls12 client key exchange message type",
));
}
if body.is_empty() {
return Err(Error::ParseFailure(
"tls12 client key exchange body must not be empty",
));
}
self.append_transcript(msg);
Ok(())
}
fn recv_tls12_client_certificate_verify(&mut self, msg: &[u8]) -> Result<()> {
let (message_type, body) = parse_handshake_message(msg)?;
if message_type != HANDSHAKE_CERTIFICATE_VERIFY {
return Err(Error::ParseFailure(
"invalid tls12 client certificate verify message type",
));
}
parse_tls12_certificate_verify_body(body)?;
self.append_transcript(msg);
Ok(())
}
fn recv_tls12_client_finished(&mut self, msg: &[u8]) -> Result<()> {
let (message_type, body) = parse_handshake_message(msg)?;
if message_type != HANDSHAKE_FINISHED {
return Err(Error::ParseFailure("invalid tls12 finished message type"));
}
if body.is_empty() {
return Err(Error::ParseFailure("tls12 finished body must not be empty"));
}
self.append_transcript(msg);
Ok(())
}
pub fn build_certificate_message(certificate_der: &[u8]) -> Result<Vec<u8>> {
Self::build_certificate_message_with_ocsp_staple(certificate_der, None)
}
pub fn build_certificate_message_with_ocsp_staple(
certificate_der: &[u8],
ocsp_staple: Option<&[u8]>,
) -> Result<Vec<u8>> {
if certificate_der.is_empty() {
return Err(Error::InvalidLength("certificate der must not be empty"));
}
if certificate_der.len() > 0x00FF_FFFF {
return Err(Error::InvalidLength("certificate der is too large"));
}
let certificate_extensions = if let Some(staple) = ocsp_staple {
encode_certificate_entry_status_request_extension(staple)?
} else {
Vec::new()
};
let mut body = Vec::new();
body.push(0x00); let cert_entry_len = 3 + certificate_der.len() + 2 + certificate_extensions.len();
let list_len = cert_entry_len as u32;
body.extend_from_slice(&list_len.to_be_bytes()[1..4]);
let cert_len = certificate_der.len() as u32;
body.extend_from_slice(&cert_len.to_be_bytes()[1..4]);
body.extend_from_slice(certificate_der);
body.extend_from_slice(&(certificate_extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&certificate_extensions);
Ok(encode_handshake_message(HANDSHAKE_CERTIFICATE, &body))
}
pub fn recv_certificate_verify(&mut self, msg: &[u8]) -> Result<()> {
if self.state != HandshakeState::ServerCertificateReceived {
return Err(Error::StateError(
"certificate verify can only be processed after certificate",
));
}
let (handshake_type, body) = parse_handshake_message(msg)?;
if handshake_type != HANDSHAKE_CERTIFICATE_VERIFY {
return Err(Error::ParseFailure("invalid certificate verify type"));
}
let (signature_scheme, signature) = parse_certificate_verify_fields(body)?;
if signature.is_empty() {
return Err(Error::ParseFailure(
"certificate verify signature must not be empty",
));
}
if !tls13_supported_certificate_verify_signature_scheme(signature_scheme) {
return Err(Error::UnsupportedFeature(
"unsupported tls13 certificate verify signature scheme",
));
}
if self.tls13_require_certificate_auth {
if !self.tls13_server_certificate_chain_validated {
return Err(Error::StateError(
"certificate verify requires validated server certificate chain",
));
}
self.verify_tls13_server_certificate_verify_signature(signature_scheme, signature)?;
}
self.append_transcript(msg);
self.state = HandshakeState::ServerCertificateVerified;
Ok(())
}
pub fn build_certificate_verify_message(
signature_scheme: u16,
signature: &[u8],
) -> Result<Vec<u8>> {
if signature.is_empty() {
return Err(Error::InvalidLength(
"certificate verify signature must not be empty",
));
}
if signature.len() > usize::from(u16::MAX) {
return Err(Error::InvalidLength(
"certificate verify signature is too large",
));
}
let mut body = Vec::new();
body.extend_from_slice(&signature_scheme.to_be_bytes());
body.extend_from_slice(&(signature.len() as u16).to_be_bytes());
body.extend_from_slice(signature);
Ok(encode_handshake_message(
HANDSHAKE_CERTIFICATE_VERIFY,
&body,
))
}
pub fn derive_handshake_secret(&mut self) -> Result<[u8; 32]> {
if self.state != HandshakeState::ServerHelloReceived
&& self.state != HandshakeState::ServerCertificateVerified
{
return Err(Error::StateError(
"cannot derive handshake secret before server hello",
));
}
let transcript_hash = self.transcript_hash();
let hash_algorithm = self.negotiated_hash_algorithm();
let secret_material = match self.version {
TlsVersion::Tls13 | TlsVersion::Dtls13 => derive_tls13_handshake_secret(
hash_algorithm,
self.tls13_shared_secret
.as_ref()
.map_or(&transcript_hash, |secret| secret),
self.selected_cipher_suite,
)?,
TlsVersion::Tls12 | TlsVersion::Dtls12 => {
let prk = hkdf_extract_for_hash(hash_algorithm, &transcript_hash);
tls12_prf_for_hash(
hash_algorithm,
&prk,
b"handshake secret",
&transcript_hash,
32,
)?
}
TlsVersion::Tls10 | TlsVersion::Tls11 => {
let prk = hkdf_extract_for_hash(hash_algorithm, &transcript_hash);
hkdf_expand_for_hash(hash_algorithm, &prk, b"handshake secret", 32)?
}
};
self.install_traffic_keys(hash_algorithm, &secret_material, &transcript_hash)?;
self.install_tls13_finished_key(hash_algorithm, &secret_material)?;
self.handshake_secret = Some(secret_material.clone());
let mut secret = [0_u8; 32];
let copy_len = secret_material.len().min(32);
secret[..copy_len].copy_from_slice(&secret_material[..copy_len]);
self.state = HandshakeState::KeysDerived;
Ok(secret)
}
pub fn finish(&mut self, verify_data: &[u8]) -> Result<()> {
if self.state != HandshakeState::KeysDerived
&& self.state != HandshakeState::ServerCertificateVerified
{
return Err(Error::StateError("finish must follow key derivation"));
}
let expected = self.compute_expected_finished()?;
if verify_data != expected.as_slice() {
return Err(Error::CryptoFailure("finished verify_data mismatch"));
}
self.append_transcript(verify_data);
self.install_tls13_application_traffic_keys()?;
self.state = HandshakeState::Finished;
Ok(())
}
pub fn recv_finished_message(&mut self, msg: &[u8]) -> Result<()> {
let (handshake_type, body) = parse_handshake_message(msg)?;
if handshake_type != HANDSHAKE_FINISHED {
return Err(Error::ParseFailure("invalid finished type"));
}
if self.state != HandshakeState::KeysDerived
&& self.state != HandshakeState::ServerCertificateVerified
{
return Err(Error::StateError("finish must follow key derivation"));
}
let expected_len = self.compute_expected_finished()?.len();
if body.len() != expected_len {
return Err(Error::ParseFailure("finished verify_data length mismatch"));
}
self.finish(body)
}
pub fn build_finished_message(&self) -> Result<Vec<u8>> {
let verify_data = self.compute_finished_verify_data()?;
Ok(encode_handshake_message(HANDSHAKE_FINISHED, &verify_data))
}
pub fn build_new_session_ticket_message(
ticket_lifetime: u32,
ticket_age_add: u32,
ticket_nonce: &[u8],
ticket: &[u8],
) -> Result<Vec<u8>> {
if ticket_nonce.len() > usize::from(u8::MAX) {
return Err(Error::InvalidLength("ticket nonce is too large"));
}
if ticket.len() > usize::from(u16::MAX) {
return Err(Error::InvalidLength("ticket identity is too large"));
}
let mut body = Vec::new();
body.extend_from_slice(&ticket_lifetime.to_be_bytes());
body.extend_from_slice(&ticket_age_add.to_be_bytes());
body.push(ticket_nonce.len() as u8);
body.extend_from_slice(ticket_nonce);
body.extend_from_slice(&(ticket.len() as u16).to_be_bytes());
body.extend_from_slice(ticket);
body.extend_from_slice(&0_u16.to_be_bytes()); Ok(encode_handshake_message(
HANDSHAKE_NEW_SESSION_TICKET,
&body,
))
}
pub fn recv_new_session_ticket_message(&mut self, msg: &[u8]) -> Result<()> {
if self.state != HandshakeState::Finished {
return Err(Error::StateError(
"new session ticket requires finished handshake state",
));
}
let (handshake_type, body) = parse_handshake_message(msg)?;
if handshake_type != HANDSHAKE_NEW_SESSION_TICKET {
return Err(Error::ParseFailure("invalid new session ticket type"));
}
parse_new_session_ticket_body(body)?;
self.append_transcript(msg);
Ok(())
}
pub fn build_key_update_message(request_update: bool) -> Vec<u8> {
let request = if request_update { 1_u8 } else { 0_u8 };
encode_handshake_message(HANDSHAKE_KEY_UPDATE, &[request])
}
pub fn recv_key_update_message(&mut self, msg: &[u8]) -> Result<()> {
if self.state != HandshakeState::Finished {
return Err(Error::StateError(
"key update requires finished handshake state",
));
}
let (handshake_type, body) = parse_handshake_message(msg)?;
if handshake_type != HANDSHAKE_KEY_UPDATE {
return Err(Error::ParseFailure("invalid key update type"));
}
if body.len() != 1 || body[0] > 1 {
return Err(Error::ParseFailure("invalid key update request value"));
}
self.update_tls13_traffic_keys()?;
self.append_transcript(msg);
Ok(())
}
#[must_use]
pub fn transcript_hash(&self) -> Vec<u8> {
self.transcript_hash.snapshot_hash()
}
#[must_use]
pub fn selected_cipher_suite(&self) -> Option<CipherSuite> {
self.selected_cipher_suite
}
pub fn build_server_hello(
version: TlsVersion,
suite: CipherSuite,
random: &[u8],
) -> Result<Vec<u8>> {
if random.len() != 32 {
return Err(Error::InvalidLength("server hello random must be 32 bytes"));
}
let body = encode_server_hello_body(version, suite, random)?;
Ok(encode_handshake_message(HANDSHAKE_SERVER_HELLO, &body))
}
pub fn build_server_hello_with_key_share(
version: TlsVersion,
suite: CipherSuite,
random: &[u8],
named_group: u16,
key_exchange: &[u8],
) -> Result<Vec<u8>> {
if random.len() != 32 {
return Err(Error::InvalidLength("server hello random must be 32 bytes"));
}
let body = encode_server_hello_body_with_key_share(
version,
suite,
random,
Some((named_group, key_exchange)),
)?;
Ok(encode_handshake_message(HANDSHAKE_SERVER_HELLO, &body))
}
pub fn build_server_hello_auto(
version: TlsVersion,
suite: CipherSuite,
drbg: &mut HmacDrbgSha256,
) -> Result<Vec<u8>> {
let random = drbg.generate(32, b"server_hello_random")?;
Self::build_server_hello(version, suite, &random)
}
pub fn parse_client_hello_cipher_suites(msg: &[u8]) -> Result<Vec<CipherSuite>> {
parse_client_hello_info(msg).map(|hello| hello.offered_cipher_suites)
}
pub fn parse_client_hello_info(msg: &[u8]) -> Result<ClientHelloInfo> {
parse_client_hello_info(msg)
}
#[must_use]
pub fn tls13_server_certificate_verify_content(transcript_hash: &[u8]) -> Vec<u8> {
build_tls13_server_certificate_verify_message(transcript_hash)
}
pub fn select_cipher_suite_from_client_hello(
client_hello: &[u8],
server_preferred: &[CipherSuite],
version: TlsVersion,
) -> Result<CipherSuite> {
let hello = parse_client_hello_info(client_hello)?;
pick_intersection_suite(&hello, server_preferred, version)
}
pub fn build_server_hello_for_client(
version: TlsVersion,
client_hello: &[u8],
server_random: &[u8],
server_preferred: &[CipherSuite],
) -> Result<Vec<u8>> {
let selected =
Self::select_cipher_suite_from_client_hello(client_hello, server_preferred, version)?;
Self::build_server_hello(version, selected, server_random)
}
pub fn build_server_hello_for_client_auto(
version: TlsVersion,
client_hello: &[u8],
server_preferred: &[CipherSuite],
drbg: &mut HmacDrbgSha256,
) -> Result<Vec<u8>> {
let random = drbg.generate(32, b"server_hello_random")?;
Self::build_server_hello_for_client(version, client_hello, &random, server_preferred)
}
pub fn compute_finished_verify_data(&self) -> Result<Vec<u8>> {
self.compute_expected_finished()
}
pub fn update_tls13_traffic_keys(&mut self) -> Result<()> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 traffic key update is only valid for TLS 1.3",
));
}
if self.state != HandshakeState::Finished {
return Err(Error::StateError(
"tls13 traffic key update requires finished handshake",
));
}
let hash_algorithm = self.negotiated_hash_algorithm();
let hash_len = hash_algorithm.output_len();
let client_secret = self
.tls13_client_application_traffic_secret
.as_ref()
.ok_or(Error::StateError(
"tls13 application client traffic secret is not installed",
))?;
let server_secret = self
.tls13_server_application_traffic_secret
.as_ref()
.ok_or(Error::StateError(
"tls13 application server traffic secret is not installed",
))?;
let next_client_secret = tls13_expand_label_for_hash(
hash_algorithm,
client_secret,
b"traffic upd",
&[],
hash_len,
)?;
let next_server_secret = tls13_expand_label_for_hash(
hash_algorithm,
server_secret,
b"traffic upd",
&[],
hash_len,
)?;
self.install_tls13_record_protection_keys(
hash_algorithm,
&next_client_secret,
&next_server_secret,
)?;
self.tls13_client_application_traffic_secret = Some(next_client_secret);
self.tls13_server_application_traffic_secret = Some(next_server_secret);
self.client_sequence = 0;
self.server_sequence = 0;
Ok(())
}
pub fn derive_tls13_quic_initial_secrets_v1(
destination_connection_id: &[u8],
) -> Result<Tls13QuicInitialSecrets> {
if destination_connection_id.is_empty() {
return Err(Error::InvalidLength(
"quic destination connection id must not be empty",
));
}
let initial_secret =
hkdf_extract_sha256(&TLS13_QUIC_V1_INITIAL_SALT, destination_connection_id).to_vec();
let client_initial_secret = tls13_expand_label_for_hash(
HashAlgorithm::Sha256,
&initial_secret,
b"client in",
&[],
32,
)?;
let server_initial_secret = tls13_expand_label_for_hash(
HashAlgorithm::Sha256,
&initial_secret,
b"server in",
&[],
32,
)?;
Ok(Tls13QuicInitialSecrets {
initial_secret,
client_initial_secret,
server_initial_secret,
})
}
pub fn derive_tls13_quic_packet_protection_keys(
hash_algorithm: HashAlgorithm,
traffic_secret: &[u8],
key_len: usize,
header_protection_key_len: usize,
) -> Result<Tls13QuicPacketProtectionKeys> {
if key_len == 0 {
return Err(Error::InvalidLength(
"quic key length must be greater than zero",
));
}
if header_protection_key_len == 0 {
return Err(Error::InvalidLength(
"quic header protection key length must be greater than zero",
));
}
let key =
tls13_expand_label_for_hash(hash_algorithm, traffic_secret, b"quic key", &[], key_len)?;
let iv = tls13_expand_label_for_hash(hash_algorithm, traffic_secret, b"quic iv", &[], 12)?;
let header_protection_key = tls13_expand_label_for_hash(
hash_algorithm,
traffic_secret,
b"quic hp",
&[],
header_protection_key_len,
)?;
Ok(Tls13QuicPacketProtectionKeys {
key,
iv,
header_protection_key,
})
}
pub fn tls13_quic_traffic_secret_snapshot(&self) -> Result<Tls13QuicTrafficSecretSnapshot> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"quic traffic secret snapshot is only defined for TLS 1.3",
));
}
let client_handshake_secret =
self.tls13_client_handshake_traffic_secret
.clone()
.ok_or(Error::StateError(
"tls13 client handshake traffic secret is not installed",
))?;
let server_handshake_secret =
self.tls13_server_handshake_traffic_secret
.clone()
.ok_or(Error::StateError(
"tls13 server handshake traffic secret is not installed",
))?;
let client_application_secret = self
.tls13_client_application_traffic_secret
.clone()
.ok_or(Error::StateError(
"tls13 client application traffic secret is not installed",
))?;
let server_application_secret = self
.tls13_server_application_traffic_secret
.clone()
.ok_or(Error::StateError(
"tls13 server application traffic secret is not installed",
))?;
Ok(Tls13QuicTrafficSecretSnapshot {
client_handshake_secret,
server_handshake_secret,
client_application_secret,
server_application_secret,
})
}
pub fn derive_tls13_quic_next_traffic_secrets(&self) -> Result<Tls13QuicNextTrafficSecrets> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"quic key update secrets are only defined for TLS 1.3",
));
}
let hash_algorithm = self.negotiated_hash_algorithm();
let hash_len = hash_algorithm.output_len();
let client_secret = self
.tls13_client_application_traffic_secret
.as_ref()
.ok_or(Error::StateError(
"tls13 application client traffic secret is not installed",
))?;
let server_secret = self
.tls13_server_application_traffic_secret
.as_ref()
.ok_or(Error::StateError(
"tls13 application server traffic secret is not installed",
))?;
let client_next_application_secret =
tls13_expand_label_for_hash(hash_algorithm, client_secret, b"quic ku", &[], hash_len)?;
let server_next_application_secret =
tls13_expand_label_for_hash(hash_algorithm, server_secret, b"quic ku", &[], hash_len)?;
Ok(Tls13QuicNextTrafficSecrets {
client_next_application_secret,
server_next_application_secret,
})
}
pub fn export_quic_keying_material(
&self,
label: &[u8],
context: &[u8],
len: usize,
) -> Result<Vec<u8>> {
if !label.starts_with(b"EXPORTER-QUIC ") {
return Err(Error::StateError(
"quic exporter requires label prefix EXPORTER-QUIC ",
));
}
self.export_keying_material(label, context, len)
}
pub fn export_keying_material(
&self,
label: &[u8],
context: &[u8],
len: usize,
) -> Result<Vec<u8>> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"key exporter is currently only modeled for TLS 1.3",
));
}
if self.state != HandshakeState::Finished {
return Err(Error::StateError(
"key exporter requires finished handshake state",
));
}
let hash_algorithm = self.negotiated_hash_algorithm();
let hash_len = hash_algorithm.output_len();
let exporter_master =
self.tls13_exporter_master_secret
.as_ref()
.ok_or(Error::StateError(
"tls13 exporter master secret is not installed",
))?;
let context_hash = hash_bytes_for_algorithm(hash_algorithm, context);
let exporter_secret = tls13_expand_label_for_hash(
hash_algorithm,
exporter_master,
b"exporter",
&context_hash,
hash_len,
)?;
tls13_expand_label_for_hash(hash_algorithm, &exporter_secret, label, &context_hash, len)
}
pub fn tls13_resumption_master_secret(&self) -> Result<Vec<u8>> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"resumption master secret is only defined for TLS 1.3",
));
}
if self.state != HandshakeState::Finished {
return Err(Error::StateError(
"resumption master secret requires finished handshake state",
));
}
self.tls13_resumption_master_secret
.clone()
.ok_or(Error::StateError(
"tls13 resumption master secret is not installed",
))
}
pub fn derive_tls13_resumption_psk(&self, ticket_nonce: &[u8]) -> Result<Vec<u8>> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"resumption psk derivation is only defined for TLS 1.3",
));
}
if ticket_nonce.is_empty() {
return Err(Error::InvalidLength("ticket nonce must not be empty"));
}
let hash_algorithm = self.negotiated_hash_algorithm();
let hash_len = hash_algorithm.output_len();
let resumption_master =
self.tls13_resumption_master_secret
.as_ref()
.ok_or(Error::StateError(
"tls13 resumption master secret is not installed",
))?;
tls13_expand_label_for_hash(
hash_algorithm,
resumption_master,
b"resumption",
ticket_nonce,
hash_len,
)
}
pub fn issue_tls13_resumption_ticket(
&self,
drbg: &mut HmacDrbgSha256,
age_add: u32,
) -> Result<ResumptionTicket> {
self.issue_tls13_resumption_ticket_with_time(drbg, age_add, 0, u64::MAX)
}
pub fn issue_tls13_resumption_ticket_into_store(
&self,
drbg: &mut HmacDrbgSha256,
age_add: u32,
ticket_store: &mut TicketStore,
) -> Result<ResumptionTicket> {
let ticket = self.issue_tls13_resumption_ticket(drbg, age_add)?;
ticket_store.insert(ticket.clone());
Ok(ticket)
}
pub fn issue_tls13_resumption_ticket_with_time(
&self,
drbg: &mut HmacDrbgSha256,
age_add: u32,
issued_at_ms: u64,
lifetime_ms: u64,
) -> Result<ResumptionTicket> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"resumption ticket issuance is only defined for TLS 1.3",
));
}
if self.state != HandshakeState::Finished {
return Err(Error::StateError(
"resumption ticket issuance requires finished handshake state",
));
}
let nonce = drbg.generate(16, b"tls13_ticket_nonce")?;
let hash_algorithm = self.negotiated_hash_algorithm();
let identity = tls13_expand_label_for_hash(
hash_algorithm,
&self.tls13_resumption_master_secret()?,
b"ticket",
&nonce,
16,
)?;
Ok(ResumptionTicket {
identity,
ticket_nonce: nonce,
obfuscated_ticket_age: age_add,
age_add,
issued_at_ms,
lifetime_ms,
max_early_data_size: TLS_MAX_RECORD_PLAINTEXT_LEN as u32,
consumed: false,
})
}
pub fn issue_tls13_resumption_ticket_with_time_and_early_data(
&self,
drbg: &mut HmacDrbgSha256,
age_add: u32,
issued_at_ms: u64,
lifetime_ms: u64,
max_early_data_size: u32,
) -> Result<ResumptionTicket> {
let mut ticket =
self.issue_tls13_resumption_ticket_with_time(drbg, age_add, issued_at_ms, lifetime_ms)?;
ticket.max_early_data_size = max_early_data_size;
Ok(ticket)
}
pub fn issue_tls13_resumption_ticket_with_time_into_store(
&self,
drbg: &mut HmacDrbgSha256,
age_add: u32,
issued_at_ms: u64,
lifetime_ms: u64,
ticket_store: &mut TicketStore,
) -> Result<ResumptionTicket> {
let ticket =
self.issue_tls13_resumption_ticket_with_time(drbg, age_add, issued_at_ms, lifetime_ms)?;
ticket_store.insert(ticket.clone());
Ok(ticket)
}
pub fn issue_tls13_resumption_ticket_with_time_and_early_data_into_store(
&self,
drbg: &mut HmacDrbgSha256,
age_add: u32,
issued_at_ms: u64,
lifetime_ms: u64,
max_early_data_size: u32,
ticket_store: &mut TicketStore,
) -> Result<ResumptionTicket> {
let ticket = self.issue_tls13_resumption_ticket_with_time_and_early_data(
drbg,
age_add,
issued_at_ms,
lifetime_ms,
max_early_data_size,
)?;
ticket_store.insert(ticket.clone());
Ok(ticket)
}
pub fn compute_tls13_psk_binder(
&self,
psk: &[u8],
truncated_client_hello: &[u8],
) -> Result<Vec<u8>> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"psk binder computation is only defined for TLS 1.3",
));
}
if psk.is_empty() {
return Err(Error::InvalidLength("psk must not be empty"));
}
if truncated_client_hello.is_empty() {
return Err(Error::InvalidLength(
"truncated client hello must not be empty",
));
}
let hash_algorithm = self.negotiated_hash_algorithm();
let hash_len = hash_algorithm.output_len();
let early_secret = hkdf_extract_for_hash(hash_algorithm, psk);
let binder_key = tls13_expand_label_for_hash(
hash_algorithm,
&early_secret,
b"res binder",
&[],
hash_len,
)?;
let finished_key =
tls13_expand_label_for_hash(hash_algorithm, &binder_key, b"finished", &[], hash_len)?;
let transcript_hash = hash_bytes_for_algorithm(hash_algorithm, truncated_client_hello);
Ok(finished_hmac_for_hash(
hash_algorithm,
&finished_key,
&transcript_hash,
))
}
pub fn verify_tls13_psk_binder(
&self,
psk: &[u8],
truncated_client_hello: &[u8],
received_binder: &[u8],
) -> Result<bool> {
let expected = self.compute_tls13_psk_binder(psk, truncated_client_hello)?;
Ok(constant_time_eq(&expected, received_binder))
}
pub fn verify_client_hello_psk_binder(&self, client_hello: &[u8], psk: &[u8]) -> Result<bool> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"psk binder verification is only defined for TLS 1.3",
));
}
if psk.is_empty() {
return Err(Error::InvalidLength("psk must not be empty"));
}
let received = extract_first_psk_binder_from_client_hello(client_hello)?;
let normalized = zero_client_hello_psk_binders(client_hello)?;
self.verify_tls13_psk_binder(psk, &normalized, &received)
}
pub fn verify_client_hello_psk_binder_for_ticket(
&self,
client_hello: &[u8],
ticket: &ResumptionTicket,
) -> Result<bool> {
self.verify_client_hello_psk_binder_for_ticket_with_age(
client_hello,
ticket,
ticket.issued_at_ms,
u32::MAX,
)
}
pub fn verify_client_hello_psk_binder_for_ticket_with_age(
&self,
client_hello: &[u8],
ticket: &ResumptionTicket,
current_time_ms: u64,
max_skew_ms: u32,
) -> Result<bool> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"psk binder verification is only defined for TLS 1.3",
));
}
let info = parse_client_hello_info(client_hello)?;
let Some(identity) = info.extensions.psk_identities.first() else {
return Ok(false);
};
if identity.as_slice() != ticket.identity.as_slice() {
return Ok(false);
}
let Some(offered_age) = info.extensions.psk_obfuscated_ticket_ages.first().copied() else {
return Ok(false);
};
if ticket.consumed {
return Ok(false);
}
if !ticket_age_matches_policy(ticket, offered_age, current_time_ms, max_skew_ms) {
return Ok(false);
}
let psk = self.derive_tls13_resumption_psk(&ticket.ticket_nonce)?;
self.verify_client_hello_psk_binder(client_hello, &psk)
}
pub fn verify_client_hello_psk_binder_for_tickets_with_age(
&self,
client_hello: &[u8],
tickets: &[ResumptionTicket],
current_time_ms: u64,
max_skew_ms: u32,
) -> Result<Option<usize>> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"psk binder verification is only defined for TLS 1.3",
));
}
if tickets.is_empty() {
return Ok(None);
}
let info = parse_client_hello_info(client_hello)?;
if info.extensions.psk_identities.is_empty() || info.extensions.psk_binders.is_empty() {
return Ok(None);
}
let normalized = zero_client_hello_psk_binders(client_hello)?;
for (identity_idx, identity) in info.extensions.psk_identities.iter().enumerate() {
let Some(offered_age) = info
.extensions
.psk_obfuscated_ticket_ages
.get(identity_idx)
.copied()
else {
continue;
};
let Some(received_binder) = info.extensions.psk_binders.get(identity_idx) else {
continue;
};
for (ticket_idx, ticket) in tickets.iter().enumerate() {
if identity.as_slice() != ticket.identity.as_slice() {
continue;
}
if ticket.consumed {
continue;
}
if !ticket_age_matches_policy(ticket, offered_age, current_time_ms, max_skew_ms) {
continue;
}
let psk = self.derive_tls13_resumption_psk(&ticket.ticket_nonce)?;
let expected_binder = self.compute_tls13_psk_binder(&psk, &normalized)?;
if constant_time_eq(&expected_binder, received_binder) {
return Ok(Some(ticket_idx));
}
}
}
Ok(None)
}
pub fn verify_and_apply_client_hello_psk_policy(
&self,
client_hello: &[u8],
tickets: &mut [ResumptionTicket],
current_time_ms: u64,
max_skew_ms: u32,
usage_policy: TicketUsagePolicy,
) -> Result<Option<usize>> {
let matched = self.verify_client_hello_psk_binder_for_tickets_with_age(
client_hello,
tickets,
current_time_ms,
max_skew_ms,
)?;
if let Some(index) = matched {
if usage_policy == TicketUsagePolicy::SingleUse {
if let Some(ticket) = tickets.get_mut(index) {
ticket.consumed = true;
}
}
}
Ok(matched)
}
pub fn verify_and_apply_client_hello_psk_policy_with_store(
&self,
client_hello: &[u8],
ticket_store: &mut TicketStore,
current_time_ms: u64,
max_skew_ms: u32,
usage_policy: TicketUsagePolicy,
) -> Result<Option<usize>> {
self.verify_and_apply_client_hello_psk_policy(
client_hello,
ticket_store.tickets_mut(),
current_time_ms,
max_skew_ms,
usage_policy,
)
}
pub fn accept_tls13_early_data_with_ticket_policy(
&mut self,
client_hello: &[u8],
tickets: &mut [ResumptionTicket],
current_time_ms: u64,
max_skew_ms: u32,
usage_policy: TicketUsagePolicy,
) -> Result<bool> {
let info = parse_client_hello_info(client_hello)?;
self.tls13_early_data_offered_in_client_hello = info.extensions.early_data_offered;
self.tls13_early_data_accepted_in_encrypted_extensions = false;
self.tls13_early_data_opened_bytes = 0;
self.reset_tls13_early_data_transcript_to_client_hello(client_hello);
let matched = self.verify_and_apply_client_hello_psk_policy(
client_hello,
tickets,
current_time_ms,
max_skew_ms,
usage_policy,
)?;
let Some(ticket_index) = matched else {
self.tls13_early_data_accepted_psk = None;
self.tls13_early_data_max_bytes = None;
return Ok(false);
};
if !self.tls13_early_data_offered_in_client_hello {
self.tls13_early_data_accepted_psk = None;
self.tls13_early_data_max_bytes = None;
return Ok(false);
}
let ticket = tickets
.get(ticket_index)
.ok_or(Error::StateError("matched ticket index is out of range"))?;
if ticket.max_early_data_size == 0 {
self.tls13_early_data_accepted_psk = None;
self.tls13_early_data_max_bytes = None;
return Ok(false);
}
let psk = self.derive_tls13_resumption_psk(&ticket.ticket_nonce)?;
self.tls13_early_data_accepted_psk = Some(psk);
self.tls13_early_data_max_bytes = Some(ticket.max_early_data_size);
self.tls13_early_data_replay_window = DtlsReplayWindow::new();
Ok(true)
}
pub fn accept_tls13_early_data_with_ticket_store(
&mut self,
client_hello: &[u8],
ticket_store: &mut TicketStore,
current_time_ms: u64,
max_skew_ms: u32,
usage_policy: TicketUsagePolicy,
) -> Result<bool> {
self.accept_tls13_early_data_with_ticket_policy(
client_hello,
ticket_store.tickets_mut(),
current_time_ms,
max_skew_ms,
usage_policy,
)
}
pub fn seal_record(&mut self, plaintext: &[u8], aad: &[u8]) -> Result<ProtectedRecord> {
if self.state != HandshakeState::Finished {
return Err(Error::StateError(
"cannot seal record before handshake finish",
));
}
if plaintext.len() > self.max_record_plaintext_len {
return Err(Error::InvalidLength(
"record plaintext exceeds configured limit",
));
}
if self.client_sequence == u64::MAX {
return Err(Error::StateError("client record sequence exhausted"));
}
let suite = self.selected_cipher_suite.ok_or(Error::StateError(
"cipher suite must be selected before sealing records",
))?;
let key = self
.client_write_key
.ok_or(Error::StateError("client write key is not installed"))?;
let iv = self
.client_write_iv
.ok_or(Error::StateError("client write iv is not installed"))?;
let nonce = build_record_nonce(&iv, self.client_sequence);
let (ciphertext, tag) = match suite {
CipherSuite::TlsChacha20Poly1305Sha256 => {
chacha20_poly1305_encrypt(&key, &nonce, aad, plaintext)?
}
CipherSuite::TlsAes128GcmSha256 | CipherSuite::TlsAes256GcmSha384 => {
let key_len = suite.tls13_traffic_key_len().ok_or(Error::StateError(
"tls 1.3 aes suites must define traffic key length",
))?;
let cipher = AesCipher::new(&key[..key_len])?;
aes_gcm_encrypt(&cipher, &nonce, aad, plaintext)?
}
CipherSuite::TlsEcdheRsaWithAes128GcmSha256
| CipherSuite::TlsEcdheRsaWithAes256GcmSha384 => {
let cipher = AesCipher::new(&key[..16])?;
aes_gcm_encrypt(&cipher, &nonce, aad, plaintext)?
}
};
let record = ProtectedRecord {
sequence: self.client_sequence,
ciphertext,
tag,
};
self.client_sequence = self.client_sequence.wrapping_add(1);
Ok(record)
}
pub fn seal_tls13_early_data_record(
&self,
psk: &[u8],
plaintext: &[u8],
aad: &[u8],
sequence: u64,
) -> Result<ProtectedRecord> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 early-data records require TLS 1.3 connection",
));
}
if psk.is_empty() {
return Err(Error::InvalidLength(
"tls13 early-data psk must not be empty",
));
}
if plaintext.len() > self.max_record_plaintext_len {
return Err(Error::InvalidLength(
"record plaintext exceeds configured limit",
));
}
if self.state != HandshakeState::ClientHelloSent {
return Err(Error::StateError(
"tls13 early-data may only be sealed in ClientHelloSent state",
));
}
let (key, iv) = self.derive_tls13_early_data_record_key_iv(psk)?;
let nonce = build_record_nonce(&iv, sequence);
let (ciphertext, tag) = if self.tls13_early_data_uses_chacha20_poly1305() {
let key_32: [u8; 32] = key.as_slice().try_into().map_err(|_| {
Error::InvalidLength("tls13 early-data chacha key must be 32 bytes")
})?;
chacha20_poly1305_encrypt(&key_32, &nonce, aad, plaintext)?
} else {
let cipher = AesCipher::new(&key)?;
aes_gcm_encrypt(&cipher, &nonce, aad, plaintext)?
};
Ok(ProtectedRecord {
sequence,
ciphertext,
tag,
})
}
pub fn open_tls13_early_data_record(
&mut self,
psk: &[u8],
record: &ProtectedRecord,
aad: &[u8],
) -> Result<Vec<u8>> {
if !self.version.uses_tls13_handshake_semantics() {
self.tls13_early_data_telemetry.rejected_invalid_input = self
.tls13_early_data_telemetry
.rejected_invalid_input
.saturating_add(1);
return Err(Error::StateError(
"tls13 early-data records require TLS 1.3 connection",
));
}
if psk.is_empty() {
self.tls13_early_data_telemetry.rejected_invalid_input = self
.tls13_early_data_telemetry
.rejected_invalid_input
.saturating_add(1);
return Err(Error::InvalidLength(
"tls13 early-data psk must not be empty",
));
}
if !matches!(
self.state,
HandshakeState::ClientHelloSent
| HandshakeState::ServerHelloReceived
| HandshakeState::Finished
) {
self.tls13_early_data_telemetry.rejected_decrypt_or_policy = self
.tls13_early_data_telemetry
.rejected_decrypt_or_policy
.saturating_add(1);
return Err(Error::StateError(
"tls13 early-data may only be opened before encrypted extensions",
));
}
if self.tls13_early_data_require_acceptance {
let Some(accepted_psk) = self.tls13_early_data_accepted_psk.as_deref() else {
self.tls13_early_data_telemetry.rejected_missing_acceptance = self
.tls13_early_data_telemetry
.rejected_missing_acceptance
.saturating_add(1);
return Err(Error::StateError(
"tls13 early-data requires prior ticket-policy acceptance",
));
};
if !constant_time_eq(accepted_psk, psk) {
self.tls13_early_data_telemetry.rejected_psk_mismatch = self
.tls13_early_data_telemetry
.rejected_psk_mismatch
.saturating_add(1);
return Err(Error::StateError(
"tls13 early-data psk does not match accepted ticket context",
));
}
}
if self.tls13_early_data_anti_replay_enabled
&& !self
.tls13_early_data_replay_window
.check_and_mark(record.sequence)
{
self.tls13_early_data_telemetry.rejected_replay_or_too_old = self
.tls13_early_data_telemetry
.rejected_replay_or_too_old
.saturating_add(1);
return Err(Error::StateError(
"tls13 early-data replay detected or sequence is too old",
));
}
let (key, iv) = self.derive_tls13_early_data_record_key_iv(psk)?;
let nonce = build_record_nonce(&iv, record.sequence);
let plaintext = if self.tls13_early_data_uses_chacha20_poly1305() {
let key_32: [u8; 32] = key.as_slice().try_into().map_err(|_| {
Error::InvalidLength("tls13 early-data chacha key must be 32 bytes")
})?;
chacha20_poly1305_decrypt(&key_32, &nonce, aad, &record.ciphertext, &record.tag)
.map_err(|err| {
self.tls13_early_data_telemetry.rejected_decrypt_or_policy = self
.tls13_early_data_telemetry
.rejected_decrypt_or_policy
.saturating_add(1);
err
})?
} else {
let cipher = AesCipher::new(&key)?;
aes_gcm_decrypt(&cipher, &nonce, aad, &record.ciphertext, &record.tag).map_err(
|err| {
self.tls13_early_data_telemetry.rejected_decrypt_or_policy = self
.tls13_early_data_telemetry
.rejected_decrypt_or_policy
.saturating_add(1);
err
},
)?
};
if plaintext.len() > self.max_record_plaintext_len {
self.tls13_early_data_telemetry.rejected_decrypt_or_policy = self
.tls13_early_data_telemetry
.rejected_decrypt_or_policy
.saturating_add(1);
return Err(Error::InvalidLength(
"record plaintext exceeds configured limit",
));
}
if let Some(max_bytes) = self.tls13_early_data_max_bytes {
let next_total = self
.tls13_early_data_opened_bytes
.saturating_add(plaintext.len() as u64);
if next_total > u64::from(max_bytes) {
self.tls13_early_data_telemetry.rejected_decrypt_or_policy = self
.tls13_early_data_telemetry
.rejected_decrypt_or_policy
.saturating_add(1);
return Err(Error::InvalidLength(
"tls13 early-data exceeds accepted ticket max_early_data_size",
));
}
self.tls13_early_data_opened_bytes = next_total;
}
self.tls13_early_data_telemetry.accepted_records = self
.tls13_early_data_telemetry
.accepted_records
.saturating_add(1);
Ok(plaintext)
}
pub fn seal_tls13_early_data_record_packet(
&self,
psk: &[u8],
content: &[u8],
content_type: u8,
aad: &[u8],
sequence: u64,
padding_len: usize,
) -> Result<Vec<u8>> {
let inner = encode_tls13_inner_plaintext(content, content_type, padding_len);
let expected_aad = self.build_tls13_record_aad(inner.len().saturating_add(16))?;
let aad_to_use = if aad.is_empty() {
&expected_aad[..]
} else {
aad
};
let record = self.seal_tls13_early_data_record(psk, &inner, aad_to_use, sequence)?;
self.encode_tls13_record_packet(&record)
}
pub fn open_tls13_early_data_record_packet(
&mut self,
psk: &[u8],
packet: &[u8],
aad: &[u8],
sequence: u64,
) -> Result<(Vec<u8>, u8)> {
let record = self.decode_tls13_record_packet(packet, sequence)?;
let expected_aad =
self.build_tls13_record_aad(record.ciphertext.len().saturating_add(record.tag.len()))?;
let aad_to_use = if aad.is_empty() {
&expected_aad[..]
} else {
aad
};
let inner = self.open_tls13_early_data_record(psk, &record, aad_to_use)?;
decode_tls13_inner_plaintext(&inner)
}
pub fn open_tls13_early_data_client_flight_packets(
&mut self,
psk: &[u8],
packets: &[Vec<u8>],
first_sequence: u64,
) -> Result<Vec<Vec<u8>>> {
let mut out = Vec::with_capacity(packets.len());
for (idx, packet) in packets.iter().enumerate() {
let sequence = first_sequence.saturating_add(idx as u64);
let (payload, content_type) =
self.open_tls13_early_data_record_packet(psk, packet, &[], sequence)?;
if content_type != RecordContentType::ApplicationData.to_u8() {
return Err(Error::ParseFailure(
"tls13 early-data packet inner content type must be application_data",
));
}
out.push(payload);
}
Ok(out)
}
#[allow(clippy::too_many_arguments)]
pub fn accept_and_open_tls13_early_data_client_flight_with_ticket_policy(
&mut self,
client_hello: &[u8],
tickets: &mut [ResumptionTicket],
current_time_ms: u64,
max_skew_ms: u32,
usage_policy: TicketUsagePolicy,
packets: &[Vec<u8>],
first_sequence: u64,
) -> Result<Vec<Vec<u8>>> {
if !self.accept_tls13_early_data_with_ticket_policy(
client_hello,
tickets,
current_time_ms,
max_skew_ms,
usage_policy,
)? {
return Ok(Vec::new());
}
let accepted_psk = self
.tls13_early_data_accepted_psk
.clone()
.ok_or(Error::StateError(
"tls13 early-data accepted ticket context is not installed",
))?;
self.open_tls13_early_data_client_flight_packets(&accepted_psk, packets, first_sequence)
}
#[allow(clippy::too_many_arguments)]
pub fn accept_and_open_tls13_early_data_client_flight_with_ticket_store(
&mut self,
client_hello: &[u8],
ticket_store: &mut TicketStore,
current_time_ms: u64,
max_skew_ms: u32,
usage_policy: TicketUsagePolicy,
packets: &[Vec<u8>],
first_sequence: u64,
) -> Result<Vec<Vec<u8>>> {
if !self.accept_tls13_early_data_with_ticket_store(
client_hello,
ticket_store,
current_time_ms,
max_skew_ms,
usage_policy,
)? {
return Ok(Vec::new());
}
let accepted_psk = self
.tls13_early_data_accepted_psk
.clone()
.ok_or(Error::StateError(
"tls13 early-data accepted ticket context is not installed",
))?;
self.open_tls13_early_data_client_flight_packets(&accepted_psk, packets, first_sequence)
}
pub fn open_record(&mut self, record: &ProtectedRecord, aad: &[u8]) -> Result<Vec<u8>> {
if self.state != HandshakeState::Finished {
return Err(Error::StateError(
"cannot open record before handshake finish",
));
}
if self.server_sequence == u64::MAX {
return Err(Error::StateError("server record sequence exhausted"));
}
if record.sequence != self.server_sequence {
return Err(Error::StateError(
"unexpected server record sequence number",
));
}
let suite = self.selected_cipher_suite.ok_or(Error::StateError(
"cipher suite must be selected before opening records",
))?;
let key = self
.server_write_key
.ok_or(Error::StateError("server write key is not installed"))?;
let iv = self
.server_write_iv
.ok_or(Error::StateError("server write iv is not installed"))?;
let nonce = build_record_nonce(&iv, record.sequence);
let plaintext = match suite {
CipherSuite::TlsChacha20Poly1305Sha256 => {
chacha20_poly1305_decrypt(&key, &nonce, aad, &record.ciphertext, &record.tag)?
}
CipherSuite::TlsAes128GcmSha256 | CipherSuite::TlsAes256GcmSha384 => {
let key_len = suite.tls13_traffic_key_len().ok_or(Error::StateError(
"tls 1.3 aes suites must define traffic key length",
))?;
let cipher = AesCipher::new(&key[..key_len])?;
aes_gcm_decrypt(&cipher, &nonce, aad, &record.ciphertext, &record.tag)?
}
CipherSuite::TlsEcdheRsaWithAes128GcmSha256
| CipherSuite::TlsEcdheRsaWithAes256GcmSha384 => {
let cipher = AesCipher::new(&key[..16])?;
aes_gcm_decrypt(&cipher, &nonce, aad, &record.ciphertext, &record.tag)?
}
};
if plaintext.len() > self.max_record_plaintext_len {
return Err(Error::InvalidLength(
"record plaintext exceeds configured limit",
));
}
self.server_sequence = self.server_sequence.wrapping_add(1);
Ok(plaintext)
}
pub fn open_own_record(&self, record: &ProtectedRecord, aad: &[u8]) -> Result<Vec<u8>> {
let suite = self.selected_cipher_suite.ok_or(Error::StateError(
"cipher suite must be selected before opening own records",
))?;
let key = self
.client_write_key
.ok_or(Error::StateError("client write key is not installed"))?;
let iv = self
.client_write_iv
.ok_or(Error::StateError("client write iv is not installed"))?;
let nonce = build_record_nonce(&iv, record.sequence);
let plaintext = match suite {
CipherSuite::TlsChacha20Poly1305Sha256 => {
chacha20_poly1305_decrypt(&key, &nonce, aad, &record.ciphertext, &record.tag)?
}
CipherSuite::TlsAes128GcmSha256 | CipherSuite::TlsAes256GcmSha384 => {
let key_len = suite.tls13_traffic_key_len().ok_or(Error::StateError(
"tls 1.3 aes suites must define traffic key length",
))?;
let cipher = AesCipher::new(&key[..key_len])?;
aes_gcm_decrypt(&cipher, &nonce, aad, &record.ciphertext, &record.tag)?
}
CipherSuite::TlsEcdheRsaWithAes128GcmSha256
| CipherSuite::TlsEcdheRsaWithAes256GcmSha384 => {
let cipher = AesCipher::new(&key[..16])?;
aes_gcm_decrypt(&cipher, &nonce, aad, &record.ciphertext, &record.tag)?
}
};
if plaintext.len() > self.max_record_plaintext_len {
return Err(Error::InvalidLength(
"record plaintext exceeds configured limit",
));
}
Ok(plaintext)
}
pub fn seal_tls12_record_packet(
&mut self,
plaintext: &[u8],
content_type: RecordContentType,
) -> Result<Vec<u8>> {
self.ensure_tls12_wire_mode()?;
let sequence = self.client_sequence;
let aad = self.build_tls12_record_aad(sequence, content_type, plaintext.len())?;
let record = self.seal_record(plaintext, &aad)?;
self.encode_tls12_record_packet(&record, content_type)
}
pub fn open_tls12_record_packet(
&mut self,
packet: &[u8],
) -> Result<(RecordContentType, Vec<u8>)> {
self.ensure_tls12_wire_mode()?;
let sequence = self.server_sequence;
let (record, content_type) = self.decode_tls12_record_packet(packet, sequence)?;
let aad = self.build_tls12_record_aad(sequence, content_type, record.ciphertext.len())?;
let plaintext = self.open_record(&record, &aad)?;
Ok((content_type, plaintext))
}
pub fn open_own_tls12_record_packet(
&self,
packet: &[u8],
sequence: u64,
) -> Result<(RecordContentType, Vec<u8>)> {
self.ensure_tls12_wire_mode()?;
let (record, content_type) = self.decode_tls12_record_packet(packet, sequence)?;
let aad = self.build_tls12_record_aad(sequence, content_type, record.ciphertext.len())?;
let plaintext = self.open_own_record(&record, &aad)?;
Ok((content_type, plaintext))
}
pub fn send_tls12_alert_packet(
&mut self,
level: AlertLevel,
description: AlertDescription,
) -> Result<Vec<u8>> {
if self.version != TlsVersion::Tls12 {
return Err(Error::StateError(
"tls12 alert records require TLS 1.2 connection",
));
}
self.seal_tls12_record_packet(
&[level.to_u8(), description.to_u8()],
RecordContentType::Alert,
)
}
pub fn send_tls12_alert_for_handshake_error(&mut self, error: &Error) -> Result<Vec<u8>> {
let (level, description) = Self::tls12_alert_for_handshake_error(error);
self.send_tls12_alert_packet(level, description)
}
pub fn recv_tls12_alert_packet(
&mut self,
packet: &[u8],
) -> Result<(AlertLevel, AlertDescription)> {
let (content_type, payload) = self.open_tls12_record_packet(packet)?;
self.parse_tls12_alert_payload(content_type, &payload)
}
pub fn recv_own_tls12_alert_packet(
&self,
packet: &[u8],
sequence: u64,
) -> Result<(AlertLevel, AlertDescription)> {
let (content_type, payload) = self.open_own_tls12_record_packet(packet, sequence)?;
self.parse_tls12_alert_payload(content_type, &payload)
}
fn parse_tls12_alert_payload(
&self,
content_type: RecordContentType,
payload: &[u8],
) -> Result<(AlertLevel, AlertDescription)> {
if content_type != RecordContentType::Alert {
return Err(Error::ParseFailure("record is not an alert content type"));
}
if payload.len() != 2 {
return Err(Error::ParseFailure("tls12 alert payload must be two bytes"));
}
let level =
AlertLevel::from_u8(payload[0]).ok_or(Error::ParseFailure("unknown alert level"))?;
let description = AlertDescription::from_u8(payload[1])
.ok_or(Error::ParseFailure("unknown alert description"))?;
Ok((level, description))
}
pub fn build_dtls12_record_packet(
&self,
content_type: RecordContentType,
epoch: u16,
sequence: u64,
payload: &[u8],
) -> Result<Vec<u8>> {
if self.version != TlsVersion::Dtls12 {
return Err(Error::StateError(
"dtls12 record packet builder requires DTLS1.2 connection",
));
}
encode_dtls_record_packet(content_type, [0xFE, 0xFD], epoch, sequence, payload)
}
pub fn parse_dtls12_record_packet(&self, packet: &[u8]) -> Result<(DtlsRecordHeader, Vec<u8>)> {
if self.version != TlsVersion::Dtls12 {
return Err(Error::StateError(
"dtls12 record packet parser requires DTLS1.2 connection",
));
}
let (header, payload) = parse_dtls_record_packet(packet)?;
if header.version != [0xFE, 0xFD] {
return Err(Error::ParseFailure("dtls record version mismatch"));
}
Ok((header, payload))
}
pub fn fragment_dtls12_handshake_message(
&self,
handshake_type: u8,
message_seq: u16,
body: &[u8],
max_fragment_len: usize,
) -> Result<Vec<Vec<u8>>> {
if self.version != TlsVersion::Dtls12 {
return Err(Error::StateError(
"dtls12 handshake fragmentation requires DTLS1.2 connection",
));
}
encode_dtls12_handshake_fragments(handshake_type, message_seq, body, max_fragment_len)
}
pub fn reassemble_dtls12_handshake_fragments(
&self,
fragments: &[Vec<u8>],
max_message_len: usize,
) -> Result<(u8, u16, Vec<u8>)> {
if self.version != TlsVersion::Dtls12 {
return Err(Error::StateError(
"dtls12 handshake reassembly requires DTLS1.2 connection",
));
}
reassemble_dtls12_handshake_fragments(fragments, max_message_len)
}
pub fn set_dtls12_anti_amplification_enforced(&mut self, enforced: bool) {
self.dtls12_anti_amplification_enforced = enforced;
}
pub fn record_dtls12_inbound_datagram(&mut self, bytes: usize) {
self.dtls12_inbound_bytes = self.dtls12_inbound_bytes.saturating_add(bytes as u64);
}
#[must_use]
pub fn dtls12_can_send_datagram_bytes(&self, bytes: usize) -> bool {
if !self.dtls12_anti_amplification_enforced {
return true;
}
if matches!(
self.dtls12_handshake_phase,
Dtls12HandshakePhase::AwaitingClientKeyExchange
| Dtls12HandshakePhase::AwaitingFinished
| Dtls12HandshakePhase::Connected
) {
return true;
}
let budget = self
.dtls12_inbound_bytes
.saturating_mul(DTLS12_ANTI_AMPLIFICATION_FACTOR);
self.dtls12_outbound_bytes.saturating_add(bytes as u64) <= budget
}
pub fn record_dtls12_outbound_datagram(&mut self, bytes: usize) -> Result<()> {
if !self.dtls12_can_send_datagram_bytes(bytes) {
return Err(Error::StateError(
"dtls12 anti-amplification budget exceeded before cookie validation",
));
}
self.dtls12_outbound_bytes = self.dtls12_outbound_bytes.saturating_add(bytes as u64);
Ok(())
}
pub fn process_dtls12_client_hello_without_cookie(
&mut self,
client_hello: &[u8],
cookie_secret: &[u8],
) -> Result<Vec<u8>> {
if self.version != TlsVersion::Dtls12 {
return Err(Error::StateError(
"dtls12 cookie exchange requires DTLS1.2 connection",
));
}
if self.dtls12_handshake_phase != Dtls12HandshakePhase::AwaitingClientHello {
return Err(Error::StateError(
"dtls12 cookie challenge requires initial client-hello phase",
));
}
let (message_type, _body) = parse_handshake_message(client_hello)?;
if message_type != HANDSHAKE_CLIENT_HELLO {
return Err(Error::ParseFailure(
"dtls12 cookie exchange requires client hello message",
));
}
let cookie = self.compute_dtls12_cookie(client_hello, cookie_secret)?;
self.dtls12_expected_cookie = Some(cookie.clone());
self.dtls12_handshake_phase = Dtls12HandshakePhase::AwaitingClientHelloWithCookie;
self.build_dtls12_hello_verify_request(&cookie)
}
pub fn process_dtls12_client_hello_with_cookie(
&mut self,
client_hello: &[u8],
cookie: &[u8],
cookie_secret: &[u8],
) -> Result<()> {
if self.version != TlsVersion::Dtls12 {
return Err(Error::StateError(
"dtls12 cookie exchange requires DTLS1.2 connection",
));
}
if self.dtls12_handshake_phase != Dtls12HandshakePhase::AwaitingClientHelloWithCookie {
return Err(Error::StateError(
"dtls12 cookie verification requires retry client-hello phase",
));
}
if cookie.is_empty() {
return Err(Error::InvalidLength(
"dtls12 client cookie must not be empty",
));
}
let (message_type, _body) = parse_handshake_message(client_hello)?;
if message_type != HANDSHAKE_CLIENT_HELLO {
return Err(Error::ParseFailure(
"dtls12 cookie verification requires client hello message",
));
}
let expected = self.compute_dtls12_cookie(client_hello, cookie_secret)?;
let Some(challenge_cookie) = self.dtls12_expected_cookie.as_ref() else {
return Err(Error::StateError(
"dtls12 cookie challenge must be issued before verification",
));
};
if !constant_time_eq(challenge_cookie, cookie) || !constant_time_eq(&expected, cookie) {
return Err(Error::ParseFailure("dtls12 client cookie mismatch"));
}
self.dtls12_expected_cookie = None;
self.dtls12_handshake_phase = Dtls12HandshakePhase::AwaitingClientKeyExchange;
self.state = HandshakeState::ClientHelloSent;
Ok(())
}
pub fn process_dtls12_client_handshake_message(&mut self, message: &[u8]) -> Result<()> {
if self.version != TlsVersion::Dtls12 {
return Err(Error::StateError(
"dtls12 handshake sequencing requires DTLS1.2 connection",
));
}
let (message_type, _body) = parse_handshake_message(message)?;
match self.dtls12_handshake_phase {
Dtls12HandshakePhase::AwaitingClientKeyExchange => {
if message_type != HANDSHAKE_CLIENT_KEY_EXCHANGE {
return Err(Error::ParseFailure(
"dtls12 expected client key exchange handshake message",
));
}
self.dtls12_handshake_phase = Dtls12HandshakePhase::AwaitingFinished;
Ok(())
}
Dtls12HandshakePhase::AwaitingFinished => {
if message_type != HANDSHAKE_FINISHED {
return Err(Error::ParseFailure(
"dtls12 expected finished handshake message",
));
}
self.dtls12_handshake_phase = Dtls12HandshakePhase::Connected;
self.state = HandshakeState::Finished;
Ok(())
}
_ => Err(Error::StateError(
"dtls12 handshake message received in invalid phase",
)),
}
}
#[must_use]
pub fn dtls12_handshake_phase(&self) -> &'static str {
match self.dtls12_handshake_phase {
Dtls12HandshakePhase::AwaitingClientHello => "awaiting_client_hello",
Dtls12HandshakePhase::AwaitingClientHelloWithCookie => {
"awaiting_client_hello_with_cookie"
}
Dtls12HandshakePhase::AwaitingClientKeyExchange => "awaiting_client_key_exchange",
Dtls12HandshakePhase::AwaitingFinished => "awaiting_finished",
Dtls12HandshakePhase::Connected => "connected",
}
}
pub fn install_dtls13_traffic_keys(
&mut self,
client_key: [u8; 16],
client_iv: [u8; 12],
server_key: [u8; 16],
server_iv: [u8; 12],
) -> Result<()> {
self.ensure_dtls13_mode()?;
self.dtls13_client_write_key = Some(client_key);
self.dtls13_client_write_iv = Some(client_iv);
self.dtls13_server_write_key = Some(server_key);
self.dtls13_server_write_iv = Some(server_iv);
self.dtls13_inbound_replay_tracker = DtlsEpochReplayTracker::new();
self.dtls13_client_inbound_replay_tracker = DtlsEpochReplayTracker::new();
Ok(())
}
pub fn set_dtls13_outbound_epoch(&mut self, epoch: u16) -> Result<()> {
self.ensure_dtls13_mode()?;
if !self.dtls13_active_flight.is_empty() && !self.is_dtls13_active_flight_complete()? {
return Err(Error::StateError(
"cannot change dtls13 outbound epoch while active flight is incomplete",
));
}
if epoch < self.dtls13_outbound_epoch {
return Err(Error::StateError("dtls13 outbound epoch must be monotonic"));
}
self.dtls13_outbound_epoch = epoch;
self.dtls13_outbound_sequence = 0;
Ok(())
}
pub fn seal_dtls13_record(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
self.ensure_dtls13_mode()?;
self.ensure_dtls13_tx_sequence_available()?;
let key = self.dtls13_client_write_key.ok_or(Error::StateError(
"dtls13 client write key is not installed",
))?;
let iv = self
.dtls13_client_write_iv
.ok_or(Error::StateError("dtls13 client write iv is not installed"))?;
let packet = seal_dtls13_aes128gcm_record(
self.dtls13_outbound_epoch,
self.dtls13_outbound_sequence,
&key,
&iv,
plaintext,
)?;
self.dtls13_outbound_sequence = self.dtls13_outbound_sequence.saturating_add(1);
Ok(packet)
}
pub fn seal_dtls13_record_for_flight(
&mut self,
plaintext: &[u8],
now_ms: u64,
) -> Result<Vec<u8>> {
self.ensure_dtls13_mode()?;
let packet = self.seal_dtls13_record(plaintext)?;
let (header, _payload) = parse_dtls_record_packet(&packet)?;
self.dtls_retransmit_tracker.track_outbound_with_schedule(
header.epoch,
header.sequence,
&packet,
now_ms,
self.dtls_retransmit_initial_timeout_ms,
)?;
Ok(packet)
}
pub fn seal_dtls13_record_flight(
&mut self,
plaintext_records: &[&[u8]],
now_ms: u64,
) -> Result<Vec<Vec<u8>>> {
self.ensure_dtls13_mode()?;
if plaintext_records.is_empty() {
return Err(Error::InvalidLength(
"dtls13 record flight must contain at least one payload",
));
}
let mut packets = Vec::with_capacity(plaintext_records.len());
for plaintext in plaintext_records {
packets.push(self.seal_dtls13_record_for_flight(plaintext, now_ms)?);
}
Ok(packets)
}
pub fn start_dtls13_active_flight(
&mut self,
plaintext_records: &[&[u8]],
now_ms: u64,
) -> Result<Vec<Vec<u8>>> {
self.ensure_dtls13_mode()?;
if !self.dtls13_active_flight.is_empty() && !self.is_dtls13_active_flight_complete()? {
return Err(Error::StateError(
"cannot start new dtls13 active flight while previous flight is incomplete",
));
}
let packets = self.seal_dtls13_record_flight(plaintext_records, now_ms)?;
self.dtls13_active_flight.clear();
for packet in &packets {
self.dtls13_active_flight
.push(self.parse_dtls_packet_key(packet)?);
}
self.dtls13_active_flight_started_at_ms = Some(now_ms);
self.dtls13_active_flight_failed = false;
Ok(packets)
}
pub fn set_dtls13_active_flight_timeout_ms(&mut self, timeout_ms: u64) -> Result<()> {
self.ensure_dtls13_mode()?;
self.dtls13_active_flight_timeout_ms = timeout_ms.max(1);
Ok(())
}
pub fn open_dtls13_record(&mut self, packet: &[u8]) -> Result<(DtlsRecordHeader, Vec<u8>)> {
self.ensure_dtls13_mode()?;
let key = self.dtls13_server_write_key.ok_or(Error::StateError(
"dtls13 server write key is not installed",
))?;
let iv = self
.dtls13_server_write_iv
.ok_or(Error::StateError("dtls13 server write iv is not installed"))?;
open_dtls13_aes128gcm_record(packet, &key, &iv, &mut self.dtls13_inbound_replay_tracker)
}
pub fn open_dtls13_client_record(
&mut self,
packet: &[u8],
) -> Result<(DtlsRecordHeader, Vec<u8>)> {
self.ensure_dtls13_mode()?;
let key = self.dtls13_client_write_key.ok_or(Error::StateError(
"dtls13 client write key is not installed",
))?;
let iv = self
.dtls13_client_write_iv
.ok_or(Error::StateError("dtls13 client write iv is not installed"))?;
open_dtls13_aes128gcm_record(
packet,
&key,
&iv,
&mut self.dtls13_client_inbound_replay_tracker,
)
}
pub fn process_dtls13_encrypted_server_flight_after_hello(
&mut self,
packets: &[Vec<u8>],
) -> Result<()> {
self.ensure_dtls13_mode()?;
if self.state != HandshakeState::ServerHelloReceived {
return Err(Error::StateError(
"dtls13 encrypted server flight requires server hello state",
));
}
if packets.len() < 4 {
return Err(Error::ParseFailure(
"dtls13 encrypted server flight is too short",
));
}
let mut messages = Vec::with_capacity(packets.len());
for packet in packets {
let (_header, plaintext) = self.open_dtls13_record(packet)?;
messages.push(plaintext);
}
let mut index = 0_usize;
self.recv_encrypted_extensions(&messages[index])?;
index += 1;
let (next_type, _) = parse_handshake_message(&messages[index])?;
if next_type == HANDSHAKE_CERTIFICATE_REQUEST {
self.recv_certificate_request(&messages[index])?;
index += 1;
}
self.recv_certificate(&messages[index])?;
index += 1;
self.recv_certificate_verify(&messages[index])?;
index += 1;
self.derive_handshake_secret()?;
self.recv_finished_message(&messages[index])?;
index += 1;
if index != messages.len() {
return Err(Error::ParseFailure(
"unexpected trailing dtls13 encrypted server handshake messages",
));
}
Ok(())
}
pub fn process_dtls13_full_server_handshake_flight(
&mut self,
server_hello: &[u8],
encrypted_packets: &[Vec<u8>],
) -> Result<()> {
self.ensure_dtls13_mode()?;
self.recv_server_hello(server_hello)?;
self.process_dtls13_encrypted_server_flight_after_hello(encrypted_packets)
}
pub fn process_dtls13_encrypted_client_flight_after_server_hello(
&mut self,
packets: &[Vec<u8>],
) -> Result<()> {
self.ensure_dtls13_mode()?;
if packets.is_empty() {
return Err(Error::ParseFailure(
"dtls13 encrypted client flight is too short",
));
}
let mut message_types = Vec::with_capacity(packets.len());
for packet in packets {
let (_header, plaintext) = self.open_dtls13_client_record(packet)?;
let (handshake_type, _body) = parse_handshake_message(&plaintext)?;
message_types.push(handshake_type);
}
if message_types == [HANDSHAKE_FINISHED] {
return Ok(());
}
if message_types
== [
HANDSHAKE_CERTIFICATE,
HANDSHAKE_CERTIFICATE_VERIFY,
HANDSHAKE_FINISHED,
]
{
return Ok(());
}
Err(Error::ParseFailure(
"invalid dtls13 encrypted client flight message ordering",
))
}
pub fn build_dtls13_encrypted_client_flight_after_server_hello(
&mut self,
messages: &[Vec<u8>],
now_ms: u64,
) -> Result<Vec<Vec<u8>>> {
self.ensure_dtls13_mode()?;
if self.state != HandshakeState::ServerHelloReceived
&& self.state != HandshakeState::ServerCertificateVerified
&& self.state != HandshakeState::KeysDerived
{
return Err(Error::StateError(
"dtls13 encrypted client flight requires post-server-hello state",
));
}
self.validate_dtls13_client_post_hello_flight_order(messages)?;
let plaintext_refs: Vec<&[u8]> = messages.iter().map(Vec::as_slice).collect();
self.start_dtls13_active_flight(&plaintext_refs, now_ms)
}
pub fn advance_dtls13_outbound_epoch(&mut self) -> Result<u16> {
self.ensure_dtls13_mode()?;
if !self.dtls13_active_flight.is_empty() && !self.is_dtls13_active_flight_complete()? {
return Err(Error::StateError(
"cannot advance dtls13 outbound epoch while active flight is incomplete",
));
}
if self.dtls13_outbound_epoch == u16::MAX {
return Err(Error::StateError("dtls13 outbound epoch exhausted"));
}
self.dtls13_outbound_epoch = self.dtls13_outbound_epoch.saturating_add(1);
self.dtls13_outbound_sequence = 0;
Ok(self.dtls13_outbound_epoch)
}
pub fn open_own_dtls13_record(&self, packet: &[u8]) -> Result<(DtlsRecordHeader, Vec<u8>)> {
self.ensure_dtls13_mode()?;
let key = self.dtls13_client_write_key.ok_or(Error::StateError(
"dtls13 client write key is not installed",
))?;
let iv = self
.dtls13_client_write_iv
.ok_or(Error::StateError("dtls13 client write iv is not installed"))?;
let mut replay_tracker = DtlsEpochReplayTracker::new();
open_dtls13_aes128gcm_record(packet, &key, &iv, &mut replay_tracker)
}
pub fn mark_dtls13_record_acked_from_packet(&mut self, packet: &[u8]) -> Result<bool> {
self.ensure_dtls13_mode()?;
let (header, _payload) = parse_dtls_record_packet(packet)?;
if header.version != [0xFE, 0xFD] {
return Err(Error::ParseFailure("dtls record version mismatch"));
}
Ok(self
.dtls_retransmit_tracker
.mark_acked(header.epoch, header.sequence))
}
pub fn mark_dtls13_flight_acked_from_packets(&mut self, packets: &[Vec<u8>]) -> Result<usize> {
self.ensure_dtls13_mode()?;
let mut marked = 0_usize;
for packet in packets {
if self.mark_dtls13_record_acked_from_packet(packet)? {
marked = marked.saturating_add(1);
}
}
Ok(marked)
}
pub fn poll_dtls13_active_flight_due_packets(&mut self, now_ms: u64) -> Result<Vec<Vec<u8>>> {
self.ensure_dtls13_mode()?;
if self.dtls13_active_flight.is_empty() {
return Ok(Vec::new());
}
if self.dtls13_active_flight_has_timed_out(now_ms) {
let _ = self.abort_dtls13_active_flight()?;
return Err(Error::StateError(
"dtls13 active flight timed out before completion",
));
}
if self.dtls13_active_flight_missing_tracked_records() {
self.dtls13_active_flight.clear();
self.dtls13_active_flight_started_at_ms = None;
self.dtls13_active_flight_failed = true;
return Err(Error::StateError(
"dtls13 active flight failed after retransmit budget exhausted",
));
}
let due_packets = self.poll_dtls12_due_retransmit_packets(now_ms)?;
let mut filtered = Vec::new();
for packet in due_packets {
let key = self.parse_dtls_packet_key(&packet)?;
if self.dtls13_active_flight.contains(&key) {
filtered.push(packet);
}
}
if self.dtls13_active_flight_missing_tracked_records() {
self.dtls13_active_flight.clear();
self.dtls13_active_flight_started_at_ms = None;
self.dtls13_active_flight_failed = true;
return Err(Error::StateError(
"dtls13 active flight failed after retransmit budget exhausted",
));
}
Ok(filtered)
}
pub fn acknowledge_dtls13_active_flight_packets(
&mut self,
packets: &[Vec<u8>],
) -> Result<usize> {
self.ensure_dtls13_mode()?;
if self.dtls13_active_flight.is_empty() {
return Ok(0);
}
let mut marked = 0_usize;
for packet in packets {
let key = self.parse_dtls_packet_key(packet)?;
if !self.dtls13_active_flight.contains(&key) {
continue;
}
if self.mark_dtls12_record_acked(key.0, key.1)? {
marked = marked.saturating_add(1);
}
}
let _ = self.prune_dtls12_acked_records()?;
if self.is_dtls13_active_flight_complete()? {
self.dtls13_active_flight.clear();
self.dtls13_active_flight_started_at_ms = None;
self.dtls13_active_flight_failed = false;
}
Ok(marked)
}
pub fn abort_dtls13_active_flight(&mut self) -> Result<usize> {
self.ensure_dtls13_mode()?;
if self.dtls13_active_flight.is_empty() {
return Ok(0);
}
for (epoch, sequence) in &self.dtls13_active_flight {
let _ = self.dtls_retransmit_tracker.mark_acked(*epoch, *sequence);
}
let removed = self.prune_dtls12_acked_records()?;
self.dtls13_active_flight.clear();
self.dtls13_active_flight_started_at_ms = None;
self.dtls13_active_flight_failed = false;
Ok(removed)
}
#[must_use]
pub fn dtls13_active_flight_failed(&self) -> bool {
self.dtls13_active_flight_failed
}
pub fn is_dtls13_active_flight_complete(&self) -> Result<bool> {
self.ensure_dtls13_mode()?;
if self.dtls13_active_flight.is_empty() {
return Ok(true);
}
for (epoch, sequence) in &self.dtls13_active_flight {
let still_pending = self.dtls_retransmit_tracker.records().iter().any(|record| {
record.epoch == *epoch && record.sequence == *sequence && !record.acknowledged
});
if still_pending {
return Ok(false);
}
}
Ok(true)
}
fn compute_dtls12_cookie(&self, client_hello: &[u8], cookie_secret: &[u8]) -> Result<Vec<u8>> {
if cookie_secret.is_empty() {
return Err(Error::InvalidLength(
"dtls12 cookie secret must not be empty",
));
}
let mut material = Vec::with_capacity(cookie_secret.len() + client_hello.len());
material.extend_from_slice(cookie_secret);
material.extend_from_slice(client_hello);
let digest = hash_bytes_for_algorithm(HashAlgorithm::Sha256, &material);
let cookie_len = digest.len().min(16);
Ok(digest[..cookie_len].to_vec())
}
fn build_dtls12_hello_verify_request(&self, cookie: &[u8]) -> Result<Vec<u8>> {
if cookie.is_empty() {
return Err(Error::InvalidLength("dtls12 cookie must not be empty"));
}
if cookie.len() > DTLS12_MAX_COOKIE_LEN {
return Err(Error::InvalidLength(
"dtls12 cookie exceeds 8-bit cookie length field",
));
}
let mut body = Vec::with_capacity(3 + cookie.len());
body.extend_from_slice(&[0xFE, 0xFD]);
body.push(cookie.len() as u8);
body.extend_from_slice(cookie);
Ok(encode_handshake_message(
HANDSHAKE_HELLO_VERIFY_REQUEST,
&body,
))
}
fn parse_dtls_packet_key(&self, packet: &[u8]) -> Result<(u16, u64)> {
let (header, _payload) = parse_dtls_record_packet(packet)?;
if header.version != [0xFE, 0xFD] {
return Err(Error::ParseFailure("dtls record version mismatch"));
}
Ok((header.epoch, header.sequence))
}
fn dtls13_active_flight_has_timed_out(&self, now_ms: u64) -> bool {
let Some(started_at_ms) = self.dtls13_active_flight_started_at_ms else {
return false;
};
now_ms.saturating_sub(started_at_ms) > self.dtls13_active_flight_timeout_ms
}
fn dtls13_active_flight_missing_tracked_records(&self) -> bool {
self.dtls13_active_flight.iter().any(|(epoch, sequence)| {
!self
.dtls_retransmit_tracker
.records()
.iter()
.any(|record| record.epoch == *epoch && record.sequence == *sequence)
})
}
fn validate_dtls13_client_post_hello_flight_order(&self, messages: &[Vec<u8>]) -> Result<()> {
if messages.is_empty() {
return Err(Error::InvalidLength(
"dtls13 encrypted client flight must contain at least one message",
));
}
let mut message_types = Vec::with_capacity(messages.len());
for message in messages {
let (handshake_type, _body) = parse_handshake_message(message)?;
message_types.push(handshake_type);
}
if message_types == [HANDSHAKE_FINISHED] {
return Ok(());
}
if message_types
== [
HANDSHAKE_CERTIFICATE,
HANDSHAKE_CERTIFICATE_VERIFY,
HANDSHAKE_FINISHED,
]
{
return Ok(());
}
Err(Error::ParseFailure(
"invalid dtls13 client post-hello flight message ordering",
))
}
fn build_tls12_record_aad(
&self,
sequence: u64,
content_type: RecordContentType,
plaintext_len: usize,
) -> Result<[u8; 13]> {
let len = u16::try_from(plaintext_len)
.map_err(|_| Error::InvalidLength("tls12 plaintext length exceeds 16-bit field"))?;
let mut aad = [0_u8; 13];
aad[..8].copy_from_slice(&sequence.to_be_bytes());
aad[8] = content_type.to_u8();
aad[9..11].copy_from_slice(&legacy_wire_version(self.version));
aad[11..13].copy_from_slice(&len.to_be_bytes());
Ok(aad)
}
fn build_tls13_record_aad(&self, payload_len: usize) -> Result<[u8; 5]> {
let len = u16::try_from(payload_len)
.map_err(|_| Error::InvalidLength("tls13 record payload length exceeds u16 range"))?;
let mut aad = [0_u8; 5];
aad[0] = RecordContentType::ApplicationData.to_u8();
aad[1..3].copy_from_slice(&0x0303_u16.to_be_bytes());
aad[3..5].copy_from_slice(&len.to_be_bytes());
Ok(aad)
}
fn encode_tls12_record_packet(
&self,
record: &ProtectedRecord,
content_type: RecordContentType,
) -> Result<Vec<u8>> {
let mut payload = Vec::with_capacity(record.ciphertext.len() + record.tag.len());
payload.extend_from_slice(&record.ciphertext);
payload.extend_from_slice(&record.tag);
encode_tls12_ciphertext_record(
content_type.to_u8(),
legacy_wire_version(self.version),
&payload,
)
}
fn decode_tls12_record_packet(
&self,
packet: &[u8],
sequence: u64,
) -> Result<(ProtectedRecord, RecordContentType)> {
let (content_type_u8, version, payload) = decode_tls12_ciphertext_record(packet)?;
let strict_version = legacy_wire_version(self.version);
let legacy_compat_ok = self.tls12_allow_legacy_record_versions
&& (version == [0x03, 0x01] || version == [0x03, 0x02]);
if version != strict_version && !legacy_compat_ok {
return Err(Error::ParseFailure(
"tls12 record has invalid legacy version",
));
}
let content_type = RecordContentType::from_u8(content_type_u8)
.ok_or(Error::ParseFailure("unknown tls12 record content type"))?;
if payload.len() < 16 {
return Err(Error::ParseFailure("tls12 record payload too short"));
}
let tag_offset = payload.len() - 16;
let mut tag = [0_u8; 16];
tag.copy_from_slice(&payload[tag_offset..]);
Ok((
ProtectedRecord {
sequence,
ciphertext: payload[..tag_offset].to_vec(),
tag,
},
content_type,
))
}
fn encode_tls13_record_packet(&self, record: &ProtectedRecord) -> Result<Vec<u8>> {
let mut payload = Vec::with_capacity(record.ciphertext.len() + record.tag.len());
payload.extend_from_slice(&record.ciphertext);
payload.extend_from_slice(&record.tag);
encode_tls13_ciphertext_record(&payload)
}
fn decode_tls13_record_packet(&self, packet: &[u8], sequence: u64) -> Result<ProtectedRecord> {
let payload = decode_tls13_ciphertext_record(packet)?;
let tag_offset = payload.len() - 16;
let mut tag = [0_u8; 16];
tag.copy_from_slice(&payload[tag_offset..]);
Ok(ProtectedRecord {
sequence,
ciphertext: payload[..tag_offset].to_vec(),
tag,
})
}
fn ensure_dtls13_tx_sequence_available(&self) -> Result<()> {
if self.dtls13_outbound_sequence > DTLS13_MAX_SEQUENCE {
return Err(Error::StateError(
"dtls13 outbound record sequence exhausted",
));
}
Ok(())
}
pub fn set_dtls12_retransmit_initial_timeout_ms(&mut self, timeout_ms: u64) -> Result<()> {
self.ensure_dtls12_mode()?;
self.dtls_retransmit_initial_timeout_ms = timeout_ms.max(1);
Ok(())
}
pub fn set_dtls12_max_retransmit_attempts(&mut self, attempts: u8) -> Result<()> {
self.ensure_dtls12_mode()?;
self.dtls_max_retransmit_attempts = attempts.max(1);
Ok(())
}
pub fn build_dtls12_record_packet_for_flight(
&mut self,
content_type: RecordContentType,
epoch: u16,
sequence: u64,
payload: &[u8],
now_ms: u64,
) -> Result<Vec<u8>> {
self.ensure_dtls12_mode()?;
let packet = self.build_dtls12_record_packet(content_type, epoch, sequence, payload)?;
self.dtls_retransmit_tracker.track_outbound_with_schedule(
epoch,
sequence,
&packet,
now_ms,
self.dtls_retransmit_initial_timeout_ms,
)?;
Ok(packet)
}
pub fn mark_dtls12_record_acked(&mut self, epoch: u16, sequence: u64) -> Result<bool> {
self.ensure_dtls12_mode()?;
Ok(self.dtls_retransmit_tracker.mark_acked(epoch, sequence))
}
pub fn poll_dtls12_due_retransmit_packets(&mut self, now_ms: u64) -> Result<Vec<Vec<u8>>> {
self.ensure_dtls12_mode()?;
Ok(self
.dtls_retransmit_tracker
.collect_due_retransmit_packets(now_ms, self.dtls_max_retransmit_attempts))
}
#[must_use]
pub fn dtls12_pending_retransmit_packets(&self) -> Vec<Vec<u8>> {
if !self.version.is_dtls() {
return Vec::new();
}
self.dtls_retransmit_tracker.pending_retransmit_packets()
}
pub fn prune_dtls12_acked_records(&mut self) -> Result<usize> {
self.ensure_dtls12_mode()?;
Ok(self.dtls_retransmit_tracker.prune_acked())
}
fn ensure_dtls12_mode(&self) -> Result<()> {
if !self.version.is_dtls() {
return Err(Error::StateError(
"dtls retransmit scheduler requires DTLS connection",
));
}
Ok(())
}
fn ensure_dtls13_mode(&self) -> Result<()> {
if !self.version.is_dtls() {
return Err(Error::StateError("dtls13 APIs require DTLS connection"));
}
Ok(())
}
fn ensure_tls12_wire_mode(&self) -> Result<()> {
if self.version == TlsVersion::Tls10
|| self.version == TlsVersion::Tls11
|| self.version == TlsVersion::Tls12
{
return Ok(());
}
Err(Error::StateError(
"tls12 record packets require TLS 1.0/1.1/1.2 connection",
))
}
pub fn seal_record_fragments(
&mut self,
plaintext: &[u8],
aad: &[u8],
fragment_len: usize,
) -> Result<Vec<ProtectedRecord>> {
if fragment_len == 0 {
return Err(Error::InvalidLength(
"fragment length must be greater than zero",
));
}
if fragment_len > self.max_record_plaintext_len {
return Err(Error::InvalidLength(
"fragment length exceeds configured record plaintext limit",
));
}
if plaintext.is_empty() {
return Ok(Vec::new());
}
let fragment_count = plaintext.len().div_ceil(fragment_len);
let required_sequences = u64::try_from(fragment_count)
.map_err(|_| Error::InvalidLength("too many record fragments requested"))?;
let highest_sequence = self
.client_sequence
.checked_add(required_sequences.saturating_sub(1));
if highest_sequence.is_none() {
return Err(Error::StateError(
"insufficient record sequence space for all fragments",
));
}
let mut out = Vec::with_capacity(fragment_count);
let mut offset = 0_usize;
while offset < plaintext.len() {
let end = (offset + fragment_len).min(plaintext.len());
out.push(self.seal_record(&plaintext[offset..end], aad)?);
offset = end;
}
Ok(out)
}
pub fn open_record_fragments(
&mut self,
records: &[ProtectedRecord],
aad: &[u8],
) -> Result<Vec<u8>> {
if records.is_empty() {
return Ok(Vec::new());
}
let base_sequence = self.server_sequence;
for (index, record) in records.iter().enumerate() {
let expected_sequence = base_sequence
.checked_add(index as u64)
.ok_or(Error::ParseFailure("record fragment sequence overflow"))?;
if record.sequence != expected_sequence {
return Err(Error::ParseFailure(
"record fragments must be contiguous sequences",
));
}
}
let mut out = Vec::new();
for record in records {
out.extend_from_slice(&self.open_record(record, aad)?);
}
Ok(out)
}
pub fn open_own_record_fragments(
&self,
records: &[ProtectedRecord],
aad: &[u8],
) -> Result<Vec<u8>> {
if records.is_empty() {
return Ok(Vec::new());
}
let base_sequence = records[0].sequence;
for (index, record) in records.iter().enumerate() {
let expected_sequence = base_sequence
.checked_add(index as u64)
.ok_or(Error::ParseFailure("record fragment sequence overflow"))?;
if record.sequence != expected_sequence {
return Err(Error::ParseFailure(
"record fragments must be contiguous sequences",
));
}
}
let mut out = Vec::new();
for record in records {
out.extend_from_slice(&self.open_own_record(record, aad)?);
}
Ok(out)
}
pub fn seal_tls13_inner_record(
&mut self,
content: &[u8],
content_type: u8,
aad: &[u8],
padding_len: usize,
) -> Result<ProtectedRecord> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 inner plaintext records require TLS 1.3 connection",
));
}
let inner = encode_tls13_inner_plaintext(content, content_type, padding_len);
self.seal_record(&inner, aad)
}
pub fn open_tls13_inner_record(
&mut self,
record: &ProtectedRecord,
aad: &[u8],
) -> Result<(Vec<u8>, u8)> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 inner plaintext records require TLS 1.3 connection",
));
}
let inner = self.open_record(record, aad)?;
decode_tls13_inner_plaintext(&inner)
}
pub fn open_own_tls13_inner_record(
&self,
record: &ProtectedRecord,
aad: &[u8],
) -> Result<(Vec<u8>, u8)> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 inner plaintext records require TLS 1.3 connection",
));
}
let inner = self.open_own_record(record, aad)?;
decode_tls13_inner_plaintext(&inner)
}
pub fn seal_tls13_record_packet(
&mut self,
content: &[u8],
content_type: u8,
aad: &[u8],
padding_len: usize,
) -> Result<Vec<u8>> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 record packets require TLS 1.3 connection",
));
}
let record = self.seal_tls13_inner_record(content, content_type, aad, padding_len)?;
self.encode_tls13_record_packet(&record)
}
pub fn open_tls13_record_packet(&mut self, packet: &[u8], aad: &[u8]) -> Result<(Vec<u8>, u8)> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 record packets require TLS 1.3 connection",
));
}
let record = self.decode_tls13_record_packet(packet, self.server_sequence)?;
self.open_tls13_inner_record(&record, aad)
}
pub fn open_own_tls13_record_packet(
&self,
packet: &[u8],
sequence: u64,
aad: &[u8],
) -> Result<(Vec<u8>, u8)> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 record packets require TLS 1.3 connection",
));
}
let record = self.decode_tls13_record_packet(packet, sequence)?;
self.open_own_tls13_inner_record(&record, aad)
}
pub fn send_tls13_alert(
&mut self,
level: AlertLevel,
description: AlertDescription,
aad: &[u8],
) -> Result<ProtectedRecord> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 alert records require TLS 1.3 connection",
));
}
let payload = [level.to_u8(), description.to_u8()];
let record =
self.seal_tls13_inner_record(&payload, RecordContentType::Alert.to_u8(), aad, 0)?;
self.apply_tls13_alert_effects(level, description, true);
Ok(record)
}
pub fn send_tls13_alert_packet(
&mut self,
level: AlertLevel,
description: AlertDescription,
aad: &[u8],
) -> Result<Vec<u8>> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 alert records require TLS 1.3 connection",
));
}
let record = self.send_tls13_alert(level, description, aad)?;
self.encode_tls13_record_packet(&record)
}
pub fn recv_tls13_alert(
&mut self,
record: &ProtectedRecord,
aad: &[u8],
) -> Result<(AlertLevel, AlertDescription)> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 alert records require TLS 1.3 connection",
));
}
let (payload, content_type) = self.open_tls13_inner_record(record, aad)?;
self.process_parsed_tls13_alert(payload, content_type)
}
pub fn recv_own_tls13_alert(
&mut self,
record: &ProtectedRecord,
aad: &[u8],
) -> Result<(AlertLevel, AlertDescription)> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 alert records require TLS 1.3 connection",
));
}
let (payload, content_type) = self.open_own_tls13_inner_record(record, aad)?;
self.process_parsed_tls13_alert(payload, content_type)
}
pub fn recv_tls13_alert_packet(
&mut self,
packet: &[u8],
aad: &[u8],
) -> Result<(AlertLevel, AlertDescription)> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 alert records require TLS 1.3 connection",
));
}
let (payload, content_type) = self.open_tls13_record_packet(packet, aad)?;
self.process_parsed_tls13_alert(payload, content_type)
}
pub fn recv_own_tls13_alert_packet(
&mut self,
packet: &[u8],
sequence: u64,
aad: &[u8],
) -> Result<(AlertLevel, AlertDescription)> {
if !self.version.uses_tls13_handshake_semantics() {
return Err(Error::StateError(
"tls13 alert records require TLS 1.3 connection",
));
}
let (payload, content_type) = self.open_own_tls13_record_packet(packet, sequence, aad)?;
self.process_parsed_tls13_alert(payload, content_type)
}
fn process_parsed_tls13_alert(
&mut self,
payload: Vec<u8>,
content_type: u8,
) -> Result<(AlertLevel, AlertDescription)> {
if RecordContentType::from_u8(content_type) != Some(RecordContentType::Alert) {
return Err(Error::ParseFailure("record is not an alert content type"));
}
if payload.len() != 2 {
return Err(Error::ParseFailure("tls13 alert payload must be two bytes"));
}
let level =
AlertLevel::from_u8(payload[0]).ok_or(Error::ParseFailure("unknown alert level"))?;
let description = AlertDescription::from_u8(payload[1])
.ok_or(Error::ParseFailure("unknown alert description"))?;
self.apply_tls13_alert_effects(level, description, false);
Ok((level, description))
}
fn apply_tls13_alert_effects(
&mut self,
level: AlertLevel,
description: AlertDescription,
from_local_send: bool,
) {
if description == AlertDescription::CloseNotify {
if from_local_send {
self.tls13_local_close_notify_sent = true;
} else {
self.tls13_peer_close_notify_received = true;
}
}
if level == AlertLevel::Fatal {
self.state = HandshakeState::Idle;
}
}
#[must_use]
pub fn tls13_peer_close_notify_received(&self) -> bool {
self.tls13_peer_close_notify_received
}
#[must_use]
pub fn tls13_local_close_notify_sent(&self) -> bool {
self.tls13_local_close_notify_sent
}
fn reset_tls13_certificate_auth_state(&mut self) {
self.tls13_server_leaf_public_key_der = None;
self.tls13_server_certificate_chain_validated = false;
self.tls13_server_name_acknowledged = false;
self.tls13_selected_alpn_protocol = None;
self.tls13_server_ocsp_staple = None;
self.tls13_server_ocsp_staple_verified = false;
}
fn validate_tls13_hrr_retry_group_support(&self) -> Result<()> {
if !self.version.uses_tls13_handshake_semantics() || !self.tls13_hrr_seen {
return Ok(());
}
let requested_group = self.tls13_hrr_requested_group.ok_or(Error::ParseFailure(
"hello retry request is missing requested key_share group",
))?;
if !super::keyshare::tls13_key_share_group_supported(requested_group) {
return Err(Error::StateError(
"hello retry request requested unsupported key_share group",
));
}
Ok(())
}
fn derive_tls13_early_data_record_key_iv(&self, psk: &[u8]) -> Result<(Vec<u8>, [u8; 12])> {
let hash_algorithm = self.negotiated_hash_algorithm();
let hash_len = hash_algorithm.output_len();
let transcript_hash = hash_bytes_for_algorithm(hash_algorithm, &self.transcript);
let early_secret = hkdf_extract_for_hash(hash_algorithm, psk);
let client_early_traffic_secret = tls13_expand_label_for_hash(
hash_algorithm,
&early_secret,
b"c e traffic",
&transcript_hash,
hash_len,
)?;
let key_len = self.tls13_early_data_key_len();
let key = tls13_expand_label_for_hash(
hash_algorithm,
&client_early_traffic_secret,
b"key",
&[],
key_len,
)?;
let iv: [u8; 12] = tls13_expand_label_for_hash(
hash_algorithm,
&client_early_traffic_secret,
b"iv",
&[],
12,
)?
.try_into()
.expect("tls13 early-data iv should be 12 bytes");
Ok((key, iv))
}
fn tls13_early_data_key_len(&self) -> usize {
match self.selected_cipher_suite {
Some(CipherSuite::TlsAes256GcmSha384 | CipherSuite::TlsChacha20Poly1305Sha256) => 32,
_ => 16,
}
}
fn tls13_early_data_uses_chacha20_poly1305(&self) -> bool {
matches!(
self.selected_cipher_suite,
Some(CipherSuite::TlsChacha20Poly1305Sha256)
)
}
fn validate_tls13_server_certificate_chain(&mut self, certificates: &[Vec<u8>]) -> Result<()> {
if certificates.is_empty() {
return Err(Error::ParseFailure(
"certificate list must include leaf certificate",
));
}
if self.tls13_server_trust_anchors_der.is_empty() {
return Err(Error::StateError(
"tls13 server trust anchors are not configured",
));
}
let validation_time =
self.tls13_server_validation_time
.as_deref()
.ok_or(Error::StateError(
"tls13 server validation time is not configured",
))?;
let leaf = parse_certificate(&certificates[0])
.map_err(|_| Error::ParseFailure("failed to parse server leaf certificate"))?;
if let Some(expected_hostname) = self.tls13_server_expected_hostname.as_deref() {
if !certificate_matches_hostname(&leaf, expected_hostname) {
return Err(Error::CryptoFailure(
"server certificate hostname validation failed",
));
}
}
let mut parsed_intermediates = Vec::new();
for der in &certificates[1..] {
let parsed = parse_certificate(der).map_err(|_| {
Error::ParseFailure("failed to parse server intermediate certificate")
})?;
parsed_intermediates.push(parsed);
}
for der in &self.tls13_server_intermediates_der {
let parsed = parse_certificate(der).map_err(|_| {
Error::ParseFailure("failed to parse configured server intermediate certificate")
})?;
parsed_intermediates.push(parsed);
}
let mut parsed_anchors = Vec::new();
for der in &self.tls13_server_trust_anchors_der {
let parsed = parse_certificate(der).map_err(|_| {
Error::ParseFailure("failed to parse configured trust anchor certificate")
})?;
parsed_anchors.push(parsed);
}
validate_certificate_chain(
&leaf,
&parsed_intermediates,
&parsed_anchors,
validation_time,
)
.map_err(|_| Error::CryptoFailure("server certificate chain validation failed"))?;
self.tls13_server_leaf_public_key_der = Some(leaf.subject_public_key.clone());
self.tls13_server_certificate_chain_validated = true;
Ok(())
}
fn verify_tls13_server_certificate_verify_signature(
&self,
signature_scheme: u16,
signature: &[u8],
) -> Result<()> {
let leaf_spki =
self.tls13_server_leaf_public_key_der
.as_deref()
.ok_or(Error::StateError(
"server leaf public key is unavailable for certificate verify",
))?;
let signed_message = build_tls13_server_certificate_verify_message(&self.transcript_hash());
match signature_scheme {
TLS13_SIGALG_ECDSA_SECP256R1_SHA256 => {
let public_key = P256PublicKey::from_uncompressed(leaf_spki)?;
let (r, s) = parse_ecdsa_signature_der(signature)?;
p256_ecdsa_verify_sha256(&public_key, &signed_message, &r, &s).map_err(|_| {
Error::CryptoFailure("tls13 certificate verify signature validation failed")
})
}
TLS13_SIGALG_RSA_PSS_RSAE_SHA256 => {
let public_key = parse_rsa_public_key_der(leaf_spki)?;
rsassa_pss_sha256_verify(&public_key, &signed_message, signature, 32).map_err(
|_| {
Error::CryptoFailure("tls13 certificate verify signature validation failed")
},
)
}
TLS13_SIGALG_RSA_PSS_RSAE_SHA384 => {
let public_key = parse_rsa_public_key_der(leaf_spki)?;
rsassa_pss_sha384_verify(&public_key, &signed_message, signature, 48).map_err(
|_| {
Error::CryptoFailure("tls13 certificate verify signature validation failed")
},
)
}
TLS13_SIGALG_ED25519 => {
let public_key = ed25519_public_key_from_subject_public_key_info(leaf_spki)?;
ed25519_verify(&public_key, &signed_message, signature).map_err(|_| {
Error::CryptoFailure("tls13 certificate verify signature validation failed")
})
}
TLS13_SIGALG_MLDSA65 => {
let public_key = MlDsaPublicKey::from_bytes(leaf_spki).map_err(|_| {
Error::ParseFailure("failed to parse mldsa server public key bytes")
})?;
mldsa_verify(&public_key, &signed_message, signature).map_err(|_| {
Error::CryptoFailure("tls13 certificate verify signature validation failed")
})
}
_ => Err(Error::UnsupportedFeature(
"unsupported tls13 certificate verify signature scheme",
)),
}
}
pub fn set_record_sequences_for_test(&mut self, client_sequence: u64, server_sequence: u64) {
self.client_sequence = client_sequence;
self.server_sequence = server_sequence;
}
pub fn set_tls13_certificate_verify_material_for_test(&mut self, leaf_spki_der: Vec<u8>) {
self.tls13_server_leaf_public_key_der = Some(leaf_spki_der);
self.tls13_server_certificate_chain_validated = true;
}
pub fn tls13_server_certificate_verify_message_for_test(&self) -> Vec<u8> {
build_tls13_server_certificate_verify_message(&self.transcript_hash())
}
fn install_traffic_keys(
&mut self,
hash_algorithm: HashAlgorithm,
secret: &[u8],
transcript_hash: &[u8],
) -> Result<()> {
let (client_key, server_key, client_iv, server_iv) = match self.version {
TlsVersion::Tls13 | TlsVersion::Dtls13 => {
let hash_len = hash_algorithm.output_len();
let client_hs_traffic = tls13_expand_label_for_hash(
hash_algorithm,
secret,
b"c hs traffic",
transcript_hash,
hash_len,
)?;
let server_hs_traffic = tls13_expand_label_for_hash(
hash_algorithm,
secret,
b"s hs traffic",
transcript_hash,
hash_len,
)?;
self.tls13_client_handshake_traffic_secret = Some(client_hs_traffic.clone());
self.tls13_server_handshake_traffic_secret = Some(server_hs_traffic.clone());
self.install_tls13_record_protection_keys(
hash_algorithm,
&client_hs_traffic,
&server_hs_traffic,
)?;
return Ok(());
}
TlsVersion::Tls10 | TlsVersion::Tls11 | TlsVersion::Tls12 | TlsVersion::Dtls12 => {
let client_key_16: [u8; 16] =
hkdf_expand_for_hash(hash_algorithm, secret, b"client_write_key", 16)?
.try_into()
.expect("hkdf output length should be 16");
let server_key_16: [u8; 16] =
hkdf_expand_for_hash(hash_algorithm, secret, b"server_write_key", 16)?
.try_into()
.expect("hkdf output length should be 16");
let mut client_key = [0_u8; 32];
let mut server_key = [0_u8; 32];
client_key[..16].copy_from_slice(&client_key_16);
server_key[..16].copy_from_slice(&server_key_16);
let client_iv: [u8; 12] =
hkdf_expand_for_hash(hash_algorithm, secret, b"client_write_iv", 12)?
.try_into()
.expect("hkdf output length should be 12");
let server_iv: [u8; 12] =
hkdf_expand_for_hash(hash_algorithm, secret, b"server_write_iv", 12)?
.try_into()
.expect("hkdf output length should be 12");
(client_key, server_key, client_iv, server_iv)
}
};
self.client_write_key = Some(client_key);
self.server_write_key = Some(server_key);
self.client_write_iv = Some(client_iv);
self.server_write_iv = Some(server_iv);
self.sync_dtls13_traffic_keys_from_record_protection_state();
Ok(())
}
fn install_tls13_application_traffic_keys(&mut self) -> Result<()> {
if !self.version.uses_tls13_handshake_semantics() {
return Ok(());
}
let hash_algorithm = self.negotiated_hash_algorithm();
let hash_len = hash_algorithm.output_len();
let transcript_hash = self.transcript_hash();
let handshake_secret = self.handshake_secret.as_ref().ok_or(Error::StateError(
"handshake secret must be available before tls13 application traffic keys",
))?;
let derived = tls13_expand_label_for_hash(
hash_algorithm,
handshake_secret,
b"derived",
&[],
hash_len,
)?;
let zero_ikm = vec![0_u8; hash_len];
let master_secret = hkdf_extract_with_salt_for_hash(hash_algorithm, &derived, &zero_ikm);
let client_app_secret = tls13_expand_label_for_hash(
hash_algorithm,
&master_secret,
b"c ap traffic",
&transcript_hash,
hash_len,
)?;
let server_app_secret = tls13_expand_label_for_hash(
hash_algorithm,
&master_secret,
b"s ap traffic",
&transcript_hash,
hash_len,
)?;
self.install_tls13_record_protection_keys(
hash_algorithm,
&client_app_secret,
&server_app_secret,
)?;
self.install_tls13_exporter_and_resumption_secrets(
hash_algorithm,
&master_secret,
&transcript_hash,
)?;
self.tls13_master_secret = Some(master_secret);
self.tls13_client_application_traffic_secret = Some(client_app_secret);
self.tls13_server_application_traffic_secret = Some(server_app_secret);
self.client_sequence = 0;
self.server_sequence = 0;
Ok(())
}
fn install_tls13_exporter_and_resumption_secrets(
&mut self,
hash_algorithm: HashAlgorithm,
master_secret: &[u8],
transcript_hash: &[u8],
) -> Result<()> {
let hash_len = hash_algorithm.output_len();
self.tls13_exporter_master_secret = Some(tls13_expand_label_for_hash(
hash_algorithm,
master_secret,
b"exp master",
transcript_hash,
hash_len,
)?);
self.tls13_resumption_master_secret = Some(tls13_expand_label_for_hash(
hash_algorithm,
master_secret,
b"res master",
transcript_hash,
hash_len,
)?);
Ok(())
}
fn install_tls13_record_protection_keys(
&mut self,
hash_algorithm: HashAlgorithm,
client_traffic_secret: &[u8],
server_traffic_secret: &[u8],
) -> Result<()> {
let suite = self.selected_cipher_suite.ok_or(Error::StateError(
"cipher suite must be selected before tls13 record protection keys",
))?;
let key_len = suite.tls13_traffic_key_len().ok_or(Error::StateError(
"tls 1.3 record protection requires a tls 1.3 AEAD cipher suite",
))?;
let client_key_material = tls13_expand_label_for_hash(
hash_algorithm,
client_traffic_secret,
b"key",
&[],
key_len,
)?;
let server_key_material = tls13_expand_label_for_hash(
hash_algorithm,
server_traffic_secret,
b"key",
&[],
key_len,
)?;
let mut client_key = [0_u8; 32];
let mut server_key = [0_u8; 32];
client_key[..key_len].copy_from_slice(&client_key_material);
server_key[..key_len].copy_from_slice(&server_key_material);
let client_iv: [u8; 12] =
tls13_expand_label_for_hash(hash_algorithm, client_traffic_secret, b"iv", &[], 12)?
.try_into()
.expect("tls13 iv length should be 12");
let server_iv: [u8; 12] =
tls13_expand_label_for_hash(hash_algorithm, server_traffic_secret, b"iv", &[], 12)?
.try_into()
.expect("tls13 iv length should be 12");
self.client_write_key = Some(client_key);
self.server_write_key = Some(server_key);
self.client_write_iv = Some(client_iv);
self.server_write_iv = Some(server_iv);
self.sync_dtls13_traffic_keys_from_record_protection_state();
Ok(())
}
fn sync_dtls13_traffic_keys_from_record_protection_state(&mut self) {
if !self.version.is_dtls() {
return;
}
self.dtls13_client_write_key = self.client_write_key.map(|full| {
full[..16]
.try_into()
.expect("dtls13 shim copies first 16 bytes of traffic key material")
});
self.dtls13_client_write_iv = self.client_write_iv;
self.dtls13_server_write_key = self.server_write_key.map(|full| {
full[..16]
.try_into()
.expect("dtls13 shim copies first 16 bytes of traffic key material")
});
self.dtls13_server_write_iv = self.server_write_iv;
self.dtls13_outbound_epoch = 0;
self.dtls13_outbound_sequence = 0;
self.dtls13_inbound_replay_tracker = DtlsEpochReplayTracker::new();
self.dtls13_client_inbound_replay_tracker = DtlsEpochReplayTracker::new();
}
fn install_tls13_finished_key(
&mut self,
hash_algorithm: HashAlgorithm,
prk: &[u8],
) -> Result<()> {
self.tls13_finished_key = match self.version {
TlsVersion::Tls13 | TlsVersion::Dtls13 => {
let finished_len = hash_algorithm.output_len();
Some(hkdf_expand_for_hash(
hash_algorithm,
prk,
b"tls13 finished",
finished_len,
)?)
}
TlsVersion::Tls10 | TlsVersion::Tls11 | TlsVersion::Tls12 | TlsVersion::Dtls12 => None,
};
Ok(())
}
fn compute_expected_finished(&self) -> Result<Vec<u8>> {
let hash = self.transcript_hash();
match self.version {
TlsVersion::Tls12 | TlsVersion::Dtls12 => {
let secret = self.handshake_secret.as_ref().ok_or(Error::StateError(
"handshake secret must be available before finished",
))?;
tls12_prf_for_hash(
self.negotiated_hash_algorithm(),
secret,
b"client finished",
&hash,
12,
)
}
TlsVersion::Tls13 | TlsVersion::Dtls13 => {
let key = self.tls13_finished_key.as_ref().ok_or(Error::StateError(
"tls13 finished key must be available before finished",
))?;
Ok(finished_hmac_for_hash(
self.negotiated_hash_algorithm(),
key,
&hash,
))
}
TlsVersion::Tls10 | TlsVersion::Tls11 => Ok(finished_hmac_for_hash(
self.negotiated_hash_algorithm(),
b"finished",
&hash,
)),
}
}
fn append_transcript(&mut self, message: &[u8]) {
self.transcript.extend_from_slice(message);
self.transcript_hash.update(message);
}
fn reset_transcript_for_new_handshake(&mut self) {
self.transcript.clear();
self.transcript_hash = TranscriptHashState::for_version(self.version);
}
fn reset_tls13_early_data_transcript_to_client_hello(&mut self, client_hello: &[u8]) {
self.transcript.clear();
self.transcript_hash = TranscriptHashState::for_version(self.version);
self.append_transcript(client_hello);
}
fn prepare_client_key_share(&mut self, random: &[u8]) -> Result<Tls13ClientPublicKeyShares> {
if !self.version.uses_tls13_handshake_semantics() {
return Ok(Tls13ClientPublicKeyShares::default());
}
let x25519_private = derive_deterministic_x25519_private(random, b"tls13 client x25519");
let x25519_public = x25519_private.clone().public_key().bytes;
self.tls13_client_x25519_private = Some(x25519_private);
let p256_private = derive_deterministic_p256_private(random, b"tls13 client secp256r1")?;
let p256_public = p256_private.public_key()?.to_uncompressed()?;
self.tls13_client_p256_private = Some(p256_private);
let (mlkem_private, mlkem_public) =
derive_deterministic_mlkem768_keypair(random, b"tls13 client mlkem768")?;
self.tls13_client_mlkem768_private = Some(mlkem_private);
let mlkem_public = mlkem_public.as_bytes().to_vec();
let mut hybrid_public = Vec::with_capacity(32 + mlkem_public.len());
hybrid_public.extend_from_slice(&x25519_public);
hybrid_public.extend_from_slice(&mlkem_public);
Ok(Tls13ClientPublicKeyShares {
x25519: Some(x25519_public),
secp256r1_uncompressed: Some(p256_public),
mlkem768: Some(mlkem_public),
x25519_mlkem768_hybrid: Some(hybrid_public),
})
}
fn rebuild_transcript_hash_from_selected_suite(&mut self) {
let Some(suite) = self.selected_cipher_suite else {
return;
};
self.transcript_hash = suite.transcript_hash_state();
self.transcript_hash.update(&self.transcript);
}
fn reset_transcript_for_hrr(&mut self) {
let prior_hash = self.transcript_hash();
self.transcript.clear();
if let Some(suite) = self.selected_cipher_suite {
self.transcript_hash = suite.transcript_hash_state();
} else {
self.transcript_hash = TranscriptHashState::for_version(self.version);
}
let message_hash = encode_handshake_message(0xFE, &prior_hash);
self.append_transcript(&message_hash);
}
fn negotiated_hash_algorithm(&self) -> HashAlgorithm {
self.selected_cipher_suite
.map(CipherSuite::hash_algorithm)
.unwrap_or_else(|| self.transcript_hash.algorithm())
}
}
fn derive_tls13_handshake_secret(
hash_algorithm: HashAlgorithm,
shared_secret: &[u8],
suite: Option<CipherSuite>,
) -> Result<Vec<u8>> {
let hash_len = hash_algorithm.output_len();
let zero_ikm = vec![0_u8; hash_len];
let early_secret = hkdf_extract_for_hash(hash_algorithm, &zero_ikm);
let derived =
tls13_expand_label_for_hash(hash_algorithm, &early_secret, b"derived", &[], hash_len)?;
let mut handshake_secret =
hkdf_extract_with_salt_for_hash(hash_algorithm, &derived, shared_secret);
if let Some(selected) = suite {
if selected.hash_algorithm() != hash_algorithm {
handshake_secret =
hkdf_extract_with_salt_for_hash(selected.hash_algorithm(), &derived, shared_secret);
}
}
Ok(handshake_secret)
}
fn combine_tls13_hybrid_shared_secret(classical: &[u8; 32], pq: &[u8; 32]) -> [u8; 32] {
sha256(&[classical.as_slice(), pq.as_slice()].concat())
}
fn tls12_prf_for_hash(
hash_algorithm: HashAlgorithm,
secret: &[u8],
label: &[u8],
seed: &[u8],
len: usize,
) -> Result<Vec<u8>> {
match hash_algorithm {
HashAlgorithm::Sha256 => tls12_prf_sha256(secret, label, seed, len),
HashAlgorithm::Sha384 => tls12_prf_sha384(secret, label, seed, len),
}
}
fn constant_time_eq(left: &[u8], right: &[u8]) -> bool {
let max_len = left.len().max(right.len());
let mut diff = left.len() ^ right.len();
for idx in 0..max_len {
let l = left.get(idx).copied().unwrap_or(0);
let r = right.get(idx).copied().unwrap_or(0);
diff |= usize::from(l ^ r);
}
diff == 0
}
fn extract_first_psk_binder_from_client_hello(client_hello: &[u8]) -> Result<Vec<u8>> {
let info = parse_client_hello_info(client_hello)?;
info.extensions
.psk_binders
.first()
.cloned()
.ok_or(Error::ParseFailure(
"client hello missing pre_shared_key binder",
))
}
fn zero_client_hello_psk_binders(client_hello: &[u8]) -> Result<Vec<u8>> {
let (handshake_type, body) = parse_handshake_message(client_hello)?;
if handshake_type != HANDSHAKE_CLIENT_HELLO {
return Err(Error::ParseFailure("invalid client hello type"));
}
if body.len() < 39 {
return Err(Error::ParseFailure("client hello body too short"));
}
let mut out = client_hello.to_vec();
let session_id_len = body[34] as usize;
let suites_len_offset = 35 + session_id_len;
if body.len() < suites_len_offset + 2 {
return Err(Error::ParseFailure(
"client hello missing cipher suites length",
));
}
let suites_len =
u16::from_be_bytes([body[suites_len_offset], body[suites_len_offset + 1]]) as usize;
let suites_end = suites_len_offset + 2 + suites_len;
if body.len() < suites_end + 3 {
return Err(Error::ParseFailure(
"client hello missing compression methods",
));
}
let compression_methods_len = body[suites_end] as usize;
let compression_methods_end = suites_end + 1 + compression_methods_len;
if body.len() < compression_methods_end + 2 {
return Err(Error::ParseFailure("client hello missing extension length"));
}
let extensions_len = u16::from_be_bytes([
body[compression_methods_end],
body[compression_methods_end + 1],
]) as usize;
let extensions_start_in_body = compression_methods_end + 2;
let extensions_end_in_body = extensions_start_in_body + extensions_len;
if body.len() < extensions_end_in_body {
return Err(Error::ParseFailure("client hello extensions truncated"));
}
let body_offset = 4; let mut ext_cursor = extensions_start_in_body;
while ext_cursor < extensions_end_in_body {
if extensions_end_in_body - ext_cursor < 4 {
return Err(Error::ParseFailure(
"client hello extension header truncated",
));
}
let ext_type = u16::from_be_bytes([body[ext_cursor], body[ext_cursor + 1]]);
let ext_len = u16::from_be_bytes([body[ext_cursor + 2], body[ext_cursor + 3]]) as usize;
let ext_data_start = ext_cursor + 4;
let ext_data_end = ext_data_start + ext_len;
if ext_data_end > extensions_end_in_body {
return Err(Error::ParseFailure("client hello extension truncated"));
}
if ext_type == EXT_PRE_SHARED_KEY {
if ext_len < 4 {
return Err(Error::ParseFailure("pre_shared_key extension too short"));
}
let identities_len =
u16::from_be_bytes([body[ext_data_start], body[ext_data_start + 1]]) as usize;
if ext_len < 2 + identities_len + 2 {
return Err(Error::ParseFailure("pre_shared_key identities truncated"));
}
let binders_len_offset = ext_data_start + 2 + identities_len;
let binders_len =
u16::from_be_bytes([body[binders_len_offset], body[binders_len_offset + 1]])
as usize;
let mut binder_cursor = binders_len_offset + 2;
let binders_end = binder_cursor + binders_len;
if binders_end != ext_data_end {
return Err(Error::ParseFailure(
"invalid pre_shared_key binder vector length",
));
}
while binder_cursor < binders_end {
let binder_len = body[binder_cursor] as usize;
binder_cursor += 1;
if binder_cursor + binder_len > binders_end {
return Err(Error::ParseFailure("pre_shared_key binder bytes truncated"));
}
let start = body_offset + binder_cursor;
let end = start + binder_len;
out[start..end].fill(0);
binder_cursor += binder_len;
}
return Ok(out);
}
ext_cursor = ext_data_end;
}
Err(Error::ParseFailure(
"client hello missing pre_shared_key extension",
))
}
fn default_client_cipher_suites(version: TlsVersion) -> Vec<CipherSuite> {
match version {
TlsVersion::Tls13 | TlsVersion::Dtls13 => vec![
CipherSuite::TlsAes256GcmSha384,
CipherSuite::TlsAes128GcmSha256,
CipherSuite::TlsChacha20Poly1305Sha256,
],
TlsVersion::Tls10 | TlsVersion::Tls11 | TlsVersion::Tls12 | TlsVersion::Dtls12 => {
vec![
CipherSuite::TlsEcdheRsaWithAes256GcmSha384,
CipherSuite::TlsEcdheRsaWithAes128GcmSha256,
]
}
}
}
#[allow(clippy::too_many_arguments)]
fn encode_client_hello_body(
version: TlsVersion,
random: &[u8],
suites: &[CipherSuite],
key_shares: &Tls13ClientPublicKeyShares,
sni_server_name: Option<&str>,
alpn_protocols: &[Vec<u8>],
request_ocsp_stapling: bool,
offer_early_data: bool,
psk_offer: Option<&PskClientOffer<'_>>,
tls12_session_id: Option<&[u8]>,
) -> Result<Vec<u8>> {
if random.len() != 32 {
return Err(Error::InvalidLength("client hello random must be 32 bytes"));
}
if suites.is_empty() {
return Err(Error::InvalidLength(
"client hello suite list must not be empty",
));
}
let mut body = Vec::new();
body.extend_from_slice(&legacy_wire_version(version));
body.extend_from_slice(random);
if version == TlsVersion::Tls12 {
let session_id = tls12_session_id.unwrap_or(&[]);
if session_id.len() > 32 {
return Err(Error::InvalidLength(
"tls12 session id must not exceed 32 bytes",
));
}
body.push(session_id.len() as u8);
body.extend_from_slice(session_id);
} else {
body.push(0x00); }
body.extend_from_slice(&((suites.len() * 2) as u16).to_be_bytes());
for suite in suites {
body.extend_from_slice(&suite.to_u16().to_be_bytes());
}
body.extend_from_slice(&[0x01, 0x00]); let extensions = build_client_hello_extensions(
version,
key_shares,
sni_server_name,
alpn_protocols,
request_ocsp_stapling,
offer_early_data,
psk_offer,
)?;
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
Ok(body)
}
fn encode_server_hello_body(
version: TlsVersion,
suite: CipherSuite,
random: &[u8],
) -> Result<Vec<u8>> {
encode_server_hello_body_with_key_share(version, suite, random, None)
}
fn encode_server_hello_body_with_key_share(
version: TlsVersion,
suite: CipherSuite,
random: &[u8],
key_share_override: Option<(u16, &[u8])>,
) -> Result<Vec<u8>> {
if random.len() != 32 {
return Err(Error::InvalidLength("server hello random must be 32 bytes"));
}
let mut body = Vec::new();
body.extend_from_slice(&legacy_wire_version(version));
body.extend_from_slice(random);
body.push(0x00); body.extend_from_slice(&suite.to_u16().to_be_bytes());
body.push(0x00); let mut extensions = Vec::new();
if version.uses_tls13_handshake_semantics() {
push_extension(
&mut extensions,
EXT_SUPPORTED_VERSIONS,
&0x0304_u16.to_be_bytes(),
);
let mut key_share = Vec::new();
if let Some((g, bytes)) = key_share_override {
if g == TLS13_KEY_SHARE_GROUP_X25519 && bytes.len() != 32 {
return Err(Error::ParseFailure(
"invalid x25519 server key_share key_exchange length",
));
}
if g == TLS13_KEY_SHARE_GROUP_SECP256R1 && bytes.len() != 65 {
return Err(Error::ParseFailure(
"invalid secp256r1 server key_share key_exchange length",
));
}
if g == TLS13_KEY_SHARE_GROUP_MLKEM768 && bytes.len() != MLKEM_CIPHERTEXT_LEN {
return Err(Error::ParseFailure(
"invalid mlkem768 server key_share key_exchange length",
));
}
if g == TLS13_KEY_SHARE_GROUP_X25519_MLKEM768_HYBRID
&& bytes.len() != (32 + MLKEM_CIPHERTEXT_LEN)
{
return Err(Error::ParseFailure(
"invalid x25519_mlkem768 hybrid server key_share key_exchange length",
));
}
if g != TLS13_KEY_SHARE_GROUP_X25519
&& g != TLS13_KEY_SHARE_GROUP_SECP256R1
&& g != TLS13_KEY_SHARE_GROUP_MLKEM768
&& g != TLS13_KEY_SHARE_GROUP_X25519_MLKEM768_HYBRID
{
return Err(Error::ParseFailure("unsupported server key_share group"));
}
key_share.extend_from_slice(&g.to_be_bytes());
key_share.extend_from_slice(&(bytes.len() as u16).to_be_bytes());
key_share.extend_from_slice(bytes);
} else {
let private = derive_deterministic_x25519_private(random, b"tls13 server x25519");
let public = private.public_key().bytes;
key_share.extend_from_slice(&TLS13_KEY_SHARE_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&32_u16.to_be_bytes());
key_share.extend_from_slice(&public);
}
push_extension(&mut extensions, EXT_KEY_SHARE, &key_share);
}
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
Ok(body)
}
fn parse_server_hello(msg: &[u8]) -> Result<ParsedServerHello> {
if msg.len() == 3 && msg.first().copied() == Some(HANDSHAKE_SERVER_HELLO) {
let suite_id = u16::from_be_bytes([msg[1], msg[2]]);
let suite = CipherSuite::from_u16(suite_id)
.ok_or(Error::ParseFailure("unsupported cipher suite"))?;
return Ok(ParsedServerHello {
suite,
key_share: None,
hello_retry_request: false,
requested_group: None,
});
}
let (handshake_type, body) = parse_handshake_message(msg)?;
if handshake_type != HANDSHAKE_SERVER_HELLO {
return Err(Error::ParseFailure("invalid server hello type"));
}
if body.len() < 40 {
return Err(Error::ParseFailure("server hello body too short"));
}
let session_id_len = body[34] as usize;
let suite_start = 35 + session_id_len;
let suite_end = suite_start + 2;
if body.len() < suite_end + 3 {
return Err(Error::ParseFailure("server hello missing cipher suite"));
}
let suite_id = u16::from_be_bytes([body[suite_start], body[suite_start + 1]]);
let suite =
CipherSuite::from_u16(suite_id).ok_or(Error::ParseFailure("unsupported cipher suite"))?;
let legacy_version = u16::from_be_bytes([body[0], body[1]]);
if is_tls13_suite(suite) && legacy_version != 0x0303 && legacy_version != 0xFEFD {
return Err(Error::ParseFailure(
"invalid tls13 server hello legacy_version",
));
}
let compression_method = body[suite_end];
if compression_method != 0x00 {
return Err(Error::ParseFailure(
"invalid server hello compression method",
));
}
let random = &body[2..34];
let hello_retry_request = random == TLS13_HRR_RANDOM;
let mut key_share_parsed = None;
let mut requested_group = None;
let mut seen_key_share_extension = false;
let mut seen_supported_versions_extension = false;
let mut supports_tls13 = false;
let mut seen_extension_types = Vec::new();
let ext_len_offset = suite_end + 1;
let ext_len = u16::from_be_bytes([body[ext_len_offset], body[ext_len_offset + 1]]) as usize;
let ext_start = ext_len_offset + 2;
let ext_end = ext_start + ext_len;
if ext_end > body.len() {
return Err(Error::ParseFailure("server hello extensions truncated"));
}
let mut cursor = &body[ext_start..ext_end];
while !cursor.is_empty() {
if cursor.len() < 4 {
return Err(Error::ParseFailure(
"server hello extension header truncated",
));
}
let ext_type = u16::from_be_bytes([cursor[0], cursor[1]]);
let ext_data_len = u16::from_be_bytes([cursor[2], cursor[3]]) as usize;
cursor = &cursor[4..];
if cursor.len() < ext_data_len {
return Err(Error::ParseFailure("server hello extension truncated"));
}
if seen_extension_types.contains(&ext_type) {
return Err(Error::ParseFailure("duplicate server hello extension type"));
}
seen_extension_types.push(ext_type);
let ext_data = &cursor[..ext_data_len];
match ext_type {
EXT_SIGNATURE_ALGORITHMS | EXT_PSK_KEY_EXCHANGE_MODES | EXT_SERVER_NAME => {
return Err(Error::ParseFailure(
"server hello contains forbidden extension type",
));
}
EXT_SUPPORTED_VERSIONS => {
if ext_data_len != 2 {
return Err(Error::ParseFailure(
"invalid server hello supported_versions length",
));
}
seen_supported_versions_extension = true;
let selected_version = u16::from_be_bytes([ext_data[0], ext_data[1]]);
if selected_version != 0x0304 {
return Err(Error::ParseFailure(
"invalid tls13 server hello supported_versions value",
));
}
supports_tls13 = true;
}
EXT_KEY_SHARE => {
seen_key_share_extension = true;
if hello_retry_request {
if ext_data_len != 2 {
return Err(Error::ParseFailure("invalid hrr key_share length"));
}
requested_group = Some(u16::from_be_bytes([ext_data[0], ext_data[1]]));
} else {
if ext_data_len < 4 {
return Err(Error::ParseFailure("invalid server key_share length"));
}
let group = u16::from_be_bytes([ext_data[0], ext_data[1]]);
let key_len = u16::from_be_bytes([ext_data[2], ext_data[3]]) as usize;
if ext_data_len != 4 + key_len {
return Err(Error::ParseFailure("invalid server key_share length"));
}
key_share_parsed = Some(match group {
TLS13_KEY_SHARE_GROUP_X25519 => {
if key_len != 32 {
return Err(Error::ParseFailure(
"invalid x25519 server key_share key_exchange length",
));
}
let mut key = [0_u8; 32];
key.copy_from_slice(&ext_data[4..36]);
Tls13ServerKeyShareParsed::X25519(key)
}
TLS13_KEY_SHARE_GROUP_SECP256R1 => {
if key_len != 65 {
return Err(Error::ParseFailure(
"invalid secp256r1 server key_share key_exchange length",
));
}
let mut key = [0_u8; 65];
key.copy_from_slice(&ext_data[4..69]);
Tls13ServerKeyShareParsed::Secp256r1(key)
}
TLS13_KEY_SHARE_GROUP_MLKEM768 => {
if key_len != MLKEM_CIPHERTEXT_LEN {
return Err(Error::ParseFailure(
"invalid mlkem768 server key_share key_exchange length",
));
}
Tls13ServerKeyShareParsed::MlKem768(ext_data[4..].to_vec())
}
TLS13_KEY_SHARE_GROUP_X25519_MLKEM768_HYBRID => {
if key_len != (32 + MLKEM_CIPHERTEXT_LEN) {
return Err(Error::ParseFailure(
"invalid x25519_mlkem768 hybrid server key_share key_exchange length",
));
}
let mut x25519 = [0_u8; 32];
x25519.copy_from_slice(&ext_data[4..36]);
let mlkem768 = ext_data[36..].to_vec();
Tls13ServerKeyShareParsed::X25519MlKem768Hybrid { x25519, mlkem768 }
}
_ => {
return Err(Error::ParseFailure("unsupported server key_share"));
}
});
}
}
_ => {}
}
cursor = &cursor[ext_data_len..];
}
if hello_retry_request && !seen_key_share_extension {
return Err(Error::ParseFailure("hrr missing key_share extension"));
}
if !hello_retry_request
&& is_tls13_suite(suite)
&& legacy_version == 0x0303
&& !seen_supported_versions_extension
{
return Err(Error::ParseFailure(
"tls13 server hello missing supported_versions extension",
));
}
if !hello_retry_request && is_tls13_suite(suite) && legacy_version == 0x0303 && !supports_tls13
{
return Err(Error::ParseFailure(
"invalid tls13 server hello supported_versions value",
));
}
if !hello_retry_request
&& is_tls13_suite(suite)
&& legacy_version == 0x0303
&& !seen_key_share_extension
{
return Err(Error::ParseFailure(
"tls13 server hello missing key_share extension",
));
}
Ok(ParsedServerHello {
suite,
key_share: key_share_parsed,
hello_retry_request,
requested_group,
})
}
fn is_tls13_suite(suite: CipherSuite) -> bool {
matches!(
suite,
CipherSuite::TlsAes128GcmSha256
| CipherSuite::TlsAes256GcmSha384
| CipherSuite::TlsChacha20Poly1305Sha256
)
}
fn parse_client_hello_info(msg: &[u8]) -> Result<ClientHelloInfo> {
let (handshake_type, body) = parse_handshake_message(msg)?;
if handshake_type != HANDSHAKE_CLIENT_HELLO {
return Err(Error::ParseFailure("invalid client hello type"));
}
if body.len() < 39 {
return Err(Error::ParseFailure("client hello body too short"));
}
let session_id_len = body[34] as usize;
let suites_len_offset = 35 + session_id_len;
if body.len() < suites_len_offset + 2 {
return Err(Error::ParseFailure(
"client hello missing cipher suites length",
));
}
let suites_len =
u16::from_be_bytes([body[suites_len_offset], body[suites_len_offset + 1]]) as usize;
if suites_len == 0 || !suites_len.is_multiple_of(2) {
return Err(Error::ParseFailure(
"invalid client hello cipher suites length",
));
}
let suites_start = suites_len_offset + 2;
let suites_end = suites_start + suites_len;
if body.len() < suites_end + 3 {
return Err(Error::ParseFailure("client hello cipher suites truncated"));
}
let mut suites = Vec::new();
for chunk in body[suites_start..suites_end].chunks_exact(2) {
let codepoint = u16::from_be_bytes([chunk[0], chunk[1]]);
if let Some(suite) = CipherSuite::from_u16(codepoint) {
suites.push(suite);
}
}
if suites.is_empty() {
return Err(Error::ParseFailure(
"client hello has no supported cipher suite",
));
}
let compression_methods_len = body[suites_end] as usize;
let compression_methods_start = suites_end + 1;
let compression_methods_end = compression_methods_start + compression_methods_len;
if body.len() < compression_methods_end + 2 {
return Err(Error::ParseFailure(
"client hello missing compression methods",
));
}
let extensions_len = u16::from_be_bytes([
body[compression_methods_end],
body[compression_methods_end + 1],
]) as usize;
let extensions_start = compression_methods_end + 2;
let extensions_end = extensions_start + extensions_len;
if body.len() < extensions_end {
return Err(Error::ParseFailure("client hello extensions truncated"));
}
if body.len() != extensions_end {
return Err(Error::ParseFailure("client hello has trailing bytes"));
}
let extensions = parse_client_hello_extensions(&body[extensions_start..extensions_end])?;
Ok(ClientHelloInfo {
offered_cipher_suites: suites,
extensions,
})
}
fn pick_intersection_suite(
hello: &ClientHelloInfo,
preferred: &[CipherSuite],
version: TlsVersion,
) -> Result<CipherSuite> {
for suite in preferred {
if !hello.offered_cipher_suites.contains(suite) {
continue;
}
if !suite_supported_by_version(*suite, version) {
continue;
}
if suite_allowed_by_extensions(*suite, version, &hello.extensions) {
return Ok(*suite);
}
}
Err(Error::ParseFailure("no mutually supported cipher suite"))
}
fn suite_supported_by_version(suite: CipherSuite, version: TlsVersion) -> bool {
match version {
TlsVersion::Tls13 | TlsVersion::Dtls13 => matches!(
suite,
CipherSuite::TlsAes128GcmSha256
| CipherSuite::TlsAes256GcmSha384
| CipherSuite::TlsChacha20Poly1305Sha256
),
TlsVersion::Tls10 | TlsVersion::Tls11 | TlsVersion::Tls12 | TlsVersion::Dtls12 => {
matches!(
suite,
CipherSuite::TlsEcdheRsaWithAes128GcmSha256
| CipherSuite::TlsEcdheRsaWithAes256GcmSha384
)
}
}
}
fn suite_allowed_by_extensions(
suite: CipherSuite,
version: TlsVersion,
extensions: &ClientHelloExtensions,
) -> bool {
match version {
TlsVersion::Tls13 | TlsVersion::Dtls13 => {
if matches!(
suite,
CipherSuite::TlsAes128GcmSha256
| CipherSuite::TlsAes256GcmSha384
| CipherSuite::TlsChacha20Poly1305Sha256
) {
return tls13_client_hello_offers_supported_key_exchange(
&extensions.supported_versions,
&extensions.key_share_groups,
&extensions.signature_algorithms,
);
}
true
}
TlsVersion::Tls10 | TlsVersion::Tls11 | TlsVersion::Tls12 | TlsVersion::Dtls12 => true,
}
}
fn build_client_hello_extensions(
version: TlsVersion,
key_shares: &Tls13ClientPublicKeyShares,
sni_server_name: Option<&str>,
alpn_protocols: &[Vec<u8>],
request_ocsp_stapling: bool,
offer_early_data: bool,
psk_offer: Option<&PskClientOffer<'_>>,
) -> Result<Vec<u8>> {
let mut extensions = Vec::new();
match version {
TlsVersion::Tls13 | TlsVersion::Dtls13 => {
let mut supported_versions = Vec::new();
supported_versions.push(4_u8);
supported_versions.extend_from_slice(&0x0304_u16.to_be_bytes());
supported_versions.extend_from_slice(&0x0303_u16.to_be_bytes());
push_extension(&mut extensions, EXT_SUPPORTED_VERSIONS, &supported_versions);
let mut sigalgs = Vec::new();
let supported_sigalgs = [
TLS13_SIGALG_ECDSA_SECP256R1_SHA256,
TLS13_SIGALG_RSA_PSS_RSAE_SHA256,
TLS13_SIGALG_RSA_PSS_RSAE_SHA384,
TLS13_SIGALG_ED25519,
TLS13_SIGALG_MLDSA65,
];
sigalgs.extend_from_slice(&((supported_sigalgs.len() * 2) as u16).to_be_bytes());
for sigalg in supported_sigalgs {
sigalgs.extend_from_slice(&sigalg.to_be_bytes());
}
push_extension(&mut extensions, EXT_SIGNATURE_ALGORITHMS, &sigalgs);
let mut key_share_list = Vec::new();
if let Some(public) = key_shares.x25519 {
key_share_list.extend_from_slice(&TLS13_KEY_SHARE_GROUP_X25519.to_be_bytes());
key_share_list.extend_from_slice(&32_u16.to_be_bytes());
key_share_list.extend_from_slice(&public);
}
if let Some(public) = key_shares.secp256r1_uncompressed {
key_share_list.extend_from_slice(&TLS13_KEY_SHARE_GROUP_SECP256R1.to_be_bytes());
key_share_list.extend_from_slice(&65_u16.to_be_bytes());
key_share_list.extend_from_slice(&public);
}
if let Some(public) = key_shares.mlkem768.as_ref() {
key_share_list.extend_from_slice(&TLS13_KEY_SHARE_GROUP_MLKEM768.to_be_bytes());
key_share_list.extend_from_slice(&(public.len() as u16).to_be_bytes());
key_share_list.extend_from_slice(public);
}
if let Some(public) = key_shares.x25519_mlkem768_hybrid.as_ref() {
key_share_list
.extend_from_slice(&TLS13_KEY_SHARE_GROUP_X25519_MLKEM768_HYBRID.to_be_bytes());
key_share_list.extend_from_slice(&(public.len() as u16).to_be_bytes());
key_share_list.extend_from_slice(public);
}
if key_share_list.is_empty() {
return Err(Error::InvalidLength(
"tls13 client hello key_share extension must not be empty",
));
}
let mut key_share_ext = Vec::new();
key_share_ext.extend_from_slice(&(key_share_list.len() as u16).to_be_bytes());
key_share_ext.extend_from_slice(&key_share_list);
push_extension(&mut extensions, EXT_KEY_SHARE, &key_share_ext);
if let Some(server_name) = sni_server_name {
let server_name_extension_data = encode_server_name_extension_data(server_name)?;
push_extension(
&mut extensions,
EXT_SERVER_NAME,
&server_name_extension_data,
);
}
if request_ocsp_stapling {
let status_request_data = encode_status_request_ocsp_extension_data()?;
push_extension(&mut extensions, EXT_STATUS_REQUEST, &status_request_data);
}
if !alpn_protocols.is_empty() {
let alpn_extension_data = encode_alpn_extension_data(alpn_protocols)?;
push_extension(&mut extensions, EXT_ALPN, &alpn_extension_data);
}
if offer_early_data {
if psk_offer.is_none() {
return Err(Error::StateError(
"tls13 early_data extension requires pre_shared_key offer",
));
}
push_extension(&mut extensions, EXT_EARLY_DATA, &[]);
}
if let Some(psk) = psk_offer {
let psk_key_exchange_modes = [1_u8, TLS13_PSK_KEY_EXCHANGE_MODE_PSK_DHE_KE];
push_extension(
&mut extensions,
EXT_PSK_KEY_EXCHANGE_MODES,
&psk_key_exchange_modes,
);
let psk_extension = encode_pre_shared_key_extension(psk)?;
push_extension(&mut extensions, EXT_PRE_SHARED_KEY, &psk_extension);
}
}
TlsVersion::Tls10 | TlsVersion::Tls11 | TlsVersion::Tls12 | TlsVersion::Dtls12 => {
let mut sigalgs = Vec::new();
sigalgs.extend_from_slice(&2_u16.to_be_bytes());
sigalgs.extend_from_slice(&0x0401_u16.to_be_bytes());
push_extension(&mut extensions, EXT_SIGNATURE_ALGORITHMS, &sigalgs);
}
}
Ok(extensions)
}
fn push_extension(out: &mut Vec<u8>, ext_type: u16, ext_data: &[u8]) {
out.extend_from_slice(&ext_type.to_be_bytes());
out.extend_from_slice(&(ext_data.len() as u16).to_be_bytes());
out.extend_from_slice(ext_data);
}
fn parse_client_hello_extensions(input: &[u8]) -> Result<ClientHelloExtensions> {
let mut out = ClientHelloExtensions::default();
let mut cursor = input;
let mut seen_supported_versions = false;
let mut seen_signature_algorithms = false;
let mut seen_key_share = false;
let mut seen_psk_key_exchange_modes = false;
let mut seen_pre_shared_key = false;
let mut seen_early_data = false;
let mut seen_extension_types = Vec::new();
while !cursor.is_empty() {
if cursor.len() < 4 {
return Err(Error::ParseFailure(
"client hello extension header truncated",
));
}
let ext_type = u16::from_be_bytes([cursor[0], cursor[1]]);
let ext_len = u16::from_be_bytes([cursor[2], cursor[3]]) as usize;
cursor = &cursor[4..];
if cursor.len() < ext_len {
return Err(Error::ParseFailure("client hello extension truncated"));
}
let ext_data = &cursor[..ext_len];
if seen_extension_types.contains(&ext_type) {
return Err(Error::ParseFailure("duplicate client hello extension type"));
}
seen_extension_types.push(ext_type);
if seen_pre_shared_key {
return Err(Error::ParseFailure(
"pre_shared_key extension must be the last extension",
));
}
match ext_type {
EXT_SUPPORTED_VERSIONS => {
if seen_supported_versions {
return Err(Error::ParseFailure(
"duplicate supported_versions extension",
));
}
out.supported_versions = parse_supported_versions_extension(ext_data)?;
seen_supported_versions = true;
}
EXT_SIGNATURE_ALGORITHMS => {
if seen_signature_algorithms {
return Err(Error::ParseFailure(
"duplicate signature_algorithms extension",
));
}
out.signature_algorithms = parse_u16_vector_with_len(ext_data)?;
if out.signature_algorithms.is_empty() {
return Err(Error::ParseFailure(
"signature_algorithms extension must not be empty",
));
}
seen_signature_algorithms = true;
}
EXT_KEY_SHARE => {
if seen_key_share {
return Err(Error::ParseFailure("duplicate key_share extension"));
}
out.key_share_groups = parse_key_share_groups_extension(ext_data)?;
seen_key_share = true;
}
EXT_SERVER_NAME => {
out.sni_server_name = Some(parse_server_name_extension(ext_data)?);
}
EXT_ALPN => {
out.alpn_protocols = parse_alpn_protocol_name_list(ext_data)?;
}
EXT_STATUS_REQUEST => {
out.status_request_ocsp = parse_status_request_ocsp_extension(ext_data)?;
}
EXT_PSK_KEY_EXCHANGE_MODES => {
if seen_psk_key_exchange_modes {
return Err(Error::ParseFailure(
"duplicate psk_key_exchange_modes extension",
));
}
out.psk_key_exchange_modes = parse_u8_vector_with_len(ext_data)?;
if !out
.psk_key_exchange_modes
.contains(&TLS13_PSK_KEY_EXCHANGE_MODE_PSK_DHE_KE)
{
return Err(Error::ParseFailure(
"psk_key_exchange_modes must include psk_dhe_ke",
));
}
seen_psk_key_exchange_modes = true;
}
EXT_PRE_SHARED_KEY => {
if seen_pre_shared_key {
return Err(Error::ParseFailure("duplicate pre_shared_key extension"));
}
let (identity_count, identities, obfuscated_ages, binders) =
parse_pre_shared_key_extension(ext_data)?;
out.psk_identity_count = identity_count;
out.psk_identities = identities;
out.psk_obfuscated_ticket_ages = obfuscated_ages;
out.psk_binders = binders;
seen_pre_shared_key = true;
}
EXT_EARLY_DATA => {
if seen_early_data {
return Err(Error::ParseFailure("duplicate early_data extension"));
}
if !ext_data.is_empty() {
return Err(Error::ParseFailure(
"client hello early_data extension must be empty",
));
}
out.early_data_offered = true;
seen_early_data = true;
}
_ => {}
}
cursor = &cursor[ext_len..];
}
if seen_pre_shared_key && !seen_psk_key_exchange_modes {
return Err(Error::ParseFailure(
"pre_shared_key extension requires psk_key_exchange_modes extension",
));
}
if seen_early_data && !seen_pre_shared_key {
return Err(Error::ParseFailure(
"early_data extension requires pre_shared_key extension",
));
}
if seen_psk_key_exchange_modes && !seen_pre_shared_key {
return Err(Error::ParseFailure(
"psk_key_exchange_modes extension requires pre_shared_key extension",
));
}
if seen_key_share && out.key_share_groups.is_empty() {
return Err(Error::ParseFailure("key_share extension must not be empty"));
}
let advertises_tls13 = out.supported_versions.contains(&0x0304);
if seen_pre_shared_key && !advertises_tls13 {
return Err(Error::ParseFailure(
"pre_shared_key extension requires tls13 supported_versions entry",
));
}
if seen_key_share && !advertises_tls13 {
return Err(Error::ParseFailure(
"key_share extension requires tls13 supported_versions entry",
));
}
if advertises_tls13 && !seen_signature_algorithms {
return Err(Error::ParseFailure(
"tls13 supported_versions requires signature_algorithms extension",
));
}
if advertises_tls13 && !seen_key_share {
return Err(Error::ParseFailure(
"tls13 supported_versions requires key_share extension",
));
}
if seen_pre_shared_key && !seen_key_share {
return Err(Error::ParseFailure(
"pre_shared_key with psk_dhe_ke requires key_share extension",
));
}
Ok(out)
}
fn parse_supported_versions_extension(input: &[u8]) -> Result<Vec<u16>> {
if input.is_empty() {
return Err(Error::ParseFailure("supported_versions extension is empty"));
}
let declared = input[0] as usize;
if input.len() != declared + 1 || !declared.is_multiple_of(2) {
return Err(Error::ParseFailure(
"invalid supported_versions extension length",
));
}
let mut versions = Vec::new();
for chunk in input[1..].chunks_exact(2) {
let version = u16::from_be_bytes([chunk[0], chunk[1]]);
if versions.contains(&version) {
return Err(Error::ParseFailure(
"duplicate supported_versions entry in extension body",
));
}
versions.push(version);
}
Ok(versions)
}
fn is_valid_sni_dns_name(name: &str) -> bool {
if name.is_empty() || !name.is_ascii() {
return false;
}
let trimmed = if let Some(stripped) = name.strip_suffix('.') {
stripped
} else {
name
};
if trimmed.is_empty() || trimmed.len() > u16::MAX as usize {
return false;
}
if trimmed
.as_bytes()
.iter()
.any(|byte| *byte <= 0x20 || *byte >= 0x7f)
{
return false;
}
for label in trimmed.split('.') {
if label.is_empty() || label.len() > 63 {
return false;
}
let bytes = label.as_bytes();
if bytes.first() == Some(&b'-') || bytes.last() == Some(&b'-') {
return false;
}
if !bytes
.iter()
.all(|byte| byte.is_ascii_alphanumeric() || *byte == b'-')
{
return false;
}
}
true
}
fn parse_server_name_extension(input: &[u8]) -> Result<String> {
if input.len() < 5 {
return Err(Error::ParseFailure("server_name extension too short"));
}
let list_len = u16::from_be_bytes([input[0], input[1]]) as usize;
if list_len == 0 || input.len() != list_len + 2 {
return Err(Error::ParseFailure("invalid server_name extension length"));
}
if input[2] != 0x00 {
return Err(Error::ParseFailure("unsupported server_name type"));
}
let name_len = u16::from_be_bytes([input[3], input[4]]) as usize;
if name_len == 0 || input.len() != 5 + name_len {
return Err(Error::ParseFailure("invalid server_name host_name length"));
}
let name = core::str::from_utf8(&input[5..])
.map_err(|_| Error::ParseFailure("invalid sni server_name"))?;
if !is_valid_sni_dns_name(name) {
return Err(Error::ParseFailure("invalid sni server_name"));
}
Ok(name.to_owned())
}
fn encode_server_name_extension_data(server_name: &str) -> Result<Vec<u8>> {
if !is_valid_sni_dns_name(server_name) {
return Err(Error::ParseFailure("invalid sni server_name"));
}
let name_bytes = server_name.as_bytes();
let mut entry = Vec::new();
entry.push(0x00); entry.extend_from_slice(&(name_bytes.len() as u16).to_be_bytes());
entry.extend_from_slice(name_bytes);
let mut out = Vec::new();
out.extend_from_slice(&(entry.len() as u16).to_be_bytes());
out.extend_from_slice(&entry);
Ok(out)
}
fn encode_status_request_ocsp_extension_data() -> Result<Vec<u8>> {
let mut out = Vec::new();
out.push(0x01); out.extend_from_slice(&0_u16.to_be_bytes()); out.extend_from_slice(&0_u16.to_be_bytes()); Ok(out)
}
fn parse_status_request_ocsp_extension(input: &[u8]) -> Result<bool> {
if input.len() != 5 {
return Err(Error::ParseFailure(
"invalid status_request extension length",
));
}
if input[0] != 0x01 {
return Err(Error::ParseFailure(
"status_request extension must use ocsp status type",
));
}
let responder_id_list_len = u16::from_be_bytes([input[1], input[2]]) as usize;
let request_extensions_len = u16::from_be_bytes([input[3], input[4]]) as usize;
if responder_id_list_len != 0 || request_extensions_len != 0 {
return Err(Error::ParseFailure(
"status_request extension non-empty responder/request vectors are unsupported",
));
}
Ok(true)
}
fn parse_alpn_protocol_name_list(input: &[u8]) -> Result<Vec<Vec<u8>>> {
if input.len() < 2 {
return Err(Error::ParseFailure(
"alpn extension missing protocol_name_list",
));
}
let declared_len = u16::from_be_bytes([input[0], input[1]]) as usize;
if declared_len == 0 || input.len() != declared_len + 2 {
return Err(Error::ParseFailure("invalid alpn extension length"));
}
let mut cursor = &input[2..];
let mut protocols = Vec::new();
while !cursor.is_empty() {
let protocol_len = cursor[0] as usize;
cursor = &cursor[1..];
if protocol_len == 0 {
return Err(Error::ParseFailure("alpn protocol must not be empty"));
}
if cursor.len() < protocol_len {
return Err(Error::ParseFailure("alpn protocol truncated"));
}
let protocol = cursor[..protocol_len].to_vec();
if protocols.contains(&protocol) {
return Err(Error::ParseFailure("duplicate alpn protocol"));
}
protocols.push(protocol);
cursor = &cursor[protocol_len..];
}
Ok(protocols)
}
fn encode_alpn_extension_data(protocols: &[Vec<u8>]) -> Result<Vec<u8>> {
if protocols.is_empty() {
return Err(Error::InvalidLength(
"alpn extension must include at least one protocol",
));
}
let mut protocol_name_list = Vec::new();
let mut seen_protocols = Vec::new();
for protocol in protocols {
if protocol.is_empty() {
return Err(Error::InvalidLength("alpn protocol must not be empty"));
}
if protocol.len() > u8::MAX as usize {
return Err(Error::InvalidLength(
"alpn protocol length must not exceed 255 bytes",
));
}
if seen_protocols.contains(protocol) {
return Err(Error::ParseFailure("duplicate alpn protocol"));
}
seen_protocols.push(protocol.clone());
protocol_name_list.push(protocol.len() as u8);
protocol_name_list.extend_from_slice(protocol);
}
let mut extension_data = Vec::new();
extension_data.extend_from_slice(&(protocol_name_list.len() as u16).to_be_bytes());
extension_data.extend_from_slice(&protocol_name_list);
Ok(extension_data)
}
fn parse_u16_vector_with_len(input: &[u8]) -> Result<Vec<u16>> {
if input.len() < 2 {
return Err(Error::ParseFailure("u16 vector missing length prefix"));
}
let len = u16::from_be_bytes([input[0], input[1]]) as usize;
if input.len() != len + 2 || !len.is_multiple_of(2) {
return Err(Error::ParseFailure("invalid u16 vector length"));
}
let mut out = Vec::new();
for chunk in input[2..].chunks_exact(2) {
let value = u16::from_be_bytes([chunk[0], chunk[1]]);
if out.contains(&value) {
return Err(Error::ParseFailure("duplicate u16 vector entry"));
}
out.push(value);
}
Ok(out)
}
fn parse_u8_vector_with_len(input: &[u8]) -> Result<Vec<u8>> {
if input.is_empty() {
return Err(Error::ParseFailure("u8 vector missing length prefix"));
}
let len = input[0] as usize;
if input.len() != len + 1 {
return Err(Error::ParseFailure("invalid u8 vector length"));
}
if len == 0 {
return Err(Error::ParseFailure("u8 vector must not be empty"));
}
let mut out = Vec::new();
for value in &input[1..] {
if out.contains(value) {
return Err(Error::ParseFailure("duplicate u8 vector entry"));
}
out.push(*value);
}
Ok(out)
}
fn parse_certificate_request_body(body: &[u8]) -> Result<()> {
if body.len() < 3 {
return Err(Error::ParseFailure("certificate request body too short"));
}
let context_len = body[0] as usize;
let ext_len_offset = 1 + context_len;
if body.len() < ext_len_offset + 2 {
return Err(Error::ParseFailure("certificate request context truncated"));
}
let ext_len = u16::from_be_bytes([body[ext_len_offset], body[ext_len_offset + 1]]) as usize;
let ext_start = ext_len_offset + 2;
if body.len() != ext_start + ext_len {
return Err(Error::ParseFailure(
"certificate request extensions truncated",
));
}
parse_certificate_request_extensions(&body[ext_start..])?;
Ok(())
}
fn parse_certificate_request_extensions(input: &[u8]) -> Result<()> {
let mut cursor = input;
let mut seen_extension_types = Vec::new();
let mut seen_signature_algorithms = false;
while !cursor.is_empty() {
if cursor.len() < 4 {
return Err(Error::ParseFailure(
"certificate request extension header truncated",
));
}
let ext_type = u16::from_be_bytes([cursor[0], cursor[1]]);
let ext_len = u16::from_be_bytes([cursor[2], cursor[3]]) as usize;
if seen_extension_types.contains(&ext_type) {
return Err(Error::ParseFailure(
"duplicate certificate request extension type",
));
}
if matches!(
ext_type,
EXT_SUPPORTED_VERSIONS
| EXT_KEY_SHARE
| EXT_PRE_SHARED_KEY
| EXT_PSK_KEY_EXCHANGE_MODES
| EXT_SERVER_NAME
) {
return Err(Error::ParseFailure(
"certificate request contains forbidden extension type",
));
}
seen_extension_types.push(ext_type);
cursor = &cursor[4..];
if cursor.len() < ext_len {
return Err(Error::ParseFailure(
"certificate request extension truncated",
));
}
if ext_type == EXT_SIGNATURE_ALGORITHMS {
let signature_algorithms = parse_u16_vector_with_len(&cursor[..ext_len])?;
if signature_algorithms.is_empty() {
return Err(Error::ParseFailure(
"certificate request signature_algorithms must not be empty",
));
}
seen_signature_algorithms = true;
}
cursor = &cursor[ext_len..];
}
if !seen_signature_algorithms {
return Err(Error::ParseFailure(
"certificate request missing signature_algorithms extension",
));
}
Ok(())
}
fn parse_encrypted_extensions_body(body: &[u8]) -> Result<ParsedEncryptedExtensions> {
if body.len() < 2 {
return Err(Error::ParseFailure("encrypted extensions body too short"));
}
let extensions_len = u16::from_be_bytes([body[0], body[1]]) as usize;
if body.len() != 2 + extensions_len {
return Err(Error::ParseFailure("encrypted extensions malformed length"));
}
let mut cursor = &body[2..];
let mut seen_extension_types = Vec::new();
let mut selected_alpn_protocol = None;
let mut server_name_acknowledged = false;
let mut early_data_accepted = false;
while !cursor.is_empty() {
if cursor.len() < 4 {
return Err(Error::ParseFailure(
"encrypted extensions entry header truncated",
));
}
let ext_type = u16::from_be_bytes([cursor[0], cursor[1]]);
let ext_len = u16::from_be_bytes([cursor[2], cursor[3]]) as usize;
if ext_len > TLS13_MAX_EXTENSION_VALUE_BYTES {
return Err(Error::ParseFailure(
"encrypted extensions extension value exceeds modeled maximum",
));
}
if seen_extension_types.contains(&ext_type) {
return Err(Error::ParseFailure("duplicate encrypted extensions type"));
}
seen_extension_types.push(ext_type);
cursor = &cursor[4..];
if cursor.len() < ext_len {
return Err(Error::ParseFailure("encrypted extensions entry truncated"));
}
let ext_data = &cursor[..ext_len];
match ext_type {
EXT_SERVER_NAME => {
if !ext_data.is_empty() {
return Err(Error::ParseFailure(
"encrypted extensions server_name must be empty",
));
}
server_name_acknowledged = true;
}
EXT_ALPN => {
let protocols = parse_alpn_protocol_name_list(ext_data)?;
if protocols.len() != 1 {
return Err(Error::ParseFailure(
"encrypted extensions alpn must select exactly one protocol",
));
}
selected_alpn_protocol = protocols.first().cloned();
}
EXT_EARLY_DATA => {
if !ext_data.is_empty() {
return Err(Error::ParseFailure(
"encrypted extensions early_data must be empty",
));
}
early_data_accepted = true;
}
EXT_SUPPORTED_VERSIONS
| EXT_KEY_SHARE
| EXT_PRE_SHARED_KEY
| EXT_PSK_KEY_EXCHANGE_MODES => {
return Err(Error::ParseFailure(
"encrypted extensions contains forbidden extension type",
));
}
_ => {}
}
cursor = &cursor[ext_len..];
}
Ok(ParsedEncryptedExtensions {
selected_alpn_protocol,
server_name_acknowledged,
early_data_accepted,
})
}
fn parse_certificate_body(body: &[u8]) -> Result<ParsedTls13CertificateBody> {
if body.len() < 4 {
return Err(Error::ParseFailure("certificate body too short"));
}
let context_len = body[0] as usize;
let list_len_offset = 1 + context_len;
if body.len() < list_len_offset + 3 {
return Err(Error::ParseFailure("certificate list length missing"));
}
let cert_list_len = u32::from_be_bytes([
0x00,
body[list_len_offset],
body[list_len_offset + 1],
body[list_len_offset + 2],
]) as usize;
let cert_list_start = list_len_offset + 3;
let cert_list_end = cert_list_start + cert_list_len;
if cert_list_end > body.len() {
return Err(Error::ParseFailure("certificate list truncated"));
}
let mut certificates = Vec::new();
let mut cursor = &body[cert_list_start..cert_list_end];
let mut leaf_ocsp_staple = None;
while !cursor.is_empty() {
if cursor.len() < 5 {
return Err(Error::ParseFailure("certificate entry truncated"));
}
let cert_len = u32::from_be_bytes([0x00, cursor[0], cursor[1], cursor[2]]) as usize;
let cert_end = 3 + cert_len;
if cursor.len() < cert_end + 2 {
return Err(Error::ParseFailure("certificate bytes truncated"));
}
certificates.push(cursor[3..cert_end].to_vec());
let ext_len = u16::from_be_bytes([cursor[cert_end], cursor[cert_end + 1]]) as usize;
let ext_end = cert_end + 2 + ext_len;
if cursor.len() < ext_end {
return Err(Error::ParseFailure(
"certificate entry extensions truncated",
));
}
let parsed_staple = parse_certificate_entry_extensions(&cursor[cert_end + 2..ext_end])?;
if certificates.len() == 1 {
leaf_ocsp_staple = parsed_staple;
}
cursor = &cursor[ext_end..];
}
if certificates.is_empty() {
return Err(Error::ParseFailure("certificate list must not be empty"));
}
if cert_list_end != body.len() {
return Err(Error::ParseFailure("certificate body trailing bytes"));
}
Ok(ParsedTls13CertificateBody {
certificates,
leaf_ocsp_staple,
})
}
fn parse_tls12_certificate_list(body: &[u8]) -> Result<Vec<Vec<u8>>> {
if body.len() < 3 {
return Err(Error::ParseFailure(
"tls12 certificate message is malformed",
));
}
let list_len = ((body[0] as usize) << 16) | ((body[1] as usize) << 8) | body[2] as usize;
if list_len == 0 || list_len != body.len() - 3 {
return Err(Error::ParseFailure(
"tls12 certificate list length is malformed",
));
}
let mut certificates = Vec::new();
let mut cursor = &body[3..];
while !cursor.is_empty() {
if cursor.len() < 3 {
return Err(Error::ParseFailure(
"tls12 certificate entry length is truncated",
));
}
let cert_len =
((cursor[0] as usize) << 16) | ((cursor[1] as usize) << 8) | cursor[2] as usize;
if cert_len == 0 {
return Err(Error::ParseFailure(
"tls12 certificate entry must not be empty",
));
}
if cursor.len() < 3 + cert_len {
return Err(Error::ParseFailure("tls12 certificate entry is truncated"));
}
certificates.push(cursor[3..3 + cert_len].to_vec());
cursor = &cursor[3 + cert_len..];
}
if certificates.is_empty() {
return Err(Error::ParseFailure(
"tls12 certificate list must not be empty",
));
}
Ok(certificates)
}
fn parse_tls12_server_key_exchange_body(body: &[u8]) -> Result<()> {
if body.len() < 8 {
return Err(Error::ParseFailure(
"tls12 server key exchange body must include key share and signature fields",
));
}
if body[0] != 0x03 {
return Err(Error::ParseFailure(
"tls12 server key exchange requires named_curve parameters",
));
}
let public_len = body[3] as usize;
if public_len == 0 {
return Err(Error::ParseFailure(
"tls12 server key exchange public key must not be empty",
));
}
let signature_header_offset = 4 + public_len;
if body.len() < signature_header_offset + 4 {
return Err(Error::ParseFailure(
"tls12 server key exchange signature header is truncated",
));
}
let signature_scheme = u16::from_be_bytes([
body[signature_header_offset],
body[signature_header_offset + 1],
]);
if !tls12_signature_scheme_is_modern(signature_scheme) {
return Err(Error::ParseFailure(
"tls12 server key exchange uses unsupported signature scheme",
));
}
let signature_len = u16::from_be_bytes([
body[signature_header_offset + 2],
body[signature_header_offset + 3],
]) as usize;
if signature_len == 0 {
return Err(Error::ParseFailure(
"tls12 server key exchange signature must not be empty",
));
}
if body.len() != signature_header_offset + 4 + signature_len {
return Err(Error::ParseFailure(
"tls12 server key exchange signature length is malformed",
));
}
Ok(())
}
fn parse_tls12_certificate_verify_body(body: &[u8]) -> Result<()> {
if body.len() < 4 {
return Err(Error::ParseFailure(
"tls12 client certificate verify body must include signature scheme and length",
));
}
let signature_scheme = u16::from_be_bytes([body[0], body[1]]);
if !tls12_signature_scheme_is_modern(signature_scheme) {
return Err(Error::ParseFailure(
"tls12 client certificate verify uses unsupported signature scheme",
));
}
let signature_len = u16::from_be_bytes([body[2], body[3]]) as usize;
if signature_len == 0 {
return Err(Error::ParseFailure(
"tls12 client certificate verify signature must not be empty",
));
}
if body.len() != 4 + signature_len {
return Err(Error::ParseFailure(
"tls12 client certificate verify signature length is malformed",
));
}
Ok(())
}
fn tls12_signature_scheme_is_modern(signature_scheme: u16) -> bool {
matches!(
signature_scheme,
TLS13_SIGALG_ECDSA_SECP256R1_SHA256
| TLS13_SIGALG_RSA_PSS_RSAE_SHA256
| TLS13_SIGALG_RSA_PSS_RSAE_SHA384
| TLS13_SIGALG_ED25519
| TLS13_SIGALG_MLDSA65
)
}
fn parse_certificate_entry_extensions(input: &[u8]) -> Result<Option<Vec<u8>>> {
let mut cursor = input;
let mut seen_extension_types = Vec::new();
let mut status_request_ocsp = None;
while !cursor.is_empty() {
if cursor.len() < 4 {
return Err(Error::ParseFailure(
"certificate entry extension header truncated",
));
}
let ext_type = u16::from_be_bytes([cursor[0], cursor[1]]);
let ext_len = u16::from_be_bytes([cursor[2], cursor[3]]) as usize;
if seen_extension_types.contains(&ext_type) {
return Err(Error::ParseFailure(
"duplicate certificate entry extension type",
));
}
seen_extension_types.push(ext_type);
cursor = &cursor[4..];
if cursor.len() < ext_len {
return Err(Error::ParseFailure("certificate entry extension truncated"));
}
let ext_data = &cursor[..ext_len];
if ext_type == EXT_STATUS_REQUEST {
if status_request_ocsp.is_some() {
return Err(Error::ParseFailure(
"duplicate certificate entry status_request extension",
));
}
status_request_ocsp = Some(parse_certificate_entry_status_request_extension(ext_data)?);
}
cursor = &cursor[ext_len..];
}
Ok(status_request_ocsp)
}
fn encode_certificate_entry_status_request_extension(ocsp_staple: &[u8]) -> Result<Vec<u8>> {
if ocsp_staple.is_empty() {
return Err(Error::InvalidLength("ocsp staple must not be empty"));
}
if ocsp_staple.len() > 0x00FF_FFFF {
return Err(Error::InvalidLength("ocsp staple is too large"));
}
let mut status_request_payload = Vec::new();
status_request_payload.push(0x01); let staple_len = ocsp_staple.len() as u32;
status_request_payload.extend_from_slice(&staple_len.to_be_bytes()[1..4]);
status_request_payload.extend_from_slice(ocsp_staple);
let mut extension = Vec::new();
extension.extend_from_slice(&EXT_STATUS_REQUEST.to_be_bytes());
extension.extend_from_slice(&(status_request_payload.len() as u16).to_be_bytes());
extension.extend_from_slice(&status_request_payload);
Ok(extension)
}
fn parse_certificate_entry_status_request_extension(input: &[u8]) -> Result<Vec<u8>> {
if input.len() < 4 {
return Err(Error::ParseFailure(
"certificate entry status_request extension is truncated",
));
}
if input[0] != 0x01 {
return Err(Error::ParseFailure(
"certificate entry status_request must use ocsp status type",
));
}
let ocsp_len = ((input[1] as usize) << 16) | ((input[2] as usize) << 8) | input[3] as usize;
if ocsp_len == 0 {
return Err(Error::ParseFailure(
"certificate entry status_request ocsp response must not be empty",
));
}
if input.len() != 4 + ocsp_len {
return Err(Error::ParseFailure(
"certificate entry status_request ocsp response is truncated",
));
}
Ok(input[4..].to_vec())
}
fn parse_certificate_verify_fields(body: &[u8]) -> Result<(u16, &[u8])> {
if body.len() < 4 {
return Err(Error::ParseFailure("certificate verify body too short"));
}
let signature_scheme = u16::from_be_bytes([body[0], body[1]]);
let sig_len = u16::from_be_bytes([body[2], body[3]]) as usize;
if body.len() != 4 + sig_len {
return Err(Error::ParseFailure(
"certificate verify signature truncated",
));
}
Ok((signature_scheme, &body[4..]))
}
fn tls13_supported_certificate_verify_signature_scheme(signature_scheme: u16) -> bool {
matches!(
signature_scheme,
TLS13_SIGALG_ECDSA_SECP256R1_SHA256
| TLS13_SIGALG_RSA_PSS_RSAE_SHA256
| TLS13_SIGALG_RSA_PSS_RSAE_SHA384
| TLS13_SIGALG_ED25519
| TLS13_SIGALG_MLDSA65
)
}
fn parse_new_session_ticket_body(body: &[u8]) -> Result<()> {
if body.len() < 11 {
return Err(Error::ParseFailure("new session ticket body too short"));
}
let nonce_len = body[8] as usize;
let ticket_len_offset = 9 + nonce_len;
if body.len() < ticket_len_offset + 2 {
return Err(Error::ParseFailure("new session ticket nonce truncated"));
}
let ticket_len =
u16::from_be_bytes([body[ticket_len_offset], body[ticket_len_offset + 1]]) as usize;
let ext_len_offset = ticket_len_offset + 2 + ticket_len;
if body.len() < ext_len_offset + 2 {
return Err(Error::ParseFailure("new session ticket bytes truncated"));
}
let ext_len = u16::from_be_bytes([body[ext_len_offset], body[ext_len_offset + 1]]) as usize;
if body.len() != ext_len_offset + 2 + ext_len {
return Err(Error::ParseFailure(
"new session ticket extensions truncated",
));
}
Ok(())
}
fn build_tls13_server_certificate_verify_message(transcript_hash: &[u8]) -> Vec<u8> {
const PREFIX_LEN: usize = 64;
const CONTEXT: &[u8] = b"TLS 1.3, server CertificateVerify";
let mut out = Vec::with_capacity(PREFIX_LEN + CONTEXT.len() + 1 + transcript_hash.len());
out.extend(core::iter::repeat_n(0x20_u8, PREFIX_LEN));
out.extend_from_slice(CONTEXT);
out.push(0x00);
out.extend_from_slice(transcript_hash);
out
}
fn parse_rsa_public_key_der(public_key_der: &[u8]) -> Result<RsaPublicKey> {
let (rsa_seq, rem) = parse_der_node(public_key_der)
.map_err(|_| Error::ParseFailure("failed to parse server RSA public key"))?;
if rsa_seq.tag != 0x30 || !rem.is_empty() {
return Err(Error::ParseFailure(
"invalid server RSA public key sequence",
));
}
let (modulus_node, rest) = parse_der_node(rsa_seq.body)
.map_err(|_| Error::ParseFailure("failed to parse server RSA modulus"))?;
let (exponent_node, tail) = parse_der_node(rest)
.map_err(|_| Error::ParseFailure("failed to parse server RSA exponent"))?;
if modulus_node.tag != 0x02 || exponent_node.tag != 0x02 || !tail.is_empty() {
return Err(Error::ParseFailure(
"invalid server RSA public key integer fields",
));
}
RsaPublicKey::from_be_bytes(modulus_node.body, exponent_node.body)
.map_err(|_| Error::CryptoFailure("failed to construct server RSA public key"))
}
fn parse_key_share_groups_extension(input: &[u8]) -> Result<Vec<u16>> {
if input.len() < 2 {
return Err(Error::ParseFailure(
"key_share extension missing list length",
));
}
let list_len = u16::from_be_bytes([input[0], input[1]]) as usize;
if input.len() != list_len + 2 {
return Err(Error::ParseFailure("invalid key_share extension length"));
}
let mut cursor = &input[2..];
let mut groups = Vec::new();
while !cursor.is_empty() {
if cursor.len() < 4 {
return Err(Error::ParseFailure("key_share entry truncated"));
}
let group = u16::from_be_bytes([cursor[0], cursor[1]]);
let key_len = u16::from_be_bytes([cursor[2], cursor[3]]) as usize;
if groups.contains(&group) {
return Err(Error::ParseFailure("duplicate key_share group"));
}
if key_len == 0 {
return Err(Error::ParseFailure(
"key_share key_exchange must not be empty",
));
}
cursor = &cursor[4..];
if cursor.len() < key_len {
return Err(Error::ParseFailure("key_share key_exchange truncated"));
}
groups.push(group);
cursor = &cursor[key_len..];
}
Ok(groups)
}
fn encode_pre_shared_key_extension(offer: &PskClientOffer<'_>) -> Result<Vec<u8>> {
if offer.identities.is_empty() || offer.binders.is_empty() {
return Err(Error::InvalidLength(
"psk identity/binder list must not be empty",
));
}
if offer.identities.len() != offer.binders.len() {
return Err(Error::InvalidLength(
"psk identity and binder list lengths must match",
));
}
let mut identities = Vec::new();
let mut binders = Vec::new();
for (identity, binder) in offer.identities.iter().zip(offer.binders.iter()) {
if identity.identity.is_empty() || binder.is_empty() {
return Err(Error::InvalidLength(
"psk identity and binder must not be empty",
));
}
if identity.identity.len() > u16::MAX as usize || binder.len() > u8::MAX as usize {
return Err(Error::InvalidLength("psk identity or binder too long"));
}
identities.extend_from_slice(&(identity.identity.len() as u16).to_be_bytes());
identities.extend_from_slice(identity.identity);
identities.extend_from_slice(&identity.obfuscated_ticket_age.to_be_bytes());
binders.push(binder.len() as u8);
binders.extend_from_slice(binder);
}
let mut out = Vec::new();
out.extend_from_slice(&(identities.len() as u16).to_be_bytes());
out.extend_from_slice(&identities);
out.extend_from_slice(&(binders.len() as u16).to_be_bytes());
out.extend_from_slice(&binders);
Ok(out)
}
fn parse_pre_shared_key_extension(
input: &[u8],
) -> Result<(usize, Vec<Vec<u8>>, Vec<u32>, Vec<Vec<u8>>)> {
if input.len() < 4 {
return Err(Error::ParseFailure("pre_shared_key extension too short"));
}
let identities_len = u16::from_be_bytes([input[0], input[1]]) as usize;
if input.len() < 2 + identities_len + 2 {
return Err(Error::ParseFailure("pre_shared_key identities truncated"));
}
let identities_end = 2 + identities_len;
let mut id_cursor = &input[2..identities_end];
let mut identity_count = 0_usize;
let mut identities = Vec::new();
let mut obfuscated_ages = Vec::new();
while !id_cursor.is_empty() {
if id_cursor.len() < 6 {
return Err(Error::ParseFailure(
"pre_shared_key identity entry truncated",
));
}
let id_len = u16::from_be_bytes([id_cursor[0], id_cursor[1]]) as usize;
if id_len == 0 {
return Err(Error::ParseFailure(
"pre_shared_key identity must not be empty",
));
}
if id_cursor.len() < 2 + id_len + 4 {
return Err(Error::ParseFailure(
"pre_shared_key identity bytes truncated",
));
}
let identity = id_cursor[2..2 + id_len].to_vec();
if identities.iter().any(|existing| existing == &identity) {
return Err(Error::ParseFailure("duplicate pre_shared_key identity"));
}
identities.push(identity);
obfuscated_ages.push(u32::from_be_bytes([
id_cursor[2 + id_len],
id_cursor[3 + id_len],
id_cursor[4 + id_len],
id_cursor[5 + id_len],
]));
identity_count = identity_count.saturating_add(1);
id_cursor = &id_cursor[2 + id_len + 4..];
}
let binders_len =
u16::from_be_bytes([input[identities_end], input[identities_end + 1]]) as usize;
let binders_start = identities_end + 2;
let binders_end = binders_start + binders_len;
if input.len() != binders_end {
return Err(Error::ParseFailure(
"invalid pre_shared_key binder vector length",
));
}
let mut binders = Vec::new();
let mut binder_cursor = &input[binders_start..binders_end];
while !binder_cursor.is_empty() {
let binder_len = binder_cursor[0] as usize;
if binder_len == 0 {
return Err(Error::ParseFailure(
"pre_shared_key binder must not be empty",
));
}
binder_cursor = &binder_cursor[1..];
if binder_cursor.len() < binder_len {
return Err(Error::ParseFailure("pre_shared_key binder bytes truncated"));
}
binders.push(binder_cursor[..binder_len].to_vec());
binder_cursor = &binder_cursor[binder_len..];
}
if identity_count != binders.len() {
return Err(Error::ParseFailure(
"pre_shared_key identity and binder counts differ",
));
}
if identity_count == 0 {
return Err(Error::ParseFailure(
"pre_shared_key extension must include at least one identity",
));
}
Ok((identity_count, identities, obfuscated_ages, binders))
}
fn legacy_wire_version(version: TlsVersion) -> [u8; 2] {
match version {
TlsVersion::Tls10 => [0x03, 0x01],
TlsVersion::Tls11 => [0x03, 0x02],
TlsVersion::Tls12 | TlsVersion::Tls13 => [0x03, 0x03],
TlsVersion::Dtls12 | TlsVersion::Dtls13 => [0xFE, 0xFD],
}
}