Skip to main content

mcrx_core/
context.rs

1use crate::config::SubscriptionConfig;
2use crate::error::McrxError;
3use crate::packet::{Packet, PacketWithMetadata};
4use crate::platform::{
5    ReceiveSocket, join_multicast_group, leave_multicast_group, open_bound_socket,
6    prepare_existing_socket,
7};
8use crate::subscription::{Subscription, SubscriptionId};
9use socket2::Socket;
10
11#[cfg(feature = "metrics")]
12use crate::metrics::ContextMetricsSnapshot;
13#[cfg(feature = "metrics")]
14use std::sync::atomic::{AtomicU64, Ordering};
15#[cfg(feature = "metrics")]
16use std::time::SystemTime;
17
18#[cfg(feature = "metrics")]
19#[derive(Debug, Default)]
20struct ContextMetricsInner {
21    subscriptions_added: AtomicU64,
22    subscriptions_removed: AtomicU64,
23    total_packets_received: AtomicU64,
24    total_bytes_received: AtomicU64,
25    total_would_block_count: AtomicU64,
26    total_receive_errors: AtomicU64,
27    total_join_count: AtomicU64,
28    total_leave_count: AtomicU64,
29    batch_calls: AtomicU64,
30    batch_packets_received: AtomicU64,
31}
32
33/// Owns and manages the set of active subscriptions.
34#[derive(Debug, Default)]
35pub struct Context {
36    subscriptions: Vec<Subscription>,
37    next_subscription_id: u64,
38    next_recv_index: usize,
39    #[cfg(feature = "metrics")]
40    metrics: ContextMetricsInner,
41}
42
43impl Context {
44    #[cfg(feature = "metrics")]
45    fn record_received_packet(&self, payload_len: usize) {
46        self.metrics
47            .total_packets_received
48            .fetch_add(1, Ordering::Relaxed);
49        self.metrics
50            .total_bytes_received
51            .fetch_add(payload_len as u64, Ordering::Relaxed);
52    }
53
54    #[cfg(feature = "metrics")]
55    fn record_would_block(&self) {
56        self.metrics
57            .total_would_block_count
58            .fetch_add(1, Ordering::Relaxed);
59    }
60
61    #[cfg(feature = "metrics")]
62    fn record_receive_error(&self) {
63        self.metrics
64            .total_receive_errors
65            .fetch_add(1, Ordering::Relaxed);
66    }
67
68    #[cfg(feature = "metrics")]
69    fn record_batch_call(&self) {
70        self.metrics.batch_calls.fetch_add(1, Ordering::Relaxed);
71    }
72
73    #[cfg(feature = "metrics")]
74    fn record_batch_packets_received(&self, packet_count: usize) {
75        self.metrics
76            .batch_packets_received
77            .fetch_add(packet_count as u64, Ordering::Relaxed);
78    }
79
80    fn ensure_subscription_config_is_unique(
81        &self,
82        config: &SubscriptionConfig,
83    ) -> Result<(), McrxError> {
84        if self
85            .subscriptions
86            .iter()
87            .any(|subscription| subscription.config() == config)
88        {
89            return Err(McrxError::DuplicateSubscription);
90        }
91
92        Ok(())
93    }
94
95    fn insert_subscription(
96        &mut self,
97        config: SubscriptionConfig,
98        socket: ReceiveSocket,
99    ) -> SubscriptionId {
100        let id = SubscriptionId(self.next_subscription_id);
101        self.next_subscription_id += 1;
102
103        let subscription = Subscription::from_receive_socket(id, config, socket);
104        self.subscriptions.push(subscription);
105
106        #[cfg(feature = "metrics")]
107        self.metrics
108            .subscriptions_added
109            .fetch_add(1, Ordering::Relaxed);
110
111        id
112    }
113
114    fn finish_subscription_removal(&mut self, index: usize) -> Subscription {
115        let removed = self.subscriptions.swap_remove(index);
116
117        if self.subscriptions.is_empty() {
118            self.next_recv_index = 0;
119        } else {
120            self.next_recv_index %= self.subscriptions.len();
121        }
122
123        #[cfg(feature = "metrics")]
124        self.metrics
125            .subscriptions_removed
126            .fetch_add(1, Ordering::Relaxed);
127
128        removed
129    }
130
131    /// Creates an empty context with no subscriptions.
132    pub fn new() -> Self {
133        Self {
134            subscriptions: Vec::new(),
135            next_subscription_id: 1,
136            next_recv_index: 0,
137            #[cfg(feature = "metrics")]
138            metrics: ContextMetricsInner::default(),
139        }
140    }
141
142    /// Returns the number of active subscriptions currently stored in the context.
143    pub fn subscription_count(&self) -> usize {
144        self.subscriptions.len()
145    }
146
147    /// Returns a snapshot of the context's current metrics.
148    ///
149    /// Counter fields such as `total_packets_received` are cumulative for the
150    /// lifetime of the context and are not reduced when subscriptions are
151    /// removed.
152    #[cfg(feature = "metrics")]
153    pub fn metrics_snapshot(&self) -> ContextMetricsSnapshot {
154        let mut joined_subscriptions = 0usize;
155
156        for subscription in &self.subscriptions {
157            if subscription.is_joined() {
158                joined_subscriptions += 1;
159            }
160        }
161
162        ContextMetricsSnapshot {
163            subscriptions_added: self.metrics.subscriptions_added.load(Ordering::Relaxed),
164            subscriptions_removed: self.metrics.subscriptions_removed.load(Ordering::Relaxed),
165            active_subscriptions: self.subscriptions.len(),
166            joined_subscriptions,
167            total_packets_received: self.metrics.total_packets_received.load(Ordering::Relaxed),
168            total_bytes_received: self.metrics.total_bytes_received.load(Ordering::Relaxed),
169            total_would_block_count: self.metrics.total_would_block_count.load(Ordering::Relaxed),
170            total_receive_errors: self.metrics.total_receive_errors.load(Ordering::Relaxed),
171            total_join_count: self.metrics.total_join_count.load(Ordering::Relaxed),
172            total_leave_count: self.metrics.total_leave_count.load(Ordering::Relaxed),
173            batch_calls: self.metrics.batch_calls.load(Ordering::Relaxed),
174            batch_packets_received: self.metrics.batch_packets_received.load(Ordering::Relaxed),
175            captured_at: SystemTime::now(),
176        }
177    }
178
179    /// Returns true if a subscription with the given ID exists.
180    pub fn contains_subscription(&self, id: SubscriptionId) -> bool {
181        self.subscriptions
182            .iter()
183            .any(|subscription| subscription.id() == id)
184    }
185
186    /// Returns a read-only reference to the subscription with the given ID, if it exists.
187    pub fn get_subscription(&self, id: SubscriptionId) -> Option<&Subscription> {
188        self.subscriptions
189            .iter()
190            .find(|subscription| subscription.id() == id)
191    }
192
193    /// Returns a mutable reference to the subscription with the given ID, if it exists.
194    pub fn get_subscription_mut(&mut self, id: SubscriptionId) -> Option<&mut Subscription> {
195        self.subscriptions
196            .iter_mut()
197            .find(|subscription| subscription.id() == id)
198    }
199
200    /// Adds a new subscription to the context.
201    ///
202    /// The configuration is validated before insertion. If an identical subscription
203    /// already exists, an error is returned instead of creating a duplicate.
204    ///
205    /// This function creates and binds the underlying socket, but does not join the
206    /// multicast group yet. Call `join_subscription()` to activate multicast reception.
207    pub fn add_subscription(
208        &mut self,
209        config: SubscriptionConfig,
210    ) -> Result<SubscriptionId, McrxError> {
211        config.validate()?;
212        self.ensure_subscription_config_is_unique(&config)?;
213
214        let socket = open_bound_socket(&config)?;
215        Ok(self.insert_subscription(config, socket))
216    }
217
218    /// Adds a new subscription using a caller-provided socket.
219    ///
220    /// The socket must already be bound to the destination port from `config`.
221    /// This method preserves the existing lifecycle model: the context will still
222    /// perform multicast join/leave operations later via `join_subscription()` and
223    /// `leave_subscription()`.
224    ///
225    /// The supplied socket is switched to non-blocking mode so the receive APIs keep
226    /// their usual non-blocking behavior.
227    pub fn add_subscription_with_socket(
228        &mut self,
229        config: SubscriptionConfig,
230        socket: Socket,
231    ) -> Result<SubscriptionId, McrxError> {
232        config.validate()?;
233        self.ensure_subscription_config_is_unique(&config)?;
234
235        let socket = prepare_existing_socket(socket, &config)?;
236
237        Ok(self.insert_subscription(config, socket))
238    }
239
240    /// Removes the subscription with the given ID.
241    ///
242    /// Returns true if a subscription was removed and false if no matching
243    /// subscription was found.
244    ///
245    /// This uses `swap_remove`, so subscription order is not preserved.
246    pub fn remove_subscription(&mut self, id: SubscriptionId) -> bool {
247        self.take_subscription(id).is_some()
248    }
249
250    /// Removes the subscription with the given ID and returns it to the caller.
251    ///
252    /// This preserves the current socket ownership and lifecycle state, which is
253    /// useful when moving a subscription into an external event loop or runtime.
254    ///
255    /// This uses `swap_remove`, so subscription order is not preserved.
256    pub fn take_subscription(&mut self, id: SubscriptionId) -> Option<Subscription> {
257        let index = self
258            .subscriptions
259            .iter()
260            .position(|subscription| subscription.id() == id)?;
261
262        Some(self.finish_subscription_removal(index))
263    }
264
265    /// Joins the multicast group for the given subscription.
266    pub fn join_subscription(&mut self, id: SubscriptionId) -> Result<(), McrxError> {
267        let subscription = self
268            .get_subscription_mut(id)
269            .ok_or(McrxError::SubscriptionNotFound)?;
270
271        if subscription.is_joined() {
272            return Err(McrxError::SubscriptionAlreadyJoined);
273        }
274
275        join_multicast_group(subscription.socket(), subscription.config())?;
276        subscription.mark_joined()?;
277
278        #[cfg(feature = "metrics")]
279        self.metrics
280            .total_join_count
281            .fetch_add(1, Ordering::Relaxed);
282
283        Ok(())
284    }
285
286    fn try_recv_from_joined<T>(
287        &mut self,
288        mut recv: impl FnMut(&Subscription) -> Result<Option<T>, McrxError>,
289        _packet_len: impl Fn(&T) -> usize,
290    ) -> Result<Option<T>, McrxError> {
291        let subscription_count = self.subscriptions.len();
292
293        if subscription_count == 0 {
294            return Ok(None);
295        }
296
297        for offset in 0..subscription_count {
298            let index = (self.next_recv_index + offset) % subscription_count;
299            let subscription = &self.subscriptions[index];
300
301            if !subscription.is_joined() {
302                continue;
303            }
304
305            match recv(subscription) {
306                Ok(Some(packet)) => {
307                    self.next_recv_index = (index + 1) % subscription_count;
308
309                    #[cfg(feature = "metrics")]
310                    self.record_received_packet(_packet_len(&packet));
311
312                    return Ok(Some(packet));
313                }
314                Ok(None) => {
315                    #[cfg(feature = "metrics")]
316                    self.record_would_block();
317                }
318                Err(err) => {
319                    #[cfg(feature = "metrics")]
320                    self.record_receive_error();
321
322                    return Err(err);
323                }
324            }
325        }
326
327        Ok(None)
328    }
329
330    /// Leaves the multicast group for the given subscription while keeping the socket bound.
331    pub fn leave_subscription(&mut self, id: SubscriptionId) -> Result<(), McrxError> {
332        let subscription = self
333            .get_subscription_mut(id)
334            .ok_or(McrxError::SubscriptionNotFound)?;
335
336        if !subscription.is_joined() {
337            return Err(McrxError::SubscriptionNotJoined);
338        }
339
340        leave_multicast_group(subscription.socket(), subscription.config())?;
341        subscription.mark_bound()?;
342
343        #[cfg(feature = "metrics")]
344        self.metrics
345            .total_leave_count
346            .fetch_add(1, Ordering::Relaxed);
347
348        Ok(())
349    }
350
351    /// Returns a read-only slice of all subscriptions currently stored in the context.
352    pub fn subscriptions(&self) -> &[Subscription] {
353        &self.subscriptions
354    }
355
356    /// Returns a mutable slice of all subscriptions currently stored in the context.
357    pub fn subscriptions_mut(&mut self) -> &mut [Subscription] {
358        &mut self.subscriptions
359    }
360
361    /// Attempts to receive a single packet from any joined subscription without blocking.
362    ///
363    /// Subscriptions are scanned using round-robin style fairness so that repeated
364    /// calls do not always favor the first subscription.
365    ///
366    /// Returns the first available packet, if any joined subscription currently has
367    /// one ready to be read.
368    pub fn try_recv_any(&mut self) -> Result<Option<Packet>, McrxError> {
369        self.try_recv_from_joined(Subscription::try_recv, |packet| packet.payload.len())
370    }
371
372    /// Attempts to receive a single packet with richer receive metadata from any
373    /// joined subscription without blocking.
374    ///
375    /// This uses the same round-robin fairness logic as `try_recv_any()`.
376    pub fn try_recv_any_with_metadata(&mut self) -> Result<Option<PacketWithMetadata>, McrxError> {
377        self.try_recv_from_joined(Subscription::try_recv_with_metadata, |packet| {
378            packet.packet.payload.len()
379        })
380    }
381
382    /// Attempts to receive up to `max_packets` packets from any subscriptions without blocking.
383    ///
384    /// This method repeatedly calls `try_recv_any()` using the same round-robin fairness logic
385    /// and pushes received packets into the provided `out` vector.
386    ///
387    /// Returns the number of packets that were added to `out`.
388    ///
389    /// Behavior:
390    /// - Stops early if no more packets are available
391    /// - Does not block or wait for new packets
392    /// - Preserves fairness across subscriptions
393    pub fn try_recv_batch_into(
394        &mut self,
395        out: &mut Vec<Packet>,
396        max_packets: usize,
397    ) -> Result<usize, McrxError> {
398        #[cfg(feature = "metrics")]
399        self.record_batch_call();
400
401        self.try_recv_batch_into_impl(out, max_packets)
402    }
403
404    fn try_recv_batch_into_impl(
405        &mut self,
406        out: &mut Vec<Packet>,
407        max_packets: usize,
408    ) -> Result<usize, McrxError> {
409        let mut received = 0;
410
411        for _ in 0..max_packets {
412            match self.try_recv_any()? {
413                Some(packet) => {
414                    out.push(packet);
415                    received += 1;
416                }
417                None => break,
418            }
419        }
420
421        #[cfg(feature = "metrics")]
422        self.record_batch_packets_received(received);
423
424        Ok(received)
425    }
426
427    /// Attempts to receive up to `max_packets` packets with richer receive
428    /// metadata from any subscriptions without blocking.
429    pub fn try_recv_batch_with_metadata_into(
430        &mut self,
431        out: &mut Vec<PacketWithMetadata>,
432        max_packets: usize,
433    ) -> Result<usize, McrxError> {
434        #[cfg(feature = "metrics")]
435        self.record_batch_call();
436
437        self.try_recv_batch_with_metadata_into_impl(out, max_packets)
438    }
439
440    fn try_recv_batch_with_metadata_into_impl(
441        &mut self,
442        out: &mut Vec<PacketWithMetadata>,
443        max_packets: usize,
444    ) -> Result<usize, McrxError> {
445        let mut received = 0;
446
447        for _ in 0..max_packets {
448            match self.try_recv_any_with_metadata()? {
449                Some(packet) => {
450                    out.push(packet);
451                    received += 1;
452                }
453                None => break,
454            }
455        }
456
457        #[cfg(feature = "metrics")]
458        self.record_batch_packets_received(received);
459
460        Ok(received)
461    }
462
463    /// Attempts to receive all currently available packets without blocking.
464    ///
465    /// This is a convenience wrapper around `try_recv_batch_into` that continues
466    /// draining until no more packets are available.
467    ///
468    /// Note: this only drains packets that are currently available without blocking.
469    /// It may result in unbounded growth of `out` if a large number of packets are queued,
470    /// so callers should ensure capacity if needed.
471    pub fn try_recv_all_into(&mut self, out: &mut Vec<Packet>) -> Result<usize, McrxError> {
472        let mut total_received = 0;
473
474        #[cfg(feature = "metrics")]
475        self.record_batch_call();
476
477        loop {
478            let received = self.try_recv_batch_into_impl(out, usize::MAX)?;
479            total_received += received;
480
481            if received == 0 {
482                break;
483            }
484        }
485
486        Ok(total_received)
487    }
488
489    /// Attempts to receive all currently available packets with richer receive
490    /// metadata without blocking.
491    pub fn try_recv_all_with_metadata_into(
492        &mut self,
493        out: &mut Vec<PacketWithMetadata>,
494    ) -> Result<usize, McrxError> {
495        let mut total_received = 0;
496
497        #[cfg(feature = "metrics")]
498        self.record_batch_call();
499
500        loop {
501            let received = self.try_recv_batch_with_metadata_into_impl(out, usize::MAX)?;
502            total_received += received;
503
504            if received == 0 {
505                break;
506            }
507        }
508
509        Ok(total_received)
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516    use crate::config::SourceFilter;
517    use crate::subscription::SubscriptionState;
518    use socket2::{Domain, Protocol, SockAddr, Socket, Type};
519    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
520    use std::thread;
521    use std::time::{Duration, Instant};
522
523    use crate::test_support::{
524        ipv6_group_socket_addr, make_multicast_sender, make_multicast_sender_v6, recv_next_packet,
525        sample_config_on_unused_port, sample_config_v6_on_unused_port,
526    };
527
528    fn make_bound_external_socket(port: u16) -> Socket {
529        let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)).unwrap();
530        socket.set_reuse_address(true).unwrap();
531        socket
532            .bind(&SockAddr::from(SocketAddrV4::new(
533                Ipv4Addr::UNSPECIFIED,
534                port,
535            )))
536            .unwrap();
537        socket
538    }
539
540    fn ipv4_group(config: &SubscriptionConfig) -> Ipv4Addr {
541        config.ipv4_membership().unwrap().group
542    }
543
544    fn ipv4_group_socket_addr(config: &SubscriptionConfig) -> SocketAddrV4 {
545        SocketAddrV4::new(ipv4_group(config), config.dst_port)
546    }
547
548    fn assert_pktinfo_metadata(
549        packet: &PacketWithMetadata,
550        expected_destination: std::net::IpAddr,
551    ) {
552        #[cfg(any(
553            target_os = "linux",
554            target_os = "android",
555            windows,
556            target_vendor = "apple",
557            target_os = "freebsd",
558            target_os = "dragonfly",
559            target_os = "netbsd",
560            target_os = "openbsd"
561        ))]
562        {
563            assert_eq!(
564                packet.metadata.destination_local_ip,
565                Some(expected_destination)
566            );
567            assert!(packet.metadata.ingress_interface_index.is_some());
568        }
569
570        #[cfg(not(any(
571            target_os = "linux",
572            target_os = "android",
573            windows,
574            target_vendor = "apple",
575            target_os = "freebsd",
576            target_os = "dragonfly",
577            target_os = "netbsd",
578            target_os = "openbsd"
579        )))]
580        {
581            let _ = expected_destination;
582            assert_eq!(packet.metadata.destination_local_ip, None);
583            assert_eq!(packet.metadata.ingress_interface_index, None);
584        }
585    }
586
587    fn send_round_robin_test_packets(
588        first_config: &SubscriptionConfig,
589        second_config: &SubscriptionConfig,
590    ) {
591        let sender = make_multicast_sender();
592
593        sender
594            .send_to(b"first-1", ipv4_group_socket_addr(first_config))
595            .unwrap();
596        sender
597            .send_to(b"second-1", ipv4_group_socket_addr(second_config))
598            .unwrap();
599        sender
600            .send_to(b"first-2", ipv4_group_socket_addr(first_config))
601            .unwrap();
602    }
603
604    #[test]
605    fn new_context_starts_empty() {
606        let context = Context::new();
607
608        assert_eq!(context.subscription_count(), 0);
609        assert!(context.subscriptions().is_empty());
610    }
611
612    #[test]
613    fn add_subscription_returns_id_and_increases_count() {
614        let mut context = Context::new();
615
616        let id = context
617            .add_subscription(sample_config_on_unused_port())
618            .unwrap();
619
620        assert_eq!(id, SubscriptionId(1));
621        assert_eq!(context.subscription_count(), 1);
622        assert_eq!(context.subscriptions()[0].id(), id);
623    }
624
625    #[test]
626    fn adding_two_subscriptions_generates_different_ids() {
627        let mut context = Context::new();
628
629        let first = context
630            .add_subscription(sample_config_on_unused_port())
631            .unwrap();
632        let second = context
633            .add_subscription(sample_config_on_unused_port())
634            .unwrap();
635
636        assert_ne!(first, second);
637        assert_eq!(first, SubscriptionId(1));
638        assert_eq!(second, SubscriptionId(2));
639    }
640
641    #[test]
642    fn invalid_subscription_is_rejected() {
643        let mut context = Context::new();
644
645        let invalid_config = SubscriptionConfig {
646            group: std::net::IpAddr::V4(Ipv4Addr::new(192, 168, 1, 10)),
647            source: SourceFilter::Any,
648            dst_port: 5000,
649            interface: None,
650            interface_index: None,
651        };
652
653        let result = context.add_subscription(invalid_config);
654
655        assert!(matches!(result, Err(McrxError::InvalidMulticastGroup)));
656        assert_eq!(context.subscription_count(), 0);
657    }
658
659    #[test]
660    fn try_recv_any_returns_packet_from_ready_ipv6_subscription() {
661        let mut context = Context::new();
662        let config = sample_config_v6_on_unused_port();
663        let id = context.add_subscription(config.clone()).unwrap();
664        context.join_subscription(id).unwrap();
665
666        let sender = make_multicast_sender_v6(std::net::Ipv6Addr::LOCALHOST);
667        let payload = b"context try_recv_any ipv6 asm";
668        sender
669            .send_to(payload, ipv6_group_socket_addr(&config))
670            .unwrap();
671
672        let deadline = Instant::now() + Duration::from_secs(1);
673        let packet = recv_next_packet(&mut context, deadline);
674
675        assert_eq!(packet.subscription_id, id);
676        assert_eq!(packet.group, config.group);
677        assert_eq!(packet.dst_port, config.dst_port);
678        assert_eq!(&packet.payload[..], payload);
679    }
680
681    #[test]
682    fn try_recv_any_with_metadata_returns_packet_from_ready_ipv6_subscription() {
683        let mut context = Context::new();
684        let config = sample_config_v6_on_unused_port();
685        let id = context.add_subscription(config.clone()).unwrap();
686        context.join_subscription(id).unwrap();
687
688        let sender = make_multicast_sender_v6(std::net::Ipv6Addr::LOCALHOST);
689
690        let payload = b"context try_recv_any_with_metadata ipv6 asm";
691        sender
692            .send_to(payload, ipv6_group_socket_addr(&config))
693            .unwrap();
694
695        let deadline = Instant::now() + Duration::from_secs(1);
696        let packet = loop {
697            match context.try_recv_any_with_metadata().unwrap() {
698                Some(packet) => break packet,
699                None if Instant::now() < deadline => {
700                    thread::sleep(Duration::from_millis(10));
701                }
702                None => panic!("timed out waiting for IPv6 packet with metadata from context"),
703            }
704        };
705
706        assert_eq!(packet.packet.subscription_id, id);
707        assert_eq!(packet.packet.group, config.group);
708        assert_eq!(packet.packet.dst_port, config.dst_port);
709        assert_eq!(&packet.packet.payload[..], payload);
710        assert_pktinfo_metadata(&packet, config.group);
711        assert_eq!(packet.metadata.configured_interface, config.interface);
712        assert_eq!(
713            packet.metadata.configured_interface_index,
714            config.interface_index
715        );
716        assert_eq!(
717            packet.metadata.socket_local_addr,
718            Some(std::net::SocketAddr::V6(SocketAddrV6::new(
719                Ipv6Addr::UNSPECIFIED,
720                config.dst_port,
721                0,
722                0,
723            )))
724        );
725    }
726
727    #[test]
728    fn remove_existing_subscription_returns_true() {
729        let mut context = Context::new();
730
731        let id = context
732            .add_subscription(sample_config_on_unused_port())
733            .unwrap();
734
735        let removed = context.remove_subscription(id);
736
737        assert!(removed);
738        assert_eq!(context.subscription_count(), 0);
739    }
740
741    #[test]
742    fn remove_missing_subscription_returns_false() {
743        let mut context = Context::new();
744
745        let removed = context.remove_subscription(SubscriptionId(999));
746
747        assert!(!removed);
748    }
749
750    #[test]
751    fn take_subscription_returns_owned_subscription_and_removes_it() {
752        let mut context = Context::new();
753        let config = sample_config_on_unused_port();
754        let id = context.add_subscription(config.clone()).unwrap();
755        context.join_subscription(id).unwrap();
756
757        let subscription = context.take_subscription(id).unwrap();
758
759        assert_eq!(subscription.id(), id);
760        assert_eq!(subscription.config(), &config);
761        assert_eq!(subscription.state(), SubscriptionState::Joined);
762        assert_eq!(context.subscription_count(), 0);
763        assert!(!context.contains_subscription(id));
764    }
765
766    #[test]
767    fn taken_subscription_can_be_split_into_owned_parts() {
768        let mut context = Context::new();
769        let config = sample_config_on_unused_port();
770        let id = context.add_subscription(config.clone()).unwrap();
771        context.join_subscription(id).unwrap();
772
773        let subscription = context.take_subscription(id).unwrap();
774        let parts = subscription.into_parts();
775
776        assert_eq!(parts.id, id);
777        assert_eq!(parts.config, config);
778        assert_eq!(parts.state, SubscriptionState::Joined);
779
780        let local_addr = parts.socket.local_addr().unwrap().as_socket().unwrap();
781        assert_eq!(local_addr.port(), parts.config.dst_port);
782    }
783
784    #[test]
785    fn taken_joined_subscription_can_still_receive_packets() {
786        let mut context = Context::new();
787        let config = sample_config_on_unused_port();
788        let id = context.add_subscription(config.clone()).unwrap();
789        context.join_subscription(id).unwrap();
790
791        let subscription = context.take_subscription(id).unwrap();
792        let sender = make_multicast_sender();
793        let payload = b"context take_subscription handoff";
794
795        sender
796            .send_to(payload, ipv4_group_socket_addr(&config))
797            .unwrap();
798
799        let deadline = Instant::now() + Duration::from_secs(1);
800        let packet = loop {
801            match subscription.try_recv().unwrap() {
802                Some(packet) => break packet,
803                None if Instant::now() < deadline => thread::sleep(Duration::from_millis(10)),
804                None => panic!("timed out waiting for packet after subscription handoff"),
805            }
806        };
807
808        assert_eq!(packet.subscription_id, id);
809        assert_eq!(packet.group, std::net::IpAddr::V4(ipv4_group(&config)));
810        assert_eq!(packet.dst_port, config.dst_port);
811        assert_eq!(&packet.payload[..], payload);
812    }
813
814    #[test]
815    fn three_subscriptions_have_len_3() {
816        let mut context = Context::new();
817
818        context
819            .add_subscription(sample_config_on_unused_port())
820            .unwrap();
821        context
822            .add_subscription(sample_config_on_unused_port())
823            .unwrap();
824        context
825            .add_subscription(sample_config_on_unused_port())
826            .unwrap();
827
828        assert_eq!(context.subscription_count(), 3);
829    }
830
831    #[test]
832    fn duplicate_subscription_is_rejected() {
833        let mut context = Context::new();
834        let config = sample_config_on_unused_port();
835
836        let first = context.add_subscription(config.clone());
837        let second = context.add_subscription(config);
838
839        assert!(first.is_ok());
840        assert!(matches!(second, Err(McrxError::DuplicateSubscription)));
841        assert_eq!(context.subscription_count(), 1);
842    }
843
844    #[test]
845    fn add_subscription_with_socket_returns_id_and_increases_count() {
846        let mut context = Context::new();
847        let config = sample_config_on_unused_port();
848        let socket = make_bound_external_socket(config.dst_port);
849
850        let id = context
851            .add_subscription_with_socket(config, socket)
852            .unwrap();
853
854        assert_eq!(id, SubscriptionId(1));
855        assert_eq!(context.subscription_count(), 1);
856        assert_eq!(context.subscriptions()[0].id(), id);
857    }
858
859    #[test]
860    fn add_subscription_with_socket_rejects_port_mismatch() {
861        let mut context = Context::new();
862        let config = sample_config_on_unused_port();
863        let expected_port = config.dst_port;
864        let socket = make_bound_external_socket(0);
865
866        let result = context.add_subscription_with_socket(config, socket);
867
868        assert!(matches!(
869            result,
870            Err(McrxError::ExistingSocketPortMismatch { expected, .. })
871                if expected == expected_port
872        ));
873    }
874
875    #[test]
876    fn contains_subscription_returns_true_for_existing_id() {
877        let mut context = Context::new();
878        let id = context
879            .add_subscription(sample_config_on_unused_port())
880            .unwrap();
881
882        assert!(context.contains_subscription(id));
883    }
884
885    #[test]
886    fn contains_subscription_returns_false_for_missing_id() {
887        let context = Context::new();
888
889        assert!(!context.contains_subscription(SubscriptionId(999)));
890    }
891
892    #[test]
893    fn get_subscription_returns_matching_subscription() {
894        let mut context = Context::new();
895        let id = context
896            .add_subscription(sample_config_on_unused_port())
897            .unwrap();
898
899        let subscription = context.get_subscription(id);
900
901        assert!(subscription.is_some());
902        assert_eq!(subscription.unwrap().id(), id);
903    }
904
905    #[test]
906    fn get_subscription_returns_none_for_missing_id() {
907        let context = Context::new();
908
909        let subscription = context.get_subscription(SubscriptionId(999));
910
911        assert!(subscription.is_none());
912    }
913
914    #[test]
915    fn get_subscription_mut_returns_matching_subscription() {
916        let mut context = Context::new();
917        let id = context
918            .add_subscription(sample_config_on_unused_port())
919            .unwrap();
920
921        let subscription = context.get_subscription_mut(id);
922
923        assert!(subscription.is_some());
924        assert_eq!(subscription.unwrap().id(), id);
925    }
926
927    #[test]
928    fn get_subscription_mut_returns_none_for_missing_id() {
929        let mut context = Context::new();
930
931        let subscription = context.get_subscription_mut(SubscriptionId(999));
932
933        assert!(subscription.is_none());
934    }
935
936    #[test]
937    fn try_recv_any_returns_none_when_no_packet_is_available() {
938        let mut context = Context::new();
939        let id = context
940            .add_subscription(sample_config_on_unused_port())
941            .unwrap();
942        context.join_subscription(id).unwrap();
943
944        let result = context.try_recv_any().unwrap();
945
946        assert!(result.is_none());
947    }
948
949    #[test]
950    fn try_recv_any_returns_packet_from_ready_subscription() {
951        let mut context = Context::new();
952        let config = sample_config_on_unused_port();
953        let id = context.add_subscription(config.clone()).unwrap();
954        context.join_subscription(id).unwrap();
955
956        let sender = make_multicast_sender();
957
958        let payload = b"context try_recv_any";
959        sender
960            .send_to(payload, ipv4_group_socket_addr(&config))
961            .unwrap();
962
963        let deadline = Instant::now() + Duration::from_secs(1);
964        let packet = recv_next_packet(&mut context, deadline);
965
966        assert_eq!(packet.group, std::net::IpAddr::V4(ipv4_group(&config)));
967        assert_eq!(packet.dst_port, config.dst_port);
968        assert_eq!(&packet.payload[..], payload);
969    }
970
971    #[test]
972    fn try_recv_any_with_metadata_returns_packet_from_ready_subscription() {
973        let mut context = Context::new();
974        let config = sample_config_on_unused_port();
975        let id = context.add_subscription(config.clone()).unwrap();
976        context.join_subscription(id).unwrap();
977
978        let sender = make_multicast_sender();
979
980        let payload = b"context try_recv_any_with_metadata";
981        sender
982            .send_to(payload, ipv4_group_socket_addr(&config))
983            .unwrap();
984
985        let deadline = Instant::now() + Duration::from_secs(1);
986        let packet = loop {
987            match context.try_recv_any_with_metadata().unwrap() {
988                Some(packet) => break packet,
989                None if Instant::now() < deadline => {
990                    thread::sleep(Duration::from_millis(10));
991                }
992                None => panic!("timed out waiting for packet with metadata from context"),
993            }
994        };
995
996        assert_eq!(packet.packet.subscription_id, id);
997        assert_eq!(
998            packet.packet.group,
999            std::net::IpAddr::V4(ipv4_group(&config))
1000        );
1001        assert_eq!(packet.packet.dst_port, config.dst_port);
1002        assert_eq!(&packet.packet.payload[..], payload);
1003        assert_pktinfo_metadata(&packet, std::net::IpAddr::V4(ipv4_group(&config)));
1004        assert_eq!(
1005            packet.metadata.socket_local_addr,
1006            Some(std::net::SocketAddr::V4(SocketAddrV4::new(
1007                Ipv4Addr::UNSPECIFIED,
1008                config.dst_port,
1009            )))
1010        );
1011    }
1012
1013    #[test]
1014    fn try_recv_any_works_with_caller_provided_socket() {
1015        let mut context = Context::new();
1016        let config = sample_config_on_unused_port();
1017        let socket = make_bound_external_socket(config.dst_port);
1018        let id = context
1019            .add_subscription_with_socket(config.clone(), socket)
1020            .unwrap();
1021        context.join_subscription(id).unwrap();
1022
1023        let sender = make_multicast_sender();
1024
1025        let payload = b"context add_subscription_with_socket";
1026        sender
1027            .send_to(payload, ipv4_group_socket_addr(&config))
1028            .unwrap();
1029
1030        let deadline = Instant::now() + Duration::from_secs(1);
1031        let packet = recv_next_packet(&mut context, deadline);
1032
1033        assert_eq!(packet.subscription_id, id);
1034        assert_eq!(packet.group, std::net::IpAddr::V4(ipv4_group(&config)));
1035        assert_eq!(packet.dst_port, config.dst_port);
1036        assert_eq!(&packet.payload[..], payload);
1037    }
1038
1039    #[test]
1040    fn try_recv_any_round_robins_between_ready_subscriptions() {
1041        let mut context = Context::new();
1042        let first_config = sample_config_on_unused_port();
1043        let second_config = sample_config_on_unused_port();
1044
1045        let first_id = context.add_subscription(first_config.clone()).unwrap();
1046        context.join_subscription(first_id).unwrap();
1047        let second_id = context.add_subscription(second_config.clone()).unwrap();
1048        context.join_subscription(second_id).unwrap();
1049
1050        send_round_robin_test_packets(&first_config, &second_config);
1051
1052        let deadline = Instant::now() + Duration::from_secs(1);
1053
1054        let first_packet = recv_next_packet(&mut context, deadline);
1055        let second_packet = recv_next_packet(&mut context, deadline);
1056        let third_packet = recv_next_packet(&mut context, deadline);
1057
1058        assert_eq!(
1059            (first_packet.subscription_id, &first_packet.payload[..]),
1060            (first_id, b"first-1".as_slice())
1061        );
1062
1063        assert_eq!(
1064            (second_packet.subscription_id, &second_packet.payload[..]),
1065            (second_id, b"second-1".as_slice())
1066        );
1067
1068        assert_eq!(
1069            (third_packet.subscription_id, &third_packet.payload[..]),
1070            (first_id, b"first-2".as_slice())
1071        );
1072    }
1073
1074    #[test]
1075    fn try_recv_batch_into_returns_zero_when_no_packet_is_available() {
1076        let mut context = Context::new();
1077        let id = context
1078            .add_subscription(sample_config_on_unused_port())
1079            .unwrap();
1080        context.join_subscription(id).unwrap();
1081
1082        let mut packets = Vec::new();
1083        let received = context.try_recv_batch_into(&mut packets, 8).unwrap();
1084
1085        assert_eq!(received, 0);
1086        assert!(packets.is_empty());
1087    }
1088
1089    #[test]
1090    fn try_recv_batch_into_receives_up_to_max_packets() {
1091        let mut context = Context::new();
1092        let first_config = sample_config_on_unused_port();
1093        let second_config = sample_config_on_unused_port();
1094
1095        let first_id = context.add_subscription(first_config.clone()).unwrap();
1096        context.join_subscription(first_id).unwrap();
1097        let second_id = context.add_subscription(second_config.clone()).unwrap();
1098        context.join_subscription(second_id).unwrap();
1099
1100        send_round_robin_test_packets(&first_config, &second_config);
1101
1102        let deadline = Instant::now() + Duration::from_secs(1);
1103        let mut packets = Vec::new();
1104
1105        while packets.len() < 2 && Instant::now() < deadline {
1106            let before = packets.len();
1107            let received = context.try_recv_batch_into(&mut packets, 2).unwrap();
1108            assert!(received <= 2);
1109            assert_eq!(packets.len(), before + received);
1110
1111            if packets.len() < 2 {
1112                thread::sleep(Duration::from_millis(10));
1113            }
1114        }
1115
1116        assert!(packets.len() >= 2);
1117
1118        assert_eq!(
1119            (packets[0].subscription_id, &packets[0].payload[..]),
1120            (first_id, b"first-1".as_slice())
1121        );
1122
1123        assert_eq!(
1124            (packets[1].subscription_id, &packets[1].payload[..]),
1125            (second_id, b"second-1".as_slice())
1126        );
1127    }
1128
1129    #[test]
1130    fn try_recv_batch_with_metadata_into_receives_up_to_max_packets() {
1131        let mut context = Context::new();
1132        let first_config = sample_config_on_unused_port();
1133        let second_config = sample_config_on_unused_port();
1134
1135        let first_id = context.add_subscription(first_config.clone()).unwrap();
1136        context.join_subscription(first_id).unwrap();
1137        let second_id = context.add_subscription(second_config.clone()).unwrap();
1138        context.join_subscription(second_id).unwrap();
1139
1140        send_round_robin_test_packets(&first_config, &second_config);
1141
1142        let deadline = Instant::now() + Duration::from_secs(1);
1143        let mut packets = Vec::new();
1144
1145        while packets.len() < 2 && Instant::now() < deadline {
1146            let before = packets.len();
1147            let received = context
1148                .try_recv_batch_with_metadata_into(&mut packets, 2)
1149                .unwrap();
1150            assert!(received <= 2);
1151            assert_eq!(packets.len(), before + received);
1152
1153            if packets.len() < 2 {
1154                thread::sleep(Duration::from_millis(10));
1155            }
1156        }
1157
1158        assert!(packets.len() >= 2);
1159
1160        assert_eq!(
1161            (
1162                packets[0].packet.subscription_id,
1163                &packets[0].packet.payload[..]
1164            ),
1165            (first_id, b"first-1".as_slice())
1166        );
1167
1168        assert_eq!(
1169            (
1170                packets[1].packet.subscription_id,
1171                &packets[1].packet.payload[..]
1172            ),
1173            (second_id, b"second-1".as_slice())
1174        );
1175
1176        assert_eq!(
1177            packets[0].metadata.socket_local_addr,
1178            Some(std::net::SocketAddr::V4(SocketAddrV4::new(
1179                Ipv4Addr::UNSPECIFIED,
1180                first_config.dst_port,
1181            )))
1182        );
1183        assert_eq!(
1184            packets[1].metadata.socket_local_addr,
1185            Some(std::net::SocketAddr::V4(SocketAddrV4::new(
1186                Ipv4Addr::UNSPECIFIED,
1187                second_config.dst_port,
1188            )))
1189        );
1190    }
1191
1192    #[test]
1193    fn try_recv_all_into_drains_all_available_packets() {
1194        let mut context = Context::new();
1195        let first_config = sample_config_on_unused_port();
1196        let second_config = sample_config_on_unused_port();
1197
1198        let first_id = context.add_subscription(first_config.clone()).unwrap();
1199        context.join_subscription(first_id).unwrap();
1200        let second_id = context.add_subscription(second_config.clone()).unwrap();
1201        context.join_subscription(second_id).unwrap();
1202
1203        send_round_robin_test_packets(&first_config, &second_config);
1204
1205        let deadline = Instant::now() + Duration::from_secs(1);
1206        let mut packets = Vec::new();
1207
1208        while packets.len() < 3 && Instant::now() < deadline {
1209            context.try_recv_all_into(&mut packets).unwrap();
1210
1211            if packets.len() < 3 {
1212                thread::sleep(Duration::from_millis(10));
1213            }
1214        }
1215
1216        assert_eq!(packets.len(), 3);
1217
1218        assert_eq!(
1219            (packets[0].subscription_id, &packets[0].payload[..]),
1220            (first_id, b"first-1".as_slice())
1221        );
1222
1223        assert_eq!(
1224            (packets[1].subscription_id, &packets[1].payload[..]),
1225            (second_id, b"second-1".as_slice())
1226        );
1227
1228        assert_eq!(
1229            (packets[2].subscription_id, &packets[2].payload[..]),
1230            (first_id, b"first-2".as_slice())
1231        );
1232    }
1233
1234    #[test]
1235    fn try_recv_all_with_metadata_into_drains_all_available_packets() {
1236        let mut context = Context::new();
1237        let first_config = sample_config_on_unused_port();
1238        let second_config = sample_config_on_unused_port();
1239
1240        let first_id = context.add_subscription(first_config.clone()).unwrap();
1241        context.join_subscription(first_id).unwrap();
1242        let second_id = context.add_subscription(second_config.clone()).unwrap();
1243        context.join_subscription(second_id).unwrap();
1244
1245        send_round_robin_test_packets(&first_config, &second_config);
1246
1247        let deadline = Instant::now() + Duration::from_secs(1);
1248        let mut packets = Vec::new();
1249
1250        while packets.len() < 3 && Instant::now() < deadline {
1251            context
1252                .try_recv_all_with_metadata_into(&mut packets)
1253                .unwrap();
1254
1255            if packets.len() < 3 {
1256                thread::sleep(Duration::from_millis(10));
1257            }
1258        }
1259
1260        assert_eq!(packets.len(), 3);
1261
1262        assert_eq!(
1263            (
1264                packets[0].packet.subscription_id,
1265                &packets[0].packet.payload[..]
1266            ),
1267            (first_id, b"first-1".as_slice())
1268        );
1269
1270        assert_eq!(
1271            (
1272                packets[1].packet.subscription_id,
1273                &packets[1].packet.payload[..]
1274            ),
1275            (second_id, b"second-1".as_slice())
1276        );
1277
1278        assert_eq!(
1279            (
1280                packets[2].packet.subscription_id,
1281                &packets[2].packet.payload[..]
1282            ),
1283            (first_id, b"first-2".as_slice())
1284        );
1285    }
1286
1287    #[test]
1288    fn add_subscription_creates_bound_subscription() {
1289        let mut context = Context::new();
1290        let id = context
1291            .add_subscription(sample_config_on_unused_port())
1292            .unwrap();
1293
1294        let subscription = context.get_subscription(id).unwrap();
1295        assert_eq!(subscription.state(), SubscriptionState::Bound);
1296    }
1297
1298    #[test]
1299    fn join_subscription_transitions_bound_to_joined() {
1300        let mut context = Context::new();
1301        let id = context
1302            .add_subscription(sample_config_on_unused_port())
1303            .unwrap();
1304
1305        context.join_subscription(id).unwrap();
1306
1307        let subscription = context.get_subscription(id).unwrap();
1308        assert_eq!(subscription.state(), SubscriptionState::Joined);
1309    }
1310
1311    #[test]
1312    fn leave_subscription_transitions_joined_to_bound() {
1313        let mut context = Context::new();
1314        let id = context
1315            .add_subscription(sample_config_on_unused_port())
1316            .unwrap();
1317
1318        context.join_subscription(id).unwrap();
1319        context.leave_subscription(id).unwrap();
1320
1321        let subscription = context.get_subscription(id).unwrap();
1322        assert_eq!(subscription.state(), SubscriptionState::Bound);
1323    }
1324
1325    #[test]
1326    fn join_subscription_rejects_already_joined_subscription() {
1327        let mut context = Context::new();
1328        let id = context
1329            .add_subscription(sample_config_on_unused_port())
1330            .unwrap();
1331
1332        context.join_subscription(id).unwrap();
1333        let result = context.join_subscription(id);
1334
1335        assert!(matches!(result, Err(McrxError::SubscriptionAlreadyJoined)));
1336    }
1337
1338    #[test]
1339    fn leave_subscription_rejects_not_joined_subscription() {
1340        let mut context = Context::new();
1341        let id = context
1342            .add_subscription(sample_config_on_unused_port())
1343            .unwrap();
1344
1345        let result = context.leave_subscription(id);
1346
1347        assert!(matches!(result, Err(McrxError::SubscriptionNotJoined)));
1348    }
1349}