1mod batch;
34pub mod behavior;
35mod cancel_registry;
41pub mod channel;
42pub mod compute;
43mod config;
44pub mod contested;
45pub mod continuity;
46#[cfg(feature = "cortex")]
47pub mod cortex;
48mod crypto;
49mod failure;
50pub mod identity;
51mod mesh;
52#[cfg(feature = "dataforts")]
59pub mod dataforts;
60#[cfg(feature = "cortex")]
61pub mod mesh_rpc;
62#[cfg(feature = "cortex")]
63pub mod mesh_rpc_metrics;
64#[cfg(feature = "netdb")]
65pub mod netdb;
66mod pool;
67mod protocol;
68mod proxy;
69#[cfg(feature = "redex")]
70pub mod redex;
71mod reliability;
72mod reroute;
73mod route;
74mod router;
75mod session;
76pub mod state;
77mod stream;
78pub mod subnet;
79pub mod subprotocol;
80mod swarm;
81mod transport;
82#[cfg(feature = "nat-traversal")]
83pub mod traversal;
84
85#[cfg(target_os = "linux")]
86mod linux;
87
88pub use batch::AdaptiveBatcher;
89pub use channel::{
90 AckReason, AuthGuard, AuthVerdict, ChannelConfig, ChannelConfigRegistry, ChannelError,
91 ChannelHash, ChannelId, ChannelName, ChannelPublisher, ChannelRegistry, MembershipMsg,
92 OnFailure, PublishConfig, PublishReport, SubscriberRoster, Visibility,
93 SUBPROTOCOL_CHANNEL_MEMBERSHIP,
94};
95pub use compute::{
96 DaemonError, DaemonFactoryRegistry, DaemonHost, DaemonHostConfig, DaemonRegistry, DaemonStats,
97 FactoryEntry, MeshDaemon, MigrationError, MigrationMessage, MigrationOrchestrator,
98 MigrationPhase, MigrationSourceHandler, MigrationState, MigrationTargetHandler,
99 PlacementDecision, Scheduler, SchedulerError, SUBPROTOCOL_MIGRATION,
100};
101pub use config::{ConnectionRole, NetAdapterConfig, ReliabilityConfig};
102pub use contested::{
103 CorrelatedFailureConfig, CorrelatedFailureDetector, CorrelationVerdict, FailureCause,
104 PartitionDetector, PartitionPhase, PartitionRecord, ReconcileOutcome, Side,
105 SUBPROTOCOL_PARTITION,
106};
107pub use continuity::{
108 assess_continuity, CausalCone, Causality, ContinuityProof, ContinuityStatus, Discontinuity,
109 DiscontinuityReason, ForkRecord, HorizonDivergence, ObservationWindow, ProofError,
110 PropagationModel, SuperpositionPhase, SuperpositionState, SUBPROTOCOL_CONTINUITY,
111};
112#[cfg(feature = "cortex")]
113pub use cortex::{
114 CortexAdapter, CortexAdapterConfig, CortexAdapterError, EventEnvelope, EventMeta,
115 FoldErrorPolicy, IntoRedexPayload, StartPosition, EVENT_META_SIZE,
116};
117pub use crypto::{CryptoError, SessionKeys, StaticKeypair};
118pub use failure::{
119 CircuitBreaker, CircuitState, FailureDetector, FailureDetectorConfig, FailureStats,
120 LossSimulator, NodeStatus, RecoveryAction, RecoveryManager, RecoveryStats,
121};
122pub use identity::{
123 EntityError, EntityId, EntityKeypair, OriginStamp, PermissionToken, TokenCache, TokenError,
124 TokenScope,
125};
126pub use mesh::{MeshNode, MeshNodeConfig, PartitionFilter};
127#[cfg(feature = "netdb")]
128pub use netdb::{MemoriesFilter, NetDb, NetDbBuilder, NetDbError, NetDbSnapshot, TasksFilter};
129pub use pool::{PacketBuilder, PacketPool, SharedLocalPool, ThreadLocalPool};
135pub use protocol::{
136 EventFrame, NackPayload, NetHeader, PacketFlags, HEADER_SIZE, NONCE_SIZE, TAG_SIZE,
137};
138pub use proxy::{
139 ForwardResult, HopStats, MultiHopPacketBuilder, NetProxy, ProxyConfig, ProxyError, ProxyStats,
140};
141#[cfg(feature = "redex")]
142pub use redex::{
143 FsyncPolicy, IndexOp, IndexStart, OrderedAppender, Redex, RedexEntry, RedexError, RedexEvent,
144 RedexFile, RedexFileConfig, RedexFlags, RedexFold, RedexIndex, TypedRedexFile,
145};
146pub use reliability::{FireAndForget, ReliabilityMode, ReliableStream, RetransmitDescriptor};
147pub use reroute::ReroutePolicy;
148pub use route::{
149 AggregateStats, RouteEntry, RouteFlags, RoutingHeader, RoutingTable, SchedulerStreamStats,
150 ROUTING_HEADER_SIZE,
151};
152pub use router::{FairScheduler, NetRouter, RouteAction, RouterConfig, RouterError, RouterStats};
153pub use session::{NetSession, SessionManager, StreamState, TxAdmit, TxSlotGuard};
154pub use state::{
155 CausalChainBuilder, CausalEvent, CausalLink, ChainError, EntityLog, HorizonEncoder, LogError,
156 LogIndex, ObservedHorizon, SnapshotStore, StateSnapshot, CAUSAL_LINK_SIZE, SUBPROTOCOL_CAUSAL,
157 SUBPROTOCOL_SNAPSHOT,
158};
159pub use stream::{
160 CloseBehavior, Reliability, Stream, StreamConfig, StreamError, StreamStats,
161 DEFAULT_STREAM_WINDOW_BYTES,
162};
163pub use subnet::{DropReason, ForwardDecision, SubnetGateway, SubnetId, SubnetPolicy, SubnetRule};
164pub use subprotocol::{
165 negotiate, MigrationSubprotocolHandler, NegotiatedSet, OutboundMigrationMessage,
166 SubprotocolDescriptor, SubprotocolManifest, SubprotocolRegistry, SubprotocolVersion,
167 SUBPROTOCOL_NEGOTIATION,
168};
169pub use swarm::{
170 Capabilities, CapabilityAd, EdgeInfo, GraphStats, LocalGraph, NodeInfo, Pingwave,
171 MAX_GRAPH_NODES, MAX_SEEN_PINGWAVES, PINGWAVE_SIZE,
172};
173pub use transport::{NetSocket, PacketReceiver, PacketSender, ParsedPacket, SocketBufferConfig};
174
175use async_trait::async_trait;
176use bytes::Bytes;
177use crossbeam_queue::SegQueue;
178use dashmap::DashMap;
179use std::sync::atomic::{AtomicBool, Ordering};
180use std::sync::Arc;
181use tokio::sync::Mutex as TokioMutex;
182use tokio::sync::Notify;
183use tokio::task::JoinHandle;
184
185use crate::adapter::{Adapter, ShardPollResult};
186use crate::error::AdapterError;
187use crate::event::{Batch, StoredEvent};
188
189use crypto::NoiseHandshake;
190use session::SessionManager as SessionMgr;
191use transport::NetSocket as Socket;
192
193pub use routing::{route_to_shard, stream_id_from_bytes, stream_id_from_key};
195
196#[inline]
212pub(crate) fn current_timestamp() -> u64 {
213 let elapsed = std::time::SystemTime::now()
214 .duration_since(std::time::UNIX_EPOCH)
215 .unwrap_or_default();
216 u64::try_from(elapsed.as_nanos()).unwrap_or(u64::MAX)
217}
218
219#[inline]
225pub(crate) fn current_timestamp_micros() -> u64 {
226 std::time::SystemTime::now()
227 .duration_since(std::time::UNIX_EPOCH)
228 .map(|d| d.as_micros() as u64)
229 .unwrap_or(0)
230}
231
232mod routing {
236 use xxhash_rust::xxh3::xxh3_64;
237
238 #[inline]
242 pub fn stream_id_from_bytes(data: &[u8]) -> u64 {
243 xxh3_64(data)
244 }
245
246 #[inline]
250 pub fn stream_id_from_key(key: &str) -> u64 {
251 xxh3_64(key.as_bytes())
252 }
253
254 #[inline]
262 pub fn route_to_shard(data: &[u8], num_shards: u16) -> u16 {
263 assert!(num_shards > 0, "num_shards must be > 0");
264 (xxh3_64(data) % num_shards as u64) as u16
265 }
266
267 #[cfg(test)]
268 mod tests {
269 use super::*;
270
271 #[test]
272 fn test_stream_id_deterministic() {
273 let data = b"test event data";
274 let id1 = stream_id_from_bytes(data);
275 let id2 = stream_id_from_bytes(data);
276 assert_eq!(id1, id2);
277 }
278
279 #[test]
280 fn test_stream_id_different_for_different_data() {
281 let id1 = stream_id_from_bytes(b"event1");
282 let id2 = stream_id_from_bytes(b"event2");
283 assert_ne!(id1, id2);
284 }
285
286 #[test]
287 fn test_stream_id_from_key() {
288 let id = stream_id_from_key("user:12345");
289 assert_ne!(id, 0);
290 }
291
292 #[test]
293 fn test_route_to_shard_range() {
294 let num_shards = 16u16;
295 for i in 0..1000 {
296 let data = format!("event_{}", i);
297 let shard = route_to_shard(data.as_bytes(), num_shards);
298 assert!(shard < num_shards);
299 }
300 }
301
302 #[test]
303 #[should_panic(expected = "num_shards must be > 0")]
304 fn test_route_to_shard_zero_shards_panics() {
305 route_to_shard(b"test", 0);
308 }
309
310 #[test]
311 fn test_route_to_shard_distribution() {
312 let num_shards = 8u16;
313 let mut counts = [0u32; 8];
314
315 for i in 0..8000 {
316 let data = format!("event_{}", i);
317 let shard = route_to_shard(data.as_bytes(), num_shards);
318 counts[shard as usize] += 1;
319 }
320
321 let expected = 1000;
323 for count in counts {
324 assert!(count > expected / 2, "shard count {} too low", count);
325 assert!(count < expected * 2, "shard count {} too high", count);
326 }
327 }
328 }
329}
330
331type InboundQueues = Arc<DashMap<u16, SegQueue<StoredEvent>>>;
333
334pub(crate) struct HandshakePacer {
348 entries: std::collections::HashMap<std::net::SocketAddr, (u32, std::time::Instant)>,
350 max_per_window: u32,
352 window: std::time::Duration,
354 last_gc: std::time::Instant,
356 gc_size_threshold: usize,
360}
361
362impl HandshakePacer {
363 pub(crate) fn new(max_per_window: u32, window: std::time::Duration) -> Self {
364 Self {
365 entries: std::collections::HashMap::new(),
366 max_per_window,
367 window,
368 last_gc: std::time::Instant::now(),
369 gc_size_threshold: 4096,
373 }
374 }
375
376 pub(crate) fn check_and_record(&mut self, source: std::net::SocketAddr) -> bool {
380 let now = std::time::Instant::now();
381 if now.duration_since(self.last_gc) >= self.window
390 || self.entries.len() >= self.gc_size_threshold
391 {
392 let cutoff = self.window.saturating_mul(2);
393 self.entries
394 .retain(|_, (_, start)| now.duration_since(*start) < cutoff);
395 self.last_gc = now;
396 }
397
398 let entry = self.entries.entry(source).or_insert((0, now));
399 if now.duration_since(entry.1) > self.window {
400 entry.0 = 0;
402 entry.1 = now;
403 }
404 entry.0 = entry.0.saturating_add(1);
405 entry.0 <= self.max_per_window
406 }
407}
408
409pub struct NetAdapter {
411 config: NetAdapterConfig,
413 socket: Option<Arc<Socket>>,
415 session: Option<Arc<NetSession>>,
417 session_manager: SessionMgr,
419 inbound: InboundQueues,
421 tasks: TokioMutex<Vec<JoinHandle<()>>>,
423 shutdown: Arc<AtomicBool>,
425 shutdown_notify: Arc<Notify>,
427 initialized: AtomicBool,
429 handshake_pacer: parking_lot::Mutex<HandshakePacer>,
434}
435
436impl NetAdapter {
437 pub fn new(config: NetAdapterConfig) -> Result<Self, AdapterError> {
439 config
440 .validate()
441 .map_err(|e| AdapterError::Fatal(format!("invalid config: {}", e)))?;
442
443 Ok(Self {
444 session_manager: SessionMgr::new(config.session_timeout),
445 config,
446 socket: None,
447 session: None,
448 inbound: Arc::new(DashMap::new()),
449 tasks: TokioMutex::new(Vec::new()),
450 shutdown: Arc::new(AtomicBool::new(false)),
451 shutdown_notify: Arc::new(Notify::new()),
452 initialized: AtomicBool::new(false),
453 handshake_pacer: parking_lot::Mutex::new(HandshakePacer::new(
457 5,
458 std::time::Duration::from_secs(1),
459 )),
460 })
461 }
462
463 async fn perform_handshake(
466 &self,
467 socket: &Socket,
468 ) -> Result<(SessionKeys, std::net::SocketAddr), AdapterError> {
469 let mut attempt = 0;
470 let max_attempts = self.config.handshake_retries;
471
472 const HANDSHAKE_RETRY_SLEEP_CAP_MS: u64 = 5_000;
480
481 loop {
482 attempt += 1;
483 match self.try_handshake(socket).await {
484 Ok(result) => return Ok(result),
485 Err(e) if attempt < max_attempts => {
486 tracing::warn!(
487 attempt = attempt,
488 max = max_attempts,
489 error = %e,
490 "handshake failed, retrying"
491 );
492 let backoff_ms =
493 (100u64.saturating_mul(attempt as u64)).min(HANDSHAKE_RETRY_SLEEP_CAP_MS);
494 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
495 }
496 Err(e) => return Err(e),
497 }
498 }
499 }
500
501 async fn try_handshake(
504 &self,
505 socket: &Socket,
506 ) -> Result<(SessionKeys, std::net::SocketAddr), AdapterError> {
507 let timeout = self.config.handshake_timeout;
508 let socket_arc = socket.socket_arc();
509
510 if self.config.is_initiator() {
511 let peer_pubkey = self
513 .config
514 .peer_static_pubkey
515 .as_ref()
516 .ok_or_else(|| AdapterError::Fatal("missing peer public key".into()))?;
517
518 let mut handshake = NoiseHandshake::initiator(&self.config.psk, peer_pubkey)
519 .map_err(|e| AdapterError::Fatal(format!("handshake init failed: {}", e)))?;
520
521 let msg1 = handshake
523 .write_message(&[])
524 .map_err(|e| AdapterError::Connection(format!("write_message failed: {}", e)))?;
525
526 let mut builder = PacketBuilder::new(&[0u8; 32], 0);
527 let packet = builder.build_handshake(&msg1);
528
529 socket
530 .send_to(&packet, self.config.peer_addr)
531 .await
532 .map_err(|e| AdapterError::Connection(format!("send failed: {}", e)))?;
533
534 let (parsed, _source) = tokio::time::timeout(timeout, async {
538 let mut recv_buf = [0u8; protocol::MAX_PACKET_SIZE];
551 loop {
552 let (n, source) = socket_arc
553 .recv_from(&mut recv_buf)
554 .await
555 .map_err(|e| AdapterError::Connection(format!("recv failed: {}", e)))?;
556
557 if source != self.config.peer_addr {
559 continue;
560 }
561
562 let data = bytes::Bytes::copy_from_slice(&recv_buf[..n]);
563
564 if let Some(p) = ParsedPacket::parse(data, source) {
565 if p.header.flags.is_handshake() {
566 return Ok::<_, AdapterError>((p, source));
567 }
568 }
569 }
571 })
572 .await
573 .map_err(|_| AdapterError::Connection("handshake timeout".into()))??;
574
575 handshake
577 .read_message(&parsed.payload)
578 .map_err(|e| AdapterError::Connection(format!("read_message failed: {}", e)))?;
579
580 let keys = handshake
582 .into_session_keys()
583 .map_err(|e| AdapterError::Fatal(format!("key extraction failed: {}", e)))?;
584 Ok((keys, self.config.peer_addr))
585 } else {
586 let keypair = self
588 .config
589 .static_keypair
590 .as_ref()
591 .ok_or_else(|| AdapterError::Fatal("missing static keypair".into()))?;
592
593 let (parsed, source) = tokio::time::timeout(timeout, async {
600 loop {
601 let mut recv_buf = bytes::BytesMut::with_capacity(protocol::MAX_PACKET_SIZE);
602 recv_buf.resize(protocol::MAX_PACKET_SIZE, 0);
603
604 let (n, source) = socket_arc
605 .recv_from(&mut recv_buf)
606 .await
607 .map_err(|e| AdapterError::Connection(format!("recv failed: {}", e)))?;
608
609 recv_buf.truncate(n);
610 let data = recv_buf.freeze();
611
612 if let Some(p) = ParsedPacket::parse(data, source) {
613 if p.header.flags.is_handshake() {
614 let allowed = self.handshake_pacer.lock().check_and_record(source);
617 if !allowed {
618 tracing::debug!(
619 %source,
620 "handshake responder: dropping packet from \
621 rate-limited source"
622 );
623 continue;
624 }
625 return Ok::<_, AdapterError>((p, source));
626 }
627 }
628 }
630 })
631 .await
632 .map_err(|_| AdapterError::Connection("handshake timeout".into()))??;
633
634 let mut handshake = NoiseHandshake::responder(&self.config.psk, keypair)
635 .map_err(|e| AdapterError::Fatal(format!("handshake init failed: {}", e)))?;
636
637 handshake
639 .read_message(&parsed.payload)
640 .map_err(|e| AdapterError::Connection(format!("read_message failed: {}", e)))?;
641
642 let msg2 = handshake
644 .write_message(&[])
645 .map_err(|e| AdapterError::Connection(format!("write_message failed: {}", e)))?;
646
647 let mut builder = PacketBuilder::new(&[0u8; 32], 0);
648 let packet = builder.build_handshake(&msg2);
649
650 socket
653 .send_to(&packet, source)
654 .await
655 .map_err(|e| AdapterError::Connection(format!("send failed: {}", e)))?;
656
657 let keys = handshake
659 .into_session_keys()
660 .map_err(|e| AdapterError::Fatal(format!("key extraction failed: {}", e)))?;
661 Ok((keys, source))
662 }
663 }
664
665 fn process_packet(
667 data: Bytes,
668 source: std::net::SocketAddr,
669 session: &NetSession,
670 inbound: &InboundQueues,
671 num_shards: u16,
672 ) {
673 let mut parsed = match ParsedPacket::parse(data, source) {
675 Some(p) => p,
676 None => return,
677 };
678
679 if !parsed.header.flags.is_handshake()
683 && !parsed.header.flags.is_heartbeat()
684 && !parsed.is_valid_length()
685 {
686 return;
687 }
688
689 if parsed.header.flags.is_handshake() {
691 return;
692 }
693
694 if parsed.header.session_id != session.session_id() {
696 return;
697 }
698
699 if parsed.header.flags.is_heartbeat() {
716 if source == session.peer_addr() {
717 session.verify_and_touch_heartbeat(&parsed);
718 }
719 return;
720 }
721
722 let aad = parsed.header.aad();
727 let counter = u64::from_le_bytes(parsed.header.nonce[4..12].try_into().unwrap_or([0u8; 8]));
728 let rx_cipher = session.rx_cipher();
729 let payload = std::mem::take(&mut parsed.payload);
730 let decrypted = match rx_cipher.decrypt_to_bytes(counter, &aad, payload) {
738 Ok(d) => {
739 if !rx_cipher.try_admit_rx_counter(counter) {
740 return;
741 }
742 d
743 }
744 Err(_) => return,
745 };
746
747 let events = EventFrame::read_events(decrypted, parsed.header.event_count);
749
750 let stream_id = parsed.header.stream_id;
752 let shard_id = if num_shards > 0 {
753 (stream_id % num_shards as u64) as u16
754 } else {
755 0
756 };
757
758 let is_fresh = {
771 let stream = session.get_or_create_stream(stream_id);
772 let fresh = stream.with_reliability(|r| r.on_receive(parsed.header.sequence));
778 stream.update_rx_seq(parsed.header.sequence);
779 fresh
780 };
781
782 if is_fresh {
783 let queue = inbound.entry(shard_id).or_default();
785 let seq = parsed.header.sequence;
786 for (i, event_data) in events.into_iter().enumerate() {
787 use std::fmt::Write;
788 let mut event_id = String::with_capacity(24);
789 let _ = write!(event_id, "{}:{}", seq, i);
790 queue.push(StoredEvent::new(event_id, event_data, seq, shard_id));
791 }
792 } else {
793 tracing::debug!(
794 seq = parsed.header.sequence,
795 stream_id,
796 "Dropping duplicate packet"
797 );
798 }
799
800 session.touch();
801 }
802
803 #[cfg(target_os = "linux")]
808 fn spawn_receiver(
809 shutdown: Arc<AtomicBool>,
810 shutdown_notify: Arc<Notify>,
811 socket: Arc<Socket>,
812 session: Arc<NetSession>,
813 inbound: InboundQueues,
814 num_shards: u16,
815 ) -> JoinHandle<()> {
816 let mut receiver = transport::BatchedPacketReceiver::new(socket.socket_arc());
817
818 tokio::spawn(async move {
819 while !shutdown.load(Ordering::Acquire) {
820 tokio::select! {
821 result = receiver.recv() => {
822 match result {
823 Ok((data, source)) => {
824 Self::process_packet(data, source, &session, &inbound, num_shards);
825 }
826 Err(e) if e.kind() == std::io::ErrorKind::ConnectionReset => {
827 tracing::warn!("batch receiver thread exited, stopping receiver");
828 break;
829 }
830 Err(e) => {
831 if !shutdown.load(Ordering::Acquire) {
832 tracing::warn!(error = %e, "receive error");
833 }
834 }
835 }
836 }
837 _ = shutdown_notify.notified() => {
838 break;
839 }
840 }
841 }
842 })
843 }
844
845 #[cfg(not(target_os = "linux"))]
847 fn spawn_receiver(
848 shutdown: Arc<AtomicBool>,
849 shutdown_notify: Arc<Notify>,
850 socket: Arc<Socket>,
851 session: Arc<NetSession>,
852 inbound: InboundQueues,
853 num_shards: u16,
854 ) -> JoinHandle<()> {
855 tokio::spawn(async move {
856 let mut receiver = PacketReceiver::new(socket.socket_arc());
857
858 while !shutdown.load(Ordering::Acquire) {
859 tokio::select! {
863 result = receiver.recv() => {
864 match result {
865 Ok((data, source)) => {
866 Self::process_packet(data, source, &session, &inbound, num_shards);
867 }
868 Err(e) => {
869 if !shutdown.load(Ordering::Acquire) {
870 tracing::warn!(error = %e, "receive error");
871 }
872 }
873 }
874 }
875 _ = shutdown_notify.notified() => {
876 break;
877 }
878 }
879 }
880 })
881 }
882
883 fn spawn_heartbeat(
885 shutdown: Arc<AtomicBool>,
886 shutdown_notify: Arc<Notify>,
887 socket: Arc<Socket>,
888 session: Arc<NetSession>,
889 interval: std::time::Duration,
890 peer_addr: std::net::SocketAddr,
891 ) -> JoinHandle<()> {
892 tokio::spawn(async move {
893 let mut ticker = tokio::time::interval(interval);
894
895 loop {
896 tokio::select! {
897 _ = ticker.tick() => {
898 if shutdown.load(Ordering::Acquire) || !session.is_active() {
899 break;
900 }
901
902 let packet = session.build_heartbeat();
916
917 if let Err(e) = socket.send_to(&packet, peer_addr).await {
918 tracing::warn!(error = %e, "heartbeat send failed");
919 }
920 }
921 _ = shutdown_notify.notified() => {
922 break;
923 }
924 }
925 }
926 })
927 }
928}
929
930#[async_trait]
931impl Adapter for NetAdapter {
932 async fn init(&mut self) -> Result<(), AdapterError> {
933 if self.initialized.load(Ordering::Acquire) {
934 return Ok(());
935 }
936
937 let socket_config = match (
939 self.config.socket_recv_buffer,
940 self.config.socket_send_buffer,
941 ) {
942 (Some(recv), Some(send)) => transport::SocketBufferConfig {
943 recv_buffer_size: recv,
944 send_buffer_size: send,
945 },
946 _ => transport::SocketBufferConfig::default(),
947 };
948 let socket = Socket::with_config(self.config.bind_addr, socket_config)
949 .await
950 .map_err(|e| AdapterError::Connection(format!("socket creation failed: {}", e)))?;
951
952 let socket = Arc::new(socket);
953 self.socket = Some(socket.clone());
954
955 let (keys, actual_peer) = self.perform_handshake(&socket).await?;
957
958 let session = Arc::new(NetSession::new(
962 keys,
963 actual_peer,
964 self.config.packet_pool_size,
965 self.config.default_reliability.is_reliable(),
966 ));
967 self.session = Some(session.clone());
968
969 self.session_manager.set_session_arc(session.clone());
971
972 let recv_task = Self::spawn_receiver(
974 self.shutdown.clone(),
975 self.shutdown_notify.clone(),
976 socket.clone(),
977 session.clone(),
978 self.inbound.clone(),
979 self.config.num_shards,
980 );
981
982 let heartbeat_task = Self::spawn_heartbeat(
983 self.shutdown.clone(),
984 self.shutdown_notify.clone(),
985 socket,
986 session,
987 self.config.heartbeat_interval,
988 actual_peer,
989 );
990
991 {
992 let mut tasks = self.tasks.lock().await;
993 tasks.push(recv_task);
994 tasks.push(heartbeat_task);
995 }
996
997 self.initialized.store(true, Ordering::Release);
998
999 tracing::info!(
1000 bind_addr = %self.config.bind_addr,
1001 peer_addr = %self.config.peer_addr,
1002 role = ?self.config.role,
1003 "Net adapter initialized"
1004 );
1005
1006 Ok(())
1007 }
1008
1009 async fn on_batch(&self, batch: std::sync::Arc<Batch>) -> Result<(), AdapterError> {
1010 let session = self
1011 .session
1012 .as_ref()
1013 .ok_or_else(|| AdapterError::Connection("not connected".into()))?;
1014
1015 let socket = self
1016 .socket
1017 .as_ref()
1018 .ok_or_else(|| AdapterError::Connection("socket not initialized".into()))?;
1019
1020 let stream_id = batch.shard_id as u64;
1021 let peer_addr = session.peer_addr();
1022
1023 let reliable = {
1027 let stream = session.get_or_create_stream(stream_id);
1028 stream.with_reliability(|r| r.needs_ack())
1029 };
1031
1032 let mut current_batch: Vec<Bytes> = Vec::with_capacity(64);
1034 let mut current_size = 0usize;
1035
1036 let pool = session.thread_local_pool();
1038 let mut builder = pool.get();
1039
1040 for event in &batch.events {
1041 let event_bytes = event.raw.clone();
1042 let frame_size = EventFrame::LEN_SIZE + event_bytes.len();
1043
1044 if current_size + frame_size > protocol::MAX_PAYLOAD_SIZE && !current_batch.is_empty() {
1046 let seq;
1048 {
1049 let stream = session.get_or_create_stream(stream_id);
1050 seq = stream.next_tx_seq();
1051 }
1052
1053 let flags = if reliable {
1054 PacketFlags::RELIABLE
1055 } else {
1056 PacketFlags::NONE
1057 };
1058
1059 let packet = builder.build(stream_id, seq, ¤t_batch, flags);
1060
1061 socket
1063 .send_to(&packet, peer_addr)
1064 .await
1065 .map_err(|e| AdapterError::Connection(format!("send failed: {}", e)))?;
1066
1067 if reliable {
1074 let descriptor = std::sync::Arc::new(reliability::RetransmitDescriptor {
1079 seq,
1080 stream_id,
1081 events: current_batch.clone(),
1082 flags,
1083 });
1084 let stream = session.get_or_create_stream(stream_id);
1085 stream.with_reliability(|r| r.on_send(descriptor));
1086 }
1087
1088 current_batch.clear();
1089 current_size = 0;
1090 }
1091
1092 current_batch.push(event_bytes);
1093 current_size += frame_size;
1094 }
1095
1096 if !current_batch.is_empty() {
1098 let seq;
1099 {
1100 let stream = session.get_or_create_stream(stream_id);
1101 seq = stream.next_tx_seq();
1102 }
1103
1104 let flags = if reliable {
1105 PacketFlags::RELIABLE
1106 } else {
1107 PacketFlags::NONE
1108 };
1109
1110 let packet = builder.build(stream_id, seq, ¤t_batch, flags);
1111
1112 socket
1113 .send_to(&packet, peer_addr)
1114 .await
1115 .map_err(|e| AdapterError::Connection(format!("send failed: {}", e)))?;
1116
1117 if reliable {
1118 let descriptor = std::sync::Arc::new(reliability::RetransmitDescriptor {
1120 seq,
1121 stream_id,
1122 events: current_batch.clone(),
1123 flags,
1124 });
1125 let stream = session.get_or_create_stream(stream_id);
1126 stream.with_reliability(|r| r.on_send(descriptor));
1127 }
1128 }
1129
1130 session.touch();
1131
1132 Ok(())
1133 }
1134
1135 async fn poll_shard(
1136 &self,
1137 shard_id: u16,
1138 from_id: Option<&str>,
1139 limit: usize,
1140 ) -> Result<ShardPollResult, AdapterError> {
1141 let mut events = Vec::with_capacity(limit);
1142
1143 if let Some(queue) = self.inbound.get(&shard_id) {
1144 while events.len() < limit {
1145 if let Some(event) = queue.pop() {
1146 if from_id.is_none() || event_id_gt(&event.id, from_id.unwrap_or("")) {
1147 events.push(event);
1148 }
1149 } else {
1154 break;
1155 }
1156 }
1157 }
1158
1159 let has_more = self
1160 .inbound
1161 .get(&shard_id)
1162 .map(|q| !q.is_empty())
1163 .unwrap_or(false);
1164 let next_id = events.last().map(|e| e.id.clone());
1165
1166 Ok(ShardPollResult {
1167 events,
1168 next_id,
1169 has_more,
1170 })
1171 }
1172
1173 async fn flush(&self) -> Result<(), AdapterError> {
1174 Ok(())
1177 }
1178
1179 async fn shutdown(&self) -> Result<(), AdapterError> {
1180 self.shutdown.store(true, Ordering::Release);
1181
1182 self.shutdown_notify.notify_waiters();
1185
1186 self.session_manager.clear_session();
1188
1189 let mut tasks = self.tasks.lock().await;
1191 for task in tasks.drain(..) {
1192 let _ = task.await;
1193 }
1194
1195 self.initialized.store(false, Ordering::Release);
1196
1197 tracing::info!("Net adapter shutdown complete");
1198
1199 Ok(())
1200 }
1201
1202 fn name(&self) -> &'static str {
1203 "net"
1204 }
1205
1206 async fn is_healthy(&self) -> bool {
1207 self.initialized.load(Ordering::Acquire) && self.session_manager.check_session()
1208 }
1209}
1210
1211impl std::fmt::Debug for NetAdapter {
1212 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1213 f.debug_struct("NetAdapter")
1214 .field("config", &self.config)
1215 .field("initialized", &self.initialized.load(Ordering::Relaxed))
1216 .finish()
1217 }
1218}
1219
1220fn event_id_gt(a: &str, b: &str) -> bool {
1227 fn parse_id(id: &str) -> Option<(u64, u64)> {
1228 let (seq, idx) = id.split_once(':')?;
1229 Some((seq.parse().ok()?, idx.parse().ok()?))
1230 }
1231
1232 match (parse_id(a), parse_id(b)) {
1233 (Some(a), Some(b)) => a > b,
1234 _ => a > b, }
1236}
1237
1238#[cfg(test)]
1239mod tests {
1240 use super::*;
1241
1242 #[test]
1243 fn test_adapter_creation() {
1244 let psk = [0x42u8; 32];
1245 let peer_pubkey = [0x24u8; 32];
1246
1247 let config = NetAdapterConfig::initiator(
1248 "127.0.0.1:0".parse().unwrap(),
1249 "127.0.0.1:9999".parse().unwrap(),
1250 psk,
1251 peer_pubkey,
1252 );
1253
1254 let adapter = NetAdapter::new(config).unwrap();
1255 assert_eq!(adapter.name(), "net");
1256 }
1257
1258 #[test]
1259 fn test_shard_id_from_stream_id_uses_modulo() {
1260 let num_shards: u16 = 8;
1264
1265 let stream_a: u64 = 0xDEAD_BEEF_0000_0003;
1270 let stream_b: u64 = 0xCAFE_BABE_0000_0003;
1271
1272 let shard_a = (stream_a % num_shards as u64) as u16;
1273 let shard_b = (stream_b % num_shards as u64) as u16;
1274
1275 assert!(
1276 shard_a < num_shards,
1277 "shard must be in range [0, num_shards)"
1278 );
1279 assert!(
1280 shard_b < num_shards,
1281 "shard must be in range [0, num_shards)"
1282 );
1283
1284 let big_stream: u64 = 0xFFFF_FFFF_FFFF_FFFF;
1286 let shard_big = (big_stream % num_shards as u64) as u16;
1287 assert!(shard_big < num_shards);
1288
1289 assert_ne!(
1292 big_stream as u16, shard_big,
1293 "modulo must differ from truncation for large stream IDs"
1294 );
1295 }
1296
1297 #[test]
1298 fn test_invalid_config() {
1299 let psk = [0x42u8; 32];
1300 let peer_pubkey = [0x24u8; 32];
1301
1302 let mut config = NetAdapterConfig::initiator(
1303 "127.0.0.1:0".parse().unwrap(),
1304 "127.0.0.1:9999".parse().unwrap(),
1305 psk,
1306 peer_pubkey,
1307 );
1308 config.peer_static_pubkey = None;
1309
1310 let result = NetAdapter::new(config);
1311 assert!(result.is_err());
1312 }
1313
1314 #[test]
1317 fn test_event_id_gt_numeric_ordering() {
1318 assert!(event_id_gt("2:0", "1:0"));
1320 assert!(!event_id_gt("1:0", "2:0"));
1321 assert!(!event_id_gt("1:0", "1:0"));
1322
1323 assert!(event_id_gt("10:0", "9:0"));
1325 assert!(event_id_gt("100:0", "99:0"));
1326 assert!(!event_id_gt("9:0", "10:0"));
1327
1328 assert!(event_id_gt("5:2", "5:1"));
1330 assert!(!event_id_gt("5:1", "5:2"));
1331
1332 assert!(event_id_gt("1000000:0", "999999:0"));
1334 }
1335
1336 #[test]
1342 fn test_event_id_gt_edge_cases() {
1343 assert!(event_id_gt("1:0", ""));
1345 assert!(event_id_gt("b", "a"));
1347 assert!(!event_id_gt("a", "b"));
1348 }
1349
1350 #[test]
1355 fn test_build_then_process_packet_roundtrip() {
1356 use crate::adapter::net::crypto::{NoiseHandshake, StaticKeypair};
1357 use dashmap::DashMap;
1358 use std::sync::Arc;
1359
1360 let psk = [0x42u8; 32];
1362 let responder_kp = StaticKeypair::generate();
1363
1364 let mut initiator = NoiseHandshake::initiator(&psk, &responder_kp.public).unwrap();
1365 let mut responder = NoiseHandshake::responder(&psk, &responder_kp).unwrap();
1366
1367 let msg1 = initiator.write_message(&[]).unwrap();
1368 responder.read_message(&msg1).unwrap();
1369 let msg2 = responder.write_message(&[]).unwrap();
1370 initiator.read_message(&msg2).unwrap();
1371
1372 let init_keys = initiator.into_session_keys().unwrap();
1373 let resp_keys = responder.into_session_keys().unwrap();
1374
1375 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1377 let events = vec![
1378 Bytes::from(r#"{"token":"hello"}"#),
1379 Bytes::from(r#"{"token":"world"}"#),
1380 ];
1381 let packet = builder.build(0, 0, &events, PacketFlags::NONE);
1382
1383 let resp_session = Arc::new(NetSession::new(
1385 resp_keys,
1386 "127.0.0.1:5000".parse().unwrap(),
1387 4,
1388 false,
1389 ));
1390 let inbound: InboundQueues = Arc::new(DashMap::new());
1391 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1392
1393 NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
1394
1395 let queue = inbound.get(&0).expect("shard 0 should have events");
1397 assert_eq!(queue.len(), 2, "expected 2 events, got {}", queue.len());
1398
1399 let e1 = queue.pop().unwrap();
1400 assert_eq!(&e1.raw[..], br#"{"token":"hello"}"#);
1401
1402 let e2 = queue.pop().unwrap();
1403 assert_eq!(&e2.raw[..], br#"{"token":"world"}"#);
1404 }
1405
1406 fn make_session_keys() -> (SessionKeys, SessionKeys) {
1408 use crate::adapter::net::crypto::{NoiseHandshake, StaticKeypair};
1409
1410 let psk = [0x42u8; 32];
1411 let responder_kp = StaticKeypair::generate();
1412
1413 let mut initiator = NoiseHandshake::initiator(&psk, &responder_kp.public).unwrap();
1414 let mut responder = NoiseHandshake::responder(&psk, &responder_kp).unwrap();
1415
1416 let msg1 = initiator.write_message(&[]).unwrap();
1417 responder.read_message(&msg1).unwrap();
1418 let msg2 = responder.write_message(&[]).unwrap();
1419 initiator.read_message(&msg2).unwrap();
1420
1421 (
1422 initiator.into_session_keys().unwrap(),
1423 responder.into_session_keys().unwrap(),
1424 )
1425 }
1426
1427 #[test]
1428 fn test_process_packet_rejects_truncated_packet() {
1429 use dashmap::DashMap;
1430 use std::sync::Arc;
1431
1432 let (init_keys, resp_keys) = make_session_keys();
1433
1434 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1436 let packet = builder.build(0, 0, &[Bytes::from_static(b"hello")], PacketFlags::NONE);
1437
1438 let resp_session = Arc::new(NetSession::new(
1439 resp_keys,
1440 "127.0.0.1:5000".parse().unwrap(),
1441 4,
1442 false,
1443 ));
1444 let inbound: InboundQueues = Arc::new(DashMap::new());
1445 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1446
1447 let truncated = packet.slice(..packet.len() - 10);
1449 NetAdapter::process_packet(truncated, source, &resp_session, &inbound, 1);
1450 assert!(
1451 inbound.get(&0).is_none() || inbound.get(&0).unwrap().is_empty(),
1452 "truncated packet must be silently dropped"
1453 );
1454 }
1455
1456 #[test]
1457 fn test_process_packet_rejects_tampered_payload() {
1458 use dashmap::DashMap;
1459 use std::sync::Arc;
1460
1461 let (init_keys, resp_keys) = make_session_keys();
1462
1463 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1464 let packet = builder.build(0, 0, &[Bytes::from_static(b"hello")], PacketFlags::NONE);
1465
1466 let resp_session = Arc::new(NetSession::new(
1467 resp_keys,
1468 "127.0.0.1:5000".parse().unwrap(),
1469 4,
1470 false,
1471 ));
1472 let inbound: InboundQueues = Arc::new(DashMap::new());
1473 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1474
1475 let mut tampered = bytes::BytesMut::from(&packet[..]);
1477 tampered[super::protocol::HEADER_SIZE + 2] ^= 0xFF;
1478 NetAdapter::process_packet(tampered.freeze(), source, &resp_session, &inbound, 1);
1479
1480 assert!(
1481 inbound.get(&0).is_none() || inbound.get(&0).unwrap().is_empty(),
1482 "tampered packet must be rejected by AEAD"
1483 );
1484 }
1485
1486 #[test]
1487 fn test_process_packet_rejects_wrong_session_id() {
1488 use dashmap::DashMap;
1489 use std::sync::Arc;
1490
1491 let (init_keys, resp_keys) = make_session_keys();
1492
1493 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1494 let packet = builder.build(0, 0, &[Bytes::from_static(b"hello")], PacketFlags::NONE);
1495
1496 let mut wrong_keys = resp_keys;
1498 wrong_keys.session_id = 0xDEAD;
1499 let resp_session = Arc::new(NetSession::new(
1500 wrong_keys,
1501 "127.0.0.1:5000".parse().unwrap(),
1502 4,
1503 false,
1504 ));
1505 let inbound: InboundQueues = Arc::new(DashMap::new());
1506 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1507
1508 NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
1509
1510 assert!(
1511 inbound.get(&0).is_none() || inbound.get(&0).unwrap().is_empty(),
1512 "packet with wrong session_id must be dropped"
1513 );
1514 }
1515
1516 #[test]
1517 fn test_process_packet_multi_packet_batch_all_events_arrive() {
1518 use dashmap::DashMap;
1519 use std::sync::Arc;
1520
1521 let (init_keys, resp_keys) = make_session_keys();
1522
1523 let resp_session = Arc::new(NetSession::new(
1524 resp_keys,
1525 "127.0.0.1:5000".parse().unwrap(),
1526 4,
1527 false,
1528 ));
1529 let inbound: InboundQueues = Arc::new(DashMap::new());
1530 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1531
1532 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1536 let total_events = 200;
1537 let mut seq = 0u64;
1538
1539 let mut current_batch: Vec<Bytes> = Vec::new();
1541 let mut current_size = 0;
1542
1543 for i in 0..total_events {
1544 let data = format!("{{\"i\":{},\"pad\":\"{}\"}}", i, "x".repeat(150));
1545 let event_bytes = Bytes::from(data);
1546 let frame_size = EventFrame::LEN_SIZE + event_bytes.len();
1547
1548 if current_size + frame_size > protocol::MAX_PAYLOAD_SIZE && !current_batch.is_empty() {
1549 let packet = builder.build(0, seq, ¤t_batch, PacketFlags::NONE);
1550 NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
1551 seq += 1;
1552 current_batch.clear();
1553 current_size = 0;
1554 }
1555
1556 current_batch.push(event_bytes);
1557 current_size += frame_size;
1558 }
1559
1560 if !current_batch.is_empty() {
1561 let packet = builder.build(0, seq, ¤t_batch, PacketFlags::NONE);
1562 NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
1563 }
1564
1565 let queue = inbound.get(&0).expect("shard 0 should have events");
1567 assert_eq!(
1568 queue.len(),
1569 total_events,
1570 "all {} events must arrive across multiple packets",
1571 total_events
1572 );
1573 }
1574
1575 #[test]
1576 fn test_build_then_process_packet_both_directions() {
1577 use dashmap::DashMap;
1578 use std::sync::Arc;
1579
1580 let (init_keys, resp_keys) = make_session_keys();
1581 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1582
1583 {
1585 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1586 let packet = builder.build(0, 0, &[Bytes::from_static(b"i2r")], PacketFlags::NONE);
1587
1588 let session = Arc::new(NetSession::new(resp_keys.clone(), source, 4, false));
1589 let inbound: InboundQueues = Arc::new(DashMap::new());
1590 NetAdapter::process_packet(packet, source, &session, &inbound, 1);
1591
1592 let queue = inbound.get(&0).expect("i2r: shard 0 should have events");
1593 assert_eq!(queue.len(), 1, "i2r: expected 1 event");
1594 assert_eq!(&queue.pop().unwrap().raw[..], b"i2r");
1595 }
1596
1597 {
1599 let mut builder = PacketBuilder::new(&resp_keys.tx_key, resp_keys.session_id);
1600 let packet = builder.build(0, 0, &[Bytes::from_static(b"r2i")], PacketFlags::NONE);
1601
1602 let session = Arc::new(NetSession::new(init_keys.clone(), source, 4, false));
1603 let inbound: InboundQueues = Arc::new(DashMap::new());
1604 NetAdapter::process_packet(packet, source, &session, &inbound, 1);
1605
1606 let queue = inbound.get(&0).expect("r2i: shard 0 should have events");
1607 assert_eq!(queue.len(), 1, "r2i: expected 1 event");
1608 assert_eq!(&queue.pop().unwrap().raw[..], b"r2i");
1609 }
1610 }
1611
1612 #[test]
1613 fn test_poll_shard_cursor_drops_consumed_events() {
1614 use std::sync::Arc;
1619
1620 let (init_keys, resp_keys) = make_session_keys();
1621
1622 let resp_session = Arc::new(NetSession::new(
1623 resp_keys,
1624 "127.0.0.1:5000".parse().unwrap(),
1625 4,
1626 false,
1627 ));
1628 let inbound: InboundQueues = Arc::new(DashMap::new());
1629 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1630
1631 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1633 for seq in 0..3u64 {
1634 let events = vec![Bytes::from(format!("event-{}", seq))];
1635 let packet = builder.build(0, seq, &events, PacketFlags::NONE);
1636 NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
1637 }
1638
1639 let queue = inbound.get(&0u16).unwrap();
1640 assert_eq!(queue.len(), 3);
1641
1642 let from_id = "0:0";
1645 let mut events = Vec::new();
1646 while events.len() < 10 {
1647 if let Some(event) = queue.pop() {
1648 if event_id_gt(&event.id, from_id) {
1649 events.push(event);
1650 }
1651 } else {
1653 break;
1654 }
1655 }
1656
1657 assert_eq!(events.len(), 2, "should get 2 events after cursor 0:0");
1658 assert_eq!(events[0].id, "1:0");
1659 assert_eq!(events[1].id, "2:0");
1660
1661 assert_eq!(queue.len(), 0, "queue should be empty after poll drains it");
1663 }
1664
1665 #[test]
1666 fn test_process_packet_old_counter_rejected() {
1667 use std::sync::Arc;
1670
1671 let (init_keys, resp_keys) = make_session_keys();
1672 let resp_session = Arc::new(NetSession::new(
1673 resp_keys,
1674 "127.0.0.1:5000".parse().unwrap(),
1675 4,
1676 false,
1677 ));
1678 let inbound: InboundQueues = Arc::new(DashMap::new());
1679 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1680
1681 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1683 for seq in 0..1100u64 {
1684 let packet = builder.build(0, seq, &[Bytes::from_static(b"x")], PacketFlags::NONE);
1685 NetAdapter::process_packet(packet, source, &resp_session, &inbound, 1);
1686 }
1687 assert_eq!(inbound.get(&0).unwrap().len(), 1100);
1688
1689 let mut stale_builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1693 let stale_packet =
1694 stale_builder.build(0, 9999, &[Bytes::from_static(b"stale")], PacketFlags::NONE);
1695 NetAdapter::process_packet(stale_packet, source, &resp_session, &inbound, 1);
1696
1697 assert_eq!(
1699 inbound.get(&0).unwrap().len(),
1700 1100,
1701 "packet with stale counter must be rejected"
1702 );
1703 }
1704
1705 #[test]
1706 fn test_process_packet_far_future_counter_rejected() {
1707 use std::sync::Arc;
1711
1712 let (_init_keys, resp_keys) = make_session_keys();
1713
1714 let resp_session = Arc::new(NetSession::new(
1719 resp_keys,
1720 "127.0.0.1:5000".parse().unwrap(),
1721 4,
1722 false,
1723 ));
1724
1725 let rx_cipher = resp_session.rx_cipher();
1727 assert!(
1728 !rx_cipher.is_valid_rx_counter(u64::MAX),
1729 "counter at u64::MAX must be rejected (far beyond MAX_FORWARD)"
1730 );
1731 assert!(
1732 rx_cipher.is_valid_rx_counter(0),
1733 "counter 0 should be valid initially"
1734 );
1735 }
1736
1737 #[test]
1754 fn process_packet_drops_duplicates_per_reliability_decision() {
1755 use dashmap::DashMap;
1756 use std::sync::Arc;
1757
1758 let (init_keys, resp_keys) = make_session_keys();
1759
1760 let resp_session = Arc::new(NetSession::new(
1764 resp_keys,
1765 "127.0.0.1:5000".parse().unwrap(),
1766 4,
1767 true, ));
1769 let inbound: InboundQueues = Arc::new(DashMap::new());
1770 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1771
1772 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1778 let packet0 = builder.build(7, 0, &[Bytes::from(r#"{"first":0}"#)], PacketFlags::NONE);
1779 let packet1 = builder.build(7, 1, &[Bytes::from(r#"{"first":1}"#)], PacketFlags::NONE);
1780 let packet0_dup = builder.build(
1784 7,
1785 0,
1786 &[Bytes::from(r#"{"dup":"should_not_appear"}"#)],
1787 PacketFlags::NONE,
1788 );
1789
1790 NetAdapter::process_packet(packet0, source, &resp_session, &inbound, 1);
1791 NetAdapter::process_packet(packet1, source, &resp_session, &inbound, 1);
1792 NetAdapter::process_packet(packet0_dup, source, &resp_session, &inbound, 1);
1793
1794 let queue = inbound.get(&0).expect("shard 0 should exist");
1795 assert_eq!(
1796 queue.len(),
1797 2,
1798 "duplicate packet must NOT enqueue (BUG_REPORT.md #5); \
1799 got {} events, expected exactly 2 (seq=0 and seq=1, no dup)",
1800 queue.len()
1801 );
1802
1803 let e0 = queue.pop().unwrap();
1806 assert_eq!(&e0.raw[..], br#"{"first":0}"#);
1807 let e1 = queue.pop().unwrap();
1808 assert_eq!(&e1.raw[..], br#"{"first":1}"#);
1809 assert!(queue.is_empty());
1810 }
1811
1812 #[test]
1820 fn heartbeat_is_aead_authenticated() {
1821 use crate::adapter::net::pool::PacketBuilder;
1822 use dashmap::DashMap;
1823 use std::sync::Arc;
1824
1825 let (init_keys, resp_keys) = make_session_keys();
1826
1827 let resp_session = Arc::new(NetSession::new(
1828 resp_keys,
1829 "127.0.0.1:5000".parse().unwrap(),
1830 4,
1831 false,
1832 ));
1833 let inbound: InboundQueues = Arc::new(DashMap::new());
1834 let source: std::net::SocketAddr = "127.0.0.1:5000".parse().unwrap();
1835
1836 let mut builder = PacketBuilder::new(&init_keys.tx_key, init_keys.session_id);
1839 let heartbeat = builder.build_heartbeat();
1840 let last_activity_before = resp_session.last_activity_ns();
1841 std::thread::sleep(std::time::Duration::from_millis(2));
1842
1843 NetAdapter::process_packet(heartbeat, source, &resp_session, &inbound, 1);
1845 let last_activity_after = resp_session.last_activity_ns();
1846 assert!(
1847 last_activity_after > last_activity_before,
1848 "legitimate AEAD-tagged heartbeat must call session.touch()"
1849 );
1850
1851 let mut forged = bytes::BytesMut::new();
1855 let header = NetHeader::heartbeat(resp_session.session_id());
1856 forged.extend_from_slice(&header.to_bytes());
1857 let forged = forged.freeze();
1858 let last_activity_before = resp_session.last_activity_ns();
1859 std::thread::sleep(std::time::Duration::from_millis(2));
1860 NetAdapter::process_packet(forged, source, &resp_session, &inbound, 1);
1861 let last_activity_after = resp_session.last_activity_ns();
1862 assert_eq!(
1863 last_activity_before, last_activity_after,
1864 "unauthenticated heartbeat (no AEAD tag) must NOT touch the session"
1865 );
1866
1867 let mut forged_tag = bytes::BytesMut::new();
1870 let mut header_bytes = NetHeader::heartbeat(resp_session.session_id()).to_bytes();
1871 header_bytes[12..16].copy_from_slice(&[0u8; 4]);
1874 header_bytes[16..24].copy_from_slice(&1u64.to_le_bytes());
1875 forged_tag.extend_from_slice(&header_bytes);
1876 forged_tag.extend_from_slice(&[0xAAu8; 16]); let forged_tag = forged_tag.freeze();
1878 let last_activity_before = resp_session.last_activity_ns();
1879 std::thread::sleep(std::time::Duration::from_millis(2));
1880 NetAdapter::process_packet(forged_tag, source, &resp_session, &inbound, 1);
1881 let last_activity_after = resp_session.last_activity_ns();
1882 assert_eq!(
1883 last_activity_before, last_activity_after,
1884 "heartbeat with garbage AEAD tag must NOT touch the session"
1885 );
1886 }
1887
1888 #[test]
1894 fn handshake_pacer_rejects_floods_per_source() {
1895 use std::time::Duration;
1896 let mut pacer = HandshakePacer::new(3, Duration::from_millis(50));
1897
1898 let attacker: std::net::SocketAddr = "10.0.0.1:9000".parse().unwrap();
1899 let legit: std::net::SocketAddr = "10.0.0.2:9000".parse().unwrap();
1900
1901 for _ in 0..3 {
1903 assert!(pacer.check_and_record(attacker));
1904 }
1905 for _ in 0..10 {
1907 assert!(
1908 !pacer.check_and_record(attacker),
1909 "attacker exceeding budget must be dropped"
1910 );
1911 }
1912
1913 assert!(
1916 pacer.check_and_record(legit),
1917 "legitimate source must still get through despite attacker flood"
1918 );
1919
1920 std::thread::sleep(Duration::from_millis(55));
1922 assert!(
1923 pacer.check_and_record(attacker),
1924 "attacker budget must refill after window"
1925 );
1926 }
1927}