1#![allow(dead_code)]
20
21use std::collections::VecDeque;
22use std::io::{self};
23use tokio::io::{AsyncReadExt, AsyncWriteExt};
24use tokio::net::TcpStream;
25
26#[derive(Debug, thiserror::Error)]
32pub enum NetworkError {
33 #[error("not connected")]
34 NotConnected,
35 #[error("io error: {0}")]
36 Io(#[from] io::Error),
37 #[error("connection refused: {0}")]
38 ConnectionRefused(String),
39 #[error("frame too large: channel name length {0} exceeds 255 bytes")]
40 ChannelTooLong(usize),
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum ConnectionState {
50 Disconnected,
51 Connecting,
52 Connected,
53 Error,
54}
55
56#[derive(Debug, Clone)]
58pub struct NetworkConfig {
59 pub host: String,
61 pub port: u16,
63 pub latency_ms: u32,
65 pub packet_loss_prob: f32,
67 pub recv_buffer_size: usize,
69 pub endpoint: String,
71}
72
73#[derive(Debug, Clone)]
75pub struct NetworkPacket {
76 pub id: u64,
78 pub payload: Vec<u8>,
80 pub channel: String,
82 pub timestamp_ms: u64,
84}
85
86pub struct NetworkStub {
88 pub config: NetworkConfig,
90 runtime: tokio::runtime::Runtime,
91 stream: Option<TcpStream>,
92 recv_buffer: VecDeque<NetworkPacket>,
93 partial_buf: Vec<u8>,
95 state: ConnectionState,
96 send_count: u64,
97 recv_count: u64,
98 next_id: u64,
99 lcg_state: u64,
101}
102
103fn lcg_next(state: &mut u64) -> f32 {
108 *state = state
109 .wrapping_mul(6_364_136_223_846_793_005)
110 .wrapping_add(1_442_695_040_888_963_407);
111 ((*state >> 33) as f32) / (u32::MAX as f32 + 1.0)
112}
113
114pub fn default_network_config() -> NetworkConfig {
120 NetworkConfig {
121 host: "127.0.0.1".to_string(),
122 port: 7878,
123 latency_ms: 20,
124 packet_loss_prob: 0.0,
125 recv_buffer_size: 256,
126 endpoint: "127.0.0.1:7878".to_string(),
127 }
128}
129
130pub fn new_network_stub(config: NetworkConfig) -> NetworkStub {
134 let runtime = tokio::runtime::Builder::new_current_thread()
135 .enable_all()
136 .build()
137 .expect("tokio current-thread runtime is infallible");
138
139 NetworkStub {
140 config,
141 runtime,
142 stream: None,
143 recv_buffer: VecDeque::new(),
144 partial_buf: Vec::new(),
145 state: ConnectionState::Disconnected,
146 send_count: 0,
147 recv_count: 0,
148 next_id: 1,
149 lcg_state: 0xDEAD_BEEF_CAFE_1234,
150 }
151}
152
153fn encode_frame(channel: &str, payload: &[u8]) -> Result<Vec<u8>, NetworkError> {
164 let ch_bytes = channel.as_bytes();
165 let ch_len = ch_bytes.len();
166 if ch_len > 255 {
167 return Err(NetworkError::ChannelTooLong(ch_len));
168 }
169 let total_len: u32 = (1 + ch_len + payload.len()) as u32;
170 let mut frame = Vec::with_capacity(4 + 1 + ch_len + payload.len());
171 frame.extend_from_slice(&total_len.to_le_bytes());
172 frame.push(ch_len as u8);
173 frame.extend_from_slice(ch_bytes);
174 frame.extend_from_slice(payload);
175 Ok(frame)
176}
177
178fn try_parse_frame(buf: &[u8]) -> Option<(NetworkPacket, usize)> {
183 if buf.len() < 4 {
184 return None;
185 }
186 let total_len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
187 let frame_end = 4 + total_len;
188 if buf.len() < frame_end {
189 return None;
190 }
191 if total_len < 1 {
192 return None;
194 }
195 let ch_len = buf[4] as usize;
196 if total_len < 1 + ch_len {
197 return None;
198 }
199 let ch_start = 5;
200 let ch_end = ch_start + ch_len;
201 let payload_start = ch_end;
202 let payload_end = frame_end;
203
204 let channel = String::from_utf8_lossy(&buf[ch_start..ch_end]).into_owned();
205 let payload = buf[payload_start..payload_end].to_vec();
206 let pkt = NetworkPacket {
207 id: 0, payload,
209 channel,
210 timestamp_ms: current_millis(),
211 };
212 Some((pkt, frame_end))
213}
214
215fn current_millis() -> u64 {
216 std::time::SystemTime::now()
217 .duration_since(std::time::UNIX_EPOCH)
218 .map(|d| d.as_millis() as u64)
219 .unwrap_or(0)
220}
221
222pub fn connect_stub(stub: &mut NetworkStub) -> bool {
231 if stub.state == ConnectionState::Connected {
232 return true;
233 }
234 stub.state = ConnectionState::Connecting;
235
236 let addr = format!("{}:{}", stub.config.host, stub.config.port);
237 let latency_ms = stub.config.latency_ms;
238
239 let result = stub.runtime.block_on(async move {
240 if latency_ms > 0 {
241 tokio::time::sleep(std::time::Duration::from_millis(latency_ms as u64)).await;
242 }
243 TcpStream::connect(&addr).await
244 });
245
246 match result {
247 Ok(stream) => {
248 stub.stream = Some(stream);
249 stub.state = ConnectionState::Connected;
250 true
251 }
252 Err(e) if e.kind() == io::ErrorKind::ConnectionRefused => {
253 stub.state = ConnectionState::Error;
254 false
255 }
256 Err(_) => {
257 stub.state = ConnectionState::Error;
258 false
259 }
260 }
261}
262
263pub fn disconnect_stub(stub: &mut NetworkStub) {
265 if let Some(stream) = stub.stream.take() {
267 stub.runtime.block_on(async move {
268 drop(stream);
269 });
270 }
271 stub.state = ConnectionState::Disconnected;
272 stub.recv_buffer.clear();
273 stub.partial_buf.clear();
274}
275
276pub fn send_packet(stub: &mut NetworkStub, channel: &str, payload: Vec<u8>) -> Option<u64> {
285 if stub.state != ConnectionState::Connected {
286 return None;
287 }
288 if simulate_packet_loss(stub) {
289 return None;
290 }
291
292 let frame = encode_frame(channel, &payload).ok()?;
293
294 let stream = stub.stream.take()?;
296 let write_result = stub.runtime.block_on(async move {
297 let mut s = stream;
298 match s.write_all(&frame).await {
299 Ok(()) => Ok(s),
300 Err(e) => Err((e, s)),
301 }
302 });
303
304 match write_result {
305 Ok(s) => {
306 stub.stream = Some(s);
307 stub.send_count += 1;
308 Some(stub.send_count)
309 }
310 Err((_, s)) => {
311 stub.stream = Some(s);
312 stub.state = ConnectionState::Error;
313 None
314 }
315 }
316}
317
318pub fn receive_packet(stub: &mut NetworkStub) -> Option<NetworkPacket> {
324 if stub.state != ConnectionState::Connected {
325 return None;
326 }
327
328 if let Some(pkt) = stub.recv_buffer.pop_front() {
330 let pkt = if pkt.id == 0 {
333 let mut p = pkt;
334 p.id = stub.next_id;
335 stub.next_id += 1;
336 p
337 } else {
338 pkt
339 };
340 stub.recv_count += 1;
341 return Some(pkt);
342 }
343
344 let stream = stub.stream.take()?;
346 let partial = std::mem::take(&mut stub.partial_buf);
347
348 let (returned_stream, returned_partial, new_packets) = stub
349 .runtime
350 .block_on(async move { try_read_packets(stream, partial).await });
351
352 stub.stream = Some(returned_stream);
353 stub.partial_buf = returned_partial;
354
355 for pkt in new_packets {
357 if stub.recv_buffer.len() < stub.config.recv_buffer_size {
358 stub.recv_buffer.push_back(pkt);
359 }
360 }
361
362 if let Some(pkt) = stub.recv_buffer.pop_front() {
364 let pkt = if pkt.id == 0 {
365 let mut p = pkt;
366 p.id = stub.next_id;
367 stub.next_id += 1;
368 p
369 } else {
370 pkt
371 };
372 stub.recv_count += 1;
373 return Some(pkt);
374 }
375
376 None
377}
378
379async fn try_read_packets(
384 mut stream: TcpStream,
385 mut partial: Vec<u8>,
386) -> (TcpStream, Vec<u8>, Vec<NetworkPacket>) {
387 let mut tmp = [0u8; 4096];
388
389 let read_result =
391 tokio::time::timeout(std::time::Duration::from_millis(10), stream.read(&mut tmp)).await;
392
393 match read_result {
394 Ok(Ok(0)) => {
395 (stream, partial, vec![])
397 }
398 Ok(Ok(n)) => {
399 partial.extend_from_slice(&tmp[..n]);
400 let packets = parse_all_frames(&mut partial);
401 (stream, partial, packets)
402 }
403 Ok(Err(_)) | Err(_) => (stream, partial, vec![]),
405 }
406}
407
408fn parse_all_frames(buf: &mut Vec<u8>) -> Vec<NetworkPacket> {
411 let mut packets = Vec::new();
412 let mut cursor = 0usize;
413
414 while let Some((pkt, consumed)) = try_parse_frame(&buf[cursor..]) {
415 packets.push(pkt);
416 cursor += consumed;
417 }
418
419 if cursor > 0 {
420 buf.drain(..cursor);
421 }
422
423 packets
424}
425
426pub fn push_to_recv_buffer(stub: &mut NetworkStub, channel: &str, payload: Vec<u8>) -> bool {
430 if stub.recv_buffer.len() >= stub.config.recv_buffer_size {
431 return false;
432 }
433 stub.recv_buffer.push_back(NetworkPacket {
434 id: stub.next_id,
435 payload,
436 channel: channel.to_string(),
437 timestamp_ms: stub.send_count * stub.config.latency_ms as u64,
438 });
439 stub.next_id += 1;
440 true
441}
442
443pub fn flush_receive_buffer(stub: &mut NetworkStub) {
445 stub.recv_buffer.clear();
446}
447
448pub fn connection_state(stub: &NetworkStub) -> ConnectionState {
454 stub.state
455}
456
457pub fn packet_count_sent(stub: &NetworkStub) -> u64 {
459 stub.send_count
460}
461
462pub fn packet_count_received(stub: &NetworkStub) -> u64 {
464 stub.recv_count
465}
466
467pub fn set_latency_ms(stub: &mut NetworkStub, latency_ms: u32) {
469 stub.config.latency_ms = latency_ms;
470}
471
472pub fn simulate_packet_loss(stub: &mut NetworkStub) -> bool {
476 if stub.config.packet_loss_prob <= 0.0 {
477 return false;
478 }
479 lcg_next(&mut stub.lcg_state) < stub.config.packet_loss_prob
480}
481
482pub fn network_stub_to_json(stub: &NetworkStub) -> String {
484 let state_str = match stub.state {
485 ConnectionState::Disconnected => "disconnected",
486 ConnectionState::Connecting => "connecting",
487 ConnectionState::Connected => "connected",
488 ConnectionState::Error => "error",
489 };
490 format!(
491 "{{\"state\":\"{}\",\"endpoint\":\"{}\",\"latency_ms\":{},\
492 \"packet_loss_prob\":{:.4},\"sent\":{},\"received\":{},\"recv_buffer\":{}}}",
493 state_str,
494 stub.config.endpoint,
495 stub.config.latency_ms,
496 stub.config.packet_loss_prob,
497 stub.send_count,
498 stub.recv_count,
499 stub.recv_buffer.len(),
500 )
501}
502
503#[cfg(test)]
508mod tests {
509 use super::*;
510 use std::net::TcpListener as StdTcpListener;
511
512 fn bind_test_listener() -> (StdTcpListener, u16) {
521 let listener = StdTcpListener::bind("127.0.0.1:0").expect("bind test listener");
522 let port = listener.local_addr().expect("local_addr").port();
523 (listener, port)
524 }
525
526 fn accept_one_in_background(listener: StdTcpListener) -> u16 {
531 let port = listener.local_addr().expect("local_addr").port();
532 std::thread::spawn(move || {
533 let _conn = listener.accept();
535 });
539 port
540 }
541
542 fn test_config(port: u16) -> NetworkConfig {
544 NetworkConfig {
545 host: "127.0.0.1".to_string(),
546 port,
547 latency_ms: 0, packet_loss_prob: 0.0,
549 recv_buffer_size: 256,
550 endpoint: format!("127.0.0.1:{}", port),
551 }
552 }
553
554 fn connected_stub() -> NetworkStub {
559 let (listener, port) = bind_test_listener();
560 accept_one_in_background(listener);
561 let cfg = test_config(port);
562 let mut s = new_network_stub(cfg);
563 assert!(connect_stub(&mut s), "connected_stub: connect failed");
564 s
565 }
566
567 #[test]
573 fn default_config_sane() {
574 let cfg = default_network_config();
575 assert_eq!(cfg.latency_ms, 20);
576 assert_eq!(cfg.packet_loss_prob, 0.0);
577 assert!(cfg.recv_buffer_size > 0);
578 assert!(!cfg.endpoint.is_empty());
579 }
580
581 #[test]
583 fn new_stub_disconnected() {
584 let s = new_network_stub(default_network_config());
585 assert_eq!(connection_state(&s), ConnectionState::Disconnected);
586 }
587
588 #[test]
590 fn connect_transitions_to_connected() {
591 let s = connected_stub();
592 assert_eq!(connection_state(&s), ConnectionState::Connected);
593 }
594
595 #[test]
597 fn connect_refused_sets_error() {
598 let (listener, port) = bind_test_listener();
600 drop(listener);
601 let cfg = test_config(port);
602 let mut s = new_network_stub(cfg);
603 let ok = connect_stub(&mut s);
604 assert!(!ok);
605 assert_eq!(connection_state(&s), ConnectionState::Error);
606 }
607
608 #[test]
610 fn disconnect_sets_disconnected() {
611 let mut s = connected_stub();
612 disconnect_stub(&mut s);
613 assert_eq!(connection_state(&s), ConnectionState::Disconnected);
614 }
615
616 #[test]
618 fn send_when_disconnected_returns_none() {
619 let mut s = new_network_stub(default_network_config());
620 assert!(send_packet(&mut s, "ch", vec![1, 2, 3]).is_none());
621 }
622
623 #[test]
625 fn send_increments_count() {
626 let mut s = connected_stub();
627 send_packet(&mut s, "data", vec![42]).expect("should succeed");
628 send_packet(&mut s, "data", vec![43]).expect("should succeed");
629 assert_eq!(packet_count_sent(&s), 2);
630 }
631
632 #[test]
634 fn recv_empty_buffer_returns_none() {
635 let mut s = connected_stub();
636 assert!(receive_packet(&mut s).is_none());
638 }
639
640 #[test]
642 fn push_recv_round_trip() {
643 let mut s = connected_stub();
644 assert!(push_to_recv_buffer(&mut s, "ch", vec![9, 8, 7]));
645 let pkt = receive_packet(&mut s).expect("must have a packet");
646 assert_eq!(pkt.payload, vec![9, 8, 7]);
647 assert_eq!(pkt.channel, "ch");
648 }
649
650 #[test]
652 fn recv_when_disconnected_returns_none() {
653 let mut s = new_network_stub(default_network_config());
654 push_to_recv_buffer(&mut s, "ch", vec![1]);
655 assert!(receive_packet(&mut s).is_none());
658 }
659
660 #[test]
662 fn flush_empties_buffer() {
663 let mut s = connected_stub();
664 push_to_recv_buffer(&mut s, "x", vec![1]);
665 push_to_recv_buffer(&mut s, "x", vec![2]);
666 flush_receive_buffer(&mut s);
667 assert!(receive_packet(&mut s).is_none());
668 }
669
670 #[test]
672 fn set_latency_updates() {
673 let mut s = new_network_stub(default_network_config());
674 set_latency_ms(&mut s, 100);
675 assert_eq!(s.config.latency_ms, 100);
676 }
677
678 #[test]
680 fn zero_loss_never_loses() {
681 let mut cfg = default_network_config();
682 cfg.packet_loss_prob = 0.0;
683 let mut s = new_network_stub(cfg);
684 for _ in 0..50 {
685 assert!(!simulate_packet_loss(&mut s));
686 }
687 }
688
689 #[test]
691 fn full_loss_always_loses() {
692 let mut cfg = default_network_config();
693 cfg.packet_loss_prob = 1.0;
694 let mut s = new_network_stub(cfg);
695 for _ in 0..20 {
696 assert!(simulate_packet_loss(&mut s));
697 }
698 }
699
700 #[test]
702 fn recv_buffer_size_enforced() {
703 let (listener, port) = bind_test_listener();
704 accept_one_in_background(listener);
705 let mut cfg = test_config(port);
706 cfg.recv_buffer_size = 2;
707 let mut s = new_network_stub(cfg);
708 connect_stub(&mut s);
709 assert!(push_to_recv_buffer(&mut s, "a", vec![1]));
710 assert!(push_to_recv_buffer(&mut s, "b", vec![2]));
711 assert!(!push_to_recv_buffer(&mut s, "c", vec![3]));
712 }
713
714 #[test]
716 fn to_json_contains_endpoint() {
717 let s = new_network_stub(default_network_config());
718 let json = network_stub_to_json(&s);
719 assert!(json.contains("127.0.0.1:7878"));
720 }
721
722 #[test]
724 fn recv_count_increments() {
725 let mut s = connected_stub();
726 push_to_recv_buffer(&mut s, "x", vec![5]);
727 push_to_recv_buffer(&mut s, "x", vec![6]);
728 receive_packet(&mut s);
729 assert_eq!(packet_count_received(&s), 1);
730 receive_packet(&mut s);
731 assert_eq!(packet_count_received(&s), 2);
732 }
733
734 #[test]
736 fn connection_state_error_distinct() {
737 assert_ne!(ConnectionState::Error, ConnectionState::Connected);
738 assert_ne!(ConnectionState::Error, ConnectionState::Disconnected);
739 }
740
741 #[test]
743 fn connect_when_already_connected() {
744 let mut s = connected_stub();
745 assert!(connect_stub(&mut s));
746 assert_eq!(connection_state(&s), ConnectionState::Connected);
747 }
748
749 #[test]
751 fn disconnect_clears_recv() {
752 let (listener, port) = bind_test_listener();
753 accept_one_in_background(listener);
754 let cfg = test_config(port);
755 let mut s = new_network_stub(cfg);
756 connect_stub(&mut s);
757 push_to_recv_buffer(&mut s, "z", vec![0xFF]);
758 disconnect_stub(&mut s);
759
760 let (listener2, port2) = bind_test_listener();
762 accept_one_in_background(listener2);
763 s.config.host = "127.0.0.1".to_string();
764 s.config.port = port2;
765 s.config.endpoint = format!("127.0.0.1:{}", port2);
766 connect_stub(&mut s);
767 assert!(receive_packet(&mut s).is_none());
768 }
769
770 #[test]
772 fn to_json_contains_counts() {
773 let mut s = connected_stub();
774 send_packet(&mut s, "ch", vec![1]);
775 let json = network_stub_to_json(&s);
776 assert!(json.contains("\"sent\":1"));
777 assert!(json.contains("\"received\":0"));
778 }
779
780 #[test]
782 fn frame_encode_decode_round_trip() {
783 let frame = encode_frame("test-channel", b"hello world").expect("encode");
784 let (pkt, consumed) = try_parse_frame(&frame).expect("parse");
785 assert_eq!(consumed, frame.len());
786 assert_eq!(pkt.channel, "test-channel");
787 assert_eq!(pkt.payload, b"hello world");
788 }
789
790 #[test]
792 fn parse_all_frames_two_frames() {
793 let mut buf = encode_frame("a", b"foo").expect("enc a");
794 buf.extend(encode_frame("b", b"bar").expect("enc b"));
795 let pkts = parse_all_frames(&mut buf);
796 assert_eq!(pkts.len(), 2);
797 assert_eq!(pkts[0].channel, "a");
798 assert_eq!(pkts[1].channel, "b");
799 assert!(buf.is_empty());
800 }
801
802 #[test]
804 fn parse_all_frames_partial_frame() {
805 let mut buf = encode_frame("x", b"data").expect("enc");
806 let full_len = buf.len();
808 buf.truncate(full_len - 2);
809 let pkts = parse_all_frames(&mut buf);
810 assert!(pkts.is_empty());
811 assert_eq!(buf.len(), full_len - 2); }
813
814 #[test]
816 fn encode_frame_rejects_long_channel() {
817 let long_name: String = "x".repeat(256);
818 let result = encode_frame(&long_name, b"payload");
819 assert!(result.is_err());
820 }
821}