use std::marker::PhantomData;
use serde::Serialize;
use serde::de::DeserializeOwned;
use noxu_db::DatabaseEntry;
use crate::Result;
use crate::entry_binding::EntryBinding;
use crate::serial::simple_serial;
pub const SERDE_BINDING_MAGIC: u8 = 0xCB;
pub const SERDE_BINDING_VERSION: u8 = 0x01;
pub const SERDE_BINDING_HEADER_LEN: usize = 2;
pub struct SerdeBinding<T> {
_phantom: PhantomData<T>,
}
impl<T> SerdeBinding<T> {
pub fn new() -> Self {
Self { _phantom: PhantomData }
}
}
impl<T> Default for SerdeBinding<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Clone for SerdeBinding<T> {
fn clone(&self) -> Self {
Self::new()
}
}
impl<T> std::fmt::Debug for SerdeBinding<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SerdeBinding")
.field("type", &std::any::type_name::<T>())
.finish()
}
}
impl<T: Serialize + DeserializeOwned> EntryBinding<T> for SerdeBinding<T> {
fn entry_to_object(&self, entry: &DatabaseEntry) -> Result<T> {
let data = entry.data();
if data.len() < SERDE_BINDING_HEADER_LEN {
return Err(crate::BindError::VersionMismatch {
expected_magic: SERDE_BINDING_MAGIC,
expected_version: SERDE_BINDING_VERSION,
found_magic: data.first().copied().unwrap_or(0),
found_version: data.get(1).copied().unwrap_or(0),
});
}
if data[0] != SERDE_BINDING_MAGIC || data[1] != SERDE_BINDING_VERSION {
return Err(crate::BindError::VersionMismatch {
expected_magic: SERDE_BINDING_MAGIC,
expected_version: SERDE_BINDING_VERSION,
found_magic: data[0],
found_version: data[1],
});
}
simple_serial::from_bytes(&data[SERDE_BINDING_HEADER_LEN..])
}
fn object_to_entry(
&self,
object: &T,
entry: &mut DatabaseEntry,
) -> Result<()> {
let body = simple_serial::to_bytes(object)?;
let mut bytes =
Vec::with_capacity(body.len() + SERDE_BINDING_HEADER_LEN);
bytes.push(SERDE_BINDING_MAGIC);
bytes.push(SERDE_BINDING_VERSION);
bytes.extend_from_slice(&body);
entry.set_data_vec(bytes);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[test]
fn test_u32_round_trip() {
let binding = SerdeBinding::<u32>::new();
let mut entry = DatabaseEntry::new();
binding.object_to_entry(&42u32, &mut entry).unwrap();
assert_eq!(binding.entry_to_object(&entry).unwrap(), 42u32);
}
#[test]
fn test_string_round_trip() {
let binding = SerdeBinding::<String>::new();
let mut entry = DatabaseEntry::new();
let s = "hello world".to_string();
binding.object_to_entry(&s, &mut entry).unwrap();
assert_eq!(binding.entry_to_object(&entry).unwrap(), s);
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct TestRecord {
id: u64,
name: String,
active: bool,
}
#[test]
fn test_struct_round_trip() {
let binding = SerdeBinding::<TestRecord>::new();
let record =
TestRecord { id: 12345, name: "test".to_string(), active: true };
let mut entry = DatabaseEntry::new();
binding.object_to_entry(&record, &mut entry).unwrap();
assert_eq!(binding.entry_to_object(&entry).unwrap(), record);
}
#[test]
fn test_vec_round_trip() {
let binding = SerdeBinding::<Vec<u32>>::new();
let v = vec![1, 2, 3, 4, 5];
let mut entry = DatabaseEntry::new();
binding.object_to_entry(&v, &mut entry).unwrap();
assert_eq!(binding.entry_to_object(&entry).unwrap(), v);
}
#[test]
fn test_option_round_trip() {
let binding = SerdeBinding::<Option<String>>::new();
let mut entry = DatabaseEntry::new();
binding.object_to_entry(&Some("yes".to_string()), &mut entry).unwrap();
assert_eq!(
binding.entry_to_object(&entry).unwrap(),
Some("yes".to_string())
);
binding.object_to_entry(&None, &mut entry).unwrap();
assert_eq!(binding.entry_to_object(&entry).unwrap(), None);
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
enum Status {
Active,
Inactive,
Pending(String),
}
#[test]
fn test_enum_round_trip() {
let binding = SerdeBinding::<Status>::new();
let mut entry = DatabaseEntry::new();
binding.object_to_entry(&Status::Active, &mut entry).unwrap();
assert_eq!(binding.entry_to_object(&entry).unwrap(), Status::Active);
binding
.object_to_entry(&Status::Pending("review".to_string()), &mut entry)
.unwrap();
assert_eq!(
binding.entry_to_object(&entry).unwrap(),
Status::Pending("review".to_string())
);
}
#[test]
fn test_default() {
let binding = SerdeBinding::<u32>::default();
let mut entry = DatabaseEntry::new();
binding.object_to_entry(&7u32, &mut entry).unwrap();
assert_eq!(binding.entry_to_object(&entry).unwrap(), 7u32);
}
#[test]
fn test_clone() {
let binding = SerdeBinding::<u32>::new();
let cloned = binding.clone();
let mut entry = DatabaseEntry::new();
binding.object_to_entry(&99u32, &mut entry).unwrap();
assert_eq!(cloned.entry_to_object(&entry).unwrap(), 99u32);
}
#[test]
fn test_debug() {
let binding = SerdeBinding::<u32>::new();
let debug = format!("{:?}", binding);
assert!(debug.contains("SerdeBinding"));
}
#[test]
fn test_empty_entry_error() {
let binding = SerdeBinding::<u32>::new();
let entry = DatabaseEntry::new();
assert!(binding.entry_to_object(&entry).is_err());
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct Nested {
inner: TestRecord,
tags: Vec<String>,
}
#[test]
fn test_nested_struct_round_trip() {
let binding = SerdeBinding::<Nested>::new();
let nested = Nested {
inner: TestRecord {
id: 1,
name: "nested".to_string(),
active: false,
},
tags: vec!["a".to_string(), "b".to_string()],
};
let mut entry = DatabaseEntry::new();
binding.object_to_entry(&nested, &mut entry).unwrap();
assert_eq!(binding.entry_to_object(&entry).unwrap(), nested);
}
#[test]
fn test_tuple_round_trip() {
let binding = SerdeBinding::<(u32, String, bool)>::new();
let val = (42u32, "hello".to_string(), true);
let mut entry = DatabaseEntry::new();
binding.object_to_entry(&val, &mut entry).unwrap();
assert_eq!(binding.entry_to_object(&entry).unwrap(), val);
}
#[test]
fn test_entry_data_is_set() {
let binding = SerdeBinding::<u32>::new();
let mut entry = DatabaseEntry::new();
assert!(entry.is_empty());
binding.object_to_entry(&42u32, &mut entry).unwrap();
assert!(!entry.is_empty());
assert!(entry.get_data().is_some());
}
#[test]
fn test_encoded_payload_starts_with_version_header() {
let binding = SerdeBinding::<u32>::new();
let mut entry = DatabaseEntry::new();
binding.object_to_entry(&42u32, &mut entry).unwrap();
let bytes = entry.get_data().unwrap();
assert!(
bytes.len() >= SERDE_BINDING_HEADER_LEN,
"encoded entry must include the 2-byte header",
);
assert_eq!(bytes[0], SERDE_BINDING_MAGIC);
assert_eq!(bytes[1], SERDE_BINDING_VERSION);
assert_eq!(&bytes[2..], &[0, 0, 0, 42]);
}
#[test]
fn test_decode_unprefixed_payload_returns_version_mismatch() {
let entry = DatabaseEntry::from_bytes(&[0, 0, 0, 42]);
let binding = SerdeBinding::<u32>::new();
let err = binding
.entry_to_object(&entry)
.expect_err("unprefixed payload must fail to decode");
match err {
crate::BindError::VersionMismatch {
expected_magic,
expected_version,
found_magic,
found_version,
} => {
assert_eq!(expected_magic, SERDE_BINDING_MAGIC);
assert_eq!(expected_version, SERDE_BINDING_VERSION);
assert_eq!(found_magic, 0x00);
assert_eq!(found_version, 0x00);
}
other => panic!("expected VersionMismatch, got {:?}", other),
}
}
#[test]
fn test_decode_short_payload_returns_version_mismatch() {
let binding = SerdeBinding::<u32>::new();
for short in &[&[][..], &[SERDE_BINDING_MAGIC][..]] {
let entry = DatabaseEntry::from_bytes(short);
let err = binding
.entry_to_object(&entry)
.expect_err("short payload must fail to decode");
assert!(
matches!(err, crate::BindError::VersionMismatch { .. }),
"short payload (len={}) must fail with VersionMismatch, got {:?}",
short.len(),
err,
);
}
}
#[test]
fn test_decode_wrong_version_returns_version_mismatch() {
let mut bytes = vec![SERDE_BINDING_MAGIC, 0xFF];
bytes.extend_from_slice(&42u32.to_be_bytes());
let entry = DatabaseEntry::from_bytes(&bytes);
let binding = SerdeBinding::<u32>::new();
let err = binding
.entry_to_object(&entry)
.expect_err("wrong-version payload must fail to decode");
match err {
crate::BindError::VersionMismatch {
found_magic,
found_version,
..
} => {
assert_eq!(found_magic, SERDE_BINDING_MAGIC);
assert_eq!(found_version, 0xFF);
}
other => panic!("expected VersionMismatch, got {:?}", other),
}
}
#[test]
fn test_version_mismatch_display() {
let err = crate::BindError::VersionMismatch {
expected_magic: 0xCB,
expected_version: 0x01,
found_magic: 0x00,
found_version: 0x00,
};
let s = err.to_string();
assert!(s.contains("0xCB"), "display must include expected magic: {s}");
assert!(
s.contains("0x01"),
"display must include expected version: {s}"
);
assert!(
s.contains("version mismatch"),
"display must name the failure: {s}"
);
}
}