rag_plusplus_core/wal/
entry.rs

1//! WAL Entry Types
2//!
3//! Defines the structure of write-ahead log entries.
4
5use crate::types::{MemoryRecord, RecordId};
6use rkyv::{Archive, Deserialize, Serialize};
7
8/// Type of WAL entry.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Archive, Serialize, Deserialize)]
10#[archive(check_bytes)]
11pub enum WalEntryType {
12    /// Insert a new record
13    Insert,
14    /// Update statistics for a record
15    UpdateStats,
16    /// Delete a record
17    Delete,
18    /// Checkpoint marker
19    Checkpoint,
20}
21
22/// A single WAL entry.
23///
24/// Each entry is self-contained with all information needed to replay it.
25#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
26#[archive(check_bytes)]
27pub struct WalEntry {
28    /// Monotonically increasing sequence number
29    pub sequence: u64,
30    /// Entry type
31    pub entry_type: WalEntryType,
32    /// Timestamp (Unix epoch millis)
33    pub timestamp_ms: u64,
34    /// Record ID (for all entry types)
35    pub record_id: String,
36    /// Full record data (for Insert)
37    pub record_data: Option<WalRecordData>,
38    /// Outcome value (for UpdateStats)
39    pub outcome: Option<f64>,
40    /// CRC32 checksum for integrity
41    pub checksum: u32,
42}
43
44/// Serializable record data for WAL.
45///
46/// Subset of MemoryRecord needed for reconstruction.
47#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
48#[archive(check_bytes)]
49pub struct WalRecordData {
50    /// Record ID
51    pub id: String,
52    /// Embedding vector
53    pub embedding: Vec<f32>,
54    /// Context string
55    pub context: String,
56    /// Initial outcome
57    pub outcome: f64,
58    /// Metadata (serialized as JSON string for simplicity)
59    pub metadata_json: String,
60    /// Creation timestamp
61    pub created_at: u64,
62}
63
64impl WalRecordData {
65    /// Create from a MemoryRecord.
66    #[must_use]
67    pub fn from_record(record: &MemoryRecord) -> Self {
68        Self {
69            id: record.id.to_string(),
70            embedding: record.embedding.clone(),
71            context: record.context.clone(),
72            outcome: record.outcome,
73            metadata_json: "{}".to_string(), // TODO: serialize metadata
74            created_at: record.created_at,
75        }
76    }
77}
78
79impl WalEntry {
80    /// Create a new Insert entry.
81    #[must_use]
82    pub fn insert(sequence: u64, record: &MemoryRecord) -> Self {
83        let mut entry = Self {
84            sequence,
85            entry_type: WalEntryType::Insert,
86            timestamp_ms: current_time_ms(),
87            record_id: record.id.to_string(),
88            record_data: Some(WalRecordData::from_record(record)),
89            outcome: None,
90            checksum: 0,
91        };
92        entry.checksum = entry.compute_checksum();
93        entry
94    }
95
96    /// Create a new UpdateStats entry.
97    #[must_use]
98    pub fn update_stats(sequence: u64, record_id: &RecordId, outcome: f64) -> Self {
99        let mut entry = Self {
100            sequence,
101            entry_type: WalEntryType::UpdateStats,
102            timestamp_ms: current_time_ms(),
103            record_id: record_id.to_string(),
104            record_data: None,
105            outcome: Some(outcome),
106            checksum: 0,
107        };
108        entry.checksum = entry.compute_checksum();
109        entry
110    }
111
112    /// Create a new Delete entry.
113    #[must_use]
114    pub fn delete(sequence: u64, record_id: &RecordId) -> Self {
115        let mut entry = Self {
116            sequence,
117            entry_type: WalEntryType::Delete,
118            timestamp_ms: current_time_ms(),
119            record_id: record_id.to_string(),
120            record_data: None,
121            outcome: None,
122            checksum: 0,
123        };
124        entry.checksum = entry.compute_checksum();
125        entry
126    }
127
128    /// Create a checkpoint marker.
129    #[must_use]
130    pub fn checkpoint(sequence: u64) -> Self {
131        let mut entry = Self {
132            sequence,
133            entry_type: WalEntryType::Checkpoint,
134            timestamp_ms: current_time_ms(),
135            record_id: String::new(),
136            record_data: None,
137            outcome: None,
138            checksum: 0,
139        };
140        entry.checksum = entry.compute_checksum();
141        entry
142    }
143
144    /// Compute CRC32 checksum for this entry.
145    fn compute_checksum(&self) -> u32 {
146        use xxhash_rust::xxh32::xxh32;
147
148        // Build a byte buffer of key fields
149        let mut data = Vec::new();
150
151        // Hash key fields
152        data.extend_from_slice(&self.sequence.to_le_bytes());
153        data.push(self.entry_type as u8);
154        data.extend_from_slice(&self.timestamp_ms.to_le_bytes());
155        data.extend_from_slice(self.record_id.as_bytes());
156
157        if let Some(ref record_data) = self.record_data {
158            data.extend_from_slice(record_data.id.as_bytes());
159            data.extend_from_slice(record_data.context.as_bytes());
160            data.extend_from_slice(&record_data.outcome.to_bits().to_le_bytes());
161        }
162
163        if let Some(outcome) = self.outcome {
164            data.extend_from_slice(&outcome.to_bits().to_le_bytes());
165        }
166
167        xxh32(&data, 0)
168    }
169
170    /// Verify the checksum.
171    #[must_use]
172    pub fn verify_checksum(&self) -> bool {
173        let mut copy = self.clone();
174        copy.checksum = 0;
175        copy.checksum = copy.compute_checksum();
176        copy.checksum == self.checksum
177    }
178
179    /// Serialize to bytes using rkyv.
180    #[must_use]
181    pub fn to_bytes(&self) -> Vec<u8> {
182        rkyv::to_bytes::<_, 256>(self)
183            .expect("WAL entry serialization should not fail")
184            .to_vec()
185    }
186
187    /// Deserialize from bytes.
188    ///
189    /// # Errors
190    ///
191    /// Returns error if deserialization fails or checksum is invalid.
192    pub fn from_bytes(bytes: &[u8]) -> Result<Self, WalError> {
193        let archived = rkyv::check_archived_root::<Self>(bytes)
194            .map_err(|e| WalError::Corrupted(format!("Failed to deserialize: {e}")))?;
195
196        let entry: Self = archived
197            .deserialize(&mut rkyv::Infallible)
198            .map_err(|_| WalError::Corrupted("Deserialization failed".into()))?;
199
200        if !entry.verify_checksum() {
201            return Err(WalError::ChecksumMismatch);
202        }
203
204        Ok(entry)
205    }
206}
207
208/// WAL-specific errors.
209#[derive(Debug, Clone, thiserror::Error)]
210pub enum WalError {
211    #[error("WAL entry corrupted: {0}")]
212    Corrupted(String),
213
214    #[error("Checksum mismatch")]
215    ChecksumMismatch,
216
217    #[error("IO error: {0}")]
218    Io(String),
219
220    #[error("WAL is full")]
221    Full,
222}
223
224/// Get current time in milliseconds.
225fn current_time_ms() -> u64 {
226    use std::time::{SystemTime, UNIX_EPOCH};
227
228    SystemTime::now()
229        .duration_since(UNIX_EPOCH)
230        .map(|d| d.as_millis() as u64)
231        .unwrap_or(0)
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::stats::OutcomeStats;
238    use crate::types::RecordStatus;
239
240    fn create_test_record() -> MemoryRecord {
241        MemoryRecord {
242            id: "test-record".into(),
243            embedding: vec![1.0, 2.0, 3.0],
244            context: "Test context".into(),
245            outcome: 0.8,
246            metadata: Default::default(),
247            created_at: 1234567890,
248            status: RecordStatus::Active,
249            stats: OutcomeStats::new(1),
250        }
251    }
252
253    #[test]
254    fn test_insert_entry() {
255        let record = create_test_record();
256        let entry = WalEntry::insert(1, &record);
257
258        assert_eq!(entry.sequence, 1);
259        assert_eq!(entry.entry_type, WalEntryType::Insert);
260        assert_eq!(entry.record_id, "test-record");
261        assert!(entry.record_data.is_some());
262        assert!(entry.verify_checksum());
263    }
264
265    #[test]
266    fn test_update_stats_entry() {
267        let entry = WalEntry::update_stats(2, &"rec-1".into(), 0.9);
268
269        assert_eq!(entry.sequence, 2);
270        assert_eq!(entry.entry_type, WalEntryType::UpdateStats);
271        assert_eq!(entry.record_id, "rec-1");
272        assert_eq!(entry.outcome, Some(0.9));
273        assert!(entry.verify_checksum());
274    }
275
276    #[test]
277    fn test_delete_entry() {
278        let entry = WalEntry::delete(3, &"rec-2".into());
279
280        assert_eq!(entry.sequence, 3);
281        assert_eq!(entry.entry_type, WalEntryType::Delete);
282        assert_eq!(entry.record_id, "rec-2");
283        assert!(entry.verify_checksum());
284    }
285
286    #[test]
287    fn test_checkpoint_entry() {
288        let entry = WalEntry::checkpoint(100);
289
290        assert_eq!(entry.sequence, 100);
291        assert_eq!(entry.entry_type, WalEntryType::Checkpoint);
292        assert!(entry.verify_checksum());
293    }
294
295    #[test]
296    fn test_serialization_roundtrip() {
297        let record = create_test_record();
298        let entry = WalEntry::insert(1, &record);
299
300        let bytes = entry.to_bytes();
301        let restored = WalEntry::from_bytes(&bytes).unwrap();
302
303        assert_eq!(restored.sequence, entry.sequence);
304        assert_eq!(restored.entry_type, entry.entry_type);
305        assert_eq!(restored.record_id, entry.record_id);
306        assert!(restored.verify_checksum());
307    }
308
309    #[test]
310    fn test_corrupted_bytes() {
311        let record = create_test_record();
312        let entry = WalEntry::insert(1, &record);
313
314        let mut bytes = entry.to_bytes();
315        // Corrupt a byte
316        if !bytes.is_empty() {
317            let mid = bytes.len() / 2;
318            bytes[mid] ^= 0xFF;
319        }
320
321        let result = WalEntry::from_bytes(&bytes);
322        assert!(result.is_err());
323    }
324
325    #[test]
326    fn test_checksum_tamper_detection() {
327        let mut entry = WalEntry::update_stats(1, &"test".into(), 0.5);
328
329        // Tamper with data after checksum
330        entry.outcome = Some(0.99);
331
332        assert!(!entry.verify_checksum());
333    }
334}