Skip to main content

crous_core/
encoder.rs

1//! Encoder for the Crous binary format.
2//!
3//! Encodes `Value` instances into the canonical Crous binary representation.
4//! The encoder handles:
5//! - File header emission
6//! - Block framing with checksums
7//! - Wire-type-tagged field encoding
8//! - Varint/ZigZag integer encoding
9//! - Length-delimited strings and bytes
10
11use std::collections::HashMap;
12
13use crate::checksum::compute_xxh64;
14use crate::error::{CrousError, Result};
15use crate::header::{FLAGS_NONE, FileHeader};
16use crate::limits::Limits;
17use crate::value::Value;
18use crate::varint::encode_varint_vec;
19use crate::wire::{BlockType, CompressionType, WireType};
20
21/// Encoder that serializes `Value`s into Crous binary format.
22///
23/// # Example
24/// ```
25/// use crous_core::{Encoder, Value};
26///
27/// let mut enc = Encoder::new();
28/// enc.encode_value(&Value::UInt(42)).unwrap();
29/// let bytes = enc.finish().unwrap();
30/// assert!(bytes.len() > 8); // header + block
31/// ```
32pub struct Encoder {
33    /// The output buffer accumulating the binary output.
34    output: Vec<u8>,
35    /// Buffer for the current block's payload (before framing).
36    block_buf: Vec<u8>,
37    /// Current nesting depth (for overflow protection).
38    depth: usize,
39    /// Resource limits.
40    limits: Limits,
41    /// Whether the file header has been written.
42    header_written: bool,
43    /// File header flags.
44    flags: u8,
45    /// Compression type for blocks.
46    compression: CompressionType,
47    /// Per-block string dictionary: string → index.
48    /// When `dedup_strings` is true, repeated strings are encoded as Reference.
49    string_dict: HashMap<String, u32>,
50    /// Whether to enable string deduplication.
51    dedup_strings: bool,
52}
53
54impl Encoder {
55    /// Create a new encoder with default settings.
56    pub fn new() -> Self {
57        Self {
58            output: Vec::with_capacity(4096),
59            block_buf: Vec::with_capacity(4096),
60            depth: 0,
61            limits: Limits::default(),
62            header_written: false,
63            flags: FLAGS_NONE,
64            compression: CompressionType::None,
65            string_dict: HashMap::new(),
66            dedup_strings: false,
67        }
68    }
69
70    /// Create an encoder with custom limits.
71    pub fn with_limits(limits: Limits) -> Self {
72        Self {
73            limits,
74            ..Self::new()
75        }
76    }
77
78    /// Enable string deduplication. Repeated strings within a block
79    /// will be encoded as Reference wire types pointing to the dictionary.
80    pub fn enable_dedup(&mut self) {
81        self.dedup_strings = true;
82    }
83
84    /// Set the compression type for subsequent blocks.
85    pub fn set_compression(&mut self, comp: CompressionType) {
86        self.compression = comp;
87    }
88
89    /// Set the file header flags.
90    pub fn set_flags(&mut self, flags: u8) {
91        self.flags = flags;
92    }
93
94    /// Ensure the file header has been written.
95    fn ensure_header(&mut self) {
96        if !self.header_written {
97            let header = FileHeader::new(self.flags);
98            self.output.extend_from_slice(&header.encode());
99            self.header_written = true;
100        }
101    }
102
103    /// Encode a single `Value` into the current block buffer.
104    ///
105    /// This is the main entry point for encoding. Values are accumulated
106    /// in the block buffer; call `finish()` to flush and produce the final bytes.
107    pub fn encode_value(&mut self, value: &Value) -> Result<()> {
108        self.encode_value_inner(value)
109    }
110
111    fn encode_value_inner(&mut self, value: &Value) -> Result<()> {
112        match value {
113            Value::Null => {
114                self.block_buf.push(WireType::Null.to_tag());
115            }
116            Value::Bool(b) => {
117                self.block_buf.push(WireType::Bool.to_tag());
118                self.block_buf.push(if *b { 0x01 } else { 0x00 });
119            }
120            Value::UInt(n) => {
121                self.block_buf.push(WireType::VarUInt.to_tag());
122                encode_varint_vec(*n, &mut self.block_buf);
123            }
124            Value::Int(n) => {
125                self.block_buf.push(WireType::VarInt.to_tag());
126                crate::varint::encode_signed_varint_vec(*n, &mut self.block_buf);
127            }
128            Value::Float(f) => {
129                self.block_buf.push(WireType::Fixed64.to_tag());
130                self.block_buf.extend_from_slice(&f.to_le_bytes());
131            }
132            Value::Str(s) => {
133                if self.dedup_strings {
134                    if let Some(&idx) = self.string_dict.get(s.as_str()) {
135                        // Emit a Reference to the dictionary entry.
136                        self.block_buf.push(WireType::Reference.to_tag());
137                        encode_varint_vec(idx as u64, &mut self.block_buf);
138                        return Ok(());
139                    }
140                    // First occurrence: record in dictionary.
141                    let idx = self.string_dict.len() as u32;
142                    self.string_dict.insert(s.clone(), idx);
143                }
144                self.block_buf.push(WireType::LenDelimited.to_tag());
145                // Sub-type marker: 0x00 = UTF-8 string
146                self.block_buf.push(0x00);
147                encode_varint_vec(s.len() as u64, &mut self.block_buf);
148                self.block_buf.extend_from_slice(s.as_bytes());
149            }
150            Value::Bytes(b) => {
151                self.block_buf.push(WireType::LenDelimited.to_tag());
152                // Sub-type marker: 0x01 = raw binary
153                self.block_buf.push(0x01);
154                encode_varint_vec(b.len() as u64, &mut self.block_buf);
155                self.block_buf.extend_from_slice(b);
156            }
157            Value::Array(items) => {
158                if self.depth >= self.limits.max_nesting_depth {
159                    return Err(CrousError::NestingTooDeep(
160                        self.depth,
161                        self.limits.max_nesting_depth,
162                    ));
163                }
164                if items.len() > self.limits.max_items {
165                    return Err(CrousError::TooManyItems(items.len(), self.limits.max_items));
166                }
167                self.block_buf.push(WireType::StartArray.to_tag());
168                // Encode item count as a varint for fast skipping.
169                encode_varint_vec(items.len() as u64, &mut self.block_buf);
170                self.depth += 1;
171                for item in items {
172                    self.encode_value_inner(item)?;
173                }
174                self.depth -= 1;
175                self.block_buf.push(WireType::EndArray.to_tag());
176            }
177            Value::Object(entries) => {
178                if self.depth >= self.limits.max_nesting_depth {
179                    return Err(CrousError::NestingTooDeep(
180                        self.depth,
181                        self.limits.max_nesting_depth,
182                    ));
183                }
184                if entries.len() > self.limits.max_items {
185                    return Err(CrousError::TooManyItems(
186                        entries.len(),
187                        self.limits.max_items,
188                    ));
189                }
190                self.block_buf.push(WireType::StartObject.to_tag());
191                // Encode entry count for fast skipping.
192                encode_varint_vec(entries.len() as u64, &mut self.block_buf);
193                self.depth += 1;
194                for (key, val) in entries {
195                    // Encode key as a length-delimited string inline.
196                    encode_varint_vec(key.len() as u64, &mut self.block_buf);
197                    self.block_buf.extend_from_slice(key.as_bytes());
198                    // Encode value.
199                    self.encode_value_inner(val)?;
200                }
201                self.depth -= 1;
202                self.block_buf.push(WireType::EndObject.to_tag());
203            }
204        }
205        Ok(())
206    }
207
208    /// Flush the current block buffer into a framed block and append to output.
209    /// Returns the number of bytes in the flushed block.
210    ///
211    /// When `dedup_strings` is enabled and the per-block string dictionary is
212    /// non-empty, a `StringDict` block is emitted *before* the data block.
213    /// The dictionary entries are sorted and stored using prefix-delta
214    /// compression for compactness.
215    ///
216    /// When `compression` is set to something other than `None`, the block
217    /// payload is compressed before framing. The checksum is always computed
218    /// on the **uncompressed** payload so the decoder can verify integrity
219    /// after decompression. The block_len field reflects the **compressed**
220    /// size written to the wire.
221    pub fn flush_block(&mut self) -> Result<usize> {
222        if self.block_buf.is_empty() {
223            return Ok(0);
224        }
225
226        self.ensure_header();
227
228        let mut total_size = 0;
229
230        // --- Emit StringDict block before the data block ---
231        if self.dedup_strings && !self.string_dict.is_empty() {
232            let dict_payload = self.encode_string_dict_payload();
233            let dict_checksum = compute_xxh64(&dict_payload);
234
235            self.output.push(BlockType::StringDict as u8);
236            encode_varint_vec(dict_payload.len() as u64, &mut self.output);
237            self.output.push(CompressionType::None as u8);
238            self.output.extend_from_slice(&dict_checksum.to_le_bytes());
239            self.output.extend_from_slice(&dict_payload);
240
241            total_size += 1 + 1 + 1 + 8 + dict_payload.len();
242        }
243
244        // --- Emit the data block ---
245
246        // Checksum is always over the uncompressed payload.
247        let checksum = compute_xxh64(&self.block_buf);
248
249        // Compress if requested. The on-wire payload may differ from block_buf.
250        let (wire_payload, wire_comp) = if self.compression != CompressionType::None {
251            match self.compress_payload(&self.block_buf) {
252                Some(compressed) if compressed.len() < self.block_buf.len() => {
253                    // Store uncompressed length as a varint prefix so the decoder
254                    // can pre-allocate the decompression buffer.
255                    let mut framed = Vec::with_capacity(10 + compressed.len());
256                    encode_varint_vec(self.block_buf.len() as u64, &mut framed);
257                    framed.extend_from_slice(&compressed);
258                    (framed, self.compression)
259                }
260                _ => {
261                    // Compression didn't help — store uncompressed.
262                    (self.block_buf.clone(), CompressionType::None)
263                }
264            }
265        } else {
266            (self.block_buf.clone(), CompressionType::None)
267        };
268
269        // Block header:
270        //   block_type (1B) | block_len (varint) | comp_type (1B) | checksum (8B) | payload
271        let block_type = BlockType::Data as u8;
272
273        self.output.push(block_type);
274        encode_varint_vec(wire_payload.len() as u64, &mut self.output);
275        self.output.push(wire_comp as u8);
276        self.output.extend_from_slice(&checksum.to_le_bytes());
277        self.output.extend_from_slice(&wire_payload);
278
279        total_size += 1 + 1 + 8 + wire_payload.len();
280        self.block_buf.clear();
281        self.string_dict.clear(); // Reset per-block dictionary.
282        Ok(total_size)
283    }
284
285    /// Encode the per-block string dictionary as a prefix-delta-compressed
286    /// payload for a `StringDict` block.
287    ///
288    /// Layout:
289    ///   `entry_count(varint)` | entries...
290    ///
291    /// Each entry (prefix-delta encoded):
292    ///   `original_index(varint)` | `prefix_len(varint)` | `suffix_len(varint)` | `suffix_bytes`
293    ///
294    /// Entries are sorted lexicographically for prefix sharing. The
295    /// `original_index` preserves the insertion order so the decoder can
296    /// rebuild the reference table with the correct indices.
297    fn encode_string_dict_payload(&self) -> Vec<u8> {
298        // Collect entries and sort by string for prefix-delta compression.
299        let mut entries: Vec<(&str, u32)> = self
300            .string_dict
301            .iter()
302            .map(|(s, &idx)| (s.as_str(), idx))
303            .collect();
304        entries.sort_by(|a, b| a.0.cmp(b.0));
305
306        let mut payload = Vec::with_capacity(entries.len() * 16);
307        encode_varint_vec(entries.len() as u64, &mut payload);
308
309        let mut prev = "";
310        for (s, original_idx) in &entries {
311            // Compute shared prefix length with previous entry.
312            let prefix_len = s
313                .as_bytes()
314                .iter()
315                .zip(prev.as_bytes().iter())
316                .take_while(|(a, b)| a == b)
317                .count();
318            let suffix = &s.as_bytes()[prefix_len..];
319
320            encode_varint_vec(*original_idx as u64, &mut payload);
321            encode_varint_vec(prefix_len as u64, &mut payload);
322            encode_varint_vec(suffix.len() as u64, &mut payload);
323            payload.extend_from_slice(suffix);
324
325            prev = s;
326        }
327        payload
328    }
329
330    /// Compress the payload using the configured compression algorithm.
331    /// Returns `None` if the compression feature is not available.
332    #[allow(unused_variables)]
333    fn compress_payload(&self, data: &[u8]) -> Option<Vec<u8>> {
334        match self.compression {
335            CompressionType::None => None,
336            #[cfg(feature = "zstd")]
337            CompressionType::Zstd => zstd::encode_all(std::io::Cursor::new(data), 3).ok(),
338            #[cfg(not(feature = "zstd"))]
339            CompressionType::Zstd => None,
340            #[cfg(feature = "snappy")]
341            CompressionType::Snappy => {
342                let mut enc = snap::raw::Encoder::new();
343                enc.compress_vec(data).ok()
344            }
345            #[cfg(not(feature = "snappy"))]
346            CompressionType::Snappy => None,
347            #[cfg(feature = "lz4")]
348            CompressionType::Lz4 => Some(lz4_flex::compress_prepend_size(data)),
349            #[cfg(not(feature = "lz4"))]
350            CompressionType::Lz4 => None,
351        }
352    }
353
354    /// Finish encoding: flush remaining data and return the complete binary output.
355    ///
356    /// The output includes: file header + data blocks + file trailer checksum.
357    pub fn finish(mut self) -> Result<Vec<u8>> {
358        self.flush_block()?;
359        self.ensure_header();
360
361        // Write file trailer: XXH64 checksum over everything written so far.
362        let overall_checksum = compute_xxh64(&self.output);
363        // Trailer block: type=0xFF, length=8, no compression, checksum of checksum, payload=checksum
364        self.output.push(BlockType::Trailer as u8);
365        encode_varint_vec(8, &mut self.output);
366        self.output.push(CompressionType::None as u8);
367        let trailer_checksum = compute_xxh64(&overall_checksum.to_le_bytes());
368        self.output
369            .extend_from_slice(&trailer_checksum.to_le_bytes());
370        self.output
371            .extend_from_slice(&overall_checksum.to_le_bytes());
372
373        Ok(self.output)
374    }
375
376    /// Get the current size of the output buffer (including unflushed block data).
377    pub fn current_size(&self) -> usize {
378        self.output.len() + self.block_buf.len()
379    }
380
381    /// Get access to the raw block buffer (for testing/inspection).
382    pub fn block_buffer(&self) -> &[u8] {
383        &self.block_buf
384    }
385}
386
387impl Default for Encoder {
388    fn default() -> Self {
389        Self::new()
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn encode_null() {
399        let mut enc = Encoder::new();
400        enc.encode_value(&Value::Null).unwrap();
401        assert_eq!(enc.block_buffer(), &[0x00]); // WireType::Null
402    }
403
404    #[test]
405    fn encode_bool() {
406        let mut enc = Encoder::new();
407        enc.encode_value(&Value::Bool(true)).unwrap();
408        assert_eq!(enc.block_buffer(), &[0x01, 0x01]);
409        enc.block_buf.clear();
410        enc.encode_value(&Value::Bool(false)).unwrap();
411        assert_eq!(enc.block_buffer(), &[0x01, 0x00]);
412    }
413
414    #[test]
415    fn encode_uint_small() {
416        let mut enc = Encoder::new();
417        enc.encode_value(&Value::UInt(42)).unwrap();
418        assert_eq!(enc.block_buffer(), &[0x02, 42]); // WireType::VarUInt, 42
419    }
420
421    #[test]
422    fn encode_uint_large() {
423        let mut enc = Encoder::new();
424        enc.encode_value(&Value::UInt(300)).unwrap();
425        assert_eq!(enc.block_buffer(), &[0x02, 0xac, 0x02]);
426    }
427
428    #[test]
429    fn encode_int_negative() {
430        let mut enc = Encoder::new();
431        enc.encode_value(&Value::Int(-1)).unwrap();
432        // ZigZag(-1) = 1, LEB128(1) = 0x01
433        assert_eq!(enc.block_buffer(), &[0x03, 0x01]);
434    }
435
436    #[test]
437    fn encode_float() {
438        let mut enc = Encoder::new();
439        enc.encode_value(&Value::Float(3.125)).unwrap();
440        let mut expected = vec![0x04];
441        expected.extend_from_slice(&3.125f64.to_le_bytes());
442        assert_eq!(enc.block_buffer(), &expected);
443    }
444
445    #[test]
446    fn encode_string() {
447        let mut enc = Encoder::new();
448        enc.encode_value(&Value::Str("hello".into())).unwrap();
449        // WireType::LenDelimited (0x05) + sub-type 0x00 + length 5 + "hello"
450        let mut expected = vec![0x05, 0x00, 5];
451        expected.extend_from_slice(b"hello");
452        assert_eq!(enc.block_buffer(), &expected);
453    }
454
455    #[test]
456    fn encode_bytes() {
457        let mut enc = Encoder::new();
458        enc.encode_value(&Value::Bytes(vec![0xDE, 0xAD])).unwrap();
459        // WireType::LenDelimited (0x05) + sub-type 0x01 + length 2 + bytes
460        assert_eq!(enc.block_buffer(), &[0x05, 0x01, 2, 0xDE, 0xAD]);
461    }
462
463    #[test]
464    fn encode_array() {
465        let mut enc = Encoder::new();
466        let arr = Value::Array(vec![Value::UInt(1), Value::UInt(2)]);
467        enc.encode_value(&arr).unwrap();
468        // StartArray(0x08) + count(2) + UInt(1) + UInt(2) + EndArray(0x09)
469        assert_eq!(
470            enc.block_buffer(),
471            &[0x08, 0x02, 0x02, 0x01, 0x02, 0x02, 0x09]
472        );
473    }
474
475    #[test]
476    fn encode_object() {
477        let mut enc = Encoder::new();
478        let obj = Value::Object(vec![("x".into(), Value::UInt(10))]);
479        enc.encode_value(&obj).unwrap();
480        // StartObject(0x06) + count(1) + key_len(1) + "x" + UInt(10) + EndObject(0x07)
481        assert_eq!(
482            enc.block_buffer(),
483            &[0x06, 0x01, 0x01, b'x', 0x02, 0x0a, 0x07]
484        );
485    }
486
487    #[test]
488    fn finish_produces_valid_file() {
489        let mut enc = Encoder::new();
490        enc.encode_value(&Value::Null).unwrap();
491        let bytes = enc.finish().unwrap();
492        // Must start with magic
493        assert_eq!(&bytes[..7], b"CROUSv1");
494        // Trailer block is 19 bytes: type(1) + varint(1) + comp(1) + checksum(8) + payload(8)
495        assert_eq!(bytes[bytes.len() - 19], BlockType::Trailer as u8);
496    }
497
498    #[test]
499    fn nesting_depth_limit() {
500        let mut enc = Encoder::with_limits(Limits {
501            max_nesting_depth: 2,
502            ..Limits::default()
503        });
504        // Nest 3 levels deep — should fail
505        let val = Value::Array(vec![Value::Array(vec![Value::Array(vec![])])]);
506        assert!(enc.encode_value(&val).is_err());
507    }
508}