use std::collections::BTreeMap;
use crate::document::value::Value;
use crate::error::{GrumpyError, Result};
const TAG_NULL: u8 = 0x00;
const TAG_BOOL: u8 = 0x01;
const TAG_INTEGER: u8 = 0x02;
const TAG_FLOAT: u8 = 0x03;
const TAG_STRING: u8 = 0x04;
const TAG_BYTES: u8 = 0x05;
const TAG_ARRAY: u8 = 0x06;
const TAG_OBJECT: u8 = 0x07;
const MAX_NESTING_DEPTH: usize = 64;
const MAX_BLOB_LEN: u32 = 16 * 1024 * 1024;
const MAX_ARRAY_LEN: u32 = 1_000_000;
const MAX_OBJECT_KEYS: u32 = 100_000;
pub fn encode(value: &Value, buf: &mut Vec<u8>) {
match value {
Value::Null => buf.push(TAG_NULL),
Value::Bool(b) => {
buf.push(TAG_BOOL);
buf.push(u8::from(*b));
}
Value::Integer(n) => {
buf.push(TAG_INTEGER);
buf.extend_from_slice(&n.to_le_bytes());
}
Value::Float(f) => {
buf.push(TAG_FLOAT);
buf.extend_from_slice(&f.to_le_bytes());
}
Value::String(s) => {
buf.push(TAG_STRING);
buf.extend_from_slice(&(s.len() as u32).to_le_bytes());
buf.extend_from_slice(s.as_bytes());
}
Value::Bytes(b) => {
buf.push(TAG_BYTES);
buf.extend_from_slice(&(b.len() as u32).to_le_bytes());
buf.extend_from_slice(b);
}
Value::Array(arr) => {
buf.push(TAG_ARRAY);
buf.extend_from_slice(&(arr.len() as u32).to_le_bytes());
for item in arr {
encode(item, buf);
}
}
Value::Object(map) => {
buf.push(TAG_OBJECT);
buf.extend_from_slice(&(map.len() as u32).to_le_bytes());
for (key, val) in map {
buf.extend_from_slice(&(key.len() as u32).to_le_bytes());
buf.extend_from_slice(key.as_bytes());
encode(val, buf);
}
}
}
}
pub fn encode_to_vec(value: &Value) -> Vec<u8> {
let mut buf = Vec::with_capacity(encoded_size(value));
encode(value, &mut buf);
buf
}
pub fn encoded_size(value: &Value) -> usize {
match value {
Value::Null => 1,
Value::Bool(_) => 2,
Value::Integer(_) | Value::Float(_) => 9,
Value::String(s) => 1 + 4 + s.len(),
Value::Bytes(b) => 1 + 4 + b.len(),
Value::Array(arr) => 1 + 4 + arr.iter().map(encoded_size).sum::<usize>(),
Value::Object(map) => {
1 + 4
+ map
.iter()
.map(|(k, v)| 4 + k.len() + encoded_size(v))
.sum::<usize>()
}
}
}
pub fn decode(data: &[u8]) -> Result<Value> {
let mut cursor = data;
let value = decode_recursive(&mut cursor, 0)?;
Ok(value)
}
pub fn decode_from_cursor(cursor: &mut &[u8]) -> Result<Value> {
decode_recursive(cursor, 0)
}
fn decode_recursive(cursor: &mut &[u8], depth: usize) -> Result<Value> {
if depth > MAX_NESTING_DEPTH {
return Err(GrumpyError::Codec(format!(
"nesting depth exceeds maximum ({MAX_NESTING_DEPTH})"
)));
}
let tag = read_u8(cursor)?;
match tag {
TAG_NULL => Ok(Value::Null),
TAG_BOOL => {
let b = read_u8(cursor)?;
Ok(Value::Bool(b != 0))
}
TAG_INTEGER => {
let n = read_i64_le(cursor)?;
Ok(Value::Integer(n))
}
TAG_FLOAT => {
let f = read_f64_le(cursor)?;
Ok(Value::Float(f))
}
TAG_STRING => {
let len = read_u32_le(cursor)?;
if len > MAX_BLOB_LEN {
return Err(GrumpyError::Codec(format!(
"string length {len} exceeds maximum ({MAX_BLOB_LEN})"
)));
}
let s = read_string(cursor, len as usize)?;
Ok(Value::String(s))
}
TAG_BYTES => {
let len = read_u32_le(cursor)?;
if len > MAX_BLOB_LEN {
return Err(GrumpyError::Codec(format!(
"bytes length {len} exceeds maximum ({MAX_BLOB_LEN})"
)));
}
let b = read_bytes(cursor, len as usize)?;
Ok(Value::Bytes(b))
}
TAG_ARRAY => {
let count = read_u32_le(cursor)?;
if count > MAX_ARRAY_LEN {
return Err(GrumpyError::Codec(format!(
"array length {count} exceeds maximum ({MAX_ARRAY_LEN})"
)));
}
let mut arr = Vec::with_capacity(count as usize);
for _ in 0..count {
arr.push(decode_recursive(cursor, depth + 1)?);
}
Ok(Value::Array(arr))
}
TAG_OBJECT => {
let count = read_u32_le(cursor)?;
if count > MAX_OBJECT_KEYS {
return Err(GrumpyError::Codec(format!(
"object key count {count} exceeds maximum ({MAX_OBJECT_KEYS})"
)));
}
let mut map = BTreeMap::new();
for _ in 0..count {
let key_len = read_u32_le(cursor)?;
if key_len > MAX_BLOB_LEN {
return Err(GrumpyError::Codec(format!(
"object key length {key_len} exceeds maximum ({MAX_BLOB_LEN})"
)));
}
let key = read_string(cursor, key_len as usize)?;
let val = decode_recursive(cursor, depth + 1)?;
map.insert(key, val);
}
Ok(Value::Object(map))
}
_ => Err(GrumpyError::Codec(format!(
"unknown type tag: 0x{tag:02x}"
))),
}
}
fn read_u8(cursor: &mut &[u8]) -> Result<u8> {
if cursor.is_empty() {
return Err(GrumpyError::Codec("unexpected end of data".into()));
}
let val = cursor[0];
*cursor = &cursor[1..];
Ok(val)
}
fn read_u32_le(cursor: &mut &[u8]) -> Result<u32> {
if cursor.len() < 4 {
return Err(GrumpyError::Codec("unexpected end of data reading u32".into()));
}
let val = u32::from_le_bytes(cursor[..4].try_into().unwrap());
*cursor = &cursor[4..];
Ok(val)
}
fn read_i64_le(cursor: &mut &[u8]) -> Result<i64> {
if cursor.len() < 8 {
return Err(GrumpyError::Codec("unexpected end of data reading i64".into()));
}
let val = i64::from_le_bytes(cursor[..8].try_into().unwrap());
*cursor = &cursor[8..];
Ok(val)
}
fn read_f64_le(cursor: &mut &[u8]) -> Result<f64> {
if cursor.len() < 8 {
return Err(GrumpyError::Codec("unexpected end of data reading f64".into()));
}
let val = f64::from_le_bytes(cursor[..8].try_into().unwrap());
*cursor = &cursor[8..];
Ok(val)
}
fn read_bytes(cursor: &mut &[u8], len: usize) -> Result<Vec<u8>> {
if cursor.len() < len {
return Err(GrumpyError::Codec(format!(
"unexpected end of data: need {len} bytes, have {}",
cursor.len()
)));
}
let val = cursor[..len].to_vec();
*cursor = &cursor[len..];
Ok(val)
}
fn read_string(cursor: &mut &[u8], len: usize) -> Result<String> {
let bytes = read_bytes(cursor, len)?;
String::from_utf8(bytes).map_err(|e| GrumpyError::Codec(format!("invalid UTF-8: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
fn round_trip(value: &Value) {
let encoded = encode_to_vec(value);
assert_eq!(
encoded.len(),
encoded_size(value),
"encoded_size mismatch for {value:?}"
);
let decoded = decode(&encoded).unwrap();
assert_eq!(*value, decoded);
}
#[test]
fn test_null_round_trip() {
round_trip(&Value::Null);
}
#[test]
fn test_bool_round_trip() {
round_trip(&Value::Bool(true));
round_trip(&Value::Bool(false));
}
#[test]
fn test_integer_round_trip() {
round_trip(&Value::Integer(0));
round_trip(&Value::Integer(42));
round_trip(&Value::Integer(-1));
round_trip(&Value::Integer(i64::MAX));
round_trip(&Value::Integer(i64::MIN));
}
#[test]
fn test_float_round_trip() {
round_trip(&Value::Float(0.0));
round_trip(&Value::Float(std::f64::consts::PI));
round_trip(&Value::Float(-1.0e100));
round_trip(&Value::Float(f64::INFINITY));
round_trip(&Value::Float(f64::NEG_INFINITY));
}
#[test]
fn test_string_round_trip() {
round_trip(&Value::String(String::new()));
round_trip(&Value::String("hello".into()));
round_trip(&Value::String("émoji: 🦀".into()));
round_trip(&Value::String("a".repeat(10_000)));
}
#[test]
fn test_bytes_round_trip() {
round_trip(&Value::Bytes(vec![]));
round_trip(&Value::Bytes(vec![0, 1, 2, 255]));
round_trip(&Value::Bytes(vec![0xAB; 5000]));
}
#[test]
fn test_array_round_trip() {
round_trip(&Value::Array(vec![]));
round_trip(&Value::Array(vec![
Value::Integer(1),
Value::String("two".into()),
Value::Null,
]));
}
#[test]
fn test_object_round_trip() {
round_trip(&Value::Object(BTreeMap::new()));
round_trip(&Value::Object(BTreeMap::from([
("name".into(), Value::String("grumpy".into())),
("version".into(), Value::Integer(1)),
])));
}
#[test]
fn test_nested_complex_document() {
let value = Value::Object(BTreeMap::from([
("name".into(), Value::String("GrumpyDB".into())),
("version".into(), Value::Integer(1)),
("active".into(), Value::Bool(true)),
("score".into(), Value::Float(99.5)),
("data".into(), Value::Bytes(vec![0xDE, 0xAD])),
(
"tags".into(),
Value::Array(vec![
Value::String("db".into()),
Value::String("rust".into()),
Value::Null,
]),
),
(
"metadata".into(),
Value::Object(BTreeMap::from([
("created".into(), Value::Integer(1234567890)),
(
"nested".into(),
Value::Object(BTreeMap::from([(
"deep".into(),
Value::Bool(true),
)])),
),
])),
),
]));
round_trip(&value);
}
#[test]
fn test_encoded_size_matches() {
let values = vec![
Value::Null,
Value::Bool(true),
Value::Integer(42),
Value::Float(std::f64::consts::PI),
Value::String("test".into()),
Value::Bytes(vec![1, 2, 3]),
Value::Array(vec![Value::Integer(1), Value::Integer(2)]),
Value::Object(BTreeMap::from([("k".into(), Value::Null)])),
];
for v in &values {
let encoded = encode_to_vec(v);
assert_eq!(encoded.len(), encoded_size(v), "mismatch for {v:?}");
}
}
#[test]
fn test_decode_unknown_tag() {
let data = [0xFF];
let result = decode(&data);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("unknown type tag"));
}
#[test]
fn test_decode_truncated_string() {
let encoded = encode_to_vec(&Value::String("hello".into()));
let truncated = &encoded[..3]; assert!(decode(truncated).is_err());
}
#[test]
fn test_decode_truncated_integer() {
let encoded = encode_to_vec(&Value::Integer(42));
let truncated = &encoded[..5]; assert!(decode(truncated).is_err());
}
#[test]
fn test_decode_empty_data() {
assert!(decode(&[]).is_err());
}
#[test]
fn test_decode_invalid_utf8() {
let mut data = vec![TAG_STRING];
data.extend_from_slice(&3u32.to_le_bytes());
data.extend_from_slice(&[0xFF, 0xFE, 0xFD]); let result = decode(&data);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("UTF-8"));
}
#[test]
fn test_nesting_depth_limit() {
let mut value = Value::Null;
for _ in 0..MAX_NESTING_DEPTH + 5 {
value = Value::Array(vec![value]);
}
let encoded = encode_to_vec(&value);
let result = decode(&encoded);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("nesting depth"));
}
#[test]
fn test_nesting_at_max_depth_ok() {
let mut value = Value::Null;
for _ in 0..MAX_NESTING_DEPTH {
value = Value::Array(vec![value]);
}
let encoded = encode_to_vec(&value);
assert!(decode(&encoded).is_ok());
}
#[test]
fn test_empty_containers() {
round_trip(&Value::String(String::new()));
round_trip(&Value::Bytes(vec![]));
round_trip(&Value::Array(vec![]));
round_trip(&Value::Object(BTreeMap::new()));
}
#[test]
fn test_float_nan() {
let encoded = encode_to_vec(&Value::Float(f64::NAN));
let decoded = decode(&encoded).unwrap();
match decoded {
Value::Float(f) => assert!(f.is_nan()),
_ => panic!("expected Float"),
}
}
}