Skip to main content

nodedb_columnar/
wal_record.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! WAL record types for columnar operations.
4//!
5//! Each mutation (INSERT, DELETE, compaction commit) produces a WAL record
6//! that is written before the mutation is applied. On crash recovery, WAL
7//! records are replayed to reconstruct the memtable, delete bitmaps, and
8//! segment metadata.
9//!
10//! Records are serialized as MessagePack for compact wire representation.
11
12use serde::{Deserialize, Serialize};
13use sonic_rs;
14use zerompk::{FromMessagePack, ToMessagePack};
15
16/// A WAL record for a columnar collection operation.
17#[derive(Debug, Clone, Serialize, Deserialize, ToMessagePack, FromMessagePack)]
18#[serde(rename_all = "snake_case")]
19#[non_exhaustive]
20pub enum ColumnarWalRecord {
21    /// A row was inserted into the memtable.
22    ///
23    /// Contains the collection name and the row data as packed binary
24    /// (the columnar wire format, not MessagePack). On replay, the row
25    /// is re-inserted into the memtable.
26    #[serde(rename = "insert_row")]
27    InsertRow {
28        collection: String,
29        /// Row data as packed binary values. Each value is encoded per its
30        /// column type: i64 as 8 LE bytes, f64 as 8 LE bytes, strings as
31        /// length-prefixed UTF-8, etc.
32        row_data: Vec<u8>,
33    },
34
35    /// Rows were marked as deleted in a segment's delete bitmap.
36    ///
37    /// On replay, these row indices are re-applied to the segment's
38    /// delete bitmap.
39    #[serde(rename = "delete_rows")]
40    DeleteRows {
41        collection: String,
42        segment_id: u64,
43        row_indices: Vec<u32>,
44    },
45
46    /// A compaction was committed: old segments replaced with new ones.
47    ///
48    /// This is the atomic commit point of the 3-phase compaction protocol.
49    /// On replay:
50    /// - If new segments exist on disk: complete the metadata swap.
51    /// - If new segments don't exist: the compaction was interrupted before
52    ///   writing; discard and treat old segments as authoritative.
53    #[serde(rename = "compaction_commit")]
54    CompactionCommit {
55        collection: String,
56        old_segment_ids: Vec<u64>,
57        new_segment_ids: Vec<u64>,
58    },
59
60    /// The memtable was flushed to a new segment.
61    ///
62    /// On replay, if the segment file exists, update metadata to include it.
63    /// If it doesn't exist, the flush was interrupted; rows are already in
64    /// the memtable via InsertRow records.
65    #[serde(rename = "memtable_flushed")]
66    MemtableFlushed {
67        collection: String,
68        segment_id: u64,
69        row_count: u64,
70    },
71}
72
73impl ColumnarWalRecord {
74    /// Collection name this record belongs to.
75    pub fn collection(&self) -> &str {
76        match self {
77            Self::InsertRow { collection, .. }
78            | Self::DeleteRows { collection, .. }
79            | Self::CompactionCommit { collection, .. }
80            | Self::MemtableFlushed { collection, .. } => collection,
81        }
82    }
83
84    /// Serialize the record to bytes.
85    pub fn to_bytes(&self) -> Result<Vec<u8>, crate::error::ColumnarError> {
86        zerompk::to_msgpack_vec(self)
87            .map_err(|e| crate::error::ColumnarError::Serialization(e.to_string()))
88    }
89
90    /// Deserialize a record from bytes.
91    pub fn from_bytes(data: &[u8]) -> Result<Self, crate::error::ColumnarError> {
92        zerompk::from_msgpack(data)
93            .map_err(|e| crate::error::ColumnarError::Serialization(e.to_string()))
94    }
95}
96
97/// Encode a row of values into the columnar wire format for WAL records.
98///
99/// Each value is written as: [type_tag: u8][value_bytes].
100/// This is more compact than MessagePack for typed columns and enables
101/// direct replay into the memtable without schema interpretation overhead.
102pub fn encode_row_for_wal(
103    values: &[nodedb_types::value::Value],
104) -> Result<Vec<u8>, crate::error::ColumnarError> {
105    use nodedb_types::value::Value;
106
107    // no-governor: WAL encode per row; rough estimate, hot write path where instrument cost exceeds benefit
108    let mut buf = Vec::with_capacity(values.len() * 10); // Rough estimate.
109
110    for value in values {
111        match value {
112            Value::Null => buf.push(0),
113            Value::Integer(v) => {
114                buf.push(1);
115                buf.extend_from_slice(&v.to_le_bytes());
116            }
117            Value::Float(v) => {
118                buf.push(2);
119                buf.extend_from_slice(&v.to_le_bytes());
120            }
121            Value::Bool(v) => {
122                buf.push(3);
123                buf.push(*v as u8);
124            }
125            Value::String(s) => {
126                buf.push(4);
127                let bytes = s.as_bytes();
128                buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
129                buf.extend_from_slice(bytes);
130            }
131            Value::Bytes(b) => {
132                buf.push(5);
133                buf.extend_from_slice(&(b.len() as u32).to_le_bytes());
134                buf.extend_from_slice(b);
135            }
136            Value::DateTime(dt) => {
137                buf.push(6);
138                buf.extend_from_slice(&dt.micros.to_le_bytes());
139            }
140            Value::Decimal(d) => {
141                buf.push(7);
142                buf.extend_from_slice(&d.serialize());
143            }
144            Value::Uuid(s) => {
145                buf.push(8);
146                let bytes = s.as_bytes();
147                buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
148                buf.extend_from_slice(bytes);
149            }
150            Value::Array(arr) => {
151                // Vectors stored as: tag(9) + count(u32) + f32 values.
152                buf.push(9);
153                buf.extend_from_slice(&(arr.len() as u32).to_le_bytes());
154                for v in arr {
155                    let f = match v {
156                        Value::Float(f) => *f as f32,
157                        Value::Integer(n) => *n as f32,
158                        _ => 0.0,
159                    };
160                    buf.extend_from_slice(&f.to_le_bytes());
161                }
162            }
163            _ => {
164                // Geometry and other complex types: serialize as JSON bytes.
165                buf.push(10);
166                let json = sonic_rs::to_vec(value).map_err(|e| {
167                    crate::error::ColumnarError::Serialization(format!(
168                        "failed to serialize value as JSON: {e}"
169                    ))
170                })?;
171                buf.extend_from_slice(&(json.len() as u32).to_le_bytes());
172                buf.extend_from_slice(&json);
173            }
174        }
175    }
176
177    Ok(buf)
178}
179
180/// Maximum length for a variable-length field in a WAL record (256 MiB).
181/// Prevents OOM from crafted/corrupt records with bogus length prefixes.
182const MAX_FIELD_LEN: usize = 256 * 1024 * 1024;
183
184/// Read exactly `n` bytes from `data` at `cursor`, advancing cursor.
185/// Returns `Err` if not enough bytes remain.
186fn read_slice<'a>(
187    data: &'a [u8],
188    cursor: &mut usize,
189    n: usize,
190    context: &str,
191) -> Result<&'a [u8], crate::error::ColumnarError> {
192    let end = cursor.checked_add(n).ok_or_else(|| {
193        crate::error::ColumnarError::Serialization(format!("overflow in {context}"))
194    })?;
195    if end > data.len() {
196        return Err(crate::error::ColumnarError::Serialization(format!(
197            "truncated {context}: need {n} bytes at offset {cursor}, have {}",
198            data.len().saturating_sub(*cursor)
199        )));
200    }
201    let slice = &data[*cursor..end];
202    *cursor = end;
203    Ok(slice)
204}
205
206/// Read a u32 length prefix, validate it against MAX_FIELD_LEN, then read
207/// that many bytes. Returns the payload slice.
208fn read_length_prefixed<'a>(
209    data: &'a [u8],
210    cursor: &mut usize,
211    context: &str,
212) -> Result<&'a [u8], crate::error::ColumnarError> {
213    let len_bytes = read_slice(data, cursor, 4, context)?;
214    let len = u32::from_le_bytes(len_bytes.try_into().map_err(|_| {
215        crate::error::ColumnarError::Serialization(format!("truncated {context} len"))
216    })?) as usize;
217    if len > MAX_FIELD_LEN {
218        return Err(crate::error::ColumnarError::Serialization(format!(
219            "{context} length {len} exceeds maximum {MAX_FIELD_LEN}"
220        )));
221    }
222    read_slice(data, cursor, len, context)
223}
224
225/// Decode a row from the columnar wire format back into Values.
226pub fn decode_row_from_wal(
227    data: &[u8],
228) -> Result<Vec<nodedb_types::value::Value>, crate::error::ColumnarError> {
229    use nodedb_types::value::Value;
230
231    let mut values = Vec::new();
232    let mut cursor = 0;
233
234    while cursor < data.len() {
235        let tag_slice = read_slice(data, &mut cursor, 1, "tag")?;
236        let tag = tag_slice[0];
237
238        let value = match tag {
239            0 => Value::Null,
240            1 => {
241                let bytes = read_slice(data, &mut cursor, 8, "i64")?;
242                let v = i64::from_le_bytes(bytes.try_into().map_err(|_| {
243                    crate::error::ColumnarError::Serialization("truncated i64".into())
244                })?);
245                Value::Integer(v)
246            }
247            2 => {
248                let bytes = read_slice(data, &mut cursor, 8, "f64")?;
249                let v = f64::from_le_bytes(bytes.try_into().map_err(|_| {
250                    crate::error::ColumnarError::Serialization("truncated f64".into())
251                })?);
252                Value::Float(v)
253            }
254            3 => {
255                let bytes = read_slice(data, &mut cursor, 1, "bool")?;
256                Value::Bool(bytes[0] != 0)
257            }
258            4 | 5 | 8 => {
259                let bytes = read_length_prefixed(
260                    data,
261                    &mut cursor,
262                    match tag {
263                        4 => "string",
264                        5 => "bytes",
265                        8 => "uuid",
266                        _ => unreachable!(),
267                    },
268                )?;
269                match tag {
270                    4 => Value::String(String::from_utf8_lossy(bytes).into_owned()),
271                    5 => Value::Bytes(bytes.to_vec()),
272                    8 => Value::Uuid(String::from_utf8_lossy(bytes).into_owned()),
273                    _ => unreachable!(),
274                }
275            }
276            6 => {
277                let bytes = read_slice(data, &mut cursor, 8, "timestamp")?;
278                let micros = i64::from_le_bytes(bytes.try_into().map_err(|_| {
279                    crate::error::ColumnarError::Serialization("truncated timestamp".into())
280                })?);
281                Value::DateTime(nodedb_types::datetime::NdbDateTime::from_micros(micros))
282            }
283            7 => {
284                let bytes = read_slice(data, &mut cursor, 16, "decimal")?;
285                let mut arr = [0u8; 16];
286                arr.copy_from_slice(bytes);
287                Value::Decimal(rust_decimal::Decimal::deserialize(arr))
288            }
289            9 => {
290                let count_bytes = read_slice(data, &mut cursor, 4, "vector count")?;
291                let count = u32::from_le_bytes(count_bytes.try_into().map_err(|_| {
292                    crate::error::ColumnarError::Serialization("truncated vector count".into())
293                })?) as usize;
294                if count > MAX_FIELD_LEN / 4 {
295                    return Err(crate::error::ColumnarError::Serialization(format!(
296                        "vector count {count} exceeds maximum {}",
297                        MAX_FIELD_LEN / 4
298                    )));
299                }
300                // no-governor: WAL decode inner array; bounded by MAX_FIELD_LEN/4 guard above
301                let mut arr = Vec::with_capacity(count);
302                for _ in 0..count {
303                    let fb = read_slice(data, &mut cursor, 4, "vector f32")?;
304                    let f = f32::from_le_bytes(fb.try_into().map_err(|_| {
305                        crate::error::ColumnarError::Serialization("truncated f32".into())
306                    })?);
307                    arr.push(Value::Float(f as f64));
308                }
309                Value::Array(arr)
310            }
311            10 => {
312                let json_bytes = read_length_prefixed(data, &mut cursor, "json")?;
313                sonic_rs::from_slice(json_bytes).unwrap_or(Value::Null)
314            }
315            _ => {
316                return Err(crate::error::ColumnarError::Serialization(format!(
317                    "unknown WAL value tag: {tag}"
318                )));
319            }
320        };
321
322        values.push(value);
323    }
324
325    Ok(values)
326}
327
328#[cfg(test)]
329mod tests {
330    use nodedb_types::datetime::NdbDateTime;
331    use nodedb_types::value::Value;
332
333    use super::*;
334
335    #[test]
336    fn wal_record_roundtrip() {
337        let records = vec![
338            ColumnarWalRecord::InsertRow {
339                collection: "test".into(),
340                row_data: vec![1, 2, 3],
341            },
342            ColumnarWalRecord::DeleteRows {
343                collection: "test".into(),
344                segment_id: 0,
345                row_indices: vec![5, 10, 15],
346            },
347            ColumnarWalRecord::CompactionCommit {
348                collection: "test".into(),
349                old_segment_ids: vec![0, 1],
350                new_segment_ids: vec![2],
351            },
352            ColumnarWalRecord::MemtableFlushed {
353                collection: "test".into(),
354                segment_id: 3,
355                row_count: 1024,
356            },
357        ];
358
359        for record in &records {
360            let bytes = record.to_bytes().expect("serialize");
361            let restored = ColumnarWalRecord::from_bytes(&bytes).expect("deserialize");
362            assert_eq!(restored.collection(), record.collection());
363        }
364    }
365
366    #[test]
367    fn row_wire_format_roundtrip() {
368        let values = vec![
369            Value::Integer(42),
370            Value::Float(0.75),
371            Value::Bool(true),
372            Value::String("hello".into()),
373            Value::Bytes(vec![0xDE, 0xAD]),
374            Value::DateTime(NdbDateTime::from_micros(1_700_000_000)),
375            Value::Decimal(rust_decimal::Decimal::new(314, 2)),
376            Value::Uuid("550e8400-e29b-41d4-a716-446655440000".into()),
377            Value::Null,
378            Value::Array(vec![Value::Float(1.0), Value::Float(2.0)]),
379        ];
380
381        let encoded = encode_row_for_wal(&values).expect("encode");
382        let decoded = decode_row_from_wal(&encoded).expect("decode");
383
384        assert_eq!(decoded.len(), values.len());
385        assert_eq!(decoded[0], Value::Integer(42));
386        assert_eq!(decoded[1], Value::Float(0.75));
387        assert_eq!(decoded[2], Value::Bool(true));
388        assert_eq!(decoded[3], Value::String("hello".into()));
389        assert_eq!(decoded[4], Value::Bytes(vec![0xDE, 0xAD]));
390        assert_eq!(
391            decoded[5],
392            Value::DateTime(NdbDateTime::from_micros(1_700_000_000))
393        );
394        assert_eq!(
395            decoded[7],
396            Value::Uuid("550e8400-e29b-41d4-a716-446655440000".into())
397        );
398        assert_eq!(decoded[8], Value::Null);
399    }
400
401    #[test]
402    fn decode_truncated_i64_returns_error() {
403        // Tag 1 (i64) requires 8 payload bytes; supply none.
404        // Today the slice index `data[cursor..cursor+8]` panics with an index
405        // out-of-bounds. After the fix, `try_into()` returns the
406        // Serialization error instead.
407        let result = decode_row_from_wal(&[1]);
408        assert!(
409            result.is_err(),
410            "truncated i64 payload must return Err, not panic"
411        );
412    }
413
414    #[test]
415    fn decode_truncated_string_returns_error() {
416        // Tag 4 (string): length prefix says 255 bytes but the slice ends
417        // immediately after the 4-byte length field. The read of
418        // `data[cursor..cursor+255]` panics today; after the fix it errors.
419        let input = {
420            let mut v = vec![4u8]; // tag = string
421            v.extend_from_slice(&255u32.to_le_bytes()); // len = 255
422            // no payload bytes follow
423            v
424        };
425        let result = decode_row_from_wal(&input);
426        assert!(
427            result.is_err(),
428            "truncated string payload must return Err, not panic"
429        );
430    }
431
432    #[test]
433    fn decode_huge_vector_count_returns_error() {
434        // Tag 9 (vector array): count = 0x7FFFFFFF. After reading the count,
435        // the very first iteration tries to read 4 bytes of f32 from an empty
436        // slice, which panics today. After the fix the loop errors out cleanly
437        // before any allocation proportional to count is attempted.
438        let input = {
439            let mut v = vec![9u8]; // tag = vector array
440            v.extend_from_slice(&0x7FFF_FFFFu32.to_le_bytes()); // count
441            // no f32 bytes follow
442            v
443        };
444        let result = decode_row_from_wal(&input);
445        assert!(
446            result.is_err(),
447            "huge vector count with no payload must return Err, not panic or OOM"
448        );
449    }
450
451    #[test]
452    fn decode_truncated_decimal_returns_error() {
453        // Tag 7 (Decimal) requires 16 bytes; supply only 4.
454        // `data[cursor..cursor+16]` panics today; after the fix it errors.
455        let input = {
456            let mut v = vec![7u8]; // tag = decimal
457            v.extend_from_slice(&[0u8; 4]); // only 4 bytes, need 16
458            v
459        };
460        let result = decode_row_from_wal(&input);
461        assert!(
462            result.is_err(),
463            "truncated decimal payload must return Err, not panic"
464        );
465    }
466}