1#[cfg(feature = "metrics")]
2use crate::metrics::ContextMetricsSnapshot;
3use crate::{MctxError, Publication, PublicationConfig, PublicationId, SendReport};
4use socket2::Socket;
5use std::net::UdpSocket;
6#[cfg(feature = "metrics")]
7use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
8#[cfg(feature = "metrics")]
9use std::time::SystemTime;
10
11#[cfg(feature = "metrics")]
12#[derive(Debug, Default)]
13struct ContextMetricsInner {
14 publications_added: AtomicU64,
15 publications_removed: AtomicU64,
16 total_send_calls: AtomicU64,
17 total_packets_sent: AtomicU64,
18 total_bytes_sent: AtomicU64,
19 total_send_errors: AtomicU64,
20}
21
22#[derive(Debug)]
24pub struct Context {
25 publications: Vec<Publication>,
26 next_id: u64,
27 #[cfg(feature = "metrics")]
28 metrics: ContextMetricsInner,
29}
30
31impl Default for Context {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl Context {
38 #[cfg(feature = "metrics")]
39 fn record_send_success(&self, bytes_sent: usize) {
40 self.metrics.total_send_calls.fetch_add(1, Relaxed);
41 self.metrics.total_packets_sent.fetch_add(1, Relaxed);
42 self.metrics
43 .total_bytes_sent
44 .fetch_add(bytes_sent as u64, Relaxed);
45 }
46
47 #[cfg(feature = "metrics")]
48 fn record_send_error(&self) {
49 self.metrics.total_send_calls.fetch_add(1, Relaxed);
50 self.metrics.total_send_errors.fetch_add(1, Relaxed);
51 }
52
53 fn ensure_publication_config_is_unique(
54 &self,
55 config: &PublicationConfig,
56 ) -> Result<(), MctxError> {
57 if self
58 .publications
59 .iter()
60 .any(|publication| publication.config() == config)
61 {
62 return Err(MctxError::DuplicatePublication);
63 }
64
65 Ok(())
66 }
67
68 fn insert_publication(&mut self, publication: Publication) -> PublicationId {
69 let id = publication.id();
70 self.publications.push(publication);
71
72 #[cfg(feature = "metrics")]
73 self.metrics.publications_added.fetch_add(1, Relaxed);
74
75 id
76 }
77
78 fn finish_publication_removal(&mut self, index: usize) -> Publication {
79 let publication = self.publications.swap_remove(index);
80
81 #[cfg(feature = "metrics")]
82 self.metrics.publications_removed.fetch_add(1, Relaxed);
83
84 publication
85 }
86
87 pub fn new() -> Self {
89 Self {
90 publications: Vec::new(),
91 next_id: 1,
92 #[cfg(feature = "metrics")]
93 metrics: ContextMetricsInner::default(),
94 }
95 }
96
97 pub fn publication_count(&self) -> usize {
99 self.publications.len()
100 }
101
102 pub fn contains_publication(&self, id: PublicationId) -> bool {
104 self.publications
105 .iter()
106 .any(|publication| publication.id() == id)
107 }
108
109 pub fn get_publication(&self, id: PublicationId) -> Option<&Publication> {
111 self.publications
112 .iter()
113 .find(|publication| publication.id() == id)
114 }
115
116 pub fn get_publication_mut(&mut self, id: PublicationId) -> Option<&mut Publication> {
118 self.publications
119 .iter_mut()
120 .find(|publication| publication.id() == id)
121 }
122
123 pub fn add_publication(
125 &mut self,
126 config: PublicationConfig,
127 ) -> Result<PublicationId, MctxError> {
128 self.ensure_publication_config_is_unique(&config)?;
129
130 let id = self.next_publication_id();
131 let publication = Publication::new(id, config)?;
132 Ok(self.insert_publication(publication))
133 }
134
135 pub fn add_publication_with_socket(
137 &mut self,
138 config: PublicationConfig,
139 socket: Socket,
140 ) -> Result<PublicationId, MctxError> {
141 self.ensure_publication_config_is_unique(&config)?;
142
143 let id = self.next_publication_id();
144 let publication = Publication::new_with_socket(id, config, socket)?;
145 Ok(self.insert_publication(publication))
146 }
147
148 pub fn add_publication_with_udp_socket(
150 &mut self,
151 config: PublicationConfig,
152 socket: UdpSocket,
153 ) -> Result<PublicationId, MctxError> {
154 self.ensure_publication_config_is_unique(&config)?;
155
156 let id = self.next_publication_id();
157 let publication = Publication::new_with_udp_socket(id, config, socket)?;
158 Ok(self.insert_publication(publication))
159 }
160
161 pub fn remove_publication(&mut self, id: PublicationId) -> bool {
163 self.take_publication(id).is_some()
164 }
165
166 pub fn take_publication(&mut self, id: PublicationId) -> Option<Publication> {
168 let index = self
169 .publications
170 .iter()
171 .position(|publication| publication.id() == id)?;
172
173 Some(self.finish_publication_removal(index))
174 }
175
176 pub fn publications(&self) -> &[Publication] {
178 &self.publications
179 }
180
181 pub fn publications_mut(&mut self) -> &mut [Publication] {
183 &mut self.publications
184 }
185
186 pub fn send(&self, id: PublicationId, payload: &[u8]) -> Result<SendReport, MctxError> {
188 let publication = self
189 .get_publication(id)
190 .ok_or(MctxError::PublicationNotFound)?;
191
192 match publication.send(payload) {
193 Ok(report) => {
194 #[cfg(feature = "metrics")]
195 self.record_send_success(report.bytes_sent);
196
197 Ok(report)
198 }
199 Err(error) => {
200 #[cfg(feature = "metrics")]
201 self.record_send_error();
202
203 Err(error)
204 }
205 }
206 }
207
208 pub fn send_all(&self, payload: &[u8], out: &mut Vec<SendReport>) -> Result<usize, MctxError> {
212 let before = out.len();
213
214 for publication in &self.publications {
215 match publication.send(payload) {
216 Ok(report) => {
217 #[cfg(feature = "metrics")]
218 self.record_send_success(report.bytes_sent);
219
220 out.push(report);
221 }
222 Err(error) => {
223 #[cfg(feature = "metrics")]
224 self.record_send_error();
225
226 return Err(error);
227 }
228 }
229 }
230
231 Ok(out.len() - before)
232 }
233
234 #[cfg(feature = "metrics")]
241 pub fn metrics_snapshot(&self) -> ContextMetricsSnapshot {
242 ContextMetricsSnapshot {
243 publications_added: self.metrics.publications_added.load(Relaxed),
244 publications_removed: self.metrics.publications_removed.load(Relaxed),
245 active_publications: self.publications.len(),
246 total_send_calls: self.metrics.total_send_calls.load(Relaxed),
247 total_packets_sent: self.metrics.total_packets_sent.load(Relaxed),
248 total_bytes_sent: self.metrics.total_bytes_sent.load(Relaxed),
249 total_send_errors: self.metrics.total_send_errors.load(Relaxed),
250 captured_at: SystemTime::now(),
251 }
252 }
253
254 fn next_publication_id(&mut self) -> PublicationId {
255 let id = PublicationId(self.next_id);
256 self.next_id += 1;
257 id
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 #[cfg(feature = "metrics")]
265 use crate::metrics::ContextMetricsSampler;
266 use crate::test_support::{TEST_GROUP, recv_payload, test_multicast_receiver};
267 use std::net::Ipv4Addr;
268
269 #[test]
270 fn context_send_reaches_a_local_receiver() {
271 let (receiver, port) = test_multicast_receiver();
272 let mut context = Context::new();
273 let config = PublicationConfig::new(TEST_GROUP, port);
274 let id = context.add_publication(config).unwrap();
275
276 let report = context.send(id, b"context hello").unwrap();
277 let payload = recv_payload(&receiver);
278
279 assert_eq!(report.bytes_sent, b"context hello".len());
280 assert_eq!(payload, b"context hello");
281 }
282
283 #[test]
284 fn duplicate_publications_are_rejected() {
285 let mut context = Context::new();
286 let config = PublicationConfig::new(Ipv4Addr::new(239, 1, 2, 3), 5000);
287
288 context.add_publication(config.clone()).unwrap();
289 let result = context.add_publication(config);
290
291 assert!(matches!(result, Err(MctxError::DuplicatePublication)));
292 }
293
294 #[cfg(feature = "metrics")]
295 #[test]
296 fn context_metrics_track_successful_sends() {
297 let (_receiver, port) = test_multicast_receiver();
298 let mut context = Context::new();
299 let id = context
300 .add_publication(PublicationConfig::new(TEST_GROUP, port))
301 .unwrap();
302 let mut sampler = ContextMetricsSampler::new(&context);
303
304 assert!(sampler.sample().is_none());
305 context.send(id, b"metrics").unwrap();
306
307 let snapshot = context.metrics_snapshot();
308 let delta = sampler.sample().unwrap();
309
310 assert_eq!(snapshot.publications_added, 1);
311 assert_eq!(snapshot.publications_removed, 0);
312 assert_eq!(snapshot.active_publications, 1);
313 assert_eq!(snapshot.total_send_calls, 1);
314 assert_eq!(snapshot.total_packets_sent, 1);
315 assert_eq!(snapshot.total_bytes_sent, b"metrics".len() as u64);
316 assert_eq!(snapshot.total_send_errors, 0);
317 assert_eq!(delta.publications_added, 0);
318 assert_eq!(delta.publications_removed, 0);
319 assert_eq!(delta.send_calls, 1);
320 assert_eq!(delta.packets_sent, 1);
321 assert_eq!(delta.bytes_sent, b"metrics".len() as u64);
322 assert_eq!(delta.send_errors, 0);
323 }
324
325 #[cfg(feature = "metrics")]
326 #[test]
327 fn context_metrics_totals_survive_publication_removal() {
328 let (_receiver, port) = test_multicast_receiver();
329 let mut context = Context::new();
330 let id = context
331 .add_publication(PublicationConfig::new(TEST_GROUP, port))
332 .unwrap();
333
334 context.send(id, b"lifetime").unwrap();
335 let before_removal = context.metrics_snapshot();
336 assert!(context.remove_publication(id));
337
338 let after_removal = context.metrics_snapshot();
339
340 assert_eq!(before_removal.total_packets_sent, 1);
341 assert_eq!(before_removal.total_bytes_sent, b"lifetime".len() as u64);
342 assert_eq!(after_removal.total_packets_sent, 1);
343 assert_eq!(after_removal.total_bytes_sent, b"lifetime".len() as u64);
344 assert_eq!(after_removal.active_publications, 0);
345 assert_eq!(after_removal.publications_removed, 1);
346 }
347
348 #[cfg(feature = "metrics")]
349 #[test]
350 fn context_is_sync_with_metrics_enabled() {
351 fn assert_sync<T: Sync>() {}
352
353 assert_sync::<Context>();
354 }
355}