#![allow(dead_code, unreachable_pub)]
use crate::ct::ConstantTimeEq;
use crate::ec::x25519::X25519PrivateKey;
use crate::ec::{BoxedEcdhPrivateKey, BoxedEcdsaPrivateKey, BoxedEcdsaPublicKey, CurveId};
use crate::mlkem::{ENCAPS_KEY_BYTES, MlKem768EncapsKey};
use crate::rng::RngCore;
use crate::signature_registry::SignaturePolicy;
use crate::tls::codec::extension as ext;
use crate::tls::codec::{
ClientHello, ExtensionType, NamedGroup, Random, ReadCursor, ServerHello, hs_type, put_u16,
with_len_u16, with_len_u24,
};
use crate::tls::crypto::sign::sign_certificate_verify;
use crate::tls::crypto::{
KeySchedule, RecordCrypter, SuiteParams, Transcript, certificate_verify_content,
finished_verify_data, supported_suites,
};
use crate::tls::keylog::KeyLog;
use crate::tls::{ContentType, Error, ProtocolVersion};
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::time::Duration;
use super::ack::{ACK_CONTENT_TYPE, RecordNumber, decode as decode_ack, encode as encode_ack};
use super::client13::{
decrypt_dtls13_record, derive_sn_key, encrypt_dtls13_record, sn_key_len_for,
};
use super::cookie::CookieGenerator;
use super::reassembly::{HandshakeFragment, Reassembler, read_fragment, write_message};
use super::record::{self, ParsedDtlsRecord};
use super::record13::{self, peek_header_layout, reconstruct_seq, sn_mask_for};
use super::reliability13::{InFlightRecord, Retransmit13};
const HRR_RANDOM: [u8; 32] = [
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
];
const EXT_COOKIE: u16 = 0x002C;
const DEFAULT_MAX_FRAGMENT: usize = 1100;
pub(crate) struct ServerConfig13Internal {
pub cert_chain: Vec<Vec<u8>>,
pub key: crate::tls::conn::ServerKey,
pub cookie_secret: Option<[u8; 32]>,
pub require_cookie: bool,
#[allow(dead_code)]
pub signature_policy: Arc<SignaturePolicy>,
pub key_log: Option<Arc<dyn KeyLog>>,
}
impl ServerConfig13Internal {
pub fn with_signing_key(cert_chain: Vec<Vec<u8>>, key: crate::tls::conn::ServerKey) -> Self {
Self {
cert_chain,
key,
cookie_secret: None,
require_cookie: true,
signature_policy: Arc::new(SignaturePolicy::modern()),
key_log: None,
}
}
#[allow(dead_code)]
pub fn with_ecdsa(cert_chain: Vec<Vec<u8>>, key: BoxedEcdsaPrivateKey) -> Self {
Self::with_signing_key(cert_chain, crate::tls::conn::ServerKey::Ecdsa(key))
}
pub fn with_cookie_secret(mut self, secret: [u8; 32]) -> Self {
self.cookie_secret = Some(secret);
self
}
pub fn with_no_cookie(mut self) -> Self {
self.require_cookie = false;
self
}
}
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
enum State {
WaitFirstClientHello,
WaitSecondClientHello,
WaitClientFinished,
Connected,
Closed,
}
pub struct DtlsServerConnection13<R: RngCore> {
config: Arc<ServerConfig13Internal>,
rng: R,
peer_addr: Vec<u8>,
state: State,
out_msg_seq: u16,
reassembler: Option<Reassembler>,
pre_state_reasm: Option<Reassembler>,
out_dgrams: Vec<Vec<u8>>,
app_in: Vec<u8>,
plain_write_epoch: u16,
plain_write_seq: u64,
enc_write_epoch: u16,
enc_write_seq: u64,
enc_read_seq: u64,
read_replay: crate::dtls::replay::AntiReplayWindow,
#[allow(dead_code)]
x25519: Option<X25519PrivateKey>,
client_random: Option<Random>,
server_random: Option<Random>,
transcript: Transcript,
ks: Option<KeySchedule>,
client_hs_secret: Option<crate::tls::crypto::Secret>,
server_hs_secret: Option<crate::tls::crypto::Secret>,
client_app_secret: Option<crate::tls::crypto::Secret>,
server_app_secret: Option<crate::tls::crypto::Secret>,
write_crypter: Option<RecordCrypter>,
read_crypter: Option<RecordCrypter>,
write_sn_key: Option<Vec<u8>>,
read_sn_key: Option<Vec<u8>>,
read_app_sn_key: Option<Vec<u8>>,
write_app_sn_key: Option<Vec<u8>>,
pending_read_app_crypter: Option<RecordCrypter>,
pending_write_app_crypter: Option<RecordCrypter>,
suite: Option<SuiteParams>,
exporter_secret: Option<crate::tls::crypto::Secret>,
hrr_selected_group: Option<NamedGroup>,
pending_acks: Vec<RecordNumber>,
retransmit: Retransmit13,
last_now: Duration,
}
impl<R: RngCore> DtlsServerConnection13<R> {
pub(crate) fn new(config: Arc<ServerConfig13Internal>, peer_addr: Vec<u8>, rng: R) -> Self {
let t = Transcript::new();
Self {
config,
rng,
peer_addr,
state: State::WaitFirstClientHello,
out_msg_seq: 0,
reassembler: None,
pre_state_reasm: None,
out_dgrams: Vec::new(),
app_in: Vec::new(),
plain_write_epoch: 0,
plain_write_seq: 0,
enc_write_epoch: 0,
enc_write_seq: 0,
enc_read_seq: 0,
read_replay: crate::dtls::replay::AntiReplayWindow::new(),
x25519: None,
client_random: None,
server_random: None,
transcript: t,
ks: None,
client_hs_secret: None,
server_hs_secret: None,
client_app_secret: None,
server_app_secret: None,
write_crypter: None,
read_crypter: None,
write_sn_key: None,
read_sn_key: None,
read_app_sn_key: None,
write_app_sn_key: None,
pending_read_app_crypter: None,
pending_write_app_crypter: None,
suite: None,
exporter_secret: None,
hrr_selected_group: None,
pending_acks: Vec::new(),
retransmit: Retransmit13::new(),
last_now: Duration::from_secs(0),
}
}
pub fn is_handshake_complete(&self) -> bool {
self.state == State::Connected
}
pub fn negotiated_cipher_suite(&self) -> Option<u16> {
if self.is_handshake_complete() {
self.suite.map(|s| s.suite.0)
} else {
None
}
}
pub fn tls_exporter(&self, label: &[u8], context: &[u8], out: &mut [u8]) -> Result<(), Error> {
let ems = self
.exporter_secret
.as_ref()
.ok_or(Error::InappropriateState)?;
let suite = self.suite.ok_or(Error::InappropriateState)?;
crate::tls::crypto::tls_exporter(suite.hash, ems, label, context, out);
Ok(())
}
pub fn pop_outbound_datagrams(&mut self) -> Vec<Vec<u8>> {
self.flush_pending_acks();
core::mem::take(&mut self.out_dgrams)
}
pub fn take_received(&mut self) -> Vec<u8> {
core::mem::take(&mut self.app_in)
}
pub fn send(&mut self, plaintext: &[u8]) -> Result<(), Error> {
if self.state != State::Connected {
return Err(Error::InappropriateState);
}
let dg = self.encrypt_protected_record(ContentType::ApplicationData, plaintext)?;
self.out_dgrams.push(dg);
Ok(())
}
pub fn next_timeout(&self) -> Option<Duration> {
self.retransmit.next_timeout()
}
pub fn on_timeout(&mut self, now: Duration) {
self.last_now = now;
match self.retransmit.on_timeout(now) {
super::reliability::Action::Retransmit => {
for dg in self.retransmit.in_flight_datagrams() {
self.out_dgrams.push(dg.to_vec());
}
}
super::reliability::Action::GiveUp => self.state = State::Closed,
super::reliability::Action::Idle => {}
}
}
pub fn feed_datagram(&mut self, datagram: &[u8]) -> Result<(), Error> {
let mut off = 0usize;
while off < datagram.len() {
let first = datagram[off];
if first < 32 {
let Some(rec) = record::read_record(&datagram[off..])? else {
return Ok(());
};
off += rec.len;
self.process_plaintext_record(rec)?;
} else if (first & 0b1110_0000) == 0b0010_0000 {
let consumed = self.process_protected_record(&datagram[off..])?;
if consumed == 0 {
return Ok(());
}
off += consumed;
} else {
return Ok(());
}
}
Ok(())
}
fn process_plaintext_record(&mut self, rec: ParsedDtlsRecord<'_>) -> Result<(), Error> {
if rec.version != ProtocolVersion::DTLSv1_2 && rec.version != ProtocolVersion::DTLSv1_0 {
return Err(Error::UnsupportedVersion);
}
if rec.epoch != 0 {
return Ok(());
}
match rec.content_type {
ContentType::Handshake => self.process_handshake_record(rec.fragment),
ContentType::Alert => Ok(()),
ContentType::ChangeCipherSpec => Ok(()),
_ => Err(Error::UnexpectedMessage),
}
}
fn process_protected_record(&mut self, buf: &[u8]) -> Result<usize, Error> {
let (hdr_len, body_len) = peek_header_layout(buf)?;
let total = hdr_len + body_len;
if total > buf.len() {
return Ok(0);
}
let body = &buf[hdr_len..total];
if body.len() < 16 {
return Err(Error::Decode);
}
let suite = self.suite.ok_or(Error::UnexpectedMessage)?;
let sn_key = self.read_sn_key.as_ref().ok_or(Error::UnexpectedMessage)?;
let mask_full = sn_mask_for(suite, sn_key, body)?;
let mask: &[u8] = if (buf[0] & 0b0000_1000) != 0 {
&mask_full[..2]
} else {
&mask_full[..1]
};
let (hdr, ct_body) = record13::decode_record(buf, mask)?;
let consumed = hdr.header_len + ct_body.len();
let read_epoch = self.current_read_epoch();
if (read_epoch as u8 & 0b11) != hdr.epoch_low2 {
return Ok(consumed);
}
let seq = reconstruct_seq(
hdr.seq_low,
hdr.seq_is_16bit,
self.enc_read_seq.wrapping_add(1),
);
let mut aad = buf[..hdr.header_len].to_vec();
if hdr.seq_is_16bit {
aad[1] ^= mask[0];
aad[2] ^= mask[1];
} else {
aad[1] ^= mask[0];
}
let crypter = self.read_crypter.as_mut().ok_or(Error::UnexpectedMessage)?;
let (inner_type, plain) = decrypt_dtls13_record(crypter, seq, &aad, ct_body)?;
if !self.read_replay.accept(seq) {
return Ok(consumed);
}
if seq > self.enc_read_seq {
self.enc_read_seq = seq;
}
let is_handshake = matches!(
inner_type,
ContentType::Handshake | ContentType::Alert | ContentType::Unknown(ACK_CONTENT_TYPE)
);
if is_handshake {
self.pending_acks.push(RecordNumber {
epoch: read_epoch as u64,
seq,
});
}
match inner_type {
ContentType::Handshake => self.process_handshake_record(&plain)?,
ContentType::ApplicationData => {
if self.state != State::Connected {
return Err(Error::UnexpectedMessage);
}
self.app_in.extend_from_slice(&plain);
}
ContentType::Alert => {}
ContentType::Unknown(t) if t == ACK_CONTENT_TYPE => {
let acks = decode_ack(&plain)?;
self.retransmit.on_ack(&acks);
}
_ => return Err(Error::UnexpectedMessage),
}
Ok(consumed)
}
fn current_read_epoch(&self) -> u16 {
if matches!(self.state, State::Connected) {
3
} else {
2
}
}
fn process_handshake_record(&mut self, plain: &[u8]) -> Result<(), Error> {
let mut off = 0;
while off < plain.len() {
let frag = read_fragment(&plain[off..])?;
let consumed = frag.len;
if self.reassembler.is_none() {
if frag.msg_type != hs_type::CLIENT_HELLO {
return Err(Error::UnexpectedMessage);
}
let msg_seq = frag.message_seq;
let f = HandshakeFragment {
msg_type: frag.msg_type,
total_length: frag.total_length,
message_seq: frag.message_seq,
fragment_offset: frag.fragment_offset,
fragment: frag.fragment,
len: frag.len,
};
off += consumed;
let reasm = self.pre_state_reasm.get_or_insert_with(|| {
let mut r = Reassembler::new();
for s in 0..msg_seq {
let mut buf = Vec::new();
write_message(&mut buf, hs_type::CLIENT_HELLO, s, b"", 0);
if let Ok(empty) = read_fragment(&buf) {
let _ = r.feed(empty);
}
}
r
});
if let Some((_mt, body)) = reasm.feed(f) {
self.pre_state_reasm = None;
self.handle_pre_state_client_hello(msg_seq, &body)?;
}
continue;
}
let frag = HandshakeFragment {
msg_type: frag.msg_type,
total_length: frag.total_length,
message_seq: frag.message_seq,
fragment_offset: frag.fragment_offset,
fragment: frag.fragment,
len: frag.len,
};
off += consumed;
let feeding = self
.reassembler
.as_mut()
.expect("reassembler built")
.feed(frag);
if let Some((mt, body)) = feeding {
self.dispatch_one(mt, &body)?;
}
loop {
let popped = self
.reassembler
.as_mut()
.expect("reassembler built")
.pop_ready();
match popped {
Some((mt, body)) => self.dispatch_one(mt, &body)?,
None => break,
}
}
}
Ok(())
}
fn dispatch_one(&mut self, msg_type: u8, body: &[u8]) -> Result<(), Error> {
let mut raw = Vec::with_capacity(4 + body.len());
raw.push(msg_type);
let n = body.len() as u32;
raw.push(((n >> 16) & 0xff) as u8);
raw.push(((n >> 8) & 0xff) as u8);
raw.push((n & 0xff) as u8);
raw.extend_from_slice(body);
match self.state {
State::WaitClientFinished => self.on_client_finished(msg_type, body, &raw),
_ => Err(Error::UnexpectedMessage),
}
}
fn handle_pre_state_client_hello(&mut self, msg_seq: u16, body: &[u8]) -> Result<(), Error> {
let ch = ClientHello::decode(body)?;
let cookie_required = self.config.require_cookie && self.config.cookie_secret.is_some();
let presented_cookie = ch
.extensions
.iter()
.find(|(t, _)| t.0 == EXT_COOKIE)
.map(|(_, b)| b.clone());
let suite = supported_suites()
.iter()
.copied()
.find(|s| ch.cipher_suites.contains(&s.suite))
.ok_or(Error::HandshakeFailure)?;
let groups_ext = ext::find(&ch.extensions, ExtensionType::SUPPORTED_GROUPS)
.ok_or(Error::HandshakeFailure)?;
let offered_groups = parse_supported_groups(groups_ext)?;
let ks_ext =
ext::find(&ch.extensions, ExtensionType::KEY_SHARE).ok_or(Error::HandshakeFailure)?;
let client_shares = ext::parse_client_key_shares(ks_ext)?;
let preferred_group = supported_server_groups()
.iter()
.copied()
.find(|g| offered_groups.contains(g));
let preferred_share =
preferred_group.and_then(|g| client_shares.iter().find(|(sg, _)| *sg == g).cloned());
if cookie_required && presented_cookie.is_none() {
let group_needed = if preferred_share.is_none() {
preferred_group.ok_or(Error::HandshakeFailure)?.into()
} else {
None
};
let secret = self
.config
.cookie_secret
.as_ref()
.ok_or(Error::InappropriateState)?;
let cg = CookieGenerator::new(*secret);
let now_min = (self.last_now.as_secs() / 60) as u32;
let cookie = cg.generate(&self.peer_addr, &ch.random, now_min);
self.suite = Some(suite);
self.hrr_selected_group = group_needed;
self.emit_hello_retry_request(Some(&cookie))?;
self.state = State::WaitSecondClientHello;
let mut t = Transcript::new();
t.set_alg(suite.hash);
let mut tls_ch = Vec::with_capacity(4 + body.len());
tls_ch.push(hs_type::CLIENT_HELLO);
let n = body.len() as u32;
tls_ch.push(((n >> 16) & 0xff) as u8);
tls_ch.push(((n >> 8) & 0xff) as u8);
tls_ch.push((n & 0xff) as u8);
tls_ch.extend_from_slice(body);
t.update(&tls_ch);
self.transcript = t;
self.out_msg_seq = 1;
let _ = msg_seq;
return Ok(());
}
if cookie_required {
let cookie_bytes = presented_cookie
.as_ref()
.ok_or(Error::IllegalParameter)?
.clone();
let secret = self
.config
.cookie_secret
.as_ref()
.ok_or(Error::InappropriateState)?;
let cg = CookieGenerator::new(*secret);
if cookie_bytes.len() < 2 {
return Err(Error::Decode);
}
let clen = u16::from_be_bytes([cookie_bytes[0], cookie_bytes[1]]) as usize;
if cookie_bytes.len() != 2 + clen {
return Err(Error::Decode);
}
let cookie = &cookie_bytes[2..];
let now_min = (self.last_now.as_secs() / 60) as u32;
if !cg.validate(&self.peer_addr, &ch.random, now_min, cookie) {
return Err(Error::IllegalParameter);
}
let pinned = self.suite.ok_or(Error::InappropriateState)?;
if !ch.cipher_suites.contains(&pinned.suite) {
return Err(Error::IllegalParameter);
}
self.transcript.replace_with_message_hash();
let hrr_bytes = self.build_hrr_bytes(Some(cookie), self.hrr_selected_group);
self.transcript.update(&hrr_bytes);
} else {
if self.hrr_selected_group.is_none() {
if preferred_share.is_none() {
let group_needed = preferred_group.ok_or(Error::HandshakeFailure)?;
self.suite = Some(suite);
self.hrr_selected_group = Some(group_needed);
let mut t = Transcript::new();
t.set_alg(suite.hash);
let mut tls_ch = Vec::with_capacity(4 + body.len());
tls_ch.push(hs_type::CLIENT_HELLO);
let n = body.len() as u32;
tls_ch.push(((n >> 16) & 0xff) as u8);
tls_ch.push(((n >> 8) & 0xff) as u8);
tls_ch.push((n & 0xff) as u8);
tls_ch.extend_from_slice(body);
t.update(&tls_ch);
self.transcript = t;
self.emit_hello_retry_request(None)?;
self.state = State::WaitSecondClientHello;
self.out_msg_seq = 1;
let _ = msg_seq;
return Ok(());
}
self.suite = Some(suite);
} else {
self.transcript.replace_with_message_hash();
let hrr_bytes = self.build_hrr_bytes(None, self.hrr_selected_group);
self.transcript.update(&hrr_bytes);
}
}
let (selected_group, client_pub) = if let Some(g) = self.hrr_selected_group {
let share = client_shares
.iter()
.find(|(sg, _)| *sg == g)
.ok_or(Error::IllegalParameter)?;
(g, share.1.clone())
} else {
let (g, k) = preferred_share.ok_or(Error::HandshakeFailure)?;
(g, k)
};
let suite = self.suite.ok_or(Error::InappropriateState)?;
self.client_random = Some(ch.random);
let mut tls_ch = Vec::with_capacity(4 + body.len());
tls_ch.push(hs_type::CLIENT_HELLO);
let n = body.len() as u32;
tls_ch.push(((n >> 16) & 0xff) as u8);
tls_ch.push(((n >> 8) & 0xff) as u8);
tls_ch.push((n & 0xff) as u8);
tls_ch.extend_from_slice(body);
if !cookie_required && self.hrr_selected_group.is_none() {
self.transcript = Transcript::new();
self.transcript.set_alg(suite.hash);
}
self.transcript.update(&tls_ch);
let mut reasm = Reassembler::new();
for s in 0..=msg_seq {
let mut buf = Vec::new();
write_message(&mut buf, hs_type::CLIENT_HELLO, s, b"", 0);
let f = read_fragment(&buf)?;
let _ = reasm.feed(f);
}
self.reassembler = Some(reasm);
let mut sr: Random = [0u8; 32];
self.rng.fill_bytes(&mut sr);
self.server_random = Some(sr);
let (server_pub, shared) = self.key_agreement(selected_group, &client_pub)?;
let sh_extensions = alloc::vec![
ext::server_key_share(selected_group, &server_pub),
ext::server_supported_versions(),
];
let sh_bytes = ServerHello {
random: sr,
session_id: ch.session_id.clone(),
cipher_suite: suite.suite,
extensions: sh_extensions,
}
.encode();
self.transcript.update(&sh_bytes);
let sh_body = &sh_bytes[4..];
let sh_msg_seq = self.out_msg_seq;
self.out_msg_seq += 1;
let mut frag_buf = Vec::new();
write_message(
&mut frag_buf,
hs_type::SERVER_HELLO,
sh_msg_seq,
sh_body,
DEFAULT_MAX_FRAGMENT,
);
let sh_dgram = self.wrap_plain_record(ContentType::Handshake, &frag_buf);
self.emit_plaintext(sh_dgram);
let mut ks = KeySchedule::new(suite.hash);
ks.enter_handshake(&shared);
let th = self.transcript.current_hash();
let chts = ks.client_handshake_traffic_secret(th.as_slice());
let shts = ks.server_handshake_traffic_secret(th.as_slice());
if let Some(kl) = self.config.key_log.as_ref() {
kl.log(
"CLIENT_HANDSHAKE_TRAFFIC_SECRET",
&ch.random,
chts.as_slice(),
);
kl.log(
"SERVER_HANDSHAKE_TRAFFIC_SECRET",
&ch.random,
shts.as_slice(),
);
}
let w_crypter = RecordCrypter::new(suite.hash, suite.aead, suite.key_len, &shts);
let r_crypter = RecordCrypter::new(suite.hash, suite.aead, suite.key_len, &chts);
self.write_crypter = Some(w_crypter);
self.read_crypter = Some(r_crypter);
let sn_len = sn_key_len_for(suite.aead);
self.write_sn_key = Some(derive_sn_key(suite.hash, &shts, sn_len));
self.read_sn_key = Some(derive_sn_key(suite.hash, &chts, sn_len));
self.enc_write_epoch = 2;
self.enc_write_seq = 0;
self.enc_read_seq = 0;
self.read_replay = crate::dtls::replay::AntiReplayWindow::new();
self.ks = Some(ks);
self.client_hs_secret = Some(chts);
self.server_hs_secret = Some(shts);
self.send_encrypted_extensions()?;
self.send_certificate()?;
self.send_certificate_verify()?;
self.send_finished()?;
let (cats, sats, ems) = {
let ks = self.ks.as_mut().expect("ks");
ks.enter_master();
let th_app = self.transcript.current_hash();
let cats = ks.client_application_traffic_secret(th_app.as_slice());
let sats = ks.server_application_traffic_secret(th_app.as_slice());
let ems = ks.exporter_master_secret(th_app.as_slice());
(cats, sats, ems)
};
if let Some(kl) = self.config.key_log.as_ref() {
kl.log("CLIENT_TRAFFIC_SECRET_0", &ch.random, cats.as_slice());
kl.log("SERVER_TRAFFIC_SECRET_0", &ch.random, sats.as_slice());
kl.log("EXPORTER_SECRET", &ch.random, ems.as_slice());
}
self.exporter_secret = Some(ems);
self.pending_write_app_crypter = Some(RecordCrypter::new(
suite.hash,
suite.aead,
suite.key_len,
&sats,
));
self.pending_read_app_crypter = Some(RecordCrypter::new(
suite.hash,
suite.aead,
suite.key_len,
&cats,
));
self.write_app_sn_key = Some(derive_sn_key(suite.hash, &sats, sn_len));
self.read_app_sn_key = Some(derive_sn_key(suite.hash, &cats, sn_len));
self.client_app_secret = Some(cats);
self.server_app_secret = Some(sats);
self.state = State::WaitClientFinished;
Ok(())
}
fn send_encrypted_extensions(&mut self) -> Result<(), Error> {
let mut body = Vec::new();
with_len_u16(&mut body, |_| {});
let mut tls_msg = Vec::with_capacity(4 + body.len());
tls_msg.push(hs_type::ENCRYPTED_EXTENSIONS);
let n = body.len() as u32;
tls_msg.push(((n >> 16) & 0xff) as u8);
tls_msg.push(((n >> 8) & 0xff) as u8);
tls_msg.push((n & 0xff) as u8);
tls_msg.extend_from_slice(&body);
self.transcript.update(&tls_msg);
self.emit_encrypted_handshake(hs_type::ENCRYPTED_EXTENSIONS, &body)?;
Ok(())
}
fn send_certificate(&mut self) -> Result<(), Error> {
let mut body = Vec::new();
body.push(0); with_len_u24(&mut body, |list| {
for cert in &self.config.cert_chain {
with_len_u24(list, |c| c.extend_from_slice(cert));
with_len_u16(list, |_| {}); }
});
let mut tls_msg = Vec::with_capacity(4 + body.len());
tls_msg.push(hs_type::CERTIFICATE);
let n = body.len() as u32;
tls_msg.push(((n >> 16) & 0xff) as u8);
tls_msg.push(((n >> 8) & 0xff) as u8);
tls_msg.push((n & 0xff) as u8);
tls_msg.extend_from_slice(&body);
self.transcript.update(&tls_msg);
self.emit_encrypted_handshake(hs_type::CERTIFICATE, &body)?;
Ok(())
}
fn send_certificate_verify(&mut self) -> Result<(), Error> {
let th = self.transcript.current_hash();
let content = certificate_verify_content(true, th.as_slice());
let (scheme, sig_der) = sign_certificate_verify(&self.config.key, &content, &mut self.rng)?;
let mut body = Vec::new();
body.extend_from_slice(&scheme.0.to_be_bytes());
with_len_u16(&mut body, |b| b.extend_from_slice(&sig_der));
let mut tls_msg = Vec::with_capacity(4 + body.len());
tls_msg.push(hs_type::CERTIFICATE_VERIFY);
let n = body.len() as u32;
tls_msg.push(((n >> 16) & 0xff) as u8);
tls_msg.push(((n >> 8) & 0xff) as u8);
tls_msg.push((n & 0xff) as u8);
tls_msg.extend_from_slice(&body);
self.transcript.update(&tls_msg);
self.emit_encrypted_handshake(hs_type::CERTIFICATE_VERIFY, &body)?;
Ok(())
}
fn send_finished(&mut self) -> Result<(), Error> {
let suite = self.suite.ok_or(Error::InappropriateState)?;
let shts = self
.server_hs_secret
.as_ref()
.ok_or(Error::InappropriateState)?;
let th = self.transcript.current_hash();
let verify_data = finished_verify_data(suite.hash, shts, th.as_slice());
let body = verify_data.as_slice().to_vec();
let mut tls_msg = Vec::with_capacity(4 + body.len());
tls_msg.push(hs_type::FINISHED);
let n = body.len() as u32;
tls_msg.push(((n >> 16) & 0xff) as u8);
tls_msg.push(((n >> 8) & 0xff) as u8);
tls_msg.push((n & 0xff) as u8);
tls_msg.extend_from_slice(&body);
self.transcript.update(&tls_msg);
self.emit_encrypted_handshake(hs_type::FINISHED, &body)?;
Ok(())
}
fn on_client_finished(&mut self, msg_type: u8, body: &[u8], raw: &[u8]) -> Result<(), Error> {
if msg_type != hs_type::FINISHED {
return Err(Error::UnexpectedMessage);
}
let suite = self.suite.ok_or(Error::InappropriateState)?;
let chts = self
.client_hs_secret
.as_ref()
.ok_or(Error::InappropriateState)?;
let th = self.transcript.current_hash();
let expected = finished_verify_data(suite.hash, chts, th.as_slice());
if !bool::from(expected.as_slice().ct_eq(body)) {
return Err(Error::HandshakeFailure);
}
self.transcript.update(raw);
self.write_crypter = self.pending_write_app_crypter.take();
self.read_crypter = self.pending_read_app_crypter.take();
self.write_sn_key = self.write_app_sn_key.take();
self.read_sn_key = self.read_app_sn_key.take();
self.enc_write_epoch = 3;
self.enc_write_seq = 0;
self.enc_read_seq = 0;
self.read_replay = crate::dtls::replay::AntiReplayWindow::new();
self.state = State::Connected;
Ok(())
}
fn build_hrr_bytes(&self, cookie: Option<&[u8]>, group: Option<NamedGroup>) -> Vec<u8> {
let mut extensions = alloc::vec![ext::server_supported_versions(),];
if let Some(g) = group {
let mut body = Vec::with_capacity(2);
put_u16(&mut body, g.0);
extensions.push((ExtensionType::KEY_SHARE, body));
}
if let Some(c) = cookie {
let mut v = Vec::with_capacity(2 + c.len());
v.extend_from_slice(&(c.len() as u16).to_be_bytes());
v.extend_from_slice(c);
extensions.push((ExtensionType(EXT_COOKIE), v));
}
let suite_id = self
.suite
.map(|s| s.suite)
.unwrap_or_else(|| supported_suites()[0].suite);
ServerHello {
random: HRR_RANDOM,
session_id: Vec::new(),
cipher_suite: suite_id,
extensions,
}
.encode()
}
fn emit_hello_retry_request(&mut self, cookie: Option<&[u8]>) -> Result<(), Error> {
let bytes = self.build_hrr_bytes(cookie, self.hrr_selected_group);
let body = &bytes[4..];
let mut frag_buf = Vec::new();
write_message(
&mut frag_buf,
hs_type::SERVER_HELLO,
0,
body,
DEFAULT_MAX_FRAGMENT,
);
let dgram = self.wrap_plain_record(ContentType::Handshake, &frag_buf);
self.out_dgrams.push(dgram);
Ok(())
}
fn wrap_plain_record(&mut self, ct: ContentType, fragment: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
record::write_record(
&mut out,
ct,
ProtocolVersion::DTLSv1_2,
self.plain_write_epoch,
self.plain_write_seq,
fragment,
);
self.plain_write_seq += 1;
out
}
fn encrypt_protected_record(
&mut self,
ct: ContentType,
payload: &[u8],
) -> Result<Vec<u8>, Error> {
let suite = self.suite.ok_or(Error::InappropriateState)?;
let crypter = self
.write_crypter
.as_mut()
.ok_or(Error::InappropriateState)?;
let sn_key = self
.write_sn_key
.as_ref()
.ok_or(Error::InappropriateState)?;
let epoch = self.enc_write_epoch;
let seq = self.enc_write_seq;
self.enc_write_seq += 1;
let seq_is_16bit = true;
let omit_length = false;
let mut inner = Vec::with_capacity(payload.len() + 1);
inner.extend_from_slice(payload);
inner.push(ct.as_u8());
let mut aad = Vec::new();
let aad_zero_mask = [0u8; 2];
let ct_len = inner.len() + 16;
record13::encode_record(
&mut aad,
epoch,
seq,
seq_is_16bit,
omit_length,
&alloc::vec![0u8; ct_len],
&aad_zero_mask,
);
let hdr_len = aad.len() - ct_len;
aad.truncate(hdr_len);
encrypt_dtls13_record(crypter, seq, &aad, &mut inner)?;
let mask_full = sn_mask_for(suite, sn_key, &inner)?;
let mask: &[u8] = if seq_is_16bit {
&mask_full[..2]
} else {
&mask_full[..1]
};
let mut wire = Vec::new();
record13::encode_record(
&mut wire,
epoch,
seq,
seq_is_16bit,
omit_length,
&inner,
mask,
);
Ok(wire)
}
fn emit_plaintext(&mut self, datagram: Vec<u8>) {
let seq = self.plain_write_seq.saturating_sub(1);
let record_number = RecordNumber {
epoch: self.plain_write_epoch as u64,
seq,
};
self.out_dgrams.push(datagram.clone());
self.retransmit.on_record_sent(
InFlightRecord {
record_number,
datagram,
},
self.last_now,
);
}
fn emit_encrypted_handshake(&mut self, msg_type: u8, body: &[u8]) -> Result<(), Error> {
let msg_seq = self.out_msg_seq;
self.out_msg_seq += 1;
let mut frag_buf = Vec::new();
write_message(&mut frag_buf, msg_type, msg_seq, body, DEFAULT_MAX_FRAGMENT);
let dg = self.encrypt_protected_record(ContentType::Handshake, &frag_buf)?;
let seq = self.enc_write_seq.saturating_sub(1);
let record_number = RecordNumber {
epoch: self.enc_write_epoch as u64,
seq,
};
self.out_dgrams.push(dg.clone());
self.retransmit.on_record_sent(
InFlightRecord {
record_number,
datagram: dg,
},
self.last_now,
);
Ok(())
}
fn flush_pending_acks(&mut self) {
if self.pending_acks.is_empty() {
return;
}
if self.write_crypter.is_none() {
return;
}
let acks = core::mem::take(&mut self.pending_acks);
let body = encode_ack(&acks);
if let Ok(dg) = self.encrypt_protected_record(ContentType::Unknown(ACK_CONTENT_TYPE), &body)
{
self.out_dgrams.push(dg);
}
}
fn key_agreement(
&mut self,
group: NamedGroup,
client_pub: &[u8],
) -> Result<(Vec<u8>, Vec<u8>), Error> {
match group {
NamedGroup::X25519 => {
let sk = X25519PrivateKey::generate(&mut self.rng);
let peer: [u8; 32] = client_pub.try_into().map_err(|_| Error::Decode)?;
let ss = sk
.diffie_hellman(&peer)
.map_err(|_| Error::IllegalParameter)?;
let pk = sk.public_key().to_vec();
self.x25519 = Some(sk);
Ok((pk, ss.to_vec()))
}
NamedGroup::SECP256R1 => {
let sk = BoxedEcdhPrivateKey::generate(CurveId::P256, &mut self.rng);
let peer = BoxedEcdsaPublicKey::from_sec1(CurveId::P256, client_pub)
.map_err(|_| Error::Decode)?;
let ss = sk
.diffie_hellman(&peer)
.map_err(|_| Error::PeerMisbehaved)?;
Ok((sk.public_key().to_sec1(), ss))
}
NamedGroup::SECP384R1 => {
let sk = BoxedEcdhPrivateKey::generate(CurveId::P384, &mut self.rng);
let peer = BoxedEcdsaPublicKey::from_sec1(CurveId::P384, client_pub)
.map_err(|_| Error::Decode)?;
let ss = sk
.diffie_hellman(&peer)
.map_err(|_| Error::PeerMisbehaved)?;
Ok((sk.public_key().to_sec1(), ss))
}
NamedGroup::X25519MLKEM768 => {
if client_pub.len() != ENCAPS_KEY_BYTES + 32 {
return Err(Error::Decode);
}
let mut ek = [0u8; ENCAPS_KEY_BYTES];
ek.copy_from_slice(&client_pub[..ENCAPS_KEY_BYTES]);
let peer: [u8; 32] = client_pub[ENCAPS_KEY_BYTES..]
.try_into()
.map_err(|_| Error::Decode)?;
let validated_ek = MlKem768EncapsKey::from_bytes_validated(ek)
.map_err(|_| Error::IllegalParameter)?;
let (ct, ml_ss) = validated_ek.encapsulate(&mut self.rng);
let sk = X25519PrivateKey::generate(&mut self.rng);
let x_ss = sk
.diffie_hellman(&peer)
.map_err(|_| Error::IllegalParameter)?;
let mut share = ct.to_bytes().to_vec();
share.extend_from_slice(&sk.public_key());
let mut combined = Vec::with_capacity(64);
combined.extend_from_slice(&ml_ss);
combined.extend_from_slice(&x_ss);
Ok((share, combined))
}
_ => Err(Error::HandshakeFailure),
}
}
}
fn supported_server_groups() -> [NamedGroup; 4] {
[
NamedGroup::X25519MLKEM768,
NamedGroup::X25519,
NamedGroup::SECP256R1,
NamedGroup::SECP384R1,
]
}
fn parse_supported_groups(body: &[u8]) -> Result<Vec<NamedGroup>, Error> {
let mut outer = ReadCursor::new(body);
let list = outer.vec_u16()?;
outer.expect_empty()?;
if list.len() % 2 != 0 {
return Err(Error::Decode);
}
let mut c = ReadCursor::new(list);
let mut out = Vec::with_capacity(list.len() / 2);
while !c.is_empty() {
out.push(NamedGroup(c.u16()?));
}
Ok(out)
}