1use crate::config::SubscriptionConfig;
2use crate::error::McrxError;
3use crate::packet::{Packet, PacketWithMetadata};
4use crate::platform::{
5 ReceiveSocket, join_multicast_group, leave_multicast_group, open_bound_socket,
6 prepare_existing_socket,
7};
8use crate::subscription::{Subscription, SubscriptionId};
9use socket2::Socket;
10
11#[cfg(feature = "metrics")]
12use crate::metrics::ContextMetricsSnapshot;
13#[cfg(feature = "metrics")]
14use std::sync::atomic::{AtomicU64, Ordering};
15#[cfg(feature = "metrics")]
16use std::time::SystemTime;
17
18#[cfg(feature = "metrics")]
19#[derive(Debug, Default)]
20struct ContextMetricsInner {
21 subscriptions_added: AtomicU64,
22 subscriptions_removed: AtomicU64,
23 total_packets_received: AtomicU64,
24 total_bytes_received: AtomicU64,
25 total_would_block_count: AtomicU64,
26 total_receive_errors: AtomicU64,
27 total_join_count: AtomicU64,
28 total_leave_count: AtomicU64,
29 batch_calls: AtomicU64,
30 batch_packets_received: AtomicU64,
31}
32
33#[derive(Debug, Default)]
35pub struct Context {
36 subscriptions: Vec<Subscription>,
37 next_subscription_id: u64,
38 next_recv_index: usize,
39 #[cfg(feature = "metrics")]
40 metrics: ContextMetricsInner,
41}
42
43impl Context {
44 #[cfg(feature = "metrics")]
45 fn record_received_packet(&self, payload_len: usize) {
46 self.metrics
47 .total_packets_received
48 .fetch_add(1, Ordering::Relaxed);
49 self.metrics
50 .total_bytes_received
51 .fetch_add(payload_len as u64, Ordering::Relaxed);
52 }
53
54 #[cfg(feature = "metrics")]
55 fn record_would_block(&self) {
56 self.metrics
57 .total_would_block_count
58 .fetch_add(1, Ordering::Relaxed);
59 }
60
61 #[cfg(feature = "metrics")]
62 fn record_receive_error(&self) {
63 self.metrics
64 .total_receive_errors
65 .fetch_add(1, Ordering::Relaxed);
66 }
67
68 #[cfg(feature = "metrics")]
69 fn record_batch_call(&self) {
70 self.metrics.batch_calls.fetch_add(1, Ordering::Relaxed);
71 }
72
73 #[cfg(feature = "metrics")]
74 fn record_batch_packets_received(&self, packet_count: usize) {
75 self.metrics
76 .batch_packets_received
77 .fetch_add(packet_count as u64, Ordering::Relaxed);
78 }
79
80 fn ensure_subscription_config_is_unique(
81 &self,
82 config: &SubscriptionConfig,
83 ) -> Result<(), McrxError> {
84 if self
85 .subscriptions
86 .iter()
87 .any(|subscription| subscription.config() == config)
88 {
89 return Err(McrxError::DuplicateSubscription);
90 }
91
92 Ok(())
93 }
94
95 fn insert_subscription(
96 &mut self,
97 config: SubscriptionConfig,
98 socket: ReceiveSocket,
99 ) -> SubscriptionId {
100 let id = SubscriptionId(self.next_subscription_id);
101 self.next_subscription_id += 1;
102
103 let subscription = Subscription::from_receive_socket(id, config, socket);
104 self.subscriptions.push(subscription);
105
106 #[cfg(feature = "metrics")]
107 self.metrics
108 .subscriptions_added
109 .fetch_add(1, Ordering::Relaxed);
110
111 id
112 }
113
114 fn finish_subscription_removal(&mut self, index: usize) -> Subscription {
115 let removed = self.subscriptions.swap_remove(index);
116
117 if self.subscriptions.is_empty() {
118 self.next_recv_index = 0;
119 } else {
120 self.next_recv_index %= self.subscriptions.len();
121 }
122
123 #[cfg(feature = "metrics")]
124 self.metrics
125 .subscriptions_removed
126 .fetch_add(1, Ordering::Relaxed);
127
128 removed
129 }
130
131 pub fn new() -> Self {
133 Self {
134 subscriptions: Vec::new(),
135 next_subscription_id: 1,
136 next_recv_index: 0,
137 #[cfg(feature = "metrics")]
138 metrics: ContextMetricsInner::default(),
139 }
140 }
141
142 pub fn subscription_count(&self) -> usize {
144 self.subscriptions.len()
145 }
146
147 #[cfg(feature = "metrics")]
153 pub fn metrics_snapshot(&self) -> ContextMetricsSnapshot {
154 let mut joined_subscriptions = 0usize;
155
156 for subscription in &self.subscriptions {
157 if subscription.is_joined() {
158 joined_subscriptions += 1;
159 }
160 }
161
162 ContextMetricsSnapshot {
163 subscriptions_added: self.metrics.subscriptions_added.load(Ordering::Relaxed),
164 subscriptions_removed: self.metrics.subscriptions_removed.load(Ordering::Relaxed),
165 active_subscriptions: self.subscriptions.len(),
166 joined_subscriptions,
167 total_packets_received: self.metrics.total_packets_received.load(Ordering::Relaxed),
168 total_bytes_received: self.metrics.total_bytes_received.load(Ordering::Relaxed),
169 total_would_block_count: self.metrics.total_would_block_count.load(Ordering::Relaxed),
170 total_receive_errors: self.metrics.total_receive_errors.load(Ordering::Relaxed),
171 total_join_count: self.metrics.total_join_count.load(Ordering::Relaxed),
172 total_leave_count: self.metrics.total_leave_count.load(Ordering::Relaxed),
173 batch_calls: self.metrics.batch_calls.load(Ordering::Relaxed),
174 batch_packets_received: self.metrics.batch_packets_received.load(Ordering::Relaxed),
175 captured_at: SystemTime::now(),
176 }
177 }
178
179 pub fn contains_subscription(&self, id: SubscriptionId) -> bool {
181 self.subscriptions
182 .iter()
183 .any(|subscription| subscription.id() == id)
184 }
185
186 pub fn get_subscription(&self, id: SubscriptionId) -> Option<&Subscription> {
188 self.subscriptions
189 .iter()
190 .find(|subscription| subscription.id() == id)
191 }
192
193 pub fn get_subscription_mut(&mut self, id: SubscriptionId) -> Option<&mut Subscription> {
195 self.subscriptions
196 .iter_mut()
197 .find(|subscription| subscription.id() == id)
198 }
199
200 pub fn add_subscription(
208 &mut self,
209 config: SubscriptionConfig,
210 ) -> Result<SubscriptionId, McrxError> {
211 config.validate()?;
212 self.ensure_subscription_config_is_unique(&config)?;
213
214 let socket = open_bound_socket(&config)?;
215 Ok(self.insert_subscription(config, socket))
216 }
217
218 pub fn add_subscription_with_socket(
228 &mut self,
229 config: SubscriptionConfig,
230 socket: Socket,
231 ) -> Result<SubscriptionId, McrxError> {
232 config.validate()?;
233 self.ensure_subscription_config_is_unique(&config)?;
234
235 let socket = prepare_existing_socket(socket, &config)?;
236
237 Ok(self.insert_subscription(config, socket))
238 }
239
240 pub fn remove_subscription(&mut self, id: SubscriptionId) -> bool {
247 self.take_subscription(id).is_some()
248 }
249
250 pub fn take_subscription(&mut self, id: SubscriptionId) -> Option<Subscription> {
257 let index = self
258 .subscriptions
259 .iter()
260 .position(|subscription| subscription.id() == id)?;
261
262 Some(self.finish_subscription_removal(index))
263 }
264
265 pub fn join_subscription(&mut self, id: SubscriptionId) -> Result<(), McrxError> {
267 let subscription = self
268 .get_subscription_mut(id)
269 .ok_or(McrxError::SubscriptionNotFound)?;
270
271 if subscription.is_joined() {
272 return Err(McrxError::SubscriptionAlreadyJoined);
273 }
274
275 join_multicast_group(subscription.socket(), subscription.config())?;
276 subscription.mark_joined()?;
277
278 #[cfg(feature = "metrics")]
279 self.metrics
280 .total_join_count
281 .fetch_add(1, Ordering::Relaxed);
282
283 Ok(())
284 }
285
286 fn try_recv_from_joined<T>(
287 &mut self,
288 mut recv: impl FnMut(&Subscription) -> Result<Option<T>, McrxError>,
289 _packet_len: impl Fn(&T) -> usize,
290 ) -> Result<Option<T>, McrxError> {
291 let subscription_count = self.subscriptions.len();
292
293 if subscription_count == 0 {
294 return Ok(None);
295 }
296
297 for offset in 0..subscription_count {
298 let index = (self.next_recv_index + offset) % subscription_count;
299 let subscription = &self.subscriptions[index];
300
301 if !subscription.is_joined() {
302 continue;
303 }
304
305 match recv(subscription) {
306 Ok(Some(packet)) => {
307 self.next_recv_index = (index + 1) % subscription_count;
308
309 #[cfg(feature = "metrics")]
310 self.record_received_packet(_packet_len(&packet));
311
312 return Ok(Some(packet));
313 }
314 Ok(None) => {
315 #[cfg(feature = "metrics")]
316 self.record_would_block();
317 }
318 Err(err) => {
319 #[cfg(feature = "metrics")]
320 self.record_receive_error();
321
322 return Err(err);
323 }
324 }
325 }
326
327 Ok(None)
328 }
329
330 pub fn leave_subscription(&mut self, id: SubscriptionId) -> Result<(), McrxError> {
332 let subscription = self
333 .get_subscription_mut(id)
334 .ok_or(McrxError::SubscriptionNotFound)?;
335
336 if !subscription.is_joined() {
337 return Err(McrxError::SubscriptionNotJoined);
338 }
339
340 leave_multicast_group(subscription.socket(), subscription.config())?;
341 subscription.mark_bound()?;
342
343 #[cfg(feature = "metrics")]
344 self.metrics
345 .total_leave_count
346 .fetch_add(1, Ordering::Relaxed);
347
348 Ok(())
349 }
350
351 pub fn subscriptions(&self) -> &[Subscription] {
353 &self.subscriptions
354 }
355
356 pub fn subscriptions_mut(&mut self) -> &mut [Subscription] {
358 &mut self.subscriptions
359 }
360
361 pub fn try_recv_any(&mut self) -> Result<Option<Packet>, McrxError> {
369 self.try_recv_from_joined(Subscription::try_recv, |packet| packet.payload.len())
370 }
371
372 pub fn try_recv_any_with_metadata(&mut self) -> Result<Option<PacketWithMetadata>, McrxError> {
377 self.try_recv_from_joined(Subscription::try_recv_with_metadata, |packet| {
378 packet.packet.payload.len()
379 })
380 }
381
382 pub fn try_recv_batch_into(
394 &mut self,
395 out: &mut Vec<Packet>,
396 max_packets: usize,
397 ) -> Result<usize, McrxError> {
398 #[cfg(feature = "metrics")]
399 self.record_batch_call();
400
401 self.try_recv_batch_into_impl(out, max_packets)
402 }
403
404 fn try_recv_batch_into_impl(
405 &mut self,
406 out: &mut Vec<Packet>,
407 max_packets: usize,
408 ) -> Result<usize, McrxError> {
409 let mut received = 0;
410
411 for _ in 0..max_packets {
412 match self.try_recv_any()? {
413 Some(packet) => {
414 out.push(packet);
415 received += 1;
416 }
417 None => break,
418 }
419 }
420
421 #[cfg(feature = "metrics")]
422 self.record_batch_packets_received(received);
423
424 Ok(received)
425 }
426
427 pub fn try_recv_batch_with_metadata_into(
430 &mut self,
431 out: &mut Vec<PacketWithMetadata>,
432 max_packets: usize,
433 ) -> Result<usize, McrxError> {
434 #[cfg(feature = "metrics")]
435 self.record_batch_call();
436
437 self.try_recv_batch_with_metadata_into_impl(out, max_packets)
438 }
439
440 fn try_recv_batch_with_metadata_into_impl(
441 &mut self,
442 out: &mut Vec<PacketWithMetadata>,
443 max_packets: usize,
444 ) -> Result<usize, McrxError> {
445 let mut received = 0;
446
447 for _ in 0..max_packets {
448 match self.try_recv_any_with_metadata()? {
449 Some(packet) => {
450 out.push(packet);
451 received += 1;
452 }
453 None => break,
454 }
455 }
456
457 #[cfg(feature = "metrics")]
458 self.record_batch_packets_received(received);
459
460 Ok(received)
461 }
462
463 pub fn try_recv_all_into(&mut self, out: &mut Vec<Packet>) -> Result<usize, McrxError> {
472 let mut total_received = 0;
473
474 #[cfg(feature = "metrics")]
475 self.record_batch_call();
476
477 loop {
478 let received = self.try_recv_batch_into_impl(out, usize::MAX)?;
479 total_received += received;
480
481 if received == 0 {
482 break;
483 }
484 }
485
486 Ok(total_received)
487 }
488
489 pub fn try_recv_all_with_metadata_into(
492 &mut self,
493 out: &mut Vec<PacketWithMetadata>,
494 ) -> Result<usize, McrxError> {
495 let mut total_received = 0;
496
497 #[cfg(feature = "metrics")]
498 self.record_batch_call();
499
500 loop {
501 let received = self.try_recv_batch_with_metadata_into_impl(out, usize::MAX)?;
502 total_received += received;
503
504 if received == 0 {
505 break;
506 }
507 }
508
509 Ok(total_received)
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516 use crate::config::SourceFilter;
517 use crate::subscription::SubscriptionState;
518 use socket2::{Domain, Protocol, SockAddr, Socket, Type};
519 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
520 use std::thread;
521 use std::time::{Duration, Instant};
522
523 use crate::test_support::{
524 ipv6_group_socket_addr, make_multicast_sender, make_multicast_sender_v6, recv_next_packet,
525 sample_config_on_unused_port, sample_config_v6_on_unused_port,
526 };
527
528 fn make_bound_external_socket(port: u16) -> Socket {
529 let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)).unwrap();
530 socket.set_reuse_address(true).unwrap();
531 socket
532 .bind(&SockAddr::from(SocketAddrV4::new(
533 Ipv4Addr::UNSPECIFIED,
534 port,
535 )))
536 .unwrap();
537 socket
538 }
539
540 fn ipv4_group(config: &SubscriptionConfig) -> Ipv4Addr {
541 config.ipv4_membership().unwrap().group
542 }
543
544 fn ipv4_group_socket_addr(config: &SubscriptionConfig) -> SocketAddrV4 {
545 SocketAddrV4::new(ipv4_group(config), config.dst_port)
546 }
547
548 fn assert_pktinfo_metadata(
549 packet: &PacketWithMetadata,
550 expected_destination: std::net::IpAddr,
551 ) {
552 #[cfg(any(
553 target_os = "linux",
554 target_os = "android",
555 windows,
556 target_vendor = "apple",
557 target_os = "freebsd",
558 target_os = "dragonfly",
559 target_os = "netbsd",
560 target_os = "openbsd"
561 ))]
562 {
563 assert_eq!(
564 packet.metadata.destination_local_ip,
565 Some(expected_destination)
566 );
567 assert!(packet.metadata.ingress_interface_index.is_some());
568 }
569
570 #[cfg(not(any(
571 target_os = "linux",
572 target_os = "android",
573 windows,
574 target_vendor = "apple",
575 target_os = "freebsd",
576 target_os = "dragonfly",
577 target_os = "netbsd",
578 target_os = "openbsd"
579 )))]
580 {
581 let _ = expected_destination;
582 assert_eq!(packet.metadata.destination_local_ip, None);
583 assert_eq!(packet.metadata.ingress_interface_index, None);
584 }
585 }
586
587 fn send_round_robin_test_packets(
588 first_config: &SubscriptionConfig,
589 second_config: &SubscriptionConfig,
590 ) {
591 let sender = make_multicast_sender();
592
593 sender
594 .send_to(b"first-1", ipv4_group_socket_addr(first_config))
595 .unwrap();
596 sender
597 .send_to(b"second-1", ipv4_group_socket_addr(second_config))
598 .unwrap();
599 sender
600 .send_to(b"first-2", ipv4_group_socket_addr(first_config))
601 .unwrap();
602 }
603
604 #[test]
605 fn new_context_starts_empty() {
606 let context = Context::new();
607
608 assert_eq!(context.subscription_count(), 0);
609 assert!(context.subscriptions().is_empty());
610 }
611
612 #[test]
613 fn add_subscription_returns_id_and_increases_count() {
614 let mut context = Context::new();
615
616 let id = context
617 .add_subscription(sample_config_on_unused_port())
618 .unwrap();
619
620 assert_eq!(id, SubscriptionId(1));
621 assert_eq!(context.subscription_count(), 1);
622 assert_eq!(context.subscriptions()[0].id(), id);
623 }
624
625 #[test]
626 fn adding_two_subscriptions_generates_different_ids() {
627 let mut context = Context::new();
628
629 let first = context
630 .add_subscription(sample_config_on_unused_port())
631 .unwrap();
632 let second = context
633 .add_subscription(sample_config_on_unused_port())
634 .unwrap();
635
636 assert_ne!(first, second);
637 assert_eq!(first, SubscriptionId(1));
638 assert_eq!(second, SubscriptionId(2));
639 }
640
641 #[test]
642 fn invalid_subscription_is_rejected() {
643 let mut context = Context::new();
644
645 let invalid_config = SubscriptionConfig {
646 group: std::net::IpAddr::V4(Ipv4Addr::new(192, 168, 1, 10)),
647 source: SourceFilter::Any,
648 dst_port: 5000,
649 interface: None,
650 interface_index: None,
651 };
652
653 let result = context.add_subscription(invalid_config);
654
655 assert!(matches!(result, Err(McrxError::InvalidMulticastGroup)));
656 assert_eq!(context.subscription_count(), 0);
657 }
658
659 #[test]
660 fn try_recv_any_returns_packet_from_ready_ipv6_subscription() {
661 let mut context = Context::new();
662 let config = sample_config_v6_on_unused_port();
663 let id = context.add_subscription(config.clone()).unwrap();
664 context.join_subscription(id).unwrap();
665
666 let sender = make_multicast_sender_v6(std::net::Ipv6Addr::LOCALHOST);
667 let payload = b"context try_recv_any ipv6 asm";
668 sender
669 .send_to(payload, ipv6_group_socket_addr(&config))
670 .unwrap();
671
672 let deadline = Instant::now() + Duration::from_secs(1);
673 let packet = recv_next_packet(&mut context, deadline);
674
675 assert_eq!(packet.subscription_id, id);
676 assert_eq!(packet.group, config.group);
677 assert_eq!(packet.dst_port, config.dst_port);
678 assert_eq!(&packet.payload[..], payload);
679 }
680
681 #[test]
682 fn try_recv_any_with_metadata_returns_packet_from_ready_ipv6_subscription() {
683 let mut context = Context::new();
684 let config = sample_config_v6_on_unused_port();
685 let id = context.add_subscription(config.clone()).unwrap();
686 context.join_subscription(id).unwrap();
687
688 let sender = make_multicast_sender_v6(std::net::Ipv6Addr::LOCALHOST);
689
690 let payload = b"context try_recv_any_with_metadata ipv6 asm";
691 sender
692 .send_to(payload, ipv6_group_socket_addr(&config))
693 .unwrap();
694
695 let deadline = Instant::now() + Duration::from_secs(1);
696 let packet = loop {
697 match context.try_recv_any_with_metadata().unwrap() {
698 Some(packet) => break packet,
699 None if Instant::now() < deadline => {
700 thread::sleep(Duration::from_millis(10));
701 }
702 None => panic!("timed out waiting for IPv6 packet with metadata from context"),
703 }
704 };
705
706 assert_eq!(packet.packet.subscription_id, id);
707 assert_eq!(packet.packet.group, config.group);
708 assert_eq!(packet.packet.dst_port, config.dst_port);
709 assert_eq!(&packet.packet.payload[..], payload);
710 assert_pktinfo_metadata(&packet, config.group);
711 assert_eq!(packet.metadata.configured_interface, config.interface);
712 assert_eq!(
713 packet.metadata.configured_interface_index,
714 config.interface_index
715 );
716 assert_eq!(
717 packet.metadata.socket_local_addr,
718 Some(std::net::SocketAddr::V6(SocketAddrV6::new(
719 Ipv6Addr::UNSPECIFIED,
720 config.dst_port,
721 0,
722 0,
723 )))
724 );
725 }
726
727 #[test]
728 fn remove_existing_subscription_returns_true() {
729 let mut context = Context::new();
730
731 let id = context
732 .add_subscription(sample_config_on_unused_port())
733 .unwrap();
734
735 let removed = context.remove_subscription(id);
736
737 assert!(removed);
738 assert_eq!(context.subscription_count(), 0);
739 }
740
741 #[test]
742 fn remove_missing_subscription_returns_false() {
743 let mut context = Context::new();
744
745 let removed = context.remove_subscription(SubscriptionId(999));
746
747 assert!(!removed);
748 }
749
750 #[test]
751 fn take_subscription_returns_owned_subscription_and_removes_it() {
752 let mut context = Context::new();
753 let config = sample_config_on_unused_port();
754 let id = context.add_subscription(config.clone()).unwrap();
755 context.join_subscription(id).unwrap();
756
757 let subscription = context.take_subscription(id).unwrap();
758
759 assert_eq!(subscription.id(), id);
760 assert_eq!(subscription.config(), &config);
761 assert_eq!(subscription.state(), SubscriptionState::Joined);
762 assert_eq!(context.subscription_count(), 0);
763 assert!(!context.contains_subscription(id));
764 }
765
766 #[test]
767 fn taken_subscription_can_be_split_into_owned_parts() {
768 let mut context = Context::new();
769 let config = sample_config_on_unused_port();
770 let id = context.add_subscription(config.clone()).unwrap();
771 context.join_subscription(id).unwrap();
772
773 let subscription = context.take_subscription(id).unwrap();
774 let parts = subscription.into_parts();
775
776 assert_eq!(parts.id, id);
777 assert_eq!(parts.config, config);
778 assert_eq!(parts.state, SubscriptionState::Joined);
779
780 let local_addr = parts.socket.local_addr().unwrap().as_socket().unwrap();
781 assert_eq!(local_addr.port(), parts.config.dst_port);
782 }
783
784 #[test]
785 fn taken_joined_subscription_can_still_receive_packets() {
786 let mut context = Context::new();
787 let config = sample_config_on_unused_port();
788 let id = context.add_subscription(config.clone()).unwrap();
789 context.join_subscription(id).unwrap();
790
791 let subscription = context.take_subscription(id).unwrap();
792 let sender = make_multicast_sender();
793 let payload = b"context take_subscription handoff";
794
795 sender
796 .send_to(payload, ipv4_group_socket_addr(&config))
797 .unwrap();
798
799 let deadline = Instant::now() + Duration::from_secs(1);
800 let packet = loop {
801 match subscription.try_recv().unwrap() {
802 Some(packet) => break packet,
803 None if Instant::now() < deadline => thread::sleep(Duration::from_millis(10)),
804 None => panic!("timed out waiting for packet after subscription handoff"),
805 }
806 };
807
808 assert_eq!(packet.subscription_id, id);
809 assert_eq!(packet.group, std::net::IpAddr::V4(ipv4_group(&config)));
810 assert_eq!(packet.dst_port, config.dst_port);
811 assert_eq!(&packet.payload[..], payload);
812 }
813
814 #[test]
815 fn three_subscriptions_have_len_3() {
816 let mut context = Context::new();
817
818 context
819 .add_subscription(sample_config_on_unused_port())
820 .unwrap();
821 context
822 .add_subscription(sample_config_on_unused_port())
823 .unwrap();
824 context
825 .add_subscription(sample_config_on_unused_port())
826 .unwrap();
827
828 assert_eq!(context.subscription_count(), 3);
829 }
830
831 #[test]
832 fn duplicate_subscription_is_rejected() {
833 let mut context = Context::new();
834 let config = sample_config_on_unused_port();
835
836 let first = context.add_subscription(config.clone());
837 let second = context.add_subscription(config);
838
839 assert!(first.is_ok());
840 assert!(matches!(second, Err(McrxError::DuplicateSubscription)));
841 assert_eq!(context.subscription_count(), 1);
842 }
843
844 #[test]
845 fn add_subscription_with_socket_returns_id_and_increases_count() {
846 let mut context = Context::new();
847 let config = sample_config_on_unused_port();
848 let socket = make_bound_external_socket(config.dst_port);
849
850 let id = context
851 .add_subscription_with_socket(config, socket)
852 .unwrap();
853
854 assert_eq!(id, SubscriptionId(1));
855 assert_eq!(context.subscription_count(), 1);
856 assert_eq!(context.subscriptions()[0].id(), id);
857 }
858
859 #[test]
860 fn add_subscription_with_socket_rejects_port_mismatch() {
861 let mut context = Context::new();
862 let config = sample_config_on_unused_port();
863 let expected_port = config.dst_port;
864 let socket = make_bound_external_socket(0);
865
866 let result = context.add_subscription_with_socket(config, socket);
867
868 assert!(matches!(
869 result,
870 Err(McrxError::ExistingSocketPortMismatch { expected, .. })
871 if expected == expected_port
872 ));
873 }
874
875 #[test]
876 fn contains_subscription_returns_true_for_existing_id() {
877 let mut context = Context::new();
878 let id = context
879 .add_subscription(sample_config_on_unused_port())
880 .unwrap();
881
882 assert!(context.contains_subscription(id));
883 }
884
885 #[test]
886 fn contains_subscription_returns_false_for_missing_id() {
887 let context = Context::new();
888
889 assert!(!context.contains_subscription(SubscriptionId(999)));
890 }
891
892 #[test]
893 fn get_subscription_returns_matching_subscription() {
894 let mut context = Context::new();
895 let id = context
896 .add_subscription(sample_config_on_unused_port())
897 .unwrap();
898
899 let subscription = context.get_subscription(id);
900
901 assert!(subscription.is_some());
902 assert_eq!(subscription.unwrap().id(), id);
903 }
904
905 #[test]
906 fn get_subscription_returns_none_for_missing_id() {
907 let context = Context::new();
908
909 let subscription = context.get_subscription(SubscriptionId(999));
910
911 assert!(subscription.is_none());
912 }
913
914 #[test]
915 fn get_subscription_mut_returns_matching_subscription() {
916 let mut context = Context::new();
917 let id = context
918 .add_subscription(sample_config_on_unused_port())
919 .unwrap();
920
921 let subscription = context.get_subscription_mut(id);
922
923 assert!(subscription.is_some());
924 assert_eq!(subscription.unwrap().id(), id);
925 }
926
927 #[test]
928 fn get_subscription_mut_returns_none_for_missing_id() {
929 let mut context = Context::new();
930
931 let subscription = context.get_subscription_mut(SubscriptionId(999));
932
933 assert!(subscription.is_none());
934 }
935
936 #[test]
937 fn try_recv_any_returns_none_when_no_packet_is_available() {
938 let mut context = Context::new();
939 let id = context
940 .add_subscription(sample_config_on_unused_port())
941 .unwrap();
942 context.join_subscription(id).unwrap();
943
944 let result = context.try_recv_any().unwrap();
945
946 assert!(result.is_none());
947 }
948
949 #[test]
950 fn try_recv_any_returns_packet_from_ready_subscription() {
951 let mut context = Context::new();
952 let config = sample_config_on_unused_port();
953 let id = context.add_subscription(config.clone()).unwrap();
954 context.join_subscription(id).unwrap();
955
956 let sender = make_multicast_sender();
957
958 let payload = b"context try_recv_any";
959 sender
960 .send_to(payload, ipv4_group_socket_addr(&config))
961 .unwrap();
962
963 let deadline = Instant::now() + Duration::from_secs(1);
964 let packet = recv_next_packet(&mut context, deadline);
965
966 assert_eq!(packet.group, std::net::IpAddr::V4(ipv4_group(&config)));
967 assert_eq!(packet.dst_port, config.dst_port);
968 assert_eq!(&packet.payload[..], payload);
969 }
970
971 #[test]
972 fn try_recv_any_with_metadata_returns_packet_from_ready_subscription() {
973 let mut context = Context::new();
974 let config = sample_config_on_unused_port();
975 let id = context.add_subscription(config.clone()).unwrap();
976 context.join_subscription(id).unwrap();
977
978 let sender = make_multicast_sender();
979
980 let payload = b"context try_recv_any_with_metadata";
981 sender
982 .send_to(payload, ipv4_group_socket_addr(&config))
983 .unwrap();
984
985 let deadline = Instant::now() + Duration::from_secs(1);
986 let packet = loop {
987 match context.try_recv_any_with_metadata().unwrap() {
988 Some(packet) => break packet,
989 None if Instant::now() < deadline => {
990 thread::sleep(Duration::from_millis(10));
991 }
992 None => panic!("timed out waiting for packet with metadata from context"),
993 }
994 };
995
996 assert_eq!(packet.packet.subscription_id, id);
997 assert_eq!(
998 packet.packet.group,
999 std::net::IpAddr::V4(ipv4_group(&config))
1000 );
1001 assert_eq!(packet.packet.dst_port, config.dst_port);
1002 assert_eq!(&packet.packet.payload[..], payload);
1003 assert_pktinfo_metadata(&packet, std::net::IpAddr::V4(ipv4_group(&config)));
1004 assert_eq!(
1005 packet.metadata.socket_local_addr,
1006 Some(std::net::SocketAddr::V4(SocketAddrV4::new(
1007 Ipv4Addr::UNSPECIFIED,
1008 config.dst_port,
1009 )))
1010 );
1011 }
1012
1013 #[test]
1014 fn try_recv_any_works_with_caller_provided_socket() {
1015 let mut context = Context::new();
1016 let config = sample_config_on_unused_port();
1017 let socket = make_bound_external_socket(config.dst_port);
1018 let id = context
1019 .add_subscription_with_socket(config.clone(), socket)
1020 .unwrap();
1021 context.join_subscription(id).unwrap();
1022
1023 let sender = make_multicast_sender();
1024
1025 let payload = b"context add_subscription_with_socket";
1026 sender
1027 .send_to(payload, ipv4_group_socket_addr(&config))
1028 .unwrap();
1029
1030 let deadline = Instant::now() + Duration::from_secs(1);
1031 let packet = recv_next_packet(&mut context, deadline);
1032
1033 assert_eq!(packet.subscription_id, id);
1034 assert_eq!(packet.group, std::net::IpAddr::V4(ipv4_group(&config)));
1035 assert_eq!(packet.dst_port, config.dst_port);
1036 assert_eq!(&packet.payload[..], payload);
1037 }
1038
1039 #[test]
1040 fn try_recv_any_round_robins_between_ready_subscriptions() {
1041 let mut context = Context::new();
1042 let first_config = sample_config_on_unused_port();
1043 let second_config = sample_config_on_unused_port();
1044
1045 let first_id = context.add_subscription(first_config.clone()).unwrap();
1046 context.join_subscription(first_id).unwrap();
1047 let second_id = context.add_subscription(second_config.clone()).unwrap();
1048 context.join_subscription(second_id).unwrap();
1049
1050 send_round_robin_test_packets(&first_config, &second_config);
1051
1052 let deadline = Instant::now() + Duration::from_secs(1);
1053
1054 let first_packet = recv_next_packet(&mut context, deadline);
1055 let second_packet = recv_next_packet(&mut context, deadline);
1056 let third_packet = recv_next_packet(&mut context, deadline);
1057
1058 assert_eq!(
1059 (first_packet.subscription_id, &first_packet.payload[..]),
1060 (first_id, b"first-1".as_slice())
1061 );
1062
1063 assert_eq!(
1064 (second_packet.subscription_id, &second_packet.payload[..]),
1065 (second_id, b"second-1".as_slice())
1066 );
1067
1068 assert_eq!(
1069 (third_packet.subscription_id, &third_packet.payload[..]),
1070 (first_id, b"first-2".as_slice())
1071 );
1072 }
1073
1074 #[test]
1075 fn try_recv_batch_into_returns_zero_when_no_packet_is_available() {
1076 let mut context = Context::new();
1077 let id = context
1078 .add_subscription(sample_config_on_unused_port())
1079 .unwrap();
1080 context.join_subscription(id).unwrap();
1081
1082 let mut packets = Vec::new();
1083 let received = context.try_recv_batch_into(&mut packets, 8).unwrap();
1084
1085 assert_eq!(received, 0);
1086 assert!(packets.is_empty());
1087 }
1088
1089 #[test]
1090 fn try_recv_batch_into_receives_up_to_max_packets() {
1091 let mut context = Context::new();
1092 let first_config = sample_config_on_unused_port();
1093 let second_config = sample_config_on_unused_port();
1094
1095 let first_id = context.add_subscription(first_config.clone()).unwrap();
1096 context.join_subscription(first_id).unwrap();
1097 let second_id = context.add_subscription(second_config.clone()).unwrap();
1098 context.join_subscription(second_id).unwrap();
1099
1100 send_round_robin_test_packets(&first_config, &second_config);
1101
1102 let deadline = Instant::now() + Duration::from_secs(1);
1103 let mut packets = Vec::new();
1104
1105 while packets.len() < 2 && Instant::now() < deadline {
1106 let before = packets.len();
1107 let received = context.try_recv_batch_into(&mut packets, 2).unwrap();
1108 assert!(received <= 2);
1109 assert_eq!(packets.len(), before + received);
1110
1111 if packets.len() < 2 {
1112 thread::sleep(Duration::from_millis(10));
1113 }
1114 }
1115
1116 assert!(packets.len() >= 2);
1117
1118 assert_eq!(
1119 (packets[0].subscription_id, &packets[0].payload[..]),
1120 (first_id, b"first-1".as_slice())
1121 );
1122
1123 assert_eq!(
1124 (packets[1].subscription_id, &packets[1].payload[..]),
1125 (second_id, b"second-1".as_slice())
1126 );
1127 }
1128
1129 #[test]
1130 fn try_recv_batch_with_metadata_into_receives_up_to_max_packets() {
1131 let mut context = Context::new();
1132 let first_config = sample_config_on_unused_port();
1133 let second_config = sample_config_on_unused_port();
1134
1135 let first_id = context.add_subscription(first_config.clone()).unwrap();
1136 context.join_subscription(first_id).unwrap();
1137 let second_id = context.add_subscription(second_config.clone()).unwrap();
1138 context.join_subscription(second_id).unwrap();
1139
1140 send_round_robin_test_packets(&first_config, &second_config);
1141
1142 let deadline = Instant::now() + Duration::from_secs(1);
1143 let mut packets = Vec::new();
1144
1145 while packets.len() < 2 && Instant::now() < deadline {
1146 let before = packets.len();
1147 let received = context
1148 .try_recv_batch_with_metadata_into(&mut packets, 2)
1149 .unwrap();
1150 assert!(received <= 2);
1151 assert_eq!(packets.len(), before + received);
1152
1153 if packets.len() < 2 {
1154 thread::sleep(Duration::from_millis(10));
1155 }
1156 }
1157
1158 assert!(packets.len() >= 2);
1159
1160 assert_eq!(
1161 (
1162 packets[0].packet.subscription_id,
1163 &packets[0].packet.payload[..]
1164 ),
1165 (first_id, b"first-1".as_slice())
1166 );
1167
1168 assert_eq!(
1169 (
1170 packets[1].packet.subscription_id,
1171 &packets[1].packet.payload[..]
1172 ),
1173 (second_id, b"second-1".as_slice())
1174 );
1175
1176 assert_eq!(
1177 packets[0].metadata.socket_local_addr,
1178 Some(std::net::SocketAddr::V4(SocketAddrV4::new(
1179 Ipv4Addr::UNSPECIFIED,
1180 first_config.dst_port,
1181 )))
1182 );
1183 assert_eq!(
1184 packets[1].metadata.socket_local_addr,
1185 Some(std::net::SocketAddr::V4(SocketAddrV4::new(
1186 Ipv4Addr::UNSPECIFIED,
1187 second_config.dst_port,
1188 )))
1189 );
1190 }
1191
1192 #[test]
1193 fn try_recv_all_into_drains_all_available_packets() {
1194 let mut context = Context::new();
1195 let first_config = sample_config_on_unused_port();
1196 let second_config = sample_config_on_unused_port();
1197
1198 let first_id = context.add_subscription(first_config.clone()).unwrap();
1199 context.join_subscription(first_id).unwrap();
1200 let second_id = context.add_subscription(second_config.clone()).unwrap();
1201 context.join_subscription(second_id).unwrap();
1202
1203 send_round_robin_test_packets(&first_config, &second_config);
1204
1205 let deadline = Instant::now() + Duration::from_secs(1);
1206 let mut packets = Vec::new();
1207
1208 while packets.len() < 3 && Instant::now() < deadline {
1209 context.try_recv_all_into(&mut packets).unwrap();
1210
1211 if packets.len() < 3 {
1212 thread::sleep(Duration::from_millis(10));
1213 }
1214 }
1215
1216 assert_eq!(packets.len(), 3);
1217
1218 assert_eq!(
1219 (packets[0].subscription_id, &packets[0].payload[..]),
1220 (first_id, b"first-1".as_slice())
1221 );
1222
1223 assert_eq!(
1224 (packets[1].subscription_id, &packets[1].payload[..]),
1225 (second_id, b"second-1".as_slice())
1226 );
1227
1228 assert_eq!(
1229 (packets[2].subscription_id, &packets[2].payload[..]),
1230 (first_id, b"first-2".as_slice())
1231 );
1232 }
1233
1234 #[test]
1235 fn try_recv_all_with_metadata_into_drains_all_available_packets() {
1236 let mut context = Context::new();
1237 let first_config = sample_config_on_unused_port();
1238 let second_config = sample_config_on_unused_port();
1239
1240 let first_id = context.add_subscription(first_config.clone()).unwrap();
1241 context.join_subscription(first_id).unwrap();
1242 let second_id = context.add_subscription(second_config.clone()).unwrap();
1243 context.join_subscription(second_id).unwrap();
1244
1245 send_round_robin_test_packets(&first_config, &second_config);
1246
1247 let deadline = Instant::now() + Duration::from_secs(1);
1248 let mut packets = Vec::new();
1249
1250 while packets.len() < 3 && Instant::now() < deadline {
1251 context
1252 .try_recv_all_with_metadata_into(&mut packets)
1253 .unwrap();
1254
1255 if packets.len() < 3 {
1256 thread::sleep(Duration::from_millis(10));
1257 }
1258 }
1259
1260 assert_eq!(packets.len(), 3);
1261
1262 assert_eq!(
1263 (
1264 packets[0].packet.subscription_id,
1265 &packets[0].packet.payload[..]
1266 ),
1267 (first_id, b"first-1".as_slice())
1268 );
1269
1270 assert_eq!(
1271 (
1272 packets[1].packet.subscription_id,
1273 &packets[1].packet.payload[..]
1274 ),
1275 (second_id, b"second-1".as_slice())
1276 );
1277
1278 assert_eq!(
1279 (
1280 packets[2].packet.subscription_id,
1281 &packets[2].packet.payload[..]
1282 ),
1283 (first_id, b"first-2".as_slice())
1284 );
1285 }
1286
1287 #[test]
1288 fn add_subscription_creates_bound_subscription() {
1289 let mut context = Context::new();
1290 let id = context
1291 .add_subscription(sample_config_on_unused_port())
1292 .unwrap();
1293
1294 let subscription = context.get_subscription(id).unwrap();
1295 assert_eq!(subscription.state(), SubscriptionState::Bound);
1296 }
1297
1298 #[test]
1299 fn join_subscription_transitions_bound_to_joined() {
1300 let mut context = Context::new();
1301 let id = context
1302 .add_subscription(sample_config_on_unused_port())
1303 .unwrap();
1304
1305 context.join_subscription(id).unwrap();
1306
1307 let subscription = context.get_subscription(id).unwrap();
1308 assert_eq!(subscription.state(), SubscriptionState::Joined);
1309 }
1310
1311 #[test]
1312 fn leave_subscription_transitions_joined_to_bound() {
1313 let mut context = Context::new();
1314 let id = context
1315 .add_subscription(sample_config_on_unused_port())
1316 .unwrap();
1317
1318 context.join_subscription(id).unwrap();
1319 context.leave_subscription(id).unwrap();
1320
1321 let subscription = context.get_subscription(id).unwrap();
1322 assert_eq!(subscription.state(), SubscriptionState::Bound);
1323 }
1324
1325 #[test]
1326 fn join_subscription_rejects_already_joined_subscription() {
1327 let mut context = Context::new();
1328 let id = context
1329 .add_subscription(sample_config_on_unused_port())
1330 .unwrap();
1331
1332 context.join_subscription(id).unwrap();
1333 let result = context.join_subscription(id);
1334
1335 assert!(matches!(result, Err(McrxError::SubscriptionAlreadyJoined)));
1336 }
1337
1338 #[test]
1339 fn leave_subscription_rejects_not_joined_subscription() {
1340 let mut context = Context::new();
1341 let id = context
1342 .add_subscription(sample_config_on_unused_port())
1343 .unwrap();
1344
1345 let result = context.leave_subscription(id);
1346
1347 assert!(matches!(result, Err(McrxError::SubscriptionNotJoined)));
1348 }
1349}