1#![allow(clippy::doc_markdown)]
5
6use alloc::collections::BTreeMap;
31use alloc::string::{String, ToString};
32use alloc::vec::Vec;
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct Subscription {
36 pub conn_str: String,
37 pub publications: Vec<String>,
38 pub enabled: bool,
39 pub last_received_pos: u64,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Default)]
43pub struct Subscriptions {
44 inner: BTreeMap<String, Subscription>,
45}
46
47#[derive(Debug, PartialEq, Eq)]
48pub enum SubscriptionError {
49 DuplicateName(String),
50 Corrupt(String),
51}
52
53impl Subscriptions {
54 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn len(&self) -> usize {
59 self.inner.len()
60 }
61
62 pub fn is_empty(&self) -> bool {
63 self.inner.is_empty()
64 }
65
66 pub fn contains(&self, name: &str) -> bool {
67 self.inner.contains_key(name)
68 }
69
70 pub fn get(&self, name: &str) -> Option<&Subscription> {
71 self.inner.get(name)
72 }
73
74 pub fn iter(&self) -> impl Iterator<Item = (&String, &Subscription)> {
75 self.inner.iter()
76 }
77
78 pub fn create(
79 &mut self,
80 name: String,
81 sub: Subscription,
82 ) -> Result<(), SubscriptionError> {
83 if self.inner.contains_key(&name) {
84 return Err(SubscriptionError::DuplicateName(name));
85 }
86 self.inner.insert(name, sub);
87 Ok(())
88 }
89
90 pub fn drop(&mut self, name: &str) -> bool {
91 self.inner.remove(name).is_some()
92 }
93
94 pub fn update_last_received_pos(&mut self, name: &str, pos: u64) -> bool {
100 if let Some(s) = self.inner.get_mut(name) {
101 if pos > s.last_received_pos {
105 s.last_received_pos = pos;
106 }
107 true
108 } else {
109 false
110 }
111 }
112
113 pub fn serialize(&self) -> Vec<u8> {
125 let mut out = Vec::with_capacity(2 + self.inner.len() * 64);
126 let n = u16::try_from(self.inner.len()).expect("≤ 65,535 subscriptions per cluster");
127 out.extend_from_slice(&n.to_le_bytes());
128 for (name, sub) in &self.inner {
129 write_short_str(&mut out, name);
130 write_long_str(&mut out, &sub.conn_str);
131 let np =
132 u16::try_from(sub.publications.len()).expect("≤ 65,535 publications per subscription");
133 out.extend_from_slice(&np.to_le_bytes());
134 for p in &sub.publications {
135 write_short_str(&mut out, p);
136 }
137 out.push(u8::from(sub.enabled));
138 out.extend_from_slice(&sub.last_received_pos.to_le_bytes());
139 }
140 out
141 }
142
143 pub fn deserialize(buf: &[u8]) -> Result<Self, SubscriptionError> {
144 let mut p = 0usize;
145 let n = read_u16(buf, &mut p)? as usize;
146 let mut inner = BTreeMap::new();
147 for _ in 0..n {
148 let name = read_short_str(buf, &mut p)?;
149 let conn_str = read_long_str(buf, &mut p)?;
150 let np = read_u16(buf, &mut p)? as usize;
151 let mut publications = Vec::with_capacity(np);
152 for _ in 0..np {
153 publications.push(read_short_str(buf, &mut p)?);
154 }
155 let enabled_byte = read_u8(buf, &mut p)?;
156 let enabled = match enabled_byte {
157 0 => false,
158 1 => true,
159 other => {
160 return Err(SubscriptionError::Corrupt(alloc::format!(
161 "invalid `enabled` byte {other}, expected 0 or 1"
162 )));
163 }
164 };
165 let last_received_pos = read_u64(buf, &mut p)?;
166 if inner
167 .insert(
168 name.clone(),
169 Subscription {
170 conn_str,
171 publications,
172 enabled,
173 last_received_pos,
174 },
175 )
176 .is_some()
177 {
178 return Err(SubscriptionError::Corrupt(alloc::format!(
179 "duplicate subscription name {name:?} in serialised payload"
180 )));
181 }
182 }
183 if p != buf.len() {
184 return Err(SubscriptionError::Corrupt(alloc::format!(
185 "trailing bytes in subscriptions payload: read {p}, len {}",
186 buf.len()
187 )));
188 }
189 Ok(Self { inner })
190 }
191}
192
193fn write_short_str(out: &mut Vec<u8>, s: &str) {
194 let n = u16::try_from(s.len()).expect("subscription / publication name fits in u16");
195 out.extend_from_slice(&n.to_le_bytes());
196 out.extend_from_slice(s.as_bytes());
197}
198
199fn write_long_str(out: &mut Vec<u8>, s: &str) {
200 let n = u32::try_from(s.len()).expect("conn_str fits in u32");
202 out.extend_from_slice(&n.to_le_bytes());
203 out.extend_from_slice(s.as_bytes());
204}
205
206fn read_u8(buf: &[u8], p: &mut usize) -> Result<u8, SubscriptionError> {
207 let v = buf
208 .get(*p)
209 .copied()
210 .ok_or_else(|| SubscriptionError::Corrupt("short read (u8)".to_string()))?;
211 *p += 1;
212 Ok(v)
213}
214
215fn read_u16(buf: &[u8], p: &mut usize) -> Result<u16, SubscriptionError> {
216 let slice = buf
217 .get(*p..*p + 2)
218 .ok_or_else(|| SubscriptionError::Corrupt("short read (u16)".to_string()))?;
219 let arr: [u8; 2] = slice
220 .try_into()
221 .map_err(|_| SubscriptionError::Corrupt("u16 slice".to_string()))?;
222 *p += 2;
223 Ok(u16::from_le_bytes(arr))
224}
225
226fn read_u32_as_usize(buf: &[u8], p: &mut usize) -> Result<usize, SubscriptionError> {
227 let slice = buf
228 .get(*p..*p + 4)
229 .ok_or_else(|| SubscriptionError::Corrupt("short read (u32)".to_string()))?;
230 let arr: [u8; 4] = slice
231 .try_into()
232 .map_err(|_| SubscriptionError::Corrupt("u32 slice".to_string()))?;
233 *p += 4;
234 Ok(u32::from_le_bytes(arr) as usize)
235}
236
237fn read_u64(buf: &[u8], p: &mut usize) -> Result<u64, SubscriptionError> {
238 let slice = buf
239 .get(*p..*p + 8)
240 .ok_or_else(|| SubscriptionError::Corrupt("short read (u64)".to_string()))?;
241 let arr: [u8; 8] = slice
242 .try_into()
243 .map_err(|_| SubscriptionError::Corrupt("u64 slice".to_string()))?;
244 *p += 8;
245 Ok(u64::from_le_bytes(arr))
246}
247
248fn read_short_str(buf: &[u8], p: &mut usize) -> Result<String, SubscriptionError> {
249 let n = read_u16(buf, p)? as usize;
250 let slice = buf.get(*p..*p + n).ok_or_else(|| {
251 SubscriptionError::Corrupt(alloc::format!("short read (short str, {n} bytes)"))
252 })?;
253 *p += n;
254 core::str::from_utf8(slice)
255 .map(ToString::to_string)
256 .map_err(|e| SubscriptionError::Corrupt(alloc::format!("non-UTF-8 str: {e}")))
257}
258
259fn read_long_str(buf: &[u8], p: &mut usize) -> Result<String, SubscriptionError> {
260 let n = read_u32_as_usize(buf, p)?;
261 let slice = buf.get(*p..*p + n).ok_or_else(|| {
262 SubscriptionError::Corrupt(alloc::format!("short read (long str, {n} bytes)"))
263 })?;
264 *p += n;
265 core::str::from_utf8(slice)
266 .map(ToString::to_string)
267 .map_err(|e| SubscriptionError::Corrupt(alloc::format!("non-UTF-8 conn_str: {e}")))
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 fn mk(name: &str, host: &str, pubs: &[&str], enabled: bool, pos: u64) -> (String, Subscription) {
275 (
276 name.to_string(),
277 Subscription {
278 conn_str: alloc::format!("host=127.0.0.1 port={host}"),
279 publications: pubs.iter().map(|s| (*s).to_string()).collect(),
280 enabled,
281 last_received_pos: pos,
282 },
283 )
284 }
285
286 #[test]
287 fn empty_roundtrips() {
288 let s = Subscriptions::new();
289 let bytes = s.serialize();
290 assert_eq!(Subscriptions::deserialize(&bytes).unwrap(), s);
291 }
292
293 #[test]
294 fn single_subscription_roundtrips() {
295 let mut s = Subscriptions::new();
296 let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 0);
297 s.create(n, sub).unwrap();
298 let bytes = s.serialize();
299 let s2 = Subscriptions::deserialize(&bytes).unwrap();
300 assert_eq!(s2, s);
301 assert!(s2.contains("sub_a"));
302 }
303
304 #[test]
305 fn multi_publication_roundtrips_with_nontrivial_last_pos() {
306 let mut s = Subscriptions::new();
307 let (n, sub) = mk("sub_z", "20002", &["p1", "p2", "p3"], true, 1_234_567_890);
308 s.create(n, sub).unwrap();
309 let s2 = Subscriptions::deserialize(&s.serialize()).unwrap();
310 assert_eq!(s2, s);
311 let r = s2.get("sub_z").unwrap();
312 assert_eq!(r.publications, alloc::vec!["p1", "p2", "p3"]);
313 assert_eq!(r.last_received_pos, 1_234_567_890);
314 }
315
316 #[test]
317 fn disabled_roundtrips() {
318 let mut s = Subscriptions::new();
319 let (n, sub) = mk("sub_off", "20002", &["pub_a"], false, 42);
320 s.create(n, sub).unwrap();
321 let s2 = Subscriptions::deserialize(&s.serialize()).unwrap();
322 assert!(!s2.get("sub_off").unwrap().enabled);
323 }
324
325 #[test]
326 fn duplicate_name_errors() {
327 let mut s = Subscriptions::new();
328 let (n1, sub1) = mk("sub_a", "20002", &["pub_a"], true, 0);
329 s.create(n1, sub1).unwrap();
330 let (n2, sub2) = mk("sub_a", "20003", &["pub_b"], true, 0);
331 assert_eq!(
332 s.create(n2, sub2).unwrap_err(),
333 SubscriptionError::DuplicateName("sub_a".into())
334 );
335 }
336
337 #[test]
338 fn drop_present_and_absent() {
339 let mut s = Subscriptions::new();
340 let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 0);
341 s.create(n, sub).unwrap();
342 assert!(s.drop("sub_a"));
343 assert!(!s.drop("sub_a"));
344 assert!(!s.drop("never"));
345 }
346
347 #[test]
348 fn update_last_pos_monotone_and_absent_returns_false() {
349 let mut s = Subscriptions::new();
350 let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 100);
351 s.create(n, sub).unwrap();
352 assert!(s.update_last_received_pos("sub_a", 50)); assert_eq!(s.get("sub_a").unwrap().last_received_pos, 100);
354 assert!(s.update_last_received_pos("sub_a", 200));
355 assert_eq!(s.get("sub_a").unwrap().last_received_pos, 200);
356 assert!(!s.update_last_received_pos("missing", 1));
357 }
358
359 #[test]
360 fn corrupt_enabled_byte_errors() {
361 let mut buf = Vec::new();
363 buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&3u16.to_le_bytes());
366 buf.extend_from_slice(b"bad");
367 buf.extend_from_slice(&0u32.to_le_bytes()); buf.extend_from_slice(&0u16.to_le_bytes());
371 buf.push(2);
373 buf.extend_from_slice(&0u64.to_le_bytes());
375 let err = Subscriptions::deserialize(&buf).unwrap_err();
376 assert!(matches!(err, SubscriptionError::Corrupt(_)));
377 }
378
379 #[test]
380 fn deterministic_order_independent_of_insert_sequence() {
381 let mut s1 = Subscriptions::new();
382 let (n, sub) = mk("z", "20002", &["p1"], true, 0);
383 s1.create(n, sub).unwrap();
384 let (n, sub) = mk("a", "20003", &["p2"], true, 0);
385 s1.create(n, sub).unwrap();
386 let mut s2 = Subscriptions::new();
387 let (n, sub) = mk("a", "20003", &["p2"], true, 0);
388 s2.create(n, sub).unwrap();
389 let (n, sub) = mk("z", "20002", &["p1"], true, 0);
390 s2.create(n, sub).unwrap();
391 assert_eq!(s1.serialize(), s2.serialize());
392 }
393}