use std::sync::Arc;
use qbase::{
Epoch, GetEpoch,
error::{Error, QuicError},
frame::{ConnectionCloseFrame, CryptoFrame, Frame},
net::tx::Signals,
packet::{header::long::HandshakeHeader, io::PacketSpace, keys::ArcKeys},
util::BoundQueue,
};
use qcongestion::{Feedback, Transport};
use qevent::{
quic::{
PacketHeader, PacketType, QuicFramesCollector,
recovery::{PacketLost, PacketLostTrigger},
},
telemetry::Instrument,
};
use qinterface::component::route::{CipherPacket, PlainPacket, Way};
use qrecovery::crypto::CryptoStream;
use tokio::sync::mpsc;
use crate::{
Components, HandshakeJournal,
events::{ArcEventBroker, EmitEvent, Event},
path::{self, Path, error::CreatePathFailure},
space::{
AckHandshakeSpace, assemble_closing_packet, filter_odcid_packet, pipe, read_plain_packet,
},
state,
termination::Terminator,
tx::{PacketWriter, TrivialPacketWriter},
};
pub type CipherHanshakePacket = CipherPacket<HandshakeHeader>;
pub type PlainHandshakePacket = PlainPacket<HandshakeHeader>;
pub type ReceivedFrom = (CipherHanshakePacket, Way);
pub struct HandshakeSpace {
keys: ArcKeys,
journal: HandshakeJournal,
}
impl AsRef<HandshakeJournal> for HandshakeSpace {
fn as_ref(&self) -> &HandshakeJournal {
&self.journal
}
}
impl HandshakeSpace {
pub fn new() -> Self {
Self {
keys: ArcKeys::new_pending(),
journal: HandshakeJournal::with_capacity(16, None),
}
}
pub fn keys(&self) -> ArcKeys {
self.keys.clone()
}
pub async fn decrypt_packet(
&self,
packet: CipherHanshakePacket,
) -> Option<Result<PlainHandshakePacket, QuicError>> {
match self.keys.get_remote_keys().await {
Some(keys) => packet.decrypt_long_packet(
keys.remote.header.as_ref(),
keys.remote.packet.as_ref(),
|pn| self.journal.of_rcvd_packets().decode_pn(pn),
),
None => {
packet.drop_on_key_unavailable();
None
}
}
}
pub fn tracker(&self, crypto_stream: CryptoStream) -> HandshakeTracker {
HandshakeTracker {
journal: self.journal.clone(),
crypto_stream,
}
}
}
impl Default for HandshakeSpace {
fn default() -> Self {
Self::new()
}
}
impl GetEpoch for HandshakeSpace {
fn epoch(&self) -> Epoch {
Epoch::Handshake
}
}
impl path::PacketSpace<HandshakeHeader> for HandshakeSpace {
type JournalFrame = CryptoFrame;
fn new_packet<'b, 's>(
&'s self,
header: HandshakeHeader,
cc: &qcongestion::ArcCC,
buffer: &'b mut [u8],
) -> Result<PacketWriter<'b, 's, CryptoFrame>, Signals> {
let keys = self.keys.get_local_keys().ok_or(Signals::KEYS)?;
let (retran_timeout, expire_timeout) = cc.retransmit_and_expire_time(Epoch::Handshake);
PacketWriter::new_long(
header,
buffer,
keys.local.clone(),
self.journal.as_ref(),
retran_timeout,
expire_timeout,
)
}
}
impl PacketSpace<HandshakeHeader> for HandshakeSpace {
type PacketAssembler<'a> = TrivialPacketWriter<'a, 'a, CryptoFrame>;
#[inline]
fn new_packet<'a>(
&'a self,
header: HandshakeHeader,
buffer: &'a mut [u8],
) -> Result<Self::PacketAssembler<'a>, Signals> {
let keys = self.keys.get_local_keys().ok_or(Signals::KEYS)?;
TrivialPacketWriter::new_long(header, buffer, keys.local, self.journal.as_ref())
}
}
fn frame_dispathcer(
space: &HandshakeSpace,
components: &Components,
event_broker: &ArcEventBroker,
) -> impl for<'p> Fn(Frame, &'p Path) + use<> {
let (crypto_frames_entry, rcvd_crypto_frames) = mpsc::unbounded_channel();
let (ack_frames_entry, rcvd_ack_frames) = mpsc::unbounded_channel();
pipe(
rcvd_crypto_frames,
components.crypto_streams[space.epoch()].incoming(),
event_broker.clone(),
);
pipe(
rcvd_ack_frames,
AckHandshakeSpace::new(&space.journal, &components.crypto_streams[space.epoch()]),
event_broker.clone(),
);
let inform_cc = components.quic_handshake.status();
let event_broker = event_broker.clone();
let rcvd_joural = space.journal.of_rcvd_packets();
move |frame: Frame, path: &Path| match frame {
Frame::Ack(f) => {
path.cc().on_ack_rcvd(Epoch::Handshake, &f);
rcvd_joural.on_rcvd_ack(&f);
_ = ack_frames_entry.send(f);
inform_cc.received_handshake_ack();
}
Frame::Close(f) => event_broker.emit(Event::Closed(f)),
Frame::Crypto(f, bytes) => _ = crypto_frames_entry.send((f, bytes)),
Frame::Padding(_) | Frame::Ping(_) => {}
_ => unreachable!("unexpected frame: {:?} in handshake packet", frame),
}
}
async fn parse_normal_packet(
(packet, (bind_uri, pathway, link)): ReceivedFrom,
space: &HandshakeSpace,
components: &Components,
dispatch_frame: impl Fn(Frame, &Path),
) -> Result<(), Error> {
let Some(packet) = space.decrypt_packet(packet).await.transpose()? else {
return Ok(());
};
let path = match components.get_or_try_create_path(bind_uri, link, pathway, true) {
Ok(path) => path,
Err(CreatePathFailure::ConnectionClosed(..)) => {
packet.drop_on_conenction_closed();
return Ok(());
}
Err(CreatePathFailure::NoInterface(..)) => {
packet.drop_on_interface_not_found();
return Ok(());
}
};
let Some(packet) = filter_odcid_packet(packet, &components.specific) else {
return Ok(());
};
path.grant_anti_amplification();
let packet_content = read_plain_packet(&packet, |frame| dispatch_frame(frame, &path))?;
space.journal.of_rcvd_packets().on_rcvd_pn(
packet.pn(),
packet_content.is_ack_eliciting(),
path.cc().get_pto(Epoch::Handshake),
);
path.on_packet_rcvd(Epoch::Handshake, packet.pn(), packet.size(), packet_content);
Result::<(), Error>::Ok(())
}
fn parse_closing_packet(
space: &HandshakeSpace,
packet: CipherHanshakePacket,
) -> Option<ConnectionCloseFrame> {
let remote_keys = space.keys.get_local_keys()?.remote;
let packet = packet
.decrypt_long_packet(
remote_keys.header.as_ref(),
remote_keys.packet.as_ref(),
|pn| space.journal.of_rcvd_packets().decode_pn(pn),
)
.and_then(Result::ok)?;
let mut ccf = None;
_ = read_plain_packet(&packet, |frame| {
ccf = ccf.take().or(match frame {
Frame::Close(ccf) => Some(ccf),
_ => None,
});
});
ccf
}
pub async fn deliver_and_parse_packets(
packets: BoundQueue<ReceivedFrom>,
space: Arc<HandshakeSpace>,
components: Components,
event_broker: ArcEventBroker,
) {
let conn_state = &components.conn_state;
let dispatch_frame = frame_dispathcer(&space, &components, &event_broker);
let normal_deliver_and_parse_loop = async {
while let Some(form) = packets.recv().await {
let span = qevent::span!(@current, path=form.1.2.to_string());
let parse = parse_normal_packet(form, &space, &components, &dispatch_frame);
if let Err(Error::Quic(error)) = Instrument::instrument(parse, span).await {
event_broker.emit(Event::Failed(error));
};
}
};
let ccf = tokio::select! {
_ = normal_deliver_and_parse_loop => return,
error = conn_state.terminated() => match conn_state.current() {
state if state == Some(state::CLOSING) => ConnectionCloseFrame::from(error),
_ => return
}
};
let terminator = Terminator::new(ccf, &components);
drop(components);
while let Some((packet, (_bind_uri, pathway, _link))) = packets.recv().await {
if let Some(ccf) = parse_closing_packet(&space, packet) {
event_broker.emit(Event::Closed(ccf));
}
if terminator.should_send() {
terminator
.try_send_on(pathway, |buffer, ccf| {
assemble_closing_packet(space.as_ref(), &terminator, buffer, ccf)
})
.await
}
}
}
pub struct HandshakeTracker {
journal: HandshakeJournal,
crypto_stream: CryptoStream,
}
impl Feedback for HandshakeTracker {
fn may_loss(&self, trigger: PacketLostTrigger, pns: &mut dyn Iterator<Item = u64>) {
let sent_jornal = self.journal.of_sent_packets();
let outgoing = self.crypto_stream.outgoing();
let mut sent_packets = sent_jornal.rotate();
for pn in pns {
let mut may_lost_frames = QuicFramesCollector::<PacketLost>::new();
for frame in sent_packets.may_loss_packet(pn) {
may_lost_frames.extend([&frame]);
outgoing.may_loss_data(&frame);
}
qevent::event!(PacketLost {
header: PacketHeader {
packet_type: PacketType::Handshake,
packet_number: pn
},
frames: may_lost_frames,
trigger
});
}
}
}