Skip to main content

mctx_core/
context.rs

1#[cfg(feature = "metrics")]
2use crate::metrics::ContextMetricsSnapshot;
3use crate::{MctxError, Publication, PublicationConfig, PublicationId, SendReport};
4use socket2::Socket;
5#[cfg(feature = "metrics")]
6use std::cell::Cell;
7use std::net::UdpSocket;
8#[cfg(feature = "metrics")]
9use std::time::SystemTime;
10
11#[cfg(feature = "metrics")]
12#[derive(Debug, Default)]
13struct ContextMetricsInner {
14    publications_added: Cell<u64>,
15    publications_removed: Cell<u64>,
16    total_send_calls: Cell<u64>,
17    total_packets_sent: Cell<u64>,
18    total_bytes_sent: Cell<u64>,
19    total_send_errors: Cell<u64>,
20}
21
22/// Small owner for a set of multicast publication sockets.
23#[derive(Debug)]
24pub struct Context {
25    publications: Vec<Publication>,
26    next_id: u64,
27    #[cfg(feature = "metrics")]
28    metrics: ContextMetricsInner,
29}
30
31impl Default for Context {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl Context {
38    #[cfg(feature = "metrics")]
39    fn record_send_success(&self, bytes_sent: usize) {
40        self.metrics
41            .total_send_calls
42            .set(self.metrics.total_send_calls.get() + 1);
43        self.metrics
44            .total_packets_sent
45            .set(self.metrics.total_packets_sent.get() + 1);
46        self.metrics
47            .total_bytes_sent
48            .set(self.metrics.total_bytes_sent.get() + bytes_sent as u64);
49    }
50
51    #[cfg(feature = "metrics")]
52    fn record_send_error(&self) {
53        self.metrics
54            .total_send_calls
55            .set(self.metrics.total_send_calls.get() + 1);
56        self.metrics
57            .total_send_errors
58            .set(self.metrics.total_send_errors.get() + 1);
59    }
60
61    fn ensure_publication_config_is_unique(
62        &self,
63        config: &PublicationConfig,
64    ) -> Result<(), MctxError> {
65        if self
66            .publications
67            .iter()
68            .any(|publication| publication.config() == config)
69        {
70            return Err(MctxError::DuplicatePublication);
71        }
72
73        Ok(())
74    }
75
76    fn insert_publication(&mut self, publication: Publication) -> PublicationId {
77        let id = publication.id();
78        self.publications.push(publication);
79
80        #[cfg(feature = "metrics")]
81        self.metrics
82            .publications_added
83            .set(self.metrics.publications_added.get() + 1);
84
85        id
86    }
87
88    fn finish_publication_removal(&mut self, index: usize) -> Publication {
89        let publication = self.publications.swap_remove(index);
90
91        #[cfg(feature = "metrics")]
92        self.metrics
93            .publications_removed
94            .set(self.metrics.publications_removed.get() + 1);
95
96        publication
97    }
98
99    /// Creates an empty multicast sender context.
100    pub fn new() -> Self {
101        Self {
102            publications: Vec::new(),
103            next_id: 1,
104            #[cfg(feature = "metrics")]
105            metrics: ContextMetricsInner::default(),
106        }
107    }
108
109    /// Returns the number of tracked publications.
110    pub fn publication_count(&self) -> usize {
111        self.publications.len()
112    }
113
114    /// Returns whether a publication ID exists in the context.
115    pub fn contains_publication(&self, id: PublicationId) -> bool {
116        self.publications
117            .iter()
118            .any(|publication| publication.id() == id)
119    }
120
121    /// Returns an immutable reference to one publication.
122    pub fn get_publication(&self, id: PublicationId) -> Option<&Publication> {
123        self.publications
124            .iter()
125            .find(|publication| publication.id() == id)
126    }
127
128    /// Returns a mutable reference to one publication.
129    pub fn get_publication_mut(&mut self, id: PublicationId) -> Option<&mut Publication> {
130        self.publications
131            .iter_mut()
132            .find(|publication| publication.id() == id)
133    }
134
135    /// Creates a new publication socket from configuration and stores it.
136    pub fn add_publication(
137        &mut self,
138        config: PublicationConfig,
139    ) -> Result<PublicationId, MctxError> {
140        self.ensure_publication_config_is_unique(&config)?;
141
142        let id = self.next_publication_id();
143        let publication = Publication::new(id, config)?;
144        Ok(self.insert_publication(publication))
145    }
146
147    /// Stores an existing socket as a publication after configuring it.
148    pub fn add_publication_with_socket(
149        &mut self,
150        config: PublicationConfig,
151        socket: Socket,
152    ) -> Result<PublicationId, MctxError> {
153        self.ensure_publication_config_is_unique(&config)?;
154
155        let id = self.next_publication_id();
156        let publication = Publication::new_with_socket(id, config, socket)?;
157        Ok(self.insert_publication(publication))
158    }
159
160    /// Stores an existing standard-library UDP socket as a publication after configuring it.
161    pub fn add_publication_with_udp_socket(
162        &mut self,
163        config: PublicationConfig,
164        socket: UdpSocket,
165    ) -> Result<PublicationId, MctxError> {
166        self.ensure_publication_config_is_unique(&config)?;
167
168        let id = self.next_publication_id();
169        let publication = Publication::new_with_udp_socket(id, config, socket)?;
170        Ok(self.insert_publication(publication))
171    }
172
173    /// Removes one publication and drops its socket.
174    pub fn remove_publication(&mut self, id: PublicationId) -> bool {
175        self.take_publication(id).is_some()
176    }
177
178    /// Extracts one publication from the context.
179    pub fn take_publication(&mut self, id: PublicationId) -> Option<Publication> {
180        let index = self
181            .publications
182            .iter()
183            .position(|publication| publication.id() == id)?;
184
185        Some(self.finish_publication_removal(index))
186    }
187
188    /// Returns all tracked publications.
189    pub fn publications(&self) -> &[Publication] {
190        &self.publications
191    }
192
193    /// Returns all tracked publications mutably.
194    pub fn publications_mut(&mut self) -> &mut [Publication] {
195        &mut self.publications
196    }
197
198    /// Sends one payload through the selected publication.
199    pub fn send(&self, id: PublicationId, payload: &[u8]) -> Result<SendReport, MctxError> {
200        let publication = self
201            .get_publication(id)
202            .ok_or(MctxError::PublicationNotFound)?;
203
204        match publication.send(payload) {
205            Ok(report) => {
206                #[cfg(feature = "metrics")]
207                self.record_send_success(report.bytes_sent);
208
209                Ok(report)
210            }
211            Err(error) => {
212                #[cfg(feature = "metrics")]
213                self.record_send_error();
214
215                Err(error)
216            }
217        }
218    }
219
220    /// Sends the same payload through every publication and pushes reports into `out`.
221    ///
222    /// If one publication fails, reports already written into `out` are preserved.
223    pub fn send_all(&self, payload: &[u8], out: &mut Vec<SendReport>) -> Result<usize, MctxError> {
224        let before = out.len();
225
226        for publication in &self.publications {
227            match publication.send(payload) {
228                Ok(report) => {
229                    #[cfg(feature = "metrics")]
230                    self.record_send_success(report.bytes_sent);
231
232                    out.push(report);
233                }
234                Err(error) => {
235                    #[cfg(feature = "metrics")]
236                    self.record_send_error();
237
238                    return Err(error);
239                }
240            }
241        }
242
243        Ok(out.len() - before)
244    }
245
246    /// Returns a snapshot of the context's current metrics.
247    ///
248    /// Counter fields such as `total_packets_sent` are cumulative for the
249    /// lifetime of the context for send activity issued through `Context`
250    /// methods. They are not recomputed from the currently active publications,
251    /// and they do not decrease when a publication is removed.
252    #[cfg(feature = "metrics")]
253    pub fn metrics_snapshot(&self) -> ContextMetricsSnapshot {
254        ContextMetricsSnapshot {
255            publications_added: self.metrics.publications_added.get(),
256            publications_removed: self.metrics.publications_removed.get(),
257            active_publications: self.publications.len(),
258            total_send_calls: self.metrics.total_send_calls.get(),
259            total_packets_sent: self.metrics.total_packets_sent.get(),
260            total_bytes_sent: self.metrics.total_bytes_sent.get(),
261            total_send_errors: self.metrics.total_send_errors.get(),
262            captured_at: SystemTime::now(),
263        }
264    }
265
266    fn next_publication_id(&mut self) -> PublicationId {
267        let id = PublicationId(self.next_id);
268        self.next_id += 1;
269        id
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    #[cfg(feature = "metrics")]
277    use crate::metrics::ContextMetricsSampler;
278    use crate::test_support::{TEST_GROUP, recv_payload, test_multicast_receiver};
279    use std::net::Ipv4Addr;
280
281    #[test]
282    fn context_send_reaches_a_local_receiver() {
283        let (receiver, port) = test_multicast_receiver();
284        let mut context = Context::new();
285        let config = PublicationConfig::new(TEST_GROUP, port);
286        let id = context.add_publication(config).unwrap();
287
288        let report = context.send(id, b"context hello").unwrap();
289        let payload = recv_payload(&receiver);
290
291        assert_eq!(report.bytes_sent, b"context hello".len());
292        assert_eq!(payload, b"context hello");
293    }
294
295    #[test]
296    fn duplicate_publications_are_rejected() {
297        let mut context = Context::new();
298        let config = PublicationConfig::new(Ipv4Addr::new(239, 1, 2, 3), 5000);
299
300        context.add_publication(config.clone()).unwrap();
301        let result = context.add_publication(config);
302
303        assert!(matches!(result, Err(MctxError::DuplicatePublication)));
304    }
305
306    #[cfg(feature = "metrics")]
307    #[test]
308    fn context_metrics_track_successful_sends() {
309        let (_receiver, port) = test_multicast_receiver();
310        let mut context = Context::new();
311        let id = context
312            .add_publication(PublicationConfig::new(TEST_GROUP, port))
313            .unwrap();
314        let mut sampler = ContextMetricsSampler::new(&context);
315
316        assert!(sampler.sample().is_none());
317        context.send(id, b"metrics").unwrap();
318
319        let snapshot = context.metrics_snapshot();
320        let delta = sampler.sample().unwrap();
321
322        assert_eq!(snapshot.publications_added, 1);
323        assert_eq!(snapshot.publications_removed, 0);
324        assert_eq!(snapshot.active_publications, 1);
325        assert_eq!(snapshot.total_send_calls, 1);
326        assert_eq!(snapshot.total_packets_sent, 1);
327        assert_eq!(snapshot.total_bytes_sent, b"metrics".len() as u64);
328        assert_eq!(snapshot.total_send_errors, 0);
329        assert_eq!(delta.publications_added, 0);
330        assert_eq!(delta.publications_removed, 0);
331        assert_eq!(delta.send_calls, 1);
332        assert_eq!(delta.packets_sent, 1);
333        assert_eq!(delta.bytes_sent, b"metrics".len() as u64);
334        assert_eq!(delta.send_errors, 0);
335    }
336
337    #[cfg(feature = "metrics")]
338    #[test]
339    fn context_metrics_totals_survive_publication_removal() {
340        let (_receiver, port) = test_multicast_receiver();
341        let mut context = Context::new();
342        let id = context
343            .add_publication(PublicationConfig::new(TEST_GROUP, port))
344            .unwrap();
345
346        context.send(id, b"lifetime").unwrap();
347        let before_removal = context.metrics_snapshot();
348        assert!(context.remove_publication(id));
349
350        let after_removal = context.metrics_snapshot();
351
352        assert_eq!(before_removal.total_packets_sent, 1);
353        assert_eq!(before_removal.total_bytes_sent, b"lifetime".len() as u64);
354        assert_eq!(after_removal.total_packets_sent, 1);
355        assert_eq!(after_removal.total_bytes_sent, b"lifetime".len() as u64);
356        assert_eq!(after_removal.active_publications, 0);
357        assert_eq!(after_removal.publications_removed, 1);
358    }
359}