use arcstr::ArcStr;
use grafeo_common::types::Value;
use std::collections::BTreeMap;
use std::io::{Read, Write};
use std::sync::Arc;
const TAG_NULL: u8 = 0;
const TAG_BOOL: u8 = 1;
const TAG_INT64: u8 = 2;
const TAG_FLOAT64: u8 = 3;
const TAG_STRING: u8 = 4;
const TAG_BYTES: u8 = 5;
const TAG_TIMESTAMP: u8 = 6;
const TAG_LIST: u8 = 7;
const TAG_MAP: u8 = 8;
const TAG_VECTOR: u8 = 9;
const TAG_DATE: u8 = 10;
const TAG_TIME: u8 = 11;
const TAG_DURATION: u8 = 12;
const TAG_PATH: u8 = 13;
const TAG_ZONED_DATETIME: u8 = 14;
const TAG_GCOUNTER: u8 = 15;
const TAG_PNCOUNTER: u8 = 16;
pub fn serialize_value<W: Write + ?Sized>(value: &Value, w: &mut W) -> std::io::Result<usize> {
match value {
Value::Null => {
w.write_all(&[TAG_NULL])?;
Ok(1)
}
Value::Bool(b) => {
w.write_all(&[TAG_BOOL, u8::from(*b)])?;
Ok(2)
}
Value::Int64(i) => {
w.write_all(&[TAG_INT64])?;
w.write_all(&i.to_le_bytes())?;
Ok(9)
}
Value::Float64(f) => {
w.write_all(&[TAG_FLOAT64])?;
w.write_all(&f.to_le_bytes())?;
Ok(9)
}
Value::String(s) => {
w.write_all(&[TAG_STRING])?;
let bytes = s.as_bytes();
w.write_all(&(bytes.len() as u64).to_le_bytes())?;
w.write_all(bytes)?;
Ok(1 + 8 + bytes.len())
}
Value::Bytes(b) => {
w.write_all(&[TAG_BYTES])?;
w.write_all(&(b.len() as u64).to_le_bytes())?;
w.write_all(b)?;
Ok(1 + 8 + b.len())
}
Value::Timestamp(t) => {
w.write_all(&[TAG_TIMESTAMP])?;
let micros = t.as_micros();
w.write_all(µs.to_le_bytes())?;
Ok(9)
}
Value::List(items) => {
w.write_all(&[TAG_LIST])?;
w.write_all(&(items.len() as u64).to_le_bytes())?;
let mut total = 1 + 8;
for item in items.iter() {
total += serialize_value(item, w)?;
}
Ok(total)
}
Value::Map(map) => {
w.write_all(&[TAG_MAP])?;
w.write_all(&(map.len() as u64).to_le_bytes())?;
let mut total = 1 + 8;
for (key, val) in map.iter() {
let key_bytes = key.as_str().as_bytes();
w.write_all(&(key_bytes.len() as u64).to_le_bytes())?;
w.write_all(key_bytes)?;
total += 8 + key_bytes.len();
total += serialize_value(val, w)?;
}
Ok(total)
}
Value::Vector(v) => {
w.write_all(&[TAG_VECTOR])?;
w.write_all(&(v.len() as u64).to_le_bytes())?;
for &f in v.iter() {
w.write_all(&f.to_le_bytes())?;
}
Ok(1 + 8 + v.len() * 4)
}
Value::Date(d) => {
w.write_all(&[TAG_DATE])?;
w.write_all(&d.as_days().to_le_bytes())?;
Ok(5)
}
Value::Time(t) => {
w.write_all(&[TAG_TIME])?;
w.write_all(&t.as_nanos().to_le_bytes())?;
let offset = t.offset_seconds().unwrap_or(i32::MIN);
w.write_all(&offset.to_le_bytes())?;
Ok(13)
}
Value::Duration(d) => {
w.write_all(&[TAG_DURATION])?;
w.write_all(&d.months().to_le_bytes())?;
w.write_all(&d.days().to_le_bytes())?;
w.write_all(&d.nanos().to_le_bytes())?;
Ok(25)
}
Value::ZonedDatetime(zdt) => {
w.write_all(&[TAG_ZONED_DATETIME])?;
w.write_all(&zdt.as_timestamp().as_micros().to_le_bytes())?;
w.write_all(&zdt.offset_seconds().to_le_bytes())?;
Ok(13)
}
Value::Path { nodes, edges } => {
w.write_all(&[TAG_PATH])?;
w.write_all(&(nodes.len() as u64).to_le_bytes())?;
let mut total = 1 + 8;
for node in nodes.iter() {
total += serialize_value(node, w)?;
}
w.write_all(&(edges.len() as u64).to_le_bytes())?;
total += 8;
for edge in edges.iter() {
total += serialize_value(edge, w)?;
}
Ok(total)
}
Value::GCounter(counts) => {
w.write_all(&[TAG_GCOUNTER])?;
w.write_all(&(counts.len() as u64).to_le_bytes())?;
let mut total = 1 + 8;
for (k, v) in counts.iter() {
let key_bytes = k.as_bytes();
w.write_all(&(key_bytes.len() as u64).to_le_bytes())?;
w.write_all(key_bytes)?;
w.write_all(&v.to_le_bytes())?;
total += 8 + key_bytes.len() + 8;
}
Ok(total)
}
Value::OnCounter { pos, neg } => {
w.write_all(&[TAG_PNCOUNTER])?;
let mut total = 1;
for map in [pos.as_ref(), neg.as_ref()] {
w.write_all(&(map.len() as u64).to_le_bytes())?;
total += 8;
for (k, v) in map {
let key_bytes = k.as_bytes();
w.write_all(&(key_bytes.len() as u64).to_le_bytes())?;
w.write_all(key_bytes)?;
w.write_all(&v.to_le_bytes())?;
total += 8 + key_bytes.len() + 8;
}
}
Ok(total)
}
_ => {
w.write_all(&[TAG_NULL])?;
Ok(1)
}
}
}
pub fn deserialize_value<R: Read + ?Sized>(r: &mut R) -> std::io::Result<Value> {
let mut tag = [0u8; 1];
r.read_exact(&mut tag)?;
match tag[0] {
TAG_NULL => Ok(Value::Null),
TAG_BOOL => {
let mut buf = [0u8; 1];
r.read_exact(&mut buf)?;
Ok(Value::Bool(buf[0] != 0))
}
TAG_INT64 => {
let mut buf = [0u8; 8];
r.read_exact(&mut buf)?;
Ok(Value::Int64(i64::from_le_bytes(buf)))
}
TAG_FLOAT64 => {
let mut buf = [0u8; 8];
r.read_exact(&mut buf)?;
Ok(Value::Float64(f64::from_le_bytes(buf)))
}
TAG_STRING => {
let mut len_buf = [0u8; 8];
r.read_exact(&mut len_buf)?;
let len = u64::from_le_bytes(len_buf) as usize;
let mut str_buf = vec![0u8; len];
r.read_exact(&mut str_buf)?;
let s = String::from_utf8(str_buf)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
Ok(Value::String(ArcStr::from(s)))
}
TAG_BYTES => {
let mut len_buf = [0u8; 8];
r.read_exact(&mut len_buf)?;
let len = u64::from_le_bytes(len_buf) as usize;
let mut bytes_buf = vec![0u8; len];
r.read_exact(&mut bytes_buf)?;
Ok(Value::Bytes(Arc::from(bytes_buf)))
}
TAG_TIMESTAMP => {
let mut buf = [0u8; 8];
r.read_exact(&mut buf)?;
let micros = i64::from_le_bytes(buf);
Ok(Value::Timestamp(
grafeo_common::types::Timestamp::from_micros(micros),
))
}
TAG_LIST => {
let mut len_buf = [0u8; 8];
r.read_exact(&mut len_buf)?;
let len = u64::from_le_bytes(len_buf) as usize;
let mut items = Vec::with_capacity(len);
for _ in 0..len {
items.push(deserialize_value(r)?);
}
Ok(Value::List(Arc::from(items)))
}
TAG_MAP => {
let mut len_buf = [0u8; 8];
r.read_exact(&mut len_buf)?;
let len = u64::from_le_bytes(len_buf) as usize;
let mut map = BTreeMap::new();
for _ in 0..len {
let mut key_len_buf = [0u8; 8];
r.read_exact(&mut key_len_buf)?;
let key_len = u64::from_le_bytes(key_len_buf) as usize;
let mut key_buf = vec![0u8; key_len];
r.read_exact(&mut key_buf)?;
let key_str = String::from_utf8(key_buf).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())
})?;
let val = deserialize_value(r)?;
map.insert(grafeo_common::types::PropertyKey::new(key_str), val);
}
Ok(Value::Map(Arc::new(map)))
}
TAG_VECTOR => {
let mut len_buf = [0u8; 8];
r.read_exact(&mut len_buf)?;
let len = u64::from_le_bytes(len_buf) as usize;
let mut floats = Vec::with_capacity(len);
let mut buf = [0u8; 4];
for _ in 0..len {
r.read_exact(&mut buf)?;
floats.push(f32::from_le_bytes(buf));
}
Ok(Value::Vector(Arc::from(floats)))
}
TAG_DATE => {
let mut buf = [0u8; 4];
r.read_exact(&mut buf)?;
Ok(Value::Date(grafeo_common::types::Date::from_days(
i32::from_le_bytes(buf),
)))
}
TAG_TIME => {
let mut nanos_buf = [0u8; 8];
r.read_exact(&mut nanos_buf)?;
let nanos = u64::from_le_bytes(nanos_buf);
let mut offset_buf = [0u8; 4];
r.read_exact(&mut offset_buf)?;
let offset = i32::from_le_bytes(offset_buf);
let time = grafeo_common::types::Time::from_nanos(nanos).ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid time nanos")
})?;
if offset == i32::MIN {
Ok(Value::Time(time))
} else {
Ok(Value::Time(time.with_offset(offset)))
}
}
TAG_DURATION => {
let mut buf = [0u8; 8];
r.read_exact(&mut buf)?;
let months = i64::from_le_bytes(buf);
r.read_exact(&mut buf)?;
let days = i64::from_le_bytes(buf);
r.read_exact(&mut buf)?;
let nanos = i64::from_le_bytes(buf);
Ok(Value::Duration(grafeo_common::types::Duration::new(
months, days, nanos,
)))
}
TAG_ZONED_DATETIME => {
let mut micros_buf = [0u8; 8];
r.read_exact(&mut micros_buf)?;
let micros = i64::from_le_bytes(micros_buf);
let mut offset_buf = [0u8; 4];
r.read_exact(&mut offset_buf)?;
let offset = i32::from_le_bytes(offset_buf);
Ok(Value::ZonedDatetime(
grafeo_common::types::ZonedDatetime::from_timestamp_offset(
grafeo_common::types::Timestamp::from_micros(micros),
offset,
),
))
}
TAG_PATH => {
let mut len_buf = [0u8; 8];
r.read_exact(&mut len_buf)?;
let nodes_len = u64::from_le_bytes(len_buf) as usize;
let mut nodes = Vec::with_capacity(nodes_len);
for _ in 0..nodes_len {
nodes.push(deserialize_value(r)?);
}
r.read_exact(&mut len_buf)?;
let edges_len = u64::from_le_bytes(len_buf) as usize;
let mut edges = Vec::with_capacity(edges_len);
for _ in 0..edges_len {
edges.push(deserialize_value(r)?);
}
Ok(Value::Path {
nodes: Arc::from(nodes),
edges: Arc::from(edges),
})
}
TAG_GCOUNTER => {
let mut u64_buf = [0u8; 8];
r.read_exact(&mut u64_buf)?;
let count = u64::from_le_bytes(u64_buf) as usize;
let mut map = std::collections::HashMap::with_capacity(count);
for _ in 0..count {
r.read_exact(&mut u64_buf)?;
let key_len = u64::from_le_bytes(u64_buf) as usize;
let mut key_buf = vec![0u8; key_len];
r.read_exact(&mut key_buf)?;
let key = String::from_utf8(key_buf).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())
})?;
r.read_exact(&mut u64_buf)?;
let value = u64::from_le_bytes(u64_buf);
map.insert(key, value);
}
Ok(Value::GCounter(Arc::new(map)))
}
TAG_PNCOUNTER => {
let mut u64_buf = [0u8; 8];
let mut maps = [
std::collections::HashMap::new(),
std::collections::HashMap::new(),
];
for map in &mut maps {
r.read_exact(&mut u64_buf)?;
let count = u64::from_le_bytes(u64_buf) as usize;
map.reserve(count);
for _ in 0..count {
r.read_exact(&mut u64_buf)?;
let key_len = u64::from_le_bytes(u64_buf) as usize;
let mut key_buf = vec![0u8; key_len];
r.read_exact(&mut key_buf)?;
let key = String::from_utf8(key_buf).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())
})?;
r.read_exact(&mut u64_buf)?;
let value = u64::from_le_bytes(u64_buf);
map.insert(key, value);
}
}
let [pos, neg] = maps;
Ok(Value::OnCounter {
pos: Arc::new(pos),
neg: Arc::new(neg),
})
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Unknown value tag: {}", tag[0]),
)),
}
}
pub fn serialize_row<W: Write + ?Sized>(row: &[Value], w: &mut W) -> std::io::Result<usize> {
w.write_all(&(row.len() as u64).to_le_bytes())?;
let mut total = 8;
for value in row {
total += serialize_value(value, w)?;
}
Ok(total)
}
pub fn deserialize_row<R: Read + ?Sized>(
r: &mut R,
expected_columns: usize,
) -> std::io::Result<Vec<Value>> {
let mut len_buf = [0u8; 8];
r.read_exact(&mut len_buf)?;
let num_columns = u64::from_le_bytes(len_buf) as usize;
if expected_columns > 0 && num_columns != expected_columns {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Column count mismatch: expected {}, got {}",
expected_columns, num_columns
),
));
}
let mut row = Vec::with_capacity(num_columns);
for _ in 0..num_columns {
row.push(deserialize_value(r)?);
}
Ok(row)
}
#[cfg(test)]
mod tests {
use super::*;
use arcstr::ArcStr;
use std::io::Cursor;
fn roundtrip_value(value: Value) -> Value {
let mut buf = Vec::new();
serialize_value(&value, &mut buf).unwrap();
let mut cursor = Cursor::new(buf);
deserialize_value(&mut cursor).unwrap()
}
#[test]
fn test_serialize_null() {
let result = roundtrip_value(Value::Null);
assert_eq!(result, Value::Null);
}
#[test]
fn test_serialize_bool() {
assert_eq!(roundtrip_value(Value::Bool(true)), Value::Bool(true));
assert_eq!(roundtrip_value(Value::Bool(false)), Value::Bool(false));
}
#[test]
fn test_serialize_int64() {
assert_eq!(roundtrip_value(Value::Int64(0)), Value::Int64(0));
assert_eq!(
roundtrip_value(Value::Int64(i64::MAX)),
Value::Int64(i64::MAX)
);
assert_eq!(
roundtrip_value(Value::Int64(i64::MIN)),
Value::Int64(i64::MIN)
);
assert_eq!(roundtrip_value(Value::Int64(-42)), Value::Int64(-42));
}
#[test]
fn test_serialize_float64() {
assert_eq!(roundtrip_value(Value::Float64(0.0)), Value::Float64(0.0));
assert_eq!(
roundtrip_value(Value::Float64(std::f64::consts::PI)),
Value::Float64(std::f64::consts::PI)
);
let nan_result = roundtrip_value(Value::Float64(f64::NAN));
assert!(matches!(nan_result, Value::Float64(f) if f.is_nan()));
}
#[test]
fn test_serialize_string() {
let result = roundtrip_value(Value::String(ArcStr::from("hello world")));
assert_eq!(result.as_str(), Some("hello world"));
let result = roundtrip_value(Value::String(ArcStr::from("")));
assert_eq!(result.as_str(), Some(""));
let result = roundtrip_value(Value::String(ArcStr::from("héllo 世界 🌍")));
assert_eq!(result.as_str(), Some("héllo 世界 🌍"));
}
#[test]
fn test_serialize_bytes() {
let data = vec![0u8, 1, 2, 255, 128];
let result = roundtrip_value(Value::Bytes(Arc::from(data.clone())));
assert_eq!(result.as_bytes(), Some(&data[..]));
let result = roundtrip_value(Value::Bytes(Arc::from(vec![])));
assert_eq!(result.as_bytes(), Some(&[][..]));
}
#[test]
fn test_serialize_timestamp() {
let ts = grafeo_common::types::Timestamp::from_micros(1234567890);
let result = roundtrip_value(Value::Timestamp(ts));
assert_eq!(result.as_timestamp(), Some(ts));
}
#[test]
fn test_serialize_list() {
let list = Value::List(Arc::from(vec![
Value::Int64(1),
Value::String(ArcStr::from("two")),
Value::Bool(true),
]));
let result = roundtrip_value(list.clone());
assert_eq!(result, list);
let nested = Value::List(Arc::from(vec![
Value::List(Arc::from(vec![Value::Int64(1), Value::Int64(2)])),
Value::List(Arc::from(vec![Value::Int64(3)])),
]));
let result = roundtrip_value(nested.clone());
assert_eq!(result, nested);
let empty = Value::List(Arc::from(vec![]));
let result = roundtrip_value(empty.clone());
assert_eq!(result, empty);
}
#[test]
fn test_serialize_map() {
let mut map = BTreeMap::new();
map.insert(
grafeo_common::types::PropertyKey::new("name"),
Value::String(ArcStr::from("Alix")),
);
map.insert(
grafeo_common::types::PropertyKey::new("age"),
Value::Int64(30),
);
let value = Value::Map(Arc::new(map));
let result = roundtrip_value(value.clone());
assert_eq!(result, value);
}
#[test]
fn test_serialize_row() {
let row = vec![
Value::Int64(1),
Value::String(ArcStr::from("test")),
Value::Bool(true),
Value::Null,
];
let mut buf = Vec::new();
serialize_row(&row, &mut buf).unwrap();
let mut cursor = Cursor::new(buf);
let result = deserialize_row(&mut cursor, 4).unwrap();
assert_eq!(result, row);
}
#[test]
fn test_serialize_row_column_count_check() {
let row = vec![Value::Int64(1), Value::Int64(2)];
let mut buf = Vec::new();
serialize_row(&row, &mut buf).unwrap();
let mut cursor = Cursor::new(buf.clone());
let result = deserialize_row(&mut cursor, 3);
assert!(result.is_err());
let mut cursor = Cursor::new(buf);
let result = deserialize_row(&mut cursor, 0).unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn test_serialize_multiple_rows() {
let rows = vec![
vec![Value::Int64(1), Value::String(ArcStr::from("a"))],
vec![Value::Int64(2), Value::String(ArcStr::from("b"))],
vec![Value::Int64(3), Value::String(ArcStr::from("c"))],
];
let mut buf = Vec::new();
for row in &rows {
serialize_row(row, &mut buf).unwrap();
}
let mut cursor = Cursor::new(buf);
for expected in &rows {
let result = deserialize_row(&mut cursor, 2).unwrap();
assert_eq!(&result, expected);
}
}
#[test]
fn test_serialize_gcounter_roundtrip() {
let mut counts = std::collections::HashMap::new();
counts.insert("replica-1".to_string(), 42u64);
counts.insert("replica-2".to_string(), 17u64);
let v = Value::GCounter(Arc::new(counts));
let result = roundtrip_value(v.clone());
assert_eq!(result, v);
}
#[test]
fn test_serialize_gcounter_empty() {
let v = Value::GCounter(Arc::new(std::collections::HashMap::new()));
let result = roundtrip_value(v.clone());
assert_eq!(result, v);
}
#[test]
fn test_serialize_pncounter_roundtrip() {
let mut pos = std::collections::HashMap::new();
pos.insert("node-a".to_string(), 10u64);
pos.insert("node-b".to_string(), 5u64);
let mut neg = std::collections::HashMap::new();
neg.insert("node-a".to_string(), 3u64);
let v = Value::OnCounter {
pos: Arc::new(pos),
neg: Arc::new(neg),
};
let result = roundtrip_value(v.clone());
assert_eq!(result, v);
}
#[test]
fn test_serialization_size() {
let mut buf = Vec::new();
serialize_value(&Value::Null, &mut buf).unwrap();
assert_eq!(buf.len(), 1);
buf.clear();
serialize_value(&Value::Bool(true), &mut buf).unwrap();
assert_eq!(buf.len(), 2);
buf.clear();
serialize_value(&Value::Int64(42), &mut buf).unwrap();
assert_eq!(buf.len(), 9);
buf.clear();
serialize_value(&Value::String(ArcStr::from("hi")), &mut buf).unwrap();
assert_eq!(buf.len(), 11);
}
}