1mod batch;
34pub mod behavior;
35#[cfg_attr(not(feature = "cortex"), allow(dead_code))]
42mod cancel_registry;
43pub mod channel;
44pub mod compute;
45mod config;
46pub mod contested;
47pub mod continuity;
48#[cfg(feature = "cortex")]
49pub mod cortex;
50mod crypto;
51mod failure;
52pub mod identity;
53mod mesh;
54#[cfg(feature = "dataforts")]
61pub mod dataforts;
62#[cfg(feature = "cortex")]
63pub mod mesh_rpc;
64#[cfg(feature = "cortex")]
65pub mod mesh_rpc_metrics;
66#[cfg(feature = "netdb")]
67pub mod netdb;
68mod pool;
69mod protocol;
70mod proxy;
71#[cfg(feature = "redex")]
72pub mod redex;
73mod reliability;
74mod reroute;
75mod route;
76mod router;
77mod session;
78pub mod state;
79mod stream;
80pub mod subnet;
81pub mod subprotocol;
82mod swarm;
83mod transport;
84#[cfg(feature = "nat-traversal")]
85pub mod traversal;
86
87#[cfg(target_os = "linux")]
88mod linux;
89
90pub use batch::AdaptiveBatcher;
91pub use channel::{
92 AckReason, AuthGuard, AuthVerdict, ChannelConfig, ChannelConfigRegistry, ChannelError,
93 ChannelHash, ChannelId, ChannelName, ChannelPublisher, ChannelRegistry, MembershipMsg,
94 OnFailure, PublishConfig, PublishReport, SubscriberRoster, Visibility,
95 SUBPROTOCOL_CHANNEL_MEMBERSHIP,
96};
97pub use compute::{
98 DaemonError, DaemonFactoryRegistry, DaemonHost, DaemonHostConfig, DaemonRegistry, DaemonStats,
99 FactoryEntry, MeshDaemon, MigrationError, MigrationMessage, MigrationOrchestrator,
100 MigrationPhase, MigrationSourceHandler, MigrationState, MigrationTargetHandler,
101 PlacementDecision, Scheduler, SchedulerError, SUBPROTOCOL_MIGRATION,
102};
103pub use config::{ConnectionRole, NetAdapterConfig, ReliabilityConfig};
104pub use contested::{
105 CorrelatedFailureConfig, CorrelatedFailureDetector, CorrelationVerdict, FailureCause,
106 PartitionDetector, PartitionPhase, PartitionRecord, ReconcileOutcome, Side,
107 SUBPROTOCOL_PARTITION,
108};
109pub use continuity::{
110 assess_continuity, CausalCone, Causality, ContinuityProof, ContinuityStatus, Discontinuity,
111 DiscontinuityReason, ForkRecord, HorizonDivergence, ObservationWindow, ProofError,
112 PropagationModel, SuperpositionPhase, SuperpositionState, SUBPROTOCOL_CONTINUITY,
113};
114#[cfg(feature = "cortex")]
115pub use cortex::{
116 CortexAdapter, CortexAdapterConfig, CortexAdapterError, EventEnvelope, EventMeta,
117 FoldErrorPolicy, IntoRedexPayload, StartPosition, EVENT_META_SIZE,
118};
119pub use crypto::{CryptoError, SessionKeys, StaticKeypair};
120pub use failure::{
121 CircuitBreaker, CircuitState, FailureDetector, FailureDetectorConfig, FailureStats,
122 LossSimulator, NodeStatus, RecoveryAction, RecoveryManager, RecoveryStats,
123};
124pub use identity::{
125 EntityError, EntityId, EntityKeypair, OriginStamp, PermissionToken, TokenCache, TokenError,
126 TokenScope,
127};
128pub use mesh::{MeshNode, MeshNodeConfig, PartitionFilter};
129#[cfg(feature = "netdb")]
130pub use netdb::{MemoriesFilter, NetDb, NetDbBuilder, NetDbError, NetDbSnapshot, TasksFilter};
131pub use pool::{PacketBuilder, PacketPool, SharedLocalPool, ThreadLocalPool};
137pub use protocol::{
138 EventFrame, NackPayload, NetHeader, PacketFlags, HEADER_SIZE, NONCE_SIZE, TAG_SIZE,
139};
140pub use proxy::{
141 ForwardResult, HopStats, MultiHopPacketBuilder, NetProxy, ProxyConfig, ProxyError, ProxyStats,
142};
143#[cfg(feature = "redex")]
144pub use redex::{
145 FsyncPolicy, IndexOp, IndexStart, OrderedAppender, Redex, RedexEntry, RedexError, RedexEvent,
146 RedexFile, RedexFileConfig, RedexFlags, RedexFold, RedexIndex, TypedRedexFile,
147};
148pub use reliability::{FireAndForget, ReliabilityMode, ReliableStream, RetransmitDescriptor};
149pub use reroute::ReroutePolicy;
150pub use route::{
151 AggregateStats, RouteEntry, RouteFlags, RoutingHeader, RoutingTable, SchedulerStreamStats,
152 ROUTING_HEADER_SIZE,
153};
154pub use router::{FairScheduler, NetRouter, RouteAction, RouterConfig, RouterError, RouterStats};
155pub use session::{NetSession, SessionManager, StreamState, TxAdmit, TxSlotGuard};
156pub use state::{
157 CausalChainBuilder, CausalEvent, CausalLink, ChainError, EntityLog, HorizonEncoder, LogError,
158 LogIndex, ObservedHorizon, SnapshotStore, StateSnapshot, CAUSAL_LINK_SIZE, SUBPROTOCOL_CAUSAL,
159 SUBPROTOCOL_SNAPSHOT,
160};
161pub use stream::{
162 CloseBehavior, Reliability, Stream, StreamConfig, StreamError, StreamStats,
163 DEFAULT_STREAM_WINDOW_BYTES,
164};
165pub use subnet::{DropReason, ForwardDecision, SubnetGateway, SubnetId, SubnetPolicy, SubnetRule};
166pub use subprotocol::{
167 negotiate, MigrationSubprotocolHandler, NegotiatedSet, OutboundMigrationMessage,
168 SubprotocolDescriptor, SubprotocolManifest, SubprotocolRegistry, SubprotocolVersion,
169 SUBPROTOCOL_NEGOTIATION,
170};
171pub use swarm::{
172 Capabilities, CapabilityAd, EdgeInfo, GraphStats, LocalGraph, NodeInfo, Pingwave,
173 MAX_GRAPH_NODES, MAX_SEEN_PINGWAVES, PINGWAVE_SIZE,
174};
175pub use transport::{NetSocket, PacketReceiver, PacketSender, ParsedPacket, SocketBufferConfig};
176
177use async_trait::async_trait;
178use bytes::Bytes;
179use crossbeam_queue::SegQueue;
180use dashmap::DashMap;
181use std::sync::atomic::{AtomicBool, Ordering};
182use std::sync::Arc;
183use tokio::sync::Mutex as TokioMutex;
184use tokio::sync::Notify;
185use tokio::task::JoinHandle;
186
187use crate::adapter::{Adapter, ShardPollResult};
188use crate::error::AdapterError;
189use crate::event::{Batch, StoredEvent};
190
191use crypto::NoiseHandshake;
192use session::SessionManager as SessionMgr;
193use transport::NetSocket as Socket;
194
195pub use routing::{route_to_shard, stream_id_from_bytes, stream_id_from_key};
197
198#[inline]
214pub(crate) fn current_timestamp() -> u64 {
215 let elapsed = std::time::SystemTime::now()
216 .duration_since(std::time::UNIX_EPOCH)
217 .unwrap_or_default();
218 u64::try_from(elapsed.as_nanos()).unwrap_or(u64::MAX)
219}
220
221#[inline]
227pub(crate) fn current_timestamp_micros() -> u64 {
228 std::time::SystemTime::now()
229 .duration_since(std::time::UNIX_EPOCH)
230 .map(|d| d.as_micros() as u64)
231 .unwrap_or(0)
232}
233
234mod routing {
238 use xxhash_rust::xxh3::xxh3_64;
239
240 #[inline]
244 pub fn stream_id_from_bytes(data: &[u8]) -> u64 {
245 xxh3_64(data)
246 }
247
248 #[inline]
252 pub fn stream_id_from_key(key: &str) -> u64 {
253 xxh3_64(key.as_bytes())
254 }
255
256 #[inline]
264 pub fn route_to_shard(data: &[u8], num_shards: u16) -> u16 {
265 assert!(num_shards > 0, "num_shards must be > 0");
266 (xxh3_64(data) % num_shards as u64) as u16
267 }
268
269 #[cfg(test)]
270 mod tests {
271 use super::*;
272
273 #[test]
274 fn test_stream_id_deterministic() {
275 let data = b"test event data";
276 let id1 = stream_id_from_bytes(data);
277 let id2 = stream_id_from_bytes(data);
278 assert_eq!(id1, id2);
279 }
280
281 #[test]
282 fn test_stream_id_different_for_different_data() {
283 let id1 = stream_id_from_bytes(b"event1");
284 let id2 = stream_id_from_bytes(b"event2");
285 assert_ne!(id1, id2);
286 }
287
288 #[test]
289 fn test_stream_id_from_key() {
290 let id = stream_id_from_key("user:12345");
291 assert_ne!(id, 0);
292 }
293
294 #[test]
295 fn test_route_to_shard_range() {
296 let num_shards = 16u16;
297 for i in 0..1000 {
298 let data = format!("event_{}", i);
299 let shard = route_to_shard(data.as_bytes(), num_shards);
300 assert!(shard < num_shards);
301 }
302 }
303
304 #[test]
305 #[should_panic(expected = "num_shards must be > 0")]
306 fn test_route_to_shard_zero_shards_panics() {
307 route_to_shard(b"test", 0);
310 }
311
312 #[test]
313 fn test_route_to_shard_distribution() {
314 let num_shards = 8u16;
315 let mut counts = [0u32; 8];
316
317 for i in 0..8000 {
318 let data = format!("event_{}", i);
319 let shard = route_to_shard(data.as_bytes(), num_shards);
320 counts[shard as usize] += 1;
321 }
322
323 let expected = 1000;
325 for count in counts {
326 assert!(count > expected / 2, "shard count {} too low", count);
327 assert!(count < expected * 2, "shard count {} too high", count);
328 }
329 }
330 }
331}
332
333type InboundQueues = Arc<DashMap<u16, SegQueue<StoredEvent>>>;
335
336pub(crate) struct HandshakePacer {
350 entries: std::collections::HashMap<std::net::SocketAddr, (u32, std::time::Instant)>,
352 max_per_window: u32,
354 window: std::time::Duration,
356 last_gc: std::time::Instant,
358 gc_size_threshold: usize,
362}
363
364impl HandshakePacer {
365 pub(crate) fn new(max_per_window: u32, window: std::time::Duration) -> Self {
366 Self {
367 entries: std::collections::HashMap::new(),
368 max_per_window,
369 window,
370 last_gc: std::time::Instant::now(),
371 gc_size_threshold: 4096,
375 }
376 }
377
378 pub(crate) fn check_and_record(&mut self, source: std::net::SocketAddr) -> bool {
382 let now = std::time::Instant::now();
383 if now.duration_since(self.last_gc) >= self.window
392 || self.entries.len() >= self.gc_size_threshold
393 {
394 let cutoff = self.window.saturating_mul(2);
395 self.entries
396 .retain(|_, (_, start)| now.duration_since(*start) < cutoff);
397 self.last_gc = now;
398 }
399
400 let entry = self.entries.entry(source).or_insert((0, now));
401 if now.duration_since(entry.1) > self.window {
402 entry.0 = 0;
404 entry.1 = now;
405 }
406 entry.0 = entry.0.saturating_add(1);
407 entry.0 <= self.max_per_window
408 }
409}
410
411pub struct NetAdapter {
413 config: NetAdapterConfig,
415 socket: Option<Arc<Socket>>,
417 session: Option<Arc<NetSession>>,
419 session_manager: SessionMgr,
421 inbound: InboundQueues,
423 tasks: TokioMutex<Vec<JoinHandle<()>>>,
425 shutdown: Arc<AtomicBool>,
427 shutdown_notify: Arc<Notify>,
429 initialized: AtomicBool,
431 handshake_pacer: parking_lot::Mutex<HandshakePacer>,
436}
437
438impl NetAdapter {
439 pub fn new(config: NetAdapterConfig) -> Result<Self, AdapterError> {
441 config
442 .validate()
443 .map_err(|e| AdapterError::Fatal(format!("invalid config: {}", e)))?;
444
445 Ok(Self {
446 session_manager: SessionMgr::new(config.session_timeout),
447 config,
448 socket: None,
449 session: None,
450 inbound: Arc::new(DashMap::new()),
451 tasks: TokioMutex::new(Vec::new()),
452 shutdown: Arc::new(AtomicBool::new(false)),
453 shutdown_notify: Arc::new(Notify::new()),
454 initialized: AtomicBool::new(false),
455 handshake_pacer: parking_lot::Mutex::new(HandshakePacer::new(
459 5,
460 std::time::Duration::from_secs(1),
461 )),
462 })
463 }
464
465 async fn perform_handshake(
468 &self,
469 socket: &Socket,
470 ) -> Result<(SessionKeys, std::net::SocketAddr), AdapterError> {
471 let mut attempt = 0;
472 let max_attempts = self.config.handshake_retries;
473
474 const HANDSHAKE_RETRY_SLEEP_CAP_MS: u64 = 5_000;
482
483 loop {
484 attempt += 1;
485 match self.try_handshake(socket).await {
486 Ok(result) => return Ok(result),
487 Err(e) if attempt < max_attempts => {
488 tracing::warn!(
489 attempt = attempt,
490 max = max_attempts,
491 error = %e,
492 "handshake failed, retrying"
493 );
494 let backoff_ms =
495 (100u64.saturating_mul(attempt as u64)).min(HANDSHAKE_RETRY_SLEEP_CAP_MS);
496 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
497 }
498 Err(e) => return Err(e),
499 }
500 }
501 }
502
503 async fn try_handshake(
506 &self,
507 socket: &Socket,
508 ) -> Result<(SessionKeys, std::net::SocketAddr), AdapterError> {
509 let timeout = self.config.handshake_timeout;
510 let socket_arc = socket.socket_arc();
511
512 if self.config.is_initiator() {
513 let peer_pubkey = self
515 .config
516 .peer_static_pubkey
517 .as_ref()
518 .ok_or_else(|| AdapterError::Fatal("missing peer public key".into()))?;
519
520 let mut handshake = NoiseHandshake::initiator(&self.config.psk, peer_pubkey)
521 .map_err(|e| AdapterError::Fatal(format!("handshake init failed: {}", e)))?;
522
523 let msg1 = handshake
525 .write_message(&[])
526 .map_err(|e| AdapterError::Connection(format!("write_message failed: {}", e)))?;
527
528 let mut builder = PacketBuilder::new(&[0u8; 32], 0);
529 let packet = builder.build_handshake(&msg1);
530
531 socket
532 .send_to(&packet, self.config.peer_addr)
533 .await
534 .map_err(|e| AdapterError::Connection(format!("send failed: {}", e)))?;
535
536 let (parsed, _source) = tokio::time::timeout(timeout, async {
540 let mut recv_buf = [0u8; protocol::MAX_PACKET_SIZE];
553 loop {
554 let (n, source) = socket_arc
555 .recv_from(&mut recv_buf)
556 .await
557 .map_err(|e| AdapterError::Connection(format!("recv failed: {}", e)))?;
558
559 if source != self.config.peer_addr {
561 continue;
562 }
563
564 let data = bytes::Bytes::copy_from_slice(&recv_buf[..n]);
565
566 if let Some(p) = ParsedPacket::parse(data, source) {
567 if p.header.flags.is_handshake() {
568 return Ok::<_, AdapterError>((p, source));
569 }
570 }
571 }
573 })
574 .await
575 .map_err(|_| AdapterError::Connection("handshake timeout".into()))??;
576
577 handshake
579 .read_message(&parsed.payload)
580 .map_err(|e| AdapterError::Connection(format!("read_message failed: {}", e)))?;
581
582 let keys = handshake
584 .into_session_keys()
585 .map_err(|e| AdapterError::Fatal(format!("key extraction failed: {}", e)))?;
586 Ok((keys, self.config.peer_addr))
587 } else {
588 let keypair = self
590 .config
591 .static_keypair
592 .as_ref()
593 .ok_or_else(|| AdapterError::Fatal("missing static keypair".into()))?;
594
595 let (parsed, source) = tokio::time::timeout(timeout, async {
602 loop {
603 let mut recv_buf = bytes::BytesMut::with_capacity(protocol::MAX_PACKET_SIZE);
604 recv_buf.resize(protocol::MAX_PACKET_SIZE, 0);
605
606 let (n, source) = socket_arc
607 .recv_from(&mut recv_buf)
608 .await
609 .map_err(|e| AdapterError::Connection(format!("recv failed: {}", e)))?;
610
611 recv_buf.truncate(n);
612 let data = recv_buf.freeze();
613
614 if let Some(p) = ParsedPacket::parse(data, source) {
615 if p.header.flags.is_handshake() {
616 let allowed = self.handshake_pacer.lock().check_and_record(source);
619 if !allowed {
620 tracing::debug!(
621 %source,
622 "handshake responder: dropping packet from \
623 rate-limited source"
624 );
625 continue;
626 }
627 return Ok::<_, AdapterError>((p, source));
628 }
629 }
630 }
632 })
633 .await
634 .map_err(|_| AdapterError::Connection("handshake timeout".into()))??;
635
636 let mut handshake = NoiseHandshake::responder(&self.config.psk, keypair)
637 .map_err(|e| AdapterError::Fatal(format!("handshake init failed: {}", e)))?;
638
639 handshake
641 .read_message(&parsed.payload)
642 .map_err(|e| AdapterError::Connection(format!("read_message failed: {}", e)))?;
643
644 let msg2 = handshake
646 .write_message(&[])
647 .map_err(|e| AdapterError::Connection(format!("write_message failed: {}", e)))?;
648
649 let mut builder = PacketBuilder::new(&[0u8; 32], 0);
650 let packet = builder.build_handshake(&msg2);
651
652 socket
655 .send_to(&packet, source)
656 .await
657 .map_err(|e| AdapterError::Connection(format!("send failed: {}", e)))?;
658
659 let keys = handshake
661 .into_session_keys()
662 .map_err(|e| AdapterError::Fatal(format!("key extraction failed: {}", e)))?;
663 Ok((keys, source))
664 }
665 }
666
667 fn process_packet(
669 data: Bytes,
670 source: std::net::SocketAddr,
671 session: &NetSession,
672 inbound: &InboundQueues,
673 num_shards: u16,
674 ) {
675 let mut parsed = match ParsedPacket::parse(data, source) {
677 Some(p) => p,
678 None => return,
679 };
680
681 if !parsed.header.flags.is_handshake()
685 && !parsed.header.flags.is_heartbeat()
686 && !parsed.is_valid_length()
687 {
688 return;
689 }
690
691 if parsed.header.flags.is_handshake() {
693 return;
694 }
695
696 if parsed.header.session_id != session.session_id() {
698 return;
699 }
700
701 if parsed.header.flags.is_heartbeat() {
718 if source == session.peer_addr() {
719 session.verify_and_touch_heartbeat(&parsed);
720 }
721 return;
722 }
723
724 let aad = parsed.header.aad();
729 let counter = u64::from_le_bytes(parsed.header.nonce[4..12].try_into().unwrap_or([0u8; 8]));
730 let rx_cipher = session.rx_cipher();
731 let payload = std::mem::take(&mut parsed.payload);
732 let decrypted = match rx_cipher.decrypt_to_bytes(counter, &aad, payload) {
740 Ok(d) => {
741 if !rx_cipher.try_admit_rx_counter(counter) {
742 return;
743 }
744 d
745 }
746 Err(_) => return,
747 };
748
749 let events = EventFrame::read_events(decrypted, parsed.header.event_count);
751
752 let stream_id = parsed.header.stream_id;
754 let shard_id = if num_shards > 0 {
755 (stream_id % num_shards as u64) as u16
756 } else {
757 0
758 };
759
760 let is_fresh = {
773 let stream = session.get_or_create_stream(stream_id);
774 let fresh = stream.with_reliability(|r| r.on_receive(parsed.header.sequence));
780 stream.update_rx_seq(parsed.header.sequence);
781 fresh
782 };
783
784 if is_fresh {
785 let queue = inbound.entry(shard_id).or_default();
787 let seq = parsed.header.sequence;
788 for (i, event_data) in events.into_iter().enumerate() {
789 use std::fmt::Write;
790 let mut event_id = String::with_capacity(24);
791 let _ = write!(event_id, "{}:{}", seq, i);
792 queue.push(StoredEvent::new(event_id, event_data, seq, shard_id));
793 }
794 } else {
795 tracing::debug!(
796 seq = parsed.header.sequence,
797 stream_id,
798 "Dropping duplicate packet"
799 );
800 }
801
802 session.touch();
803 }
804
805 #[cfg(target_os = "linux")]
810 fn spawn_receiver(
811 shutdown: Arc<AtomicBool>,
812 shutdown_notify: Arc<Notify>,
813 socket: Arc<Socket>,
814 session: Arc<NetSession>,
815 inbound: InboundQueues,
816 num_shards: u16,
817 ) -> JoinHandle<()> {
818 let mut receiver = transport::BatchedPacketReceiver::new(socket.socket_arc());
819
820 tokio::spawn(async move {
821 while !shutdown.load(Ordering::Acquire) {
822 tokio::select! {
823 result = receiver.recv() => {
824 match result {
825 Ok((data, source)) => {
826 Self::process_packet(data, source, &session, &inbound, num_shards);
827 }
828 Err(e) if e.kind() == std::io::ErrorKind::ConnectionReset => {
829 tracing::warn!("batch receiver thread exited, stopping receiver");
830 break;
831 }
832 Err(e) => {
833 if !shutdown.load(Ordering::Acquire) {
834 tracing::warn!(error = %e, "receive error");
835 }
836 }
837 }
838 }
839 _ = shutdown_notify.notified() => {
840 break;
841 }
842 }
843 }
844 })
845 }
846
847 #[cfg(not(target_os = "linux"))]
849 fn spawn_receiver(
850 shutdown: Arc<AtomicBool>,
851 shutdown_notify: Arc<Notify>,
852 socket: Arc<Socket>,
853 session: Arc<NetSession>,
854 inbound: InboundQueues,
855 num_shards: u16,
856 ) -> JoinHandle<()> {
857 tokio::spawn(async move {
858 let mut receiver = PacketReceiver::new(socket.socket_arc());
859
860 while !shutdown.load(Ordering::Acquire) {
861 tokio::select! {
865 result = receiver.recv() => {
866 match result {
867 Ok((data, source)) => {
868 Self::process_packet(data, source, &session, &inbound, num_shards);
869 }
870 Err(e) => {
871 if !shutdown.load(Ordering::Acquire) {
872 tracing::warn!(error = %e, "receive error");
873 }
874 }
875 }
876 }
877 _ = shutdown_notify.notified() => {
878 break;
879 }
880 }
881 }
882 })
883 }
884
885 fn spawn_heartbeat(
887 shutdown: Arc<AtomicBool>,
888 shutdown_notify: Arc<Notify>,
889 socket: Arc<Socket>,
890 session: Arc<NetSession>,
891 interval: std::time::Duration,
892 peer_addr: std::net::SocketAddr,
893 ) -> JoinHandle<()> {
894 tokio::spawn(async move {
895 let mut ticker = tokio::time::interval(interval);
896
897 loop {
898 tokio::select! {
899 _ = ticker.tick() => {
900 if shutdown.load(Ordering::Acquire) || !session.is_active() {
901 break;
902 }
903
904 let packet = session.build_heartbeat();
918
919 if let Err(e) = socket.send_to(&packet, peer_addr).await {
920 tracing::warn!(error = %e, "heartbeat send failed");
921 }
922 }
923 _ = shutdown_notify.notified() => {
924 break;
925 }
926 }
927 }
928 })
929 }
930}
931
932#[async_trait]
933impl Adapter for NetAdapter {
934 async fn init(&mut self) -> Result<(), AdapterError> {
935 if self.initialized.load(Ordering::Acquire) {
936 return Ok(());
937 }
938
939 let socket_config = match (
941 self.config.socket_recv_buffer,
942 self.config.socket_send_buffer,
943 ) {
944 (Some(recv), Some(send)) => transport::SocketBufferConfig {
945 recv_buffer_size: recv,
946 send_buffer_size: send,
947 },
948 _ => transport::SocketBufferConfig::default(),
949 };
950 let socket = Socket::with_config(self.config.bind_addr, socket_config)
951 .await
952 .map_err(|e| AdapterError::Connection(format!("socket creation failed: {}", e)))?;
953
954 let socket = Arc::new(socket);
955 self.socket = Some(socket.clone());
956
957 let (keys, actual_peer) = self.perform_handshake(&socket).await?;
959
960 let session = Arc::new(NetSession::new(
964 keys,
965 actual_peer,
966 self.config.packet_pool_size,
967 self.config.default_reliability.is_reliable(),
968 ));
969 self.session = Some(session.clone());
970
971 self.session_manager.set_session_arc(session.clone());
973
974 let recv_task = Self::spawn_receiver(
976 self.shutdown.clone(),
977 self.shutdown_notify.clone(),
978 socket.clone(),
979 session.clone(),
980 self.inbound.clone(),
981 self.config.num_shards,
982 );
983
984 let heartbeat_task = Self::spawn_heartbeat(
985 self.shutdown.clone(),
986 self.shutdown_notify.clone(),
987 socket,
988 session,
989 self.config.heartbeat_interval,
990 actual_peer,
991 );
992
993 {
994 let mut tasks = self.tasks.lock().await;
995 tasks.push(recv_task);
996 tasks.push(heartbeat_task);
997 }
998
999 self.initialized.store(true, Ordering::Release);
1000
1001 tracing::info!(
1002 bind_addr = %self.config.bind_addr,
1003 peer_addr = %self.config.peer_addr,
1004 role = ?self.config.role,
1005 "Net adapter initialized"
1006 );
1007
1008 Ok(())
1009 }
1010
1011 async fn on_batch(&self, batch: std::sync::Arc<Batch>) -> Result<(), AdapterError> {
1012 let session = self
1013 .session
1014 .as_ref()
1015 .ok_or_else(|| AdapterError::Connection("not connected".into()))?;
1016
1017 let socket = self
1018 .socket
1019 .as_ref()
1020 .ok_or_else(|| AdapterError::Connection("socket not initialized".into()))?;
1021
1022 let stream_id = batch.shard_id as u64;
1023 let peer_addr = session.peer_addr();
1024
1025 let reliable = {
1029 let stream = session.get_or_create_stream(stream_id);
1030 stream.with_reliability(|r| r.needs_ack())
1031 };
1033
1034 let mut current_batch: Vec<Bytes> = Vec::with_capacity(64);
1036 let mut current_size = 0usize;
1037
1038 let pool = session.thread_local_pool();
1040 let mut builder = pool.get();
1041
1042 for event in &batch.events {
1043 let event_bytes = event.raw.clone();
1044 let frame_size = EventFrame::LEN_SIZE + event_bytes.len();
1045
1046 if current_size + frame_size > protocol::MAX_PAYLOAD_SIZE && !current_batch.is_empty() {
1048 let seq;
1050 {
1051 let stream = session.get_or_create_stream(stream_id);
1052 seq = stream.next_tx_seq();
1053 }
1054
1055 let flags = if reliable {
1056 PacketFlags::RELIABLE
1057 } else {
1058 PacketFlags::NONE
1059 };
1060
1061 let packet = builder.build(stream_id, seq, ¤t_batch, flags);
1062
1063 socket
1065 .send_to(&packet, peer_addr)
1066 .await
1067 .map_err(|e| AdapterError::Connection(format!("send failed: {}", e)))?;
1068
1069 if reliable {
1076 let descriptor = std::sync::Arc::new(reliability::RetransmitDescriptor {
1081 seq,
1082 stream_id,
1083 events: current_batch.clone(),
1084 flags,
1085 });
1086 let stream = session.get_or_create_stream(stream_id);
1087 stream.with_reliability(|r| r.on_send(descriptor));
1088 }
1089
1090 current_batch.clear();
1091 current_size = 0;
1092 }
1093
1094 current_batch.push(event_bytes);
1095 current_size += frame_size;
1096 }
1097
1098 if !current_batch.is_empty() {
1100 let seq;
1101 {
1102 let stream = session.get_or_create_stream(stream_id);
1103 seq = stream.next_tx_seq();
1104 }
1105
1106 let flags = if reliable {
1107 PacketFlags::RELIABLE
1108 } else {
1109 PacketFlags::NONE
1110 };
1111
1112 let packet = builder.build(stream_id, seq, ¤t_batch, flags);
1113
1114 socket
1115 .send_to(&packet, peer_addr)
1116 .await
1117 .map_err(|e| AdapterError::Connection(format!("send failed: {}", e)))?;
1118
1119 if reliable {
1120 let descriptor = std::sync::Arc::new(reliability::RetransmitDescriptor {
1122 seq,
1123 stream_id,
1124 events: current_batch.clone(),
1125 flags,
1126 });
1127 let stream = session.get_or_create_stream(stream_id);
1128 stream.with_reliability(|r| r.on_send(descriptor));
1129 }
1130 }
1131
1132 session.touch();
1133
1134 Ok(())
1135 }
1136
1137 async fn poll_shard(
1138 &self,
1139 shard_id: u16,
1140 from_id: Option<&str>,
1141 limit: usize,
1142 ) -> Result<ShardPollResult, AdapterError> {
1143 let mut events = Vec::with_capacity(limit);
1144
1145 if let Some(queue) = self.inbound.get(&shard_id) {
1146 while events.len() < limit {
1147 if let Some(event) = queue.pop() {
1148 if from_id.is_none() || event_id_gt(&event.id, from_id.unwrap_or("")) {
1149 events.push(event);
1150 }
1151 } else {
1156 break;
1157 }
1158 }
1159 }
1160
1161 let has_more = self
1162 .inbound
1163 .get(&shard_id)
1164 .map(|q| !q.is_empty())
1165 .unwrap_or(false);
1166 let next_id = events.last().map(|e| e.id.clone());
1167
1168 Ok(ShardPollResult {
1169 events,
1170 next_id,
1171 has_more,
1172 })
1173 }
1174
1175 async fn flush(&self) -> Result<(), AdapterError> {
1176 Ok(())
1179 }
1180
1181 async fn shutdown(&self) -> Result<(), AdapterError> {
1182 self.shutdown.store(true, Ordering::Release);
1183
1184 self.shutdown_notify.notify_waiters();
1187
1188 self.session_manager.clear_session();
1190
1191 let mut tasks = self.tasks.lock().await;
1193 for task in tasks.drain(..) {
1194 let _ = task.await;
1195 }
1196
1197 self.initialized.store(false, Ordering::Release);
1198
1199 tracing::info!("Net adapter shutdown complete");
1200
1201 Ok(())
1202 }
1203
1204 fn name(&self) -> &'static str {
1205 "net"
1206 }
1207
1208 async fn is_healthy(&self) -> bool {
1209 self.initialized.load(Ordering::Acquire) && self.session_manager.check_session()
1210 }
1211}
1212
1213impl std::fmt::Debug for NetAdapter {
1214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1215 f.debug_struct("NetAdapter")
1216 .field("config", &self.config)
1217 .field("initialized", &self.initialized.load(Ordering::Relaxed))
1218 .finish()
1219 }
1220}
1221
1222fn event_id_gt(a: &str, b: &str) -> bool {
1229 fn parse_id(id: &str) -> Option<(u64, u64)> {
1230 let (seq, idx) = id.split_once(':')?;
1231 Some((seq.parse().ok()?, idx.parse().ok()?))
1232 }
1233
1234 match (parse_id(a), parse_id(b)) {
1235 (Some(a), Some(b)) => a > b,
1236 _ => a > b, }
1238}
1239
1240#[cfg(test)]
1241mod tests {
1242 use super::*;
1243
1244 #[test]
1245 fn test_adapter_creation() {
1246 let psk = [0x42u8; 32];
1247 let peer_pubkey = [0x24u8; 32];
1248
1249 let config = NetAdapterConfig::initiator(
1250 "127.0.0.1:0".parse().unwrap(),
1251 "127.0.0.1:9999".parse().unwrap(),
1252 psk,
1253 peer_pubkey,
1254 );
1255
1256 let adapter = NetAdapter::new(config).unwrap();
1257 assert_eq!(adapter.name(), "net");
1258 }
1259
1260 #[test]
1261 fn test_shard_id_from_stream_id_uses_modulo() {
1262 let num_shards: u16 = 8;
1266
1267 let stream_a: u64 = 0xDEAD_BEEF_0000_0003;
1272 let stream_b: u64 = 0xCAFE_BABE_0000_0003;
1273
1274 let shard_a = (stream_a % num_shards as u64) as u16;
1275 let shard_b = (stream_b % num_shards as u64) as u16;
1276
1277 assert!(
1278 shard_a < num_shards,
1279 "shard must be in range [0, num_shards)"
1280 );
1281 assert!(
1282 shard_b < num_shards,
1283 "shard must be in range [0, num_shards)"
1284 );
1285
1286 let big_stream: u64 = 0xFFFF_FFFF_FFFF_FFFF;
1288 let shard_big = (big_stream % num_shards as u64) as u16;
1289 assert!(shard_big < num_shards);
1290
1291 assert_ne!(
1294 big_stream as u16, shard_big,
1295 "modulo must differ from truncation for large stream IDs"
1296 );
1297 }
1298
1299 #[test]
1300 fn test_invalid_config() {
1301 let psk = [0x42u8; 32];
1302 let peer_pubkey = [0x24u8; 32];
1303
1304 let mut config = NetAdapterConfig::initiator(
1305 "127.0.0.1:0".parse().unwrap(),
1306 "127.0.0.1:9999".parse().unwrap(),
1307 psk,
1308 peer_pubkey,
1309 );
1310 config.peer_static_pubkey = None;
1311
1312 let result = NetAdapter::new(config);
1313 assert!(result.is_err());
1314 }
1315
1316 #[test]
1319 fn test_event_id_gt_numeric_ordering() {
1320 assert!(event_id_gt("2:0", "1:0"));
1322 assert!(!event_id_gt("1:0", "2:0"));
1323 assert!(!event_id_gt("1:0", "1:0"));
1324
1325 assert!(event_id_gt("10:0", "9:0"));
1327 assert!(event_id_gt("100:0", "99:0"));
1328 assert!(!event_id_gt("9:0", "10:0"));
1329
1330 assert!(event_id_gt("5:2", "5:1"));
1332 assert!(!event_id_gt("5:1", "5:2"));
1333
1334 assert!(event_id_gt("1000000:0", "999999:0"));
1336 }
1337
1338 #[test]
1344 fn test_event_id_gt_edge_cases() {
1345 assert!(event_id_gt("1:0", ""));
1347 assert!(event_id_gt("b", "a"));
1349 assert!(!event_id_gt("a", "b"));
1350 }
1351
1352 #[test]
1357 fn test_build_then_process_packet_roundtrip() {
1358 use crate::adapter::net::crypto::{NoiseHandshake, StaticKeypair};
1359 use dashmap::DashMap;
1360 use std::sync::Arc;
1361
1362 let psk = [0x42u8; 32];
1364 let responder_kp = StaticKeypair::generate();
1365
1366 let mut initiator = NoiseHandshake::initiator(&psk, &responder_kp.public).unwrap();
1367 let mut responder = NoiseHandshake::responder(&psk, &responder_kp).unwrap();
1368
1369 let msg1 = initiator.write_message(&[]).unwrap();
1370 responder.read_message(&msg1).unwrap();
1371 let msg2 = responder.write_message(&[]).unwrap();
1372 initiator.read_message(&msg2).unwrap();
1373
1374 let init_keys = initiator.into_session_keys().unwrap();
1375 let resp_keys = responder.into_session_keys().unwrap();
1376
1377 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1379 let events = vec![
1380 Bytes::from(r#"{"token":"hello"}"#),
1381 Bytes::from(r#"{"token":"world"}"#),
1382 ];
1383 let packet = builder.build(0, 0, &events, PacketFlags::NONE);
1384
1385 let resp_session = Arc::new(NetSession::new(
1387 resp_keys,
1388 "127.0.0.1:5000".parse().unwrap(),
1389 4,
1390 false,
1391 ));
1392 let inbound: InboundQueues = Arc::new(DashMap::new());
1393 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1394
1395 NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
1396
1397 let queue = inbound.get(&0).expect("shard 0 should have events");
1399 assert_eq!(queue.len(), 2, "expected 2 events, got {}", queue.len());
1400
1401 let e1 = queue.pop().unwrap();
1402 assert_eq!(&e1.raw[..], br#"{"token":"hello"}"#);
1403
1404 let e2 = queue.pop().unwrap();
1405 assert_eq!(&e2.raw[..], br#"{"token":"world"}"#);
1406 }
1407
1408 fn make_session_keys() -> (SessionKeys, SessionKeys) {
1410 use crate::adapter::net::crypto::{NoiseHandshake, StaticKeypair};
1411
1412 let psk = [0x42u8; 32];
1413 let responder_kp = StaticKeypair::generate();
1414
1415 let mut initiator = NoiseHandshake::initiator(&psk, &responder_kp.public).unwrap();
1416 let mut responder = NoiseHandshake::responder(&psk, &responder_kp).unwrap();
1417
1418 let msg1 = initiator.write_message(&[]).unwrap();
1419 responder.read_message(&msg1).unwrap();
1420 let msg2 = responder.write_message(&[]).unwrap();
1421 initiator.read_message(&msg2).unwrap();
1422
1423 (
1424 initiator.into_session_keys().unwrap(),
1425 responder.into_session_keys().unwrap(),
1426 )
1427 }
1428
1429 #[test]
1430 fn test_process_packet_rejects_truncated_packet() {
1431 use dashmap::DashMap;
1432 use std::sync::Arc;
1433
1434 let (init_keys, resp_keys) = make_session_keys();
1435
1436 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1438 let packet = builder.build(0, 0, &[Bytes::from_static(b"hello")], PacketFlags::NONE);
1439
1440 let resp_session = Arc::new(NetSession::new(
1441 resp_keys,
1442 "127.0.0.1:5000".parse().unwrap(),
1443 4,
1444 false,
1445 ));
1446 let inbound: InboundQueues = Arc::new(DashMap::new());
1447 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1448
1449 let truncated = packet.slice(..packet.len() - 10);
1451 NetAdapter::process_packet(truncated, source, &resp_session, &inbound, 1);
1452 assert!(
1453 inbound.get(&0).is_none() || inbound.get(&0).unwrap().is_empty(),
1454 "truncated packet must be silently dropped"
1455 );
1456 }
1457
1458 #[test]
1459 fn test_process_packet_rejects_tampered_payload() {
1460 use dashmap::DashMap;
1461 use std::sync::Arc;
1462
1463 let (init_keys, resp_keys) = make_session_keys();
1464
1465 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1466 let packet = builder.build(0, 0, &[Bytes::from_static(b"hello")], PacketFlags::NONE);
1467
1468 let resp_session = Arc::new(NetSession::new(
1469 resp_keys,
1470 "127.0.0.1:5000".parse().unwrap(),
1471 4,
1472 false,
1473 ));
1474 let inbound: InboundQueues = Arc::new(DashMap::new());
1475 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1476
1477 let mut tampered = bytes::BytesMut::from(&packet[..]);
1479 tampered[super::protocol::HEADER_SIZE + 2] ^= 0xFF;
1480 NetAdapter::process_packet(tampered.freeze(), source, &resp_session, &inbound, 1);
1481
1482 assert!(
1483 inbound.get(&0).is_none() || inbound.get(&0).unwrap().is_empty(),
1484 "tampered packet must be rejected by AEAD"
1485 );
1486 }
1487
1488 #[test]
1489 fn test_process_packet_rejects_wrong_session_id() {
1490 use dashmap::DashMap;
1491 use std::sync::Arc;
1492
1493 let (init_keys, resp_keys) = make_session_keys();
1494
1495 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1496 let packet = builder.build(0, 0, &[Bytes::from_static(b"hello")], PacketFlags::NONE);
1497
1498 let mut wrong_keys = resp_keys;
1500 wrong_keys.session_id = 0xDEAD;
1501 let resp_session = Arc::new(NetSession::new(
1502 wrong_keys,
1503 "127.0.0.1:5000".parse().unwrap(),
1504 4,
1505 false,
1506 ));
1507 let inbound: InboundQueues = Arc::new(DashMap::new());
1508 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1509
1510 NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
1511
1512 assert!(
1513 inbound.get(&0).is_none() || inbound.get(&0).unwrap().is_empty(),
1514 "packet with wrong session_id must be dropped"
1515 );
1516 }
1517
1518 #[test]
1519 fn test_process_packet_multi_packet_batch_all_events_arrive() {
1520 use dashmap::DashMap;
1521 use std::sync::Arc;
1522
1523 let (init_keys, resp_keys) = make_session_keys();
1524
1525 let resp_session = Arc::new(NetSession::new(
1526 resp_keys,
1527 "127.0.0.1:5000".parse().unwrap(),
1528 4,
1529 false,
1530 ));
1531 let inbound: InboundQueues = Arc::new(DashMap::new());
1532 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1533
1534 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1538 let total_events = 200;
1539 let mut seq = 0u64;
1540
1541 let mut current_batch: Vec<Bytes> = Vec::new();
1543 let mut current_size = 0;
1544
1545 for i in 0..total_events {
1546 let data = format!("{{\"i\":{},\"pad\":\"{}\"}}", i, "x".repeat(150));
1547 let event_bytes = Bytes::from(data);
1548 let frame_size = EventFrame::LEN_SIZE + event_bytes.len();
1549
1550 if current_size + frame_size > protocol::MAX_PAYLOAD_SIZE && !current_batch.is_empty() {
1551 let packet = builder.build(0, seq, ¤t_batch, PacketFlags::NONE);
1552 NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
1553 seq += 1;
1554 current_batch.clear();
1555 current_size = 0;
1556 }
1557
1558 current_batch.push(event_bytes);
1559 current_size += frame_size;
1560 }
1561
1562 if !current_batch.is_empty() {
1563 let packet = builder.build(0, seq, ¤t_batch, PacketFlags::NONE);
1564 NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
1565 }
1566
1567 let queue = inbound.get(&0).expect("shard 0 should have events");
1569 assert_eq!(
1570 queue.len(),
1571 total_events,
1572 "all {} events must arrive across multiple packets",
1573 total_events
1574 );
1575 }
1576
1577 #[test]
1578 fn test_build_then_process_packet_both_directions() {
1579 use dashmap::DashMap;
1580 use std::sync::Arc;
1581
1582 let (init_keys, resp_keys) = make_session_keys();
1583 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1584
1585 {
1587 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1588 let packet = builder.build(0, 0, &[Bytes::from_static(b"i2r")], PacketFlags::NONE);
1589
1590 let session = Arc::new(NetSession::new(resp_keys.clone(), source, 4, false));
1591 let inbound: InboundQueues = Arc::new(DashMap::new());
1592 NetAdapter::process_packet(packet, source, &session, &inbound, 1);
1593
1594 let queue = inbound.get(&0).expect("i2r: shard 0 should have events");
1595 assert_eq!(queue.len(), 1, "i2r: expected 1 event");
1596 assert_eq!(&queue.pop().unwrap().raw[..], b"i2r");
1597 }
1598
1599 {
1601 let mut builder = PacketBuilder::new(&resp_keys.tx_key, resp_keys.session_id);
1602 let packet = builder.build(0, 0, &[Bytes::from_static(b"r2i")], PacketFlags::NONE);
1603
1604 let session = Arc::new(NetSession::new(init_keys.clone(), source, 4, false));
1605 let inbound: InboundQueues = Arc::new(DashMap::new());
1606 NetAdapter::process_packet(packet, source, &session, &inbound, 1);
1607
1608 let queue = inbound.get(&0).expect("r2i: shard 0 should have events");
1609 assert_eq!(queue.len(), 1, "r2i: expected 1 event");
1610 assert_eq!(&queue.pop().unwrap().raw[..], b"r2i");
1611 }
1612 }
1613
1614 #[test]
1615 fn test_poll_shard_cursor_drops_consumed_events() {
1616 use std::sync::Arc;
1621
1622 let (init_keys, resp_keys) = make_session_keys();
1623
1624 let resp_session = Arc::new(NetSession::new(
1625 resp_keys,
1626 "127.0.0.1:5000".parse().unwrap(),
1627 4,
1628 false,
1629 ));
1630 let inbound: InboundQueues = Arc::new(DashMap::new());
1631 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1632
1633 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1635 for seq in 0..3u64 {
1636 let events = vec![Bytes::from(format!("event-{}", seq))];
1637 let packet = builder.build(0, seq, &events, PacketFlags::NONE);
1638 NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
1639 }
1640
1641 let queue = inbound.get(&0u16).unwrap();
1642 assert_eq!(queue.len(), 3);
1643
1644 let from_id = "0:0";
1647 let mut events = Vec::new();
1648 while events.len() < 10 {
1649 if let Some(event) = queue.pop() {
1650 if event_id_gt(&event.id, from_id) {
1651 events.push(event);
1652 }
1653 } else {
1655 break;
1656 }
1657 }
1658
1659 assert_eq!(events.len(), 2, "should get 2 events after cursor 0:0");
1660 assert_eq!(events[0].id, "1:0");
1661 assert_eq!(events[1].id, "2:0");
1662
1663 assert_eq!(queue.len(), 0, "queue should be empty after poll drains it");
1665 }
1666
1667 #[test]
1668 fn test_process_packet_old_counter_rejected() {
1669 use std::sync::Arc;
1672
1673 let (init_keys, resp_keys) = make_session_keys();
1674 let resp_session = Arc::new(NetSession::new(
1675 resp_keys,
1676 "127.0.0.1:5000".parse().unwrap(),
1677 4,
1678 false,
1679 ));
1680 let inbound: InboundQueues = Arc::new(DashMap::new());
1681 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1682
1683 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1685 for seq in 0..1100u64 {
1686 let packet = builder.build(0, seq, &[Bytes::from_static(b"x")], PacketFlags::NONE);
1687 NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
1688 }
1689 assert_eq!(inbound.get(&0).unwrap().len(), 1100);
1690
1691 let mut stale_builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1695 let stale_packet =
1696 stale_builder.build(0, 9999, &[Bytes::from_static(b"stale")], PacketFlags::NONE);
1697 NetAdapter::process_packet(stale_packet, source, &resp_session, &inbound, 1);
1698
1699 assert_eq!(
1701 inbound.get(&0).unwrap().len(),
1702 1100,
1703 "packet with stale counter must be rejected"
1704 );
1705 }
1706
1707 #[test]
1708 fn test_process_packet_far_future_counter_rejected() {
1709 use std::sync::Arc;
1713
1714 let (_init_keys, resp_keys) = make_session_keys();
1715
1716 let resp_session = Arc::new(NetSession::new(
1721 resp_keys,
1722 "127.0.0.1:5000".parse().unwrap(),
1723 4,
1724 false,
1725 ));
1726
1727 let rx_cipher = resp_session.rx_cipher();
1729 assert!(
1730 !rx_cipher.is_valid_rx_counter(u64::MAX),
1731 "counter at u64::MAX must be rejected (far beyond MAX_FORWARD)"
1732 );
1733 assert!(
1734 rx_cipher.is_valid_rx_counter(0),
1735 "counter 0 should be valid initially"
1736 );
1737 }
1738
1739 #[test]
1756 fn process_packet_drops_duplicates_per_reliability_decision() {
1757 use dashmap::DashMap;
1758 use std::sync::Arc;
1759
1760 let (init_keys, resp_keys) = make_session_keys();
1761
1762 let resp_session = Arc::new(NetSession::new(
1766 resp_keys,
1767 "127.0.0.1:5000".parse().unwrap(),
1768 4,
1769 true, ));
1771 let inbound: InboundQueues = Arc::new(DashMap::new());
1772 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1773
1774 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1780 let packet0 = builder.build(7, 0, &[Bytes::from(r#"{"first":0}"#)], PacketFlags::NONE);
1781 let packet1 = builder.build(7, 1, &[Bytes::from(r#"{"first":1}"#)], PacketFlags::NONE);
1782 let packet0_dup = builder.build(
1786 7,
1787 0,
1788 &[Bytes::from(r#"{"dup":"should_not_appear"}"#)],
1789 PacketFlags::NONE,
1790 );
1791
1792 NetAdapter::process_packet(packet0, source, &resp_session, &inbound, 1);
1793 NetAdapter::process_packet(packet1, source, &resp_session, &inbound, 1);
1794 NetAdapter::process_packet(packet0_dup, source, &resp_session, &inbound, 1);
1795
1796 let queue = inbound.get(&0).expect("shard 0 should exist");
1797 assert_eq!(
1798 queue.len(),
1799 2,
1800 "duplicate packet must NOT enqueue (BUG_REPORT.md #5); \
1801 got {} events, expected exactly 2 (seq=0 and seq=1, no dup)",
1802 queue.len()
1803 );
1804
1805 let e0 = queue.pop().unwrap();
1808 assert_eq!(&e0.raw[..], br#"{"first":0}"#);
1809 let e1 = queue.pop().unwrap();
1810 assert_eq!(&e1.raw[..], br#"{"first":1}"#);
1811 assert!(queue.is_empty());
1812 }
1813
1814 #[test]
1822 fn heartbeat_is_aead_authenticated() {
1823 use crate::adapter::net::pool::PacketBuilder;
1824 use dashmap::DashMap;
1825 use std::sync::Arc;
1826
1827 let (init_keys, resp_keys) = make_session_keys();
1828
1829 let resp_session = Arc::new(NetSession::new(
1830 resp_keys,
1831 "127.0.0.1:5000".parse().unwrap(),
1832 4,
1833 false,
1834 ));
1835 let inbound: InboundQueues = Arc::new(DashMap::new());
1836 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1837
1838 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1841 let heartbeat = builder.build_heartbeat();
1842 let last_activity_before = resp_session.last_activity_ns();
1843 std::thread::sleep(std::time::Duration::from_millis(2));
1844
1845 NetAdapter::process_packet(heartbeat, source, &resp_session, &inbound, 1);
1847 let last_activity_after = resp_session.last_activity_ns();
1848 assert!(
1849 last_activity_after > last_activity_before,
1850 "legitimate AEAD-tagged heartbeat must call session.touch()"
1851 );
1852
1853 let mut forged = bytes::BytesMut::new();
1857 let header = NetHeader::heartbeat(resp_session.session_id());
1858 forged.extend_from_slice(&header.to_bytes());
1859 let forged = forged.freeze();
1860 let last_activity_before = resp_session.last_activity_ns();
1861 std::thread::sleep(std::time::Duration::from_millis(2));
1862 NetAdapter::process_packet(forged, source, &resp_session, &inbound, 1);
1863 let last_activity_after = resp_session.last_activity_ns();
1864 assert_eq!(
1865 last_activity_before, last_activity_after,
1866 "unauthenticated heartbeat (no AEAD tag) must NOT touch the session"
1867 );
1868
1869 let mut forged_tag = bytes::BytesMut::new();
1872 let mut header_bytes = NetHeader::heartbeat(resp_session.session_id()).to_bytes();
1873 header_bytes[12..16].copy_from_slice(&[0u8; 4]);
1876 header_bytes[16..24].copy_from_slice(&1u64.to_le_bytes());
1877 forged_tag.extend_from_slice(&header_bytes);
1878 forged_tag.extend_from_slice(&[0xAAu8; 16]); let forged_tag = forged_tag.freeze();
1880 let last_activity_before = resp_session.last_activity_ns();
1881 std::thread::sleep(std::time::Duration::from_millis(2));
1882 NetAdapter::process_packet(forged_tag, source, &resp_session, &inbound, 1);
1883 let last_activity_after = resp_session.last_activity_ns();
1884 assert_eq!(
1885 last_activity_before, last_activity_after,
1886 "heartbeat with garbage AEAD tag must NOT touch the session"
1887 );
1888 }
1889
1890 #[test]
1896 fn handshake_pacer_rejects_floods_per_source() {
1897 use std::time::Duration;
1898 let mut pacer = HandshakePacer::new(3, Duration::from_millis(50));
1899
1900 let attacker: std::net::SocketAddr = "10.0.0.1:9000".parse().unwrap();
1901 let legit: std::net::SocketAddr = "10.0.0.2:9000".parse().unwrap();
1902
1903 for _ in 0..3 {
1905 assert!(pacer.check_and_record(attacker));
1906 }
1907 for _ in 0..10 {
1909 assert!(
1910 !pacer.check_and_record(attacker),
1911 "attacker exceeding budget must be dropped"
1912 );
1913 }
1914
1915 assert!(
1918 pacer.check_and_record(legit),
1919 "legitimate source must still get through despite attacker flood"
1920 );
1921
1922 std::thread::sleep(Duration::from_millis(55));
1924 assert!(
1925 pacer.check_and_record(attacker),
1926 "attacker budget must refill after window"
1927 );
1928 }
1929}