1#[cfg(feature = "metrics")]
2use crate::metrics::ContextMetricsSnapshot;
3use crate::{MctxError, Publication, PublicationConfig, PublicationId, SendReport};
4use socket2::Socket;
5#[cfg(feature = "metrics")]
6use std::cell::Cell;
7use std::net::UdpSocket;
8#[cfg(feature = "metrics")]
9use std::time::SystemTime;
10
11#[cfg(feature = "metrics")]
12#[derive(Debug, Default)]
13struct ContextMetricsInner {
14 publications_added: Cell<u64>,
15 publications_removed: Cell<u64>,
16 total_send_calls: Cell<u64>,
17 total_packets_sent: Cell<u64>,
18 total_bytes_sent: Cell<u64>,
19 total_send_errors: Cell<u64>,
20}
21
22#[derive(Debug)]
24pub struct Context {
25 publications: Vec<Publication>,
26 next_id: u64,
27 #[cfg(feature = "metrics")]
28 metrics: ContextMetricsInner,
29}
30
31impl Default for Context {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl Context {
38 #[cfg(feature = "metrics")]
39 fn record_send_success(&self, bytes_sent: usize) {
40 self.metrics
41 .total_send_calls
42 .set(self.metrics.total_send_calls.get() + 1);
43 self.metrics
44 .total_packets_sent
45 .set(self.metrics.total_packets_sent.get() + 1);
46 self.metrics
47 .total_bytes_sent
48 .set(self.metrics.total_bytes_sent.get() + bytes_sent as u64);
49 }
50
51 #[cfg(feature = "metrics")]
52 fn record_send_error(&self) {
53 self.metrics
54 .total_send_calls
55 .set(self.metrics.total_send_calls.get() + 1);
56 self.metrics
57 .total_send_errors
58 .set(self.metrics.total_send_errors.get() + 1);
59 }
60
61 fn ensure_publication_config_is_unique(
62 &self,
63 config: &PublicationConfig,
64 ) -> Result<(), MctxError> {
65 if self
66 .publications
67 .iter()
68 .any(|publication| publication.config() == config)
69 {
70 return Err(MctxError::DuplicatePublication);
71 }
72
73 Ok(())
74 }
75
76 fn insert_publication(&mut self, publication: Publication) -> PublicationId {
77 let id = publication.id();
78 self.publications.push(publication);
79
80 #[cfg(feature = "metrics")]
81 self.metrics
82 .publications_added
83 .set(self.metrics.publications_added.get() + 1);
84
85 id
86 }
87
88 fn finish_publication_removal(&mut self, index: usize) -> Publication {
89 let publication = self.publications.swap_remove(index);
90
91 #[cfg(feature = "metrics")]
92 self.metrics
93 .publications_removed
94 .set(self.metrics.publications_removed.get() + 1);
95
96 publication
97 }
98
99 pub fn new() -> Self {
101 Self {
102 publications: Vec::new(),
103 next_id: 1,
104 #[cfg(feature = "metrics")]
105 metrics: ContextMetricsInner::default(),
106 }
107 }
108
109 pub fn publication_count(&self) -> usize {
111 self.publications.len()
112 }
113
114 pub fn contains_publication(&self, id: PublicationId) -> bool {
116 self.publications
117 .iter()
118 .any(|publication| publication.id() == id)
119 }
120
121 pub fn get_publication(&self, id: PublicationId) -> Option<&Publication> {
123 self.publications
124 .iter()
125 .find(|publication| publication.id() == id)
126 }
127
128 pub fn get_publication_mut(&mut self, id: PublicationId) -> Option<&mut Publication> {
130 self.publications
131 .iter_mut()
132 .find(|publication| publication.id() == id)
133 }
134
135 pub fn add_publication(
137 &mut self,
138 config: PublicationConfig,
139 ) -> Result<PublicationId, MctxError> {
140 self.ensure_publication_config_is_unique(&config)?;
141
142 let id = self.next_publication_id();
143 let publication = Publication::new(id, config)?;
144 Ok(self.insert_publication(publication))
145 }
146
147 pub fn add_publication_with_socket(
149 &mut self,
150 config: PublicationConfig,
151 socket: Socket,
152 ) -> Result<PublicationId, MctxError> {
153 self.ensure_publication_config_is_unique(&config)?;
154
155 let id = self.next_publication_id();
156 let publication = Publication::new_with_socket(id, config, socket)?;
157 Ok(self.insert_publication(publication))
158 }
159
160 pub fn add_publication_with_udp_socket(
162 &mut self,
163 config: PublicationConfig,
164 socket: UdpSocket,
165 ) -> Result<PublicationId, MctxError> {
166 self.ensure_publication_config_is_unique(&config)?;
167
168 let id = self.next_publication_id();
169 let publication = Publication::new_with_udp_socket(id, config, socket)?;
170 Ok(self.insert_publication(publication))
171 }
172
173 pub fn remove_publication(&mut self, id: PublicationId) -> bool {
175 self.take_publication(id).is_some()
176 }
177
178 pub fn take_publication(&mut self, id: PublicationId) -> Option<Publication> {
180 let index = self
181 .publications
182 .iter()
183 .position(|publication| publication.id() == id)?;
184
185 Some(self.finish_publication_removal(index))
186 }
187
188 pub fn publications(&self) -> &[Publication] {
190 &self.publications
191 }
192
193 pub fn publications_mut(&mut self) -> &mut [Publication] {
195 &mut self.publications
196 }
197
198 pub fn send(&self, id: PublicationId, payload: &[u8]) -> Result<SendReport, MctxError> {
200 let publication = self
201 .get_publication(id)
202 .ok_or(MctxError::PublicationNotFound)?;
203
204 match publication.send(payload) {
205 Ok(report) => {
206 #[cfg(feature = "metrics")]
207 self.record_send_success(report.bytes_sent);
208
209 Ok(report)
210 }
211 Err(error) => {
212 #[cfg(feature = "metrics")]
213 self.record_send_error();
214
215 Err(error)
216 }
217 }
218 }
219
220 pub fn send_all(&self, payload: &[u8], out: &mut Vec<SendReport>) -> Result<usize, MctxError> {
224 let before = out.len();
225
226 for publication in &self.publications {
227 match publication.send(payload) {
228 Ok(report) => {
229 #[cfg(feature = "metrics")]
230 self.record_send_success(report.bytes_sent);
231
232 out.push(report);
233 }
234 Err(error) => {
235 #[cfg(feature = "metrics")]
236 self.record_send_error();
237
238 return Err(error);
239 }
240 }
241 }
242
243 Ok(out.len() - before)
244 }
245
246 #[cfg(feature = "metrics")]
253 pub fn metrics_snapshot(&self) -> ContextMetricsSnapshot {
254 ContextMetricsSnapshot {
255 publications_added: self.metrics.publications_added.get(),
256 publications_removed: self.metrics.publications_removed.get(),
257 active_publications: self.publications.len(),
258 total_send_calls: self.metrics.total_send_calls.get(),
259 total_packets_sent: self.metrics.total_packets_sent.get(),
260 total_bytes_sent: self.metrics.total_bytes_sent.get(),
261 total_send_errors: self.metrics.total_send_errors.get(),
262 captured_at: SystemTime::now(),
263 }
264 }
265
266 fn next_publication_id(&mut self) -> PublicationId {
267 let id = PublicationId(self.next_id);
268 self.next_id += 1;
269 id
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 #[cfg(feature = "metrics")]
277 use crate::metrics::ContextMetricsSampler;
278 use crate::test_support::{TEST_GROUP, recv_payload, test_multicast_receiver};
279 use std::net::Ipv4Addr;
280
281 #[test]
282 fn context_send_reaches_a_local_receiver() {
283 let (receiver, port) = test_multicast_receiver();
284 let mut context = Context::new();
285 let config = PublicationConfig::new(TEST_GROUP, port);
286 let id = context.add_publication(config).unwrap();
287
288 let report = context.send(id, b"context hello").unwrap();
289 let payload = recv_payload(&receiver);
290
291 assert_eq!(report.bytes_sent, b"context hello".len());
292 assert_eq!(payload, b"context hello");
293 }
294
295 #[test]
296 fn duplicate_publications_are_rejected() {
297 let mut context = Context::new();
298 let config = PublicationConfig::new(Ipv4Addr::new(239, 1, 2, 3), 5000);
299
300 context.add_publication(config.clone()).unwrap();
301 let result = context.add_publication(config);
302
303 assert!(matches!(result, Err(MctxError::DuplicatePublication)));
304 }
305
306 #[cfg(feature = "metrics")]
307 #[test]
308 fn context_metrics_track_successful_sends() {
309 let (_receiver, port) = test_multicast_receiver();
310 let mut context = Context::new();
311 let id = context
312 .add_publication(PublicationConfig::new(TEST_GROUP, port))
313 .unwrap();
314 let mut sampler = ContextMetricsSampler::new(&context);
315
316 assert!(sampler.sample().is_none());
317 context.send(id, b"metrics").unwrap();
318
319 let snapshot = context.metrics_snapshot();
320 let delta = sampler.sample().unwrap();
321
322 assert_eq!(snapshot.publications_added, 1);
323 assert_eq!(snapshot.publications_removed, 0);
324 assert_eq!(snapshot.active_publications, 1);
325 assert_eq!(snapshot.total_send_calls, 1);
326 assert_eq!(snapshot.total_packets_sent, 1);
327 assert_eq!(snapshot.total_bytes_sent, b"metrics".len() as u64);
328 assert_eq!(snapshot.total_send_errors, 0);
329 assert_eq!(delta.publications_added, 0);
330 assert_eq!(delta.publications_removed, 0);
331 assert_eq!(delta.send_calls, 1);
332 assert_eq!(delta.packets_sent, 1);
333 assert_eq!(delta.bytes_sent, b"metrics".len() as u64);
334 assert_eq!(delta.send_errors, 0);
335 }
336
337 #[cfg(feature = "metrics")]
338 #[test]
339 fn context_metrics_totals_survive_publication_removal() {
340 let (_receiver, port) = test_multicast_receiver();
341 let mut context = Context::new();
342 let id = context
343 .add_publication(PublicationConfig::new(TEST_GROUP, port))
344 .unwrap();
345
346 context.send(id, b"lifetime").unwrap();
347 let before_removal = context.metrics_snapshot();
348 assert!(context.remove_publication(id));
349
350 let after_removal = context.metrics_snapshot();
351
352 assert_eq!(before_removal.total_packets_sent, 1);
353 assert_eq!(before_removal.total_bytes_sent, b"lifetime".len() as u64);
354 assert_eq!(after_removal.total_packets_sent, 1);
355 assert_eq!(after_removal.total_bytes_sent, b"lifetime".len() as u64);
356 assert_eq!(after_removal.active_publications, 0);
357 assert_eq!(after_removal.publications_removed, 1);
358 }
359}