use crate::{MemError, MemResult};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use uuid::Uuid;
#[allow(dead_code)]
const MAGIC_BYTES: [u8; 4] = [b'R', b'K', b'M', b'E'];
const FORMAT_VERSION: u8 = 1;
#[allow(dead_code)]
const HEADER_SIZE: usize = 6;
const FLAG_COMPRESSED: u8 = 0x01;
const FLAG_CHECKSUM: u8 = 0x02;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntry {
pub id: Uuid,
pub content: String,
#[serde(skip)]
pub embedding: Vec<f32>,
pub metadata: std::collections::HashMap<String, String>,
pub created_at: i64,
pub updated_at: Option<i64>,
pub document_id: Option<Uuid>,
pub chunk_index: Option<usize>,
pub tags: Vec<String>,
}
impl MemoryEntry {
pub fn new(content: String, embedding: Vec<f32>) -> Self {
Self {
id: Uuid::new_v4(),
content,
embedding,
metadata: std::collections::HashMap::new(),
created_at: chrono::Utc::now().timestamp(),
updated_at: None,
document_id: None,
chunk_index: None,
tags: Vec::new(),
}
}
pub fn with_id(id: Uuid, content: String, embedding: Vec<f32>) -> Self {
let mut entry = Self::new(content, embedding);
entry.id = id;
entry
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WalEntry {
pub sequence: u64,
pub operation: WalOperation,
pub data: WalEntryData,
pub timestamp_ns: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum WalOperation {
Insert,
Update,
Delete,
Checkpoint,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WalEntryData {
Entry(MemoryEntryCompact),
EntryId(Uuid),
Checkpoint { last_sequence: u64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntryCompact {
pub id: Uuid,
pub content: String,
#[serde(with = "serde_bytes")]
pub embedding_bytes: Vec<u8>,
pub metadata: std::collections::HashMap<String, String>,
pub created_at: i64,
pub document_id: Option<Uuid>,
pub chunk_index: Option<usize>,
pub tags: Vec<String>,
}
impl From<&MemoryEntry> for MemoryEntryCompact {
fn from(entry: &MemoryEntry) -> Self {
Self {
id: entry.id,
content: entry.content.clone(),
embedding_bytes: serialize_embedding(&entry.embedding),
metadata: entry.metadata.clone(),
created_at: entry.created_at,
document_id: entry.document_id,
chunk_index: entry.chunk_index,
tags: entry.tags.clone(),
}
}
}
impl MemoryEntryCompact {
pub fn to_memory_entry(&self) -> MemResult<MemoryEntry> {
let embedding = deserialize_embedding(&self.embedding_bytes)?;
Ok(MemoryEntry {
id: self.id,
content: self.content.clone(),
embedding,
metadata: self.metadata.clone(),
created_at: self.created_at,
updated_at: None,
document_id: self.document_id,
chunk_index: self.chunk_index,
tags: self.tags.clone(),
})
}
}
#[inline]
pub fn serialize_embedding(embedding: &[f32]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(embedding.len() * 4);
for &value in embedding {
bytes.extend_from_slice(&value.to_le_bytes());
}
bytes
}
#[inline]
pub fn deserialize_embedding(bytes: &[u8]) -> MemResult<Vec<f32>> {
if bytes.len() % 4 != 0 {
return Err(MemError::Serialization(format!(
"Invalid embedding byte length: {} (must be divisible by 4)",
bytes.len()
)));
}
let capacity = bytes.len() / 4;
let mut embedding = Vec::with_capacity(capacity);
for chunk in bytes.chunks_exact(4) {
let arr: [u8; 4] = chunk.try_into().unwrap();
embedding.push(f32::from_le_bytes(arr));
}
Ok(embedding)
}
#[inline]
pub fn serialize_embedding_with_dim(embedding: &[f32]) -> Vec<u8> {
let dim = embedding.len() as u32;
let mut bytes = Vec::with_capacity(4 + embedding.len() * 4);
bytes.extend_from_slice(&dim.to_le_bytes());
bytes.extend(serialize_embedding(embedding));
bytes
}
#[inline]
pub fn deserialize_embedding_with_dim(
bytes: &[u8],
expected_dim: Option<usize>,
) -> MemResult<Vec<f32>> {
if bytes.len() < 4 {
return Err(MemError::Serialization(
"Invalid embedding bytes: too short for dimension prefix".to_string(),
));
}
let dim_bytes: [u8; 4] = bytes[..4].try_into().unwrap();
let dim = u32::from_le_bytes(dim_bytes) as usize;
if let Some(expected) = expected_dim {
if dim != expected {
return Err(MemError::Serialization(format!(
"Embedding dimension mismatch: expected {}, got {}",
expected, dim
)));
}
}
let expected_len = 4 + dim * 4;
if bytes.len() != expected_len {
return Err(MemError::Serialization(format!(
"Invalid embedding byte length: expected {}, got {}",
expected_len,
bytes.len()
)));
}
deserialize_embedding(&bytes[4..])
}
pub fn serialize_entry(entry: &MemoryEntry) -> MemResult<Vec<u8>> {
let compact = MemoryEntryCompact::from(entry);
serialize_msgpack(&compact)
}
pub fn deserialize_entry(bytes: &[u8]) -> MemResult<MemoryEntry> {
let compact: MemoryEntryCompact = deserialize_msgpack(bytes)?;
compact.to_memory_entry()
}
pub fn serialize_msgpack<T: Serialize>(value: &T) -> MemResult<Vec<u8>> {
rmp_serde::to_vec(value)
.map_err(|e| MemError::Serialization(format!("MessagePack serialization failed: {}", e)))
}
pub fn deserialize_msgpack<T: DeserializeOwned>(bytes: &[u8]) -> MemResult<T> {
rmp_serde::from_slice(bytes)
.map_err(|e| MemError::Serialization(format!("MessagePack deserialization failed: {}", e)))
}
pub fn serialize_wal_entry(entry: &WalEntry) -> MemResult<Vec<u8>> {
const WAL_MAGIC: [u8; 4] = [b'R', b'K', b'W', b'E'];
let payload = serialize_msgpack(entry)?;
let checksum = crc32_checksum(&payload);
let mut result = Vec::with_capacity(10 + payload.len());
result.extend_from_slice(&WAL_MAGIC);
result.push(FORMAT_VERSION);
result.push(FLAG_CHECKSUM);
result.extend_from_slice(&checksum.to_le_bytes());
result.extend(payload);
Ok(result)
}
pub fn deserialize_wal_entry(bytes: &[u8]) -> MemResult<WalEntry> {
const WAL_MAGIC: [u8; 4] = [b'R', b'K', b'W', b'E'];
const MIN_SIZE: usize = 10;
if bytes.len() < MIN_SIZE {
return Err(MemError::Serialization(format!(
"WAL entry too short: {} bytes (minimum {})",
bytes.len(),
MIN_SIZE
)));
}
if bytes[..4] != WAL_MAGIC {
return Err(MemError::Serialization(
"Invalid WAL entry: wrong magic bytes".to_string(),
));
}
let version = bytes[4];
if version != FORMAT_VERSION {
return Err(MemError::Serialization(format!(
"Unsupported WAL format version: {} (expected {})",
version, FORMAT_VERSION
)));
}
let flags = bytes[5];
let has_checksum = (flags & FLAG_CHECKSUM) != 0;
let stored_checksum = u32::from_le_bytes([bytes[6], bytes[7], bytes[8], bytes[9]]);
let payload = &bytes[10..];
if has_checksum {
let computed_checksum = crc32_checksum(payload);
if stored_checksum != computed_checksum {
return Err(MemError::Serialization(format!(
"WAL entry checksum mismatch: stored {:08x}, computed {:08x}",
stored_checksum, computed_checksum
)));
}
}
deserialize_msgpack(payload)
}
#[cfg(feature = "compression")]
pub fn compress(data: &[u8]) -> MemResult<Vec<u8>> {
use lz4_flex::compress_prepend_size;
Ok(compress_prepend_size(data))
}
#[cfg(feature = "compression")]
pub fn decompress(data: &[u8]) -> MemResult<Vec<u8>> {
use lz4_flex::decompress_size_prepended;
decompress_size_prepended(data)
.map_err(|e| MemError::Serialization(format!("LZ4 decompression failed: {}", e)))
}
#[cfg(not(feature = "compression"))]
pub fn compress(data: &[u8]) -> MemResult<Vec<u8>> {
let mut result = Vec::with_capacity(1 + data.len());
result.push(0x00); result.extend_from_slice(data);
Ok(result)
}
#[cfg(not(feature = "compression"))]
pub fn decompress(data: &[u8]) -> MemResult<Vec<u8>> {
if data.is_empty() {
return Err(MemError::Serialization(
"Empty data cannot be decompressed".to_string(),
));
}
if data[0] != 0x00 {
return Err(MemError::Serialization(
"Compressed data requires 'compression' feature".to_string(),
));
}
Ok(data[1..].to_vec())
}
pub fn serialize_entry_compressed(
entry: &MemoryEntry,
compression_threshold: usize,
) -> MemResult<Vec<u8>> {
let raw = serialize_entry(entry)?;
if raw.len() >= compression_threshold {
let compressed = compress(&raw)?;
if compressed.len() < raw.len() {
let mut result = Vec::with_capacity(1 + compressed.len());
result.push(FLAG_COMPRESSED);
result.extend(compressed);
return Ok(result);
}
}
let mut result = Vec::with_capacity(1 + raw.len());
result.push(0x00);
result.extend(raw);
Ok(result)
}
pub fn deserialize_entry_compressed(bytes: &[u8]) -> MemResult<MemoryEntry> {
if bytes.is_empty() {
return Err(MemError::Serialization("Empty data".to_string()));
}
let is_compressed = bytes[0] == FLAG_COMPRESSED;
let payload = &bytes[1..];
let raw = if is_compressed {
decompress(payload)?
} else {
payload.to_vec()
};
deserialize_entry(&raw)
}
pub fn serialize_batch(entries: &[MemoryEntry]) -> MemResult<Vec<u8>> {
let count = entries.len() as u32;
let estimated_size = 4 + entries.len() * 1024;
let mut buffer = Vec::with_capacity(estimated_size);
buffer.extend_from_slice(&count.to_le_bytes());
for entry in entries {
let serialized = serialize_entry(entry)?;
let len = serialized.len() as u32;
buffer.extend_from_slice(&len.to_le_bytes());
buffer.extend(serialized);
}
Ok(buffer)
}
pub fn deserialize_batch(bytes: &[u8]) -> MemResult<Vec<MemoryEntry>> {
if bytes.len() < 4 {
return Err(MemError::Serialization("Batch data too short".to_string()));
}
let count = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
let mut entries = Vec::with_capacity(count);
let mut offset = 4;
for _ in 0..count {
if offset + 4 > bytes.len() {
return Err(MemError::Serialization("Batch data truncated".to_string()));
}
let len = u32::from_le_bytes([
bytes[offset],
bytes[offset + 1],
bytes[offset + 2],
bytes[offset + 3],
]) as usize;
offset += 4;
if offset + len > bytes.len() {
return Err(MemError::Serialization("Batch entry truncated".to_string()));
}
let entry = deserialize_entry(&bytes[offset..offset + len])?;
entries.push(entry);
offset += len;
}
Ok(entries)
}
fn crc32_checksum(data: &[u8]) -> u32 {
const CRC32_TABLE: [u32; 256] = generate_crc32_table();
let mut crc = 0xFFFF_FFFF_u32;
for &byte in data {
let index = ((crc ^ byte as u32) & 0xFF) as usize;
crc = CRC32_TABLE[index] ^ (crc >> 8);
}
crc ^ 0xFFFF_FFFF
}
const fn generate_crc32_table() -> [u32; 256] {
const POLYNOMIAL: u32 = 0xEDB88320;
let mut table = [0u32; 256];
let mut i = 0;
while i < 256 {
let mut crc = i as u32;
let mut j = 0;
while j < 8 {
if crc & 1 != 0 {
crc = (crc >> 1) ^ POLYNOMIAL;
} else {
crc >>= 1;
}
j += 1;
}
table[i] = crc;
i += 1;
}
table
}
pub fn estimate_entry_size(entry: &MemoryEntry) -> usize {
let base = 64;
let content = entry.content.len();
let embedding = entry.embedding.len() * 4;
let metadata: usize = entry
.metadata
.iter()
.map(|(k, v)| k.len() + v.len() + 8)
.sum();
let tags: usize = entry.tags.iter().map(|t| t.len() + 4).sum();
base + content + embedding + metadata + tags
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_embedding_roundtrip() {
let original: Vec<f32> = vec![1.0, -2.5, std::f32::consts::PI, 0.0, f32::MAX, f32::MIN];
let bytes = serialize_embedding(&original);
let recovered = deserialize_embedding(&bytes).unwrap();
assert_eq!(original, recovered);
}
#[test]
fn test_serialize_embedding_empty() {
let original: Vec<f32> = vec![];
let bytes = serialize_embedding(&original);
assert!(bytes.is_empty());
let recovered = deserialize_embedding(&bytes).unwrap();
assert!(recovered.is_empty());
}
#[test]
fn test_deserialize_embedding_invalid_length() {
let bytes = vec![1, 2, 3]; let result = deserialize_embedding(&bytes);
assert!(result.is_err());
}
#[test]
fn test_serialize_embedding_with_dim() {
let original = vec![1.0, 2.0, 3.0];
let bytes = serialize_embedding_with_dim(&original);
assert_eq!(bytes.len(), 4 + 12);
let recovered = deserialize_embedding_with_dim(&bytes, Some(3)).unwrap();
assert_eq!(original, recovered);
}
#[test]
fn test_deserialize_embedding_dim_mismatch() {
let original = vec![1.0, 2.0, 3.0];
let bytes = serialize_embedding_with_dim(&original);
let result = deserialize_embedding_with_dim(&bytes, Some(4));
assert!(result.is_err());
}
#[test]
fn test_memory_entry_roundtrip() {
let mut entry = MemoryEntry::new(
"Test content for memory entry".to_string(),
vec![0.1, 0.2, 0.3, 0.4, 0.5],
);
entry
.metadata
.insert("key".to_string(), "value".to_string());
entry.tags.push("test".to_string());
let bytes = serialize_entry(&entry).unwrap();
let recovered = deserialize_entry(&bytes).unwrap();
assert_eq!(entry.id, recovered.id);
assert_eq!(entry.content, recovered.content);
assert_eq!(entry.embedding, recovered.embedding);
assert_eq!(entry.metadata, recovered.metadata);
assert_eq!(entry.tags, recovered.tags);
}
#[test]
fn test_wal_entry_roundtrip() {
let entry = WalEntry {
sequence: 42,
operation: WalOperation::Insert,
data: WalEntryData::Entry(MemoryEntryCompact {
id: Uuid::new_v4(),
content: "WAL test".to_string(),
embedding_bytes: serialize_embedding(&[1.0, 2.0]),
metadata: std::collections::HashMap::new(),
created_at: 1234567890,
document_id: None,
chunk_index: Some(0),
tags: vec![],
}),
timestamp_ns: 1234567890123456789,
};
let bytes = serialize_wal_entry(&entry).unwrap();
let recovered = deserialize_wal_entry(&bytes).unwrap();
assert_eq!(entry.sequence, recovered.sequence);
assert_eq!(entry.operation, recovered.operation);
assert_eq!(entry.timestamp_ns, recovered.timestamp_ns);
}
#[test]
fn test_wal_entry_checksum_corruption() {
let entry = WalEntry {
sequence: 1,
operation: WalOperation::Delete,
data: WalEntryData::EntryId(Uuid::new_v4()),
timestamp_ns: 0,
};
let mut bytes = serialize_wal_entry(&entry).unwrap();
if let Some(last) = bytes.last_mut() {
*last ^= 0xFF;
}
let result = deserialize_wal_entry(&bytes);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("checksum mismatch"));
}
#[test]
fn test_wal_entry_invalid_magic() {
let bytes = vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0];
let result = deserialize_wal_entry(&bytes);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("wrong magic bytes"));
}
#[test]
fn test_batch_serialization() {
let entries: Vec<MemoryEntry> = (0..5)
.map(|i| MemoryEntry::new(format!("Entry {}", i), vec![i as f32; 4]))
.collect();
let bytes = serialize_batch(&entries).unwrap();
let recovered = deserialize_batch(&bytes).unwrap();
assert_eq!(entries.len(), recovered.len());
for (orig, rec) in entries.iter().zip(recovered.iter()) {
assert_eq!(orig.content, rec.content);
assert_eq!(orig.embedding, rec.embedding);
}
}
#[test]
fn test_compression_roundtrip() {
let entry = MemoryEntry::new(
"A".repeat(10000), vec![0.1; 1000],
);
let bytes = serialize_entry_compressed(&entry, 1000).unwrap();
let recovered = deserialize_entry_compressed(&bytes).unwrap();
assert_eq!(entry.content, recovered.content);
assert_eq!(entry.embedding, recovered.embedding);
}
#[test]
fn test_estimate_entry_size() {
let entry = MemoryEntry::new("Hello".to_string(), vec![1.0; 384]);
let estimated = estimate_entry_size(&entry);
let actual = serialize_entry(&entry).unwrap().len();
assert!(estimated > actual / 2);
assert!(estimated < actual * 3);
}
#[test]
fn test_crc32_known_values() {
let checksum = crc32_checksum(b"123456789");
assert_eq!(checksum, 0xCBF43926);
}
#[test]
fn test_large_embedding() {
let large_embedding: Vec<f32> = (0..1536).map(|i| i as f32 * 0.001).collect();
let bytes = serialize_embedding(&large_embedding);
assert_eq!(bytes.len(), 1536 * 4);
let recovered = deserialize_embedding(&bytes).unwrap();
assert_eq!(large_embedding, recovered);
}
#[test]
fn test_special_float_values() {
let special = vec![
f32::INFINITY,
f32::NEG_INFINITY,
f32::NAN,
0.0,
-0.0,
f32::EPSILON,
];
let bytes = serialize_embedding(&special);
let recovered = deserialize_embedding(&bytes).unwrap();
assert!(recovered[0].is_infinite() && recovered[0].is_sign_positive());
assert!(recovered[1].is_infinite() && recovered[1].is_sign_negative());
assert!(recovered[2].is_nan());
assert_eq!(recovered[3], 0.0);
assert_eq!(recovered[4], -0.0);
assert_eq!(recovered[5], f32::EPSILON);
}
}