Skip to main content

spg_engine/
subscriptions.rs

1// pedantic doc_markdown flags the embedded wire-format spec block
2// and a handful of proper nouns; allowing at the module level
3// keeps the spec readable.
4#![allow(clippy::doc_markdown)]
5
6//! v6.1.4 — logical-replication subscription catalog.
7//!
8//! In-memory table of subscriptions, owned by the engine. The
9//! catalog persists across restarts via the snapshot envelope's
10//! v4 trailer block (see `crate::lib::build_envelope`) — same
11//! mechanism v6.1.2 added for publications, just an extra section.
12//!
13//! Subscriptions are the receive side of logical replication. A
14//! `CreateSubscription` row holds:
15//!   - `name`              the local identifier
16//!   - `conn_str`          PG keyword=value string the worker
17//!                         parses for `host=…` and `port=…`
18//!   - `publications`      list of remote publication names
19//!   - `enabled`           v6.1.4 hard-codes to `true`; ALTER
20//!                         SUBSCRIPTION ENABLE / DISABLE lands
21//!                         in a future sub-version
22//!   - `last_received_pos` master-WAL byte offset the worker has
23//!                         applied through (updated live by the
24//!                         worker, persisted at the next snapshot)
25//!
26//! The worker itself lives in `spg-server::replication::
27//! run_subscription_worker` — the engine layer only owns the
28//! catalog state, snapshots, and answers `SHOW SUBSCRIPTIONS`.
29
30use alloc::collections::BTreeMap;
31use alloc::string::{String, ToString};
32use alloc::vec::Vec;
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct Subscription {
36    pub conn_str: String,
37    pub publications: Vec<String>,
38    pub enabled: bool,
39    pub last_received_pos: u64,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Default)]
43pub struct Subscriptions {
44    inner: BTreeMap<String, Subscription>,
45}
46
47#[derive(Debug, PartialEq, Eq)]
48pub enum SubscriptionError {
49    DuplicateName(String),
50    Corrupt(String),
51}
52
53impl Subscriptions {
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    pub fn len(&self) -> usize {
59        self.inner.len()
60    }
61
62    pub fn is_empty(&self) -> bool {
63        self.inner.is_empty()
64    }
65
66    pub fn contains(&self, name: &str) -> bool {
67        self.inner.contains_key(name)
68    }
69
70    pub fn get(&self, name: &str) -> Option<&Subscription> {
71        self.inner.get(name)
72    }
73
74    pub fn iter(&self) -> impl Iterator<Item = (&String, &Subscription)> {
75        self.inner.iter()
76    }
77
78    pub fn create(
79        &mut self,
80        name: String,
81        sub: Subscription,
82    ) -> Result<(), SubscriptionError> {
83        if self.inner.contains_key(&name) {
84            return Err(SubscriptionError::DuplicateName(name));
85        }
86        self.inner.insert(name, sub);
87        Ok(())
88    }
89
90    pub fn drop(&mut self, name: &str) -> bool {
91        self.inner.remove(name).is_some()
92    }
93
94    /// v6.1.4 — update the worker's last-applied master-WAL
95    /// offset. Called by the subscription worker after each apply
96    /// batch. Returns false when the subscription was dropped
97    /// between when the worker fetched the record and when this
98    /// call landed (so the worker can shut down cleanly).
99    pub fn update_last_received_pos(&mut self, name: &str, pos: u64) -> bool {
100        if let Some(s) = self.inner.get_mut(name) {
101            // Monotone: ignore stale updates (a future restart
102            // resuming from a sidecar may send an older pos than
103            // the live worker has already passed).
104            if pos > s.last_received_pos {
105                s.last_received_pos = pos;
106            }
107            true
108        } else {
109            false
110        }
111    }
112
113    // ── serialisation (envelope v4 trailer) ─────────────────────
114
115    /// Format:
116    ///   [u16 num_subscriptions]
117    ///   for each:
118    ///     [u16 name_len][name bytes]
119    ///     [u32 conn_str_len][conn_str bytes]
120    ///     [u16 num_pubs]
121    ///     for each: [u16 p_len][p bytes]
122    ///     [u8 enabled]
123    ///     [u64 last_received_pos]
124    pub fn serialize(&self) -> Vec<u8> {
125        let mut out = Vec::with_capacity(2 + self.inner.len() * 64);
126        let n = u16::try_from(self.inner.len()).expect("≤ 65,535 subscriptions per cluster");
127        out.extend_from_slice(&n.to_le_bytes());
128        for (name, sub) in &self.inner {
129            write_short_str(&mut out, name);
130            write_long_str(&mut out, &sub.conn_str);
131            let np =
132                u16::try_from(sub.publications.len()).expect("≤ 65,535 publications per subscription");
133            out.extend_from_slice(&np.to_le_bytes());
134            for p in &sub.publications {
135                write_short_str(&mut out, p);
136            }
137            out.push(u8::from(sub.enabled));
138            out.extend_from_slice(&sub.last_received_pos.to_le_bytes());
139        }
140        out
141    }
142
143    pub fn deserialize(buf: &[u8]) -> Result<Self, SubscriptionError> {
144        let mut p = 0usize;
145        let n = read_u16(buf, &mut p)? as usize;
146        let mut inner = BTreeMap::new();
147        for _ in 0..n {
148            let name = read_short_str(buf, &mut p)?;
149            let conn_str = read_long_str(buf, &mut p)?;
150            let np = read_u16(buf, &mut p)? as usize;
151            let mut publications = Vec::with_capacity(np);
152            for _ in 0..np {
153                publications.push(read_short_str(buf, &mut p)?);
154            }
155            let enabled_byte = read_u8(buf, &mut p)?;
156            let enabled = match enabled_byte {
157                0 => false,
158                1 => true,
159                other => {
160                    return Err(SubscriptionError::Corrupt(alloc::format!(
161                        "invalid `enabled` byte {other}, expected 0 or 1"
162                    )));
163                }
164            };
165            let last_received_pos = read_u64(buf, &mut p)?;
166            if inner
167                .insert(
168                    name.clone(),
169                    Subscription {
170                        conn_str,
171                        publications,
172                        enabled,
173                        last_received_pos,
174                    },
175                )
176                .is_some()
177            {
178                return Err(SubscriptionError::Corrupt(alloc::format!(
179                    "duplicate subscription name {name:?} in serialised payload"
180                )));
181            }
182        }
183        if p != buf.len() {
184            return Err(SubscriptionError::Corrupt(alloc::format!(
185                "trailing bytes in subscriptions payload: read {p}, len {}",
186                buf.len()
187            )));
188        }
189        Ok(Self { inner })
190    }
191}
192
193fn write_short_str(out: &mut Vec<u8>, s: &str) {
194    let n = u16::try_from(s.len()).expect("subscription / publication name fits in u16");
195    out.extend_from_slice(&n.to_le_bytes());
196    out.extend_from_slice(s.as_bytes());
197}
198
199fn write_long_str(out: &mut Vec<u8>, s: &str) {
200    // conn_str may be up to a few hundred bytes; u32 keeps headroom.
201    let n = u32::try_from(s.len()).expect("conn_str fits in u32");
202    out.extend_from_slice(&n.to_le_bytes());
203    out.extend_from_slice(s.as_bytes());
204}
205
206fn read_u8(buf: &[u8], p: &mut usize) -> Result<u8, SubscriptionError> {
207    let v = buf
208        .get(*p)
209        .copied()
210        .ok_or_else(|| SubscriptionError::Corrupt("short read (u8)".to_string()))?;
211    *p += 1;
212    Ok(v)
213}
214
215fn read_u16(buf: &[u8], p: &mut usize) -> Result<u16, SubscriptionError> {
216    let slice = buf
217        .get(*p..*p + 2)
218        .ok_or_else(|| SubscriptionError::Corrupt("short read (u16)".to_string()))?;
219    let arr: [u8; 2] = slice
220        .try_into()
221        .map_err(|_| SubscriptionError::Corrupt("u16 slice".to_string()))?;
222    *p += 2;
223    Ok(u16::from_le_bytes(arr))
224}
225
226fn read_u32_as_usize(buf: &[u8], p: &mut usize) -> Result<usize, SubscriptionError> {
227    let slice = buf
228        .get(*p..*p + 4)
229        .ok_or_else(|| SubscriptionError::Corrupt("short read (u32)".to_string()))?;
230    let arr: [u8; 4] = slice
231        .try_into()
232        .map_err(|_| SubscriptionError::Corrupt("u32 slice".to_string()))?;
233    *p += 4;
234    Ok(u32::from_le_bytes(arr) as usize)
235}
236
237fn read_u64(buf: &[u8], p: &mut usize) -> Result<u64, SubscriptionError> {
238    let slice = buf
239        .get(*p..*p + 8)
240        .ok_or_else(|| SubscriptionError::Corrupt("short read (u64)".to_string()))?;
241    let arr: [u8; 8] = slice
242        .try_into()
243        .map_err(|_| SubscriptionError::Corrupt("u64 slice".to_string()))?;
244    *p += 8;
245    Ok(u64::from_le_bytes(arr))
246}
247
248fn read_short_str(buf: &[u8], p: &mut usize) -> Result<String, SubscriptionError> {
249    let n = read_u16(buf, p)? as usize;
250    let slice = buf.get(*p..*p + n).ok_or_else(|| {
251        SubscriptionError::Corrupt(alloc::format!("short read (short str, {n} bytes)"))
252    })?;
253    *p += n;
254    core::str::from_utf8(slice)
255        .map(ToString::to_string)
256        .map_err(|e| SubscriptionError::Corrupt(alloc::format!("non-UTF-8 str: {e}")))
257}
258
259fn read_long_str(buf: &[u8], p: &mut usize) -> Result<String, SubscriptionError> {
260    let n = read_u32_as_usize(buf, p)?;
261    let slice = buf.get(*p..*p + n).ok_or_else(|| {
262        SubscriptionError::Corrupt(alloc::format!("short read (long str, {n} bytes)"))
263    })?;
264    *p += n;
265    core::str::from_utf8(slice)
266        .map(ToString::to_string)
267        .map_err(|e| SubscriptionError::Corrupt(alloc::format!("non-UTF-8 conn_str: {e}")))
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    fn mk(name: &str, host: &str, pubs: &[&str], enabled: bool, pos: u64) -> (String, Subscription) {
275        (
276            name.to_string(),
277            Subscription {
278                conn_str: alloc::format!("host=127.0.0.1 port={host}"),
279                publications: pubs.iter().map(|s| (*s).to_string()).collect(),
280                enabled,
281                last_received_pos: pos,
282            },
283        )
284    }
285
286    #[test]
287    fn empty_roundtrips() {
288        let s = Subscriptions::new();
289        let bytes = s.serialize();
290        assert_eq!(Subscriptions::deserialize(&bytes).unwrap(), s);
291    }
292
293    #[test]
294    fn single_subscription_roundtrips() {
295        let mut s = Subscriptions::new();
296        let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 0);
297        s.create(n, sub).unwrap();
298        let bytes = s.serialize();
299        let s2 = Subscriptions::deserialize(&bytes).unwrap();
300        assert_eq!(s2, s);
301        assert!(s2.contains("sub_a"));
302    }
303
304    #[test]
305    fn multi_publication_roundtrips_with_nontrivial_last_pos() {
306        let mut s = Subscriptions::new();
307        let (n, sub) = mk("sub_z", "20002", &["p1", "p2", "p3"], true, 1_234_567_890);
308        s.create(n, sub).unwrap();
309        let s2 = Subscriptions::deserialize(&s.serialize()).unwrap();
310        assert_eq!(s2, s);
311        let r = s2.get("sub_z").unwrap();
312        assert_eq!(r.publications, alloc::vec!["p1", "p2", "p3"]);
313        assert_eq!(r.last_received_pos, 1_234_567_890);
314    }
315
316    #[test]
317    fn disabled_roundtrips() {
318        let mut s = Subscriptions::new();
319        let (n, sub) = mk("sub_off", "20002", &["pub_a"], false, 42);
320        s.create(n, sub).unwrap();
321        let s2 = Subscriptions::deserialize(&s.serialize()).unwrap();
322        assert!(!s2.get("sub_off").unwrap().enabled);
323    }
324
325    #[test]
326    fn duplicate_name_errors() {
327        let mut s = Subscriptions::new();
328        let (n1, sub1) = mk("sub_a", "20002", &["pub_a"], true, 0);
329        s.create(n1, sub1).unwrap();
330        let (n2, sub2) = mk("sub_a", "20003", &["pub_b"], true, 0);
331        assert_eq!(
332            s.create(n2, sub2).unwrap_err(),
333            SubscriptionError::DuplicateName("sub_a".into())
334        );
335    }
336
337    #[test]
338    fn drop_present_and_absent() {
339        let mut s = Subscriptions::new();
340        let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 0);
341        s.create(n, sub).unwrap();
342        assert!(s.drop("sub_a"));
343        assert!(!s.drop("sub_a"));
344        assert!(!s.drop("never"));
345    }
346
347    #[test]
348    fn update_last_pos_monotone_and_absent_returns_false() {
349        let mut s = Subscriptions::new();
350        let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 100);
351        s.create(n, sub).unwrap();
352        assert!(s.update_last_received_pos("sub_a", 50)); // ignored (older)
353        assert_eq!(s.get("sub_a").unwrap().last_received_pos, 100);
354        assert!(s.update_last_received_pos("sub_a", 200));
355        assert_eq!(s.get("sub_a").unwrap().last_received_pos, 200);
356        assert!(!s.update_last_received_pos("missing", 1));
357    }
358
359    #[test]
360    fn corrupt_enabled_byte_errors() {
361        // Forge a payload with an invalid enabled byte (2).
362        let mut buf = Vec::new();
363        buf.extend_from_slice(&1u16.to_le_bytes()); // n = 1
364        // name
365        buf.extend_from_slice(&3u16.to_le_bytes());
366        buf.extend_from_slice(b"bad");
367        // conn_str
368        buf.extend_from_slice(&0u32.to_le_bytes()); // empty
369        // pubs (zero)
370        buf.extend_from_slice(&0u16.to_le_bytes());
371        // bogus enabled
372        buf.push(2);
373        // last_received_pos
374        buf.extend_from_slice(&0u64.to_le_bytes());
375        let err = Subscriptions::deserialize(&buf).unwrap_err();
376        assert!(matches!(err, SubscriptionError::Corrupt(_)));
377    }
378
379    #[test]
380    fn deterministic_order_independent_of_insert_sequence() {
381        let mut s1 = Subscriptions::new();
382        let (n, sub) = mk("z", "20002", &["p1"], true, 0);
383        s1.create(n, sub).unwrap();
384        let (n, sub) = mk("a", "20003", &["p2"], true, 0);
385        s1.create(n, sub).unwrap();
386        let mut s2 = Subscriptions::new();
387        let (n, sub) = mk("a", "20003", &["p2"], true, 0);
388        s2.create(n, sub).unwrap();
389        let (n, sub) = mk("z", "20002", &["p1"], true, 0);
390        s2.create(n, sub).unwrap();
391        assert_eq!(s1.serialize(), s2.serialize());
392    }
393}