1#![allow(clippy::doc_markdown)]
5
6use alloc::collections::BTreeMap;
31use alloc::string::{String, ToString};
32use alloc::vec::Vec;
33
34use spg_sql::ast::CreateSubscriptionStatement;
35use spg_storage::{ColumnSchema, DataType, Row, Value};
36
37use crate::{Engine, EngineError, QueryResult};
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct Subscription {
41 pub conn_str: String,
42 pub publications: Vec<String>,
43 pub enabled: bool,
44 pub last_received_pos: u64,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq, Default)]
48pub struct Subscriptions {
49 inner: BTreeMap<String, Subscription>,
50}
51
52#[derive(Debug, PartialEq, Eq)]
53pub enum SubscriptionError {
54 DuplicateName(String),
55 Corrupt(String),
56}
57
58impl Subscriptions {
59 pub fn new() -> Self {
60 Self::default()
61 }
62
63 pub fn len(&self) -> usize {
64 self.inner.len()
65 }
66
67 pub fn is_empty(&self) -> bool {
68 self.inner.is_empty()
69 }
70
71 pub fn contains(&self, name: &str) -> bool {
72 self.inner.contains_key(name)
73 }
74
75 pub fn get(&self, name: &str) -> Option<&Subscription> {
76 self.inner.get(name)
77 }
78
79 pub fn iter(&self) -> impl Iterator<Item = (&String, &Subscription)> {
80 self.inner.iter()
81 }
82
83 pub fn create(&mut self, name: String, sub: Subscription) -> Result<(), SubscriptionError> {
84 if self.inner.contains_key(&name) {
85 return Err(SubscriptionError::DuplicateName(name));
86 }
87 self.inner.insert(name, sub);
88 Ok(())
89 }
90
91 pub fn drop(&mut self, name: &str) -> bool {
92 self.inner.remove(name).is_some()
93 }
94
95 pub fn update_last_received_pos(&mut self, name: &str, pos: u64) -> bool {
101 if let Some(s) = self.inner.get_mut(name) {
102 if pos > s.last_received_pos {
106 s.last_received_pos = pos;
107 }
108 true
109 } else {
110 false
111 }
112 }
113
114 pub fn serialize(&self) -> Vec<u8> {
126 let mut out = Vec::with_capacity(2 + self.inner.len() * 64);
127 let n = u16::try_from(self.inner.len()).expect("≤ 65,535 subscriptions per cluster");
128 out.extend_from_slice(&n.to_le_bytes());
129 for (name, sub) in &self.inner {
130 write_short_str(&mut out, name);
131 write_long_str(&mut out, &sub.conn_str);
132 let np = u16::try_from(sub.publications.len())
133 .expect("≤ 65,535 publications per subscription");
134 out.extend_from_slice(&np.to_le_bytes());
135 for p in &sub.publications {
136 write_short_str(&mut out, p);
137 }
138 out.push(u8::from(sub.enabled));
139 out.extend_from_slice(&sub.last_received_pos.to_le_bytes());
140 }
141 out
142 }
143
144 pub fn deserialize(buf: &[u8]) -> Result<Self, SubscriptionError> {
145 let mut p = 0usize;
146 let n = read_u16(buf, &mut p)? as usize;
147 let mut inner = BTreeMap::new();
148 for _ in 0..n {
149 let name = read_short_str(buf, &mut p)?;
150 let conn_str = read_long_str(buf, &mut p)?;
151 let np = read_u16(buf, &mut p)? as usize;
152 let mut publications = Vec::with_capacity(np);
153 for _ in 0..np {
154 publications.push(read_short_str(buf, &mut p)?);
155 }
156 let enabled_byte = read_u8(buf, &mut p)?;
157 let enabled = match enabled_byte {
158 0 => false,
159 1 => true,
160 other => {
161 return Err(SubscriptionError::Corrupt(alloc::format!(
162 "invalid `enabled` byte {other}, expected 0 or 1"
163 )));
164 }
165 };
166 let last_received_pos = read_u64(buf, &mut p)?;
167 if inner
168 .insert(
169 name.clone(),
170 Subscription {
171 conn_str,
172 publications,
173 enabled,
174 last_received_pos,
175 },
176 )
177 .is_some()
178 {
179 return Err(SubscriptionError::Corrupt(alloc::format!(
180 "duplicate subscription name {name:?} in serialised payload"
181 )));
182 }
183 }
184 if p != buf.len() {
185 return Err(SubscriptionError::Corrupt(alloc::format!(
186 "trailing bytes in subscriptions payload: read {p}, len {}",
187 buf.len()
188 )));
189 }
190 Ok(Self { inner })
191 }
192}
193
194fn write_short_str(out: &mut Vec<u8>, s: &str) {
195 let n = u16::try_from(s.len()).expect("subscription / publication name fits in u16");
196 out.extend_from_slice(&n.to_le_bytes());
197 out.extend_from_slice(s.as_bytes());
198}
199
200fn write_long_str(out: &mut Vec<u8>, s: &str) {
201 let n = u32::try_from(s.len()).expect("conn_str fits in u32");
203 out.extend_from_slice(&n.to_le_bytes());
204 out.extend_from_slice(s.as_bytes());
205}
206
207fn read_u8(buf: &[u8], p: &mut usize) -> Result<u8, SubscriptionError> {
208 let v = buf
209 .get(*p)
210 .copied()
211 .ok_or_else(|| SubscriptionError::Corrupt("short read (u8)".to_string()))?;
212 *p += 1;
213 Ok(v)
214}
215
216fn read_u16(buf: &[u8], p: &mut usize) -> Result<u16, SubscriptionError> {
217 let slice = buf
218 .get(*p..*p + 2)
219 .ok_or_else(|| SubscriptionError::Corrupt("short read (u16)".to_string()))?;
220 let arr: [u8; 2] = slice
221 .try_into()
222 .map_err(|_| SubscriptionError::Corrupt("u16 slice".to_string()))?;
223 *p += 2;
224 Ok(u16::from_le_bytes(arr))
225}
226
227fn read_u32_as_usize(buf: &[u8], p: &mut usize) -> Result<usize, SubscriptionError> {
228 let slice = buf
229 .get(*p..*p + 4)
230 .ok_or_else(|| SubscriptionError::Corrupt("short read (u32)".to_string()))?;
231 let arr: [u8; 4] = slice
232 .try_into()
233 .map_err(|_| SubscriptionError::Corrupt("u32 slice".to_string()))?;
234 *p += 4;
235 Ok(u32::from_le_bytes(arr) as usize)
236}
237
238fn read_u64(buf: &[u8], p: &mut usize) -> Result<u64, SubscriptionError> {
239 let slice = buf
240 .get(*p..*p + 8)
241 .ok_or_else(|| SubscriptionError::Corrupt("short read (u64)".to_string()))?;
242 let arr: [u8; 8] = slice
243 .try_into()
244 .map_err(|_| SubscriptionError::Corrupt("u64 slice".to_string()))?;
245 *p += 8;
246 Ok(u64::from_le_bytes(arr))
247}
248
249fn read_short_str(buf: &[u8], p: &mut usize) -> Result<String, SubscriptionError> {
250 let n = read_u16(buf, p)? as usize;
251 let slice = buf.get(*p..*p + n).ok_or_else(|| {
252 SubscriptionError::Corrupt(alloc::format!("short read (short str, {n} bytes)"))
253 })?;
254 *p += n;
255 core::str::from_utf8(slice)
256 .map(ToString::to_string)
257 .map_err(|e| SubscriptionError::Corrupt(alloc::format!("non-UTF-8 str: {e}")))
258}
259
260fn read_long_str(buf: &[u8], p: &mut usize) -> Result<String, SubscriptionError> {
261 let n = read_u32_as_usize(buf, p)?;
262 let slice = buf.get(*p..*p + n).ok_or_else(|| {
263 SubscriptionError::Corrupt(alloc::format!("short read (long str, {n} bytes)"))
264 })?;
265 *p += n;
266 core::str::from_utf8(slice)
267 .map(ToString::to_string)
268 .map_err(|e| SubscriptionError::Corrupt(alloc::format!("non-UTF-8 conn_str: {e}")))
269}
270
271impl Engine {
272 pub(crate) fn exec_show_subscriptions(&self) -> QueryResult {
278 let columns = alloc::vec![
279 ColumnSchema::new("name", DataType::Text, false),
280 ColumnSchema::new("conn_str", DataType::Text, false),
281 ColumnSchema::new("publications", DataType::Text, false),
282 ColumnSchema::new("enabled", DataType::Bool, false),
283 ColumnSchema::new("last_received_pos", DataType::BigInt, false),
284 ];
285 let rows: Vec<Row> = self
286 .subscriptions
287 .iter()
288 .map(|(name, sub)| {
289 Row::new(alloc::vec![
290 Value::Text(name.clone()),
291 Value::Text(sub.conn_str.clone()),
292 Value::Text(sub.publications.join(", ")),
293 Value::Bool(sub.enabled),
294 Value::BigInt(i64::try_from(sub.last_received_pos).unwrap_or(i64::MAX)),
295 ])
296 })
297 .collect();
298 QueryResult::Rows { columns, rows }
299 }
300
301 pub(crate) fn exec_create_subscription(
306 &mut self,
307 s: CreateSubscriptionStatement,
308 ) -> Result<QueryResult, EngineError> {
309 let sub = Subscription {
313 conn_str: s.conn_str,
314 publications: s.publications,
315 enabled: true,
316 last_received_pos: 0,
317 };
318 self.subscriptions
319 .create(s.name, sub)
320 .map_err(|e| EngineError::Unsupported(alloc::format!("CREATE SUBSCRIPTION: {e:?}")))?;
321 Ok(QueryResult::CommandOk {
322 affected: 1,
323 modified_catalog: true,
324 })
325 }
326
327 pub(crate) fn exec_drop_subscription(
335 &mut self,
336 name: &str,
337 ) -> Result<QueryResult, EngineError> {
338 let removed = self.subscriptions.drop(name);
339 Ok(QueryResult::CommandOk {
340 affected: usize::from(removed),
341 modified_catalog: removed,
342 })
343 }
344
345 pub const fn subscriptions(&self) -> &Subscriptions {
350 &self.subscriptions
351 }
352
353 pub fn subscription_advance(&mut self, name: &str, pos: u64) -> bool {
359 self.subscriptions.update_last_received_pos(name, pos)
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 fn mk(
368 name: &str,
369 host: &str,
370 pubs: &[&str],
371 enabled: bool,
372 pos: u64,
373 ) -> (String, Subscription) {
374 (
375 name.to_string(),
376 Subscription {
377 conn_str: alloc::format!("host=127.0.0.1 port={host}"),
378 publications: pubs.iter().map(|s| (*s).to_string()).collect(),
379 enabled,
380 last_received_pos: pos,
381 },
382 )
383 }
384
385 #[test]
386 fn empty_roundtrips() {
387 let s = Subscriptions::new();
388 let bytes = s.serialize();
389 assert_eq!(Subscriptions::deserialize(&bytes).unwrap(), s);
390 }
391
392 #[test]
393 fn single_subscription_roundtrips() {
394 let mut s = Subscriptions::new();
395 let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 0);
396 s.create(n, sub).unwrap();
397 let bytes = s.serialize();
398 let s2 = Subscriptions::deserialize(&bytes).unwrap();
399 assert_eq!(s2, s);
400 assert!(s2.contains("sub_a"));
401 }
402
403 #[test]
404 fn multi_publication_roundtrips_with_nontrivial_last_pos() {
405 let mut s = Subscriptions::new();
406 let (n, sub) = mk("sub_z", "20002", &["p1", "p2", "p3"], true, 1_234_567_890);
407 s.create(n, sub).unwrap();
408 let s2 = Subscriptions::deserialize(&s.serialize()).unwrap();
409 assert_eq!(s2, s);
410 let r = s2.get("sub_z").unwrap();
411 assert_eq!(r.publications, alloc::vec!["p1", "p2", "p3"]);
412 assert_eq!(r.last_received_pos, 1_234_567_890);
413 }
414
415 #[test]
416 fn disabled_roundtrips() {
417 let mut s = Subscriptions::new();
418 let (n, sub) = mk("sub_off", "20002", &["pub_a"], false, 42);
419 s.create(n, sub).unwrap();
420 let s2 = Subscriptions::deserialize(&s.serialize()).unwrap();
421 assert!(!s2.get("sub_off").unwrap().enabled);
422 }
423
424 #[test]
425 fn duplicate_name_errors() {
426 let mut s = Subscriptions::new();
427 let (n1, sub1) = mk("sub_a", "20002", &["pub_a"], true, 0);
428 s.create(n1, sub1).unwrap();
429 let (n2, sub2) = mk("sub_a", "20003", &["pub_b"], true, 0);
430 assert_eq!(
431 s.create(n2, sub2).unwrap_err(),
432 SubscriptionError::DuplicateName("sub_a".into())
433 );
434 }
435
436 #[test]
437 fn drop_present_and_absent() {
438 let mut s = Subscriptions::new();
439 let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 0);
440 s.create(n, sub).unwrap();
441 assert!(s.drop("sub_a"));
442 assert!(!s.drop("sub_a"));
443 assert!(!s.drop("never"));
444 }
445
446 #[test]
447 fn update_last_pos_monotone_and_absent_returns_false() {
448 let mut s = Subscriptions::new();
449 let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 100);
450 s.create(n, sub).unwrap();
451 assert!(s.update_last_received_pos("sub_a", 50)); assert_eq!(s.get("sub_a").unwrap().last_received_pos, 100);
453 assert!(s.update_last_received_pos("sub_a", 200));
454 assert_eq!(s.get("sub_a").unwrap().last_received_pos, 200);
455 assert!(!s.update_last_received_pos("missing", 1));
456 }
457
458 #[test]
459 fn corrupt_enabled_byte_errors() {
460 let mut buf = Vec::new();
462 buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&3u16.to_le_bytes());
465 buf.extend_from_slice(b"bad");
466 buf.extend_from_slice(&0u32.to_le_bytes()); buf.extend_from_slice(&0u16.to_le_bytes());
470 buf.push(2);
472 buf.extend_from_slice(&0u64.to_le_bytes());
474 let err = Subscriptions::deserialize(&buf).unwrap_err();
475 assert!(matches!(err, SubscriptionError::Corrupt(_)));
476 }
477
478 #[test]
479 fn deterministic_order_independent_of_insert_sequence() {
480 let mut s1 = Subscriptions::new();
481 let (n, sub) = mk("z", "20002", &["p1"], true, 0);
482 s1.create(n, sub).unwrap();
483 let (n, sub) = mk("a", "20003", &["p2"], true, 0);
484 s1.create(n, sub).unwrap();
485 let mut s2 = Subscriptions::new();
486 let (n, sub) = mk("a", "20003", &["p2"], true, 0);
487 s2.create(n, sub).unwrap();
488 let (n, sub) = mk("z", "20002", &["p1"], true, 0);
489 s2.create(n, sub).unwrap();
490 assert_eq!(s1.serialize(), s2.serialize());
491 }
492}