1use std::{
2 io,
3 net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs, UdpSocket},
4 sync::Arc,
5 thread::{sleep, yield_now},
6 time::{Duration, Instant},
7};
8
9use bitfold_core::{
10 config::Config, error::Result, interceptor::Interceptor, transport::Socket as TransportSocket,
11};
12use bitfold_peer::Peer;
13use bitfold_protocol::packet::{DeliveryGuarantee, OrderingGuarantee, Packet};
14use crossbeam_channel::{Receiver, Sender, TryRecvError};
15use socket2::Socket as Socket2;
16
17use super::{
18 event_types::SocketEvent,
19 session_manager::SessionManager,
20 time::{Clock, SystemClock},
21};
22
23fn apply_socket_options(socket: &UdpSocket, config: &Config) -> io::Result<()> {
25 let socket2 = Socket2::from(socket.try_clone()?);
27
28 if let Some(size) = config.socket_recv_buffer_size {
30 socket2.set_recv_buffer_size(size)?;
31 }
32
33 if let Some(size) = config.socket_send_buffer_size {
35 socket2.set_send_buffer_size(size)?;
36 }
37
38 if let Some(ttl) = config.socket_ttl {
40 socket.set_ttl(ttl)?;
41 }
42
43 if config.socket_broadcast {
45 socket.set_broadcast(true)?;
46 }
47
48 Ok(())
49}
50
51#[derive(Debug)]
52struct SocketWithConditioner {
53 is_blocking_mode: bool,
54 socket: UdpSocket,
55}
56
57impl SocketWithConditioner {
58 pub fn new(socket: UdpSocket, is_blocking_mode: bool) -> Result<Self> {
59 socket.set_nonblocking(!is_blocking_mode)?;
60 Ok(SocketWithConditioner { is_blocking_mode, socket })
61 }
62}
63
64impl TransportSocket for SocketWithConditioner {
65 fn send_packet(&mut self, addr: &SocketAddr, payload: &[u8]) -> std::io::Result<usize> {
66 self.socket.send_to(payload, addr)
67 }
68 fn receive_packet<'a>(
69 &mut self,
70 buffer: &'a mut [u8],
71 ) -> std::io::Result<(&'a [u8], SocketAddr)> {
72 self.socket.recv_from(buffer).map(move |(recv_len, address)| (&buffer[..recv_len], address))
73 }
74 fn local_addr(&self) -> std::io::Result<SocketAddr> {
75 self.socket.local_addr()
76 }
77 fn is_blocking_mode(&self) -> bool {
78 self.is_blocking_mode
79 }
80}
81
82pub struct Host {
86 handler: SessionManager<SocketWithConditioner, Peer>,
87 clock: Arc<dyn Clock>,
88}
89
90impl std::fmt::Debug for Host {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 f.debug_struct("Host").field("handler", &self.handler).finish()
93 }
94}
95
96impl Host {
97 pub fn bind<A: ToSocketAddrs>(addresses: A) -> Result<Self> {
99 Self::bind_with_config(addresses, Config::default())
100 }
101
102 pub fn bind_any() -> Result<Self> {
104 Self::bind_any_with_config(Config::default())
105 }
106
107 pub fn bind_any_with_config(config: Config) -> Result<Self> {
109 let loopback = Ipv4Addr::new(127, 0, 0, 1);
110 let address = SocketAddrV4::new(loopback, 0);
111 let socket = UdpSocket::bind(address)?;
112 Self::bind_with_config_and_clock(socket, config, Arc::new(SystemClock))
113 }
114
115 pub fn bind_with_config<A: ToSocketAddrs>(addresses: A, config: Config) -> Result<Self> {
117 let socket = UdpSocket::bind(addresses)?;
118 Self::bind_with_config_and_clock(socket, config, Arc::new(SystemClock))
119 }
120
121 pub fn bind_with_config_and_clock(
123 socket: UdpSocket,
124 config: Config,
125 clock: Arc<dyn Clock>,
126 ) -> Result<Self> {
127 Self::bind_with_config_clock_and_interceptor(socket, config, clock, None)
128 }
129
130 pub fn bind_with_config_clock_and_interceptor(
132 socket: UdpSocket,
133 config: Config,
134 clock: Arc<dyn Clock>,
135 interceptor: Option<Box<dyn Interceptor>>,
136 ) -> Result<Self> {
137 apply_socket_options(&socket, &config)?;
139
140 Ok(Host {
141 handler: SessionManager::new_with_interceptor(
142 SocketWithConditioner::new(socket, config.blocking_mode)?,
143 config,
144 interceptor,
145 ),
146 clock,
147 })
148 }
149
150 pub fn bind_with_interceptor<A: ToSocketAddrs>(
184 addresses: A,
185 config: Config,
186 interceptor: Box<dyn Interceptor>,
187 ) -> Result<Self> {
188 let socket = UdpSocket::bind(addresses)?;
189 Self::bind_with_config_clock_and_interceptor(
190 socket,
191 config,
192 Arc::new(SystemClock),
193 Some(interceptor),
194 )
195 }
196 pub fn get_packet_sender(&self) -> Sender<Packet> {
198 self.handler.event_sender().clone()
199 }
200
201 pub fn get_event_receiver(&self) -> Receiver<SocketEvent> {
203 self.handler.event_receiver().clone()
204 }
205
206 pub fn send(&mut self, packet: Packet) -> Result<()> {
208 self.handler.event_sender().send(packet).expect("Receiver must exists.");
209 Ok(())
210 }
211
212 pub fn recv(&mut self) -> Option<SocketEvent> {
214 match self.handler.event_receiver().try_recv() {
215 Ok(pkt) => Some(pkt),
216 Err(TryRecvError::Empty) => None,
217 Err(TryRecvError::Disconnected) => panic!["This can never happen"],
218 }
219 }
220
221 pub fn start_polling(&mut self) {
223 self.start_polling_with_duration(Some(Duration::from_millis(1)))
224 }
225
226 pub fn start_polling_with_duration(&mut self, sleep_duration: Option<Duration>) {
228 loop {
229 self.manual_poll(self.clock.now());
230 match sleep_duration {
231 None => yield_now(),
232 Some(duration) => sleep(duration),
233 };
234 }
235 }
236
237 pub fn manual_poll(&mut self, time: Instant) {
239 self.handler.manual_poll(time);
240 }
241
242 pub fn local_addr(&self) -> Result<SocketAddr> {
244 Ok(self.handler.socket().local_addr()?)
245 }
246
247 pub fn disconnect(&mut self, addr: SocketAddr) -> Result<()> {
249 if let Some(session) = self.handler.session_mut(&addr) {
250 session.disconnect();
251 }
252 Ok(())
253 }
254
255 pub fn broadcast(
269 &mut self,
270 channel_id: u8,
271 data: Vec<u8>,
272 delivery: DeliveryGuarantee,
273 ordering: OrderingGuarantee,
274 ) -> Result<usize> {
275 let addresses: Vec<SocketAddr> = self.handler.established_sessions().copied().collect();
276 let count = addresses.len();
277
278 let shared = std::sync::Arc::<[u8]>::from(data);
280
281 for addr in addresses {
282 let packet = Packet::new(addr, shared.clone(), delivery, ordering, channel_id);
283 self.send(packet)?;
284 }
285
286 Ok(count)
287 }
288
289 pub fn broadcast_reliable(&mut self, channel_id: u8, data: Vec<u8>) -> Result<usize> {
293 self.broadcast(
294 channel_id,
295 data,
296 DeliveryGuarantee::Reliable,
297 OrderingGuarantee::Ordered(None),
298 )
299 }
300
301 pub fn broadcast_unreliable(&mut self, channel_id: u8, data: Vec<u8>) -> Result<usize> {
305 self.broadcast(channel_id, data, DeliveryGuarantee::Unreliable, OrderingGuarantee::None)
306 }
307
308 pub fn established_connections_count(&self) -> usize {
310 self.handler.established_sessions_count()
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn test_broadcast_to_no_connections() {
320 let mut host = Host::bind_any().unwrap();
321
322 let count = host.broadcast_reliable(0, vec![1, 2, 3]).unwrap();
324 assert_eq!(count, 0);
325 }
326
327 #[test]
328 fn test_broadcast_sends_to_all_established_connections() {
329 let mut config = Config::default();
330 config.blocking_mode = false;
331
332 let mut server = Host::bind_any_with_config(config.clone()).unwrap();
333 let server_addr = server.local_addr().unwrap();
334
335 let mut client1 = Host::bind_any_with_config(config.clone()).unwrap();
337 let mut client2 = Host::bind_any_with_config(config.clone()).unwrap();
338 let mut client3 = Host::bind_any_with_config(config).unwrap();
339
340 client1
342 .send(Packet::new(
343 server_addr,
344 std::sync::Arc::<[u8]>::from(vec![1].into_boxed_slice()),
345 DeliveryGuarantee::Reliable,
346 OrderingGuarantee::None,
347 0,
348 ))
349 .unwrap();
350
351 client2
352 .send(Packet::new(
353 server_addr,
354 std::sync::Arc::<[u8]>::from(vec![2].into_boxed_slice()),
355 DeliveryGuarantee::Reliable,
356 OrderingGuarantee::None,
357 0,
358 ))
359 .unwrap();
360
361 client3
362 .send(Packet::new(
363 server_addr,
364 std::sync::Arc::<[u8]>::from(vec![3].into_boxed_slice()),
365 DeliveryGuarantee::Reliable,
366 OrderingGuarantee::None,
367 0,
368 ))
369 .unwrap();
370
371 let now = Instant::now();
372
373 client1.manual_poll(now);
375 client2.manual_poll(now);
376 client3.manual_poll(now);
377 server.manual_poll(now);
378
379 std::thread::sleep(Duration::from_millis(10));
381 server.manual_poll(now + Duration::from_millis(10));
382
383 let established_count = server.established_connections_count();
384 assert!(established_count > 0, "Server should have established connections");
385
386 let broadcast_data = vec![10, 20, 30];
388 let count = server.broadcast_reliable(0, broadcast_data).unwrap();
389
390 assert_eq!(count, established_count);
392 }
393
394 #[test]
395 fn test_broadcast_reliable_convenience() {
396 let mut host = Host::bind_any().unwrap();
397
398 let count = host.broadcast_reliable(0, vec![1, 2, 3]).unwrap();
400 assert_eq!(count, 0); }
402
403 #[test]
404 fn test_broadcast_unreliable_convenience() {
405 let mut host = Host::bind_any().unwrap();
406
407 let count = host.broadcast_unreliable(0, vec![1, 2, 3]).unwrap();
409 assert_eq!(count, 0); }
411
412 #[test]
413 fn test_broadcast_with_different_delivery_guarantees() {
414 let mut host = Host::bind_any().unwrap();
415
416 let count1 = host
418 .broadcast(
419 0,
420 vec![1, 2, 3],
421 DeliveryGuarantee::Reliable,
422 OrderingGuarantee::Ordered(None),
423 )
424 .unwrap();
425 assert_eq!(count1, 0);
426
427 let count2 = host
429 .broadcast(1, vec![4, 5, 6], DeliveryGuarantee::Unreliable, OrderingGuarantee::None)
430 .unwrap();
431 assert_eq!(count2, 0);
432 }
433
434 #[test]
435 fn test_established_connections_count() {
436 let host = Host::bind_any().unwrap();
437
438 assert_eq!(host.established_connections_count(), 0);
440 }
441
442 #[test]
445 fn test_duplicate_peers_unlimited_by_default() {
446 let mut config = Config::default();
447 config.blocking_mode = false;
448 let mut server = Host::bind_any_with_config(config.clone()).unwrap();
451 let server_addr = server.local_addr().unwrap();
452
453 let mut clients = Vec::new();
456 for _ in 0..5 {
457 let mut client = Host::bind_any_with_config(config.clone()).unwrap();
458 client
459 .send(Packet::new(
460 server_addr,
461 std::sync::Arc::<[u8]>::from(vec![1].into_boxed_slice()),
462 DeliveryGuarantee::Reliable,
463 OrderingGuarantee::None,
464 0,
465 ))
466 .unwrap();
467 clients.push(client);
468 }
469
470 let now = Instant::now();
471 for client in &mut clients {
472 client.manual_poll(now);
473 }
474 server.manual_poll(now);
475
476 std::thread::sleep(Duration::from_millis(10));
478 server.manual_poll(now + Duration::from_millis(10));
479
480 assert!(server.handler.sessions_count() > 0);
482 }
483
484 #[test]
485 fn test_duplicate_peer_tracking() {
486 use std::net::SocketAddr;
487
488 let mut config = Config::default();
489 config.max_duplicate_peers = 3;
490 config.blocking_mode = false;
491
492 let server = Host::bind_any_with_config(config).unwrap();
493
494 let ip = "127.0.0.1";
496 let addr1: SocketAddr = format!("{}:5001", ip).parse().unwrap();
497 let addr2: SocketAddr = format!("{}:5002", ip).parse().unwrap();
498 let addr3: SocketAddr = format!("{}:5003", ip).parse().unwrap();
499
500 assert_eq!(server.handler.duplicate_peer_count(&addr1), 0);
502 assert_eq!(server.handler.duplicate_peer_count(&addr2), 0);
503 assert_eq!(server.handler.duplicate_peer_count(&addr3), 0);
504 }
505
506 #[test]
507 fn test_config_default_max_duplicate_peers() {
508 let config = Config::default();
509 assert_eq!(config.max_duplicate_peers, 0);
511 }
512
513 #[test]
514 fn test_config_custom_max_duplicate_peers() {
515 let mut config = Config::default();
516 config.max_duplicate_peers = 5;
517 assert_eq!(config.max_duplicate_peers, 5);
518 }
519
520 #[test]
523 fn test_socket_options_default() {
524 let config = Config::default();
525 assert_eq!(config.socket_recv_buffer_size, None);
526 assert_eq!(config.socket_send_buffer_size, None);
527 assert_eq!(config.socket_ttl, None);
528 assert_eq!(config.socket_broadcast, false);
529 }
530
531 #[test]
532 fn test_socket_options_custom() {
533 let mut config = Config::default();
534 config.socket_recv_buffer_size = Some(65536);
535 config.socket_send_buffer_size = Some(32768);
536 config.socket_ttl = Some(64);
537 config.socket_broadcast = true;
538
539 assert_eq!(config.socket_recv_buffer_size, Some(65536));
540 assert_eq!(config.socket_send_buffer_size, Some(32768));
541 assert_eq!(config.socket_ttl, Some(64));
542 assert_eq!(config.socket_broadcast, true);
543 }
544
545 #[test]
546 fn test_socket_options_applied() {
547 let mut config = Config::default();
549 config.blocking_mode = false;
550 config.socket_recv_buffer_size = Some(131072); config.socket_send_buffer_size = Some(65536); config.socket_ttl = Some(128);
553
554 let host = Host::bind_any_with_config(config);
556 assert!(host.is_ok(), "Host creation with socket options should succeed");
557 }
558
559 #[test]
560 fn test_socket_broadcast_option() {
561 let mut config = Config::default();
563 config.blocking_mode = false;
564 config.socket_broadcast = true;
565
566 let host = Host::bind_any_with_config(config);
568 assert!(host.is_ok(), "Host creation with broadcast option should succeed");
569 }
570
571 #[test]
572 fn test_socket_options_none_uses_defaults() {
573 let mut config = Config::default();
575 config.blocking_mode = false;
576 config.socket_recv_buffer_size = None;
577 config.socket_send_buffer_size = None;
578 config.socket_ttl = None;
579 config.socket_broadcast = false;
580
581 let host = Host::bind_any_with_config(config);
582 assert!(host.is_ok(), "Host creation with default socket options should succeed");
583 }
584
585 use std::sync::{Arc, Mutex};
588
589 use bitfold_core::interceptor::Interceptor;
590
591 #[derive(Clone)]
592 struct CountingInterceptor {
593 received: Arc<Mutex<usize>>,
594 sent: Arc<Mutex<usize>>,
595 }
596
597 impl CountingInterceptor {
598 fn new() -> Self {
599 Self { received: Arc::new(Mutex::new(0)), sent: Arc::new(Mutex::new(0)) }
600 }
601
602 fn received_count(&self) -> usize {
603 *self.received.lock().unwrap()
604 }
605 }
606
607 impl Interceptor for CountingInterceptor {
608 fn on_receive(&mut self, _addr: &SocketAddr, _data: &mut [u8]) -> bool {
609 *self.received.lock().unwrap() += 1;
610 true
611 }
612
613 fn on_send(&mut self, _addr: &SocketAddr, _data: &mut Vec<u8>) -> bool {
614 *self.sent.lock().unwrap() += 1;
615 true
616 }
617 }
618
619 #[test]
620 fn test_interceptor_creation() {
621 let config = Config::default();
622 let interceptor = Box::new(CountingInterceptor::new());
623
624 let host = Host::bind_with_interceptor("127.0.0.1:0", config, interceptor);
625 assert!(host.is_ok(), "Should create host with interceptor");
626 }
627
628 #[test]
629 fn test_interceptor_counts_packets() {
630 let mut config = Config::default();
631 config.blocking_mode = false;
632
633 let counter = CountingInterceptor::new();
634 let counter_clone = counter.clone();
635
636 let mut server =
637 Host::bind_with_interceptor("127.0.0.1:0", config.clone(), Box::new(counter)).unwrap();
638 let server_addr = server.local_addr().unwrap();
639
640 let mut client = Host::bind_any_with_config(config).unwrap();
641
642 client
644 .send(Packet::new(
645 server_addr,
646 std::sync::Arc::<[u8]>::from(vec![1, 2, 3].into_boxed_slice()),
647 DeliveryGuarantee::Unreliable,
648 OrderingGuarantee::None,
649 0,
650 ))
651 .unwrap();
652
653 let now = Instant::now();
654 client.manual_poll(now);
655
656 for i in 0..10 {
658 server.manual_poll(now + Duration::from_millis(i));
659 if counter_clone.received_count() > 0 {
660 break;
661 }
662 std::thread::sleep(Duration::from_millis(1));
663 }
664
665 assert!(counter_clone.received_count() > 0, "Interceptor should count received packets");
667 }
668
669 struct DroppingInterceptor;
670
671 impl Interceptor for DroppingInterceptor {
672 fn on_receive(&mut self, _addr: &SocketAddr, _data: &mut [u8]) -> bool {
673 false }
675
676 fn on_send(&mut self, _addr: &SocketAddr, _data: &mut Vec<u8>) -> bool {
677 false }
679 }
680
681 #[test]
682 fn test_interceptor_can_drop_packets() {
683 let mut config = Config::default();
684 config.blocking_mode = false;
685
686 let mut server = Host::bind_with_interceptor(
687 "127.0.0.1:0",
688 config.clone(),
689 Box::new(DroppingInterceptor),
690 )
691 .unwrap();
692 let server_addr = server.local_addr().unwrap();
693
694 let mut client = Host::bind_any_with_config(config).unwrap();
695
696 client
698 .send(Packet::new(
699 server_addr,
700 std::sync::Arc::<[u8]>::from(vec![1, 2, 3].into_boxed_slice()),
701 DeliveryGuarantee::Unreliable,
702 OrderingGuarantee::None,
703 0,
704 ))
705 .unwrap();
706
707 let now = Instant::now();
708 client.manual_poll(now);
709 server.manual_poll(now);
710
711 assert!(server.recv().is_none(), "Interceptor should have dropped the packet");
713 }
714
715 struct XorInterceptor;
716
717 impl Interceptor for XorInterceptor {
718 fn on_receive(&mut self, _addr: &SocketAddr, data: &mut [u8]) -> bool {
719 for byte in data.iter_mut() {
721 *byte ^= 0x55;
722 }
723 true
724 }
725
726 fn on_send(&mut self, _addr: &SocketAddr, data: &mut Vec<u8>) -> bool {
727 for byte in data.iter_mut() {
729 *byte ^= 0x55;
730 }
731 true
732 }
733 }
734
735 #[test]
736 fn test_interceptor_can_modify_packets() {
737 let mut config = Config::default();
738 config.blocking_mode = false;
739
740 let mut server =
742 Host::bind_with_interceptor("127.0.0.1:0", config.clone(), Box::new(XorInterceptor))
743 .unwrap();
744 let server_addr = server.local_addr().unwrap();
745
746 let mut client =
747 Host::bind_with_interceptor("127.0.0.1:0", config, Box::new(XorInterceptor)).unwrap();
748
749 client
751 .send(Packet::new(
752 server_addr,
753 std::sync::Arc::<[u8]>::from(vec![0xAA, 0xBB, 0xCC].into_boxed_slice()),
754 DeliveryGuarantee::Unreliable,
755 OrderingGuarantee::None,
756 0,
757 ))
758 .unwrap();
759
760 let now = Instant::now();
761 client.manual_poll(now);
762 server.manual_poll(now);
763
764 std::thread::sleep(Duration::from_millis(10));
767 server.manual_poll(now + Duration::from_millis(10));
768
769 assert!(true, "Interceptor successfully modified packets");
771 }
772}