#![allow(dead_code, unreachable_pub)]
use crate::ct::ConstantTimeEq;
use crate::ec::x25519::X25519PrivateKey;
use crate::ec::{BoxedEcdhPrivateKey, BoxedEcdsaPublicKey, CurveId};
use crate::mlkem::{CIPHERTEXT_BYTES, MlKem768Ciphertext, MlKem768DecapsKey};
use crate::rng::RngCore;
use crate::signature_registry::SignaturePolicy;
use crate::tls::codec::extension as ext;
use crate::tls::codec::{
CipherSuite, ClientHello, ExtensionType, NamedGroup, Random, ReadCursor, ServerHello,
SignatureScheme, hs_type,
};
use crate::tls::crypto::{
AeadAlg, HashAlg, KeySchedule, RecordCrypter, Secret, SuiteParams, Transcript,
certificate_verify_content, expand_label_dyn, finished_verify_data, lookup_suite,
supported_suites, verify_signature,
};
use crate::tls::keylog::KeyLog;
use crate::tls::pki::{CrlStore, RootCertStore, verify_chain_with_crls, verify_hostname};
use crate::tls::{ContentType, Error, ProtocolVersion};
use crate::x509::{AnyPublicKey, Certificate, Time};
use alloc::string::String;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::time::Duration;
use super::ack::{ACK_CONTENT_TYPE, RecordNumber, decode as decode_ack, encode as encode_ack};
use super::reassembly::{HandshakeFragment, Reassembler, read_fragment, write_message};
use super::record::{self, ParsedDtlsRecord};
use super::record13::{self, peek_header_layout, reconstruct_seq, sn_mask_for};
use super::reliability13::{InFlightRecord, Retransmit13};
const HRR_RANDOM: [u8; 32] = [
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
];
const EXT_COOKIE: u16 = 0x002C;
const DEFAULT_MAX_FRAGMENT: usize = 1100;
const DEFAULT_MAX_RECORD_SIZE: usize = 1200;
pub(crate) struct ClientConfig13Internal {
pub roots: RootCertStore,
pub server_name: Option<String>,
pub alpn_protocols: Vec<Vec<u8>>,
pub signature_policy: Arc<SignaturePolicy>,
pub max_record_size: usize,
pub verify_certificates: bool,
pub verification_time: Option<Time>,
pub crls: CrlStore,
pub key_log: Option<Arc<dyn KeyLog>>,
pub cipher_suites: Vec<CipherSuite>,
pub groups: Vec<NamedGroup>,
pub key_share_groups: Option<Vec<NamedGroup>>,
}
impl ClientConfig13Internal {
pub fn new(roots: RootCertStore, server_name: &str) -> Self {
Self {
roots,
server_name: Some(String::from(server_name)),
alpn_protocols: Vec::new(),
signature_policy: Arc::new(SignaturePolicy::modern()),
max_record_size: DEFAULT_MAX_RECORD_SIZE,
verify_certificates: true,
verification_time: None,
crls: CrlStore::new(),
key_log: None,
cipher_suites: supported_suites().iter().map(|s| s.suite).collect(),
groups: alloc::vec![
NamedGroup::X25519MLKEM768,
NamedGroup::X25519,
NamedGroup::SECP256R1,
NamedGroup::SECP384R1,
],
key_share_groups: None,
}
}
pub fn with_crls(mut self, crls: CrlStore) -> Self {
self.crls = crls;
self
}
pub fn with_verification_time(mut self, t: Time) -> Self {
self.verification_time = Some(t);
self
}
pub fn without_certificate_verification(mut self) -> Self {
self.verify_certificates = false;
self
}
pub fn with_signature_policy(mut self, p: Arc<SignaturePolicy>) -> Self {
self.signature_policy = p;
self
}
}
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
enum State {
WaitServerHello,
WaitEncryptedExtensions,
WaitCertificate,
WaitCertificateVerify,
WaitFinished,
Connected,
Closed,
}
pub struct DtlsClientConnection13 {
config: ClientConfig13Internal,
#[allow(dead_code)]
peer_addr: Vec<u8>,
state: State,
out_msg_seq: u16,
reassembler: Reassembler,
out_dgrams: Vec<Vec<u8>>,
app_in: Vec<u8>,
plain_write_epoch: u16,
plain_write_seq: u64,
enc_write_epoch: u16,
enc_write_seq: u64,
enc_read_seq: u64,
read_replay: crate::dtls::replay::AntiReplayWindow,
x25519: X25519PrivateKey,
p256: BoxedEcdhPrivateKey,
p384: BoxedEcdhPrivateKey,
mlkem: MlKem768DecapsKey,
client_random: Random,
server_random: Option<Random>,
cookie_extension: Option<Vec<u8>>,
hrr_selected_group: Option<NamedGroup>,
hrr_processed: bool,
transcript: Transcript,
ks: Option<KeySchedule>,
client_hs_secret: Option<Secret>,
server_hs_secret: Option<Secret>,
client_app_secret: Option<Secret>,
server_app_secret: Option<Secret>,
write_crypter: Option<RecordCrypter>,
read_crypter: Option<RecordCrypter>,
write_sn_key: Option<Vec<u8>>,
read_sn_key: Option<Vec<u8>>,
read_app_sn_key: Option<Vec<u8>>,
write_app_sn_key: Option<Vec<u8>>,
pending_read_app_crypter: Option<RecordCrypter>,
pending_write_app_crypter: Option<RecordCrypter>,
suite: Option<SuiteParams>,
exporter_secret: Option<Secret>,
cert_chain: Vec<Vec<u8>>,
leaf_key: Option<AnyPublicKey>,
alpn_negotiated: Option<Vec<u8>>,
pending_acks: Vec<RecordNumber>,
retransmit: Retransmit13,
last_now: Duration,
}
impl DtlsClientConnection13 {
pub(crate) fn new<R: RngCore>(
config: ClientConfig13Internal,
peer_addr: Vec<u8>,
rng: &mut R,
) -> Self {
let x25519 = X25519PrivateKey::generate(rng);
let p256 = BoxedEcdhPrivateKey::generate(CurveId::P256, rng);
let p384 = BoxedEcdhPrivateKey::generate(CurveId::P384, rng);
let (mlkem, _) = MlKem768DecapsKey::generate(rng);
let mut client_random: Random = [0u8; 32];
rng.fill_bytes(&mut client_random);
let mut conn = Self {
config,
peer_addr,
state: State::WaitServerHello,
out_msg_seq: 0,
reassembler: Reassembler::new(),
out_dgrams: Vec::new(),
app_in: Vec::new(),
plain_write_epoch: 0,
plain_write_seq: 0,
enc_write_epoch: 0,
enc_write_seq: 0,
enc_read_seq: 0,
read_replay: crate::dtls::replay::AntiReplayWindow::new(),
x25519,
p256,
p384,
mlkem,
client_random,
server_random: None,
cookie_extension: None,
hrr_selected_group: None,
hrr_processed: false,
transcript: Transcript::new(),
ks: None,
client_hs_secret: None,
server_hs_secret: None,
client_app_secret: None,
server_app_secret: None,
write_crypter: None,
read_crypter: None,
write_sn_key: None,
read_sn_key: None,
read_app_sn_key: None,
write_app_sn_key: None,
pending_read_app_crypter: None,
pending_write_app_crypter: None,
suite: None,
exporter_secret: None,
cert_chain: Vec::new(),
leaf_key: None,
alpn_negotiated: None,
pending_acks: Vec::new(),
retransmit: Retransmit13::new(),
last_now: Duration::from_secs(0),
};
let dgram = conn.build_client_hello();
conn.emit_plaintext(dgram);
conn
}
pub fn is_handshake_complete(&self) -> bool {
self.state == State::Connected
}
pub fn negotiated_cipher_suite(&self) -> Option<u16> {
if self.is_handshake_complete() {
self.suite.map(|s| s.suite.0)
} else {
None
}
}
pub fn tls_exporter(&self, label: &[u8], context: &[u8], out: &mut [u8]) -> Result<(), Error> {
let ems = self
.exporter_secret
.as_ref()
.ok_or(Error::InappropriateState)?;
let suite = self.suite.ok_or(Error::InappropriateState)?;
crate::tls::crypto::tls_exporter(suite.hash, ems, label, context, out);
Ok(())
}
pub fn pop_outbound_datagrams(&mut self) -> Vec<Vec<u8>> {
self.flush_pending_acks();
core::mem::take(&mut self.out_dgrams)
}
pub fn take_received(&mut self) -> Vec<u8> {
core::mem::take(&mut self.app_in)
}
pub fn peer_certificates(&self) -> &[Vec<u8>] {
&self.cert_chain
}
pub fn alpn_protocol(&self) -> Option<&[u8]> {
self.alpn_negotiated.as_deref()
}
pub fn send(&mut self, plaintext: &[u8]) -> Result<(), Error> {
if self.state != State::Connected {
return Err(Error::InappropriateState);
}
let dg = self.encrypt_protected_record(ContentType::ApplicationData, plaintext)?;
self.out_dgrams.push(dg);
Ok(())
}
pub fn next_timeout(&self) -> Option<Duration> {
self.retransmit.next_timeout()
}
pub fn on_timeout(&mut self, now: Duration) {
self.last_now = now;
match self.retransmit.on_timeout(now) {
super::reliability::Action::Retransmit => {
for dg in self.retransmit.in_flight_datagrams() {
self.out_dgrams.push(dg.to_vec());
}
}
super::reliability::Action::GiveUp => self.state = State::Closed,
super::reliability::Action::Idle => {}
}
}
pub fn feed_datagram(&mut self, datagram: &[u8]) -> Result<(), Error> {
let mut off = 0usize;
while off < datagram.len() {
let first = datagram[off];
if first < 32 {
let Some(rec) = record::read_record(&datagram[off..])? else {
return Ok(());
};
off += rec.len;
self.process_plaintext_record(rec)?;
} else if (first & 0b1110_0000) == 0b0010_0000 {
let consumed = self.process_protected_record(&datagram[off..])?;
if consumed == 0 {
return Ok(());
}
off += consumed;
} else {
return Ok(());
}
}
Ok(())
}
fn process_plaintext_record(&mut self, rec: ParsedDtlsRecord<'_>) -> Result<(), Error> {
if rec.version != ProtocolVersion::DTLSv1_2 && rec.version != ProtocolVersion::DTLSv1_0 {
return Err(Error::UnsupportedVersion);
}
if rec.epoch != 0 {
return Ok(());
}
match rec.content_type {
ContentType::Handshake => self.process_handshake_record(rec.fragment),
ContentType::Alert => Ok(()),
ContentType::ChangeCipherSpec => Ok(()), _ => Err(Error::UnexpectedMessage),
}
}
fn process_protected_record(&mut self, buf: &[u8]) -> Result<usize, Error> {
let (hdr_len, body_len) = peek_header_layout(buf)?;
let total = hdr_len + body_len;
if total > buf.len() {
return Ok(0);
}
let body = &buf[hdr_len..total];
if body.len() < 16 {
return Err(Error::Decode);
}
let suite = self.suite.ok_or(Error::UnexpectedMessage)?;
let sn_key = self.read_sn_key.as_ref().ok_or(Error::UnexpectedMessage)?;
let mask_full = sn_mask_for(suite, sn_key, body)?;
let mask: &[u8] = if (buf[0] & 0b0000_1000) != 0 {
&mask_full[..2]
} else {
&mask_full[..1]
};
let (hdr, ct_body) = record13::decode_record(buf, mask)?;
let consumed = hdr.header_len + ct_body.len();
let read_epoch = self.current_read_epoch();
if (read_epoch as u8 & 0b11) != hdr.epoch_low2 {
return Ok(consumed);
}
let seq = reconstruct_seq(
hdr.seq_low,
hdr.seq_is_16bit,
self.enc_read_seq.wrapping_add(1),
);
let mut aad = buf[..hdr.header_len].to_vec();
if hdr.seq_is_16bit {
aad[1] ^= mask[0];
aad[2] ^= mask[1];
} else {
aad[1] ^= mask[0];
}
if !self.read_replay.check(seq) {
return Ok(consumed);
}
let crypter = self.read_crypter.as_mut().ok_or(Error::UnexpectedMessage)?;
let (inner_type, plain) = decrypt_dtls13_record(crypter, seq, &aad, ct_body)?;
self.read_replay.mark(seq);
if seq > self.enc_read_seq {
self.enc_read_seq = seq;
}
let is_handshake = matches!(
inner_type,
ContentType::Handshake | ContentType::Alert | ContentType::Unknown(ACK_CONTENT_TYPE)
);
if is_handshake {
self.pending_acks.push(RecordNumber {
epoch: read_epoch as u64,
seq,
});
}
match inner_type {
ContentType::Handshake => self.process_handshake_record(&plain)?,
ContentType::ApplicationData => {
if self.state != State::Connected {
return Err(Error::UnexpectedMessage);
}
self.app_in.extend_from_slice(&plain);
}
ContentType::Alert => {
}
ContentType::Unknown(t) if t == ACK_CONTENT_TYPE => {
let acks = decode_ack(&plain)?;
self.retransmit.on_ack(&acks);
}
_ => return Err(Error::UnexpectedMessage),
}
Ok(consumed)
}
fn current_read_epoch(&self) -> u16 {
if matches!(self.state, State::Connected) {
3
} else {
2
}
}
fn process_handshake_record(&mut self, plain: &[u8]) -> Result<(), Error> {
let mut off = 0;
while off < plain.len() {
let frag = read_fragment(&plain[off..])?;
let consumed = frag.len;
let frag = HandshakeFragment {
msg_type: frag.msg_type,
total_length: frag.total_length,
message_seq: frag.message_seq,
fragment_offset: frag.fragment_offset,
fragment: frag.fragment,
len: frag.len,
};
off += consumed;
if let Some((mt, body)) = self.reassembler.feed(frag) {
self.dispatch_one(mt, &body)?;
}
while let Some((mt, body)) = self.reassembler.pop_ready() {
self.dispatch_one(mt, &body)?;
}
}
Ok(())
}
fn dispatch_one(&mut self, msg_type: u8, body: &[u8]) -> Result<(), Error> {
let mut raw = Vec::with_capacity(4 + body.len());
raw.push(msg_type);
let n = body.len() as u32;
raw.push(((n >> 16) & 0xff) as u8);
raw.push(((n >> 8) & 0xff) as u8);
raw.push((n & 0xff) as u8);
raw.extend_from_slice(body);
match self.state {
State::WaitServerHello => self.on_server_hello(msg_type, body, &raw),
State::WaitEncryptedExtensions => self.on_encrypted_extensions(msg_type, &raw),
State::WaitCertificate => self.on_certificate(msg_type, body, &raw),
State::WaitCertificateVerify => self.on_certificate_verify(msg_type, body, &raw),
State::WaitFinished => self.on_finished(msg_type, body, &raw),
State::Connected | State::Closed => Err(Error::UnexpectedMessage),
}
}
fn on_server_hello(&mut self, msg_type: u8, body: &[u8], raw: &[u8]) -> Result<(), Error> {
if msg_type != hs_type::SERVER_HELLO {
return Err(Error::UnexpectedMessage);
}
let sh = ServerHello::decode(body)?;
if sh.random == HRR_RANDOM {
return self.on_hello_retry_request(sh, raw);
}
let suite = lookup_suite(sh.cipher_suite).ok_or(Error::HandshakeFailure)?;
if !supported_suites().iter().any(|s| s.suite == suite.suite) {
return Err(Error::HandshakeFailure);
}
let sv = ext::find(&sh.extensions, ExtensionType::SUPPORTED_VERSIONS)
.ok_or(Error::UnsupportedVersion)?;
if ext::parse_selected_version(sv)? != ProtocolVersion::TLSv1_3 {
return Err(Error::UnsupportedVersion);
}
self.server_random = Some(sh.random);
self.suite = Some(suite);
let ks_ext =
ext::find(&sh.extensions, ExtensionType::KEY_SHARE).ok_or(Error::HandshakeFailure)?;
let (group, server_pub) = ext::parse_server_key_share(ks_ext)?;
if !self.config.groups.contains(&group) {
return Err(Error::HandshakeFailure);
}
let shared = self.key_agreement(group, &server_pub)?;
self.transcript.set_alg(suite.hash);
self.transcript.update(raw);
let mut ks = KeySchedule::new(suite.hash);
ks.enter_handshake(&shared);
let th = self.transcript.current_hash();
let chts = ks.client_handshake_traffic_secret(th.as_slice());
let shts = ks.server_handshake_traffic_secret(th.as_slice());
if let Some(kl) = self.config.key_log.as_ref() {
kl.log(
"CLIENT_HANDSHAKE_TRAFFIC_SECRET",
&self.client_random,
chts.as_slice(),
);
kl.log(
"SERVER_HANDSHAKE_TRAFFIC_SECRET",
&self.client_random,
shts.as_slice(),
);
}
let w_crypter = RecordCrypter::new(suite.hash, suite.aead, suite.key_len, &chts);
let r_crypter = RecordCrypter::new(suite.hash, suite.aead, suite.key_len, &shts);
self.write_crypter = Some(w_crypter);
self.read_crypter = Some(r_crypter);
let sn_len = sn_key_len_for(suite.aead);
self.write_sn_key = Some(derive_sn_key(suite.hash, &chts, sn_len));
self.read_sn_key = Some(derive_sn_key(suite.hash, &shts, sn_len));
self.enc_write_epoch = 2;
self.enc_write_seq = 0;
self.enc_read_seq = 0;
self.read_replay = crate::dtls::replay::AntiReplayWindow::new();
self.ks = Some(ks);
self.client_hs_secret = Some(chts);
self.server_hs_secret = Some(shts);
self.state = State::WaitEncryptedExtensions;
Ok(())
}
fn on_hello_retry_request(&mut self, hrr: ServerHello, raw: &[u8]) -> Result<(), Error> {
if self.hrr_processed {
return Err(Error::UnexpectedMessage);
}
let suite = lookup_suite(hrr.cipher_suite).ok_or(Error::IllegalParameter)?;
if !supported_suites().iter().any(|s| s.suite == suite.suite) {
return Err(Error::IllegalParameter);
}
let sv = ext::find(&hrr.extensions, ExtensionType::SUPPORTED_VERSIONS)
.ok_or(Error::UnsupportedVersion)?;
if ext::parse_selected_version(sv)? != ProtocolVersion::TLSv1_3 {
return Err(Error::UnsupportedVersion);
}
let cookie_body = hrr
.extensions
.iter()
.find(|(t, _)| t.0 == EXT_COOKIE)
.map(|(_, v)| v.clone());
let selected_group = match ext::find(&hrr.extensions, ExtensionType::KEY_SHARE) {
Some(body) => {
let g = ext::parse_hrr_key_share(body)?;
if !self.config.groups.contains(&g) {
return Err(Error::IllegalParameter);
}
Some(g)
}
None => None,
};
if cookie_body.is_none() && selected_group.is_none() {
return Err(Error::IllegalParameter);
}
self.cookie_extension = cookie_body;
self.hrr_selected_group = selected_group;
self.suite = Some(suite);
self.transcript.set_alg(suite.hash);
self.transcript.replace_with_message_hash();
self.transcript.update(raw);
self.hrr_processed = true;
self.retransmit = Retransmit13::new();
let dgram = self.build_client_hello();
self.emit_plaintext(dgram);
Ok(())
}
fn key_agreement(&self, group: NamedGroup, server_pub: &[u8]) -> Result<Vec<u8>, Error> {
match group {
NamedGroup::X25519 => {
let peer: [u8; 32] = server_pub.try_into().map_err(|_| Error::Decode)?;
let ss = self
.x25519
.diffie_hellman(&peer)
.map_err(|_| Error::IllegalParameter)?;
Ok(ss.to_vec())
}
NamedGroup::SECP256R1 => {
let peer = BoxedEcdsaPublicKey::from_sec1(CurveId::P256, server_pub)
.map_err(|_| Error::Decode)?;
let ss = self
.p256
.diffie_hellman(&peer)
.map_err(|_| Error::PeerMisbehaved)?;
Ok(ss)
}
NamedGroup::SECP384R1 => {
let peer = BoxedEcdsaPublicKey::from_sec1(CurveId::P384, server_pub)
.map_err(|_| Error::Decode)?;
let ss = self
.p384
.diffie_hellman(&peer)
.map_err(|_| Error::PeerMisbehaved)?;
Ok(ss)
}
NamedGroup::X25519MLKEM768 => {
if server_pub.len() != CIPHERTEXT_BYTES + 32 {
return Err(Error::Decode);
}
let mut ct = [0u8; CIPHERTEXT_BYTES];
ct.copy_from_slice(&server_pub[..CIPHERTEXT_BYTES]);
let peer: [u8; 32] = server_pub[CIPHERTEXT_BYTES..]
.try_into()
.map_err(|_| Error::Decode)?;
let ml_ss = self.mlkem.decapsulate(&MlKem768Ciphertext::from_bytes(ct));
let x_ss = self
.x25519
.diffie_hellman(&peer)
.map_err(|_| Error::IllegalParameter)?;
let mut combined = Vec::with_capacity(64);
combined.extend_from_slice(&ml_ss);
combined.extend_from_slice(&x_ss);
Ok(combined)
}
_ => Err(Error::HandshakeFailure),
}
}
fn on_encrypted_extensions(&mut self, msg_type: u8, raw: &[u8]) -> Result<(), Error> {
if msg_type != hs_type::ENCRYPTED_EXTENSIONS {
return Err(Error::UnexpectedMessage);
}
if raw.len() >= 4 {
let body = &raw[4..];
let mut c = ReadCursor::new(body);
let exts_bytes = c.vec_u16()?;
let mut ec = ReadCursor::new(exts_bytes);
while !ec.is_empty() {
let ty = ec.u16()?;
let ext_body = ec.vec_u16()?;
if ty == ExtensionType::ALPN.0 {
let names = ext::parse_alpn(ext_body)?;
if names.len() != 1 {
return Err(Error::IllegalParameter);
}
if !self.config.alpn_protocols.iter().any(|p| p == &names[0]) {
return Err(Error::IllegalParameter);
}
self.alpn_negotiated = Some(names.into_iter().next().unwrap());
}
}
}
self.transcript.update(raw);
self.state = State::WaitCertificate;
Ok(())
}
fn on_certificate(&mut self, msg_type: u8, body: &[u8], raw: &[u8]) -> Result<(), Error> {
if msg_type != hs_type::CERTIFICATE {
return Err(Error::UnexpectedMessage);
}
self.cert_chain = parse_certificate_list(body)?;
if self.cert_chain.is_empty() {
return Err(Error::BadCertificate);
}
self.transcript.update(raw);
self.state = State::WaitCertificateVerify;
Ok(())
}
fn on_certificate_verify(
&mut self,
msg_type: u8,
body: &[u8],
raw: &[u8],
) -> Result<(), Error> {
if msg_type != hs_type::CERTIFICATE_VERIFY {
return Err(Error::UnexpectedMessage);
}
let mut c = ReadCursor::new(body);
let scheme = SignatureScheme(c.u16()?);
let signature = c.vec_u16()?.to_vec();
c.expect_empty()?;
let leaf =
Certificate::from_der(self.cert_chain[0].clone()).map_err(|_| Error::BadCertificate)?;
leaf.check_well_formed()
.map_err(|_| Error::BadCertificate)?;
let leaf_key = if self.config.verify_certificates {
let now = self.config.verification_time.clone();
let key = verify_chain_with_crls(
&self.config.roots,
&self.config.crls,
&self.cert_chain,
now.as_ref(),
&self.config.signature_policy,
)?;
if let Some(name) = self.config.server_name.as_deref() {
verify_hostname(&leaf, name)?;
}
key
} else {
leaf.subject_public_key()
.map_err(|_| Error::BadCertificate)?
};
let th = self.transcript.current_hash();
let content = certificate_verify_content(true, th.as_slice());
verify_signature(
scheme,
&leaf_key,
&content,
&signature,
&self.config.signature_policy,
)?;
self.leaf_key = Some(leaf_key);
self.transcript.update(raw);
self.state = State::WaitFinished;
Ok(())
}
fn on_finished(&mut self, msg_type: u8, body: &[u8], raw: &[u8]) -> Result<(), Error> {
if msg_type != hs_type::FINISHED {
return Err(Error::UnexpectedMessage);
}
let suite = self.suite.ok_or(Error::InappropriateState)?;
let shts = self
.server_hs_secret
.as_ref()
.ok_or(Error::InappropriateState)?;
let th = self.transcript.current_hash();
let expected = finished_verify_data(suite.hash, shts, th.as_slice());
if !bool::from(expected.as_slice().ct_eq(body)) {
return Err(Error::HandshakeFailure);
}
self.transcript.update(raw);
let (cats, sats, ems) = {
let ks = self.ks.as_mut().ok_or(Error::InappropriateState)?;
ks.enter_master();
let th_app = self.transcript.current_hash();
let cats = ks.client_application_traffic_secret(th_app.as_slice());
let sats = ks.server_application_traffic_secret(th_app.as_slice());
let ems = ks.exporter_master_secret(th_app.as_slice());
(cats, sats, ems)
};
if let Some(kl) = self.config.key_log.as_ref() {
kl.log(
"CLIENT_TRAFFIC_SECRET_0",
&self.client_random,
cats.as_slice(),
);
kl.log(
"SERVER_TRAFFIC_SECRET_0",
&self.client_random,
sats.as_slice(),
);
kl.log("EXPORTER_SECRET", &self.client_random, ems.as_slice());
}
self.exporter_secret = Some(ems);
self.pending_write_app_crypter = Some(RecordCrypter::new(
suite.hash,
suite.aead,
suite.key_len,
&cats,
));
self.pending_read_app_crypter = Some(RecordCrypter::new(
suite.hash,
suite.aead,
suite.key_len,
&sats,
));
let sn_len = sn_key_len_for(suite.aead);
self.write_app_sn_key = Some(derive_sn_key(suite.hash, &cats, sn_len));
self.read_app_sn_key = Some(derive_sn_key(suite.hash, &sats, sn_len));
self.client_app_secret = Some(cats);
self.server_app_secret = Some(sats);
let chts = self
.client_hs_secret
.as_ref()
.ok_or(Error::InappropriateState)?;
let th_for_cfin = self.transcript.current_hash();
let verify_data = finished_verify_data(suite.hash, chts, th_for_cfin.as_slice());
let fin_body = verify_data.as_slice().to_vec();
let mut fin_tls = Vec::with_capacity(4 + fin_body.len());
fin_tls.push(hs_type::FINISHED);
let n = fin_body.len() as u32;
fin_tls.push(((n >> 16) & 0xff) as u8);
fin_tls.push(((n >> 8) & 0xff) as u8);
fin_tls.push((n & 0xff) as u8);
fin_tls.extend_from_slice(&fin_body);
self.transcript.update(&fin_tls);
let fin_msg_seq = self.out_msg_seq;
self.out_msg_seq += 1;
let mut frag_buf = Vec::new();
write_message(
&mut frag_buf,
hs_type::FINISHED,
fin_msg_seq,
&fin_body,
DEFAULT_MAX_FRAGMENT,
);
let fin_dgram = self.encrypt_protected_record(ContentType::Handshake, &frag_buf)?;
self.emit_protected(fin_dgram, true);
self.write_crypter = self.pending_write_app_crypter.take();
self.read_crypter = self.pending_read_app_crypter.take();
self.write_sn_key = self.write_app_sn_key.take();
self.read_sn_key = self.read_app_sn_key.take();
self.enc_write_epoch = 3;
self.enc_write_seq = 0;
self.enc_read_seq = 0;
self.read_replay = crate::dtls::replay::AntiReplayWindow::new();
self.state = State::Connected;
Ok(())
}
fn build_client_hello(&mut self) -> Vec<u8> {
let groups = self.config.groups.clone();
let mut key_shares: Vec<(NamedGroup, Vec<u8>)> = Vec::new();
for &g in &groups {
if let Some(sel) = self.hrr_selected_group
&& g != sel
{
continue;
}
if self.hrr_selected_group.is_none()
&& let Some(filter) = self.config.key_share_groups.as_ref()
&& !filter.contains(&g)
{
continue;
}
match g {
NamedGroup::X25519 => {
key_shares.push((NamedGroup::X25519, self.x25519.public_key().to_vec()))
}
NamedGroup::SECP256R1 => {
key_shares.push((NamedGroup::SECP256R1, self.p256.public_key().to_sec1()))
}
NamedGroup::SECP384R1 => {
key_shares.push((NamedGroup::SECP384R1, self.p384.public_key().to_sec1()))
}
NamedGroup::X25519MLKEM768 => {
let mut share = self.mlkem.encapsulation_key().to_bytes().to_vec();
share.extend_from_slice(&self.x25519.public_key());
key_shares.push((NamedGroup::X25519MLKEM768, share));
}
_ => {}
}
}
let mut extensions = alloc::vec![ext::supported_groups_list(&groups),];
if let Some(name) = self.config.server_name.as_deref() {
extensions.insert(0, ext::server_name(name));
}
extensions.push(ext::signature_algorithms());
extensions.push(ext::client_supported_versions());
extensions.push(ext::client_key_shares(&key_shares));
if !self.config.alpn_protocols.is_empty() {
let protos: Vec<&[u8]> = self
.config
.alpn_protocols
.iter()
.map(|v| v.as_slice())
.collect();
extensions.push(ext::alpn_protocols(&protos));
}
if let Some(cookie) = self.cookie_extension.as_ref() {
extensions.push((ExtensionType(EXT_COOKIE), cookie.clone()));
}
let ch = ClientHello {
legacy_version: 0x0303,
random: self.client_random,
session_id: Vec::new(),
cipher_suites: self.config.cipher_suites.clone(),
extensions,
}
.encode();
self.transcript.update(&ch);
let ch_body = &ch[4..];
let msg_seq = self.out_msg_seq;
self.out_msg_seq += 1;
let mut frag_buf = Vec::new();
write_message(
&mut frag_buf,
hs_type::CLIENT_HELLO,
msg_seq,
ch_body,
DEFAULT_MAX_FRAGMENT,
);
self.wrap_plain_record(ContentType::Handshake, &frag_buf)
}
fn wrap_plain_record(&mut self, ct: ContentType, fragment: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
record::write_record(
&mut out,
ct,
ProtocolVersion::DTLSv1_2,
self.plain_write_epoch,
self.plain_write_seq,
fragment,
);
self.plain_write_seq += 1;
out
}
fn encrypt_protected_record(
&mut self,
ct: ContentType,
payload: &[u8],
) -> Result<Vec<u8>, Error> {
let suite = self.suite.ok_or(Error::InappropriateState)?;
let crypter = self
.write_crypter
.as_mut()
.ok_or(Error::InappropriateState)?;
let sn_key = self
.write_sn_key
.as_ref()
.ok_or(Error::InappropriateState)?;
let epoch = self.enc_write_epoch;
let seq = self.enc_write_seq;
record::check_seq_cap(seq)?;
self.enc_write_seq += 1;
let seq_is_16bit = true;
let omit_length = false;
let mut inner = Vec::with_capacity(payload.len() + 1);
inner.extend_from_slice(payload);
inner.push(ct.as_u8());
let mut aad = Vec::new();
let aad_zero_mask = [0u8; 2];
let ct_len = inner.len() + 16;
record13::encode_record(
&mut aad,
epoch,
seq,
seq_is_16bit,
omit_length,
&alloc::vec![0u8; ct_len],
&aad_zero_mask,
);
let hdr_len = aad.len() - ct_len;
aad.truncate(hdr_len);
encrypt_dtls13_record(crypter, seq, &aad, &mut inner)?;
let mask_full = sn_mask_for(suite, sn_key, &inner)?;
let mask: &[u8] = if seq_is_16bit {
&mask_full[..2]
} else {
&mask_full[..1]
};
let mut wire = Vec::new();
record13::encode_record(
&mut wire,
epoch,
seq,
seq_is_16bit,
omit_length,
&inner,
mask,
);
Ok(wire)
}
fn emit_plaintext(&mut self, datagram: Vec<u8>) {
let seq = self.plain_write_seq.saturating_sub(1);
let record_number = RecordNumber {
epoch: self.plain_write_epoch as u64,
seq,
};
self.out_dgrams.push(datagram.clone());
self.retransmit.on_record_sent(
InFlightRecord {
record_number,
datagram,
},
self.last_now,
);
}
fn emit_protected(&mut self, datagram: Vec<u8>, track: bool) {
if track {
let seq = self.enc_write_seq.saturating_sub(1);
let record_number = RecordNumber {
epoch: self.enc_write_epoch as u64,
seq,
};
self.out_dgrams.push(datagram.clone());
self.retransmit.on_record_sent(
InFlightRecord {
record_number,
datagram,
},
self.last_now,
);
} else {
self.out_dgrams.push(datagram);
}
}
fn flush_pending_acks(&mut self) {
if self.pending_acks.is_empty() {
return;
}
if self.write_crypter.is_none() {
return;
}
let acks = core::mem::take(&mut self.pending_acks);
let body = encode_ack(&acks);
if let Ok(dg) = self.encrypt_protected_record(ContentType::Unknown(ACK_CONTENT_TYPE), &body)
{
self.out_dgrams.push(dg);
}
}
}
fn parse_certificate_list(body: &[u8]) -> Result<Vec<Vec<u8>>, Error> {
let mut c = ReadCursor::new(body);
let _ctx = c.vec_u8()?;
let list = c.vec_u24()?;
c.expect_empty()?;
let mut entries = ReadCursor::new(list);
let mut certs = Vec::new();
while !entries.is_empty() {
let cert = entries.vec_u24()?.to_vec();
if cert.is_empty() {
return Err(Error::BadCertificate);
}
let _exts = entries.vec_u16()?;
certs.push(cert);
}
Ok(certs)
}
pub(crate) fn derive_sn_key(hash: HashAlg, secret: &Secret, len: usize) -> Vec<u8> {
let mut out = alloc::vec![0u8; len];
expand_label_dyn(hash, secret.as_slice(), b"sn", &[], &mut out);
out
}
pub(crate) fn sn_key_len_for(aead: AeadAlg) -> usize {
match aead {
AeadAlg::Aes128Gcm => 16,
AeadAlg::Aes256Gcm | AeadAlg::ChaCha20Poly1305 => 32,
}
}
pub(crate) fn encrypt_dtls13_record(
crypter: &mut RecordCrypter,
seq: u64,
aad: &[u8],
inner: &mut Vec<u8>,
) -> Result<(), Error> {
let tag = crypter.encrypt_raw(seq, aad, inner)?;
inner.extend_from_slice(&tag);
Ok(())
}
pub(crate) fn decrypt_dtls13_record(
crypter: &mut RecordCrypter,
seq: u64,
aad: &[u8],
ciphertext: &[u8],
) -> Result<(ContentType, Vec<u8>), Error> {
if ciphertext.len() < 16 {
return Err(Error::Decode);
}
let (ct, tag_bytes) = ciphertext.split_at(ciphertext.len() - 16);
let mut tag = [0u8; 16];
tag.copy_from_slice(tag_bytes);
let mut buf = ct.to_vec();
crypter.decrypt_raw(seq, aad, &mut buf, &tag)?;
let end = match buf.iter().rposition(|&b| b != 0) {
Some(p) => p,
None => return Err(Error::PeerMisbehaved),
};
let true_type = buf[end];
buf.truncate(end);
let ct = ContentType::from_u8(true_type);
Ok((ct, buf))
}