use crate::error::{PersistError, Result};
pub trait Entity: Sized {
type PrimaryKey: PrimaryKey;
fn primary_key(&self) -> &Self::PrimaryKey;
fn entity_name() -> &'static str;
fn class_version() -> u16 {
0
}
}
pub trait PrimaryKey: Clone + Eq + std::hash::Hash {
fn to_bytes(&self) -> Vec<u8>;
fn from_bytes(bytes: &[u8]) -> Result<Self>;
fn to_sortable_bytes(&self) -> Vec<u8> {
self.to_bytes()
}
fn from_sortable_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
Ok((Self::from_bytes(bytes)?, bytes.len()))
}
}
fn encode_sortable_byte_string(data: &[u8], out: &mut Vec<u8>) {
for &b in data {
if b == 0x00 {
out.push(0x00);
out.push(0x01);
} else {
out.push(b);
}
}
out.push(0x00);
out.push(0x00);
}
fn decode_sortable_byte_string(bytes: &[u8]) -> Result<(Vec<u8>, usize)> {
let mut out = Vec::new();
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
if b == 0x00 {
match bytes.get(i + 1) {
Some(0x00) => return Ok((out, i + 2)), Some(0x01) => {
out.push(0x00);
i += 2;
}
_ => {
return Err(PersistError::SerializationError(
"invalid escape sequence decoding sortable byte string"
.into(),
));
}
}
} else {
out.push(b);
i += 1;
}
}
Err(PersistError::SerializationError(
"unterminated sortable byte string".into(),
))
}
impl PrimaryKey for u64 {
fn to_bytes(&self) -> Vec<u8> {
self.to_be_bytes().to_vec()
}
fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != 8 {
return Err(PersistError::SerializationError(format!(
"expected 8 bytes for u64, got {}",
bytes.len()
)));
}
let mut buf = [0u8; 8];
buf.copy_from_slice(bytes);
Ok(u64::from_be_bytes(buf))
}
fn from_sortable_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
if bytes.len() < 8 {
return Err(PersistError::SerializationError(format!(
"expected 8 bytes for u64, got {}",
bytes.len()
)));
}
Ok((Self::from_bytes(&bytes[..8])?, 8))
}
}
impl PrimaryKey for i64 {
fn to_bytes(&self) -> Vec<u8> {
let sortable = (*self as u64) ^ 0x8000_0000_0000_0000u64;
sortable.to_be_bytes().to_vec()
}
fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != 8 {
return Err(PersistError::SerializationError(format!(
"expected 8 bytes for i64, got {}",
bytes.len()
)));
}
let mut buf = [0u8; 8];
buf.copy_from_slice(bytes);
let sortable = u64::from_be_bytes(buf);
Ok((sortable ^ 0x8000_0000_0000_0000u64) as i64)
}
fn from_sortable_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
if bytes.len() < 8 {
return Err(PersistError::SerializationError(format!(
"expected 8 bytes for i64, got {}",
bytes.len()
)));
}
Ok((Self::from_bytes(&bytes[..8])?, 8))
}
}
impl PrimaryKey for u32 {
fn to_bytes(&self) -> Vec<u8> {
self.to_be_bytes().to_vec()
}
fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != 4 {
return Err(PersistError::SerializationError(format!(
"expected 4 bytes for u32, got {}",
bytes.len()
)));
}
let mut buf = [0u8; 4];
buf.copy_from_slice(bytes);
Ok(u32::from_be_bytes(buf))
}
fn from_sortable_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
if bytes.len() < 4 {
return Err(PersistError::SerializationError(format!(
"expected 4 bytes for u32, got {}",
bytes.len()
)));
}
Ok((Self::from_bytes(&bytes[..4])?, 4))
}
}
impl PrimaryKey for i32 {
fn to_bytes(&self) -> Vec<u8> {
let sortable = (*self as u32) ^ 0x8000_0000u32;
sortable.to_be_bytes().to_vec()
}
fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != 4 {
return Err(PersistError::SerializationError(format!(
"expected 4 bytes for i32, got {}",
bytes.len()
)));
}
let mut buf = [0u8; 4];
buf.copy_from_slice(bytes);
let sortable = u32::from_be_bytes(buf);
Ok((sortable ^ 0x8000_0000u32) as i32)
}
fn from_sortable_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
if bytes.len() < 4 {
return Err(PersistError::SerializationError(format!(
"expected 4 bytes for i32, got {}",
bytes.len()
)));
}
Ok((Self::from_bytes(&bytes[..4])?, 4))
}
}
impl PrimaryKey for String {
fn to_bytes(&self) -> Vec<u8> {
self.as_bytes().to_vec()
}
fn from_bytes(bytes: &[u8]) -> Result<Self> {
String::from_utf8(bytes.to_vec()).map_err(|e| {
PersistError::SerializationError(format!(
"invalid UTF-8 for String key: {}",
e
))
})
}
fn to_sortable_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(self.len() + 2);
encode_sortable_byte_string(self.as_bytes(), &mut out);
out
}
fn from_sortable_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
let (raw, consumed) = decode_sortable_byte_string(bytes)?;
let s = String::from_utf8(raw).map_err(|e| {
PersistError::SerializationError(format!(
"invalid UTF-8 for String key: {}",
e
))
})?;
Ok((s, consumed))
}
}
impl PrimaryKey for Vec<u8> {
fn to_bytes(&self) -> Vec<u8> {
self.clone()
}
fn from_bytes(bytes: &[u8]) -> Result<Self> {
Ok(bytes.to_vec())
}
fn to_sortable_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(self.len() + 2);
encode_sortable_byte_string(self, &mut out);
out
}
fn from_sortable_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
decode_sortable_byte_string(bytes)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_u64_round_trip() {
let val: u64 = 42;
let bytes = val.to_bytes();
assert_eq!(bytes.len(), 8);
let decoded = u64::from_bytes(&bytes).unwrap();
assert_eq!(val, decoded);
}
#[test]
fn test_u64_zero() {
let val: u64 = 0;
let bytes = val.to_bytes();
let decoded = u64::from_bytes(&bytes).unwrap();
assert_eq!(val, decoded);
}
#[test]
fn test_u64_max() {
let val: u64 = u64::MAX;
let bytes = val.to_bytes();
let decoded = u64::from_bytes(&bytes).unwrap();
assert_eq!(val, decoded);
}
#[test]
fn test_u64_wrong_length() {
let result = u64::from_bytes(&[1, 2, 3]);
assert!(result.is_err());
}
#[test]
fn test_i64_round_trip() {
let val: i64 = -42;
let bytes = val.to_bytes();
let decoded = i64::from_bytes(&bytes).unwrap();
assert_eq!(val, decoded);
}
#[test]
fn test_i64_negative() {
let val: i64 = i64::MIN;
let bytes = val.to_bytes();
let decoded = i64::from_bytes(&bytes).unwrap();
assert_eq!(val, decoded);
}
#[test]
fn test_i64_wrong_length() {
let result = i64::from_bytes(&[1]);
assert!(result.is_err());
}
#[test]
fn test_u32_round_trip() {
let val: u32 = 12345;
let bytes = val.to_bytes();
assert_eq!(bytes.len(), 4);
let decoded = u32::from_bytes(&bytes).unwrap();
assert_eq!(val, decoded);
}
#[test]
fn test_u32_wrong_length() {
let result = u32::from_bytes(&[1, 2]);
assert!(result.is_err());
}
#[test]
fn test_i32_round_trip() {
let val: i32 = -999;
let bytes = val.to_bytes();
let decoded = i32::from_bytes(&bytes).unwrap();
assert_eq!(val, decoded);
}
#[test]
fn test_i32_wrong_length() {
let result = i32::from_bytes(&[]);
assert!(result.is_err());
}
#[test]
fn test_string_round_trip() {
let val = String::from("hello world");
let bytes = val.to_bytes();
let decoded = String::from_bytes(&bytes).unwrap();
assert_eq!(val, decoded);
}
#[test]
fn test_string_empty() {
let val = String::from("");
let bytes = val.to_bytes();
let decoded = String::from_bytes(&bytes).unwrap();
assert_eq!(val, decoded);
}
#[test]
fn test_string_unicode() {
let val = String::from("hello \u{1F600} world");
let bytes = val.to_bytes();
let decoded = String::from_bytes(&bytes).unwrap();
assert_eq!(val, decoded);
}
#[test]
fn test_string_invalid_utf8() {
let result = String::from_bytes(&[0xFF, 0xFE]);
assert!(result.is_err());
}
#[test]
fn test_vec_u8_round_trip() {
let val: Vec<u8> = vec![1, 2, 3, 4, 5];
let bytes = val.to_bytes();
let decoded = Vec::<u8>::from_bytes(&bytes).unwrap();
assert_eq!(val, decoded);
}
#[test]
fn test_vec_u8_empty() {
let val: Vec<u8> = vec![];
let bytes = val.to_bytes();
let decoded = Vec::<u8>::from_bytes(&bytes).unwrap();
assert_eq!(val, decoded);
}
#[derive(Clone, Debug, PartialEq)]
struct TestEntity {
id: u64,
name: String,
}
impl Entity for TestEntity {
type PrimaryKey = u64;
fn primary_key(&self) -> &u64 {
&self.id
}
fn entity_name() -> &'static str {
"TestEntity"
}
}
#[test]
fn test_entity_primary_key() {
let entity = TestEntity { id: 42, name: "test".to_string() };
assert_eq!(*entity.primary_key(), 42);
}
#[test]
fn test_entity_name() {
assert_eq!(TestEntity::entity_name(), "TestEntity");
}
#[test]
fn test_u64_byte_ordering() {
let a: u64 = 1;
let b: u64 = 256;
let bytes_a = a.to_bytes();
let bytes_b = b.to_bytes();
assert!(bytes_a < bytes_b);
}
#[test]
fn test_u32_byte_ordering() {
let a: u32 = 100;
let b: u32 = 200;
let bytes_a = a.to_bytes();
let bytes_b = b.to_bytes();
assert!(bytes_a < bytes_b);
}
#[test]
fn test_i32_min_sorts_before_max() {
let bytes_min = i32::MIN.to_bytes();
let bytes_max = i32::MAX.to_bytes();
assert!(bytes_min < bytes_max, "i32::MIN should sort before i32::MAX");
}
#[test]
fn test_i32_negative_one_sorts_before_zero() {
let bytes_neg = (-1i32).to_bytes();
let bytes_zero = 0i32.to_bytes();
assert!(bytes_neg < bytes_zero, "-1 should sort before 0");
}
#[test]
fn test_i32_sort_order_sequence() {
let values: Vec<i32> = vec![i32::MIN, -1000, -1, 0, 1, 1000, i32::MAX];
let encoded: Vec<Vec<u8>> =
values.iter().map(|v| v.to_bytes()).collect();
for i in 0..encoded.len() - 1 {
assert!(
encoded[i] < encoded[i + 1],
"i32 sort order: {} (encoded {:?}) should be < {} (encoded {:?})",
values[i],
encoded[i],
values[i + 1],
encoded[i + 1]
);
}
}
#[test]
fn test_i32_round_trip_with_sort_encoding() {
for val in [i32::MIN, -1, 0, 1, i32::MAX] {
let bytes = val.to_bytes();
let decoded = i32::from_bytes(&bytes).unwrap();
assert_eq!(val, decoded, "i32 round-trip failed for {}", val);
}
}
#[test]
fn test_i64_min_sorts_before_max() {
let bytes_min = i64::MIN.to_bytes();
let bytes_max = i64::MAX.to_bytes();
assert!(bytes_min < bytes_max, "i64::MIN should sort before i64::MAX");
}
#[test]
fn test_i64_negative_one_sorts_before_zero() {
let bytes_neg = (-1i64).to_bytes();
let bytes_zero = 0i64.to_bytes();
assert!(bytes_neg < bytes_zero, "-1i64 should sort before 0");
}
#[test]
fn test_i64_sort_order_sequence() {
let values: Vec<i64> = vec![i64::MIN, -1000, -1, 0, 1, 1000, i64::MAX];
let encoded: Vec<Vec<u8>> =
values.iter().map(|v| v.to_bytes()).collect();
for i in 0..encoded.len() - 1 {
assert!(
encoded[i] < encoded[i + 1],
"i64 sort order: {} should be < {}",
values[i],
values[i + 1]
);
}
}
#[test]
fn test_i64_round_trip_with_sort_encoding() {
for val in [i64::MIN, -1, 0, 1, i64::MAX] {
let bytes = val.to_bytes();
let decoded = i64::from_bytes(&bytes).unwrap();
assert_eq!(val, decoded, "i64 round-trip failed for {}", val);
}
}
#[test]
fn test_i32_encoding_known_values() {
assert_eq!(i32::MIN.to_bytes(), vec![0x00, 0x00, 0x00, 0x00]);
assert_eq!((-1i32).to_bytes(), vec![0x7f, 0xff, 0xff, 0xff]);
assert_eq!(0i32.to_bytes(), vec![0x80, 0x00, 0x00, 0x00]);
assert_eq!(i32::MAX.to_bytes(), vec![0xff, 0xff, 0xff, 0xff]);
}
#[test]
fn test_i64_encoding_known_values() {
assert_eq!(
i64::MIN.to_bytes(),
vec![0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
assert_eq!(
0i64.to_bytes(),
vec![0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
assert_eq!(
i64::MAX.to_bytes(),
vec![0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff]
);
}
fn rt_one<T: PrimaryKey + std::fmt::Debug>(v: T) {
let enc = v.to_sortable_bytes();
let (dec, consumed) = T::from_sortable_bytes(&enc).unwrap();
assert_eq!(v, dec, "round-trip value mismatch");
assert_eq!(consumed, enc.len(), "consumed != encoded length");
}
#[test]
fn sortable_round_trip_each_type() {
rt_one(0u64);
rt_one(u64::MAX);
rt_one(-1i64);
rt_one(i64::MIN);
rt_one(0u32);
rt_one(u32::MAX);
rt_one(-7i32);
rt_one(i32::MIN);
rt_one(String::from("hello"));
rt_one(String::new());
rt_one(vec![1u8, 2, 3]);
rt_one(Vec::<u8>::new());
}
#[test]
fn sortable_round_trip_embedded_nulls() {
let v = vec![0u8, 1, 0, 0, 2, 255];
rt_one(v);
rt_one(String::from("a\0b\0"));
}
#[test]
fn sortable_sequential_decode_two_fields() {
let s = String::from("region");
let n = 42u32;
let mut buf = s.to_sortable_bytes();
buf.extend_from_slice(&n.to_sortable_bytes());
let (s2, c0) = String::from_sortable_bytes(&buf).unwrap();
let (n2, c1) = u32::from_sortable_bytes(&buf[c0..]).unwrap();
assert_eq!(s, s2);
assert_eq!(n, n2);
assert_eq!(c0 + c1, buf.len());
}
#[test]
fn sortable_string_order_preserving() {
let mut inputs = vec!["", "a", "aa", "aaa", "ab", "b", "ba"];
let logical = inputs.clone();
let mut encoded: Vec<(Vec<u8>, &str)> = inputs
.drain(..)
.map(|s| (String::from(s).to_sortable_bytes(), s))
.collect();
encoded.sort();
let by_bytes: Vec<&str> = encoded.iter().map(|(_, s)| *s).collect();
assert_eq!(by_bytes, logical, "byte order must equal logical order");
}
#[test]
fn sortable_bytes_order_with_null() {
let a = vec![0u8].to_sortable_bytes(); let empty = Vec::<u8>::new().to_sortable_bytes();
let b = vec![1u8].to_sortable_bytes();
assert!(empty < a);
assert!(a < b);
}
#[test]
fn sortable_unterminated_string_errors() {
let r = String::from_sortable_bytes(b"abc");
assert!(matches!(r, Err(PersistError::SerializationError(_))));
}
}