1#[cfg(feature = "metrics")]
2use crate::metrics::ContextMetricsSnapshot;
3use crate::{MctxError, Publication, PublicationConfig, PublicationId, SendReport};
4use socket2::Socket;
5use std::net::UdpSocket;
6
7#[derive(Debug)]
9pub struct Context {
10 publications: Vec<Publication>,
11 next_id: u64,
12}
13
14impl Default for Context {
15 fn default() -> Self {
16 Self::new()
17 }
18}
19
20impl Context {
21 pub fn new() -> Self {
23 Self {
24 publications: Vec::new(),
25 next_id: 1,
26 }
27 }
28
29 pub fn publication_count(&self) -> usize {
31 self.publications.len()
32 }
33
34 pub fn contains_publication(&self, id: PublicationId) -> bool {
36 self.publications
37 .iter()
38 .any(|publication| publication.id() == id)
39 }
40
41 pub fn get_publication(&self, id: PublicationId) -> Option<&Publication> {
43 self.publications
44 .iter()
45 .find(|publication| publication.id() == id)
46 }
47
48 pub fn get_publication_mut(&mut self, id: PublicationId) -> Option<&mut Publication> {
50 self.publications
51 .iter_mut()
52 .find(|publication| publication.id() == id)
53 }
54
55 pub fn add_publication(
57 &mut self,
58 config: PublicationConfig,
59 ) -> Result<PublicationId, MctxError> {
60 if self
61 .publications
62 .iter()
63 .any(|publication| publication.config() == &config)
64 {
65 return Err(MctxError::DuplicatePublication);
66 }
67
68 let id = self.next_publication_id();
69 let publication = Publication::new(id, config)?;
70 self.publications.push(publication);
71 Ok(id)
72 }
73
74 pub fn add_publication_with_socket(
76 &mut self,
77 config: PublicationConfig,
78 socket: Socket,
79 ) -> Result<PublicationId, MctxError> {
80 if self
81 .publications
82 .iter()
83 .any(|publication| publication.config() == &config)
84 {
85 return Err(MctxError::DuplicatePublication);
86 }
87
88 let id = self.next_publication_id();
89 let publication = Publication::new_with_socket(id, config, socket)?;
90 self.publications.push(publication);
91 Ok(id)
92 }
93
94 pub fn add_publication_with_udp_socket(
96 &mut self,
97 config: PublicationConfig,
98 socket: UdpSocket,
99 ) -> Result<PublicationId, MctxError> {
100 if self
101 .publications
102 .iter()
103 .any(|publication| publication.config() == &config)
104 {
105 return Err(MctxError::DuplicatePublication);
106 }
107
108 let id = self.next_publication_id();
109 let publication = Publication::new_with_udp_socket(id, config, socket)?;
110 self.publications.push(publication);
111 Ok(id)
112 }
113
114 pub fn remove_publication(&mut self, id: PublicationId) -> bool {
116 let Some(index) = self
117 .publications
118 .iter()
119 .position(|publication| publication.id() == id)
120 else {
121 return false;
122 };
123
124 self.publications.swap_remove(index);
125 true
126 }
127
128 pub fn take_publication(&mut self, id: PublicationId) -> Option<Publication> {
130 let index = self
131 .publications
132 .iter()
133 .position(|publication| publication.id() == id)?;
134
135 Some(self.publications.swap_remove(index))
136 }
137
138 pub fn publications(&self) -> &[Publication] {
140 &self.publications
141 }
142
143 pub fn publications_mut(&mut self) -> &mut [Publication] {
145 &mut self.publications
146 }
147
148 pub fn send(&self, id: PublicationId, payload: &[u8]) -> Result<SendReport, MctxError> {
150 let publication = self
151 .get_publication(id)
152 .ok_or(MctxError::PublicationNotFound)?;
153
154 publication.send(payload)
155 }
156
157 pub fn send_all(&self, payload: &[u8], out: &mut Vec<SendReport>) -> Result<usize, MctxError> {
161 let before = out.len();
162
163 for publication in &self.publications {
164 out.push(publication.send(payload)?);
165 }
166
167 Ok(out.len() - before)
168 }
169
170 #[cfg(feature = "metrics")]
172 pub fn metrics_snapshot(&self) -> ContextMetricsSnapshot {
173 let mut snapshot = ContextMetricsSnapshot {
174 publication_count: self.publications.len(),
175 ..ContextMetricsSnapshot::default()
176 };
177
178 for publication in &self.publications {
179 let publication_metrics = publication.metrics_snapshot();
180 snapshot.send_calls += publication_metrics.send_calls;
181 snapshot.packets_sent += publication_metrics.packets_sent;
182 snapshot.bytes_sent += publication_metrics.bytes_sent;
183 snapshot.send_errors += publication_metrics.send_errors;
184 }
185
186 snapshot
187 }
188
189 fn next_publication_id(&mut self) -> PublicationId {
190 let id = PublicationId(self.next_id);
191 self.next_id += 1;
192 id
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 #[cfg(feature = "metrics")]
200 use crate::metrics::ContextMetricsSampler;
201 use crate::test_support::{TEST_GROUP, recv_payload, test_multicast_receiver};
202 use std::net::Ipv4Addr;
203
204 #[test]
205 fn context_send_reaches_a_local_receiver() {
206 let (receiver, port) = test_multicast_receiver();
207 let mut context = Context::new();
208 let config = PublicationConfig::new(TEST_GROUP, port);
209 let id = context.add_publication(config).unwrap();
210
211 let report = context.send(id, b"context hello").unwrap();
212 let payload = recv_payload(&receiver);
213
214 assert_eq!(report.bytes_sent, b"context hello".len());
215 assert_eq!(payload, b"context hello");
216 }
217
218 #[test]
219 fn duplicate_publications_are_rejected() {
220 let mut context = Context::new();
221 let config = PublicationConfig::new(Ipv4Addr::new(239, 1, 2, 3), 5000);
222
223 context.add_publication(config.clone()).unwrap();
224 let result = context.add_publication(config);
225
226 assert!(matches!(result, Err(MctxError::DuplicatePublication)));
227 }
228
229 #[cfg(feature = "metrics")]
230 #[test]
231 fn context_metrics_track_successful_sends() {
232 let (_receiver, port) = test_multicast_receiver();
233 let mut context = Context::new();
234 let id = context
235 .add_publication(PublicationConfig::new(TEST_GROUP, port))
236 .unwrap();
237 let sampler = ContextMetricsSampler::new(&context);
238
239 context.send(id, b"metrics").unwrap();
240
241 let delta = sampler.delta();
242 assert_eq!(delta.publication_count_change, 0);
243 assert_eq!(delta.send_calls, 1);
244 assert_eq!(delta.packets_sent, 1);
245 assert_eq!(delta.bytes_sent, b"metrics".len() as u64);
246 assert_eq!(delta.send_errors, 0);
247 }
248}