use super::error::{OtomlError, Result};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use toml::Value;
const MAGIC: &[u8; 4] = b"OBIN";
const VERSION: u8 = 1;
mod types {
pub const NULL: u8 = 0x00;
pub const BOOL_FALSE: u8 = 0x01;
pub const BOOL_TRUE: u8 = 0x02;
pub const INT8: u8 = 0x10;
pub const INT16: u8 = 0x11;
pub const INT32: u8 = 0x12;
pub const INT64: u8 = 0x13;
pub const FLOAT64: u8 = 0x20;
pub const STRING: u8 = 0x30;
pub const ARRAY: u8 = 0x40;
pub const TABLE: u8 = 0x50;
pub const DATETIME: u8 = 0x60;
}
pub fn dump_obin<T: Serialize>(value: &T) -> Result<Vec<u8>> {
let toml_value =
toml::Value::try_from(value).map_err(|e| OtomlError::BinarySerialize(e.to_string()))?;
let canonical = to_canonical(&toml_value);
let mut buffer = Vec::new();
buffer.extend_from_slice(MAGIC);
buffer.push(VERSION);
write_value(&canonical, &mut buffer)?;
Ok(buffer)
}
pub fn load_obin<T: for<'de> Deserialize<'de>>(data: &[u8]) -> Result<T> {
if data.len() < 5 {
return Err(OtomlError::BinaryDeserialize(
"data too short for OBIN header".to_string(),
));
}
if &data[0..4] != MAGIC {
return Err(OtomlError::BinaryDeserialize(
"invalid OBIN magic bytes".to_string(),
));
}
if data[4] != VERSION {
return Err(OtomlError::BinaryDeserialize(format!(
"unsupported OBIN version: {}",
data[4]
)));
}
let mut pos = 5;
let value = read_value(data, &mut pos)?;
let toml_str =
toml::to_string(&value).map_err(|e| OtomlError::BinaryDeserialize(e.to_string()))?;
toml::from_str(&toml_str).map_err(|e| OtomlError::BinaryDeserialize(e.to_string()))
}
fn to_canonical(value: &Value) -> Value {
match value {
Value::Table(table) => {
let mut sorted: BTreeMap<String, Value> = BTreeMap::new();
for (key, val) in table {
sorted.insert(key.clone(), to_canonical(val));
}
Value::Table(sorted.into_iter().collect())
}
Value::Array(arr) => Value::Array(arr.iter().map(to_canonical).collect()),
other => other.clone(),
}
}
fn write_value(value: &Value, buf: &mut Vec<u8>) -> Result<()> {
match value {
Value::Boolean(false) => {
buf.push(types::BOOL_FALSE);
}
Value::Boolean(true) => {
buf.push(types::BOOL_TRUE);
}
Value::Integer(i) => {
write_integer(*i, buf);
}
Value::Float(f) => {
buf.push(types::FLOAT64);
buf.extend_from_slice(&f.to_le_bytes());
}
Value::String(s) => {
buf.push(types::STRING);
write_string(s, buf);
}
Value::Array(arr) => {
buf.push(types::ARRAY);
write_varint(arr.len() as u64, buf);
for item in arr {
write_value(item, buf)?;
}
}
Value::Table(table) => {
buf.push(types::TABLE);
let mut keys: Vec<&String> = table.keys().collect();
keys.sort();
write_varint(keys.len() as u64, buf);
for key in keys {
write_string(key, buf);
write_value(table.get(key).unwrap(), buf)?;
}
}
Value::Datetime(dt) => {
buf.push(types::DATETIME);
write_string(&dt.to_string(), buf);
}
}
Ok(())
}
fn read_value(data: &[u8], pos: &mut usize) -> Result<Value> {
if *pos >= data.len() {
return Err(OtomlError::BinaryDeserialize(
"unexpected end of data".to_string(),
));
}
let type_tag = data[*pos];
*pos += 1;
match type_tag {
types::NULL => Ok(Value::String("null".to_string())),
types::BOOL_FALSE => Ok(Value::Boolean(false)),
types::BOOL_TRUE => Ok(Value::Boolean(true)),
types::INT8 => {
if *pos >= data.len() {
return Err(OtomlError::BinaryDeserialize("unexpected end".to_string()));
}
let v = data[*pos] as i8 as i64;
*pos += 1;
Ok(Value::Integer(v))
}
types::INT16 => {
if *pos + 2 > data.len() {
return Err(OtomlError::BinaryDeserialize("unexpected end".to_string()));
}
let bytes: [u8; 2] = data[*pos..*pos + 2].try_into().unwrap();
let v = i16::from_le_bytes(bytes) as i64;
*pos += 2;
Ok(Value::Integer(v))
}
types::INT32 => {
if *pos + 4 > data.len() {
return Err(OtomlError::BinaryDeserialize("unexpected end".to_string()));
}
let bytes: [u8; 4] = data[*pos..*pos + 4].try_into().unwrap();
let v = i32::from_le_bytes(bytes) as i64;
*pos += 4;
Ok(Value::Integer(v))
}
types::INT64 => {
if *pos + 8 > data.len() {
return Err(OtomlError::BinaryDeserialize("unexpected end".to_string()));
}
let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
let v = i64::from_le_bytes(bytes);
*pos += 8;
Ok(Value::Integer(v))
}
types::FLOAT64 => {
if *pos + 8 > data.len() {
return Err(OtomlError::BinaryDeserialize("unexpected end".to_string()));
}
let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
let v = f64::from_le_bytes(bytes);
*pos += 8;
Ok(Value::Float(v))
}
types::STRING => {
let s = read_string(data, pos)?;
Ok(Value::String(s))
}
types::ARRAY => {
let len = read_varint(data, pos)? as usize;
let mut arr = Vec::with_capacity(len);
for _ in 0..len {
arr.push(read_value(data, pos)?);
}
Ok(Value::Array(arr))
}
types::TABLE => {
let len = read_varint(data, pos)? as usize;
let mut table = toml::map::Map::new();
for _ in 0..len {
let key = read_string(data, pos)?;
let value = read_value(data, pos)?;
table.insert(key, value);
}
Ok(Value::Table(table))
}
types::DATETIME => {
let s = read_string(data, pos)?;
if let Ok(dt) = s.parse::<toml::value::Datetime>() {
Ok(Value::Datetime(dt))
} else {
Ok(Value::String(s))
}
}
_ => Err(OtomlError::BinaryDeserialize(format!(
"unknown type tag: 0x{:02X}",
type_tag
))),
}
}
fn write_integer(i: i64, buf: &mut Vec<u8>) {
if i >= i8::MIN as i64 && i <= i8::MAX as i64 {
buf.push(types::INT8);
buf.push(i as i8 as u8);
} else if i >= i16::MIN as i64 && i <= i16::MAX as i64 {
buf.push(types::INT16);
buf.extend_from_slice(&(i as i16).to_le_bytes());
} else if i >= i32::MIN as i64 && i <= i32::MAX as i64 {
buf.push(types::INT32);
buf.extend_from_slice(&(i as i32).to_le_bytes());
} else {
buf.push(types::INT64);
buf.extend_from_slice(&i.to_le_bytes());
}
}
fn write_string(s: &str, buf: &mut Vec<u8>) {
let bytes = s.as_bytes();
write_varint(bytes.len() as u64, buf);
buf.extend_from_slice(bytes);
}
fn read_string(data: &[u8], pos: &mut usize) -> Result<String> {
let len = read_varint(data, pos)? as usize;
if *pos + len > data.len() {
return Err(OtomlError::BinaryDeserialize(
"string length exceeds data".to_string(),
));
}
let s = std::str::from_utf8(&data[*pos..*pos + len])
.map_err(|e| OtomlError::BinaryDeserialize(e.to_string()))?
.to_string();
*pos += len;
Ok(s)
}
fn write_varint(mut value: u64, buf: &mut Vec<u8>) {
loop {
let mut byte = (value & 0x7F) as u8;
value >>= 7;
if value != 0 {
byte |= 0x80;
}
buf.push(byte);
if value == 0 {
break;
}
}
}
fn read_varint(data: &[u8], pos: &mut usize) -> Result<u64> {
let mut result: u64 = 0;
let mut shift = 0;
loop {
if *pos >= data.len() {
return Err(OtomlError::BinaryDeserialize(
"unexpected end reading varint".to_string(),
));
}
let byte = data[*pos];
*pos += 1;
result |= ((byte & 0x7F) as u64) << shift;
if byte & 0x80 == 0 {
break;
}
shift += 7;
if shift > 63 {
return Err(OtomlError::BinaryDeserialize("varint overflow".to_string()));
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct TestStruct {
name: String,
value: i32,
enabled: bool,
}
#[test]
fn test_roundtrip() {
let data = TestStruct {
name: "test".to_string(),
value: 42,
enabled: true,
};
let bytes = dump_obin(&data).unwrap();
let parsed: TestStruct = load_obin(&bytes).unwrap();
assert_eq!(data, parsed);
}
#[test]
fn test_header() {
let data = TestStruct {
name: "test".to_string(),
value: 42,
enabled: true,
};
let bytes = dump_obin(&data).unwrap();
assert_eq!(&bytes[0..4], b"OBIN");
assert_eq!(bytes[4], VERSION);
}
#[test]
fn test_deterministic() {
let data = TestStruct {
name: "test".to_string(),
value: 42,
enabled: true,
};
let bytes1 = dump_obin(&data).unwrap();
let bytes2 = dump_obin(&data).unwrap();
assert_eq!(bytes1, bytes2);
}
#[test]
fn test_integer_compression() {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct Ints {
small: i8,
medium: i16,
large: i32,
huge: i64,
}
let data = Ints {
small: 42,
medium: 1000,
large: 100000,
huge: 10000000000,
};
let bytes = dump_obin(&data).unwrap();
let parsed: Ints = load_obin(&bytes).unwrap();
assert_eq!(data.small as i64, parsed.small as i64);
assert_eq!(data.medium as i64, parsed.medium as i64);
assert_eq!(data.large as i64, parsed.large as i64);
assert_eq!(data.huge, parsed.huge);
}
#[test]
fn test_nested() {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct Outer {
inner: Inner,
}
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct Inner {
value: i32,
}
let data = Outer {
inner: Inner { value: 42 },
};
let bytes = dump_obin(&data).unwrap();
let parsed: Outer = load_obin(&bytes).unwrap();
assert_eq!(data, parsed);
}
#[test]
fn test_array() {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct WithArray {
items: Vec<i32>,
}
let data = WithArray {
items: vec![1, 2, 3, 4, 5],
};
let bytes = dump_obin(&data).unwrap();
let parsed: WithArray = load_obin(&bytes).unwrap();
assert_eq!(data, parsed);
}
#[test]
fn test_invalid_magic() {
let data = b"XXXX\x01";
let result: Result<TestStruct> = load_obin(data);
assert!(result.is_err());
}
#[test]
fn test_invalid_version() {
let data = b"OBIN\xFF";
let result: Result<TestStruct> = load_obin(data);
assert!(result.is_err());
}
}