1use crate::error::{KodaError, Result};
4use crate::value::Value;
5use std::borrow::Cow;
6use std::collections::BTreeMap;
7
8const MAGIC: &[u8; 4] = b"KODA";
9const VERSION: u8 = 1;
10
11const TAG_NULL: u8 = 0x01;
12const TAG_FALSE: u8 = 0x02;
13const TAG_TRUE: u8 = 0x03;
14const TAG_INTEGER: u8 = 0x04;
15const TAG_FLOAT: u8 = 0x05;
16const TAG_STRING: u8 = 0x06;
17const TAG_ARRAY: u8 = 0x10;
18const TAG_OBJECT: u8 = 0x11;
19
20pub const DEFAULT_MAX_DEPTH: usize = 256;
22pub const DEFAULT_MAX_DICT_SIZE: usize = 65536;
24pub const DEFAULT_MAX_STRING_LENGTH: usize = 1_000_000;
26
27pub fn decode(data: &[u8]) -> Result<Value<'static>> {
31 decode_with_options(
32 data,
33 DecodeOptions {
34 max_depth: DEFAULT_MAX_DEPTH,
35 max_dict_size: DEFAULT_MAX_DICT_SIZE,
36 max_string_length: DEFAULT_MAX_STRING_LENGTH,
37 },
38 )
39}
40
41#[derive(Clone, Debug)]
43pub struct DecodeOptions {
44 pub max_depth: usize,
45 pub max_dict_size: usize,
46 pub max_string_length: usize,
47}
48
49pub fn decode_with_options(data: &[u8], opts: DecodeOptions) -> Result<Value<'static>> {
51 let mut d = Decoder {
52 buf: data,
53 off: 0,
54 max_depth: opts.max_depth,
55 max_dict: opts.max_dict_size,
56 max_str: opts.max_string_length,
57 dict: Vec::new(),
58 };
59 let value = d.decode_root()?;
60 if d.off != data.len() {
61 return Err(KodaError::decode("trailing bytes after root value"));
62 }
63 Ok(value)
64}
65
66struct Decoder<'a> {
67 buf: &'a [u8],
68 off: usize,
69 max_depth: usize,
70 max_dict: usize,
71 max_str: usize,
72 dict: Vec<String>,
73}
74
75impl Decoder<'_> {
76 fn fail(&self, msg: &str) -> KodaError {
77 KodaError::decode(format!("{} (at offset {})", msg, self.off))
78 }
79
80 fn ensure(&self, n: usize) -> Result<()> {
81 if self.off + n > self.buf.len() {
82 return Err(self.fail("truncated input"));
83 }
84 Ok(())
85 }
86
87 fn read_u8(&mut self) -> Result<u8> {
88 self.ensure(1)?;
89 let b = self.buf[self.off];
90 self.off += 1;
91 Ok(b)
92 }
93
94 fn read_u32(&mut self) -> Result<u32> {
95 self.ensure(4)?;
96 let v = u32::from_be_bytes(self.buf[self.off..self.off + 4].try_into().unwrap());
97 self.off += 4;
98 Ok(v)
99 }
100
101 fn read_i64(&mut self) -> Result<i64> {
102 self.ensure(8)?;
103 let v = i64::from_be_bytes(self.buf[self.off..self.off + 8].try_into().unwrap());
104 self.off += 8;
105 Ok(v)
106 }
107
108 fn read_f64(&mut self) -> Result<f64> {
109 self.ensure(8)?;
110 let bits = u64::from_be_bytes(self.buf[self.off..self.off + 8].try_into().unwrap());
111 self.off += 8;
112 Ok(f64::from_bits(bits))
113 }
114
115 fn decode_root(&mut self) -> Result<Value<'static>> {
116 self.ensure(5)?;
117 if &self.buf[self.off..self.off + 4] != MAGIC {
118 return Err(self.fail("invalid magic number"));
119 }
120 self.off += 4;
121 let ver = self.read_u8()?;
122 if ver != VERSION {
123 return Err(self.fail("unsupported version"));
124 }
125
126 let dict_len = self.read_u32()? as usize;
127 if dict_len > self.max_dict {
128 return Err(self.fail("dictionary too large"));
129 }
130 self.dict.clear();
131 self.dict.reserve(dict_len);
132 for _ in 0..dict_len {
133 let key_len = self.read_u32()? as usize;
134 if key_len > self.max_str {
135 return Err(self.fail("key string too long"));
136 }
137 self.ensure(key_len)?;
138 let key_bytes = &self.buf[self.off..self.off + key_len];
139 self.off += key_len;
140 let key = std::str::from_utf8(key_bytes)
141 .map_err(|_| self.fail("invalid UTF-8 in dictionary key"))?;
142 self.dict.push(key.to_string());
143 }
144
145 let value = self.decode_value(0)?;
146 Ok(value)
147 }
148
149 fn decode_value(&mut self, depth: usize) -> Result<Value<'static>> {
150 if depth > self.max_depth {
151 return Err(self.fail("maximum nesting depth exceeded"));
152 }
153 let tag = self.read_u8()?;
154 match tag {
155 TAG_NULL => Ok(Value::Null),
156 TAG_FALSE => Ok(Value::Bool(false)),
157 TAG_TRUE => Ok(Value::Bool(true)),
158 TAG_INTEGER => {
159 let n = self.read_i64()?;
160 Ok(Value::Number(n as f64))
161 }
162 TAG_FLOAT => Ok(Value::Number(self.read_f64()?)),
163 TAG_STRING => {
164 let length = self.read_u32()? as usize;
165 if length > self.max_str {
166 return Err(self.fail("string too long"));
167 }
168 self.ensure(length)?;
169 let b = &self.buf[self.off..self.off + length];
170 self.off += length;
171 let s = std::str::from_utf8(b).map_err(|_| self.fail("invalid UTF-8 in string"))?;
172 Ok(Value::String(Cow::Owned(s.to_string())))
173 }
174 TAG_ARRAY => {
175 let count = self.read_u32()?;
176 let mut arr = Vec::with_capacity(count as usize);
177 for _ in 0..count {
178 arr.push(self.decode_value(depth + 1)?);
179 }
180 Ok(Value::Array(arr))
181 }
182 TAG_OBJECT => {
183 let count = self.read_u32()?;
184 let mut obj = BTreeMap::new();
185 for _ in 0..count {
186 let key_idx = self.read_u32()? as usize;
187 if key_idx >= self.dict.len() {
188 return Err(self.fail("invalid key index"));
189 }
190 let key = Cow::Owned(self.dict[key_idx].clone());
191 let val = self.decode_value(depth + 1)?;
192 obj.insert(key, val);
193 }
194 Ok(Value::Object(obj))
195 }
196 _ => Err(self.fail("unknown type tag")),
197 }
198 }
199}
200
201#[cfg(feature = "parallel")]
206mod parallel {
207 use super::*;
208 use rayon::prelude::*;
209
210 const PARALLEL_THRESHOLD: usize = 128;
213
214 pub fn decode_parallel(bytes: &[u8]) -> Result<Value<'static>> {
219 decode_parallel_with_options(
220 bytes,
221 DecodeOptions {
222 max_depth: DEFAULT_MAX_DEPTH,
223 max_dict_size: DEFAULT_MAX_DICT_SIZE,
224 max_string_length: DEFAULT_MAX_STRING_LENGTH,
225 },
226 )
227 }
228
229 pub fn decode_parallel_with_options(
231 data: &[u8],
232 opts: DecodeOptions,
233 ) -> Result<Value<'static>> {
234 let (dict, value_start) = parse_header(data, &opts)?;
235 let value_slice = &data[value_start..];
236 if value_slice.is_empty() {
237 return Err(KodaError::decode("truncated input (no root value)"));
238 }
239 let root_len = match scan_value_extent(data, value_start, &dict, &opts, 0)? {
240 ValueExtent::Scalar(n) => n,
241 ValueExtent::Array { total, .. } | ValueExtent::Object { total, .. } => total,
242 };
243 if root_len != value_slice.len() {
244 return Err(KodaError::decode("trailing bytes after root value"));
245 }
246 let value = decode_value_parallel(data, value_start, value_slice, &dict, &opts, 0)?;
247 Ok(value)
248 }
249
250 fn parse_header(data: &[u8], opts: &DecodeOptions) -> Result<(Vec<String>, usize)> {
252 if data.len() < 5 {
253 return Err(KodaError::decode("truncated input"));
254 }
255 if &data[0..4] != MAGIC {
256 return Err(KodaError::decode("invalid magic number"));
257 }
258 if data[4] != VERSION {
259 return Err(KodaError::decode("unsupported version"));
260 }
261 let mut off = 5;
262 let dict_len = read_u32(data, &mut off)? as usize;
263 if dict_len > opts.max_dict_size {
264 return Err(KodaError::decode("dictionary too large"));
265 }
266 let mut dict = Vec::with_capacity(dict_len);
267 for _ in 0..dict_len {
268 let key_len = read_u32(data, &mut off)? as usize;
269 if key_len > opts.max_string_length {
270 return Err(KodaError::decode("key string too long"));
271 }
272 if off + key_len > data.len() {
273 return Err(KodaError::decode("truncated input"));
274 }
275 let key_bytes = &data[off..off + key_len];
276 off += key_len;
277 let key = std::str::from_utf8(key_bytes)
278 .map_err(|_| KodaError::decode("invalid UTF-8 in dictionary key"))?;
279 dict.push(key.to_string());
280 }
281 Ok((dict, off))
282 }
283
284 fn read_u32(buf: &[u8], off: &mut usize) -> Result<u32> {
285 if *off + 4 > buf.len() {
286 return Err(KodaError::decode("truncated input"));
287 }
288 let v = u32::from_be_bytes(buf[*off..*off + 4].try_into().unwrap());
289 *off += 4;
290 Ok(v)
291 }
292
293 enum ValueExtent {
295 Scalar(usize),
296 Array {
297 total: usize,
298 children: Vec<(usize, usize)>,
299 },
300 Object {
301 total: usize,
302 children: Vec<(u32, usize, usize)>, },
304 }
305
306 fn scan_value_extent(
307 buf: &[u8],
308 base: usize,
309 dict: &[String],
310 opts: &DecodeOptions,
311 depth: usize,
312 ) -> Result<ValueExtent> {
313 if depth > opts.max_depth {
314 return Err(KodaError::decode("maximum nesting depth exceeded"));
315 }
316 let mut off = base;
317 if off >= buf.len() {
318 return Err(KodaError::decode("truncated input"));
319 }
320 let tag = buf[off];
321 off += 1;
322 match tag {
323 TAG_NULL | TAG_FALSE | TAG_TRUE => Ok(ValueExtent::Scalar(off - base)),
324 TAG_INTEGER | TAG_FLOAT => {
325 if off + 8 > buf.len() {
326 return Err(KodaError::decode("truncated input"));
327 }
328 Ok(ValueExtent::Scalar(off - base + 8))
329 }
330 TAG_STRING => {
331 let length = read_u32(buf, &mut off)? as usize;
332 if length > opts.max_string_length {
333 return Err(KodaError::decode("string too long"));
334 }
335 if off + length > buf.len() {
336 return Err(KodaError::decode("truncated input"));
337 }
338 Ok(ValueExtent::Scalar(off - base + length))
339 }
340 TAG_ARRAY => {
341 let count = read_u32(buf, &mut off)? as usize;
342 let mut children = Vec::with_capacity(count);
343 for _ in 0..count {
344 let start = off;
345 let (child_len, _) = scan_one(buf, off, dict, opts, depth + 1)?;
346 off += child_len;
347 children.push((start, child_len));
348 }
349 Ok(ValueExtent::Array {
350 total: off - base,
351 children,
352 })
353 }
354 TAG_OBJECT => {
355 let count = read_u32(buf, &mut off)? as usize;
356 let mut children = Vec::with_capacity(count);
357 for _ in 0..count {
358 let key_idx = read_u32(buf, &mut off)?;
359 if key_idx as usize >= dict.len() {
360 return Err(KodaError::decode("invalid key index"));
361 }
362 let start = off;
363 let (child_len, _) = scan_one(buf, off, dict, opts, depth + 1)?;
364 off += child_len;
365 children.push((key_idx, start, child_len));
366 }
367 Ok(ValueExtent::Object {
368 total: off - base,
369 children,
370 })
371 }
372 _ => Err(KodaError::decode("unknown type tag")),
373 }
374 }
375
376 fn scan_one(
378 buf: &[u8],
379 off: usize,
380 dict: &[String],
381 opts: &DecodeOptions,
382 depth: usize,
383 ) -> Result<(usize, ())> {
384 let ext = scan_value_extent(buf, off, dict, opts, depth)?;
385 let len = match &ext {
386 ValueExtent::Scalar(n) => *n,
387 ValueExtent::Array { total, .. } | ValueExtent::Object { total, .. } => *total,
388 };
389 Ok((len, ()))
390 }
391
392 fn decode_value_from_slice(
394 buf: &[u8],
395 base: usize,
396 slice: &[u8],
397 dict: &[String],
398 opts: &DecodeOptions,
399 depth: usize,
400 ) -> Result<Value<'static>> {
401 let mut d = Decoder {
402 buf,
403 off: base,
404 max_depth: opts.max_depth,
405 max_dict: opts.max_dict_size,
406 max_str: opts.max_string_length,
407 dict: dict.to_vec(),
408 };
409 let value = d.decode_value(depth)?;
410 if d.off != base + slice.len() {
411 return Err(KodaError::decode("internal: slice length mismatch"));
412 }
413 Ok(value)
414 }
415
416 fn decode_value_parallel(
417 buf: &[u8],
418 base: usize,
419 slice: &[u8],
420 dict: &[String],
421 opts: &DecodeOptions,
422 depth: usize,
423 ) -> Result<Value<'static>> {
424 if base >= buf.len() {
425 return Err(KodaError::decode("truncated input"));
426 }
427 let tag = buf[base];
428 match tag {
429 TAG_NULL | TAG_FALSE | TAG_TRUE | TAG_INTEGER | TAG_FLOAT | TAG_STRING => {
430 decode_value_from_slice(buf, base, slice, dict, opts, depth)
431 }
432 TAG_ARRAY => {
433 if base + 5 > buf.len() {
434 return Err(KodaError::decode("truncated input"));
435 }
436 let count =
437 u32::from_be_bytes(buf[base + 1..base + 5].try_into().unwrap()) as usize;
438 if count < PARALLEL_THRESHOLD {
439 decode_value_from_slice(buf, base, slice, dict, opts, depth)
440 } else {
441 let extent = scan_value_extent(buf, base, dict, opts, depth)?;
442 let ValueExtent::Array { children, .. } = extent else {
443 return decode_value_from_slice(buf, base, slice, dict, opts, depth);
444 };
445 let decoded: Result<Vec<_>> = children
446 .par_iter()
447 .map(|&(start, len)| {
448 let s = &buf[start..start + len];
449 decode_value_parallel(buf, start, s, dict, opts, depth + 1)
450 })
451 .collect();
452 let arr = decoded?;
453 Ok(Value::Array(arr))
454 }
455 }
456 TAG_OBJECT => {
457 if base + 5 > buf.len() {
458 return Err(KodaError::decode("truncated input"));
459 }
460 let count =
461 u32::from_be_bytes(buf[base + 1..base + 5].try_into().unwrap()) as usize;
462 if count < PARALLEL_THRESHOLD {
463 decode_value_from_slice(buf, base, slice, dict, opts, depth)
464 } else {
465 let extent = scan_value_extent(buf, base, dict, opts, depth)?;
466 let ValueExtent::Object { children, .. } = extent else {
467 return decode_value_from_slice(buf, base, slice, dict, opts, depth);
468 };
469 let decoded: Result<Vec<_>> = children
470 .par_iter()
471 .map(|&(key_idx, start, len)| {
472 let s = &buf[start..start + len];
473 decode_value_parallel(buf, start, s, dict, opts, depth + 1)
474 .map(|v| (key_idx, v))
475 })
476 .collect();
477 let pairs = decoded?;
478 let mut obj = BTreeMap::new();
479 for (key_idx, val) in pairs {
480 let key = Cow::Owned(dict[key_idx as usize].clone());
481 obj.insert(key, val);
482 }
483 Ok(Value::Object(obj))
484 }
485 }
486 _ => decode_value_from_slice(buf, base, slice, dict, opts, depth),
487 }
488 }
489}
490
491#[cfg(feature = "parallel")]
492pub use parallel::{decode_parallel, decode_parallel_with_options};
493
494#[cfg(all(test, feature = "parallel"))]
495mod parallel_tests {
496 use super::*;
497 use std::borrow::Cow;
498 use std::collections::BTreeMap;
499
500 #[test]
501 fn decode_parallel_matches_decode() {
502 let mut m = BTreeMap::new();
503 m.insert(
504 Cow::Owned("name".to_string()),
505 Value::String(Cow::Owned("test".to_string())),
506 );
507 m.insert(Cow::Owned("count".to_string()), Value::Number(42.0));
508 let inner: Vec<Value<'static>> = (0..150)
509 .map(|i| {
510 Value::Object({
511 let mut o = BTreeMap::new();
512 o.insert(Cow::Owned("x".to_string()), Value::Number(i as f64));
513 o
514 })
515 })
516 .collect();
517 m.insert(Cow::Owned("items".to_string()), Value::Array(inner));
518 let value = Value::Object(m);
519 let bytes = crate::encoder::encode(&value).unwrap();
520 let decoded_seq = decode(&bytes).unwrap();
521 let decoded_par = decode_parallel(&bytes).unwrap();
522 assert_eq!(decoded_seq, decoded_par);
523 }
524}