use crate::crypto::hybrid_sign::HybridVerifyingKey;
use crate::errors::CoreError;
use crate::observability::attrs::{AeadAlgorithm, ReplayReason};
use crate::observability::{Observability, ObservabilityConfig};
use crate::runtime::{Runtime, TokioRuntime};
use crate::transport::handshake::{
HandshakeClient, HelloRetryRequest, ServerHello, ServerReject, EARLY_DATA_MAX_LEN,
};
use crate::transport::multiplexer::StreamDemultiplexer;
use crate::transport::packet_coalescer_codec::unwrap_coalesced_packet;
use crate::transport::path_validation_codec::build_path_validation_packet;
use crate::transport::session::Session;
use crate::transport::stream::Stream;
use crate::transport::types::{
LegType, PacketFlags, PacketHeader, PhantomPacket, SessionId, StreamId as TransportStreamId,
WIRE_VERSION,
};
use bytes::Bytes;
use dashmap::DashMap;
use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot, Mutex};
fn new_session_id() -> String {
let bytes: [u8; 16] = rand::random();
format!("phantom-{}", hex::encode(bytes))
}
#[cfg_attr(feature = "bindings", derive(uniffi::Enum))]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
#[non_exhaustive]
pub enum ConnectionState {
Connecting = 0,
ClassicalReady = 1,
PqcUpgrading = 2,
PqcReady = 3,
Connected = 4,
Failed = 5,
Closed = 6,
}
impl ConnectionState {
fn from_u8(v: u8) -> Self {
match v {
0 => Self::Connecting,
1 => Self::ClassicalReady,
2 => Self::PqcUpgrading,
3 => Self::PqcReady,
4 => Self::Connected,
5 => Self::Failed,
6 => Self::Closed,
_ => Self::Failed,
}
}
pub fn is_data_ready(&self) -> bool {
matches!(
self,
Self::ClassicalReady | Self::PqcUpgrading | Self::PqcReady | Self::Connected
)
}
}
#[cfg_attr(feature = "bindings", derive(uniffi::Record))]
#[derive(Clone)]
#[non_exhaustive]
pub struct ResumptionHint {
pub session_id: Vec<u8>,
pub resumption_secret: Vec<u8>,
}
impl std::fmt::Debug for ResumptionHint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResumptionHint")
.field(
"session_id",
&format_args!("<{} bytes>", self.session_id.len()),
)
.field("resumption_secret", &"REDACTED")
.finish()
}
}
pub use crate::transport::session_transport::{FramePhase, SessionTransport};
struct ObservedTransport<T> {
inner: T,
observability: Arc<Observability>,
leg: LegType,
}
impl<T> ObservedTransport<T> {
fn new(inner: T, observability: Arc<Observability>, leg: LegType) -> Self {
Self {
inner,
observability,
leg,
}
}
}
impl<T: SessionTransport> SessionTransport for ObservedTransport<T> {
async fn send_bytes(&self, data: &[u8]) -> Result<(), CoreError> {
let result = self.inner.send_bytes(data).await;
if result.is_ok() {
self.observability.record_send(data.len(), self.leg);
}
result
}
async fn recv_bytes(&self) -> Result<Bytes, CoreError> {
let result = self.inner.recv_bytes().await;
if let Ok(ref bytes) = result {
self.observability.record_recv(bytes.len(), self.leg);
}
result
}
}
#[cfg_attr(feature = "bindings", derive(uniffi::Object))]
pub struct PhantomSession {
id: String,
peer_addr: String,
state: Arc<AtomicU8>,
send_queue: Arc<Mutex<Vec<Vec<u8>>>>,
cmd_tx: mpsc::Sender<SessionCommand>,
#[allow(dead_code)]
cmd_rx: Mutex<Option<mpsc::Receiver<SessionCommand>>>,
recv_rx: Mutex<mpsc::Receiver<Bytes>>,
demux: Arc<StreamDemultiplexer>,
streams: Arc<DashMap<u32, Arc<Stream>>>,
inner_session: Arc<Mutex<Option<Arc<Session>>>>,
early_data_accepted: Arc<Mutex<Option<bool>>>,
observability: Arc<Observability>,
}
pub enum SessionCommand {
Send(Vec<u8>),
SendStreamReliable { stream_id: u32, data: bytes::Bytes },
SendStreamUnreliable { stream_id: u32, data: bytes::Bytes },
CloseStream { stream_id: u32 },
Close,
}
impl PhantomSession {
pub fn connect_with_transport<T: SessionTransport>(
peer_addr: &str,
transport: T,
expected_server_key: HybridVerifyingKey,
) -> Self {
Self::connect_with_transport_with_runtime(
peer_addr,
transport,
expected_server_key,
Arc::new(TokioRuntime),
)
}
pub fn connect_with_transport_with_runtime<T: SessionTransport>(
peer_addr: &str,
transport: T,
expected_server_key: HybridVerifyingKey,
runtime: Arc<dyn Runtime>,
) -> Self {
Self::spawn_client(peer_addr, transport, expected_server_key, runtime, None)
}
pub fn connect_with_resumption<T: SessionTransport>(
peer_addr: &str,
transport: T,
expected_server_key: HybridVerifyingKey,
resumption_hint: ([u8; 32], [u8; 32]),
early_data: Vec<u8>,
) -> Result<Self, CoreError> {
#[cfg(feature = "fips")]
crate::crypto::self_tests::ensure_post_passed()
.map_err(|e| CoreError::FipsSelfTestFailure(format!("{e:?}")))?;
if early_data.len() > EARLY_DATA_MAX_LEN {
return Err(CoreError::ValidationError(format!(
"early_data is {} bytes, exceeds the {}-byte 0-RTT cap",
early_data.len(),
EARLY_DATA_MAX_LEN
)));
}
let (resume_id, resume_secret) = resumption_hint;
Ok(Self::spawn_client(
peer_addr,
transport,
expected_server_key,
Arc::new(TokioRuntime),
Some((resume_id, resume_secret, early_data)),
))
}
fn spawn_client<T: SessionTransport>(
peer_addr: &str,
transport: T,
expected_server_key: HybridVerifyingKey,
runtime: Arc<dyn Runtime>,
resumption_request: Option<([u8; 32], [u8; 32], Vec<u8>)>,
) -> Self {
let (cmd_tx, cmd_rx) = mpsc::channel(256);
let (recv_tx, recv_rx) = mpsc::channel(256);
let state = Arc::new(AtomicU8::new(ConnectionState::Connecting as u8));
let send_queue = Arc::new(Mutex::new(Vec::new()));
let peer = peer_addr.to_string();
let (demux, _ctrl_rx) = StreamDemultiplexer::new(256);
let demux = Arc::new(demux);
let streams = Arc::new(DashMap::new());
let inner_session: Arc<Mutex<Option<Arc<Session>>>> = Arc::new(Mutex::new(None));
let early_data_accepted: Arc<Mutex<Option<bool>>> = Arc::new(Mutex::new(None));
let observability = Observability::new(ObservabilityConfig::default());
let session = Self {
id: new_session_id(),
peer_addr: peer.clone(),
state: state.clone(),
send_queue: send_queue.clone(),
cmd_tx: cmd_tx.clone(),
cmd_rx: Mutex::new(None), recv_rx: Mutex::new(recv_rx),
demux: demux.clone(),
streams: streams.clone(),
inner_session: inner_session.clone(),
early_data_accepted: early_data_accepted.clone(),
observability: observability.clone(),
};
let runtime_for_pump = runtime.clone();
let _detached = runtime.spawn(Box::pin(Self::background_task(
state,
send_queue,
cmd_tx,
cmd_rx,
recv_tx,
transport,
peer,
demux,
streams,
expected_server_key,
runtime_for_pump,
inner_session,
early_data_accepted,
resumption_request,
observability,
)));
session
}
#[allow(dead_code)]
pub(crate) fn from_accepted_server_session<T: SessionTransport>(
peer_addr: String,
transport: T,
server_session: Arc<Session>,
) -> Arc<Self> {
Self::from_accepted_server_session_with_runtime(
peer_addr,
transport,
server_session,
Arc::new(TokioRuntime),
Observability::new(ObservabilityConfig::default()),
)
}
pub(crate) fn from_accepted_server_session_with_runtime<T: SessionTransport>(
peer_addr: String,
transport: T,
server_session: Arc<Session>,
runtime: Arc<dyn Runtime>,
observability: Arc<Observability>,
) -> Arc<Self> {
let (cmd_tx, cmd_rx) = mpsc::channel(256);
let (recv_tx, recv_rx) = mpsc::channel(256);
let state = Arc::new(AtomicU8::new(ConnectionState::Connected as u8));
let send_queue = Arc::new(Mutex::new(Vec::new()));
let (demux, _ctrl_rx) = StreamDemultiplexer::new(256);
let demux = Arc::new(demux);
let streams = Arc::new(DashMap::new());
let inner_session: Arc<Mutex<Option<Arc<Session>>>> =
Arc::new(Mutex::new(Some(server_session.clone())));
let session = Arc::new(Self {
id: new_session_id(),
peer_addr: peer_addr.clone(),
state: state.clone(),
send_queue: send_queue.clone(),
cmd_tx,
cmd_rx: Mutex::new(None),
recv_rx: Mutex::new(recv_rx),
demux: demux.clone(),
streams: streams.clone(),
inner_session,
early_data_accepted: Arc::new(Mutex::new(None)),
observability: observability.clone(),
});
let session_id = *server_session.id();
let runtime_for_pump = runtime.clone();
transport.set_frame_phase(FramePhase::Established);
let observed = Arc::new(ObservedTransport::new(
transport,
observability.clone(),
LegType::Tcp,
));
let _detached = runtime.spawn(Box::pin(run_data_pump(
server_session,
session_id,
observed,
state,
send_queue,
cmd_rx,
recv_tx,
demux,
streams,
runtime_for_pump,
observability,
LegType::Tcp,
)));
session
}
#[allow(clippy::too_many_arguments)]
async fn background_task<T: SessionTransport>(
state: Arc<AtomicU8>,
send_queue: Arc<Mutex<Vec<Vec<u8>>>>,
_cmd_tx: mpsc::Sender<SessionCommand>,
cmd_rx: mpsc::Receiver<SessionCommand>,
recv_tx: mpsc::Sender<Bytes>,
transport: T,
peer: String,
demux: Arc<StreamDemultiplexer>,
streams: Arc<DashMap<u32, Arc<Stream>>>,
expected_server_key: HybridVerifyingKey,
runtime: Arc<dyn Runtime>,
inner_session: Arc<Mutex<Option<Arc<Session>>>>,
early_data_accepted: Arc<Mutex<Option<bool>>>,
resumption_request: Option<([u8; 32], [u8; 32], Vec<u8>)>,
observability: Arc<Observability>,
) {
log::debug!("PhantomSession: starting handshake with {}", peer);
#[cfg(feature = "fips")]
if let Err(e) = crate::crypto::self_tests::ensure_post_passed() {
log::error!(
"PhantomSession: FIPS POST self-test failed; refusing to handshake: {:?}",
e
);
state.store(ConnectionState::Failed as u8, Ordering::Relaxed);
return;
}
let pending_early_data: Option<Vec<u8>> = resumption_request
.as_ref()
.and_then(|(_, _, ed)| (!ed.is_empty()).then(|| ed.clone()));
const CLIENT_HANDSHAKE_DEADLINE: std::time::Duration = std::time::Duration::from_secs(10);
let handshake_result = {
let handshake_fut =
run_client_handshake(&transport, &expected_server_key, resumption_request);
let handshake_timeout = runtime.sleep(CLIENT_HANDSHAKE_DEADLINE);
tokio::pin!(handshake_fut);
tokio::select! {
r = &mut handshake_fut => r,
_ = handshake_timeout => Err(CoreError::Timeout),
}
};
let (crypto_session, ed_accepted) = match handshake_result {
Ok((session, accepted)) => (Arc::new(session), accepted),
Err(e) => {
log::error!("PhantomSession: handshake failed: {}", e);
state.store(ConnectionState::Failed as u8, Ordering::Relaxed);
return;
}
};
log::info!("PhantomSession: Handshake complete — hybrid channel ready");
{
let mut guard = inner_session.lock().await;
*guard = Some(crypto_session.clone());
}
*early_data_accepted.lock().await = ed_accepted;
if ed_accepted == Some(false) {
if let Some(ed) = pending_early_data {
send_queue.lock().await.insert(0, ed);
log::debug!(
"PhantomSession: 0-RTT early-data rejected; re-queued for 1-RTT delivery"
);
}
}
let session_id = *crypto_session.id();
state.store(ConnectionState::Connected as u8, Ordering::Relaxed);
log::debug!("PhantomSession: fully connected to {}", peer);
transport.set_frame_phase(FramePhase::Established);
let observed = Arc::new(ObservedTransport::new(
transport,
observability.clone(),
LegType::Tcp,
));
run_data_pump(
crypto_session,
session_id,
observed,
state,
send_queue,
cmd_rx,
recv_tx,
demux,
streams,
runtime,
observability,
LegType::Tcp,
)
.await;
}
}
async fn run_client_handshake<T: SessionTransport>(
transport: &T,
expected_server_key: &HybridVerifyingKey,
resumption: Option<([u8; 32], [u8; 32], Vec<u8>)>,
) -> Result<(Session, Option<bool>), CoreError> {
let handshake = HandshakeClient::new()?;
let mut hello = match &resumption {
Some((resume_id, resume_secret, early_data)) => {
let ed: Option<&[u8]> = if early_data.is_empty() {
None
} else {
Some(early_data.as_slice())
};
handshake.create_client_hello_with_resume(*resume_id, resume_secret, ed)
}
None => handshake.create_client_hello(),
};
const MAX_CLIENT_RETRY_ROUNDS: u32 = 3;
let mut retry_rounds: u32 = 0;
loop {
let bytes = borsh::to_vec(&hello).map_err(|e| {
CoreError::SerializationError(format!("ClientHello encode failed: {}", e))
})?;
transport.send_bytes(&bytes).await?;
let resp = transport.recv_bytes().await?;
if let Ok(sh) = borsh::from_slice::<ServerHello>(&resp) {
let (session, accepted) =
handshake.process_server_hello(&hello, &sh, Some(expected_server_key))?;
return Ok((session, accepted));
} else if let Ok(reject) = borsh::from_slice::<ServerReject>(&resp) {
if reject.has_marker() {
return Err(CoreError::HandshakeError(format!(
"server rejected the handshake: unsupported protocol version \
(client speaks v{}, server speaks v{})",
hello.version, reject.supported_version
)));
}
return Err(CoreError::HandshakeError(
"invalid ServerHello, Retry, or Reject received".into(),
));
} else if let Ok(retry) = borsh::from_slice::<HelloRetryRequest>(&resp) {
retry_rounds += 1;
if retry_rounds > MAX_CLIENT_RETRY_ROUNDS {
return Err(CoreError::HandshakeError(format!(
"server demanded more than {MAX_CLIENT_RETRY_ROUNDS} HelloRetryRequest rounds"
)));
}
log::info!("PhantomSession: Received HelloRetryRequest, retrying...");
hello.cookie = retry.cookie;
if let Some(challenge) = retry.challenge {
log::info!("PhantomSession: Solving PoW challenge...");
hello.pow_solution = Some(
challenge
.solve_capped(crate::crypto::pow::MAX_CLIENT_POW_DIFFICULTY)
.map_err(|e| CoreError::HandshakeError(e.to_string()))?,
);
}
continue;
} else {
return Err(CoreError::HandshakeError(
"invalid ServerHello, Retry, or Reject received".into(),
));
}
}
}
#[allow(clippy::too_many_arguments)]
async fn run_data_pump<T: SessionTransport>(
crypto_session: Arc<Session>,
session_id: SessionId,
transport: Arc<T>,
state: Arc<AtomicU8>,
send_queue: Arc<Mutex<Vec<Vec<u8>>>>,
mut cmd_rx: mpsc::Receiver<SessionCommand>,
recv_tx: mpsc::Sender<Bytes>,
demux: Arc<StreamDemultiplexer>,
streams: Arc<DashMap<u32, Arc<Stream>>>,
runtime: Arc<dyn Runtime>,
observability: Arc<Observability>,
leg: LegType,
) {
observability.session_opened(leg);
const RAW_APP_STREAM_ID: u32 = 1;
let raw_stream = Arc::new(Stream::new(RAW_APP_STREAM_ID as TransportStreamId));
streams.insert(RAW_APP_STREAM_ID, raw_stream.clone());
{
let mut queue = send_queue.lock().await;
let count = queue.len();
for msg in queue.drain(..) {
for chunk in msg.chunks(TRANSPORT_MTU) {
raw_stream
.send_reliable(Bytes::copy_from_slice(chunk))
.await;
}
}
if count > 0 {
log::info!(
"PhantomSession: queued {} early-data message(s) onto the raw-app stream",
count
);
crypto_session.notify_outbound_ready();
}
}
let (deliver_tx, mut deliver_rx) = mpsc::unbounded_channel::<(u32, Bytes)>();
let undelivered_bytes = Arc::new(AtomicU64::new(0));
{
let recv_tx_deliver = recv_tx; let demux_deliver = demux.clone();
let streams_deliver = streams.clone();
let crypto_deliver = crypto_session.clone();
let undelivered_deliver = undelivered_bytes.clone();
runtime.spawn(Box::pin(async move {
while let Some((stream_id, bytes)) = deliver_rx.recv().await {
let len = bytes.len() as u64;
demux_deliver.route_data(stream_id, bytes.clone());
undelivered_deliver.fetch_sub(len, Ordering::AcqRel);
if let Some(stream) = streams_deliver.get(&stream_id) {
if let Some(credit) = stream.record_app_consumed(len as u32) {
stream.stage_window_update_credit(credit);
crypto_deliver.notify_outbound_ready();
}
}
if recv_tx_deliver.send(bytes).await.is_err() {
break;
}
}
}));
}
let transport_recv = transport.clone();
let transport_send_ack = transport.clone();
let crypto_recv = crypto_session.clone();
let demux_recv = demux.clone();
let streams_recv = streams.clone();
let undelivered_reader = undelivered_bytes.clone();
let observability_recv = observability.clone();
let (recv_done_tx, mut recv_done_rx) = oneshot::channel::<()>();
let transport_for_path = transport.clone();
let recv_handle = runtime.spawn(Box::pin(async move {
let mut ack_buf: Vec<u8> = Vec::with_capacity(256);
let mut path_validation_seq: u32 = 0;
const RECV_DELIVERY_HARD_CAP: u64 = 4 * 1024 * 1024;
loop {
if undelivered_reader.load(Ordering::Acquire) > RECV_DELIVERY_HARD_CAP {
log::warn!(
"PhantomSession: receive backlog {} B exceeds cap — peer ignoring flow \
control; closing session",
undelivered_reader.load(Ordering::Acquire)
);
break;
}
let data = match transport_recv.recv_bytes().await {
Ok(b) => b,
Err(_) => break,
};
let packet = match PhantomPacket::from_wire(&data) {
Ok(v) => v,
Err(_) => continue,
};
if packet.header.version != WIRE_VERSION {
continue;
}
handle_packet(
packet,
session_id,
&crypto_recv,
&streams_recv,
&demux_recv,
&transport_send_ack,
&transport_for_path,
&deliver_tx,
&undelivered_reader,
&mut ack_buf,
&mut path_validation_seq,
&observability_recv,
leg,
)
.await;
}
drop(deliver_tx);
let _ = recv_done_tx.send(());
}));
const TRANSPORT_MTU: usize = 1300;
let mut poll_interval = tokio::time::interval(std::time::Duration::from_millis(10));
let send_notify = crypto_session.send_notifier();
loop {
tokio::select! {
_ = poll_interval.tick() => {
flush_pending_window_updates(
&transport, &crypto_session, session_id, &streams,
)
.await;
drain_streams_priority_ordered(
&transport,
&crypto_session,
session_id,
&streams,
)
.await;
}
_ = send_notify.notified() => {
flush_pending_window_updates(
&transport, &crypto_session, session_id, &streams,
)
.await;
drain_streams_priority_ordered(
&transport,
&crypto_session,
session_id,
&streams,
)
.await;
}
cmd_opt = cmd_rx.recv() => {
match cmd_opt {
Some(SessionCommand::Send(data)) => {
for chunk in data.chunks(TRANSPORT_MTU) {
raw_stream
.send_reliable(Bytes::copy_from_slice(chunk))
.await;
}
crypto_session.notify_outbound_ready();
}
Some(SessionCommand::SendStreamReliable { stream_id, data }) => {
if let Some(stream) = streams.get(&stream_id) {
for chunk in data.chunks(TRANSPORT_MTU) {
stream.send_reliable(Bytes::copy_from_slice(chunk)).await;
}
}
}
Some(SessionCommand::SendStreamUnreliable { stream_id, data }) => {
if let Some(stream) = streams.get(&stream_id) {
for chunk in data.chunks(TRANSPORT_MTU) {
stream.send_unreliable(Bytes::copy_from_slice(chunk)).await;
}
}
}
Some(SessionCommand::CloseStream { stream_id }) => {
if let Some(stream) = streams.get(&stream_id) {
stream.finish().await;
let seq = stream.next_send_sequence();
let _ = send_app_data(
&transport,
&crypto_session,
session_id,
stream_id as TransportStreamId,
seq,
&[],
PacketFlags::FIN,
).await;
}
streams.remove(&stream_id);
demux.close_stream(stream_id);
}
Some(SessionCommand::Close) => {
log::info!("PhantomSession: closing");
break;
}
None => {
log::info!("PhantomSession: command channel dropped");
break;
}
}
}
_ = &mut recv_done_rx => {
log::error!("PhantomSession: receive task ended unexpectedly (transport closed)");
break;
}
}
}
recv_handle.abort();
state.store(ConnectionState::Closed as u8, Ordering::Relaxed);
observability.session_closed(leg);
}
async fn flush_pending_window_updates<T: SessionTransport>(
transport: &Arc<T>,
crypto_session: &Arc<Session>,
session_id: SessionId,
streams: &Arc<DashMap<u32, Arc<Stream>>>,
) {
let pending: Vec<(u32, u32, Arc<Stream>)> = streams
.iter()
.filter_map(|e| {
e.value()
.take_pending_window_update()
.map(|c| (*e.key(), c, e.value().clone()))
})
.collect();
for (stream_id, credit, stream) in pending {
let seq = stream.next_send_sequence();
if !send_window_update(
transport,
crypto_session,
session_id,
stream_id as TransportStreamId,
seq,
credit,
)
.await
{
stream.stage_window_update_credit(credit);
}
}
}
async fn drain_streams_priority_ordered<T: SessionTransport>(
transport: &Arc<T>,
crypto_session: &Arc<Session>,
session_id: SessionId,
streams: &Arc<DashMap<u32, Arc<Stream>>>,
) {
let mut snapshot: Vec<(u32, u32, Arc<Stream>)> = streams
.iter()
.map(|e| (e.value().priority(), *e.key(), e.value().clone()))
.collect();
snapshot.sort_by(|a, b| b.0.cmp(&a.0).then(a.1.cmp(&b.1)));
for (_priority, stream_id, stream) in snapshot {
loop {
let snap = crypto_session.bandwidth_snapshot();
let budget = snap.cwnd_bytes.saturating_sub(snap.inflight_bytes);
let Some(seg) = stream.poll_send(budget).await else {
break;
};
if seg.retransmit {
crypto_session.on_packet_lost(seg.data.len() as u64);
}
let base = if seg.reliable {
PacketFlags::RELIABLE
} else {
PacketFlags::UNRELIABLE
};
if !send_app_data(
transport,
crypto_session,
session_id,
stream_id as TransportStreamId,
seg.seq,
&seg.data,
base,
)
.await
{
log::error!("PhantomSession: priority-ordered drain send failed");
if seg.reliable {
stream.mark_unsent(seg.seq).await;
}
break;
}
}
}
}
fn feed_bbr_on_ack(
crypto_session: &Arc<Session>,
sent_at: tokio::time::Instant,
packet_bytes: u64,
ack_delay_us: u64,
) {
let sample = crate::transport::bandwidth_estimator::DeliverySample {
delivered_bytes: 0, sent_at: sent_at.into_std(),
acked_at: std::time::Instant::now(),
packet_bytes,
is_app_limited: false,
ack_delay_us,
};
let _ = crypto_session.on_packet_acked(sample);
}
async fn pace_send(crypto_session: &Arc<Session>, bytes: u64) {
let pacer = crypto_session.pacer();
if !pacer.is_enabled() {
return;
}
loop {
if pacer.try_consume(bytes) {
return;
}
let wait = pacer.time_until_available(bytes);
if wait.is_zero() {
continue;
}
let cap = std::time::Duration::from_millis(50);
let wait = wait.min(cap);
tokio::time::sleep(wait).await;
}
}
fn rekey_before_stamp(
crypto_session: &Arc<Session>,
stream_id: TransportStreamId,
sequence: u32,
) -> Option<u16> {
if crypto_session.send_needs_rekey()
|| crypto_session.stream_seq_needs_rekey(stream_id, sequence)
{
match crypto_session.rekey() {
Ok(_) => Some(PacketFlags::REKEY),
Err(e) => {
log::error!("PhantomSession: mid-session rekey failed: {}", e);
None
}
}
} else {
Some(0)
}
}
async fn send_app_data<T: SessionTransport>(
transport: &Arc<T>,
crypto_session: &Arc<Session>,
session_id: SessionId,
stream_id: TransportStreamId,
sequence: u32,
payload: &[u8],
base_flags: u16,
) -> bool {
let mut flag_bits = base_flags | PacketFlags::ENCRYPTED;
match rekey_before_stamp(crypto_session, stream_id, sequence) {
Some(extra) => flag_bits |= extra,
None => return false,
}
let header = PacketHeader::new(session_id, stream_id, sequence, PacketFlags::new(flag_bits))
.with_epoch(crypto_session.current_epoch());
let ciphertext = match crypto_session.encrypt_packet(&header, payload) {
Ok(c) => c,
Err(e) => {
log::error!("PhantomSession: encrypt_packet failed: {}", e);
return false;
}
};
let packet = PhantomPacket::new(header, ciphertext);
let buf = packet.to_wire();
let size = buf.len();
pace_send(crypto_session, size as u64).await;
if let Err(e) = transport.send_bytes(&buf[..size]).await {
log::error!("PhantomSession: transport send failed: {}", e);
return false;
}
crypto_session.on_packet_sent(payload.len() as u64);
true
}
async fn send_window_update<T: SessionTransport>(
transport: &Arc<T>,
crypto_session: &Arc<Session>,
session_id: SessionId,
stream_id: TransportStreamId,
sequence: u32,
new_window: u32,
) -> bool {
let mut flag_bits = PacketFlags::ENCRYPTED | PacketFlags::WINDOW_UPDATE;
match rekey_before_stamp(crypto_session, stream_id, sequence) {
Some(extra) => flag_bits |= extra,
None => return false,
}
let header = PacketHeader::new(session_id, stream_id, sequence, PacketFlags::new(flag_bits))
.with_epoch(crypto_session.current_epoch());
let payload = new_window.to_be_bytes();
let ciphertext = match crypto_session.encrypt_packet(&header, &payload) {
Ok(c) => c,
Err(e) => {
log::error!("PhantomSession: WINDOW_UPDATE encrypt failed: {}", e);
return false;
}
};
let packet = PhantomPacket::new(header, ciphertext);
let buf = packet.to_wire();
if let Err(e) = transport.send_bytes(&buf).await {
log::error!("PhantomSession: WINDOW_UPDATE send failed: {}", e);
return false;
}
true
}
async fn send_path_validation<T: SessionTransport>(
transport: &Arc<T>,
crypto_session: &Arc<Session>,
session_id: SessionId,
path_id: u8,
sequence: u32,
payload: [u8; crate::transport::path::PATH_CHALLENGE_LEN],
) -> bool {
let mut packet = build_path_validation_packet(session_id, path_id, sequence, payload);
let flag_bits = packet.header.flags.0 | PacketFlags::ENCRYPTED;
packet.header.flags = PacketFlags::new(flag_bits);
packet.header.epoch = crypto_session.current_epoch();
let plaintext = std::mem::take(&mut packet.payload);
let ciphertext = match crypto_session.encrypt_packet(&packet.header, &plaintext) {
Ok(c) => c,
Err(e) => {
log::error!("PhantomSession: PATH_VALIDATION encrypt failed: {}", e);
return false;
}
};
packet.payload = ciphertext;
let buf = packet.to_wire();
if let Err(e) = transport.send_bytes(&buf).await {
log::error!("PhantomSession: PATH_VALIDATION send failed: {}", e);
return false;
}
true
}
#[allow(clippy::too_many_arguments)]
async fn handle_packet<T: SessionTransport>(
packet: PhantomPacket,
session_id: SessionId,
crypto_recv: &Arc<Session>,
streams_recv: &Arc<DashMap<u32, Arc<Stream>>>,
demux_recv: &Arc<StreamDemultiplexer>,
transport_send_ack: &Arc<T>,
transport_for_path: &Arc<T>,
deliver_tx: &mpsc::UnboundedSender<(u32, Bytes)>,
undelivered_bytes: &AtomicU64,
ack_buf: &mut Vec<u8>,
path_validation_seq: &mut u32,
observability: &Observability,
leg: LegType,
) {
let stream_id: u32 = packet.header.stream_id.into();
let path_id = packet.header.path_id;
if packet.header.session_id != session_id {
return;
}
crypto_recv.mark_path_seen(path_id);
let plaintext: Vec<u8> = if packet.header.flags.contains(PacketFlags::ENCRYPTED) {
match crypto_recv.decrypt_packet_accepting_rekey(&packet.header, &packet.payload) {
Ok(pt) => pt,
Err(e) => {
if matches!(e, CoreError::ReplayDetected(_)) {
observability.record_replay_rejected(ReplayReason::Duplicate);
} else {
observability.record_aead_failure(leg, AeadAlgorithm::Aes256Gcm);
}
log::warn!("PhantomSession: V2 decrypt failed (dropping packet): {}", e);
return;
}
}
} else if !packet.payload.is_empty() {
observability.record_unencrypted_dropped(leg);
log::warn!(
"PhantomSession: dropping unencrypted V2 post-handshake data packet (downgrade?)"
);
return;
} else {
Vec::new()
};
if packet.header.flags.contains(PacketFlags::ACK) {
if plaintext.len() != 4 {
log::warn!(
"PhantomSession: ACK payload length {} (expected 4)",
plaintext.len()
);
return;
}
let acked_seq =
u32::from_be_bytes([plaintext[0], plaintext[1], plaintext[2], plaintext[3]]);
if let Some(stream) = streams_recv.get(&stream_id) {
if let Some((sent_at, bytes)) = stream.ack(acked_seq).await {
feed_bbr_on_ack(crypto_recv, sent_at, bytes, packet.header.ack_delay as u64);
}
}
demux_recv.route_ack(stream_id, acked_seq);
if packet.header.flags.contains(PacketFlags::FIN) {
demux_recv.route_close(stream_id);
}
return;
}
if packet.header.flags.contains(PacketFlags::WINDOW_UPDATE) {
if plaintext.len() != 4 {
log::warn!(
"PhantomSession: WINDOW_UPDATE payload length {} (expected 4)",
plaintext.len()
);
return;
}
let credit = u32::from_be_bytes([plaintext[0], plaintext[1], plaintext[2], plaintext[3]]);
if let Some(stream) = streams_recv.get(&stream_id) {
stream.apply_peer_window_update(credit);
crypto_recv.notify_outbound_ready();
}
return;
}
if packet.header.flags.contains(PacketFlags::PATH_VALIDATION) {
if plaintext.len() != crate::transport::path::PATH_CHALLENGE_LEN {
log::warn!(
"PhantomSession: PATH_VALIDATION plaintext length {} (expected {})",
plaintext.len(),
crate::transport::path::PATH_CHALLENGE_LEN
);
return;
}
let mut payload_buf = [0u8; crate::transport::path::PATH_CHALLENGE_LEN];
payload_buf.copy_from_slice(&plaintext);
match crypto_recv.path_state(path_id) {
Some(crate::transport::path::PathStateKind::Validating) => {
let _ = crypto_recv.complete_path_validation(path_id, &payload_buf);
return;
}
Some(crate::transport::path::PathStateKind::Validated)
| Some(crate::transport::path::PathStateKind::Failed) => {
return;
}
_ => {
let seq = *path_validation_seq;
*path_validation_seq = path_validation_seq.wrapping_add(1);
let _ = send_path_validation(
transport_for_path,
crypto_recv,
session_id,
path_id,
seq,
payload_buf,
)
.await;
return;
}
}
}
if !matches!(
crypto_recv.path_state(path_id),
Some(crate::transport::path::PathStateKind::Validated)
) {
crypto_recv.register_unvalidated_path(path_id);
log::warn!(
"PhantomSession: dropping application data on non-validated path_id {}",
path_id
);
return;
}
if packet.header.flags.contains(PacketFlags::COALESCED) {
let inner_for_codec = PhantomPacket {
header: packet.header,
payload: plaintext,
extensions: Vec::new(),
};
match unwrap_coalesced_packet(&inner_for_codec) {
Ok(Some(subs)) => {
for sub in subs {
if sub.is_empty() {
continue;
}
let len = sub.len() as u64;
if deliver_tx.send((stream_id, Bytes::from(sub))).is_ok() {
undelivered_bytes.fetch_add(len, Ordering::AcqRel);
}
}
}
Ok(None) => {
log::warn!("PhantomSession: COALESCED flag set but bundle didn't parse");
}
Err(e) => {
log::warn!("PhantomSession: COALESCED parse error: {}", e);
}
}
return;
}
if packet.header.flags.contains(PacketFlags::RELIABLE) {
let local = streams_recv
.entry(stream_id)
.or_insert_with(|| Arc::new(Stream::new(stream_id as TransportStreamId)))
.clone();
let ack_seq = local.next_send_sequence();
let mut ack_flag_bits = PacketFlags::ENCRYPTED | PacketFlags::ACK;
match rekey_before_stamp(crypto_recv, stream_id as TransportStreamId, ack_seq) {
Some(extra) => ack_flag_bits |= extra,
None => return,
}
let ack_header = PacketHeader::new(
session_id,
stream_id as TransportStreamId,
ack_seq,
PacketFlags::new(ack_flag_bits),
)
.with_epoch(crypto_recv.current_epoch())
.with_path_id(path_id);
let ack_payload = packet.header.sequence.to_be_bytes();
match crypto_recv.encrypt_packet(&ack_header, &ack_payload) {
Ok(ct) => {
let ack_packet = PhantomPacket::new(ack_header, ct);
ack_buf.clear();
ack_buf.extend_from_slice(&ack_packet.to_wire());
let size = ack_buf.len();
let _ = transport_send_ack.send_bytes(&ack_buf[..size]).await;
}
Err(e) => log::error!("PhantomSession: ACK encrypt failed: {}", e),
}
}
if !plaintext.is_empty() {
let len = plaintext.len() as u64;
if deliver_tx.send((stream_id, Bytes::from(plaintext))).is_ok() {
undelivered_bytes.fetch_add(len, Ordering::AcqRel);
}
}
if packet.header.flags.contains(PacketFlags::FIN) {
demux_recv.route_close(stream_id);
}
}
impl PhantomSession {
pub(crate) fn set_state(&self, new_state: ConnectionState) {
self.state.store(new_state as u8, Ordering::Relaxed);
}
pub fn observability(&self) -> Arc<Observability> {
self.observability.clone()
}
}
#[cfg_attr(feature = "bindings", uniffi::export(async_runtime = "tokio"))]
impl PhantomSession {
#[cfg_attr(feature = "bindings", uniffi::constructor)]
pub fn connect(peer_addr: String) -> Arc<Self> {
let (cmd_tx, cmd_rx) = mpsc::channel(256);
let (_recv_tx, recv_rx) = mpsc::channel(256);
let (demux, _ctrl_rx) = StreamDemultiplexer::new(256);
let streams = Arc::new(DashMap::new());
Arc::new(Self {
id: new_session_id(),
peer_addr,
state: Arc::new(AtomicU8::new(ConnectionState::Connecting as u8)),
send_queue: Arc::new(Mutex::new(Vec::new())),
cmd_tx,
cmd_rx: Mutex::new(Some(cmd_rx)),
recv_rx: Mutex::new(recv_rx),
demux: Arc::new(demux),
streams,
inner_session: Arc::new(Mutex::new(None)),
early_data_accepted: Arc::new(Mutex::new(None)),
observability: Observability::new(ObservabilityConfig::default()),
})
}
pub fn open_stream(&self) -> Arc<crate::api::stream::PhantomStream> {
let handle = self.demux.open_stream(1024);
let stream_id = handle.stream_id;
let transport_stream = Arc::new(Stream::new(stream_id as TransportStreamId));
self.streams.insert(stream_id, transport_stream);
Arc::new(crate::api::stream::PhantomStream::new(
handle,
self.cmd_tx.clone(),
))
}
pub async fn send(&self, data: Vec<u8>) -> Result<(), CoreError> {
let state = self.connection_state();
if state.is_data_ready() {
self.cmd_tx
.send(SessionCommand::Send(data))
.await
.map_err(|_| CoreError::NetworkError("Session closed".into()))?;
} else if state == ConnectionState::Connecting {
self.send_queue.lock().await.push(data);
} else {
return Err(CoreError::NetworkError(format!(
"Cannot send in state {:?}",
state
)));
}
Ok(())
}
pub async fn recv(&self) -> Result<Vec<u8>, CoreError> {
let mut rx = self.recv_rx.lock().await;
let bytes = rx
.recv()
.await
.ok_or_else(|| CoreError::NetworkError("Session closed".into()))?;
Ok(bytes.to_vec())
}
pub fn connection_state(&self) -> ConnectionState {
ConnectionState::from_u8(self.state.load(Ordering::Relaxed))
}
pub fn is_data_ready(&self) -> bool {
self.connection_state().is_data_ready()
}
pub fn is_pqc_ready(&self) -> bool {
matches!(
self.connection_state(),
ConnectionState::PqcReady | ConnectionState::Connected
)
}
pub async fn flush_queue(&self) -> Result<u32, CoreError> {
let mut queue = self.send_queue.lock().await;
let count = queue.len() as u32;
for msg in queue.drain(..) {
self.cmd_tx
.send(SessionCommand::Send(msg))
.await
.map_err(|_| CoreError::NetworkError("Session closed during flush".into()))?;
}
Ok(count)
}
pub async fn queued_count(&self) -> u32 {
self.send_queue.lock().await.len() as u32
}
pub fn id(&self) -> String {
self.id.clone()
}
pub fn peer_addr(&self) -> String {
self.peer_addr.clone()
}
pub async fn early_data_accepted(&self) -> Option<bool> {
*self.early_data_accepted.lock().await
}
pub async fn resumption_hint(&self) -> Option<ResumptionHint> {
let guard = self.inner_session.lock().await;
guard
.as_ref()
.and_then(|s| s.resumption_hint())
.map(|(session_id, resumption_secret)| ResumptionHint {
session_id: session_id.to_vec(),
resumption_secret: resumption_secret.to_vec(),
})
}
pub async fn current_epoch(&self) -> Option<u8> {
self.inner_session
.lock()
.await
.as_ref()
.map(|s| s.current_epoch())
}
pub async fn set_rekey_threshold(&self, n: u64) -> bool {
match self.inner_session.lock().await.as_ref() {
Some(s) => {
s.set_rekey_threshold(n);
true
}
None => false,
}
}
pub async fn disconnect(&self) -> Result<(), CoreError> {
self.set_state(ConnectionState::Closed);
let _ = self.cmd_tx.send(SessionCommand::Close).await;
Ok(())
}
}
impl PhantomSession {
pub fn demux(&self) -> Arc<StreamDemultiplexer> {
self.demux.clone()
}
}
impl std::fmt::Debug for PhantomSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PhantomSession")
.field("id", &self.id)
.field("peer", &self.peer_addr)
.field("state", &self.connection_state())
.finish()
}
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg_attr(feature = "bindings", uniffi::export(async_runtime = "tokio"))]
pub async fn connect_pinned(
host: String,
port: u16,
pinned_key: Vec<u8>,
) -> Result<Arc<PhantomSession>, CoreError> {
#[cfg(feature = "fips")]
crate::crypto::self_tests::ensure_post_passed()
.map_err(|e| CoreError::FipsSelfTestFailure(format!("{e:?}")))?;
let expected_server_key = HybridVerifyingKey::from_bytes(&pinned_key)
.map_err(|e| CoreError::CryptoError(format!("invalid pinned key: {}", e)))?;
let addr = format!("{}:{}", host, port);
let stream = tokio::net::TcpStream::connect(&addr)
.await
.map_err(|e| CoreError::NetworkError(format!("connect {}: {}", addr, e)))?;
let transport = crate::api::tcp_transport::TcpSessionTransport::new(stream);
let session = PhantomSession::connect_with_transport(&addr, transport, expected_server_key);
Ok(Arc::new(session))
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg_attr(feature = "bindings", uniffi::export(async_runtime = "tokio"))]
pub async fn connect_pinned_with_resumption(
host: String,
port: u16,
pinned_key: Vec<u8>,
hint: ResumptionHint,
early_data: Vec<u8>,
) -> Result<Arc<PhantomSession>, CoreError> {
#[cfg(feature = "fips")]
crate::crypto::self_tests::ensure_post_passed()
.map_err(|e| CoreError::FipsSelfTestFailure(format!("{e:?}")))?;
let expected_server_key = HybridVerifyingKey::from_bytes(&pinned_key)
.map_err(|e| CoreError::CryptoError(format!("invalid pinned key: {}", e)))?;
let session_id: [u8; 32] = hint.session_id.as_slice().try_into().map_err(|_| {
CoreError::ValidationError(format!(
"resumption hint session_id must be 32 bytes, got {}",
hint.session_id.len()
))
})?;
let resumption_secret: [u8; 32] =
hint.resumption_secret.as_slice().try_into().map_err(|_| {
CoreError::ValidationError(format!(
"resumption hint resumption_secret must be 32 bytes, got {}",
hint.resumption_secret.len()
))
})?;
if early_data.len() > EARLY_DATA_MAX_LEN {
return Err(CoreError::ValidationError(format!(
"early_data is {} bytes, exceeds the {}-byte 0-RTT cap",
early_data.len(),
EARLY_DATA_MAX_LEN
)));
}
let addr = format!("{}:{}", host, port);
let stream = tokio::net::TcpStream::connect(&addr)
.await
.map_err(|e| CoreError::NetworkError(format!("connect {}: {}", addr, e)))?;
let transport = crate::api::tcp_transport::TcpSessionTransport::new(stream);
let session = PhantomSession::connect_with_resumption(
&addr,
transport,
expected_server_key,
(session_id, resumption_secret),
early_data,
)?;
Ok(Arc::new(session))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::handshake::{ClientHello, HandshakeResponse, HandshakeServer};
struct ChannelTransport {
tx: mpsc::Sender<Vec<u8>>,
rx: Mutex<mpsc::Receiver<Vec<u8>>>,
}
impl ChannelTransport {
fn pair() -> (Self, Self) {
let (a_tx, b_rx) = mpsc::channel(64);
let (b_tx, a_rx) = mpsc::channel(64);
(
Self {
tx: a_tx,
rx: Mutex::new(a_rx),
},
Self {
tx: b_tx,
rx: Mutex::new(b_rx),
},
)
}
}
impl SessionTransport for ChannelTransport {
async fn send_bytes(&self, data: &[u8]) -> Result<(), CoreError> {
self.tx
.send(data.to_vec())
.await
.map_err(|_| CoreError::NetworkError("channel closed".into()))
}
async fn recv_bytes(&self) -> Result<Bytes, CoreError> {
let mut rx = self.rx.lock().await;
let v = rx
.recv()
.await
.ok_or_else(|| CoreError::NetworkError("channel closed".into()))?;
Ok(Bytes::from(v))
}
}
#[tokio::test]
async fn client_surfaces_server_reject_as_version_error() {
use crate::transport::handshake::ServerReject;
let (client_transport, server_transport) = ChannelTransport::pair();
let (_sk, expected_vk) = crate::crypto::hybrid_sign::HybridSigningKey::generate();
let server = tokio::spawn(async move {
let _hello = server_transport.recv_bytes().await.unwrap();
let reject = borsh::to_vec(&ServerReject::unsupported_version()).unwrap();
server_transport.send_bytes(&reject).await.unwrap();
});
let result = run_client_handshake(&client_transport, &expected_vk, None).await;
server.await.unwrap();
let err = result.expect_err("client must surface the reject as an error");
let msg = format!("{err:?}");
assert!(
msg.contains("unsupported protocol version"),
"expected a version-mismatch error, got: {msg}"
);
}
#[tokio::test]
async fn client_handshake_caps_retry_rounds() {
use crate::transport::handshake::HelloRetryRequest;
let (client_transport, server_transport) = ChannelTransport::pair();
let (_sk, expected_vk) = crate::crypto::hybrid_sign::HybridSigningKey::generate();
let server = tokio::spawn(async move {
loop {
if server_transport.recv_bytes().await.is_err() {
break;
}
let retry = borsh::to_vec(&HelloRetryRequest {
challenge: None,
cookie: None,
})
.expect("encode retry");
if server_transport.send_bytes(&retry).await.is_err() {
break;
}
}
});
let result = run_client_handshake(&client_transport, &expected_vk, None).await;
drop(client_transport); let _ = server.await;
assert!(
matches!(result, Err(CoreError::HandshakeError(_))),
"client must error after the retry-round cap, not loop forever; got {result:?}"
);
}
#[test]
fn resumption_hint_debug_redacts_secret() {
let hint = ResumptionHint {
session_id: vec![0xAB; 32],
resumption_secret: vec![0xCD; 32],
};
let dbg = format!("{hint:?}");
assert!(dbg.contains("REDACTED"), "secret must be redacted: {dbg}");
assert!(
!dbg.contains("205"),
"no decimal secret bytes in Debug: {dbg}"
);
assert!(
!dbg.to_lowercase().contains("cd, cd"),
"no hex secret bytes: {dbg}"
);
}
#[tokio::test]
async fn test_phantom_session_instant_connect() {
let session = PhantomSession::connect("example.com:443".to_string());
assert_eq!(session.connection_state(), ConnectionState::Connecting);
assert!(!session.is_data_ready());
assert_eq!(session.peer_addr(), "example.com:443");
}
#[tokio::test]
async fn test_phantom_session_send_queue() {
let session = PhantomSession::connect("example.com:443".to_string());
session.send(vec![1, 2, 3]).await.unwrap();
session.send(vec![4, 5, 6]).await.unwrap();
assert_eq!(session.queued_count().await, 2);
session.set_state(ConnectionState::ClassicalReady);
assert!(session.is_data_ready());
let flushed = session.flush_queue().await.unwrap();
assert_eq!(flushed, 2);
assert_eq!(session.queued_count().await, 0);
}
#[tokio::test]
async fn test_phantom_session_state_progression() {
let session = PhantomSession::connect("example.com:443".to_string());
assert_eq!(session.connection_state(), ConnectionState::Connecting);
assert!(!session.is_data_ready());
session.set_state(ConnectionState::ClassicalReady);
assert!(session.is_data_ready());
assert!(!session.is_pqc_ready());
session.set_state(ConnectionState::PqcUpgrading);
assert!(session.is_data_ready());
assert!(!session.is_pqc_ready());
session.set_state(ConnectionState::PqcReady);
assert!(session.is_data_ready());
assert!(session.is_pqc_ready());
session.set_state(ConnectionState::Connected);
assert!(session.is_data_ready());
assert!(session.is_pqc_ready());
}
#[tokio::test]
async fn test_phantom_session_close() {
let session = PhantomSession::connect("example.com:443".to_string());
session.disconnect().await.unwrap();
assert_eq!(session.connection_state(), ConnectionState::Closed);
assert!(!session.is_data_ready());
}
fn decrypt_incoming(
server_session: &crate::transport::session::Session,
bytes: &[u8],
) -> Vec<u8> {
let pkt = PhantomPacket::from_wire(bytes).expect("deserialize PhantomPacket");
assert!(
pkt.header.flags.contains(PacketFlags::ENCRYPTED),
"expected ENCRYPTED flag on application data"
);
server_session
.decrypt_packet(&pkt.header, &pkt.payload)
.expect("decrypt application data")
}
fn encrypt_outgoing(
server_session: &crate::transport::session::Session,
session_id: SessionId,
stream_id: TransportStreamId,
sequence: u32,
payload: &[u8],
) -> Vec<u8> {
let flag_bits = PacketFlags::RELIABLE | PacketFlags::ENCRYPTED;
let header =
PacketHeader::new(session_id, stream_id, sequence, PacketFlags::new(flag_bits))
.with_epoch(server_session.current_epoch());
let ct = server_session
.encrypt_packet(&header, payload)
.expect("encrypt reply");
let packet = PhantomPacket::new(header, ct);
packet.to_wire()
}
#[tokio::test]
async fn test_phantom_session_handshake_via_transport() {
let (client_transport, server_transport) = ChannelTransport::pair();
let server_hs = HandshakeServer::new().unwrap();
let server_pinned_key = server_hs.verifying_key().clone();
let session = PhantomSession::connect_with_transport(
"test-server:9000",
client_transport,
server_pinned_key,
);
session.send(b"early-data".to_vec()).await.unwrap();
let server_handle = tokio::spawn(async move {
let client_ip = "127.0.0.1".parse().unwrap();
let client_hello_bytes = server_transport.recv_bytes().await.unwrap();
let client_hello = borsh::from_slice::<ClientHello>(&client_hello_bytes).unwrap();
let server_session = loop {
let response = server_hs.process_client_hello(&client_hello, 0, client_ip);
match response {
HandshakeResponse::Retry(retry) => {
let retry_bytes = borsh::to_vec(&retry).unwrap();
server_transport.send_bytes(&retry_bytes).await.unwrap();
let next_bytes = server_transport.recv_bytes().await.unwrap();
let next_hello = borsh::from_slice::<ClientHello>(&next_bytes).unwrap();
let resp2 = server_hs.process_client_hello(&next_hello, 0, client_ip);
match resp2 {
HandshakeResponse::Success(server_hello, session, _) => {
let server_hello_bytes = borsh::to_vec(&server_hello).unwrap();
server_transport
.send_bytes(&server_hello_bytes)
.await
.unwrap();
break session;
}
_ => panic!("Expected success after retry"),
}
}
HandshakeResponse::Success(server_hello, session, _) => {
let server_hello_bytes = borsh::to_vec(&server_hello).unwrap();
server_transport
.send_bytes(&server_hello_bytes)
.await
.unwrap();
break session;
}
HandshakeResponse::Reject(r) => panic!("unexpected reject: {:?}", r),
HandshakeResponse::Fail(e) => panic!("handshake failed: {:?}", e),
}
};
let session_id = *server_session.id();
let early_frame = server_transport.recv_bytes().await.unwrap();
assert!(
!early_frame
.windows(b"early-data".len())
.any(|w| w == b"early-data"),
"encrypted frame must not contain plaintext early-data"
);
let early_plain = decrypt_incoming(&server_session, &early_frame);
assert_eq!(early_plain, b"early-data");
let post_frame = server_transport.recv_bytes().await.unwrap();
let post_plain = decrypt_incoming(&server_session, &post_frame);
assert_eq!(post_plain, b"after-handshake");
let reply = encrypt_outgoing(&server_session, session_id, 1, 1, b"server-reply");
server_transport.send_bytes(&reply).await.unwrap();
});
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
assert_eq!(session.connection_state(), ConnectionState::Connected);
session.send(b"after-handshake".to_vec()).await.unwrap();
let reply = session.recv().await.unwrap();
assert_eq!(reply, b"server-reply");
server_handle.await.unwrap();
session.disconnect().await.unwrap();
}
#[tokio::test]
async fn reliable_send_survives_a_dropped_data_frame() {
use crate::test_harness::fault_transport::{FaultControl, LossyTransport};
let (client_transport, server_transport) = ChannelTransport::pair();
let faults = FaultControl::new();
let lossy_client = LossyTransport::new(client_transport, faults.clone());
let server_hs = HandshakeServer::new().unwrap();
let server_pinned_key = server_hs.verifying_key().clone();
let session = PhantomSession::connect_with_transport(
"test-server:9000",
lossy_client,
server_pinned_key,
);
let server_handle = tokio::spawn(async move {
let client_ip = "127.0.0.1".parse().unwrap();
let client_hello_bytes = server_transport.recv_bytes().await.unwrap();
let client_hello = borsh::from_slice::<ClientHello>(&client_hello_bytes).unwrap();
let server_session = loop {
match server_hs.process_client_hello(&client_hello, 0, client_ip) {
HandshakeResponse::Retry(retry) => {
let retry_bytes = borsh::to_vec(&retry).unwrap();
server_transport.send_bytes(&retry_bytes).await.unwrap();
let next_bytes = server_transport.recv_bytes().await.unwrap();
let next_hello = borsh::from_slice::<ClientHello>(&next_bytes).unwrap();
match server_hs.process_client_hello(&next_hello, 0, client_ip) {
HandshakeResponse::Success(server_hello, session, _) => {
let b = borsh::to_vec(&server_hello).unwrap();
server_transport.send_bytes(&b).await.unwrap();
break session;
}
_ => panic!("expected success after retry"),
}
}
HandshakeResponse::Success(server_hello, session, _) => {
let b = borsh::to_vec(&server_hello).unwrap();
server_transport.send_bytes(&b).await.unwrap();
break session;
}
HandshakeResponse::Reject(r) => panic!("unexpected reject: {:?}", r),
HandshakeResponse::Fail(e) => panic!("handshake failed: {:?}", e),
}
};
let data_frame = tokio::time::timeout(
std::time::Duration::from_secs(3),
server_transport.recv_bytes(),
)
.await
.expect(
"reliable payload never arrived within 3s — the dropped data frame was not \
retransmitted (loss-recovery regression)",
)
.unwrap();
let plain = decrypt_incoming(&server_session, &data_frame);
assert_eq!(plain, b"reliable-payload");
});
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
assert_eq!(session.connection_state(), ConnectionState::Connected);
faults.arm_drop_next(1);
session.send(b"reliable-payload".to_vec()).await.unwrap();
server_handle.await.unwrap();
session.disconnect().await.unwrap();
}
#[tokio::test]
async fn drain_reports_a_retransmit_as_loss_to_bbr() {
use crate::transport::bandwidth_estimator::BbrState;
tokio::time::pause();
let sid = fixed_session_id();
let (client, _server) = paired_sessions(sid);
let stream = Arc::new(TransportStream::new(1));
stream.send_reliable(Bytes::from("payload")).await;
let streams: Arc<DashMap<u32, Arc<TransportStream>>> = Arc::new(DashMap::new());
streams.insert(1u32, stream);
let (client_t, _server_t) = ChannelTransport::pair();
let transport = Arc::new(client_t);
drain_streams_priority_ordered(&transport, &client, sid, &streams).await;
assert_ne!(client.bbr_state(), BbrState::FastRecovery);
tokio::time::advance(std::time::Duration::from_millis(1100)).await;
drain_streams_priority_ordered(&transport, &client, sid, &streams).await;
assert_eq!(
client.bbr_state(),
BbrState::FastRecovery,
"a retransmit must be reported to BBR as a loss"
);
}
#[tokio::test]
async fn drain_withholds_new_data_when_inflight_exceeds_cwnd() {
let sid = fixed_session_id();
let (client, _server) = paired_sessions(sid);
client.on_packet_sent(100_000_000);
let inflight_before = client.bandwidth_snapshot().inflight_bytes;
let stream = Arc::new(TransportStream::new(1));
stream.send_reliable(Bytes::from("new-data")).await;
let streams: Arc<DashMap<u32, Arc<TransportStream>>> = Arc::new(DashMap::new());
streams.insert(1u32, stream);
let (client_t, _server_t) = ChannelTransport::pair();
let transport = Arc::new(client_t);
drain_streams_priority_ordered(&transport, &client, sid, &streams).await;
assert_eq!(
client.bandwidth_snapshot().inflight_bytes,
inflight_before,
"no new data should be sent when inflight >= cwnd"
);
}
use crate::transport::multiplexer::StreamDemultiplexer;
use crate::transport::session::Session as InnerSession;
use crate::transport::stream::Stream as TransportStream;
fn paired_sessions(session_id: SessionId) -> (Arc<InnerSession>, Arc<InnerSession>) {
let secret = [0x11u8; 32];
let client = Arc::new(InnerSession::new(session_id, &secret, false).unwrap());
let server = Arc::new(InnerSession::new(session_id, &secret, true).unwrap());
(client, server)
}
fn fixed_session_id() -> SessionId {
SessionId::from_bytes([0x88; 32])
}
fn build_app_frame(
client_session: &InnerSession,
session_id: SessionId,
stream_id: TransportStreamId,
sequence: u32,
payload: &[u8],
) -> Vec<u8> {
let flag_bits = PacketFlags::RELIABLE | PacketFlags::ENCRYPTED;
let header =
PacketHeader::new(session_id, stream_id, sequence, PacketFlags::new(flag_bits))
.with_epoch(client_session.current_epoch());
let ciphertext = client_session
.encrypt_packet(&header, payload)
.expect("encrypt_packet");
let packet = PhantomPacket::new(header, ciphertext);
packet.to_wire()
}
#[tokio::test]
async fn v2_recv_routes_encrypted_app_data_through_recv_channel() {
let session_id = fixed_session_id();
let (client_session, server_session) = paired_sessions(session_id);
let stream_id: TransportStreamId = 1;
let frame = build_app_frame(&client_session, session_id, stream_id, 0, b"hello-v2");
let v2 = PhantomPacket::from_wire(&frame).unwrap();
let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
let demux = Arc::new(demux);
let streams: Arc<DashMap<u32, Arc<TransportStream>>> = Arc::new(DashMap::new());
let (deliver_tx, mut deliver_rx) = mpsc::unbounded_channel::<(u32, Bytes)>();
let undelivered = AtomicU64::new(0);
let (ack_a, ack_b) = mpsc::channel::<Vec<u8>>(4);
let transport_send: Arc<ChannelTransport> = Arc::new(ChannelTransport {
tx: ack_a,
rx: Mutex::new(ack_b),
});
let mut ack_buf = Vec::with_capacity(256);
let mut path_validation_seq: u32 = 0;
let obs = Observability::new(ObservabilityConfig::default());
handle_packet(
v2,
session_id,
&server_session,
&streams,
&demux,
&transport_send,
&transport_send,
&deliver_tx,
&undelivered,
&mut ack_buf,
&mut path_validation_seq,
&obs,
LegType::Tcp,
)
.await;
let (sid, received) = deliver_rx.recv().await.expect("delivery hand-off");
assert_eq!(sid, stream_id as u32);
assert_eq!(&received[..], b"hello-v2");
assert_eq!(
undelivered.load(Ordering::Acquire),
b"hello-v2".len() as u64
);
}
fn build_app_frame_on_path(
client_session: &InnerSession,
session_id: SessionId,
stream_id: TransportStreamId,
sequence: u32,
path_id: u8,
payload: &[u8],
) -> Vec<u8> {
let flag_bits = PacketFlags::RELIABLE | PacketFlags::ENCRYPTED;
let header =
PacketHeader::new(session_id, stream_id, sequence, PacketFlags::new(flag_bits))
.with_epoch(client_session.current_epoch())
.with_path_id(path_id);
let ciphertext = client_session
.encrypt_packet(&header, payload)
.expect("encrypt_packet");
PhantomPacket::new(header, ciphertext).to_wire()
}
#[tokio::test]
async fn app_data_on_non_validated_path_is_dropped() {
use crate::transport::path::PathStateKind;
let session_id = fixed_session_id();
let (client_session, server_session) = paired_sessions(session_id);
let stream_id: TransportStreamId = 1;
let bad = build_app_frame_on_path(
&client_session,
session_id,
stream_id,
0,
7, b"on-bad-path",
);
let bad = PhantomPacket::from_wire(&bad).unwrap();
let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
let demux = Arc::new(demux);
let streams: Arc<DashMap<u32, Arc<TransportStream>>> = Arc::new(DashMap::new());
let (deliver_tx, mut deliver_rx) = mpsc::unbounded_channel::<(u32, Bytes)>();
let undelivered = AtomicU64::new(0);
let (ack_a, ack_b) = mpsc::channel::<Vec<u8>>(4);
let transport_send: Arc<ChannelTransport> = Arc::new(ChannelTransport {
tx: ack_a,
rx: Mutex::new(ack_b),
});
let mut ack_buf = Vec::with_capacity(256);
let mut path_validation_seq: u32 = 0;
let obs = Observability::new(ObservabilityConfig::default());
handle_packet(
bad,
session_id,
&server_session,
&streams,
&demux,
&transport_send,
&transport_send,
&deliver_tx,
&undelivered,
&mut ack_buf,
&mut path_validation_seq,
&obs,
LegType::Tcp,
)
.await;
assert!(
deliver_rx.try_recv().is_err(),
"application data on a non-validated path must be dropped"
);
assert_eq!(
undelivered.load(Ordering::Acquire),
0,
"dropped data must not count toward the backlog"
);
assert_eq!(
server_session.path_state(7),
Some(PathStateKind::Unvalidated),
"the unseen path id must be registered for a later challenge"
);
let good = build_app_frame_on_path(
&client_session,
session_id,
stream_id,
1,
0,
b"on-good-path",
);
let good = PhantomPacket::from_wire(&good).unwrap();
handle_packet(
good,
session_id,
&server_session,
&streams,
&demux,
&transport_send,
&transport_send,
&deliver_tx,
&undelivered,
&mut ack_buf,
&mut path_validation_seq,
&obs,
LegType::Tcp,
)
.await;
let (sid, received) = deliver_rx.recv().await.expect("path-0 delivery");
assert_eq!(sid, stream_id as u32);
assert_eq!(&received[..], b"on-good-path");
}
fn build_encrypted_ack(
acker_session: &InnerSession,
session_id: SessionId,
stream_id: TransportStreamId,
ack_header_seq: u32,
acked_seq: u32,
) -> Vec<u8> {
let flag_bits = PacketFlags::ENCRYPTED | PacketFlags::ACK;
let header = PacketHeader::new(
session_id,
stream_id,
ack_header_seq,
PacketFlags::new(flag_bits),
)
.with_epoch(acker_session.current_epoch());
let ct = acker_session
.encrypt_packet(&header, &acked_seq.to_be_bytes())
.expect("encrypt ack");
PhantomPacket::new(header, ct).to_wire()
}
async fn run_recv(
pkt: PhantomPacket,
session_id: SessionId,
server_session: &Arc<InnerSession>,
streams: &Arc<DashMap<u32, Arc<TransportStream>>>,
) {
let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
let demux = Arc::new(demux);
let (deliver_tx, _deliver_rx) = mpsc::unbounded_channel::<(u32, Bytes)>();
let undelivered = AtomicU64::new(0);
let (ack_a, ack_b) = mpsc::channel::<Vec<u8>>(4);
let transport: Arc<ChannelTransport> = Arc::new(ChannelTransport {
tx: ack_a,
rx: Mutex::new(ack_b),
});
let mut ack_buf = Vec::with_capacity(64);
let mut path_validation_seq: u32 = 0;
let obs = Observability::new(ObservabilityConfig::default());
handle_packet(
pkt,
session_id,
server_session,
streams,
&demux,
&transport,
&transport,
&deliver_tx,
&undelivered,
&mut ack_buf,
&mut path_validation_seq,
&obs,
LegType::Tcp,
)
.await;
}
async fn staged_pending_segment() -> (
Arc<TransportStream>,
Arc<DashMap<u32, Arc<TransportStream>>>,
u32,
) {
let stream_id: TransportStreamId = 1;
let stream = Arc::new(TransportStream::new(stream_id));
let seq = stream
.send_reliable(Bytes::from_static(b"reliable-payload"))
.await;
let _ = stream.poll_send(u64::MAX).await.expect("segment in-flight");
let streams: Arc<DashMap<u32, Arc<TransportStream>>> = Arc::new(DashMap::new());
streams.insert(stream_id as u32, stream.clone());
(stream, streams, seq)
}
#[tokio::test]
async fn forged_plaintext_ack_does_not_retire_pending_segment() {
let session_id = fixed_session_id();
let (_client, server_session) = paired_sessions(session_id);
let (stream, streams, seq) = staged_pending_segment().await;
let stream_id: TransportStreamId = 1;
run_recv(
PhantomPacket::new(
PacketHeader::new(
session_id,
stream_id,
seq,
PacketFlags::new(PacketFlags::ACK),
),
Vec::new(),
),
session_id,
&server_session,
&streams,
)
.await;
run_recv(
PhantomPacket::new(
PacketHeader::new(
session_id,
stream_id,
999,
PacketFlags::new(PacketFlags::ACK),
),
seq.to_be_bytes().to_vec(),
),
session_id,
&server_session,
&streams,
)
.await;
assert!(
stream.ack(seq).await.is_some(),
"a forged unauthenticated ACK must not retire the pending reliable segment"
);
}
#[tokio::test]
async fn authenticated_ack_retires_pending_segment() {
let session_id = fixed_session_id();
let (client_session, server_session) = paired_sessions(session_id);
let (stream, streams, seq) = staged_pending_segment().await;
let stream_id: TransportStreamId = 1;
let ack_header_seq = seq.wrapping_add(54_321);
let frame =
build_encrypted_ack(&client_session, session_id, stream_id, ack_header_seq, seq);
let ack_pkt = PhantomPacket::from_wire(&frame).expect("parse ack");
run_recv(ack_pkt, session_id, &server_session, &streams).await;
assert!(
stream.ack(seq).await.is_none(),
"an authenticated ACK must retire the acked pending segment"
);
}
#[tokio::test]
async fn ack_with_wrong_session_id_is_dropped() {
let session_id = fixed_session_id();
let (_client, server_session) = paired_sessions(session_id);
let (stream, streams, seq) = staged_pending_segment().await;
let stream_id: TransportStreamId = 1;
let wrong_id = SessionId::from_bytes([0x11; 32]);
run_recv(
PhantomPacket::new(
PacketHeader::new(wrong_id, stream_id, seq, PacketFlags::new(PacketFlags::ACK)),
Vec::new(),
),
session_id,
&server_session,
&streams,
)
.await;
assert!(
stream.ack(seq).await.is_some(),
"an ACK for a different session id must not retire the segment"
);
}
#[tokio::test]
async fn v2_recv_drops_unencrypted_non_empty_post_handshake_payload() {
let session_id = fixed_session_id();
let (_, server_session) = paired_sessions(session_id);
let stream_id: TransportStreamId = 2;
let bad_header = PacketHeader::new(
session_id,
stream_id,
0,
PacketFlags::new(PacketFlags::RELIABLE), );
let bad_packet = PhantomPacket::new(bad_header, b"leaked-cleartext".to_vec());
let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
let demux = Arc::new(demux);
let streams: Arc<DashMap<u32, Arc<TransportStream>>> = Arc::new(DashMap::new());
let (deliver_tx, mut deliver_rx) = mpsc::unbounded_channel::<(u32, Bytes)>();
let undelivered = AtomicU64::new(0);
let (ack_a, ack_b) = mpsc::channel::<Vec<u8>>(4);
let transport_send: Arc<ChannelTransport> = Arc::new(ChannelTransport {
tx: ack_a,
rx: Mutex::new(ack_b),
});
let mut ack_buf = Vec::with_capacity(256);
let mut path_validation_seq: u32 = 0;
let obs = Observability::new(ObservabilityConfig::default());
handle_packet(
bad_packet,
session_id,
&server_session,
&streams,
&demux,
&transport_send,
&transport_send,
&deliver_tx,
&undelivered,
&mut ack_buf,
&mut path_validation_seq,
&obs,
LegType::Tcp,
)
.await;
assert!(
deliver_rx.try_recv().is_err(),
"unencrypted post-handshake payload must NOT be handed off for delivery"
);
assert_eq!(undelivered.load(Ordering::Acquire), 0);
}
#[tokio::test]
async fn v2_recv_handles_coalesced_bundle_and_routes_each_subpayload() {
use crate::transport::packet_coalescer::{CoalescerConfig, PacketCoalescer};
let session_id = fixed_session_id();
let (client_session, server_session) = paired_sessions(session_id);
let mut coalescer = PacketCoalescer::new(CoalescerConfig::default());
coalescer.push(b"alpha");
coalescer.push(b"bravo");
coalescer.push(b"charlie");
let bundle = coalescer.flush().expect("bundle");
let stream_id: TransportStreamId = 3;
let flag_bits = PacketFlags::ENCRYPTED | PacketFlags::COALESCED;
let header = PacketHeader::new(session_id, stream_id, 0, PacketFlags::new(flag_bits))
.with_epoch(client_session.current_epoch());
let ciphertext = client_session
.encrypt_packet(&header, &bundle)
.expect("encrypt bundle");
let v2 = PhantomPacket::new(header, ciphertext);
let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
let demux = Arc::new(demux);
let streams: Arc<DashMap<u32, Arc<TransportStream>>> = Arc::new(DashMap::new());
let (deliver_tx, mut deliver_rx) = mpsc::unbounded_channel::<(u32, Bytes)>();
let undelivered = AtomicU64::new(0);
let (ack_a, ack_b) = mpsc::channel::<Vec<u8>>(4);
let transport_send: Arc<ChannelTransport> = Arc::new(ChannelTransport {
tx: ack_a,
rx: Mutex::new(ack_b),
});
let mut ack_buf = Vec::with_capacity(256);
let mut path_validation_seq: u32 = 0;
let obs = Observability::new(ObservabilityConfig::default());
handle_packet(
v2,
session_id,
&server_session,
&streams,
&demux,
&transport_send,
&transport_send,
&deliver_tx,
&undelivered,
&mut ack_buf,
&mut path_validation_seq,
&obs,
LegType::Tcp,
)
.await;
let (sa, a) = deliver_rx.recv().await.expect("alpha");
let (sb, b) = deliver_rx.recv().await.expect("bravo");
let (sc, c) = deliver_rx.recv().await.expect("charlie");
assert_eq!(
(sa, sb, sc),
(stream_id as u32, stream_id as u32, stream_id as u32)
);
assert_eq!(&a[..], b"alpha");
assert_eq!(&b[..], b"bravo");
assert_eq!(&c[..], b"charlie");
assert_eq!(undelivered.load(Ordering::Acquire), (5 + 5 + 7) as u64);
}
#[tokio::test]
async fn delivery_preserves_order_across_coalesced_then_normal_frame() {
use crate::transport::packet_coalescer::{CoalescerConfig, PacketCoalescer};
let session_id = fixed_session_id();
let (client_session, server_session) = paired_sessions(session_id);
let stream_id: TransportStreamId = 1;
let mut coalescer = PacketCoalescer::new(CoalescerConfig::default());
coalescer.push(b"A");
coalescer.push(b"B");
coalescer.push(b"C");
let bundle = coalescer.flush().expect("bundle");
let flag_bits = PacketFlags::ENCRYPTED | PacketFlags::COALESCED;
let h1 = PacketHeader::new(session_id, stream_id, 0, PacketFlags::new(flag_bits))
.with_epoch(client_session.current_epoch());
let ct1 = client_session
.encrypt_packet(&h1, &bundle)
.expect("encrypt bundle");
let coalesced = PhantomPacket::new(h1, ct1);
let d_wire = build_app_frame(&client_session, session_id, stream_id, 1, b"D");
let normal = PhantomPacket::from_wire(&d_wire).unwrap();
let (demux, _ctrl) = StreamDemultiplexer::new(16);
let demux = Arc::new(demux);
let streams: Arc<DashMap<u32, Arc<TransportStream>>> = Arc::new(DashMap::new());
let (deliver_tx, mut deliver_rx) = mpsc::unbounded_channel::<(u32, Bytes)>();
let undelivered = AtomicU64::new(0);
let (ack_a, ack_b) = mpsc::channel::<Vec<u8>>(8);
let transport_send: Arc<ChannelTransport> = Arc::new(ChannelTransport {
tx: ack_a,
rx: Mutex::new(ack_b),
});
let mut ack_buf = Vec::with_capacity(256);
let mut pv_seq: u32 = 0;
let obs = Observability::new(ObservabilityConfig::default());
for pkt in [coalesced, normal] {
handle_packet(
pkt,
session_id,
&server_session,
&streams,
&demux,
&transport_send,
&transport_send,
&deliver_tx,
&undelivered,
&mut ack_buf,
&mut pv_seq,
&obs,
LegType::Tcp,
)
.await;
}
let mut got: Vec<Bytes> = Vec::new();
while let Ok((_sid, b)) = deliver_rx.try_recv() {
got.push(b);
}
let seen: Vec<&[u8]> = got.iter().map(|b| &b[..]).collect();
assert_eq!(seen, vec![&b"A"[..], b"B", b"C", b"D"]);
}
#[tokio::test]
async fn peer_ignoring_flow_control_trips_delivery_hard_cap_and_closes_session() {
let session_id = fixed_session_id();
let (client_inner, server_inner) = paired_sessions(session_id);
let (client_t, server_t) = ChannelTransport::pair();
let client_t = Arc::new(client_t);
let server = PhantomSession::from_accepted_server_session(
"flooder".to_string(),
server_t,
server_inner,
);
let drain_t = client_t.clone();
let drainer = tokio::spawn(async move { while drain_t.recv_bytes().await.is_ok() {} });
let payload = vec![0xABu8; 64 * 1024];
let mut seq: u32 = 0;
let mut torn_down = false;
for _ in 0..4000 {
if server.connection_state() == ConnectionState::Closed {
torn_down = true;
break;
}
let flag_bits = PacketFlags::RELIABLE | PacketFlags::ENCRYPTED;
let header = PacketHeader::new(session_id, 1, seq, PacketFlags::new(flag_bits))
.with_epoch(client_inner.current_epoch());
let ct = client_inner
.encrypt_packet(&header, &payload)
.expect("encrypt");
let wire = PhantomPacket::new(header, ct).to_wire();
match tokio::time::timeout(
std::time::Duration::from_secs(5),
client_t.send_bytes(&wire),
)
.await
{
Ok(Ok(())) => {}
_ => {
torn_down = true;
break;
}
}
seq = seq.wrapping_add(1);
tokio::task::yield_now().await;
}
assert!(
torn_down,
"a peer flooding past the delivery hard cap must get its session torn down"
);
let mut closed = false;
for _ in 0..200 {
if server.connection_state() == ConnectionState::Closed {
closed = true;
break;
}
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
}
drainer.abort();
assert!(
closed,
"session state must be Closed after the hard cap trips"
);
}
#[tokio::test]
async fn bbr_on_ack_drives_pacer_rate() {
use crate::transport::bandwidth_estimator::DeliverySample;
use std::time::{Duration, Instant};
let session_id = fixed_session_id();
let (client_session, _server_session) = paired_sessions(session_id);
assert!(!client_session.pacer().is_enabled());
let now = Instant::now();
for i in 0..16 {
let sent_at = now - Duration::from_millis(20 + i * 5);
let acked_at = now - Duration::from_millis(i * 5);
let sample = DeliverySample {
delivered_bytes: 0,
sent_at,
acked_at,
packet_bytes: 1500,
is_app_limited: false,
ack_delay_us: 100,
};
client_session.on_packet_sent(1500);
let _ = client_session.on_packet_acked(sample);
}
let snap = client_session.bandwidth_snapshot();
assert!(
snap.pacing_rate_bps > 0,
"expected pacing_rate to be non-zero, got {}",
snap.pacing_rate_bps,
);
assert_eq!(client_session.pacer().rate(), snap.pacing_rate_bps);
}
#[tokio::test]
async fn flow_control_window_update_round_trip() {
use crate::transport::stream::INITIAL_STREAM_WINDOW;
let session_id = fixed_session_id();
let (client_session, server_session) = paired_sessions(session_id);
let stream_id: TransportStreamId = 9;
let server_streams: Arc<DashMap<u32, Arc<TransportStream>>> = Arc::new(DashMap::new());
let server_stream = Arc::new(TransportStream::new(stream_id));
server_streams.insert(stream_id as u32, server_stream.clone());
let client_stream = Arc::new(TransportStream::new(stream_id));
let drain = INITIAL_STREAM_WINDOW - 1000;
assert!(client_stream.try_consume_send_window(drain));
assert_eq!(client_stream.peer_send_window(), 1000);
let consumed = INITIAL_STREAM_WINDOW / 2 + 1;
let credit = server_stream
.record_app_consumed(consumed)
.expect("threshold crossed → credit granted");
server_stream.stage_window_update_credit(credit);
let (out_tx, mut out_rx) = mpsc::channel::<Vec<u8>>(4);
let (back_tx, back_rx) = mpsc::channel::<Vec<u8>>(4);
let server_outbound: Arc<ChannelTransport> = Arc::new(ChannelTransport {
tx: out_tx,
rx: Mutex::new(back_rx),
});
let _keep = back_tx;
flush_pending_window_updates(
&server_outbound,
&server_session,
session_id,
&server_streams,
)
.await;
let frame = tokio::time::timeout(std::time::Duration::from_millis(100), out_rx.recv())
.await
.expect("expected a WINDOW_UPDATE frame")
.expect("channel open");
let pv2 = PhantomPacket::from_wire(&frame).unwrap();
assert!(pv2.header.flags.contains(PacketFlags::WINDOW_UPDATE));
let pt = client_session
.decrypt_packet(&pv2.header, &pv2.payload)
.expect("decrypt WINDOW_UPDATE");
assert_eq!(pt.len(), 4);
let announced = u32::from_be_bytes([pt[0], pt[1], pt[2], pt[3]]);
assert_eq!(
announced, credit,
"WINDOW_UPDATE carries the relative credit (bytes consumed since last update)"
);
assert!(
out_rx.try_recv().is_err(),
"exactly one WINDOW_UPDATE must be emitted"
);
flush_pending_window_updates(
&server_outbound,
&server_session,
session_id,
&server_streams,
)
.await;
assert!(
out_rx.try_recv().is_err(),
"no spurious second WINDOW_UPDATE after the credit was already flushed"
);
client_stream.apply_peer_window_update(announced);
assert_eq!(client_stream.peer_send_window(), 1000 + credit);
}
#[tokio::test]
async fn priority_scheduler_drains_higher_priority_stream_first() {
let session_id = fixed_session_id();
let (client_session, _server_session) = paired_sessions(session_id);
let (tx_a, mut rx_a) = mpsc::channel::<Vec<u8>>(32);
let (tx_b, rx_b) = mpsc::channel::<Vec<u8>>(32);
let transport: Arc<ChannelTransport> = Arc::new(ChannelTransport {
tx: tx_a,
rx: Mutex::new(rx_b),
});
let _keep = tx_b;
let streams: Arc<DashMap<u32, Arc<TransportStream>>> = Arc::new(DashMap::new());
let low = Arc::new(TransportStream::new(11));
low.set_priority(1);
low.send_reliable(Bytes::from_static(b"L0")).await;
low.send_reliable(Bytes::from_static(b"L1")).await;
low.send_reliable(Bytes::from_static(b"L2")).await;
streams.insert(11, low);
let hi = Arc::new(TransportStream::new(22));
hi.set_priority(100);
hi.send_reliable(Bytes::from_static(b"H0")).await;
hi.send_reliable(Bytes::from_static(b"H1")).await;
hi.send_reliable(Bytes::from_static(b"H2")).await;
streams.insert(22, hi);
drain_streams_priority_ordered(&transport, &client_session, session_id, &streams).await;
let mut order: Vec<&'static str> = Vec::new();
while let Ok(frame) =
tokio::time::timeout(std::time::Duration::from_millis(50), rx_a.recv()).await
{
let bytes = match frame {
Some(b) => b,
None => break,
};
let v2 = PhantomPacket::from_wire(&bytes).unwrap();
let plaintext = _server_session
.decrypt_packet(&v2.header, &v2.payload)
.expect("decrypt");
let tag: &'static str = match &plaintext[..] {
b"H0" => "H0",
b"H1" => "H1",
b"H2" => "H2",
b"L0" => "L0",
b"L1" => "L1",
b"L2" => "L2",
other => panic!("unexpected payload {:?}", other),
};
order.push(tag);
}
let first_low = order
.iter()
.position(|s| s.starts_with('L'))
.unwrap_or(order.len());
let last_high = order.iter().rposition(|s| s.starts_with('H')).unwrap();
assert!(
last_high < first_low,
"strict priority violated: order = {:?}",
order
);
}
#[tokio::test]
async fn v2_recv_echoes_path_validation_challenge_back_as_response() {
let session_id = fixed_session_id();
let (client_session, server_session) = paired_sessions(session_id);
let path_id: u8 = 7;
let payload = [0xDEu8; crate::transport::path::PATH_CHALLENGE_LEN];
let flag_bits = PacketFlags::ENCRYPTED | PacketFlags::PATH_VALIDATION;
let header = PacketHeader::new(session_id, 0, 0, PacketFlags::new(flag_bits))
.with_epoch(client_session.current_epoch())
.with_path_id(path_id);
let ciphertext = client_session
.encrypt_packet(&header, &payload)
.expect("encrypt challenge");
let v2 = PhantomPacket::new(header, ciphertext);
let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
let demux = Arc::new(demux);
let streams: Arc<DashMap<u32, Arc<TransportStream>>> = Arc::new(DashMap::new());
let (deliver_tx, _deliver_rx) = mpsc::unbounded_channel::<(u32, Bytes)>();
let undelivered = AtomicU64::new(0);
let (echo_tx, mut echo_rx) = mpsc::channel::<Vec<u8>>(4);
let (back_tx, back_rx) = mpsc::channel::<Vec<u8>>(4);
let transport_send: Arc<ChannelTransport> = Arc::new(ChannelTransport {
tx: echo_tx,
rx: Mutex::new(back_rx),
});
let _back_tx_keepalive = back_tx;
let mut ack_buf = Vec::with_capacity(256);
let mut path_validation_seq: u32 = 100;
let obs = Observability::new(ObservabilityConfig::default());
handle_packet(
v2,
session_id,
&server_session,
&streams,
&demux,
&transport_send,
&transport_send,
&deliver_tx,
&undelivered,
&mut ack_buf,
&mut path_validation_seq,
&obs,
LegType::Tcp,
)
.await;
let echo_bytes =
tokio::time::timeout(std::time::Duration::from_millis(200), echo_rx.recv())
.await
.expect("echo should arrive")
.expect("channel open");
let echo_v2 = PhantomPacket::from_wire(&echo_bytes).unwrap();
assert!(echo_v2.header.flags.contains(PacketFlags::PATH_VALIDATION));
assert_eq!(echo_v2.header.path_id, path_id);
assert_eq!(path_validation_seq, 101);
}
#[tokio::test]
async fn zero_rtt_early_data_full_round_trip() {
let server_hs = HandshakeServer::new().unwrap();
let server_pinned_key = server_hs.verifying_key().clone();
let client_ip: std::net::IpAddr = "127.0.0.1".parse().unwrap();
let (c1, s1) = ChannelTransport::pair();
let phase1_session =
PhantomSession::connect_with_transport("test:9000", c1, server_pinned_key.clone());
let hello_bytes = s1.recv_bytes().await.unwrap();
let ch = borsh::from_slice::<ClientHello>(&hello_bytes).unwrap();
let retry = match server_hs.process_client_hello(&ch, 0, client_ip) {
HandshakeResponse::Retry(r) => r,
_ => panic!("expected Retry"),
};
s1.send_bytes(&borsh::to_vec(&retry).unwrap())
.await
.unwrap();
let next = s1.recv_bytes().await.unwrap();
let ch2 = borsh::from_slice::<ClientHello>(&next).unwrap();
match server_hs.process_client_hello(&ch2, 0, client_ip) {
HandshakeResponse::Success(sh, _session, _) => {
s1.send_bytes(&borsh::to_vec(&sh).unwrap()).await.unwrap();
}
_ => panic!("expected Success"),
}
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
assert_eq!(
phase1_session.connection_state(),
ConnectionState::Connected
);
let hint = phase1_session
.resumption_hint()
.await
.expect("phase 1 produced a resumption hint");
let hint = (
<[u8; 32]>::try_from(hint.session_id.as_slice()).expect("session_id is 32 bytes"),
<[u8; 32]>::try_from(hint.resumption_secret.as_slice())
.expect("resumption_secret is 32 bytes"),
);
let early_payload = b"zero-rtt application bytes".to_vec();
let (c2, s2) = ChannelTransport::pair();
let phase2_session = PhantomSession::connect_with_resumption(
"test:9000",
c2,
server_pinned_key.clone(),
hint,
early_payload.clone(),
)
.expect("early_data is within the size cap");
let hello_bytes = s2.recv_bytes().await.unwrap();
let ch3 = borsh::from_slice::<ClientHello>(&hello_bytes).unwrap();
assert!(
ch3.early_data.is_some(),
"phase 2 hello carries sealed 0-RTT early-data"
);
match server_hs.process_client_hello(&ch3, 0, client_ip) {
HandshakeResponse::Success(sh, _session, early_data) => {
assert_eq!(early_data.as_deref(), Some(&early_payload[..]));
assert!(sh.early_data_accepted);
s2.send_bytes(&borsh::to_vec(&sh).unwrap()).await.unwrap();
}
_ => {
panic!("expected Success with accepted early-data — the resumption ticket is fresh")
}
}
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
assert_eq!(
phase2_session.connection_state(),
ConnectionState::Connected
);
assert_eq!(
phase2_session.early_data_accepted().await,
Some(true),
"client must see the server accepted its 0-RTT early-data"
);
drop((s1, s2));
}
#[tokio::test]
async fn connect_pinned_with_resumption_rejects_malformed_hint() {
let server_hs = HandshakeServer::new().unwrap();
let pinned = server_hs.verifying_key().to_bytes();
let bad_hint = ResumptionHint {
session_id: vec![0u8; 5], resumption_secret: vec![0u8; 32],
};
let err = connect_pinned_with_resumption(
"127.0.0.1".to_string(),
9,
pinned,
bad_hint,
Vec::new(),
)
.await
.expect_err("a 5-byte session_id must be rejected");
assert!(
matches!(err, CoreError::ValidationError(_)),
"expected ValidationError, got {err:?}"
);
}
}