use std::{
collections::{HashMap, VecDeque},
io,
net::SocketAddr,
time::Instant,
};
use quiche::{ConnectionId, Header, RecvInfo, SendInfo};
use rand::seq::IteratorRandom;
use ring::{hmac::Key, rand::SystemRandom};
use crate::{
errors::into_io_error,
quic::{Config, QuicConn},
};
use super::QuicConnState;
enum InitAck {
NegotiationVersion {
scid: ConnectionId<'static>,
dcid: ConnectionId<'static>,
send_info: SendInfo,
},
Retry {
scid: ConnectionId<'static>,
dcid: ConnectionId<'static>,
new_scid: ConnectionId<'static>,
token: Vec<u8>,
version: u32,
send_info: SendInfo,
},
}
pub(crate) struct QuicAcceptor {
config: Config,
conn_id_seed: Key,
conns: HashMap<ConnectionId<'static>, quiche::Connection>,
init_acks: VecDeque<InitAck>,
}
impl QuicAcceptor {
pub fn new(config: Config) -> io::Result<Self> {
let rng = SystemRandom::new();
let conn_id_seed = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng)
.map_err(|err| io::Error::new(io::ErrorKind::Other, format!("${err}")))?;
Ok(Self {
config,
conn_id_seed,
conns: Default::default(),
init_acks: Default::default(),
})
}
pub fn pop_established(&mut self) -> Vec<(ConnectionId<'static>, QuicConn)> {
let ids = self
.conns
.values()
.filter(|conn| conn.is_established())
.map(|conn| conn.source_id().clone().into_owned())
.collect::<Vec<_>>();
let mut conns = vec![];
for id in ids {
let state = QuicConn::new(QuicConnState::new(self.conns.remove(&id).unwrap(), 5));
conns.push((id, state));
}
conns
}
pub fn send(&mut self, buf: &mut [u8]) -> io::Result<(usize, SendInfo)> {
if let Some(ack) = self.init_acks.pop_front() {
match ack {
InitAck::NegotiationVersion {
scid,
dcid,
send_info,
} => {
let send_size =
quiche::negotiate_version(&scid, &dcid, buf).map_err(into_io_error)?;
return Ok((send_size, send_info));
}
InitAck::Retry {
scid,
dcid,
new_scid,
token,
version,
send_info,
} => {
let send_size = quiche::retry(&scid, &dcid, &new_scid, &token, version, buf)
.map_err(into_io_error)?;
return Ok((send_size, send_info));
}
}
}
let len = self.conns.len();
let conns = self
.conns
.values_mut()
.choose_multiple(&mut rand::thread_rng(), len);
for conn in conns {
match conn.send(buf) {
Ok((send_size, send_info)) => {
return Ok((send_size, send_info));
}
Err(quiche::Error::Done) => {}
Err(_) => todo!(),
}
}
return Err(io::Error::new(
io::ErrorKind::WouldBlock,
"No more data to send",
));
}
pub fn recv<'a>(
&mut self,
buf: &'a mut [u8],
recv_info: RecvInfo,
) -> io::Result<(usize, Header<'a>)> {
let header =
quiche::Header::from_slice(buf, quiche::MAX_CONN_ID_LEN).map_err(into_io_error)?;
log::trace!("Recv package {:?}", header.ty);
match header.ty {
quiche::Type::Initial => {
if let Some(conn) = self.conns.get_mut(&header.dcid) {
let token = header.token.as_ref().unwrap();
if token.is_empty() {
return Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
"Quic token is null",
));
}
_ = Self::validate_token(token, &recv_info.from)?;
let read_size = conn.recv(buf, recv_info).map_err(into_io_error)?;
return Ok((read_size, header));
} else {
let len = self.client_hello(&header, buf, recv_info)?;
return Ok((len, header));
}
}
quiche::Type::Handshake => {
let conn = self.conns.get_mut(&header.dcid).ok_or(io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("Quic connection not found, decid={:?}", header.dcid),
))?;
return conn
.recv(buf, recv_info)
.map_err(into_io_error)
.map(|len| (len, header));
}
_ => {
return Ok((0, header));
}
}
}
fn client_hello(
&mut self,
header: &Header<'_>,
buf: &mut [u8],
recv_info: RecvInfo,
) -> io::Result<usize> {
if !quiche::version_is_supported(header.version) {
self.negotiation_version(header, recv_info)?;
return Ok(buf.len());
}
let token = header.token.as_ref().unwrap();
if token.is_empty() {
self.retry(header, recv_info)?;
return Ok(buf.len());
}
let odcid = Self::validate_token(token, &recv_info.from)?;
let scid: ConnectionId<'_> = header.dcid.clone();
if quiche::MAX_CONN_ID_LEN != scid.len() {
return Err(io::Error::new(
io::ErrorKind::Interrupted,
format!("Check dcid length error, len={}", scid.len()),
));
}
let mut conn = quiche::accept(
&scid,
Some(&odcid),
recv_info.to,
recv_info.from,
&mut self.config,
)
.map_err(into_io_error)?;
let read_size = conn.recv(buf, recv_info).map_err(into_io_error)?;
log::trace!(
"Create new incoming conn, scid={:?},dcid={:?},handshake={}",
conn.source_id(),
conn.destination_id(),
read_size,
);
self.conns.insert(scid.into_owned(), conn);
Ok(read_size)
}
fn retry(&mut self, header: &Header<'_>, recv_info: RecvInfo) -> io::Result<()> {
let token = self.mint_token(&header, &recv_info.from);
let new_scid = ring::hmac::sign(&self.conn_id_seed, &header.dcid);
let new_scid = &new_scid.as_ref()[..quiche::MAX_CONN_ID_LEN];
let new_scid = ConnectionId::from_vec(new_scid.to_vec());
self.init_acks.push_back(InitAck::Retry {
scid: header.scid.clone().into_owned(),
dcid: header.dcid.clone().into_owned(),
new_scid,
token,
version: header.version,
send_info: SendInfo {
from: recv_info.to,
to: recv_info.from,
at: Instant::now(),
},
});
Ok(())
}
fn validate_token<'a>(
token: &'a [u8],
src: &SocketAddr,
) -> io::Result<quiche::ConnectionId<'a>> {
if token.len() < 6 {
return Err(io::Error::new(
io::ErrorKind::Interrupted,
format!("Invalid token, token length < 6"),
));
}
if &token[..6] != b"quiche" {
return Err(io::Error::new(
io::ErrorKind::Interrupted,
format!("Invalid token, not start with 'quiche'"),
));
}
let token = &token[6..];
let addr = match src.ip() {
std::net::IpAddr::V4(a) => a.octets().to_vec(),
std::net::IpAddr::V6(a) => a.octets().to_vec(),
};
if token.len() < addr.len() || &token[..addr.len()] != addr.as_slice() {
return Err(io::Error::new(
io::ErrorKind::Interrupted,
format!("Invalid token, address mismatch"),
));
}
Ok(quiche::ConnectionId::from_ref(&token[addr.len()..]))
}
fn mint_token<'a>(&self, hdr: &quiche::Header<'a>, src: &SocketAddr) -> Vec<u8> {
let mut token = Vec::new();
token.extend_from_slice(b"quiche");
let addr = match src.ip() {
std::net::IpAddr::V4(a) => a.octets().to_vec(),
std::net::IpAddr::V6(a) => a.octets().to_vec(),
};
token.extend_from_slice(&addr);
token.extend_from_slice(&hdr.dcid);
token
}
fn negotiation_version(&mut self, header: &Header<'_>, recv_info: RecvInfo) -> io::Result<()> {
self.init_acks.push_back(InitAck::NegotiationVersion {
scid: header.scid.clone().into_owned(),
dcid: header.dcid.clone().into_owned(),
send_info: SendInfo {
from: recv_info.to,
to: recv_info.from,
at: Instant::now(),
},
});
Ok(())
}
}