1use std::collections::HashMap;
12
13use crate::checksum::compute_xxh64;
14use crate::error::{CrousError, Result};
15use crate::header::{FLAGS_NONE, FileHeader};
16use crate::limits::Limits;
17use crate::value::Value;
18use crate::varint::encode_varint_vec;
19use crate::wire::{BlockType, CompressionType, WireType};
20
21pub struct Encoder {
33 output: Vec<u8>,
35 block_buf: Vec<u8>,
37 depth: usize,
39 limits: Limits,
41 header_written: bool,
43 flags: u8,
45 compression: CompressionType,
47 string_dict: HashMap<String, u32>,
50 dedup_strings: bool,
52}
53
54impl Encoder {
55 pub fn new() -> Self {
57 Self {
58 output: Vec::with_capacity(4096),
59 block_buf: Vec::with_capacity(4096),
60 depth: 0,
61 limits: Limits::default(),
62 header_written: false,
63 flags: FLAGS_NONE,
64 compression: CompressionType::None,
65 string_dict: HashMap::new(),
66 dedup_strings: false,
67 }
68 }
69
70 pub fn with_limits(limits: Limits) -> Self {
72 Self {
73 limits,
74 ..Self::new()
75 }
76 }
77
78 pub fn enable_dedup(&mut self) {
81 self.dedup_strings = true;
82 }
83
84 pub fn set_compression(&mut self, comp: CompressionType) {
86 self.compression = comp;
87 }
88
89 pub fn set_flags(&mut self, flags: u8) {
91 self.flags = flags;
92 }
93
94 fn ensure_header(&mut self) {
96 if !self.header_written {
97 let header = FileHeader::new(self.flags);
98 self.output.extend_from_slice(&header.encode());
99 self.header_written = true;
100 }
101 }
102
103 pub fn encode_value(&mut self, value: &Value) -> Result<()> {
108 self.encode_value_inner(value)
109 }
110
111 fn encode_value_inner(&mut self, value: &Value) -> Result<()> {
112 match value {
113 Value::Null => {
114 self.block_buf.push(WireType::Null.to_tag());
115 }
116 Value::Bool(b) => {
117 self.block_buf.push(WireType::Bool.to_tag());
118 self.block_buf.push(if *b { 0x01 } else { 0x00 });
119 }
120 Value::UInt(n) => {
121 self.block_buf.push(WireType::VarUInt.to_tag());
122 encode_varint_vec(*n, &mut self.block_buf);
123 }
124 Value::Int(n) => {
125 self.block_buf.push(WireType::VarInt.to_tag());
126 crate::varint::encode_signed_varint_vec(*n, &mut self.block_buf);
127 }
128 Value::Float(f) => {
129 self.block_buf.push(WireType::Fixed64.to_tag());
130 self.block_buf.extend_from_slice(&f.to_le_bytes());
131 }
132 Value::Str(s) => {
133 if self.dedup_strings {
134 if let Some(&idx) = self.string_dict.get(s.as_str()) {
135 self.block_buf.push(WireType::Reference.to_tag());
137 encode_varint_vec(idx as u64, &mut self.block_buf);
138 return Ok(());
139 }
140 let idx = self.string_dict.len() as u32;
142 self.string_dict.insert(s.clone(), idx);
143 }
144 self.block_buf.push(WireType::LenDelimited.to_tag());
145 self.block_buf.push(0x00);
147 encode_varint_vec(s.len() as u64, &mut self.block_buf);
148 self.block_buf.extend_from_slice(s.as_bytes());
149 }
150 Value::Bytes(b) => {
151 self.block_buf.push(WireType::LenDelimited.to_tag());
152 self.block_buf.push(0x01);
154 encode_varint_vec(b.len() as u64, &mut self.block_buf);
155 self.block_buf.extend_from_slice(b);
156 }
157 Value::Array(items) => {
158 if self.depth >= self.limits.max_nesting_depth {
159 return Err(CrousError::NestingTooDeep(
160 self.depth,
161 self.limits.max_nesting_depth,
162 ));
163 }
164 if items.len() > self.limits.max_items {
165 return Err(CrousError::TooManyItems(items.len(), self.limits.max_items));
166 }
167 self.block_buf.push(WireType::StartArray.to_tag());
168 encode_varint_vec(items.len() as u64, &mut self.block_buf);
170 self.depth += 1;
171 for item in items {
172 self.encode_value_inner(item)?;
173 }
174 self.depth -= 1;
175 self.block_buf.push(WireType::EndArray.to_tag());
176 }
177 Value::Object(entries) => {
178 if self.depth >= self.limits.max_nesting_depth {
179 return Err(CrousError::NestingTooDeep(
180 self.depth,
181 self.limits.max_nesting_depth,
182 ));
183 }
184 if entries.len() > self.limits.max_items {
185 return Err(CrousError::TooManyItems(
186 entries.len(),
187 self.limits.max_items,
188 ));
189 }
190 self.block_buf.push(WireType::StartObject.to_tag());
191 encode_varint_vec(entries.len() as u64, &mut self.block_buf);
193 self.depth += 1;
194 for (key, val) in entries {
195 encode_varint_vec(key.len() as u64, &mut self.block_buf);
197 self.block_buf.extend_from_slice(key.as_bytes());
198 self.encode_value_inner(val)?;
200 }
201 self.depth -= 1;
202 self.block_buf.push(WireType::EndObject.to_tag());
203 }
204 }
205 Ok(())
206 }
207
208 pub fn flush_block(&mut self) -> Result<usize> {
222 if self.block_buf.is_empty() {
223 return Ok(0);
224 }
225
226 self.ensure_header();
227
228 let mut total_size = 0;
229
230 if self.dedup_strings && !self.string_dict.is_empty() {
232 let dict_payload = self.encode_string_dict_payload();
233 let dict_checksum = compute_xxh64(&dict_payload);
234
235 self.output.push(BlockType::StringDict as u8);
236 encode_varint_vec(dict_payload.len() as u64, &mut self.output);
237 self.output.push(CompressionType::None as u8);
238 self.output.extend_from_slice(&dict_checksum.to_le_bytes());
239 self.output.extend_from_slice(&dict_payload);
240
241 total_size += 1 + 1 + 1 + 8 + dict_payload.len();
242 }
243
244 let checksum = compute_xxh64(&self.block_buf);
248
249 let (wire_payload, wire_comp) = if self.compression != CompressionType::None {
251 match self.compress_payload(&self.block_buf) {
252 Some(compressed) if compressed.len() < self.block_buf.len() => {
253 let mut framed = Vec::with_capacity(10 + compressed.len());
256 encode_varint_vec(self.block_buf.len() as u64, &mut framed);
257 framed.extend_from_slice(&compressed);
258 (framed, self.compression)
259 }
260 _ => {
261 (self.block_buf.clone(), CompressionType::None)
263 }
264 }
265 } else {
266 (self.block_buf.clone(), CompressionType::None)
267 };
268
269 let block_type = BlockType::Data as u8;
272
273 self.output.push(block_type);
274 encode_varint_vec(wire_payload.len() as u64, &mut self.output);
275 self.output.push(wire_comp as u8);
276 self.output.extend_from_slice(&checksum.to_le_bytes());
277 self.output.extend_from_slice(&wire_payload);
278
279 total_size += 1 + 1 + 8 + wire_payload.len();
280 self.block_buf.clear();
281 self.string_dict.clear(); Ok(total_size)
283 }
284
285 fn encode_string_dict_payload(&self) -> Vec<u8> {
298 let mut entries: Vec<(&str, u32)> = self
300 .string_dict
301 .iter()
302 .map(|(s, &idx)| (s.as_str(), idx))
303 .collect();
304 entries.sort_by(|a, b| a.0.cmp(b.0));
305
306 let mut payload = Vec::with_capacity(entries.len() * 16);
307 encode_varint_vec(entries.len() as u64, &mut payload);
308
309 let mut prev = "";
310 for (s, original_idx) in &entries {
311 let prefix_len = s
313 .as_bytes()
314 .iter()
315 .zip(prev.as_bytes().iter())
316 .take_while(|(a, b)| a == b)
317 .count();
318 let suffix = &s.as_bytes()[prefix_len..];
319
320 encode_varint_vec(*original_idx as u64, &mut payload);
321 encode_varint_vec(prefix_len as u64, &mut payload);
322 encode_varint_vec(suffix.len() as u64, &mut payload);
323 payload.extend_from_slice(suffix);
324
325 prev = s;
326 }
327 payload
328 }
329
330 #[allow(unused_variables)]
333 fn compress_payload(&self, data: &[u8]) -> Option<Vec<u8>> {
334 match self.compression {
335 CompressionType::None => None,
336 #[cfg(feature = "zstd")]
337 CompressionType::Zstd => zstd::encode_all(std::io::Cursor::new(data), 3).ok(),
338 #[cfg(not(feature = "zstd"))]
339 CompressionType::Zstd => None,
340 #[cfg(feature = "snappy")]
341 CompressionType::Snappy => {
342 let mut enc = snap::raw::Encoder::new();
343 enc.compress_vec(data).ok()
344 }
345 #[cfg(not(feature = "snappy"))]
346 CompressionType::Snappy => None,
347 #[cfg(feature = "lz4")]
348 CompressionType::Lz4 => Some(lz4_flex::compress_prepend_size(data)),
349 #[cfg(not(feature = "lz4"))]
350 CompressionType::Lz4 => None,
351 }
352 }
353
354 pub fn finish(mut self) -> Result<Vec<u8>> {
358 self.flush_block()?;
359 self.ensure_header();
360
361 let overall_checksum = compute_xxh64(&self.output);
363 self.output.push(BlockType::Trailer as u8);
365 encode_varint_vec(8, &mut self.output);
366 self.output.push(CompressionType::None as u8);
367 let trailer_checksum = compute_xxh64(&overall_checksum.to_le_bytes());
368 self.output
369 .extend_from_slice(&trailer_checksum.to_le_bytes());
370 self.output
371 .extend_from_slice(&overall_checksum.to_le_bytes());
372
373 Ok(self.output)
374 }
375
376 pub fn current_size(&self) -> usize {
378 self.output.len() + self.block_buf.len()
379 }
380
381 pub fn block_buffer(&self) -> &[u8] {
383 &self.block_buf
384 }
385}
386
387impl Default for Encoder {
388 fn default() -> Self {
389 Self::new()
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn encode_null() {
399 let mut enc = Encoder::new();
400 enc.encode_value(&Value::Null).unwrap();
401 assert_eq!(enc.block_buffer(), &[0x00]); }
403
404 #[test]
405 fn encode_bool() {
406 let mut enc = Encoder::new();
407 enc.encode_value(&Value::Bool(true)).unwrap();
408 assert_eq!(enc.block_buffer(), &[0x01, 0x01]);
409 enc.block_buf.clear();
410 enc.encode_value(&Value::Bool(false)).unwrap();
411 assert_eq!(enc.block_buffer(), &[0x01, 0x00]);
412 }
413
414 #[test]
415 fn encode_uint_small() {
416 let mut enc = Encoder::new();
417 enc.encode_value(&Value::UInt(42)).unwrap();
418 assert_eq!(enc.block_buffer(), &[0x02, 42]); }
420
421 #[test]
422 fn encode_uint_large() {
423 let mut enc = Encoder::new();
424 enc.encode_value(&Value::UInt(300)).unwrap();
425 assert_eq!(enc.block_buffer(), &[0x02, 0xac, 0x02]);
426 }
427
428 #[test]
429 fn encode_int_negative() {
430 let mut enc = Encoder::new();
431 enc.encode_value(&Value::Int(-1)).unwrap();
432 assert_eq!(enc.block_buffer(), &[0x03, 0x01]);
434 }
435
436 #[test]
437 fn encode_float() {
438 let mut enc = Encoder::new();
439 enc.encode_value(&Value::Float(3.125)).unwrap();
440 let mut expected = vec![0x04];
441 expected.extend_from_slice(&3.125f64.to_le_bytes());
442 assert_eq!(enc.block_buffer(), &expected);
443 }
444
445 #[test]
446 fn encode_string() {
447 let mut enc = Encoder::new();
448 enc.encode_value(&Value::Str("hello".into())).unwrap();
449 let mut expected = vec![0x05, 0x00, 5];
451 expected.extend_from_slice(b"hello");
452 assert_eq!(enc.block_buffer(), &expected);
453 }
454
455 #[test]
456 fn encode_bytes() {
457 let mut enc = Encoder::new();
458 enc.encode_value(&Value::Bytes(vec![0xDE, 0xAD])).unwrap();
459 assert_eq!(enc.block_buffer(), &[0x05, 0x01, 2, 0xDE, 0xAD]);
461 }
462
463 #[test]
464 fn encode_array() {
465 let mut enc = Encoder::new();
466 let arr = Value::Array(vec![Value::UInt(1), Value::UInt(2)]);
467 enc.encode_value(&arr).unwrap();
468 assert_eq!(
470 enc.block_buffer(),
471 &[0x08, 0x02, 0x02, 0x01, 0x02, 0x02, 0x09]
472 );
473 }
474
475 #[test]
476 fn encode_object() {
477 let mut enc = Encoder::new();
478 let obj = Value::Object(vec![("x".into(), Value::UInt(10))]);
479 enc.encode_value(&obj).unwrap();
480 assert_eq!(
482 enc.block_buffer(),
483 &[0x06, 0x01, 0x01, b'x', 0x02, 0x0a, 0x07]
484 );
485 }
486
487 #[test]
488 fn finish_produces_valid_file() {
489 let mut enc = Encoder::new();
490 enc.encode_value(&Value::Null).unwrap();
491 let bytes = enc.finish().unwrap();
492 assert_eq!(&bytes[..7], b"CROUSv1");
494 assert_eq!(bytes[bytes.len() - 19], BlockType::Trailer as u8);
496 }
497
498 #[test]
499 fn nesting_depth_limit() {
500 let mut enc = Encoder::with_limits(Limits {
501 max_nesting_depth: 2,
502 ..Limits::default()
503 });
504 let val = Value::Array(vec![Value::Array(vec![Value::Array(vec![])])]);
506 assert!(enc.encode_value(&val).is_err());
507 }
508}