Skip to main content

reddb_wire/replication/
wal_stream.rs

1use serde_json::Value as JsonValue;
2
3use super::catchup::CatchupModeReply;
4use super::util::{
5    get_bool_default, get_opt_string, get_opt_u64, get_string, get_u64, hex_decode, hex_encode,
6    object_from_slice, ReplicationPayloadError, Result,
7};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct WalStreamOpen {
11    pub since_lsn: u64,
12    pub max_count: usize,
13    pub replica_id: Option<String>,
14    pub await_data: bool,
15    pub await_timeout_ms: u64,
16}
17
18impl WalStreamOpen {
19    pub fn encode_json(&self) -> Vec<u8> {
20        let mut obj = serde_json::Map::new();
21        obj.insert(
22            "since_lsn".to_string(),
23            JsonValue::Number(self.since_lsn.into()),
24        );
25        obj.insert(
26            "max_count".to_string(),
27            JsonValue::Number((self.max_count as u64).into()),
28        );
29        if let Some(replica_id) = &self.replica_id {
30            obj.insert(
31                "replica_id".to_string(),
32                JsonValue::String(replica_id.clone()),
33            );
34        }
35        obj.insert("await_data".to_string(), JsonValue::Bool(self.await_data));
36        obj.insert(
37            "await_timeout_ms".to_string(),
38            JsonValue::Number(self.await_timeout_ms.into()),
39        );
40        serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
41    }
42
43    pub fn decode_json(bytes: &[u8]) -> Result<Self> {
44        let obj = object_from_slice(bytes)?;
45        let max_count = get_opt_u64(&obj, "max_count").unwrap_or(1000);
46        Ok(Self {
47            since_lsn: get_opt_u64(&obj, "since_lsn").unwrap_or(0),
48            max_count: usize::try_from(max_count)
49                .map_err(|_| ReplicationPayloadError::InvalidField("max_count"))?,
50            replica_id: get_opt_string(&obj, "replica_id"),
51            await_data: get_bool_default(&obj, "await_data", false),
52            await_timeout_ms: get_opt_u64(&obj, "await_timeout_ms").unwrap_or(30_000),
53        })
54    }
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct WalStreamRecord {
59    pub lsn: u64,
60    pub data: Vec<u8>,
61}
62
63impl WalStreamRecord {
64    fn to_json(&self) -> JsonValue {
65        let mut obj = serde_json::Map::new();
66        obj.insert("lsn".to_string(), JsonValue::Number(self.lsn.into()));
67        obj.insert(
68            "data".to_string(),
69            JsonValue::String(hex_encode(&self.data)),
70        );
71        JsonValue::Object(obj)
72    }
73
74    fn from_json(value: &JsonValue) -> Result<Self> {
75        let obj = value
76            .as_object()
77            .ok_or(ReplicationPayloadError::InvalidField("records"))?;
78        let data_hex = get_string(obj, "data")?;
79        Ok(Self {
80            lsn: get_u64(obj, "lsn")?,
81            data: hex_decode("data", &data_hex)?,
82        })
83    }
84}
85
86#[derive(Debug, Clone, PartialEq, Eq)]
87pub struct WalStreamChunk {
88    pub records: Vec<WalStreamRecord>,
89    pub current_lsn: u64,
90    pub oldest_available_lsn: Option<u64>,
91    pub partial_resync: bool,
92    pub partial_resync_count: u64,
93    pub needs_rebootstrap: bool,
94    pub invalidation_reason: Option<String>,
95    pub catchup: Option<CatchupModeReply>,
96}
97
98impl WalStreamChunk {
99    pub fn encode_json(&self) -> Vec<u8> {
100        let mut obj = serde_json::Map::new();
101        obj.insert(
102            "records".to_string(),
103            JsonValue::Array(self.records.iter().map(WalStreamRecord::to_json).collect()),
104        );
105        obj.insert(
106            "current_lsn".to_string(),
107            JsonValue::Number(self.current_lsn.into()),
108        );
109        if let Some(lsn) = self.oldest_available_lsn {
110            obj.insert(
111                "oldest_available_lsn".to_string(),
112                JsonValue::Number(lsn.into()),
113            );
114        }
115        obj.insert(
116            "partial_resync".to_string(),
117            JsonValue::Bool(self.partial_resync),
118        );
119        obj.insert(
120            "partial_resync_count".to_string(),
121            JsonValue::Number(self.partial_resync_count.into()),
122        );
123        obj.insert(
124            "needs_rebootstrap".to_string(),
125            JsonValue::Bool(self.needs_rebootstrap),
126        );
127        if let Some(reason) = &self.invalidation_reason {
128            obj.insert(
129                "invalidation_reason".to_string(),
130                JsonValue::String(reason.clone()),
131            );
132        }
133        if let Some(catchup) = &self.catchup {
134            let catchup_obj = object_from_slice(&catchup.encode_json())
135                .expect("CatchupModeReply emits a JSON object");
136            for (key, value) in catchup_obj {
137                obj.insert(key, value);
138            }
139        }
140        serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
141    }
142
143    pub fn decode_json(bytes: &[u8]) -> Result<Self> {
144        let obj = object_from_slice(bytes)?;
145        let records = match obj.get("records") {
146            Some(JsonValue::Array(values)) => values
147                .iter()
148                .map(WalStreamRecord::from_json)
149                .collect::<Result<Vec<_>>>()?,
150            Some(_) => return Err(ReplicationPayloadError::InvalidField("records")),
151            None => Vec::new(),
152        };
153        Ok(Self {
154            records,
155            current_lsn: get_u64(&obj, "current_lsn")?,
156            oldest_available_lsn: get_opt_u64(&obj, "oldest_available_lsn"),
157            partial_resync: get_bool_default(&obj, "partial_resync", false),
158            partial_resync_count: get_opt_u64(&obj, "partial_resync_count").unwrap_or(0),
159            needs_rebootstrap: get_bool_default(&obj, "needs_rebootstrap", false),
160            invalidation_reason: get_opt_string(&obj, "invalidation_reason"),
161            catchup: CatchupModeReply::from_wal_rebootstrap_object(&obj)?,
162        })
163    }
164}
165
166#[derive(Debug, Clone, PartialEq, Eq)]
167pub struct WalStreamAck {
168    pub replica_id: String,
169    pub applied_lsn: u64,
170    pub durable_lsn: u64,
171    pub apply_errors_total: u64,
172    pub divergence_total: u64,
173}
174
175impl WalStreamAck {
176    pub fn encode_json(&self) -> Vec<u8> {
177        let mut obj = serde_json::Map::new();
178        obj.insert(
179            "replica_id".to_string(),
180            JsonValue::String(self.replica_id.clone()),
181        );
182        obj.insert(
183            "applied_lsn".to_string(),
184            JsonValue::Number(self.applied_lsn.into()),
185        );
186        obj.insert(
187            "durable_lsn".to_string(),
188            JsonValue::Number(self.durable_lsn.into()),
189        );
190        obj.insert(
191            "apply_errors_total".to_string(),
192            JsonValue::Number(self.apply_errors_total.into()),
193        );
194        obj.insert(
195            "divergence_total".to_string(),
196            JsonValue::Number(self.divergence_total.into()),
197        );
198        serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
199    }
200
201    pub fn decode_json(bytes: &[u8]) -> Result<Self> {
202        let obj = object_from_slice(bytes)?;
203        let applied_lsn = get_u64(&obj, "applied_lsn")?;
204        Ok(Self {
205            replica_id: get_string(&obj, "replica_id")?,
206            applied_lsn,
207            durable_lsn: get_opt_u64(&obj, "durable_lsn").unwrap_or(applied_lsn),
208            apply_errors_total: get_opt_u64(&obj, "apply_errors_total").unwrap_or(0),
209            divergence_total: get_opt_u64(&obj, "divergence_total").unwrap_or(0),
210        })
211    }
212}
213
214#[derive(Debug, Clone, PartialEq, Eq)]
215pub struct WalStreamAckReply {
216    pub ok: bool,
217    pub replica_id: String,
218    pub applied_lsn: u64,
219    pub durable_lsn: u64,
220    pub apply_errors_total: u64,
221    pub divergence_total: u64,
222}
223
224impl WalStreamAckReply {
225    pub fn from_ack(ack: &WalStreamAck) -> Self {
226        Self {
227            ok: true,
228            replica_id: ack.replica_id.clone(),
229            applied_lsn: ack.applied_lsn,
230            durable_lsn: ack.durable_lsn,
231            apply_errors_total: ack.apply_errors_total,
232            divergence_total: ack.divergence_total,
233        }
234    }
235
236    pub fn encode_json(&self) -> Vec<u8> {
237        let mut obj = serde_json::Map::new();
238        obj.insert("ok".to_string(), JsonValue::Bool(self.ok));
239        obj.insert(
240            "replica_id".to_string(),
241            JsonValue::String(self.replica_id.clone()),
242        );
243        obj.insert(
244            "applied_lsn".to_string(),
245            JsonValue::Number(self.applied_lsn.into()),
246        );
247        obj.insert(
248            "durable_lsn".to_string(),
249            JsonValue::Number(self.durable_lsn.into()),
250        );
251        obj.insert(
252            "apply_errors_total".to_string(),
253            JsonValue::Number(self.apply_errors_total.into()),
254        );
255        obj.insert(
256            "divergence_total".to_string(),
257            JsonValue::Number(self.divergence_total.into()),
258        );
259        serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
260    }
261
262    pub fn decode_json(bytes: &[u8]) -> Result<Self> {
263        let obj = object_from_slice(bytes)?;
264        Ok(Self {
265            ok: get_bool_default(&obj, "ok", false),
266            replica_id: get_string(&obj, "replica_id")?,
267            applied_lsn: get_u64(&obj, "applied_lsn")?,
268            durable_lsn: get_u64(&obj, "durable_lsn")?,
269            apply_errors_total: get_opt_u64(&obj, "apply_errors_total").unwrap_or(0),
270            divergence_total: get_opt_u64(&obj, "divergence_total").unwrap_or(0),
271        })
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use crate::replication::CatchupMode;
279
280    #[test]
281    fn wal_stream_open_round_trips() {
282        let open = WalStreamOpen {
283            since_lsn: 10,
284            max_count: 128,
285            replica_id: Some("replica-a".to_string()),
286            await_data: true,
287            await_timeout_ms: 5000,
288        };
289        assert_eq!(
290            WalStreamOpen::decode_json(&open.encode_json()).unwrap(),
291            open
292        );
293    }
294
295    #[test]
296    fn wal_stream_chunk_round_trips_records_and_rebootstrap_hint() {
297        let chunk = WalStreamChunk {
298            records: vec![WalStreamRecord {
299                lsn: 11,
300                data: b"record".to_vec(),
301            }],
302            current_lsn: 12,
303            oldest_available_lsn: Some(9),
304            partial_resync: true,
305            partial_resync_count: 3,
306            needs_rebootstrap: true,
307            invalidation_reason: Some("retention".to_string()),
308            catchup: Some(CatchupModeReply {
309                mode: CatchupMode::BaseBackupThenWal,
310                available_from_lsn: Some(9),
311                replica_lsn: None,
312                reason: Some("retention".to_string()),
313            }),
314        };
315        assert_eq!(
316            WalStreamChunk::decode_json(&chunk.encode_json()).unwrap(),
317            chunk
318        );
319    }
320
321    #[test]
322    fn wal_ack_defaults_durable_to_applied() {
323        let ack = WalStreamAck::decode_json(br#"{"replica_id":"r","applied_lsn":7}"#).unwrap();
324        assert_eq!(ack.durable_lsn, 7);
325        assert_eq!(ack.apply_errors_total, 0);
326    }
327
328    #[test]
329    fn wal_ack_reply_round_trips() {
330        let ack = WalStreamAck {
331            replica_id: "replica-a".to_string(),
332            applied_lsn: 11,
333            durable_lsn: 10,
334            apply_errors_total: 2,
335            divergence_total: 1,
336        };
337        let reply = WalStreamAckReply::from_ack(&ack);
338
339        assert_eq!(
340            WalStreamAckReply::decode_json(&reply.encode_json()).unwrap(),
341            reply
342        );
343    }
344}