mod batch;
pub mod behavior;
mod cancel_registry;
pub mod channel;
pub mod compute;
mod config;
pub mod contested;
pub mod continuity;
#[cfg(feature = "cortex")]
pub mod cortex;
mod crypto;
mod failure;
pub mod identity;
mod mesh;
#[cfg(feature = "dataforts")]
pub mod dataforts;
#[cfg(feature = "cortex")]
pub mod mesh_rpc;
#[cfg(feature = "cortex")]
pub mod mesh_rpc_metrics;
#[cfg(feature = "netdb")]
pub mod netdb;
mod pool;
mod protocol;
mod proxy;
#[cfg(feature = "redex")]
pub mod redex;
mod reliability;
mod reroute;
mod route;
mod router;
mod session;
pub mod state;
mod stream;
pub mod subnet;
pub mod subprotocol;
mod swarm;
mod transport;
#[cfg(feature = "nat-traversal")]
pub mod traversal;
#[cfg(target_os = "linux")]
mod linux;
pub use batch::AdaptiveBatcher;
pub use channel::{
AckReason, AuthGuard, AuthVerdict, ChannelConfig, ChannelConfigRegistry, ChannelError,
ChannelHash, ChannelId, ChannelName, ChannelPublisher, ChannelRegistry, MembershipMsg,
OnFailure, PublishConfig, PublishReport, SubscriberRoster, Visibility,
SUBPROTOCOL_CHANNEL_MEMBERSHIP,
};
pub use compute::{
DaemonError, DaemonFactoryRegistry, DaemonHost, DaemonHostConfig, DaemonRegistry, DaemonStats,
FactoryEntry, MeshDaemon, MigrationError, MigrationMessage, MigrationOrchestrator,
MigrationPhase, MigrationSourceHandler, MigrationState, MigrationTargetHandler,
PlacementDecision, Scheduler, SchedulerError, SUBPROTOCOL_MIGRATION,
};
pub use config::{ConnectionRole, NetAdapterConfig, ReliabilityConfig};
pub use contested::{
CorrelatedFailureConfig, CorrelatedFailureDetector, CorrelationVerdict, FailureCause,
PartitionDetector, PartitionPhase, PartitionRecord, ReconcileOutcome, Side,
SUBPROTOCOL_PARTITION,
};
pub use continuity::{
assess_continuity, CausalCone, Causality, ContinuityProof, ContinuityStatus, Discontinuity,
DiscontinuityReason, ForkRecord, HorizonDivergence, ObservationWindow, ProofError,
PropagationModel, SuperpositionPhase, SuperpositionState, SUBPROTOCOL_CONTINUITY,
};
#[cfg(feature = "cortex")]
pub use cortex::{
CortexAdapter, CortexAdapterConfig, CortexAdapterError, EventEnvelope, EventMeta,
FoldErrorPolicy, IntoRedexPayload, StartPosition, EVENT_META_SIZE,
};
pub use crypto::{CryptoError, SessionKeys, StaticKeypair};
pub use failure::{
CircuitBreaker, CircuitState, FailureDetector, FailureDetectorConfig, FailureStats,
LossSimulator, NodeStatus, RecoveryAction, RecoveryManager, RecoveryStats,
};
pub use identity::{
EntityError, EntityId, EntityKeypair, OriginStamp, PermissionToken, TokenCache, TokenError,
TokenScope,
};
pub use mesh::{MeshNode, MeshNodeConfig, PartitionFilter};
#[cfg(feature = "netdb")]
pub use netdb::{MemoriesFilter, NetDb, NetDbBuilder, NetDbError, NetDbSnapshot, TasksFilter};
pub use pool::{PacketBuilder, PacketPool, SharedLocalPool, ThreadLocalPool};
pub use protocol::{
EventFrame, NackPayload, NetHeader, PacketFlags, HEADER_SIZE, NONCE_SIZE, TAG_SIZE,
};
pub use proxy::{
ForwardResult, HopStats, MultiHopPacketBuilder, NetProxy, ProxyConfig, ProxyError, ProxyStats,
};
#[cfg(feature = "redex")]
pub use redex::{
FsyncPolicy, IndexOp, IndexStart, OrderedAppender, Redex, RedexEntry, RedexError, RedexEvent,
RedexFile, RedexFileConfig, RedexFlags, RedexFold, RedexIndex, TypedRedexFile,
};
pub use reliability::{FireAndForget, ReliabilityMode, ReliableStream, RetransmitDescriptor};
pub use reroute::ReroutePolicy;
pub use route::{
AggregateStats, RouteEntry, RouteFlags, RoutingHeader, RoutingTable, SchedulerStreamStats,
ROUTING_HEADER_SIZE,
};
pub use router::{FairScheduler, NetRouter, RouteAction, RouterConfig, RouterError, RouterStats};
pub use session::{NetSession, SessionManager, StreamState, TxAdmit, TxSlotGuard};
pub use state::{
CausalChainBuilder, CausalEvent, CausalLink, ChainError, EntityLog, HorizonEncoder, LogError,
LogIndex, ObservedHorizon, SnapshotStore, StateSnapshot, CAUSAL_LINK_SIZE, SUBPROTOCOL_CAUSAL,
SUBPROTOCOL_SNAPSHOT,
};
pub use stream::{
CloseBehavior, Reliability, Stream, StreamConfig, StreamError, StreamStats,
DEFAULT_STREAM_WINDOW_BYTES,
};
pub use subnet::{DropReason, ForwardDecision, SubnetGateway, SubnetId, SubnetPolicy, SubnetRule};
pub use subprotocol::{
negotiate, MigrationSubprotocolHandler, NegotiatedSet, OutboundMigrationMessage,
SubprotocolDescriptor, SubprotocolManifest, SubprotocolRegistry, SubprotocolVersion,
SUBPROTOCOL_NEGOTIATION,
};
pub use swarm::{
Capabilities, CapabilityAd, EdgeInfo, GraphStats, LocalGraph, NodeInfo, Pingwave,
MAX_GRAPH_NODES, MAX_SEEN_PINGWAVES, PINGWAVE_SIZE,
};
pub use transport::{NetSocket, PacketReceiver, PacketSender, ParsedPacket, SocketBufferConfig};
use async_trait::async_trait;
use bytes::Bytes;
use crossbeam_queue::SegQueue;
use dashmap::DashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::Mutex as TokioMutex;
use tokio::sync::Notify;
use tokio::task::JoinHandle;
use crate::adapter::{Adapter, ShardPollResult};
use crate::error::AdapterError;
use crate::event::{Batch, StoredEvent};
use crypto::NoiseHandshake;
use session::SessionManager as SessionMgr;
use transport::NetSocket as Socket;
pub use routing::{route_to_shard, stream_id_from_bytes, stream_id_from_key};
#[inline]
pub(crate) fn current_timestamp() -> u64 {
let elapsed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default();
u64::try_from(elapsed.as_nanos()).unwrap_or(u64::MAX)
}
#[inline]
pub(crate) fn current_timestamp_micros() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_micros() as u64)
.unwrap_or(0)
}
mod routing {
use xxhash_rust::xxh3::xxh3_64;
#[inline]
pub fn stream_id_from_bytes(data: &[u8]) -> u64 {
xxh3_64(data)
}
#[inline]
pub fn stream_id_from_key(key: &str) -> u64 {
xxh3_64(key.as_bytes())
}
#[inline]
pub fn route_to_shard(data: &[u8], num_shards: u16) -> u16 {
assert!(num_shards > 0, "num_shards must be > 0");
(xxh3_64(data) % num_shards as u64) as u16
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stream_id_deterministic() {
let data = b"test event data";
let id1 = stream_id_from_bytes(data);
let id2 = stream_id_from_bytes(data);
assert_eq!(id1, id2);
}
#[test]
fn test_stream_id_different_for_different_data() {
let id1 = stream_id_from_bytes(b"event1");
let id2 = stream_id_from_bytes(b"event2");
assert_ne!(id1, id2);
}
#[test]
fn test_stream_id_from_key() {
let id = stream_id_from_key("user:12345");
assert_ne!(id, 0);
}
#[test]
fn test_route_to_shard_range() {
let num_shards = 16u16;
for i in 0..1000 {
let data = format!("event_{}", i);
let shard = route_to_shard(data.as_bytes(), num_shards);
assert!(shard < num_shards);
}
}
#[test]
#[should_panic(expected = "num_shards must be > 0")]
fn test_route_to_shard_zero_shards_panics() {
route_to_shard(b"test", 0);
}
#[test]
fn test_route_to_shard_distribution() {
let num_shards = 8u16;
let mut counts = [0u32; 8];
for i in 0..8000 {
let data = format!("event_{}", i);
let shard = route_to_shard(data.as_bytes(), num_shards);
counts[shard as usize] += 1;
}
let expected = 1000;
for count in counts {
assert!(count > expected / 2, "shard count {} too low", count);
assert!(count < expected * 2, "shard count {} too high", count);
}
}
}
}
type InboundQueues = Arc<DashMap<u16, SegQueue<StoredEvent>>>;
pub(crate) struct HandshakePacer {
entries: std::collections::HashMap<std::net::SocketAddr, (u32, std::time::Instant)>,
max_per_window: u32,
window: std::time::Duration,
last_gc: std::time::Instant,
gc_size_threshold: usize,
}
impl HandshakePacer {
pub(crate) fn new(max_per_window: u32, window: std::time::Duration) -> Self {
Self {
entries: std::collections::HashMap::new(),
max_per_window,
window,
last_gc: std::time::Instant::now(),
gc_size_threshold: 4096,
}
}
pub(crate) fn check_and_record(&mut self, source: std::net::SocketAddr) -> bool {
let now = std::time::Instant::now();
if now.duration_since(self.last_gc) >= self.window
|| self.entries.len() >= self.gc_size_threshold
{
let cutoff = self.window.saturating_mul(2);
self.entries
.retain(|_, (_, start)| now.duration_since(*start) < cutoff);
self.last_gc = now;
}
let entry = self.entries.entry(source).or_insert((0, now));
if now.duration_since(entry.1) > self.window {
entry.0 = 0;
entry.1 = now;
}
entry.0 = entry.0.saturating_add(1);
entry.0 <= self.max_per_window
}
}
pub struct NetAdapter {
config: NetAdapterConfig,
socket: Option<Arc<Socket>>,
session: Option<Arc<NetSession>>,
session_manager: SessionMgr,
inbound: InboundQueues,
tasks: TokioMutex<Vec<JoinHandle<()>>>,
shutdown: Arc<AtomicBool>,
shutdown_notify: Arc<Notify>,
initialized: AtomicBool,
handshake_pacer: parking_lot::Mutex<HandshakePacer>,
}
impl NetAdapter {
pub fn new(config: NetAdapterConfig) -> Result<Self, AdapterError> {
config
.validate()
.map_err(|e| AdapterError::Fatal(format!("invalid config: {}", e)))?;
Ok(Self {
session_manager: SessionMgr::new(config.session_timeout),
config,
socket: None,
session: None,
inbound: Arc::new(DashMap::new()),
tasks: TokioMutex::new(Vec::new()),
shutdown: Arc::new(AtomicBool::new(false)),
shutdown_notify: Arc::new(Notify::new()),
initialized: AtomicBool::new(false),
handshake_pacer: parking_lot::Mutex::new(HandshakePacer::new(
5,
std::time::Duration::from_secs(1),
)),
})
}
async fn perform_handshake(
&self,
socket: &Socket,
) -> Result<(SessionKeys, std::net::SocketAddr), AdapterError> {
let mut attempt = 0;
let max_attempts = self.config.handshake_retries;
const HANDSHAKE_RETRY_SLEEP_CAP_MS: u64 = 5_000;
loop {
attempt += 1;
match self.try_handshake(socket).await {
Ok(result) => return Ok(result),
Err(e) if attempt < max_attempts => {
tracing::warn!(
attempt = attempt,
max = max_attempts,
error = %e,
"handshake failed, retrying"
);
let backoff_ms =
(100u64.saturating_mul(attempt as u64)).min(HANDSHAKE_RETRY_SLEEP_CAP_MS);
tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
}
Err(e) => return Err(e),
}
}
}
async fn try_handshake(
&self,
socket: &Socket,
) -> Result<(SessionKeys, std::net::SocketAddr), AdapterError> {
let timeout = self.config.handshake_timeout;
let socket_arc = socket.socket_arc();
if self.config.is_initiator() {
let peer_pubkey = self
.config
.peer_static_pubkey
.as_ref()
.ok_or_else(|| AdapterError::Fatal("missing peer public key".into()))?;
let mut handshake = NoiseHandshake::initiator(&self.config.psk, peer_pubkey)
.map_err(|e| AdapterError::Fatal(format!("handshake init failed: {}", e)))?;
let msg1 = handshake
.write_message(&[])
.map_err(|e| AdapterError::Connection(format!("write_message failed: {}", e)))?;
let mut builder = PacketBuilder::new(&[0u8; 32], 0);
let packet = builder.build_handshake(&msg1);
socket
.send_to(&packet, self.config.peer_addr)
.await
.map_err(|e| AdapterError::Connection(format!("send failed: {}", e)))?;
let (parsed, _source) = tokio::time::timeout(timeout, async {
let mut recv_buf = [0u8; protocol::MAX_PACKET_SIZE];
loop {
let (n, source) = socket_arc
.recv_from(&mut recv_buf)
.await
.map_err(|e| AdapterError::Connection(format!("recv failed: {}", e)))?;
if source != self.config.peer_addr {
continue;
}
let data = bytes::Bytes::copy_from_slice(&recv_buf[..n]);
if let Some(p) = ParsedPacket::parse(data, source) {
if p.header.flags.is_handshake() {
return Ok::<_, AdapterError>((p, source));
}
}
}
})
.await
.map_err(|_| AdapterError::Connection("handshake timeout".into()))??;
handshake
.read_message(&parsed.payload)
.map_err(|e| AdapterError::Connection(format!("read_message failed: {}", e)))?;
let keys = handshake
.into_session_keys()
.map_err(|e| AdapterError::Fatal(format!("key extraction failed: {}", e)))?;
Ok((keys, self.config.peer_addr))
} else {
let keypair = self
.config
.static_keypair
.as_ref()
.ok_or_else(|| AdapterError::Fatal("missing static keypair".into()))?;
let (parsed, source) = tokio::time::timeout(timeout, async {
loop {
let mut recv_buf = bytes::BytesMut::with_capacity(protocol::MAX_PACKET_SIZE);
recv_buf.resize(protocol::MAX_PACKET_SIZE, 0);
let (n, source) = socket_arc
.recv_from(&mut recv_buf)
.await
.map_err(|e| AdapterError::Connection(format!("recv failed: {}", e)))?;
recv_buf.truncate(n);
let data = recv_buf.freeze();
if let Some(p) = ParsedPacket::parse(data, source) {
if p.header.flags.is_handshake() {
let allowed = self.handshake_pacer.lock().check_and_record(source);
if !allowed {
tracing::debug!(
%source,
"handshake responder: dropping packet from \
rate-limited source"
);
continue;
}
return Ok::<_, AdapterError>((p, source));
}
}
}
})
.await
.map_err(|_| AdapterError::Connection("handshake timeout".into()))??;
let mut handshake = NoiseHandshake::responder(&self.config.psk, keypair)
.map_err(|e| AdapterError::Fatal(format!("handshake init failed: {}", e)))?;
handshake
.read_message(&parsed.payload)
.map_err(|e| AdapterError::Connection(format!("read_message failed: {}", e)))?;
let msg2 = handshake
.write_message(&[])
.map_err(|e| AdapterError::Connection(format!("write_message failed: {}", e)))?;
let mut builder = PacketBuilder::new(&[0u8; 32], 0);
let packet = builder.build_handshake(&msg2);
socket
.send_to(&packet, source)
.await
.map_err(|e| AdapterError::Connection(format!("send failed: {}", e)))?;
let keys = handshake
.into_session_keys()
.map_err(|e| AdapterError::Fatal(format!("key extraction failed: {}", e)))?;
Ok((keys, source))
}
}
fn process_packet(
data: Bytes,
source: std::net::SocketAddr,
session: &NetSession,
inbound: &InboundQueues,
num_shards: u16,
) {
let mut parsed = match ParsedPacket::parse(data, source) {
Some(p) => p,
None => return,
};
if !parsed.header.flags.is_handshake()
&& !parsed.header.flags.is_heartbeat()
&& !parsed.is_valid_length()
{
return;
}
if parsed.header.flags.is_handshake() {
return;
}
if parsed.header.session_id != session.session_id() {
return;
}
if parsed.header.flags.is_heartbeat() {
if source == session.peer_addr() {
session.verify_and_touch_heartbeat(&parsed);
}
return;
}
let aad = parsed.header.aad();
let counter = u64::from_le_bytes(parsed.header.nonce[4..12].try_into().unwrap_or([0u8; 8]));
let rx_cipher = session.rx_cipher();
let payload = std::mem::take(&mut parsed.payload);
let decrypted = match rx_cipher.decrypt_to_bytes(counter, &aad, payload) {
Ok(d) => {
if !rx_cipher.try_admit_rx_counter(counter) {
return;
}
d
}
Err(_) => return,
};
let events = EventFrame::read_events(decrypted, parsed.header.event_count);
let stream_id = parsed.header.stream_id;
let shard_id = if num_shards > 0 {
(stream_id % num_shards as u64) as u16
} else {
0
};
let is_fresh = {
let stream = session.get_or_create_stream(stream_id);
let fresh = stream.with_reliability(|r| r.on_receive(parsed.header.sequence));
stream.update_rx_seq(parsed.header.sequence);
fresh
};
if is_fresh {
let queue = inbound.entry(shard_id).or_default();
let seq = parsed.header.sequence;
for (i, event_data) in events.into_iter().enumerate() {
use std::fmt::Write;
let mut event_id = String::with_capacity(24);
let _ = write!(event_id, "{}:{}", seq, i);
queue.push(StoredEvent::new(event_id, event_data, seq, shard_id));
}
} else {
tracing::debug!(
seq = parsed.header.sequence,
stream_id,
"Dropping duplicate packet"
);
}
session.touch();
}
#[cfg(target_os = "linux")]
fn spawn_receiver(
shutdown: Arc<AtomicBool>,
shutdown_notify: Arc<Notify>,
socket: Arc<Socket>,
session: Arc<NetSession>,
inbound: InboundQueues,
num_shards: u16,
) -> JoinHandle<()> {
let mut receiver = transport::BatchedPacketReceiver::new(socket.socket_arc());
tokio::spawn(async move {
while !shutdown.load(Ordering::Acquire) {
tokio::select! {
result = receiver.recv() => {
match result {
Ok((data, source)) => {
Self::process_packet(data, source, &session, &inbound, num_shards);
}
Err(e) if e.kind() == std::io::ErrorKind::ConnectionReset => {
tracing::warn!("batch receiver thread exited, stopping receiver");
break;
}
Err(e) => {
if !shutdown.load(Ordering::Acquire) {
tracing::warn!(error = %e, "receive error");
}
}
}
}
_ = shutdown_notify.notified() => {
break;
}
}
}
})
}
#[cfg(not(target_os = "linux"))]
fn spawn_receiver(
shutdown: Arc<AtomicBool>,
shutdown_notify: Arc<Notify>,
socket: Arc<Socket>,
session: Arc<NetSession>,
inbound: InboundQueues,
num_shards: u16,
) -> JoinHandle<()> {
tokio::spawn(async move {
let mut receiver = PacketReceiver::new(socket.socket_arc());
while !shutdown.load(Ordering::Acquire) {
tokio::select! {
result = receiver.recv() => {
match result {
Ok((data, source)) => {
Self::process_packet(data, source, &session, &inbound, num_shards);
}
Err(e) => {
if !shutdown.load(Ordering::Acquire) {
tracing::warn!(error = %e, "receive error");
}
}
}
}
_ = shutdown_notify.notified() => {
break;
}
}
}
})
}
fn spawn_heartbeat(
shutdown: Arc<AtomicBool>,
shutdown_notify: Arc<Notify>,
socket: Arc<Socket>,
session: Arc<NetSession>,
interval: std::time::Duration,
peer_addr: std::net::SocketAddr,
) -> JoinHandle<()> {
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
loop {
tokio::select! {
_ = ticker.tick() => {
if shutdown.load(Ordering::Acquire) || !session.is_active() {
break;
}
let packet = session.build_heartbeat();
if let Err(e) = socket.send_to(&packet, peer_addr).await {
tracing::warn!(error = %e, "heartbeat send failed");
}
}
_ = shutdown_notify.notified() => {
break;
}
}
}
})
}
}
#[async_trait]
impl Adapter for NetAdapter {
async fn init(&mut self) -> Result<(), AdapterError> {
if self.initialized.load(Ordering::Acquire) {
return Ok(());
}
let socket_config = match (
self.config.socket_recv_buffer,
self.config.socket_send_buffer,
) {
(Some(recv), Some(send)) => transport::SocketBufferConfig {
recv_buffer_size: recv,
send_buffer_size: send,
},
_ => transport::SocketBufferConfig::default(),
};
let socket = Socket::with_config(self.config.bind_addr, socket_config)
.await
.map_err(|e| AdapterError::Connection(format!("socket creation failed: {}", e)))?;
let socket = Arc::new(socket);
self.socket = Some(socket.clone());
let (keys, actual_peer) = self.perform_handshake(&socket).await?;
let session = Arc::new(NetSession::new(
keys,
actual_peer,
self.config.packet_pool_size,
self.config.default_reliability.is_reliable(),
));
self.session = Some(session.clone());
self.session_manager.set_session_arc(session.clone());
let recv_task = Self::spawn_receiver(
self.shutdown.clone(),
self.shutdown_notify.clone(),
socket.clone(),
session.clone(),
self.inbound.clone(),
self.config.num_shards,
);
let heartbeat_task = Self::spawn_heartbeat(
self.shutdown.clone(),
self.shutdown_notify.clone(),
socket,
session,
self.config.heartbeat_interval,
actual_peer,
);
{
let mut tasks = self.tasks.lock().await;
tasks.push(recv_task);
tasks.push(heartbeat_task);
}
self.initialized.store(true, Ordering::Release);
tracing::info!(
bind_addr = %self.config.bind_addr,
peer_addr = %self.config.peer_addr,
role = ?self.config.role,
"Net adapter initialized"
);
Ok(())
}
async fn on_batch(&self, batch: std::sync::Arc<Batch>) -> Result<(), AdapterError> {
let session = self
.session
.as_ref()
.ok_or_else(|| AdapterError::Connection("not connected".into()))?;
let socket = self
.socket
.as_ref()
.ok_or_else(|| AdapterError::Connection("socket not initialized".into()))?;
let stream_id = batch.shard_id as u64;
let peer_addr = session.peer_addr();
let reliable = {
let stream = session.get_or_create_stream(stream_id);
stream.with_reliability(|r| r.needs_ack())
};
let mut current_batch: Vec<Bytes> = Vec::with_capacity(64);
let mut current_size = 0usize;
let pool = session.thread_local_pool();
let mut builder = pool.get();
for event in &batch.events {
let event_bytes = event.raw.clone();
let frame_size = EventFrame::LEN_SIZE + event_bytes.len();
if current_size + frame_size > protocol::MAX_PAYLOAD_SIZE && !current_batch.is_empty() {
let seq;
{
let stream = session.get_or_create_stream(stream_id);
seq = stream.next_tx_seq();
}
let flags = if reliable {
PacketFlags::RELIABLE
} else {
PacketFlags::NONE
};
let packet = builder.build(stream_id, seq, ¤t_batch, flags);
socket
.send_to(&packet, peer_addr)
.await
.map_err(|e| AdapterError::Connection(format!("send failed: {}", e)))?;
if reliable {
let descriptor = std::sync::Arc::new(reliability::RetransmitDescriptor {
seq,
stream_id,
events: current_batch.clone(),
flags,
});
let stream = session.get_or_create_stream(stream_id);
stream.with_reliability(|r| r.on_send(descriptor));
}
current_batch.clear();
current_size = 0;
}
current_batch.push(event_bytes);
current_size += frame_size;
}
if !current_batch.is_empty() {
let seq;
{
let stream = session.get_or_create_stream(stream_id);
seq = stream.next_tx_seq();
}
let flags = if reliable {
PacketFlags::RELIABLE
} else {
PacketFlags::NONE
};
let packet = builder.build(stream_id, seq, ¤t_batch, flags);
socket
.send_to(&packet, peer_addr)
.await
.map_err(|e| AdapterError::Connection(format!("send failed: {}", e)))?;
if reliable {
let descriptor = std::sync::Arc::new(reliability::RetransmitDescriptor {
seq,
stream_id,
events: current_batch.clone(),
flags,
});
let stream = session.get_or_create_stream(stream_id);
stream.with_reliability(|r| r.on_send(descriptor));
}
}
session.touch();
Ok(())
}
async fn poll_shard(
&self,
shard_id: u16,
from_id: Option<&str>,
limit: usize,
) -> Result<ShardPollResult, AdapterError> {
let mut events = Vec::with_capacity(limit);
if let Some(queue) = self.inbound.get(&shard_id) {
while events.len() < limit {
if let Some(event) = queue.pop() {
if from_id.is_none() || event_id_gt(&event.id, from_id.unwrap_or("")) {
events.push(event);
}
} else {
break;
}
}
}
let has_more = self
.inbound
.get(&shard_id)
.map(|q| !q.is_empty())
.unwrap_or(false);
let next_id = events.last().map(|e| e.id.clone());
Ok(ShardPollResult {
events,
next_id,
has_more,
})
}
async fn flush(&self) -> Result<(), AdapterError> {
Ok(())
}
async fn shutdown(&self) -> Result<(), AdapterError> {
self.shutdown.store(true, Ordering::Release);
self.shutdown_notify.notify_waiters();
self.session_manager.clear_session();
let mut tasks = self.tasks.lock().await;
for task in tasks.drain(..) {
let _ = task.await;
}
self.initialized.store(false, Ordering::Release);
tracing::info!("Net adapter shutdown complete");
Ok(())
}
fn name(&self) -> &'static str {
"net"
}
async fn is_healthy(&self) -> bool {
self.initialized.load(Ordering::Acquire) && self.session_manager.check_session()
}
}
impl std::fmt::Debug for NetAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NetAdapter")
.field("config", &self.config)
.field("initialized", &self.initialized.load(Ordering::Relaxed))
.finish()
}
}
fn event_id_gt(a: &str, b: &str) -> bool {
fn parse_id(id: &str) -> Option<(u64, u64)> {
let (seq, idx) = id.split_once(':')?;
Some((seq.parse().ok()?, idx.parse().ok()?))
}
match (parse_id(a), parse_id(b)) {
(Some(a), Some(b)) => a > b,
_ => a > b, }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_adapter_creation() {
let psk = [0x42u8; 32];
let peer_pubkey = [0x24u8; 32];
let config = NetAdapterConfig::initiator(
"127.0.0.1:0".parse().unwrap(),
"127.0.0.1:9999".parse().unwrap(),
psk,
peer_pubkey,
);
let adapter = NetAdapter::new(config).unwrap();
assert_eq!(adapter.name(), "net");
}
#[test]
fn test_shard_id_from_stream_id_uses_modulo() {
let num_shards: u16 = 8;
let stream_a: u64 = 0xDEAD_BEEF_0000_0003;
let stream_b: u64 = 0xCAFE_BABE_0000_0003;
let shard_a = (stream_a % num_shards as u64) as u16;
let shard_b = (stream_b % num_shards as u64) as u16;
assert!(
shard_a < num_shards,
"shard must be in range [0, num_shards)"
);
assert!(
shard_b < num_shards,
"shard must be in range [0, num_shards)"
);
let big_stream: u64 = 0xFFFF_FFFF_FFFF_FFFF;
let shard_big = (big_stream % num_shards as u64) as u16;
assert!(shard_big < num_shards);
assert_ne!(
big_stream as u16, shard_big,
"modulo must differ from truncation for large stream IDs"
);
}
#[test]
fn test_invalid_config() {
let psk = [0x42u8; 32];
let peer_pubkey = [0x24u8; 32];
let mut config = NetAdapterConfig::initiator(
"127.0.0.1:0".parse().unwrap(),
"127.0.0.1:9999".parse().unwrap(),
psk,
peer_pubkey,
);
config.peer_static_pubkey = None;
let result = NetAdapter::new(config);
assert!(result.is_err());
}
#[test]
fn test_event_id_gt_numeric_ordering() {
assert!(event_id_gt("2:0", "1:0"));
assert!(!event_id_gt("1:0", "2:0"));
assert!(!event_id_gt("1:0", "1:0"));
assert!(event_id_gt("10:0", "9:0"));
assert!(event_id_gt("100:0", "99:0"));
assert!(!event_id_gt("9:0", "10:0"));
assert!(event_id_gt("5:2", "5:1"));
assert!(!event_id_gt("5:1", "5:2"));
assert!(event_id_gt("1000000:0", "999999:0"));
}
#[test]
fn test_event_id_gt_edge_cases() {
assert!(event_id_gt("1:0", ""));
assert!(event_id_gt("b", "a"));
assert!(!event_id_gt("a", "b"));
}
#[test]
fn test_build_then_process_packet_roundtrip() {
use crate::adapter::net::crypto::{NoiseHandshake, StaticKeypair};
use dashmap::DashMap;
use std::sync::Arc;
let psk = [0x42u8; 32];
let responder_kp = StaticKeypair::generate();
let mut initiator = NoiseHandshake::initiator(&psk, &responder_kp.public).unwrap();
let mut responder = NoiseHandshake::responder(&psk, &responder_kp).unwrap();
let msg1 = initiator.write_message(&[]).unwrap();
responder.read_message(&msg1).unwrap();
let msg2 = responder.write_message(&[]).unwrap();
initiator.read_message(&msg2).unwrap();
let init_keys = initiator.into_session_keys().unwrap();
let resp_keys = responder.into_session_keys().unwrap();
let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
let events = vec![
Bytes::from(r#"{"token":"hello"}"#),
Bytes::from(r#"{"token":"world"}"#),
];
let packet = builder.build(0, 0, &events, PacketFlags::NONE);
let resp_session = Arc::new(NetSession::new(
resp_keys,
"127.0.0.1:5000".parse().unwrap(),
4,
false,
));
let inbound: InboundQueues = Arc::new(DashMap::new());
let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
let queue = inbound.get(&0).expect("shard 0 should have events");
assert_eq!(queue.len(), 2, "expected 2 events, got {}", queue.len());
let e1 = queue.pop().unwrap();
assert_eq!(&e1.raw[..], br#"{"token":"hello"}"#);
let e2 = queue.pop().unwrap();
assert_eq!(&e2.raw[..], br#"{"token":"world"}"#);
}
fn make_session_keys() -> (SessionKeys, SessionKeys) {
use crate::adapter::net::crypto::{NoiseHandshake, StaticKeypair};
let psk = [0x42u8; 32];
let responder_kp = StaticKeypair::generate();
let mut initiator = NoiseHandshake::initiator(&psk, &responder_kp.public).unwrap();
let mut responder = NoiseHandshake::responder(&psk, &responder_kp).unwrap();
let msg1 = initiator.write_message(&[]).unwrap();
responder.read_message(&msg1).unwrap();
let msg2 = responder.write_message(&[]).unwrap();
initiator.read_message(&msg2).unwrap();
(
initiator.into_session_keys().unwrap(),
responder.into_session_keys().unwrap(),
)
}
#[test]
fn test_process_packet_rejects_truncated_packet() {
use dashmap::DashMap;
use std::sync::Arc;
let (init_keys, resp_keys) = make_session_keys();
let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
let packet = builder.build(0, 0, &[Bytes::from_static(b"hello")], PacketFlags::NONE);
let resp_session = Arc::new(NetSession::new(
resp_keys,
"127.0.0.1:5000".parse().unwrap(),
4,
false,
));
let inbound: InboundQueues = Arc::new(DashMap::new());
let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
let truncated = packet.slice(..packet.len() - 10);
NetAdapter::process_packet(truncated, source, &resp_session, &inbound, 1);
assert!(
inbound.get(&0).is_none() || inbound.get(&0).unwrap().is_empty(),
"truncated packet must be silently dropped"
);
}
#[test]
fn test_process_packet_rejects_tampered_payload() {
use dashmap::DashMap;
use std::sync::Arc;
let (init_keys, resp_keys) = make_session_keys();
let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
let packet = builder.build(0, 0, &[Bytes::from_static(b"hello")], PacketFlags::NONE);
let resp_session = Arc::new(NetSession::new(
resp_keys,
"127.0.0.1:5000".parse().unwrap(),
4,
false,
));
let inbound: InboundQueues = Arc::new(DashMap::new());
let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
let mut tampered = bytes::BytesMut::from(&packet[..]);
tampered[super::protocol::HEADER_SIZE + 2] ^= 0xFF;
NetAdapter::process_packet(tampered.freeze(), source, &resp_session, &inbound, 1);
assert!(
inbound.get(&0).is_none() || inbound.get(&0).unwrap().is_empty(),
"tampered packet must be rejected by AEAD"
);
}
#[test]
fn test_process_packet_rejects_wrong_session_id() {
use dashmap::DashMap;
use std::sync::Arc;
let (init_keys, resp_keys) = make_session_keys();
let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
let packet = builder.build(0, 0, &[Bytes::from_static(b"hello")], PacketFlags::NONE);
let mut wrong_keys = resp_keys;
wrong_keys.session_id = 0xDEAD;
let resp_session = Arc::new(NetSession::new(
wrong_keys,
"127.0.0.1:5000".parse().unwrap(),
4,
false,
));
let inbound: InboundQueues = Arc::new(DashMap::new());
let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
assert!(
inbound.get(&0).is_none() || inbound.get(&0).unwrap().is_empty(),
"packet with wrong session_id must be dropped"
);
}
#[test]
fn test_process_packet_multi_packet_batch_all_events_arrive() {
use dashmap::DashMap;
use std::sync::Arc;
let (init_keys, resp_keys) = make_session_keys();
let resp_session = Arc::new(NetSession::new(
resp_keys,
"127.0.0.1:5000".parse().unwrap(),
4,
false,
));
let inbound: InboundQueues = Arc::new(DashMap::new());
let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
let total_events = 200;
let mut seq = 0u64;
let mut current_batch: Vec<Bytes> = Vec::new();
let mut current_size = 0;
for i in 0..total_events {
let data = format!("{{\"i\":{},\"pad\":\"{}\"}}", i, "x".repeat(150));
let event_bytes = Bytes::from(data);
let frame_size = EventFrame::LEN_SIZE + event_bytes.len();
if current_size + frame_size > protocol::MAX_PAYLOAD_SIZE && !current_batch.is_empty() {
let packet = builder.build(0, seq, ¤t_batch, PacketFlags::NONE);
NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
seq += 1;
current_batch.clear();
current_size = 0;
}
current_batch.push(event_bytes);
current_size += frame_size;
}
if !current_batch.is_empty() {
let packet = builder.build(0, seq, ¤t_batch, PacketFlags::NONE);
NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
}
let queue = inbound.get(&0).expect("shard 0 should have events");
assert_eq!(
queue.len(),
total_events,
"all {} events must arrive across multiple packets",
total_events
);
}
#[test]
fn test_build_then_process_packet_both_directions() {
use dashmap::DashMap;
use std::sync::Arc;
let (init_keys, resp_keys) = make_session_keys();
let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
{
let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
let packet = builder.build(0, 0, &[Bytes::from_static(b"i2r")], PacketFlags::NONE);
let session = Arc::new(NetSession::new(resp_keys.clone(), source, 4, false));
let inbound: InboundQueues = Arc::new(DashMap::new());
NetAdapter::process_packet(packet, source, &session, &inbound, 1);
let queue = inbound.get(&0).expect("i2r: shard 0 should have events");
assert_eq!(queue.len(), 1, "i2r: expected 1 event");
assert_eq!(&queue.pop().unwrap().raw[..], b"i2r");
}
{
let mut builder = PacketBuilder::new(&resp_keys.tx_key, resp_keys.session_id);
let packet = builder.build(0, 0, &[Bytes::from_static(b"r2i")], PacketFlags::NONE);
let session = Arc::new(NetSession::new(init_keys.clone(), source, 4, false));
let inbound: InboundQueues = Arc::new(DashMap::new());
NetAdapter::process_packet(packet, source, &session, &inbound, 1);
let queue = inbound.get(&0).expect("r2i: shard 0 should have events");
assert_eq!(queue.len(), 1, "r2i: expected 1 event");
assert_eq!(&queue.pop().unwrap().raw[..], b"r2i");
}
}
#[test]
fn test_poll_shard_cursor_drops_consumed_events() {
use std::sync::Arc;
let (init_keys, resp_keys) = make_session_keys();
let resp_session = Arc::new(NetSession::new(
resp_keys,
"127.0.0.1:5000".parse().unwrap(),
4,
false,
));
let inbound: InboundQueues = Arc::new(DashMap::new());
let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
for seq in 0..3u64 {
let events = vec![Bytes::from(format!("event-{}", seq))];
let packet = builder.build(0, seq, &events, PacketFlags::NONE);
NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
}
let queue = inbound.get(&0u16).unwrap();
assert_eq!(queue.len(), 3);
let from_id = "0:0";
let mut events = Vec::new();
while events.len() < 10 {
if let Some(event) = queue.pop() {
if event_id_gt(&event.id, from_id) {
events.push(event);
}
} else {
break;
}
}
assert_eq!(events.len(), 2, "should get 2 events after cursor 0:0");
assert_eq!(events[0].id, "1:0");
assert_eq!(events[1].id, "2:0");
assert_eq!(queue.len(), 0, "queue should be empty after poll drains it");
}
#[test]
fn test_process_packet_old_counter_rejected() {
use std::sync::Arc;
let (init_keys, resp_keys) = make_session_keys();
let resp_session = Arc::new(NetSession::new(
resp_keys,
"127.0.0.1:5000".parse().unwrap(),
4,
false,
));
let inbound: InboundQueues = Arc::new(DashMap::new());
let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
for seq in 0..1100u64 {
let packet = builder.build(0, seq, &[Bytes::from_static(b"x")], PacketFlags::NONE);
NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
}
assert_eq!(inbound.get(&0).unwrap().len(), 1100);
let mut stale_builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
let stale_packet =
stale_builder.build(0, 9999, &[Bytes::from_static(b"stale")], PacketFlags::NONE);
NetAdapter::process_packet(stale_packet, source, &resp_session, &inbound, 1);
assert_eq!(
inbound.get(&0).unwrap().len(),
1100,
"packet with stale counter must be rejected"
);
}
#[test]
fn test_process_packet_far_future_counter_rejected() {
use std::sync::Arc;
let (_init_keys, resp_keys) = make_session_keys();
let resp_session = Arc::new(NetSession::new(
resp_keys,
"127.0.0.1:5000".parse().unwrap(),
4,
false,
));
let rx_cipher = resp_session.rx_cipher();
assert!(
!rx_cipher.is_valid_rx_counter(u64::MAX),
"counter at u64::MAX must be rejected (far beyond MAX_FORWARD)"
);
assert!(
rx_cipher.is_valid_rx_counter(0),
"counter 0 should be valid initially"
);
}
#[test]
fn process_packet_drops_duplicates_per_reliability_decision() {
use dashmap::DashMap;
use std::sync::Arc;
let (init_keys, resp_keys) = make_session_keys();
let resp_session = Arc::new(NetSession::new(
resp_keys,
"127.0.0.1:5000".parse().unwrap(),
4,
true, ));
let inbound: InboundQueues = Arc::new(DashMap::new());
let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
let packet0 = builder.build(7, 0, &[Bytes::from(r#"{"first":0}"#)], PacketFlags::NONE);
let packet1 = builder.build(7, 1, &[Bytes::from(r#"{"first":1}"#)], PacketFlags::NONE);
let packet0_dup = builder.build(
7,
0,
&[Bytes::from(r#"{"dup":"should_not_appear"}"#)],
PacketFlags::NONE,
);
NetAdapter::process_packet(packet0, source, &resp_session, &inbound, 1);
NetAdapter::process_packet(packet1, source, &resp_session, &inbound, 1);
NetAdapter::process_packet(packet0_dup, source, &resp_session, &inbound, 1);
let queue = inbound.get(&0).expect("shard 0 should exist");
assert_eq!(
queue.len(),
2,
"duplicate packet must NOT enqueue (BUG_REPORT.md #5); \
got {} events, expected exactly 2 (seq=0 and seq=1, no dup)",
queue.len()
);
let e0 = queue.pop().unwrap();
assert_eq!(&e0.raw[..], br#"{"first":0}"#);
let e1 = queue.pop().unwrap();
assert_eq!(&e1.raw[..], br#"{"first":1}"#);
assert!(queue.is_empty());
}
#[test]
fn heartbeat_is_aead_authenticated() {
use crate::adapter::net::pool::PacketBuilder;
use dashmap::DashMap;
use std::sync::Arc;
let (init_keys, resp_keys) = make_session_keys();
let resp_session = Arc::new(NetSession::new(
resp_keys,
"127.0.0.1:5000".parse().unwrap(),
4,
false,
));
let inbound: InboundQueues = Arc::new(DashMap::new());
let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
let heartbeat = builder.build_heartbeat();
let last_activity_before = resp_session.last_activity_ns();
std::thread::sleep(std::time::Duration::from_millis(2));
NetAdapter::process_packet(heartbeat, source, &resp_session, &inbound, 1);
let last_activity_after = resp_session.last_activity_ns();
assert!(
last_activity_after > last_activity_before,
"legitimate AEAD-tagged heartbeat must call session.touch()"
);
let mut forged = bytes::BytesMut::new();
let header = NetHeader::heartbeat(resp_session.session_id());
forged.extend_from_slice(&header.to_bytes());
let forged = forged.freeze();
let last_activity_before = resp_session.last_activity_ns();
std::thread::sleep(std::time::Duration::from_millis(2));
NetAdapter::process_packet(forged, source, &resp_session, &inbound, 1);
let last_activity_after = resp_session.last_activity_ns();
assert_eq!(
last_activity_before, last_activity_after,
"unauthenticated heartbeat (no AEAD tag) must NOT touch the session"
);
let mut forged_tag = bytes::BytesMut::new();
let mut header_bytes = NetHeader::heartbeat(resp_session.session_id()).to_bytes();
header_bytes[12..16].copy_from_slice(&[0u8; 4]);
header_bytes[16..24].copy_from_slice(&1u64.to_le_bytes());
forged_tag.extend_from_slice(&header_bytes);
forged_tag.extend_from_slice(&[0xAAu8; 16]); let forged_tag = forged_tag.freeze();
let last_activity_before = resp_session.last_activity_ns();
std::thread::sleep(std::time::Duration::from_millis(2));
NetAdapter::process_packet(forged_tag, source, &resp_session, &inbound, 1);
let last_activity_after = resp_session.last_activity_ns();
assert_eq!(
last_activity_before, last_activity_after,
"heartbeat with garbage AEAD tag must NOT touch the session"
);
}
#[test]
fn handshake_pacer_rejects_floods_per_source() {
use std::time::Duration;
let mut pacer = HandshakePacer::new(3, Duration::from_millis(50));
let attacker: std::net::SocketAddr = "10.0.0.1:9000".parse().unwrap();
let legit: std::net::SocketAddr = "10.0.0.2:9000".parse().unwrap();
for _ in 0..3 {
assert!(pacer.check_and_record(attacker));
}
for _ in 0..10 {
assert!(
!pacer.check_and_record(attacker),
"attacker exceeding budget must be dropped"
);
}
assert!(
pacer.check_and_record(legit),
"legitimate source must still get through despite attacker flood"
);
std::thread::sleep(Duration::from_millis(55));
assert!(
pacer.check_and_record(attacker),
"attacker budget must refill after window"
);
}
}