Skip to main content

reddb_wire/replication/
wal_stream.rs

1use serde_json::Value as JsonValue;
2
3use super::util::{
4    get_bool_default, get_opt_string, get_opt_u64, get_string, get_u64, hex_decode, hex_encode,
5    object_from_slice, ReplicationPayloadError, Result,
6};
7use super::{catchup::CatchupModeReply, DEFAULT_REPLICATION_TERM};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct WalStreamOpen {
11    pub since_lsn: u64,
12    pub max_count: usize,
13    pub replica_id: Option<String>,
14    pub term: u64,
15    pub await_data: bool,
16    pub await_timeout_ms: u64,
17}
18
19impl WalStreamOpen {
20    pub fn encode_json(&self) -> Vec<u8> {
21        let mut obj = serde_json::Map::new();
22        obj.insert(
23            "since_lsn".to_string(),
24            JsonValue::Number(self.since_lsn.into()),
25        );
26        obj.insert(
27            "max_count".to_string(),
28            JsonValue::Number((self.max_count as u64).into()),
29        );
30        if let Some(replica_id) = &self.replica_id {
31            obj.insert(
32                "replica_id".to_string(),
33                JsonValue::String(replica_id.clone()),
34            );
35        }
36        obj.insert("term".to_string(), JsonValue::Number(self.term.into()));
37        obj.insert("await_data".to_string(), JsonValue::Bool(self.await_data));
38        obj.insert(
39            "await_timeout_ms".to_string(),
40            JsonValue::Number(self.await_timeout_ms.into()),
41        );
42        serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
43    }
44
45    pub fn decode_json(bytes: &[u8]) -> Result<Self> {
46        let obj = object_from_slice(bytes)?;
47        let max_count = get_opt_u64(&obj, "max_count").unwrap_or(1000);
48        Ok(Self {
49            since_lsn: get_opt_u64(&obj, "since_lsn").unwrap_or(0),
50            max_count: usize::try_from(max_count)
51                .map_err(|_| ReplicationPayloadError::InvalidField("max_count"))?,
52            replica_id: get_opt_string(&obj, "replica_id"),
53            term: get_opt_u64(&obj, "term").unwrap_or(DEFAULT_REPLICATION_TERM),
54            await_data: get_bool_default(&obj, "await_data", false),
55            await_timeout_ms: get_opt_u64(&obj, "await_timeout_ms").unwrap_or(30_000),
56        })
57    }
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub struct WalStreamRecord {
62    pub lsn: u64,
63    pub data: Vec<u8>,
64}
65
66impl WalStreamRecord {
67    fn to_json(&self) -> JsonValue {
68        let mut obj = serde_json::Map::new();
69        obj.insert("lsn".to_string(), JsonValue::Number(self.lsn.into()));
70        obj.insert(
71            "data".to_string(),
72            JsonValue::String(hex_encode(&self.data)),
73        );
74        JsonValue::Object(obj)
75    }
76
77    fn from_json(value: &JsonValue) -> Result<Self> {
78        let obj = value
79            .as_object()
80            .ok_or(ReplicationPayloadError::InvalidField("records"))?;
81        let data_hex = get_string(obj, "data")?;
82        Ok(Self {
83            lsn: get_u64(obj, "lsn")?,
84            data: hex_decode("data", &data_hex)?,
85        })
86    }
87}
88
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub struct WalStreamChunk {
91    pub records: Vec<WalStreamRecord>,
92    pub current_lsn: u64,
93    pub oldest_available_lsn: Option<u64>,
94    pub partial_resync: bool,
95    pub partial_resync_count: u64,
96    pub needs_rebootstrap: bool,
97    pub invalidation_reason: Option<String>,
98    pub catchup: Option<CatchupModeReply>,
99}
100
101impl WalStreamChunk {
102    pub fn encode_json(&self) -> Vec<u8> {
103        let mut obj = serde_json::Map::new();
104        obj.insert(
105            "records".to_string(),
106            JsonValue::Array(self.records.iter().map(WalStreamRecord::to_json).collect()),
107        );
108        obj.insert(
109            "current_lsn".to_string(),
110            JsonValue::Number(self.current_lsn.into()),
111        );
112        if let Some(lsn) = self.oldest_available_lsn {
113            obj.insert(
114                "oldest_available_lsn".to_string(),
115                JsonValue::Number(lsn.into()),
116            );
117        }
118        obj.insert(
119            "partial_resync".to_string(),
120            JsonValue::Bool(self.partial_resync),
121        );
122        obj.insert(
123            "partial_resync_count".to_string(),
124            JsonValue::Number(self.partial_resync_count.into()),
125        );
126        obj.insert(
127            "needs_rebootstrap".to_string(),
128            JsonValue::Bool(self.needs_rebootstrap),
129        );
130        if let Some(reason) = &self.invalidation_reason {
131            obj.insert(
132                "invalidation_reason".to_string(),
133                JsonValue::String(reason.clone()),
134            );
135        }
136        if let Some(catchup) = &self.catchup {
137            let catchup_obj = object_from_slice(&catchup.encode_json())
138                .expect("CatchupModeReply emits a JSON object");
139            for (key, value) in catchup_obj {
140                obj.insert(key, value);
141            }
142        }
143        serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
144    }
145
146    pub fn decode_json(bytes: &[u8]) -> Result<Self> {
147        let obj = object_from_slice(bytes)?;
148        let records = match obj.get("records") {
149            Some(JsonValue::Array(values)) => values
150                .iter()
151                .map(WalStreamRecord::from_json)
152                .collect::<Result<Vec<_>>>()?,
153            Some(_) => return Err(ReplicationPayloadError::InvalidField("records")),
154            None => Vec::new(),
155        };
156        Ok(Self {
157            records,
158            current_lsn: get_u64(&obj, "current_lsn")?,
159            oldest_available_lsn: get_opt_u64(&obj, "oldest_available_lsn"),
160            partial_resync: get_bool_default(&obj, "partial_resync", false),
161            partial_resync_count: get_opt_u64(&obj, "partial_resync_count").unwrap_or(0),
162            needs_rebootstrap: get_bool_default(&obj, "needs_rebootstrap", false),
163            invalidation_reason: get_opt_string(&obj, "invalidation_reason"),
164            catchup: CatchupModeReply::from_wal_rebootstrap_object(&obj)?,
165        })
166    }
167}
168
169#[derive(Debug, Clone, PartialEq, Eq)]
170pub struct WalStreamAck {
171    pub replica_id: String,
172    pub applied_lsn: u64,
173    pub durable_lsn: u64,
174    pub apply_errors_total: u64,
175    pub divergence_total: u64,
176}
177
178impl WalStreamAck {
179    pub fn encode_json(&self) -> Vec<u8> {
180        let mut obj = serde_json::Map::new();
181        obj.insert(
182            "replica_id".to_string(),
183            JsonValue::String(self.replica_id.clone()),
184        );
185        obj.insert(
186            "applied_lsn".to_string(),
187            JsonValue::Number(self.applied_lsn.into()),
188        );
189        obj.insert(
190            "durable_lsn".to_string(),
191            JsonValue::Number(self.durable_lsn.into()),
192        );
193        obj.insert(
194            "apply_errors_total".to_string(),
195            JsonValue::Number(self.apply_errors_total.into()),
196        );
197        obj.insert(
198            "divergence_total".to_string(),
199            JsonValue::Number(self.divergence_total.into()),
200        );
201        serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
202    }
203
204    pub fn decode_json(bytes: &[u8]) -> Result<Self> {
205        let obj = object_from_slice(bytes)?;
206        let applied_lsn = get_u64(&obj, "applied_lsn")?;
207        Ok(Self {
208            replica_id: get_string(&obj, "replica_id")?,
209            applied_lsn,
210            durable_lsn: get_opt_u64(&obj, "durable_lsn").unwrap_or(applied_lsn),
211            apply_errors_total: get_opt_u64(&obj, "apply_errors_total").unwrap_or(0),
212            divergence_total: get_opt_u64(&obj, "divergence_total").unwrap_or(0),
213        })
214    }
215}
216
217#[derive(Debug, Clone, PartialEq, Eq)]
218pub struct WalStreamAckReply {
219    pub ok: bool,
220    pub replica_id: String,
221    pub applied_lsn: u64,
222    pub durable_lsn: u64,
223    pub apply_errors_total: u64,
224    pub divergence_total: u64,
225}
226
227impl WalStreamAckReply {
228    pub fn from_ack(ack: &WalStreamAck) -> Self {
229        Self {
230            ok: true,
231            replica_id: ack.replica_id.clone(),
232            applied_lsn: ack.applied_lsn,
233            durable_lsn: ack.durable_lsn,
234            apply_errors_total: ack.apply_errors_total,
235            divergence_total: ack.divergence_total,
236        }
237    }
238
239    pub fn encode_json(&self) -> Vec<u8> {
240        let mut obj = serde_json::Map::new();
241        obj.insert("ok".to_string(), JsonValue::Bool(self.ok));
242        obj.insert(
243            "replica_id".to_string(),
244            JsonValue::String(self.replica_id.clone()),
245        );
246        obj.insert(
247            "applied_lsn".to_string(),
248            JsonValue::Number(self.applied_lsn.into()),
249        );
250        obj.insert(
251            "durable_lsn".to_string(),
252            JsonValue::Number(self.durable_lsn.into()),
253        );
254        obj.insert(
255            "apply_errors_total".to_string(),
256            JsonValue::Number(self.apply_errors_total.into()),
257        );
258        obj.insert(
259            "divergence_total".to_string(),
260            JsonValue::Number(self.divergence_total.into()),
261        );
262        serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
263    }
264
265    pub fn decode_json(bytes: &[u8]) -> Result<Self> {
266        let obj = object_from_slice(bytes)?;
267        Ok(Self {
268            ok: get_bool_default(&obj, "ok", false),
269            replica_id: get_string(&obj, "replica_id")?,
270            applied_lsn: get_u64(&obj, "applied_lsn")?,
271            durable_lsn: get_u64(&obj, "durable_lsn")?,
272            apply_errors_total: get_opt_u64(&obj, "apply_errors_total").unwrap_or(0),
273            divergence_total: get_opt_u64(&obj, "divergence_total").unwrap_or(0),
274        })
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use crate::replication::CatchupMode;
282
283    #[test]
284    fn wal_stream_open_round_trips() {
285        let open = WalStreamOpen {
286            since_lsn: 10,
287            max_count: 128,
288            replica_id: Some("replica-a".to_string()),
289            term: 7,
290            await_data: true,
291            await_timeout_ms: 5000,
292        };
293        assert_eq!(
294            WalStreamOpen::decode_json(&open.encode_json()).unwrap(),
295            open
296        );
297    }
298
299    #[test]
300    fn wal_stream_open_defaults_missing_term_to_legacy_term() {
301        let open = WalStreamOpen::decode_json(
302            br#"{"since_lsn":10,"max_count":128,"replica_id":"replica-a"}"#,
303        )
304        .unwrap();
305        assert_eq!(open.term, DEFAULT_REPLICATION_TERM);
306    }
307
308    #[test]
309    fn wal_stream_chunk_round_trips_records_and_rebootstrap_hint() {
310        let chunk = WalStreamChunk {
311            records: vec![WalStreamRecord {
312                lsn: 11,
313                data: b"record".to_vec(),
314            }],
315            current_lsn: 12,
316            oldest_available_lsn: Some(9),
317            partial_resync: true,
318            partial_resync_count: 3,
319            needs_rebootstrap: true,
320            invalidation_reason: Some("retention".to_string()),
321            catchup: Some(CatchupModeReply {
322                mode: CatchupMode::BaseBackupThenWal,
323                available_from_lsn: Some(9),
324                replica_lsn: None,
325                reason: Some("retention".to_string()),
326            }),
327        };
328        assert_eq!(
329            WalStreamChunk::decode_json(&chunk.encode_json()).unwrap(),
330            chunk
331        );
332    }
333
334    #[test]
335    fn wal_ack_defaults_durable_to_applied() {
336        let ack = WalStreamAck::decode_json(br#"{"replica_id":"r","applied_lsn":7}"#).unwrap();
337        assert_eq!(ack.durable_lsn, 7);
338        assert_eq!(ack.apply_errors_total, 0);
339    }
340
341    #[test]
342    fn wal_ack_reply_round_trips() {
343        let ack = WalStreamAck {
344            replica_id: "replica-a".to_string(),
345            applied_lsn: 11,
346            durable_lsn: 10,
347            apply_errors_total: 2,
348            divergence_total: 1,
349        };
350        let reply = WalStreamAckReply::from_ack(&ack);
351
352        assert_eq!(
353            WalStreamAckReply::decode_json(&reply.encode_json()).unwrap(),
354            reply
355        );
356    }
357}