use std::{
io,
ops::Deref,
sync::{Arc, atomic::Ordering::Acquire},
};
use bytes::BufMut;
use derive_more::From;
use qbase::{
Epoch, GetEpoch,
cid::{BorrowedCid, ConnectionId},
frame::PingFrame,
net::tx::{ArcSendWaker, Signals},
packet::{
AssemblePacket, Package, PacketProperties, ProductHeader,
header::{
long::{HandshakeHeader, InitialHeader, ZeroRttHeader, io::LongHeaderBuilder},
short::OneRttHeader,
},
io::{Packages, PadProbe, PadTo20, PadToFull, Repeat},
signal::SpinBit,
},
role::Role,
token::TokenRegistry,
};
use qcongestion::{ArcCC, Transport};
use qinterface::io::IO;
use qrecovery::journal::{AckPackege, ArcRcvdJournal, Journal};
use qtraversal::packet::{ForwardHeader, WriteForwardHeader};
use crate::{
ArcDcidCell, ArcReliableFrameDeque, CidRegistry, Components,
path::{AntiAmplifier, Constraints},
space::{Spaces, data::DataSpace, handshake::HandshakeSpace, initial::InitialSpace},
tls::ArcTlsHandshake,
tx::PacketWriter,
};
pub trait PacketSpace<H> {
type JournalFrame;
fn new_packet<'b, 's>(
&'s self,
header: H,
cc: &ArcCC,
buffer: &'b mut [u8],
) -> Result<PacketWriter<'b, 's, Self::JournalFrame>, Signals>;
}
pub struct Burst {
path: Arc<super::Path>,
initial_token: Vec<u8>,
cid_registry: CidRegistry,
spin: bool,
spaces: Spaces,
tls_handshake: ArcTlsHandshake,
}
impl super::Path {
pub fn new_burst(self: &Arc<Self>, components: &Components) -> Burst {
Burst {
path: self.clone(),
initial_token: match components.token_registry.deref() {
TokenRegistry::Client((server_name, token_sink)) => {
token_sink.fetch_token(server_name)
}
TokenRegistry::Server(..) => vec![],
},
cid_registry: components.cid_registry.clone(),
spin: false, spaces: components.spaces.clone(),
tls_handshake: components.tls_handshake.clone(),
}
}
}
#[derive(From)]
pub enum BurstError {
Signals(Signals),
PathDeactived,
}
pub struct PacketsAssembler<'a> {
cc: &'a ArcCC,
constraints: Constraints,
cid_registry: &'a CidRegistry,
borrowed_dcid: Result<BorrowedCid<'a, ArcReliableFrameDeque>, Signals>,
initial_token: &'a [u8],
spin: SpinBit,
}
impl<'a> PacketsAssembler<'a> {
fn new(
cid_registry: &'a CidRegistry,
dcid_cell: &'a ArcDcidCell,
anti_amplifier: &AntiAmplifier,
cc: &'a ArcCC,
tx_waker: ArcSendWaker,
initial_token: &'a [u8],
spin: impl Into<SpinBit>,
) -> Result<PacketsAssembler<'a>, BurstError> {
let send_quota = cc.send_quota()?;
let Some(credit_limit) = anti_amplifier.balance()? else {
return Err(BurstError::PathDeactived);
};
let Some(borrowed_dcid) = dcid_cell.borrow_cid(tx_waker).transpose() else {
return Err(BurstError::PathDeactived);
};
let constraints = Constraints::new(credit_limit, send_quota);
Ok(Self {
cid_registry,
borrowed_dcid,
cc,
constraints,
initial_token,
spin: spin.into(),
})
}
fn initial_scid(&self) -> Result<ConnectionId, Signals> {
self.cid_registry
.local
.initial_scid()
.ok_or(Signals::empty())
}
fn applied_dcid(&self) -> Result<ConnectionId, Signals> {
self.borrowed_dcid.as_deref().copied().map_err(|e| *e)
}
fn initial_dcid(&self) -> Result<ConnectionId, Signals> {
match self.cid_registry.role() {
Role::Client => Ok(self.cid_registry.origin_dcid()),
Role::Server => self.applied_dcid(),
}
}
pub fn commit(&mut self, sent_bytes: usize, packet_props: PacketProperties) {
self.constraints
.commit(sent_bytes, packet_props.in_flight());
self.cc.on_pkt_sent(
packet_props.epoch().expect("todo"),
packet_props.packet_number(),
packet_props.ack_eliciting(),
sent_bytes,
packet_props.in_flight(),
packet_props.largest_ack(),
);
}
}
impl ProductHeader<InitialHeader> for PacketsAssembler<'_> {
fn new_header(&self) -> Result<InitialHeader, Signals> {
Ok(
LongHeaderBuilder::with_cid(self.initial_dcid()?, self.initial_scid()?)
.initial(self.initial_token.to_vec()),
)
}
}
impl ProductHeader<ZeroRttHeader> for PacketsAssembler<'_> {
fn new_header(&self) -> Result<ZeroRttHeader, Signals> {
Ok(LongHeaderBuilder::with_cid(self.initial_dcid()?, self.initial_scid()?).zero_rtt())
}
}
impl ProductHeader<HandshakeHeader> for PacketsAssembler<'_> {
fn new_header(&self) -> Result<HandshakeHeader, Signals> {
Ok(LongHeaderBuilder::with_cid(self.applied_dcid()?, self.initial_scid()?).handshake())
}
}
impl ProductHeader<OneRttHeader> for PacketsAssembler<'_> {
fn new_header(&self) -> Result<OneRttHeader, Signals> {
Ok(OneRttHeader::new(self.spin, self.applied_dcid()?))
}
}
impl<'a> PacketsAssembler<'a> {
pub fn assemble<'s, 'b, H, Space, P>(
&mut self,
space: &'s Space,
data_sources: P,
buffer: &'b mut [u8],
) -> Result<usize, Signals>
where
Self: ProductHeader<H>,
Space: PacketSpace<H> + GetEpoch,
Space::JournalFrame: 's,
P: Package<PacketWriter<'b, 's, Space::JournalFrame>>,
{
let buffer = self.constraints.constrain(buffer);
let mut packet = space.new_packet(self.new_header()?, self.cc, buffer)?;
packet.assemble_packet(&mut Packages((data_sources, PadTo20)))?;
let (sent_bytes, props) = packet.encrypt_and_protect_packet();
self.commit(sent_bytes, props);
Result::<_, Signals>::Ok(sent_bytes)
}
}
pub type PackageIntoSpace<H, S> =
dyn for<'b, 's> Package<PacketWriter<'b, 's, <S as PacketSpace<H>>::JournalFrame>> + Send;
pub struct DataSources {
initial: Box<PackageIntoSpace<InitialHeader, InitialSpace>>,
zero_rtt: Box<PackageIntoSpace<ZeroRttHeader, DataSpace>>,
handshake: Box<PackageIntoSpace<HandshakeHeader, HandshakeSpace>>,
one_rtt: Box<PackageIntoSpace<OneRttHeader, DataSpace>>,
}
impl Components {
pub(super) fn packages(&self) -> DataSources {
let initial_packages = self.crypto_streams[Epoch::Initial]
.outgoing()
.package(Epoch::Initial);
let zero_rtt_packages = Packages((
Repeat(self.reliable_frames.clone()),
Repeat(
self.data_streams
.package(self.flow_ctrl.sender.clone(), true),
),
));
let handshake_packages = self.crypto_streams[Epoch::Handshake]
.outgoing()
.package(Epoch::Handshake);
let one_rtt_packages = Packages((
self.crypto_streams[Epoch::Data]
.outgoing()
.package(Epoch::Data),
Repeat(self.reliable_frames.clone()),
Repeat(
self.data_streams
.package(self.flow_ctrl.sender.clone(), false),
),
));
DataSources {
initial: Box::new(initial_packages),
zero_rtt: Box::new(zero_rtt_packages),
handshake: Box::new(handshake_packages),
one_rtt: Box::new(one_rtt_packages),
}
}
}
fn ack_package<'s, S, F>(space: &'s S, cc: &ArcCC) -> AckPackege<'s>
where
S: GetEpoch + AsRef<Journal<F>>,
F: 's,
{
ArcRcvdJournal::ack_package(space.as_ref().as_ref(), cc.need_ack(space.epoch()))
}
impl Burst {
fn assembler<'a>(&'a self) -> Result<PacketsAssembler<'a>, BurstError> {
PacketsAssembler::new(
&self.cid_registry,
&self.path.dcid_cell,
&self.path.anti_amplifier,
&self.path.cc,
self.path.tx_waker.clone(),
&self.initial_token,
self.spin,
)
}
fn load_spaces(
&self,
DataSources {
initial: initial_data_sources,
zero_rtt: zero_rtt_data_sources,
handshake: handshake_data_sources,
one_rtt: one_rtt_data_sources,
}: &mut DataSources,
mut buffer: &mut [u8],
) -> Result<usize, BurstError> {
let Self {
path,
spaces,
tls_handshake,
..
} = self;
let initial_space = spaces.initial().as_ref();
let handshake_space = spaces.handshake().as_ref();
let data_space = spaces.data().as_ref();
let origin = buffer.remaining_mut();
let mut assembler = self.assembler()?;
let mut signals = Signals::empty();
let Ok(tls_fin) = tls_handshake.is_finished() else {
return Err(BurstError::PathDeactived);
};
match assembler.assemble(
initial_space,
&mut Packages((ack_package(initial_space, &path.cc), initial_data_sources)),
buffer,
) {
Ok(bytes_sent) => buffer = buffer[bytes_sent..].as_mut(),
Err(s) => signals |= s,
};
let loaded_initial = buffer.remaining_mut() != origin;
if !tls_fin {
match assembler.assemble::<ZeroRttHeader, _, _>(
data_space,
zero_rtt_data_sources,
buffer,
) {
Ok(bytes_sent) => buffer = buffer[bytes_sent..].as_mut(),
Err(s) => signals |= s,
}
}
match assembler.assemble(
handshake_space,
&mut Packages((
ack_package(handshake_space, &path.cc),
handshake_data_sources,
)),
buffer,
) {
Ok(bytes_sent) => buffer = buffer[bytes_sent..].as_mut(),
Err(s) => signals |= s,
}
if tls_fin {
let result = if path.validated.load(Acquire) {
assembler.assemble::<OneRttHeader, _, _>(
data_space,
&mut Packages((
ack_package(data_space, &path.cc),
&path.challenge_sndbuf,
&path.response_sndbuf,
one_rtt_data_sources,
loaded_initial.then_some(PadToFull),
PadProbe,
)),
buffer,
)
} else {
assembler.assemble::<OneRttHeader, _, _>(
data_space,
&mut Packages((
ack_package(data_space, &path.cc),
&path.challenge_sndbuf,
&path.response_sndbuf,
loaded_initial.then_some(PadToFull),
PadProbe,
)),
buffer,
)
};
match result {
Ok(bytes_sent) => buffer = buffer[bytes_sent..].as_mut(),
Err(s) => signals |= s,
}
}
if loaded_initial {
assert!(buffer.remaining_mut() != origin);
buffer.put_bytes(0, buffer.remaining_mut());
return Ok(origin);
}
let sent_bytes = origin - buffer.remaining_mut();
(sent_bytes > 0)
.then_some(sent_bytes)
.ok_or(BurstError::Signals(signals))
}
}
struct PingSource {
need_send_ack_eliciting: usize,
}
impl<Target: ?Sized> Package<Target> for PingSource
where
PingFrame: Package<Target>,
{
fn dump(&mut self, target: &mut Target) -> Result<(), Signals> {
if self.need_send_ack_eliciting > 0 {
return PingFrame.dump(target);
}
Err(Signals::PING)
}
}
fn ping_package(cc: &ArcCC, epoch: Epoch) -> PingSource {
PingSource {
need_send_ack_eliciting: cc.need_send_ack_eliciting(epoch),
}
}
impl Burst {
fn load_ping(&self, buffer: &mut [u8]) -> Result<usize, BurstError> {
let Self { spaces, path, .. } = self;
let mut assembler = self.assembler()?;
let mut signals = Signals::empty();
for &epoch in Epoch::iter().rev() {
let result = match epoch {
Epoch::Data => {
let ack_package = ack_package(spaces.data().as_ref(), &path.cc);
let ping_package = ping_package(&path.cc, epoch);
assembler.assemble::<OneRttHeader, _, _>(
spaces.data().as_ref(),
&mut Packages((ack_package, ping_package, PadToFull)),
buffer,
)
}
Epoch::Handshake => {
let ack_package = ack_package(spaces.handshake().as_ref(), &path.cc);
let ping_package = ping_package(&path.cc, epoch);
assembler.assemble(
spaces.handshake().as_ref(),
&mut Packages((ack_package, ping_package, PadToFull)),
buffer,
)
}
Epoch::Initial => {
let ack_package = ack_package(spaces.initial().as_ref(), &path.cc);
let ping_package = ping_package(&path.cc, epoch);
assembler.assemble(
spaces.initial().as_ref(),
&mut Packages((ack_package, ping_package, PadToFull)),
buffer,
)
}
};
match result {
Ok(sent_bytes) => return Ok(sent_bytes),
Err(s) => signals |= s,
}
}
Err(BurstError::Signals(signals))
}
fn load_heartbeat(&self, buffer: &mut [u8]) -> Result<usize, BurstError> {
let Self { spaces, path, .. } = self;
let mut assembler = self.assembler()?;
Ok(assembler.assemble::<OneRttHeader, _, _>(
spaces.data().as_ref(),
&path.heartbeat,
buffer,
)?)
}
pub async fn burst<'b>(
&self,
data_sources: &mut DataSources,
buffers: &'b mut Vec<Vec<u8>>,
) -> Result<Vec<io::IoSlice<'b>>, BurstError> {
let Ok(max_segments) = self.path.interface.max_segments() else {
return Err(BurstError::PathDeactived);
};
let Ok(max_segment_size) = self.path.interface.max_segment_size() else {
return Err(BurstError::PathDeactived);
};
if buffers.len() < max_segments {
buffers.resize_with(max_segments, || vec![0; max_segment_size]);
}
use core::ops::ControlFlow::*;
let reversed_size =
ForwardHeader::encoding_size(&self.path.pathway.try_into().expect("BLE"));
let (Break(result) | Continue(result)) = buffers
.iter_mut()
.map(move |buffer| {
if buffer.len() < max_segment_size {
buffer.resize(max_segment_size, 0);
}
&mut buffer[..max_segment_size]
})
.map(move |segment| {
let buffer_size = segment.len().min(self.path.mtu() as _);
let buffer = &mut segment[..buffer_size][reversed_size..];
self.load_spaces(data_sources, buffer)
.inspect(|_| {
self.path.heartbeat.renew_on_effective_communicated();
})
.or_else(|error| match error {
BurstError::Signals(signals) => {
self.load_ping(buffer).map_err(|e| match e {
BurstError::Signals(s) => BurstError::Signals(signals | s),
e @ BurstError::PathDeactived => e,
})
}
e @ BurstError::PathDeactived => Err(e),
})
.or_else(|error| match error {
BurstError::Signals(signals) => {
self.load_heartbeat(buffer).map_err(|e| match e {
BurstError::Signals(s) => BurstError::Signals(signals | s),
e @ BurstError::PathDeactived => e,
})
}
e @ BurstError::PathDeactived => Err(e),
})
.map(|packet_size| {
if reversed_size > 0 {
let (mut header, payload) = segment.split_at_mut(reversed_size);
let forward_hdr = ForwardHeader::new(
0,
&self.path.pathway.try_into().unwrap(),
payload,
);
tracing::trace!(?forward_hdr, link=%self.path.link(),"put forward header");
header.put_forward_header(&forward_hdr);
}
io::IoSlice::new(&segment[..reversed_size + packet_size])
})
})
.try_fold(
Ok(Vec::with_capacity(max_segments)),
|segments, load_result| match (segments, load_result) {
(Ok(segments), Err(signals)) if segments.is_empty() => Break(Err(signals)),
(Ok(segments), Err(_signals)) => Break(Ok(segments)),
(Ok(mut segments), Ok(segment))
if segment.len() < segments.last().copied().unwrap_or_default() =>
{
segments.push(segment.len());
Break(Ok(segments))
}
(Ok(mut segments), Ok(segment)) => {
segments.push(segment.len());
Continue(Ok(segments))
}
(Err(_), _) => unreachable!("segments should not be Err in this context"),
},
);
Ok(result?
.iter()
.zip(buffers)
.map(|(&len, buffer)| io::IoSlice::new(&buffer[..len]))
.collect())
}
}