use crate::client::{ClientConfig, ClientConnectionData};
use crate::common_state::{CommonState, Protocol, Side};
use crate::conn::{ConnectionCore, SideData};
use crate::crypto::cipher::{AeadKey, Iv};
use crate::crypto::tls13::{Hkdf, HkdfExpander, OkmBlock};
use crate::enums::{AlertDescription, ProtocolVersion};
use crate::error::Error;
use crate::msgs::handshake::{ClientExtension, ServerExtension};
use crate::server::{ServerConfig, ServerConnectionData};
use crate::tls13::key_schedule::{
hkdf_expand_label, hkdf_expand_label_aead_key, hkdf_expand_label_block,
};
use crate::tls13::Tls13CipherSuite;
use pki_types::ServerName;
use alloc::boxed::Box;
use alloc::collections::VecDeque;
use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use core::fmt::{self, Debug};
use core::ops::{Deref, DerefMut};
#[derive(Debug)]
pub enum Connection {
Client(ClientConnection),
Server(ServerConnection),
}
impl Connection {
pub fn quic_transport_parameters(&self) -> Option<&[u8]> {
match self {
Self::Client(conn) => conn.quic_transport_parameters(),
Self::Server(conn) => conn.quic_transport_parameters(),
}
}
pub fn zero_rtt_keys(&self) -> Option<DirectionalKeys> {
match self {
Self::Client(conn) => conn.zero_rtt_keys(),
Self::Server(conn) => conn.zero_rtt_keys(),
}
}
pub fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error> {
match self {
Self::Client(conn) => conn.read_hs(plaintext),
Self::Server(conn) => conn.read_hs(plaintext),
}
}
pub fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange> {
match self {
Self::Client(conn) => conn.write_hs(buf),
Self::Server(conn) => conn.write_hs(buf),
}
}
pub fn alert(&self) -> Option<AlertDescription> {
match self {
Self::Client(conn) => conn.alert(),
Self::Server(conn) => conn.alert(),
}
}
#[inline]
pub fn export_keying_material<T: AsMut<[u8]>>(
&self,
output: T,
label: &[u8],
context: Option<&[u8]>,
) -> Result<T, Error> {
match self {
Self::Client(conn) => conn
.core
.export_keying_material(output, label, context),
Self::Server(conn) => conn
.core
.export_keying_material(output, label, context),
}
}
}
impl Deref for Connection {
type Target = CommonState;
fn deref(&self) -> &Self::Target {
match self {
Self::Client(conn) => &conn.core.common_state,
Self::Server(conn) => &conn.core.common_state,
}
}
}
impl DerefMut for Connection {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
Self::Client(conn) => &mut conn.core.common_state,
Self::Server(conn) => &mut conn.core.common_state,
}
}
}
pub struct ClientConnection {
inner: ConnectionCommon<ClientConnectionData>,
}
impl ClientConnection {
pub fn new(
config: Arc<ClientConfig>,
quic_version: Version,
name: ServerName<'static>,
params: Vec<u8>,
) -> Result<Self, Error> {
if !config.supports_version(ProtocolVersion::TLSv1_3) {
return Err(Error::General(
"TLS 1.3 support is required for QUIC".into(),
));
}
if !config.supports_protocol(Protocol::Quic) {
return Err(Error::General(
"at least one ciphersuite must support QUIC".into(),
));
}
let ext = match quic_version {
Version::V1Draft => ClientExtension::TransportParametersDraft(params),
Version::V1 | Version::V2 => ClientExtension::TransportParameters(params),
};
let mut inner = ConnectionCore::for_client(config, name, vec![ext], Protocol::Quic)?;
inner.common_state.quic.version = quic_version;
Ok(Self {
inner: inner.into(),
})
}
pub fn is_early_data_accepted(&self) -> bool {
self.inner.core.is_early_data_accepted()
}
}
impl Deref for ClientConnection {
type Target = ConnectionCommon<ClientConnectionData>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for ClientConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl Debug for ClientConnection {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("quic::ClientConnection")
.finish()
}
}
impl From<ClientConnection> for Connection {
fn from(c: ClientConnection) -> Self {
Self::Client(c)
}
}
pub struct ServerConnection {
inner: ConnectionCommon<ServerConnectionData>,
}
impl ServerConnection {
pub fn new(
config: Arc<ServerConfig>,
quic_version: Version,
params: Vec<u8>,
) -> Result<Self, Error> {
if !config.supports_version(ProtocolVersion::TLSv1_3) {
return Err(Error::General(
"TLS 1.3 support is required for QUIC".into(),
));
}
if !config.supports_protocol(Protocol::Quic) {
return Err(Error::General(
"at least one ciphersuite must support QUIC".into(),
));
}
if config.max_early_data_size != 0 && config.max_early_data_size != 0xffff_ffff {
return Err(Error::General(
"QUIC sessions must set a max early data of 0 or 2^32-1".into(),
));
}
let ext = match quic_version {
Version::V1Draft => ServerExtension::TransportParametersDraft(params),
Version::V1 | Version::V2 => ServerExtension::TransportParameters(params),
};
let mut core = ConnectionCore::for_server(config, vec![ext])?;
core.common_state.protocol = Protocol::Quic;
core.common_state.quic.version = quic_version;
Ok(Self { inner: core.into() })
}
pub fn reject_early_data(&mut self) {
self.inner.core.reject_early_data()
}
pub fn server_name(&self) -> Option<&str> {
self.inner.core.get_sni_str()
}
}
impl Deref for ServerConnection {
type Target = ConnectionCommon<ServerConnectionData>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for ServerConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl Debug for ServerConnection {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("quic::ServerConnection")
.finish()
}
}
impl From<ServerConnection> for Connection {
fn from(c: ServerConnection) -> Self {
Self::Server(c)
}
}
pub struct ConnectionCommon<Data> {
core: ConnectionCore<Data>,
}
impl<Data: SideData> ConnectionCommon<Data> {
pub fn quic_transport_parameters(&self) -> Option<&[u8]> {
self.core
.common_state
.quic
.params
.as_ref()
.map(|v| v.as_ref())
}
pub fn zero_rtt_keys(&self) -> Option<DirectionalKeys> {
let suite = self
.core
.common_state
.suite
.and_then(|suite| suite.tls13())?;
Some(DirectionalKeys::new(
suite,
suite.quic?,
self.core
.common_state
.quic
.early_secret
.as_ref()?,
self.core.common_state.quic.version,
))
}
pub fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error> {
self.core
.message_deframer
.push(ProtocolVersion::TLSv1_3, plaintext)?;
self.core.process_new_packets()?;
Ok(())
}
pub fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange> {
self.core
.common_state
.quic
.write_hs(buf)
}
pub fn alert(&self) -> Option<AlertDescription> {
self.core.common_state.quic.alert
}
}
impl<Data> Deref for ConnectionCommon<Data> {
type Target = CommonState;
fn deref(&self) -> &Self::Target {
&self.core.common_state
}
}
impl<Data> DerefMut for ConnectionCommon<Data> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.core.common_state
}
}
impl<Data> From<ConnectionCore<Data>> for ConnectionCommon<Data> {
fn from(core: ConnectionCore<Data>) -> Self {
Self { core }
}
}
#[derive(Default)]
pub(crate) struct Quic {
pub(crate) params: Option<Vec<u8>>,
pub(crate) alert: Option<AlertDescription>,
pub(crate) hs_queue: VecDeque<(bool, Vec<u8>)>,
pub(crate) early_secret: Option<OkmBlock>,
pub(crate) hs_secrets: Option<Secrets>,
pub(crate) traffic_secrets: Option<Secrets>,
pub(crate) returned_traffic_keys: bool,
pub(crate) version: Version,
}
impl Quic {
pub(crate) fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange> {
while let Some((_, msg)) = self.hs_queue.pop_front() {
buf.extend_from_slice(&msg);
if let Some(&(true, _)) = self.hs_queue.front() {
if self.hs_secrets.is_some() {
break;
}
}
}
if let Some(secrets) = self.hs_secrets.take() {
return Some(KeyChange::Handshake {
keys: Keys::new(&secrets),
});
}
if let Some(mut secrets) = self.traffic_secrets.take() {
if !self.returned_traffic_keys {
self.returned_traffic_keys = true;
let keys = Keys::new(&secrets);
secrets.update();
return Some(KeyChange::OneRtt {
keys,
next: secrets,
});
}
}
None
}
}
#[derive(Clone)]
pub struct Secrets {
pub(crate) client: OkmBlock,
pub(crate) server: OkmBlock,
suite: &'static Tls13CipherSuite,
quic: &'static dyn Algorithm,
side: Side,
version: Version,
}
impl Secrets {
pub(crate) fn new(
client: OkmBlock,
server: OkmBlock,
suite: &'static Tls13CipherSuite,
quic: &'static dyn Algorithm,
side: Side,
version: Version,
) -> Self {
Self {
client,
server,
suite,
quic,
side,
version,
}
}
pub fn next_packet_keys(&mut self) -> PacketKeySet {
let keys = PacketKeySet::new(self);
self.update();
keys
}
pub(crate) fn update(&mut self) {
self.client = hkdf_expand_label_block(
self.suite
.hkdf_provider
.expander_for_okm(&self.client)
.as_ref(),
self.version.key_update_label(),
&[],
);
self.server = hkdf_expand_label_block(
self.suite
.hkdf_provider
.expander_for_okm(&self.server)
.as_ref(),
self.version.key_update_label(),
&[],
);
}
fn local_remote(&self) -> (&OkmBlock, &OkmBlock) {
match self.side {
Side::Client => (&self.client, &self.server),
Side::Server => (&self.server, &self.client),
}
}
}
pub struct DirectionalKeys {
pub header: Box<dyn HeaderProtectionKey>,
pub packet: Box<dyn PacketKey>,
}
impl DirectionalKeys {
pub(crate) fn new(
suite: &'static Tls13CipherSuite,
quic: &'static dyn Algorithm,
secret: &OkmBlock,
version: Version,
) -> Self {
let builder = KeyBuilder::new(secret, version, quic, suite.hkdf_provider);
Self {
header: builder.header_protection_key(),
packet: builder.packet_key(),
}
}
}
const TAG_LEN: usize = 16;
pub struct Tag([u8; TAG_LEN]);
impl From<&[u8]> for Tag {
fn from(value: &[u8]) -> Self {
let mut array = [0u8; TAG_LEN];
array.copy_from_slice(value);
Self(array)
}
}
impl AsRef<[u8]> for Tag {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
pub trait Algorithm: Send + Sync {
fn packet_key(&self, key: AeadKey, iv: Iv) -> Box<dyn PacketKey>;
fn header_protection_key(&self, key: AeadKey) -> Box<dyn HeaderProtectionKey>;
fn aead_key_len(&self) -> usize;
}
pub trait HeaderProtectionKey {
fn encrypt_in_place(
&self,
sample: &[u8],
first: &mut u8,
packet_number: &mut [u8],
) -> Result<(), Error>;
fn decrypt_in_place(
&self,
sample: &[u8],
first: &mut u8,
packet_number: &mut [u8],
) -> Result<(), Error>;
fn sample_len(&self) -> usize;
}
pub trait PacketKey {
fn encrypt_in_place(
&self,
packet_number: u64,
header: &[u8],
payload: &mut [u8],
) -> Result<Tag, Error>;
fn decrypt_in_place<'a>(
&self,
packet_number: u64,
header: &[u8],
payload: &'a mut [u8],
) -> Result<&'a [u8], Error>;
fn tag_len(&self) -> usize;
}
pub struct PacketKeySet {
pub local: Box<dyn PacketKey>,
pub remote: Box<dyn PacketKey>,
}
impl PacketKeySet {
fn new(secrets: &Secrets) -> Self {
let (local, remote) = secrets.local_remote();
let (version, alg, hkdf) = (secrets.version, secrets.quic, secrets.suite.hkdf_provider);
Self {
local: KeyBuilder::new(local, version, alg, hkdf).packet_key(),
remote: KeyBuilder::new(remote, version, alg, hkdf).packet_key(),
}
}
}
pub(crate) struct KeyBuilder<'a> {
expander: Box<dyn HkdfExpander>,
version: Version,
alg: &'a dyn Algorithm,
}
impl<'a> KeyBuilder<'a> {
pub(crate) fn new(
secret: &OkmBlock,
version: Version,
alg: &'a dyn Algorithm,
hkdf: &'a dyn Hkdf,
) -> Self {
Self {
expander: hkdf.expander_for_okm(secret),
version,
alg,
}
}
pub(crate) fn packet_key(&self) -> Box<dyn PacketKey> {
let aead_key_len = self.alg.aead_key_len();
let packet_key = hkdf_expand_label_aead_key(
self.expander.as_ref(),
aead_key_len,
self.version.packet_key_label(),
&[],
);
let packet_iv =
hkdf_expand_label(self.expander.as_ref(), self.version.packet_iv_label(), &[]);
self.alg
.packet_key(packet_key, packet_iv)
}
pub(crate) fn header_protection_key(&self) -> Box<dyn HeaderProtectionKey> {
let header_key = hkdf_expand_label_aead_key(
self.expander.as_ref(),
self.alg.aead_key_len(),
self.version.header_key_label(),
&[],
);
self.alg
.header_protection_key(header_key)
}
}
pub struct Keys {
pub local: DirectionalKeys,
pub remote: DirectionalKeys,
}
impl Keys {
pub fn initial(
version: Version,
suite: &'static Tls13CipherSuite,
quic: &'static dyn Algorithm,
client_dst_connection_id: &[u8],
side: Side,
) -> Self {
const CLIENT_LABEL: &[u8] = b"client in";
const SERVER_LABEL: &[u8] = b"server in";
let salt = version.initial_salt();
let hs_secret = suite
.hkdf_provider
.extract_from_secret(Some(salt), client_dst_connection_id);
let secrets = Secrets {
version,
client: hkdf_expand_label_block(hs_secret.as_ref(), CLIENT_LABEL, &[]),
server: hkdf_expand_label_block(hs_secret.as_ref(), SERVER_LABEL, &[]),
suite,
quic,
side,
};
Self::new(&secrets)
}
fn new(secrets: &Secrets) -> Self {
let (local, remote) = secrets.local_remote();
Self {
local: DirectionalKeys::new(secrets.suite, secrets.quic, local, secrets.version),
remote: DirectionalKeys::new(secrets.suite, secrets.quic, remote, secrets.version),
}
}
}
#[allow(clippy::large_enum_variant)]
pub enum KeyChange {
Handshake {
keys: Keys,
},
OneRtt {
keys: Keys,
next: Secrets,
},
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug)]
pub enum Version {
V1Draft,
V1,
V2,
}
impl Version {
fn initial_salt(self) -> &'static [u8; 20] {
match self {
Self::V1Draft => &[
0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61,
0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99,
],
Self::V1 => &[
0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8,
0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a,
],
Self::V2 => &[
0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26,
0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9,
],
}
}
pub(crate) fn packet_key_label(&self) -> &'static [u8] {
match self {
Self::V1Draft | Self::V1 => b"quic key",
Self::V2 => b"quicv2 key",
}
}
pub(crate) fn packet_iv_label(&self) -> &'static [u8] {
match self {
Self::V1Draft | Self::V1 => b"quic iv",
Self::V2 => b"quicv2 iv",
}
}
pub(crate) fn header_key_label(&self) -> &'static [u8] {
match self {
Self::V1Draft | Self::V1 => b"quic hp",
Self::V2 => b"quicv2 hp",
}
}
fn key_update_label(&self) -> &'static [u8] {
match self {
Self::V1Draft | Self::V1 => b"quic ku",
Self::V2 => b"quicv2 ku",
}
}
}
impl Default for Version {
fn default() -> Self {
Self::V1
}
}