use bytes::{Bytes, BytesMut};
use crossbeam_queue::ArrayQueue;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use super::crypto::{session_prefix_from_id, PacketCipher};
use super::protocol::{
EventFrame, NetHeader, PacketFlags, MAX_PACKET_SIZE, MAX_PAYLOAD_SIZE, NONCE_SIZE,
PAYLOAD_LEN_OFFSET,
};
pub struct PacketBuilder {
payload: BytesMut,
cipher: PacketCipher,
packet: BytesMut,
session_id: u64,
origin_hash: u64,
channel_hash: u16,
}
impl PacketBuilder {
pub(crate) fn new(key: &[u8; 32], session_id: u64) -> Self {
Self {
payload: BytesMut::with_capacity(MAX_PAYLOAD_SIZE),
cipher: PacketCipher::new(key, session_id),
packet: BytesMut::with_capacity(MAX_PACKET_SIZE),
session_id,
origin_hash: 0,
channel_hash: 0,
}
}
pub fn with_origin(key: &[u8; 32], session_id: u64, origin_hash: u64) -> Self {
Self {
payload: BytesMut::with_capacity(MAX_PAYLOAD_SIZE),
cipher: PacketCipher::new(key, session_id),
packet: BytesMut::with_capacity(MAX_PACKET_SIZE),
session_id,
origin_hash,
channel_hash: 0,
}
}
pub fn with_shared_counter(
key: &[u8; 32],
session_id: u64,
origin_hash: u64,
tx_counter: Arc<AtomicU64>,
) -> Self {
Self {
payload: BytesMut::with_capacity(MAX_PAYLOAD_SIZE),
cipher: PacketCipher::with_shared_tx_counter(key, session_id, tx_counter),
packet: BytesMut::with_capacity(MAX_PACKET_SIZE),
session_id,
origin_hash,
channel_hash: 0,
}
}
pub fn set_key(&mut self, key: &[u8; 32], session_id: u64) {
self.cipher = PacketCipher::new(key, session_id);
self.session_id = session_id;
}
pub fn set_key_shared(&mut self, key: &[u8; 32], session_id: u64, tx_counter: Arc<AtomicU64>) {
self.cipher = PacketCipher::with_shared_tx_counter(key, session_id, tx_counter);
self.session_id = session_id;
}
pub fn set_origin_hash(&mut self, origin_hash: u64) {
self.origin_hash = origin_hash;
}
pub fn set_channel_hash(&mut self, channel_hash: u16) {
self.channel_hash = channel_hash;
}
#[inline]
pub fn build(
&mut self,
stream_id: u64,
sequence: u64,
events: &[Bytes],
flags: PacketFlags,
) -> Bytes {
assert!(
events.len() <= NetHeader::MAX_EVENTS_PER_PACKET as usize,
"PacketBuilder::build called with {} events; \
MAX_EVENTS_PER_PACKET is {}. The batching layer must \
split before calling build().",
events.len(),
NetHeader::MAX_EVENTS_PER_PACKET,
);
self.payload.clear();
self.packet.clear();
EventFrame::write_events(events, &mut self.payload);
let header = NetHeader::new(
self.session_id,
stream_id,
sequence,
[0u8; NONCE_SIZE],
self.payload.len() as u16,
events.len() as u16,
flags,
)
.with_origin(self.origin_hash)
.with_channel_hash(self.channel_hash);
let aad = header.aad();
let mut header_bytes = header.to_bytes();
let counter = match self.cipher.encrypt_in_place(&aad, &mut self.payload) {
Ok(c) => c,
Err(e) => panic!(
"BUG: ChaCha20-Poly1305 encryption failed (session={:016x}, payload_len={}): {}",
self.session_id,
self.payload.len(),
e
),
};
header_bytes[12..16].copy_from_slice(&session_prefix_from_id(self.session_id));
header_bytes[16..24].copy_from_slice(&counter.to_le_bytes());
debug_assert!(
self.payload.len() - 16 <= u16::MAX as usize,
"payload length {} would truncate the u16 wire field; \
revisit MAX_PAYLOAD_SIZE before raising the cap past u16::MAX + 16",
self.payload.len() - 16,
);
let payload_len = (self.payload.len() - 16) as u16;
header_bytes[PAYLOAD_LEN_OFFSET..PAYLOAD_LEN_OFFSET + 2]
.copy_from_slice(&payload_len.to_le_bytes());
self.packet.extend_from_slice(&header_bytes);
self.packet.extend_from_slice(&self.payload);
self.packet.split().freeze()
}
#[inline]
pub fn build_subprotocol(
&mut self,
stream_id: u64,
sequence: u64,
events: &[Bytes],
flags: PacketFlags,
subprotocol_id: u16,
) -> Bytes {
assert!(
events.len() <= NetHeader::MAX_EVENTS_PER_PACKET as usize,
"PacketBuilder::build_subprotocol called with {} events; \
MAX_EVENTS_PER_PACKET is {}",
events.len(),
NetHeader::MAX_EVENTS_PER_PACKET,
);
self.payload.clear();
self.packet.clear();
EventFrame::write_events(events, &mut self.payload);
let header = NetHeader::new(
self.session_id,
stream_id,
sequence,
[0u8; NONCE_SIZE],
self.payload.len() as u16,
events.len() as u16,
flags,
)
.with_origin(self.origin_hash)
.with_channel_hash(self.channel_hash)
.with_subprotocol(subprotocol_id);
let aad = header.aad();
let mut header_bytes = header.to_bytes();
let counter = match self.cipher.encrypt_in_place(&aad, &mut self.payload) {
Ok(c) => c,
Err(e) => panic!(
"BUG: ChaCha20-Poly1305 encryption failed (session={:016x}): {}",
self.session_id, e
),
};
header_bytes[12..16].copy_from_slice(&session_prefix_from_id(self.session_id));
header_bytes[16..24].copy_from_slice(&counter.to_le_bytes());
debug_assert!(
self.payload.len() - 16 <= u16::MAX as usize,
"payload length {} would truncate the u16 wire field; \
revisit MAX_PAYLOAD_SIZE before raising the cap past u16::MAX + 16",
self.payload.len() - 16,
);
let payload_len = (self.payload.len() - 16) as u16;
header_bytes[PAYLOAD_LEN_OFFSET..PAYLOAD_LEN_OFFSET + 2]
.copy_from_slice(&payload_len.to_le_bytes());
self.packet.extend_from_slice(&header_bytes);
self.packet.extend_from_slice(&self.payload);
self.packet.split().freeze()
}
#[inline]
pub fn build_handshake(&mut self, payload: &[u8]) -> Bytes {
self.packet.clear();
let header = NetHeader::handshake(payload.len() as u16);
self.packet.extend_from_slice(&header.to_bytes());
self.packet.extend_from_slice(payload);
self.packet.split().freeze()
}
#[inline]
pub fn build_heartbeat(&mut self) -> Bytes {
self.payload.clear();
self.packet.clear();
let header = NetHeader::heartbeat(self.session_id);
let aad = header.aad();
let mut header_bytes = header.to_bytes();
let counter = match self.cipher.encrypt_in_place(&aad, &mut self.payload) {
Ok(c) => c,
Err(e) => panic!(
"BUG: heartbeat AEAD encryption failed (session={:016x}): {}",
self.session_id, e
),
};
header_bytes[12..16].copy_from_slice(&session_prefix_from_id(self.session_id));
header_bytes[16..24].copy_from_slice(&counter.to_le_bytes());
let payload_len = 0u16;
header_bytes[PAYLOAD_LEN_OFFSET..PAYLOAD_LEN_OFFSET + 2]
.copy_from_slice(&payload_len.to_le_bytes());
self.packet.extend_from_slice(&header_bytes);
self.packet.extend_from_slice(&self.payload); self.packet.split().freeze()
}
#[inline]
pub fn max_events_for_size(&self, avg_event_size: usize) -> usize {
let frame_overhead = EventFrame::LEN_SIZE;
MAX_PAYLOAD_SIZE / (avg_event_size + frame_overhead)
}
#[inline]
pub fn would_fit(&self, events: &[Bytes]) -> bool {
EventFrame::calculate_size(events) <= MAX_PAYLOAD_SIZE
}
#[inline]
pub fn session_id(&self) -> u64 {
self.session_id
}
}
impl std::fmt::Debug for PacketBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PacketBuilder")
.field("session_id", &format!("{:016x}", self.session_id))
.field("payload_capacity", &self.payload.capacity())
.field("packet_capacity", &self.packet.capacity())
.finish()
}
}
pub struct PacketPool {
builders: ArrayQueue<PacketBuilder>,
key: [u8; 32],
session_id: u64,
origin_hash: u64,
capacity: usize,
tx_counter: Arc<AtomicU64>,
}
impl PacketPool {
pub fn new(size: usize, key: &[u8; 32], session_id: u64) -> Self {
Self::with_origin(size, key, session_id, 0)
}
pub fn with_origin(size: usize, key: &[u8; 32], session_id: u64, origin_hash: u64) -> Self {
let tx_counter = Arc::new(AtomicU64::new(0));
let builders = ArrayQueue::new(size);
for _ in 0..size {
let _ = builders.push(PacketBuilder::with_shared_counter(
key,
session_id,
origin_hash,
tx_counter.clone(),
));
}
Self {
builders,
key: *key,
session_id,
origin_hash,
capacity: size,
tx_counter,
}
}
pub fn set_key(&mut self, key: &[u8; 32], session_id: u64) {
self.key = *key;
self.session_id = session_id;
self.tx_counter = Arc::new(AtomicU64::new(0));
while self.builders.pop().is_some() {}
}
#[inline]
pub fn get(&self) -> PooledBuilder<'_> {
let builder = self.builders.pop().unwrap_or_else(|| {
PacketBuilder::with_shared_counter(
&self.key,
self.session_id,
self.origin_hash,
self.tx_counter.clone(),
)
});
PooledBuilder {
pool: self,
builder: Some(builder),
}
}
#[inline]
pub fn capacity(&self) -> usize {
self.capacity
}
#[inline]
pub fn available(&self) -> usize {
self.builders.len()
}
#[inline]
pub fn session_id(&self) -> u64 {
self.session_id
}
}
impl std::fmt::Debug for PacketPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PacketPool")
.field("capacity", &self.capacity)
.field("available", &self.builders.len())
.field("session_id", &format!("{:016x}", self.session_id))
.finish()
}
}
pub struct PooledBuilder<'a> {
pool: &'a PacketPool,
builder: Option<PacketBuilder>,
}
#[expect(
clippy::expect_used,
reason = "self.builder is Some between construction and Drop::drop; calling these methods after drop is a caller-side use-after-free invariant violation, not a recoverable runtime condition"
)]
impl<'a> PooledBuilder<'a> {
#[inline]
pub fn build(
&mut self,
stream_id: u64,
sequence: u64,
events: &[Bytes],
flags: PacketFlags,
) -> Bytes {
self.builder
.as_mut()
.expect("BUG: PooledBuilder used after drop")
.build(stream_id, sequence, events, flags)
}
#[inline]
pub fn build_handshake(&mut self, payload: &[u8]) -> Bytes {
self.builder
.as_mut()
.expect("BUG: PooledBuilder used after drop")
.build_handshake(payload)
}
#[inline]
pub fn build_heartbeat(&mut self) -> Bytes {
self.builder
.as_mut()
.expect("BUG: PooledBuilder used after drop")
.build_heartbeat()
}
#[inline]
pub fn would_fit(&self, events: &[Bytes]) -> bool {
self.builder
.as_ref()
.expect("BUG: PooledBuilder used after drop")
.would_fit(events)
}
}
impl Drop for PooledBuilder<'_> {
fn drop(&mut self) {
if let Some(mut builder) = self.builder.take() {
if builder.session_id() != self.pool.session_id {
builder.set_key_shared(
&self.pool.key,
self.pool.session_id,
self.pool.tx_counter.clone(),
);
}
builder.set_origin_hash(self.pool.origin_hash);
let _ = self.pool.builders.push(builder);
}
}
}
impl std::fmt::Debug for PooledBuilder<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PooledBuilder")
.field("has_builder", &self.builder.is_some())
.finish()
}
}
use std::cell::RefCell;
use std::sync::Weak;
type LocalBuildersEntry = (Weak<()>, Vec<PacketBuilder>);
thread_local! {
static LOCAL_BUILDERS: RefCell<std::collections::HashMap<u64, LocalBuildersEntry>> =
RefCell::new(std::collections::HashMap::new());
static LOCAL_REAP_COUNTER: std::cell::Cell<u32> = const { std::cell::Cell::new(0) };
}
const REAP_INTERVAL: u32 = 4096;
#[inline]
fn maybe_reap_dead_pools(pools: &mut std::collections::HashMap<u64, LocalBuildersEntry>) {
let should_reap = LOCAL_REAP_COUNTER.with(|c| {
let next = c.get().wrapping_add(1);
c.set(next);
next.is_multiple_of(REAP_INTERVAL)
});
if should_reap {
pools.retain(|_, (weak, _)| weak.strong_count() > 0);
}
}
static NEXT_POOL_ID: AtomicU64 = AtomicU64::new(0);
pub struct ThreadLocalPool {
pool_id: u64,
alive: Arc<()>,
shared: ArrayQueue<PacketBuilder>,
key: [u8; 32],
session_id: u64,
origin_hash: u64,
local_capacity: usize,
capacity: usize,
tx_counter: Arc<AtomicU64>,
}
impl ThreadLocalPool {
pub const DEFAULT_LOCAL_CAPACITY: usize = 8;
pub fn new(size: usize, key: &[u8; 32], session_id: u64) -> Self {
Self::with_local_capacity(size, key, session_id, 0, Self::DEFAULT_LOCAL_CAPACITY)
}
pub fn with_origin(size: usize, key: &[u8; 32], session_id: u64, origin_hash: u64) -> Self {
Self::with_local_capacity(
size,
key,
session_id,
origin_hash,
Self::DEFAULT_LOCAL_CAPACITY,
)
}
pub fn with_local_capacity(
size: usize,
key: &[u8; 32],
session_id: u64,
origin_hash: u64,
local_capacity: usize,
) -> Self {
let tx_counter = Arc::new(AtomicU64::new(0));
let shared = ArrayQueue::new(size);
for _ in 0..size {
let _ = shared.push(PacketBuilder::with_shared_counter(
key,
session_id,
origin_hash,
tx_counter.clone(),
));
}
Self {
pool_id: NEXT_POOL_ID.fetch_add(1, Ordering::Relaxed),
alive: Arc::new(()),
shared,
key: *key,
session_id,
origin_hash,
local_capacity,
capacity: size,
tx_counter,
}
}
#[inline]
pub fn acquire(&self) -> PacketBuilder {
LOCAL_BUILDERS.with(|pools| {
let mut pools = pools.borrow_mut();
maybe_reap_dead_pools(&mut pools);
let entry = pools
.entry(self.pool_id)
.or_insert_with(|| (Arc::downgrade(&self.alive), Vec::new()));
let pool = &mut entry.1;
if let Some(mut builder) = pool.pop() {
if builder.session_id() != self.session_id {
builder.set_key_shared(&self.key, self.session_id, self.tx_counter.clone());
}
builder.set_origin_hash(self.origin_hash);
return builder;
}
let refill_count = self.local_capacity.min(self.shared.len());
for _ in 0..refill_count {
if let Some(b) = self.shared.pop() {
pool.push(b);
} else {
break;
}
}
pool.pop()
.map(|mut b| {
if b.session_id() != self.session_id {
b.set_key_shared(&self.key, self.session_id, self.tx_counter.clone());
}
b.set_origin_hash(self.origin_hash);
b
})
.unwrap_or_else(|| {
PacketBuilder::with_shared_counter(
&self.key,
self.session_id,
self.origin_hash,
self.tx_counter.clone(),
)
})
})
}
#[inline]
pub fn release(&self, mut builder: PacketBuilder) {
if builder.session_id() != self.session_id {
builder.set_key_shared(&self.key, self.session_id, self.tx_counter.clone());
}
builder.set_origin_hash(self.origin_hash);
LOCAL_BUILDERS.with(|pools| {
let mut pools = pools.borrow_mut();
maybe_reap_dead_pools(&mut pools);
let entry = pools
.entry(self.pool_id)
.or_insert_with(|| (Arc::downgrade(&self.alive), Vec::new()));
let pool = &mut entry.1;
if pool.len() < self.local_capacity * 2 {
pool.push(builder);
} else {
let _ = self.shared.push(builder);
}
})
}
#[inline]
pub fn capacity(&self) -> usize {
self.capacity
}
#[inline]
pub fn shared_available(&self) -> usize {
self.shared.len()
}
#[inline]
pub fn session_id(&self) -> u64 {
self.session_id
}
#[inline]
pub fn local_capacity(&self) -> usize {
self.local_capacity
}
}
impl std::fmt::Debug for ThreadLocalPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ThreadLocalPool")
.field("capacity", &self.capacity)
.field("shared_available", &self.shared.len())
.field("local_capacity", &self.local_capacity)
.field("session_id", &format!("{:016x}", self.session_id))
.finish()
}
}
pub struct ThreadLocalPooledBuilder<'a> {
pool: &'a ThreadLocalPool,
builder: Option<PacketBuilder>,
}
#[expect(
clippy::expect_used,
reason = "self.builder is Some between construction and Drop::drop; calling these methods after drop is a caller-side use-after-free invariant violation, not a recoverable runtime condition"
)]
impl<'a> ThreadLocalPooledBuilder<'a> {
#[inline]
pub fn build(
&mut self,
stream_id: u64,
sequence: u64,
events: &[Bytes],
flags: PacketFlags,
) -> Bytes {
self.builder
.as_mut()
.expect("BUG: PooledBuilder used after drop")
.build(stream_id, sequence, events, flags)
}
#[inline]
pub fn build_handshake(&mut self, payload: &[u8]) -> Bytes {
self.builder
.as_mut()
.expect("BUG: PooledBuilder used after drop")
.build_handshake(payload)
}
#[inline]
pub fn build_heartbeat(&mut self) -> Bytes {
self.builder
.as_mut()
.expect("BUG: PooledBuilder used after drop")
.build_heartbeat()
}
#[inline]
pub fn build_subprotocol(
&mut self,
stream_id: u64,
sequence: u64,
events: &[Bytes],
flags: PacketFlags,
subprotocol_id: u16,
) -> Bytes {
self.builder
.as_mut()
.expect("BUG: PooledBuilder used after drop")
.build_subprotocol(stream_id, sequence, events, flags, subprotocol_id)
}
#[inline]
pub fn set_channel_hash(&mut self, channel_hash: u16) {
self.builder
.as_mut()
.expect("BUG: PooledBuilder used after drop")
.set_channel_hash(channel_hash);
}
#[inline]
pub fn set_origin_hash(&mut self, origin_hash: u64) {
self.builder
.as_mut()
.expect("BUG: PooledBuilder used after drop")
.set_origin_hash(origin_hash);
}
#[inline]
pub fn would_fit(&self, events: &[Bytes]) -> bool {
self.builder
.as_ref()
.expect("BUG: PooledBuilder used after drop")
.would_fit(events)
}
}
impl Drop for ThreadLocalPooledBuilder<'_> {
fn drop(&mut self) {
if let Some(builder) = self.builder.take() {
self.pool.release(builder);
}
}
}
impl std::fmt::Debug for ThreadLocalPooledBuilder<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ThreadLocalPooledBuilder")
.field("has_builder", &self.builder.is_some())
.finish()
}
}
pub type SharedLocalPool = Arc<ThreadLocalPool>;
pub fn shared_local_pool(size: usize, key: &[u8; 32], session_id: u64) -> SharedLocalPool {
Arc::new(ThreadLocalPool::new(size, key, session_id))
}
impl ThreadLocalPool {
#[inline]
pub fn get(&self) -> ThreadLocalPooledBuilder<'_> {
ThreadLocalPooledBuilder {
pool: self,
builder: Some(self.acquire()),
}
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::disallowed_methods,
reason = "test code legitimately uses std::sync::{Mutex,RwLock} for SUT setup; tests have no real poison concern"
)]
use super::*;
#[test]
fn test_thread_local_pool_basic() {
let key = [0x42u8; 32];
let session_id = 0x1234567890ABCDEF;
let pool = ThreadLocalPool::new(8, &key, session_id);
assert_eq!(pool.capacity(), 8);
assert_eq!(pool.session_id(), session_id);
assert_eq!(
pool.local_capacity(),
ThreadLocalPool::DEFAULT_LOCAL_CAPACITY
);
}
#[test]
fn packet_builder_with_origin_records_session_and_origin() {
let key = [0x42u8; 32];
let session_id = 0xAA_BBCC_DDEE_FF11;
let origin_hash = 0xCAFEBABE_DEADBEEF;
let builder = PacketBuilder::with_origin(&key, session_id, origin_hash);
assert_eq!(builder.session_id, session_id);
assert_eq!(builder.origin_hash, origin_hash);
}
#[test]
fn packet_pool_new_and_get_round_trip() {
let key = [0x33u8; 32];
let session_id = 0xDEAD_FACE;
let pool = PacketPool::new(4, &key, session_id);
assert_eq!(pool.session_id, session_id);
{
let _b = pool.get();
}
assert_eq!(pool.builders.len(), 4);
}
#[test]
fn thread_local_pool_with_origin_constructs_with_default_local_capacity() {
let key = [0x55u8; 32];
let session_id = 0xFEED_FACE;
let origin_hash = 0x1234_5678;
let pool = ThreadLocalPool::with_origin(8, &key, session_id, origin_hash);
assert_eq!(pool.capacity(), 8);
assert_eq!(pool.session_id(), session_id);
assert_eq!(
pool.local_capacity(),
ThreadLocalPool::DEFAULT_LOCAL_CAPACITY
);
}
#[test]
fn packet_builder_debug_includes_session_id() {
let key = [0x42u8; 32];
let builder = PacketBuilder::new(&key, 0xAB_CDEF);
let s = format!("{:?}", builder);
assert!(s.contains("PacketBuilder"));
assert!(s.contains("0000000000abcdef"), "got: {s}");
}
#[test]
fn packet_pool_debug_includes_capacity_available_and_session() {
let key = [0x42u8; 32];
let pool = PacketPool::new(4, &key, 0x1234_5678_9ABC_DEF0);
let s = format!("{:?}", pool);
assert!(s.contains("PacketPool"));
assert!(s.contains("capacity: 4"));
assert!(s.contains("available"));
assert!(s.contains("123456789abcdef0"), "got: {s}");
}
#[test]
fn thread_local_pool_debug_includes_capacities_and_session() {
let key = [0x42u8; 32];
let pool = ThreadLocalPool::new(8, &key, 0xC0FFEE);
let s = format!("{:?}", pool);
assert!(s.contains("ThreadLocalPool"));
assert!(s.contains("capacity: 8"));
assert!(s.contains("shared_available"));
assert!(s.contains("local_capacity"));
assert!(s.contains("0000000000c0ffee"), "got: {s}");
}
#[test]
fn pooled_builder_delegates_handshake_heartbeat_and_would_fit() {
let key = [0x42u8; 32];
let pool = PacketPool::new(2, &key, 0xABCD_1234);
let mut b = pool.get();
let hs = b.build_handshake(b"hello");
assert!(!hs.is_empty(), "handshake packet should not be empty");
let hb = b.build_heartbeat();
assert!(!hb.is_empty(), "heartbeat packet should not be empty");
assert!(b.would_fit(&[Bytes::from_static(b"x")]));
let big = Bytes::from(vec![0u8; MAX_PAYLOAD_SIZE + 1]);
assert!(!b.would_fit(&[big]));
}
#[test]
fn test_thread_local_pool_acquire_release() {
let key = [0x42u8; 32];
let session_id = 0xDEADBEEF;
let pool = ThreadLocalPool::new(4, &key, session_id);
let builder = pool.acquire();
assert_eq!(builder.session_id(), session_id);
pool.release(builder);
let builder2 = pool.acquire();
assert_eq!(builder2.session_id(), session_id);
pool.release(builder2);
}
#[test]
fn test_thread_local_pool_raii_guard() {
let key = [0x42u8; 32];
let session_id = 0xCAFEBABE;
let pool = ThreadLocalPool::new(4, &key, session_id);
{
let mut builder = pool.get();
let events = vec![Bytes::from_static(b"test event")];
let packet = builder.build(1, 42, &events, PacketFlags::NONE);
let header = NetHeader::from_bytes(&packet).unwrap();
assert_eq!(header.stream_id, 1);
assert_eq!(header.sequence, 42);
assert_eq!(header.event_count, 1);
}
}
#[test]
fn test_thread_local_pool_batch_refill() {
let key = [0x42u8; 32];
let session_id = 0x1111;
let pool = ThreadLocalPool::with_local_capacity(16, &key, session_id, 0, 4);
let mut builders = Vec::new();
for _ in 0..8 {
builders.push(pool.acquire());
}
for b in &builders {
assert_eq!(b.session_id(), session_id);
}
for b in builders {
pool.release(b);
}
}
#[test]
fn test_thread_local_pool_overflow_to_shared() {
let key = [0x42u8; 32];
let session_id = 0x2222;
let pool = ThreadLocalPool::with_local_capacity(8, &key, session_id, 0, 2);
for _ in 0..10 {
let b = pool.acquire();
pool.release(b);
}
let builder = pool.acquire();
assert_eq!(builder.session_id(), session_id);
}
#[test]
fn test_shared_local_pool() {
let key = [0x42u8; 32];
let session_id = 0x3333;
let pool = shared_local_pool(8, &key, session_id);
let pool_clone = pool.clone();
let _b1 = pool.get();
let _b2 = pool_clone.get();
}
#[test]
fn test_regression_pool_builders_share_tx_counter() {
let key = [0x42u8; 32];
let session_id = 0xAAAA;
let pool = PacketPool::new(4, &key, session_id);
let events = vec![Bytes::from_static(b"test")];
let mut b1 = pool.get();
let pkt1 = b1.build(0, 0, &events, PacketFlags::NONE);
drop(b1);
let mut b2 = pool.get();
let pkt2 = b2.build(0, 1, &events, PacketFlags::NONE);
drop(b2);
let nonce1 = &pkt1[12..24];
let nonce2 = &pkt2[12..24];
assert_ne!(
nonce1, nonce2,
"two builders from the same pool must produce different nonces"
);
let counter1 = u64::from_le_bytes(nonce1[4..12].try_into().unwrap());
let counter2 = u64::from_le_bytes(nonce2[4..12].try_into().unwrap());
assert_eq!(counter1, 0, "first builder should use counter 0");
assert_eq!(counter2, 1, "second builder should use counter 1");
}
#[test]
fn test_regression_thread_local_pool_builders_share_tx_counter() {
let key = [0x42u8; 32];
let session_id = 0xBBBB;
let pool = Arc::new(ThreadLocalPool::new(4, &key, session_id));
let pool1 = pool.clone();
let pkt1 = std::thread::spawn(move || {
let mut b = pool1.get();
b.build(0, 0, &[Bytes::from_static(b"test")], PacketFlags::NONE)
})
.join()
.unwrap();
let pool2 = pool.clone();
let pkt2 = std::thread::spawn(move || {
let mut b = pool2.get();
b.build(0, 1, &[Bytes::from_static(b"test")], PacketFlags::NONE)
})
.join()
.unwrap();
let nonce1 = &pkt1[12..24];
let nonce2 = &pkt2[12..24];
assert_ne!(
nonce1, nonce2,
"builders from different threads must produce different nonces"
);
let counter1 = u64::from_le_bytes(nonce1[4..12].try_into().unwrap());
let counter2 = u64::from_le_bytes(nonce2[4..12].try_into().unwrap());
assert_ne!(
counter1, counter2,
"shared counter must prevent nonce reuse across threads"
);
}
#[test]
fn test_concurrent_pool_no_nonce_collision() {
use std::collections::HashSet;
use std::sync::Mutex;
use std::thread;
let key = [0x42u8; 32];
let session_id = 0xCCCC;
let pool = Arc::new(ThreadLocalPool::new(16, &key, session_id));
let nonces = Arc::new(Mutex::new(Vec::new()));
let num_threads = 8;
let packets_per_thread = 100;
let mut handles = Vec::new();
for _ in 0..num_threads {
let pool = pool.clone();
let nonces = nonces.clone();
handles.push(thread::spawn(move || {
let mut local_nonces = Vec::with_capacity(packets_per_thread);
for seq in 0..packets_per_thread {
let mut b = pool.get();
let pkt = b.build(
0,
seq as u64,
&[Bytes::from_static(b"x")],
PacketFlags::NONE,
);
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&pkt[12..24]);
local_nonces.push(nonce);
}
nonces.lock().unwrap().extend(local_nonces);
}));
}
for h in handles {
h.join().unwrap();
}
let all_nonces = nonces.lock().unwrap();
assert_eq!(all_nonces.len(), num_threads * packets_per_thread);
let unique: HashSet<_> = all_nonces.iter().collect();
assert_eq!(
unique.len(),
all_nonces.len(),
"all {} nonces must be unique — found {} duplicates",
all_nonces.len(),
all_nonces.len() - unique.len()
);
}
#[test]
fn test_regression_thread_local_pool_isolation() {
let key_a = [0xAAu8; 32];
let key_b = [0xBBu8; 32];
let session_a = 0x1111;
let session_b = 0x2222;
let pool_a = ThreadLocalPool::new(4, &key_a, session_a);
let pool_b = ThreadLocalPool::new(4, &key_b, session_b);
let builder_a = pool_a.acquire();
assert_eq!(builder_a.session_id(), session_a);
pool_a.release(builder_a);
let builder_b = pool_b.acquire();
assert_eq!(
builder_b.session_id(),
session_b,
"builder acquired from pool B must have pool B's session_id, \
not pool A's — thread-local cache must be keyed by pool_id"
);
pool_b.release(builder_b);
let builder_a2 = pool_a.acquire();
assert_eq!(
builder_a2.session_id(),
session_a,
"builder acquired from pool A after pool B activity must still \
have pool A's session_id"
);
}
#[test]
fn dropped_thread_local_pool_evicts_tls_entry_within_reap_interval() {
let key = [0x33u8; 32];
let pool_a = ThreadLocalPool::new(4, &key, 0xA);
let b = pool_a.acquire();
pool_a.release(b);
let pool_a_id = pool_a.pool_id;
let with_a = LOCAL_BUILDERS.with(|m| m.borrow().contains_key(&pool_a_id));
assert!(
with_a,
"pool A's TLS slot must be populated after acquire/release"
);
drop(pool_a);
let pool_b = ThreadLocalPool::new(4, &key, 0xB);
for _ in 0..REAP_INTERVAL {
let b = pool_b.acquire();
pool_b.release(b);
}
let still_has_a = LOCAL_BUILDERS.with(|m| m.borrow().contains_key(&pool_a_id));
assert!(
!still_has_a,
"pool A's dead TLS entry must be reaped within REAP_INTERVAL \
accesses — pre-fix this leaked forever (production OOM under \
peer churn)"
);
}
#[test]
fn dead_tls_entry_lingers_until_amortized_reap() {
let key = [0x44u8; 32];
LOCAL_REAP_COUNTER.with(|c| c.set(0));
let pool_a = ThreadLocalPool::new(4, &key, 0xA);
let b = pool_a.acquire();
pool_a.release(b);
let pool_a_id = pool_a.pool_id;
drop(pool_a);
let pool_b = ThreadLocalPool::new(4, &key, 0xB);
let b = pool_b.acquire();
pool_b.release(b);
let still_has_a = LOCAL_BUILDERS.with(|m| m.borrow().contains_key(&pool_a_id));
assert!(
still_has_a,
"dead entry must linger until amortized reap fires; \
a regression here is the perf-degrading per-call retain returning"
);
}
#[test]
fn test_regression_set_key_drains_stale_builders() {
use std::collections::HashSet;
let key1 = [0x11u8; 32];
let key2 = [0x22u8; 32];
let session1 = 0xAAAA;
let session2 = 0xBBBB;
let mut pool = PacketPool::new(4, &key1, session1);
let events = vec![Bytes::from_static(b"test")];
let mut nonces_before = Vec::new();
for seq in 0..3u64 {
let mut b = pool.get();
let pkt = b.build(0, seq, &events, PacketFlags::NONE);
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&pkt[12..24]);
nonces_before.push(nonce);
}
pool.set_key(&key2, session2);
assert_eq!(
pool.available(),
0,
"set_key must drain stale builders from the pool"
);
let mut nonces_after = Vec::new();
for seq in 0..3u64 {
let mut b = pool.get();
let pkt = b.build(0, seq, &events, PacketFlags::NONE);
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&pkt[12..24]);
nonces_after.push(nonce);
}
let all_nonces: Vec<_> = nonces_before.iter().chain(&nonces_after).collect();
let unique: HashSet<_> = all_nonces.iter().collect();
assert_eq!(
unique.len(),
all_nonces.len(),
"nonces must not collide across key rotations"
);
}
#[test]
#[should_panic(expected = "PacketBuilder::build called with")]
fn build_panics_on_event_count_exceeding_cap() {
let key = [0x42u8; 32];
let session_id = 0xCAFEu64;
let pool = ThreadLocalPool::new(2, &key, session_id);
let mut builder = pool.get();
let too_many: Vec<Bytes> = (0..=NetHeader::MAX_EVENTS_PER_PACKET as usize)
.map(|_| Bytes::new())
.collect();
let _ = builder.build(0, 0, &too_many, PacketFlags::NONE);
}
#[test]
#[should_panic(expected = "PacketBuilder::build_subprotocol called with")]
fn build_subprotocol_panics_on_event_count_exceeding_cap() {
let key = [0x42u8; 32];
let session_id = 0xBABEu64;
let pool = ThreadLocalPool::new(2, &key, session_id);
let mut builder = pool.get();
let too_many: Vec<Bytes> = (0..=NetHeader::MAX_EVENTS_PER_PACKET as usize)
.map(|_| Bytes::new())
.collect();
let _ = builder.build_subprotocol(0, 0, &too_many, PacketFlags::NONE, 1);
}
#[test]
fn build_accepts_exactly_max_events_per_packet() {
let key = [0x42u8; 32];
let session_id = 0x4242u64;
let pool = ThreadLocalPool::new(2, &key, session_id);
let mut builder = pool.get();
let one_event = vec![Bytes::from_static(b"hi")];
let _ = builder.build(0, 0, &one_event, PacketFlags::NONE);
let cap_events: Vec<Bytes> = (0..NetHeader::MAX_EVENTS_PER_PACKET as usize)
.map(|_| Bytes::new())
.collect();
let _ = builder.build(0, 1, &cap_events, PacketFlags::NONE);
}
}