1use crate::types::{MemoryRecord, RecordId};
6use rkyv::{Archive, Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Archive, Serialize, Deserialize)]
10#[archive(check_bytes)]
11pub enum WalEntryType {
12 Insert,
14 UpdateStats,
16 Delete,
18 Checkpoint,
20}
21
22#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
26#[archive(check_bytes)]
27pub struct WalEntry {
28 pub sequence: u64,
30 pub entry_type: WalEntryType,
32 pub timestamp_ms: u64,
34 pub record_id: String,
36 pub record_data: Option<WalRecordData>,
38 pub outcome: Option<f64>,
40 pub checksum: u32,
42}
43
44#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
48#[archive(check_bytes)]
49pub struct WalRecordData {
50 pub id: String,
52 pub embedding: Vec<f32>,
54 pub context: String,
56 pub outcome: f64,
58 pub metadata_json: String,
60 pub created_at: u64,
62}
63
64impl WalRecordData {
65 #[must_use]
67 pub fn from_record(record: &MemoryRecord) -> Self {
68 Self {
69 id: record.id.to_string(),
70 embedding: record.embedding.clone(),
71 context: record.context.clone(),
72 outcome: record.outcome,
73 metadata_json: "{}".to_string(), created_at: record.created_at,
75 }
76 }
77}
78
79impl WalEntry {
80 #[must_use]
82 pub fn insert(sequence: u64, record: &MemoryRecord) -> Self {
83 let mut entry = Self {
84 sequence,
85 entry_type: WalEntryType::Insert,
86 timestamp_ms: current_time_ms(),
87 record_id: record.id.to_string(),
88 record_data: Some(WalRecordData::from_record(record)),
89 outcome: None,
90 checksum: 0,
91 };
92 entry.checksum = entry.compute_checksum();
93 entry
94 }
95
96 #[must_use]
98 pub fn update_stats(sequence: u64, record_id: &RecordId, outcome: f64) -> Self {
99 let mut entry = Self {
100 sequence,
101 entry_type: WalEntryType::UpdateStats,
102 timestamp_ms: current_time_ms(),
103 record_id: record_id.to_string(),
104 record_data: None,
105 outcome: Some(outcome),
106 checksum: 0,
107 };
108 entry.checksum = entry.compute_checksum();
109 entry
110 }
111
112 #[must_use]
114 pub fn delete(sequence: u64, record_id: &RecordId) -> Self {
115 let mut entry = Self {
116 sequence,
117 entry_type: WalEntryType::Delete,
118 timestamp_ms: current_time_ms(),
119 record_id: record_id.to_string(),
120 record_data: None,
121 outcome: None,
122 checksum: 0,
123 };
124 entry.checksum = entry.compute_checksum();
125 entry
126 }
127
128 #[must_use]
130 pub fn checkpoint(sequence: u64) -> Self {
131 let mut entry = Self {
132 sequence,
133 entry_type: WalEntryType::Checkpoint,
134 timestamp_ms: current_time_ms(),
135 record_id: String::new(),
136 record_data: None,
137 outcome: None,
138 checksum: 0,
139 };
140 entry.checksum = entry.compute_checksum();
141 entry
142 }
143
144 fn compute_checksum(&self) -> u32 {
146 use xxhash_rust::xxh32::xxh32;
147
148 let mut data = Vec::new();
150
151 data.extend_from_slice(&self.sequence.to_le_bytes());
153 data.push(self.entry_type as u8);
154 data.extend_from_slice(&self.timestamp_ms.to_le_bytes());
155 data.extend_from_slice(self.record_id.as_bytes());
156
157 if let Some(ref record_data) = self.record_data {
158 data.extend_from_slice(record_data.id.as_bytes());
159 data.extend_from_slice(record_data.context.as_bytes());
160 data.extend_from_slice(&record_data.outcome.to_bits().to_le_bytes());
161 }
162
163 if let Some(outcome) = self.outcome {
164 data.extend_from_slice(&outcome.to_bits().to_le_bytes());
165 }
166
167 xxh32(&data, 0)
168 }
169
170 #[must_use]
172 pub fn verify_checksum(&self) -> bool {
173 let mut copy = self.clone();
174 copy.checksum = 0;
175 copy.checksum = copy.compute_checksum();
176 copy.checksum == self.checksum
177 }
178
179 #[must_use]
181 pub fn to_bytes(&self) -> Vec<u8> {
182 rkyv::to_bytes::<_, 256>(self)
183 .expect("WAL entry serialization should not fail")
184 .to_vec()
185 }
186
187 pub fn from_bytes(bytes: &[u8]) -> Result<Self, WalError> {
193 let archived = rkyv::check_archived_root::<Self>(bytes)
194 .map_err(|e| WalError::Corrupted(format!("Failed to deserialize: {e}")))?;
195
196 let entry: Self = archived
197 .deserialize(&mut rkyv::Infallible)
198 .map_err(|_| WalError::Corrupted("Deserialization failed".into()))?;
199
200 if !entry.verify_checksum() {
201 return Err(WalError::ChecksumMismatch);
202 }
203
204 Ok(entry)
205 }
206}
207
208#[derive(Debug, Clone, thiserror::Error)]
210pub enum WalError {
211 #[error("WAL entry corrupted: {0}")]
212 Corrupted(String),
213
214 #[error("Checksum mismatch")]
215 ChecksumMismatch,
216
217 #[error("IO error: {0}")]
218 Io(String),
219
220 #[error("WAL is full")]
221 Full,
222}
223
224fn current_time_ms() -> u64 {
226 use std::time::{SystemTime, UNIX_EPOCH};
227
228 SystemTime::now()
229 .duration_since(UNIX_EPOCH)
230 .map(|d| d.as_millis() as u64)
231 .unwrap_or(0)
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use crate::stats::OutcomeStats;
238 use crate::types::RecordStatus;
239
240 fn create_test_record() -> MemoryRecord {
241 MemoryRecord {
242 id: "test-record".into(),
243 embedding: vec![1.0, 2.0, 3.0],
244 context: "Test context".into(),
245 outcome: 0.8,
246 metadata: Default::default(),
247 created_at: 1234567890,
248 status: RecordStatus::Active,
249 stats: OutcomeStats::new(1),
250 }
251 }
252
253 #[test]
254 fn test_insert_entry() {
255 let record = create_test_record();
256 let entry = WalEntry::insert(1, &record);
257
258 assert_eq!(entry.sequence, 1);
259 assert_eq!(entry.entry_type, WalEntryType::Insert);
260 assert_eq!(entry.record_id, "test-record");
261 assert!(entry.record_data.is_some());
262 assert!(entry.verify_checksum());
263 }
264
265 #[test]
266 fn test_update_stats_entry() {
267 let entry = WalEntry::update_stats(2, &"rec-1".into(), 0.9);
268
269 assert_eq!(entry.sequence, 2);
270 assert_eq!(entry.entry_type, WalEntryType::UpdateStats);
271 assert_eq!(entry.record_id, "rec-1");
272 assert_eq!(entry.outcome, Some(0.9));
273 assert!(entry.verify_checksum());
274 }
275
276 #[test]
277 fn test_delete_entry() {
278 let entry = WalEntry::delete(3, &"rec-2".into());
279
280 assert_eq!(entry.sequence, 3);
281 assert_eq!(entry.entry_type, WalEntryType::Delete);
282 assert_eq!(entry.record_id, "rec-2");
283 assert!(entry.verify_checksum());
284 }
285
286 #[test]
287 fn test_checkpoint_entry() {
288 let entry = WalEntry::checkpoint(100);
289
290 assert_eq!(entry.sequence, 100);
291 assert_eq!(entry.entry_type, WalEntryType::Checkpoint);
292 assert!(entry.verify_checksum());
293 }
294
295 #[test]
296 fn test_serialization_roundtrip() {
297 let record = create_test_record();
298 let entry = WalEntry::insert(1, &record);
299
300 let bytes = entry.to_bytes();
301 let restored = WalEntry::from_bytes(&bytes).unwrap();
302
303 assert_eq!(restored.sequence, entry.sequence);
304 assert_eq!(restored.entry_type, entry.entry_type);
305 assert_eq!(restored.record_id, entry.record_id);
306 assert!(restored.verify_checksum());
307 }
308
309 #[test]
310 fn test_corrupted_bytes() {
311 let record = create_test_record();
312 let entry = WalEntry::insert(1, &record);
313
314 let mut bytes = entry.to_bytes();
315 if !bytes.is_empty() {
317 let mid = bytes.len() / 2;
318 bytes[mid] ^= 0xFF;
319 }
320
321 let result = WalEntry::from_bytes(&bytes);
322 assert!(result.is_err());
323 }
324
325 #[test]
326 fn test_checksum_tamper_detection() {
327 let mut entry = WalEntry::update_stats(1, &"test".into(), 0.5);
328
329 entry.outcome = Some(0.99);
331
332 assert!(!entry.verify_checksum());
333 }
334}