use alloc::boxed::Box;
use alloc::vec::Vec;
use crate::quic::cid::{CidPair, ConnectionId};
use crate::quic::crypto::{AeadAlg, derive_dir_keys, derive_initial_secrets};
use crate::quic::endpoint::Endpoint;
use crate::quic::tls_glue::{HookHandle, build_hooks};
use crate::rng::OsRng;
use crate::tls::Error;
use crate::tls::conn::{ServerConfig, ServerConnection};
use crate::tls::quic_hooks::Level;
pub(crate) const DEFAULT_SCID_LEN: usize = 8;
pub(crate) fn build_tls_engine(
tls_cfg: ServerConfig,
transport_params: Vec<u8>,
) -> Result<(ServerConnection<OsRng>, HookHandle), Error> {
let (hooks, handle) = build_hooks(transport_params);
let engine = ServerConnection::new_for_quic(tls_cfg, OsRng, hooks as Box<_>);
Ok((engine, handle))
}
pub(crate) fn install_initial_keys(endpoint: &mut Endpoint, client_dcid: &[u8]) {
let (client_secret, server_secret) = derive_initial_secrets(client_dcid);
endpoint.crypto.levels[Level::Initial as usize].tx =
Some(derive_dir_keys(AeadAlg::Aes128Gcm, &server_secret));
endpoint.crypto.levels[Level::Initial as usize].rx =
Some(derive_dir_keys(AeadAlg::Aes128Gcm, &client_secret));
}
pub(crate) fn build_pending_endpoint() -> Endpoint {
Endpoint::new(CidPair::new(ConnectionId::empty(), ConnectionId::empty()))
}
pub(crate) fn set_cids_from_first_initial(
endpoint: &mut Endpoint,
peer_scid: ConnectionId,
our_local: ConnectionId,
) {
endpoint.cids = CidPair::new(peer_scid, our_local);
}
pub(crate) fn random_default_scid() -> ConnectionId {
let mut rng = OsRng;
ConnectionId::random(&mut rng, DEFAULT_SCID_LEN)
}
use core::time::Duration;
use std::collections::{HashMap, VecDeque};
use std::net::SocketAddr;
use std::time::Instant;
use crate::quic::connection::{QuicConfig, QuicConnection};
use crate::quic::ecn::EcnCodepoint;
use crate::quic::pkt::{LongHeader, LongType, QUIC_V1, build_version_negotiation};
use crate::quic::reset::{MIN_STATELESS_RESET_LEN, build_stateless_reset, stateless_reset_token};
use crate::rng::RngCore;
fn parse_long_invariant(buf: &[u8]) -> Option<(u32, &[u8], &[u8])> {
if buf.len() < 6 {
return None;
}
let version = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]);
let dcid_len = buf[5] as usize;
if dcid_len > 20 {
return None;
}
let dcid_end = 6 + dcid_len;
let scid_len_pos = dcid_end;
if buf.len() <= scid_len_pos {
return None;
}
let scid_len = buf[scid_len_pos] as usize;
if scid_len > 20 {
return None;
}
let scid_start = scid_len_pos + 1;
let scid_end = scid_start + scid_len;
if buf.len() < scid_end {
return None;
}
Some((version, &buf[6..dcid_end], &buf[scid_start..scid_end]))
}
struct Hosted {
conn: QuicConnection,
addr: SocketAddr,
}
pub struct QuicServer {
make_config: Box<dyn FnMut() -> Result<QuicConfig, Error>>,
reset_key: [u8; 32],
conns: HashMap<u64, Hosted>,
by_cid: HashMap<ConnectionId, u64>,
by_addr: HashMap<SocketAddr, u64>,
pending: VecDeque<(SocketAddr, EcnCodepoint, Vec<u8>)>,
next_id: u64,
now_secs: u64,
}
impl QuicServer {
pub fn new<F>(make_config: F) -> Result<Self, Error>
where
F: FnMut() -> Result<QuicConfig, Error> + 'static,
{
let mut reset_key = [0u8; 32];
OsRng.fill_bytes(&mut reset_key);
Self::with_reset_key(reset_key, make_config)
}
pub fn with_reset_key<F>(reset_key: [u8; 32], make_config: F) -> Result<Self, Error>
where
F: FnMut() -> Result<QuicConfig, Error> + 'static,
{
Ok(QuicServer {
make_config: Box::new(make_config),
reset_key,
conns: HashMap::new(),
by_cid: HashMap::new(),
by_addr: HashMap::new(),
pending: VecDeque::new(),
next_id: 0,
now_secs: 0,
})
}
pub fn connection_count(&self) -> usize {
self.conns.len()
}
pub fn set_now_secs(&mut self, secs: u64) {
self.now_secs = secs;
for h in self.conns.values_mut() {
h.conn.set_now_secs(secs);
}
}
pub fn recv(
&mut self,
from: SocketAddr,
ecn: EcnCodepoint,
datagram: &[u8],
) -> Result<(), Error> {
if datagram.is_empty() {
return Ok(());
}
if datagram[0] & 0x80 != 0 {
let (version, inv_dcid, inv_scid) = match parse_long_invariant(datagram) {
Some(t) => t,
None => return Ok(()),
};
if version == 0 {
return Ok(());
}
if version != QUIC_V1 {
let vn = build_version_negotiation(inv_scid, inv_dcid, &[QUIC_V1]);
self.pending.push_back((from, EcnCodepoint::NotEct, vn));
return Ok(());
}
let hdr = match LongHeader::parse(datagram) {
Ok(h) => h,
Err(_) => return Ok(()),
};
let dcid = match ConnectionId::from_slice(hdr.dcid) {
Some(c) => c,
None => return Ok(()),
};
if let Some(&id) = self.by_cid.get(&dcid) {
self.feed(id, from, ecn, datagram);
} else if let Some(&id) = self.by_addr.get(&from) {
self.feed(id, from, ecn, datagram);
} else if hdr.typ == LongType::Initial {
self.accept(from, ecn, datagram)?;
} else {
self.queue_reset(from, &dcid, datagram.len());
}
} else {
let dlen = DEFAULT_SCID_LEN;
if datagram.len() < 1 + dlen {
return Ok(());
}
let dcid = ConnectionId::from_slice(&datagram[1..1 + dlen]).expect("len <= 20");
if let Some(&id) = self.by_cid.get(&dcid) {
self.feed(id, from, ecn, datagram);
} else if let Some(&id) = self.by_addr.get(&from) {
self.feed(id, from, ecn, datagram);
} else {
self.queue_reset(from, &dcid, datagram.len());
}
}
Ok(())
}
pub fn poll_transmit(&mut self) -> Option<(SocketAddr, EcnCodepoint, Vec<u8>)> {
if let Some(p) = self.pending.pop_front() {
return Some(p);
}
for h in self.conns.values_mut() {
let dg = h.conn.pop_datagram();
if !dg.is_empty() {
return Some((h.addr, h.conn.egress_ecn(), dg));
}
}
None
}
pub fn next_timeout(&self) -> Option<Duration> {
let now = Instant::now();
self.conns
.values()
.filter_map(|h| {
h.conn
.next_timeout()
.map(|d| (h.conn.started_at() + d).saturating_duration_since(now))
})
.min()
}
pub fn on_timeout(&mut self) {
let now = Instant::now();
for h in self.conns.values_mut() {
let elapsed = now.saturating_duration_since(h.conn.started_at());
h.conn.on_timeout(elapsed);
}
self.reap_closed();
}
pub fn connections_mut(&mut self) -> impl Iterator<Item = (SocketAddr, &mut QuicConnection)> {
self.conns.values_mut().map(|h| (h.addr, &mut h.conn))
}
fn feed(&mut self, id: u64, from: SocketAddr, ecn: EcnCodepoint, datagram: &[u8]) {
let cids = match self.conns.get_mut(&id) {
Some(h) => {
let _ = h.conn.feed_datagram_from_with_ecn(from, ecn, datagram);
h.addr = from;
h.conn.local_cids()
}
None => return,
};
for cid in cids {
self.by_cid.insert(cid, id);
}
}
fn accept(
&mut self,
from: SocketAddr,
ecn: EcnCodepoint,
datagram: &[u8],
) -> Result<(), Error> {
let mut cfg = (self.make_config)()?;
cfg.reset_key = Some(self.reset_key);
let mut conn = QuicConnection::server(cfg)?;
conn.set_peer_addr(from);
conn.set_now_secs(self.now_secs);
let id = self.next_id;
self.next_id += 1;
self.conns.insert(id, Hosted { conn, addr: from });
self.by_addr.insert(from, id);
self.feed(id, from, ecn, datagram);
Ok(())
}
fn queue_reset(&mut self, from: SocketAddr, dcid: &ConnectionId, triggering_len: usize) {
if triggering_len <= MIN_STATELESS_RESET_LEN {
return;
}
let token = stateless_reset_token(&self.reset_key, dcid);
let len = (triggering_len - 1).max(MIN_STATELESS_RESET_LEN);
let pkt = build_stateless_reset(&mut OsRng, &token, len);
self.pending.push_back((from, EcnCodepoint::NotEct, pkt));
}
fn reap_closed(&mut self) {
let dead: Vec<u64> = self
.conns
.iter()
.filter(|(_, h)| h.conn.is_closed())
.map(|(&id, _)| id)
.collect();
for id in dead {
self.conns.remove(&id);
self.by_cid.retain(|_, v| *v != id);
self.by_addr.retain(|_, v| *v != id);
}
}
}
#[cfg(test)]
mod server_tests {
use super::*;
use crate::ec::Ed25519PrivateKey;
use crate::hash::Sha256;
use crate::quic::transport_params::TransportParameters;
use crate::rng::HmacDrbg;
use crate::tls::{Config, Identity, ProtocolVersion, RootCertStore, SigningKey};
use crate::x509::{CertSigner, Certificate, DistinguishedName, Time, Validity};
use std::net::SocketAddr;
fn server_identity() -> (Config, Vec<u8>) {
let mut rng = HmacDrbg::<Sha256>::new(b"quic-server-router-key", b"nonce", &[]);
let key = Ed25519PrivateKey::generate(&mut rng);
let name = DistinguishedName::common_name("loopback.example");
let validity = Validity::new(
Time::utc(2024, 1, 1, 0, 0, 0),
Time::utc(2034, 1, 1, 0, 0, 0),
);
let cert = Certificate::self_signed_general(
&CertSigner::Ed25519(&key),
&name,
&validity,
1,
false,
&["loopback.example"],
)
.unwrap();
let der = cert.to_der().to_vec();
let cfg = Config {
identity: Some(Identity {
cert_chain: alloc::vec![der.clone()],
key: SigningKey::Ed25519(key),
}),
alpn_protocols: alloc::vec![b"test".to_vec()],
max_version: ProtocolVersion::TLSv1_3,
min_version: ProtocolVersion::TLSv1_3,
..Config::default()
};
(cfg, der)
}
fn client_config(cert_der: &[u8]) -> Config {
let mut roots = RootCertStore::new();
roots.add_der(cert_der.to_vec()).unwrap();
Config {
roots,
alpn_protocols: alloc::vec![b"test".to_vec()],
max_version: ProtocolVersion::TLSv1_3,
min_version: ProtocolVersion::TLSv1_3,
..Config::default()
}
}
fn tp() -> TransportParameters {
TransportParameters {
max_idle_timeout_ms: Some(30_000),
max_udp_payload_size: Some(1500),
initial_max_data: Some(1 << 20),
initial_max_stream_data_bidi_local: Some(1 << 16),
initial_max_stream_data_bidi_remote: Some(1 << 16),
initial_max_stream_data_uni: Some(1 << 16),
initial_max_streams_bidi: Some(100),
initial_max_streams_uni: Some(3),
ack_delay_exponent: Some(3),
max_ack_delay_ms: Some(25),
active_connection_id_limit: Some(2),
..TransportParameters::default()
}
}
fn client(cert_der: &[u8]) -> QuicConnection {
QuicConnection::client(
QuicConfig {
tls: client_config(cert_der),
transport_params: tp(),
..QuicConfig::default()
},
"loopback.example",
)
.expect("client build")
}
fn addr(port: u16) -> SocketAddr {
SocketAddr::from(([127, 0, 0, 1], port))
}
fn server(reset_key: [u8; 32]) -> QuicServer {
QuicServer::with_reset_key(reset_key, || {
Ok(QuicConfig {
tls: server_identity().0,
transport_params: tp(),
..QuicConfig::default()
})
})
.expect("server build")
}
#[test]
fn two_clients_multiplexed_with_stream_data() {
let (_, cert) = server_identity();
let mut srv = server([0x11; 32]);
let mut c1 = client(&cert);
let mut c2 = client(&cert);
let (a1, a2, sa) = (addr(40001), addr(40002), addr(443));
let mut opened = false;
for _ in 0..128 {
if !opened && c1.is_handshake_complete() && c2.is_handshake_complete() {
let s1 = c1.open_bidi().unwrap();
c1.write(s1, b"hello-from-client-1").unwrap();
c1.finish(s1).unwrap();
let s2 = c2.open_bidi().unwrap();
c2.write(s2, b"hello-from-client-2").unwrap();
c2.finish(s2).unwrap();
opened = true;
}
for (c, a) in [(&mut c1, a1), (&mut c2, a2)] {
loop {
let d = c.pop_datagram();
if d.is_empty() {
break;
}
srv.recv(a, EcnCodepoint::NotEct, &d).unwrap();
}
}
while let Some((to, _ecn, d)) = srv.poll_transmit() {
if to == a1 {
let _ = c1.feed_datagram_from(sa, &d);
} else if to == a2 {
let _ = c2.feed_datagram_from(sa, &d);
}
}
}
assert!(c1.is_handshake_complete(), "client 1 handshake");
assert!(c2.is_handshake_complete(), "client 2 handshake");
assert_eq!(srv.connection_count(), 2, "server hosts both connections");
let mut got: Vec<Vec<u8>> = Vec::new();
for (_addr, conn) in srv.connections_mut() {
let ids: Vec<_> = conn.readable_streams().collect();
for sid in ids {
let mut buf = [0u8; 256];
let (n, _fin) = conn.read(sid, &mut buf).unwrap();
if n > 0 {
got.push(buf[..n].to_vec());
}
}
}
assert!(
got.iter().any(|p| p == b"hello-from-client-1"),
"client 1 payload delivered"
);
assert!(
got.iter().any(|p| p == b"hello-from-client-2"),
"client 2 payload delivered"
);
}
#[test]
fn version_negotiation_for_unsupported_version() {
let mut srv = server([0x22; 32]);
let dcid = [0xA1u8; 8];
let scid = [0xB2u8; 4];
let mut pkt = alloc::vec![0xC0u8];
pkt.extend_from_slice(&0x1a2a_3a4au32.to_be_bytes()); pkt.push(dcid.len() as u8);
pkt.extend_from_slice(&dcid);
pkt.push(scid.len() as u8);
pkt.extend_from_slice(&scid);
pkt.extend_from_slice(&[0u8; 16]);
srv.recv(addr(50001), EcnCodepoint::NotEct, &pkt).unwrap();
let (to, _ecn, vn) = srv.poll_transmit().expect("a Version Negotiation reply");
assert_eq!(to, addr(50001));
assert!(vn[0] & 0x80 != 0);
assert_eq!(&vn[1..5], &[0, 0, 0, 0]);
assert_eq!(vn[5] as usize, scid.len());
assert_eq!(&vn[6..6 + scid.len()], &scid);
let scid_len_pos = 6 + scid.len();
assert_eq!(vn[scid_len_pos] as usize, dcid.len());
let echoed_dcid = &vn[scid_len_pos + 1..scid_len_pos + 1 + dcid.len()];
assert_eq!(echoed_dcid, &dcid);
let versions = &vn[scid_len_pos + 1 + dcid.len()..];
assert!(
versions
.chunks_exact(4)
.any(|w| u32::from_be_bytes([w[0], w[1], w[2], w[3]]) == QUIC_V1),
"VN advertises QUIC v1"
);
}
#[test]
fn stateless_reset_for_unknown_short_header() {
const KEY: [u8; 32] = [0x33; 32];
let mut srv = server(KEY);
let unknown = crate::quic::cid::ConnectionId::from_slice(&[0xCDu8; 8]).unwrap();
let mut pkt = alloc::vec![0x42u8];
pkt.extend_from_slice(unknown.as_slice());
pkt.extend_from_slice(&[0u8; 40]);
let trigger_len = pkt.len();
srv.recv(addr(60001), EcnCodepoint::NotEct, &pkt).unwrap();
let (to, _ecn, reset) = srv.poll_transmit().expect("a stateless reset reply");
assert_eq!(to, addr(60001));
assert!(reset.len() >= MIN_STATELESS_RESET_LEN, "reset >= 21 bytes");
assert!(reset.len() < trigger_len, "reset shorter than the trigger");
assert_eq!(reset[0] & 0xc0, 0x40, "short-header form (0b01xxxxxx)");
let token = &reset[reset.len() - 16..];
let expected = stateless_reset_token(&KEY, &unknown);
assert_eq!(
token, &expected,
"reset carries the derivable token for the CID"
);
}
}