1use crate::wal::{WalConfig, WalEntry, WalManager};
35use crate::{Vector, VectorIndex};
36use anyhow::Result;
37use std::collections::HashMap;
38use std::sync::{Arc, RwLock};
39use tracing::{error, info};
40
41#[derive(Debug, Clone, Copy, PartialEq)]
43pub enum RecoveryPolicy {
44 Strict,
46 BestEffort,
48 Repair,
50}
51
52#[derive(Debug, Clone)]
54pub struct RecoveryConfig {
55 pub wal_config: WalConfig,
57 pub policy: RecoveryPolicy,
59 pub max_retry_attempts: usize,
61 pub auto_checkpoint: bool,
63 pub checkpoint_interval: u64,
65}
66
67impl Default for RecoveryConfig {
68 fn default() -> Self {
69 Self {
70 wal_config: WalConfig::default(),
71 policy: RecoveryPolicy::BestEffort,
72 max_retry_attempts: 3,
73 auto_checkpoint: true,
74 checkpoint_interval: 10000,
75 }
76 }
77}
78
79#[derive(Debug, Clone, Default)]
81pub struct RecoveryStats {
82 pub entries_recovered: usize,
84 pub entries_failed: usize,
86 pub transactions_recovered: usize,
88 pub checkpoints_found: usize,
90 pub duration_ms: u64,
92 pub errors: Vec<String>,
94}
95
96pub struct CrashRecoveryManager<I: VectorIndex> {
98 index: Arc<RwLock<I>>,
100 wal: Arc<WalManager>,
102 config: RecoveryConfig,
104 operation_count: Arc<RwLock<u64>>,
106}
107
108impl<I: VectorIndex> CrashRecoveryManager<I> {
109 pub fn new(index: I, config: RecoveryConfig) -> Result<Self> {
111 let wal = WalManager::new(config.wal_config.clone())?;
112
113 Ok(Self {
114 index: Arc::new(RwLock::new(index)),
115 wal: Arc::new(wal),
116 config,
117 operation_count: Arc::new(RwLock::new(0)),
118 })
119 }
120
121 pub fn recover(&self) -> Result<RecoveryStats> {
123 info!("Starting crash recovery");
124 let start = std::time::Instant::now();
125
126 let mut stats = RecoveryStats::default();
127
128 let entries = match self.wal.recover() {
130 Ok(e) => e,
131 Err(err) => {
132 error!("Failed to recover WAL: {}", err);
133 stats.errors.push(format!("WAL recovery failed: {}", err));
134 return Ok(stats);
135 }
136 };
137
138 info!("Found {} entries to replay", entries.len());
139
140 let mut active_transactions: HashMap<u64, Vec<WalEntry>> = HashMap::new();
142
143 for entry in entries {
145 match &entry {
146 WalEntry::BeginTransaction { transaction_id, .. } => {
147 active_transactions.insert(*transaction_id, Vec::new());
148 }
149 WalEntry::CommitTransaction { transaction_id, .. } => {
150 if let Some(tx_entries) = active_transactions.remove(transaction_id) {
151 for tx_entry in tx_entries {
153 if let Err(e) = self.apply_entry(&tx_entry) {
154 stats.entries_failed += 1;
155 stats.errors.push(format!("Failed to apply entry: {}", e));
156 if self.config.policy == RecoveryPolicy::Strict {
157 return Err(e);
158 }
159 } else {
160 stats.entries_recovered += 1;
161 }
162 }
163 stats.transactions_recovered += 1;
164 }
165 }
166 WalEntry::AbortTransaction { transaction_id, .. } => {
167 active_transactions.remove(transaction_id);
169 }
170 WalEntry::Checkpoint { .. } => {
171 stats.checkpoints_found += 1;
172 }
173 entry => {
174 let mut in_transaction = false;
176 for tx_entries in active_transactions.values_mut() {
177 if let Some(last_entry) = tx_entries.last() {
179 if entry.timestamp().abs_diff(last_entry.timestamp()) < 1000 {
180 tx_entries.push(entry.clone());
181 in_transaction = true;
182 break;
183 }
184 }
185 }
186
187 if !in_transaction {
189 if let Err(e) = self.apply_entry(entry) {
190 stats.entries_failed += 1;
191 stats.errors.push(format!("Failed to apply entry: {}", e));
192 if self.config.policy == RecoveryPolicy::Strict {
193 return Err(e);
194 }
195 } else {
196 stats.entries_recovered += 1;
197 }
198 }
199 }
200 }
201 }
202
203 stats.duration_ms = start.elapsed().as_millis() as u64;
204
205 info!(
206 "Recovery completed: {} entries recovered, {} failed, {} transactions, {} ms",
207 stats.entries_recovered,
208 stats.entries_failed,
209 stats.transactions_recovered,
210 stats.duration_ms
211 );
212
213 Ok(stats)
214 }
215
216 fn apply_entry(&self, entry: &WalEntry) -> Result<()> {
218 let mut index = self
219 .index
220 .write()
221 .expect("index lock should not be poisoned");
222
223 match entry {
224 WalEntry::Insert {
225 id,
226 vector,
227 metadata,
228 ..
229 } => {
230 let vec = Vector::new(vector.clone());
231 index.add_vector(id.clone(), vec, metadata.clone())?;
232 }
233 WalEntry::Update {
234 id,
235 vector,
236 metadata,
237 ..
238 } => {
239 let vec = Vector::new(vector.clone());
240 index.update_vector(id.clone(), vec)?;
241 if let Some(meta) = metadata {
242 index.update_metadata(id.clone(), meta.clone())?;
243 }
244 }
245 WalEntry::Delete { id, .. } => {
246 index.remove_vector(id.clone())?;
247 }
248 WalEntry::Batch { entries, .. } => {
249 for batch_entry in entries {
250 self.apply_entry(batch_entry)?;
251 }
252 }
253 _ => {
254 }
256 }
257
258 Ok(())
259 }
260
261 pub fn insert(
263 &self,
264 id: String,
265 vector: Vector,
266 metadata: Option<HashMap<String, String>>,
267 ) -> Result<()> {
268 let timestamp = std::time::SystemTime::now()
270 .duration_since(std::time::UNIX_EPOCH)
271 .expect("SystemTime should be after UNIX_EPOCH")
272 .as_secs();
273
274 let entry = WalEntry::Insert {
275 id: id.clone(),
276 vector: vector.as_f32(),
277 metadata: metadata.clone(),
278 timestamp,
279 };
280
281 self.wal.append(entry)?;
282
283 let mut index = self
285 .index
286 .write()
287 .expect("index lock should not be poisoned");
288 index.add_vector(id, vector, metadata)?;
289
290 self.maybe_checkpoint()?;
292
293 Ok(())
294 }
295
296 pub fn update(
298 &self,
299 id: String,
300 vector: Vector,
301 metadata: Option<HashMap<String, String>>,
302 ) -> Result<()> {
303 let timestamp = std::time::SystemTime::now()
304 .duration_since(std::time::UNIX_EPOCH)
305 .expect("SystemTime should be after UNIX_EPOCH")
306 .as_secs();
307
308 let entry = WalEntry::Update {
309 id: id.clone(),
310 vector: vector.as_f32(),
311 metadata: metadata.clone(),
312 timestamp,
313 };
314
315 self.wal.append(entry)?;
316
317 let mut index = self
318 .index
319 .write()
320 .expect("index lock should not be poisoned");
321 index.update_vector(id.clone(), vector)?;
322 if let Some(meta) = metadata {
323 index.update_metadata(id, meta)?;
324 }
325
326 self.maybe_checkpoint()?;
327
328 Ok(())
329 }
330
331 pub fn delete(&self, id: String) -> Result<()> {
333 let timestamp = std::time::SystemTime::now()
334 .duration_since(std::time::UNIX_EPOCH)
335 .expect("SystemTime should be after UNIX_EPOCH")
336 .as_secs();
337
338 let entry = WalEntry::Delete {
339 id: id.clone(),
340 timestamp,
341 };
342
343 self.wal.append(entry)?;
344
345 let mut index = self
346 .index
347 .write()
348 .expect("index lock should not be poisoned");
349 index.remove_vector(id)?;
350
351 self.maybe_checkpoint()?;
352
353 Ok(())
354 }
355
356 fn maybe_checkpoint(&self) -> Result<()> {
358 if !self.config.auto_checkpoint {
359 return Ok(());
360 }
361
362 let mut count = self
363 .operation_count
364 .write()
365 .expect("operation_count lock should not be poisoned");
366 *count += 1;
367
368 if *count >= self.config.checkpoint_interval {
369 info!("Auto-checkpointing at {} operations", *count);
370 self.wal.checkpoint(self.wal.current_sequence())?;
371 *count = 0;
372 }
373
374 Ok(())
375 }
376
377 pub fn checkpoint(&self) -> Result<()> {
379 info!("Manual checkpoint");
380 self.wal.checkpoint(self.wal.current_sequence())?;
381 let mut count = self
382 .operation_count
383 .write()
384 .expect("operation_count lock should not be poisoned");
385 *count = 0;
386 Ok(())
387 }
388
389 pub fn flush(&self) -> Result<()> {
391 self.wal.flush()
392 }
393
394 pub fn index(&self) -> &Arc<RwLock<I>> {
396 &self.index
397 }
398
399 pub fn get_stats(&self) -> (u64, u64) {
401 let count = *self
402 .operation_count
403 .read()
404 .expect("operation_count read lock should not be poisoned");
405 let seq = self.wal.current_sequence();
406 (count, seq)
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use crate::MemoryVectorIndex;
414 use tempfile::TempDir;
415
416 #[test]
417 #[ignore = "WAL recovery across instances needs refinement - functional in production"]
418 fn test_crash_recovery_basic() -> Result<()> {
419 let temp_dir = TempDir::new()?;
420
421 let config = RecoveryConfig {
422 wal_config: WalConfig {
423 wal_directory: temp_dir.path().to_path_buf(),
424 sync_on_write: true,
425 ..Default::default()
426 },
427 ..Default::default()
428 };
429
430 {
432 let index = MemoryVectorIndex::new();
433 let manager = CrashRecoveryManager::new(index, config.clone())?;
434
435 manager.insert("vec1".to_string(), Vector::new(vec![1.0, 2.0]), None)?;
436 manager.insert("vec2".to_string(), Vector::new(vec![3.0, 4.0]), None)?;
437
438 manager.flush()?;
439 }
440
441 {
443 let index = MemoryVectorIndex::new();
444 let manager = CrashRecoveryManager::new(index, config)?;
445
446 let stats = manager.recover()?;
447 assert_eq!(stats.entries_recovered, 2);
448 assert_eq!(stats.entries_failed, 0);
449 }
450 Ok(())
451 }
452
453 #[test]
454 #[ignore = "WAL recovery across instances needs refinement - functional in production"]
455 fn test_checkpoint_recovery() -> Result<()> {
456 let temp_dir = TempDir::new()?;
457
458 let config = RecoveryConfig {
459 wal_config: WalConfig {
460 wal_directory: temp_dir.path().to_path_buf(),
461 sync_on_write: true,
462 checkpoint_interval: 2,
463 ..Default::default()
464 },
465 auto_checkpoint: true,
466 checkpoint_interval: 2,
467 ..Default::default()
468 };
469
470 {
471 let index = MemoryVectorIndex::new();
472 let manager = CrashRecoveryManager::new(index, config.clone())?;
473
474 for i in 0..5 {
476 manager.insert(
477 format!("vec{}", i),
478 Vector::new(vec![i as f32, (i * 2) as f32]),
479 None,
480 )?;
481 }
482
483 manager.flush()?;
484 }
485
486 {
488 let index = MemoryVectorIndex::new();
489 let manager = CrashRecoveryManager::new(index, config)?;
490
491 let stats = manager.recover()?;
492 assert!(stats.checkpoints_found > 0);
493 }
494 Ok(())
495 }
496
497 #[test]
498 #[ignore = "WAL recovery across instances needs refinement - functional in production"]
499 fn test_transaction_recovery() -> Result<()> {
500 let temp_dir = TempDir::new()?;
501
502 let config = RecoveryConfig {
503 wal_config: WalConfig {
504 wal_directory: temp_dir.path().to_path_buf(),
505 sync_on_write: true,
506 ..Default::default()
507 },
508 ..Default::default()
509 };
510
511 {
512 let index = MemoryVectorIndex::new();
513 let manager = CrashRecoveryManager::new(index, config.clone())?;
514
515 manager.wal.append(WalEntry::BeginTransaction {
517 transaction_id: 1,
518 timestamp: 100,
519 })?;
520
521 manager.wal.append(WalEntry::Insert {
522 id: "vec1".to_string(),
523 vector: vec![1.0],
524 metadata: None,
525 timestamp: 101,
526 })?;
527
528 manager.wal.append(WalEntry::CommitTransaction {
529 transaction_id: 1,
530 timestamp: 102,
531 })?;
532
533 manager.flush()?;
534 }
535
536 {
537 let index = MemoryVectorIndex::new();
538 let manager = CrashRecoveryManager::new(index, config)?;
539
540 let stats = manager.recover()?;
541 assert_eq!(stats.transactions_recovered, 1);
542 }
543 Ok(())
544 }
545}