Skip to main content

koda_rs/
decoder.rs

1//! Binary decoding from KODA wire format (Go-compatible).
2
3use crate::error::{KodaError, Result};
4use crate::value::Value;
5use std::borrow::Cow;
6use std::collections::BTreeMap;
7
8const MAGIC: &[u8; 4] = b"KODA";
9const VERSION: u8 = 1;
10
11const TAG_NULL: u8 = 0x01;
12const TAG_FALSE: u8 = 0x02;
13const TAG_TRUE: u8 = 0x03;
14const TAG_INTEGER: u8 = 0x04;
15const TAG_FLOAT: u8 = 0x05;
16const TAG_STRING: u8 = 0x06;
17const TAG_ARRAY: u8 = 0x10;
18const TAG_OBJECT: u8 = 0x11;
19
20/// Default maximum nesting depth.
21pub const DEFAULT_MAX_DEPTH: usize = 256;
22/// Default maximum dictionary size (number of keys).
23pub const DEFAULT_MAX_DICT_SIZE: usize = 65536;
24/// Default maximum string length (bytes).
25pub const DEFAULT_MAX_STRING_LENGTH: usize = 1_000_000;
26
27/// Decodes KODA binary data into an owned `Value<'static>`.
28///
29/// Uses default limits. Fails if magic/version are wrong or trailing bytes remain.
30pub fn decode(data: &[u8]) -> Result<Value<'static>> {
31    decode_with_options(
32        data,
33        DecodeOptions {
34            max_depth: DEFAULT_MAX_DEPTH,
35            max_dict_size: DEFAULT_MAX_DICT_SIZE,
36            max_string_length: DEFAULT_MAX_STRING_LENGTH,
37        },
38    )
39}
40
41/// Decoding limits.
42#[derive(Clone, Debug)]
43pub struct DecodeOptions {
44    pub max_depth: usize,
45    pub max_dict_size: usize,
46    pub max_string_length: usize,
47}
48
49/// Decodes with custom limits.
50pub fn decode_with_options(data: &[u8], opts: DecodeOptions) -> Result<Value<'static>> {
51    let mut d = Decoder {
52        buf: data,
53        off: 0,
54        max_depth: opts.max_depth,
55        max_dict: opts.max_dict_size,
56        max_str: opts.max_string_length,
57        dict: Vec::new(),
58    };
59    let value = d.decode_root()?;
60    if d.off != data.len() {
61        return Err(KodaError::decode("trailing bytes after root value"));
62    }
63    Ok(value)
64}
65
66struct Decoder<'a> {
67    buf: &'a [u8],
68    off: usize,
69    max_depth: usize,
70    max_dict: usize,
71    max_str: usize,
72    dict: Vec<String>,
73}
74
75impl Decoder<'_> {
76    fn fail(&self, msg: &str) -> KodaError {
77        KodaError::decode(format!("{} (at offset {})", msg, self.off))
78    }
79
80    fn ensure(&self, n: usize) -> Result<()> {
81        if self.off + n > self.buf.len() {
82            return Err(self.fail("truncated input"));
83        }
84        Ok(())
85    }
86
87    fn read_u8(&mut self) -> Result<u8> {
88        self.ensure(1)?;
89        let b = self.buf[self.off];
90        self.off += 1;
91        Ok(b)
92    }
93
94    fn read_u32(&mut self) -> Result<u32> {
95        self.ensure(4)?;
96        let v = u32::from_be_bytes(self.buf[self.off..self.off + 4].try_into().unwrap());
97        self.off += 4;
98        Ok(v)
99    }
100
101    fn read_i64(&mut self) -> Result<i64> {
102        self.ensure(8)?;
103        let v = i64::from_be_bytes(self.buf[self.off..self.off + 8].try_into().unwrap());
104        self.off += 8;
105        Ok(v)
106    }
107
108    fn read_f64(&mut self) -> Result<f64> {
109        self.ensure(8)?;
110        let bits = u64::from_be_bytes(self.buf[self.off..self.off + 8].try_into().unwrap());
111        self.off += 8;
112        Ok(f64::from_bits(bits))
113    }
114
115    fn decode_root(&mut self) -> Result<Value<'static>> {
116        self.ensure(5)?;
117        if &self.buf[self.off..self.off + 4] != MAGIC {
118            return Err(self.fail("invalid magic number"));
119        }
120        self.off += 4;
121        let ver = self.read_u8()?;
122        if ver != VERSION {
123            return Err(self.fail("unsupported version"));
124        }
125
126        let dict_len = self.read_u32()? as usize;
127        if dict_len > self.max_dict {
128            return Err(self.fail("dictionary too large"));
129        }
130        self.dict.clear();
131        self.dict.reserve(dict_len);
132        for _ in 0..dict_len {
133            let key_len = self.read_u32()? as usize;
134            if key_len > self.max_str {
135                return Err(self.fail("key string too long"));
136            }
137            self.ensure(key_len)?;
138            let key_bytes = &self.buf[self.off..self.off + key_len];
139            self.off += key_len;
140            let key = std::str::from_utf8(key_bytes)
141                .map_err(|_| self.fail("invalid UTF-8 in dictionary key"))?;
142            self.dict.push(key.to_string());
143        }
144
145        let value = self.decode_value(0)?;
146        Ok(value)
147    }
148
149    fn decode_value(&mut self, depth: usize) -> Result<Value<'static>> {
150        if depth > self.max_depth {
151            return Err(self.fail("maximum nesting depth exceeded"));
152        }
153        let tag = self.read_u8()?;
154        match tag {
155            TAG_NULL => Ok(Value::Null),
156            TAG_FALSE => Ok(Value::Bool(false)),
157            TAG_TRUE => Ok(Value::Bool(true)),
158            TAG_INTEGER => {
159                let n = self.read_i64()?;
160                Ok(Value::Number(n as f64))
161            }
162            TAG_FLOAT => Ok(Value::Number(self.read_f64()?)),
163            TAG_STRING => {
164                let length = self.read_u32()? as usize;
165                if length > self.max_str {
166                    return Err(self.fail("string too long"));
167                }
168                self.ensure(length)?;
169                let b = &self.buf[self.off..self.off + length];
170                self.off += length;
171                let s = std::str::from_utf8(b).map_err(|_| self.fail("invalid UTF-8 in string"))?;
172                Ok(Value::String(Cow::Owned(s.to_string())))
173            }
174            TAG_ARRAY => {
175                let count = self.read_u32()?;
176                let mut arr = Vec::with_capacity(count as usize);
177                for _ in 0..count {
178                    arr.push(self.decode_value(depth + 1)?);
179                }
180                Ok(Value::Array(arr))
181            }
182            TAG_OBJECT => {
183                let count = self.read_u32()?;
184                let mut obj = BTreeMap::new();
185                for _ in 0..count {
186                    let key_idx = self.read_u32()? as usize;
187                    if key_idx >= self.dict.len() {
188                        return Err(self.fail("invalid key index"));
189                    }
190                    let key = Cow::Owned(self.dict[key_idx].clone());
191                    let val = self.decode_value(depth + 1)?;
192                    obj.insert(key, val);
193                }
194                Ok(Value::Object(obj))
195            }
196            _ => Err(self.fail("unknown type tag")),
197        }
198    }
199}
200
201// -----------------------------------------------------------------------------
202// Parallel decoding (optional rayon feature)
203// -----------------------------------------------------------------------------
204
205#[cfg(feature = "parallel")]
206mod parallel {
207    use super::*;
208    use rayon::prelude::*;
209
210    /// Minimum number of array/object children to trigger parallel decoding.
211    /// Set to balance parallelism vs rayon task overhead.
212    const PARALLEL_THRESHOLD: usize = 128;
213
214    /// Decodes KODA binary data in parallel, producing identical output to `decode`.
215    ///
216    /// Uses rayon to parallelize decoding of large arrays and objects.
217    /// Requires the `parallel` feature.
218    pub fn decode_parallel(bytes: &[u8]) -> Result<Value<'static>> {
219        decode_parallel_with_options(
220            bytes,
221            DecodeOptions {
222                max_depth: DEFAULT_MAX_DEPTH,
223                max_dict_size: DEFAULT_MAX_DICT_SIZE,
224                max_string_length: DEFAULT_MAX_STRING_LENGTH,
225            },
226        )
227    }
228
229    /// Decodes in parallel with custom limits.
230    pub fn decode_parallel_with_options(
231        data: &[u8],
232        opts: DecodeOptions,
233    ) -> Result<Value<'static>> {
234        let (dict, value_start) = parse_header(data, &opts)?;
235        let value_slice = &data[value_start..];
236        if value_slice.is_empty() {
237            return Err(KodaError::decode("truncated input (no root value)"));
238        }
239        let root_len = match scan_value_extent(data, value_start, &dict, &opts, 0)? {
240            ValueExtent::Scalar(n) => n,
241            ValueExtent::Array { total, .. } | ValueExtent::Object { total, .. } => total,
242        };
243        if root_len != value_slice.len() {
244            return Err(KodaError::decode("trailing bytes after root value"));
245        }
246        let value = decode_value_parallel(data, value_start, value_slice, &dict, &opts, 0)?;
247        Ok(value)
248    }
249
250    /// Parses magic, version, and dictionary; returns (dict, offset where root value starts).
251    fn parse_header(data: &[u8], opts: &DecodeOptions) -> Result<(Vec<String>, usize)> {
252        if data.len() < 5 {
253            return Err(KodaError::decode("truncated input"));
254        }
255        if &data[0..4] != MAGIC {
256            return Err(KodaError::decode("invalid magic number"));
257        }
258        if data[4] != VERSION {
259            return Err(KodaError::decode("unsupported version"));
260        }
261        let mut off = 5;
262        let dict_len = read_u32(data, &mut off)? as usize;
263        if dict_len > opts.max_dict_size {
264            return Err(KodaError::decode("dictionary too large"));
265        }
266        let mut dict = Vec::with_capacity(dict_len);
267        for _ in 0..dict_len {
268            let key_len = read_u32(data, &mut off)? as usize;
269            if key_len > opts.max_string_length {
270                return Err(KodaError::decode("key string too long"));
271            }
272            if off + key_len > data.len() {
273                return Err(KodaError::decode("truncated input"));
274            }
275            let key_bytes = &data[off..off + key_len];
276            off += key_len;
277            let key = std::str::from_utf8(key_bytes)
278                .map_err(|_| KodaError::decode("invalid UTF-8 in dictionary key"))?;
279            dict.push(key.to_string());
280        }
281        Ok((dict, off))
282    }
283
284    fn read_u32(buf: &[u8], off: &mut usize) -> Result<u32> {
285        if *off + 4 > buf.len() {
286            return Err(KodaError::decode("truncated input"));
287        }
288        let v = u32::from_be_bytes(buf[*off..*off + 4].try_into().unwrap());
289        *off += 4;
290        Ok(v)
291    }
292
293    /// Extent of a value: total byte length and optional child extents for arrays/objects.
294    enum ValueExtent {
295        Scalar(usize),
296        Array {
297            total: usize,
298            children: Vec<(usize, usize)>,
299        },
300        Object {
301            total: usize,
302            children: Vec<(u32, usize, usize)>, // (key_idx, offset, len)
303        },
304    }
305
306    fn scan_value_extent(
307        buf: &[u8],
308        base: usize,
309        dict: &[String],
310        opts: &DecodeOptions,
311        depth: usize,
312    ) -> Result<ValueExtent> {
313        if depth > opts.max_depth {
314            return Err(KodaError::decode("maximum nesting depth exceeded"));
315        }
316        let mut off = base;
317        if off >= buf.len() {
318            return Err(KodaError::decode("truncated input"));
319        }
320        let tag = buf[off];
321        off += 1;
322        match tag {
323            TAG_NULL | TAG_FALSE | TAG_TRUE => Ok(ValueExtent::Scalar(off - base)),
324            TAG_INTEGER | TAG_FLOAT => {
325                if off + 8 > buf.len() {
326                    return Err(KodaError::decode("truncated input"));
327                }
328                Ok(ValueExtent::Scalar(off - base + 8))
329            }
330            TAG_STRING => {
331                let length = read_u32(buf, &mut off)? as usize;
332                if length > opts.max_string_length {
333                    return Err(KodaError::decode("string too long"));
334                }
335                if off + length > buf.len() {
336                    return Err(KodaError::decode("truncated input"));
337                }
338                Ok(ValueExtent::Scalar(off - base + length))
339            }
340            TAG_ARRAY => {
341                let count = read_u32(buf, &mut off)? as usize;
342                let mut children = Vec::with_capacity(count);
343                for _ in 0..count {
344                    let start = off;
345                    let (child_len, _) = scan_one(buf, off, dict, opts, depth + 1)?;
346                    off += child_len;
347                    children.push((start, child_len));
348                }
349                Ok(ValueExtent::Array {
350                    total: off - base,
351                    children,
352                })
353            }
354            TAG_OBJECT => {
355                let count = read_u32(buf, &mut off)? as usize;
356                let mut children = Vec::with_capacity(count);
357                for _ in 0..count {
358                    let key_idx = read_u32(buf, &mut off)?;
359                    if key_idx as usize >= dict.len() {
360                        return Err(KodaError::decode("invalid key index"));
361                    }
362                    let start = off;
363                    let (child_len, _) = scan_one(buf, off, dict, opts, depth + 1)?;
364                    off += child_len;
365                    children.push((key_idx, start, child_len));
366                }
367                Ok(ValueExtent::Object {
368                    total: off - base,
369                    children,
370                })
371            }
372            _ => Err(KodaError::decode("unknown type tag")),
373        }
374    }
375
376    /// Returns (byte length, ()) for a value at the given offset.
377    fn scan_one(
378        buf: &[u8],
379        off: usize,
380        dict: &[String],
381        opts: &DecodeOptions,
382        depth: usize,
383    ) -> Result<(usize, ())> {
384        let ext = scan_value_extent(buf, off, dict, opts, depth)?;
385        let len = match &ext {
386            ValueExtent::Scalar(n) => *n,
387            ValueExtent::Array { total, .. } | ValueExtent::Object { total, .. } => *total,
388        };
389        Ok((len, ()))
390    }
391
392    /// Decodes a single value from buf[base..base+slice.len()], ensuring the slice is fully consumed.
393    fn decode_value_from_slice(
394        buf: &[u8],
395        base: usize,
396        slice: &[u8],
397        dict: &[String],
398        opts: &DecodeOptions,
399        depth: usize,
400    ) -> Result<Value<'static>> {
401        let mut d = Decoder {
402            buf,
403            off: base,
404            max_depth: opts.max_depth,
405            max_dict: opts.max_dict_size,
406            max_str: opts.max_string_length,
407            dict: dict.to_vec(),
408        };
409        let value = d.decode_value(depth)?;
410        if d.off != base + slice.len() {
411            return Err(KodaError::decode("internal: slice length mismatch"));
412        }
413        Ok(value)
414    }
415
416    fn decode_value_parallel(
417        buf: &[u8],
418        base: usize,
419        slice: &[u8],
420        dict: &[String],
421        opts: &DecodeOptions,
422        depth: usize,
423    ) -> Result<Value<'static>> {
424        if base >= buf.len() {
425            return Err(KodaError::decode("truncated input"));
426        }
427        let tag = buf[base];
428        match tag {
429            TAG_NULL | TAG_FALSE | TAG_TRUE | TAG_INTEGER | TAG_FLOAT | TAG_STRING => {
430                decode_value_from_slice(buf, base, slice, dict, opts, depth)
431            }
432            TAG_ARRAY => {
433                if base + 5 > buf.len() {
434                    return Err(KodaError::decode("truncated input"));
435                }
436                let count =
437                    u32::from_be_bytes(buf[base + 1..base + 5].try_into().unwrap()) as usize;
438                if count < PARALLEL_THRESHOLD {
439                    decode_value_from_slice(buf, base, slice, dict, opts, depth)
440                } else {
441                    let extent = scan_value_extent(buf, base, dict, opts, depth)?;
442                    let ValueExtent::Array { children, .. } = extent else {
443                        return decode_value_from_slice(buf, base, slice, dict, opts, depth);
444                    };
445                    let decoded: Result<Vec<_>> = children
446                        .par_iter()
447                        .map(|&(start, len)| {
448                            let s = &buf[start..start + len];
449                            decode_value_parallel(buf, start, s, dict, opts, depth + 1)
450                        })
451                        .collect();
452                    let arr = decoded?;
453                    Ok(Value::Array(arr))
454                }
455            }
456            TAG_OBJECT => {
457                if base + 5 > buf.len() {
458                    return Err(KodaError::decode("truncated input"));
459                }
460                let count =
461                    u32::from_be_bytes(buf[base + 1..base + 5].try_into().unwrap()) as usize;
462                if count < PARALLEL_THRESHOLD {
463                    decode_value_from_slice(buf, base, slice, dict, opts, depth)
464                } else {
465                    let extent = scan_value_extent(buf, base, dict, opts, depth)?;
466                    let ValueExtent::Object { children, .. } = extent else {
467                        return decode_value_from_slice(buf, base, slice, dict, opts, depth);
468                    };
469                    let decoded: Result<Vec<_>> = children
470                        .par_iter()
471                        .map(|&(key_idx, start, len)| {
472                            let s = &buf[start..start + len];
473                            decode_value_parallel(buf, start, s, dict, opts, depth + 1)
474                                .map(|v| (key_idx, v))
475                        })
476                        .collect();
477                    let pairs = decoded?;
478                    let mut obj = BTreeMap::new();
479                    for (key_idx, val) in pairs {
480                        let key = Cow::Owned(dict[key_idx as usize].clone());
481                        obj.insert(key, val);
482                    }
483                    Ok(Value::Object(obj))
484                }
485            }
486            _ => decode_value_from_slice(buf, base, slice, dict, opts, depth),
487        }
488    }
489}
490
491#[cfg(feature = "parallel")]
492pub use parallel::{decode_parallel, decode_parallel_with_options};
493
494#[cfg(all(test, feature = "parallel"))]
495mod parallel_tests {
496    use super::*;
497    use std::borrow::Cow;
498    use std::collections::BTreeMap;
499
500    #[test]
501    fn decode_parallel_matches_decode() {
502        let mut m = BTreeMap::new();
503        m.insert(
504            Cow::Owned("name".to_string()),
505            Value::String(Cow::Owned("test".to_string())),
506        );
507        m.insert(Cow::Owned("count".to_string()), Value::Number(42.0));
508        let inner: Vec<Value<'static>> = (0..150)
509            .map(|i| {
510                Value::Object({
511                    let mut o = BTreeMap::new();
512                    o.insert(Cow::Owned("x".to_string()), Value::Number(i as f64));
513                    o
514                })
515            })
516            .collect();
517        m.insert(Cow::Owned("items".to_string()), Value::Array(inner));
518        let value = Value::Object(m);
519        let bytes = crate::encoder::encode(&value).unwrap();
520        let decoded_seq = decode(&bytes).unwrap();
521        let decoded_par = decode_parallel(&bytes).unwrap();
522        assert_eq!(decoded_seq, decoded_par);
523    }
524}