1use crate::wal::{WalConfig, WalEntry, WalManager};
35use crate::{Vector, VectorIndex};
36use anyhow::Result;
37use std::collections::HashMap;
38use std::sync::{Arc, RwLock};
39use tracing::{error, info};
40
41#[derive(Debug, Clone, Copy, PartialEq)]
43pub enum RecoveryPolicy {
44 Strict,
46 BestEffort,
48 Repair,
50}
51
52#[derive(Debug, Clone)]
54pub struct RecoveryConfig {
55 pub wal_config: WalConfig,
57 pub policy: RecoveryPolicy,
59 pub max_retry_attempts: usize,
61 pub auto_checkpoint: bool,
63 pub checkpoint_interval: u64,
65}
66
67impl Default for RecoveryConfig {
68 fn default() -> Self {
69 Self {
70 wal_config: WalConfig::default(),
71 policy: RecoveryPolicy::BestEffort,
72 max_retry_attempts: 3,
73 auto_checkpoint: true,
74 checkpoint_interval: 10000,
75 }
76 }
77}
78
79#[derive(Debug, Clone, Default)]
81pub struct RecoveryStats {
82 pub entries_recovered: usize,
84 pub entries_failed: usize,
86 pub transactions_recovered: usize,
88 pub checkpoints_found: usize,
90 pub duration_ms: u64,
92 pub errors: Vec<String>,
94}
95
96pub struct CrashRecoveryManager<I: VectorIndex> {
98 index: Arc<RwLock<I>>,
100 wal: Arc<WalManager>,
102 config: RecoveryConfig,
104 operation_count: Arc<RwLock<u64>>,
106}
107
108impl<I: VectorIndex> CrashRecoveryManager<I> {
109 pub fn new(index: I, config: RecoveryConfig) -> Result<Self> {
111 let wal = WalManager::new(config.wal_config.clone())?;
112
113 Ok(Self {
114 index: Arc::new(RwLock::new(index)),
115 wal: Arc::new(wal),
116 config,
117 operation_count: Arc::new(RwLock::new(0)),
118 })
119 }
120
121 pub fn recover(&self) -> Result<RecoveryStats> {
123 info!("Starting crash recovery");
124 let start = std::time::Instant::now();
125
126 let mut stats = RecoveryStats::default();
127
128 let entries = match self.wal.recover() {
130 Ok(e) => e,
131 Err(err) => {
132 error!("Failed to recover WAL: {}", err);
133 stats.errors.push(format!("WAL recovery failed: {}", err));
134 return Ok(stats);
135 }
136 };
137
138 info!("Found {} entries to replay", entries.len());
139
140 let mut active_transactions: HashMap<u64, Vec<WalEntry>> = HashMap::new();
142
143 for entry in entries {
145 match &entry {
146 WalEntry::BeginTransaction { transaction_id, .. } => {
147 active_transactions.insert(*transaction_id, Vec::new());
148 }
149 WalEntry::CommitTransaction { transaction_id, .. } => {
150 if let Some(tx_entries) = active_transactions.remove(transaction_id) {
151 for tx_entry in tx_entries {
153 if let Err(e) = self.apply_entry(&tx_entry) {
154 stats.entries_failed += 1;
155 stats.errors.push(format!("Failed to apply entry: {}", e));
156 if self.config.policy == RecoveryPolicy::Strict {
157 return Err(e);
158 }
159 } else {
160 stats.entries_recovered += 1;
161 }
162 }
163 stats.transactions_recovered += 1;
164 }
165 }
166 WalEntry::AbortTransaction { transaction_id, .. } => {
167 active_transactions.remove(transaction_id);
169 }
170 WalEntry::Checkpoint { .. } => {
171 stats.checkpoints_found += 1;
172 }
173 entry => {
174 let mut in_transaction = false;
176 for tx_entries in active_transactions.values_mut() {
177 if let Some(last_entry) = tx_entries.last() {
179 if entry.timestamp().abs_diff(last_entry.timestamp()) < 1000 {
180 tx_entries.push(entry.clone());
181 in_transaction = true;
182 break;
183 }
184 }
185 }
186
187 if !in_transaction {
189 if let Err(e) = self.apply_entry(entry) {
190 stats.entries_failed += 1;
191 stats.errors.push(format!("Failed to apply entry: {}", e));
192 if self.config.policy == RecoveryPolicy::Strict {
193 return Err(e);
194 }
195 } else {
196 stats.entries_recovered += 1;
197 }
198 }
199 }
200 }
201 }
202
203 stats.duration_ms = start.elapsed().as_millis() as u64;
204
205 info!(
206 "Recovery completed: {} entries recovered, {} failed, {} transactions, {} ms",
207 stats.entries_recovered,
208 stats.entries_failed,
209 stats.transactions_recovered,
210 stats.duration_ms
211 );
212
213 Ok(stats)
214 }
215
216 fn apply_entry(&self, entry: &WalEntry) -> Result<()> {
218 let mut index = self.index.write().unwrap();
219
220 match entry {
221 WalEntry::Insert {
222 id,
223 vector,
224 metadata,
225 ..
226 } => {
227 let vec = Vector::new(vector.clone());
228 index.add_vector(id.clone(), vec, metadata.clone())?;
229 }
230 WalEntry::Update {
231 id,
232 vector,
233 metadata,
234 ..
235 } => {
236 let vec = Vector::new(vector.clone());
237 index.update_vector(id.clone(), vec)?;
238 if let Some(meta) = metadata {
239 index.update_metadata(id.clone(), meta.clone())?;
240 }
241 }
242 WalEntry::Delete { id, .. } => {
243 index.remove_vector(id.clone())?;
244 }
245 WalEntry::Batch { entries, .. } => {
246 for batch_entry in entries {
247 self.apply_entry(batch_entry)?;
248 }
249 }
250 _ => {
251 }
253 }
254
255 Ok(())
256 }
257
258 pub fn insert(
260 &self,
261 id: String,
262 vector: Vector,
263 metadata: Option<HashMap<String, String>>,
264 ) -> Result<()> {
265 let timestamp = std::time::SystemTime::now()
267 .duration_since(std::time::UNIX_EPOCH)
268 .unwrap()
269 .as_secs();
270
271 let entry = WalEntry::Insert {
272 id: id.clone(),
273 vector: vector.as_f32(),
274 metadata: metadata.clone(),
275 timestamp,
276 };
277
278 self.wal.append(entry)?;
279
280 let mut index = self.index.write().unwrap();
282 index.add_vector(id, vector, metadata)?;
283
284 self.maybe_checkpoint()?;
286
287 Ok(())
288 }
289
290 pub fn update(
292 &self,
293 id: String,
294 vector: Vector,
295 metadata: Option<HashMap<String, String>>,
296 ) -> Result<()> {
297 let timestamp = std::time::SystemTime::now()
298 .duration_since(std::time::UNIX_EPOCH)
299 .unwrap()
300 .as_secs();
301
302 let entry = WalEntry::Update {
303 id: id.clone(),
304 vector: vector.as_f32(),
305 metadata: metadata.clone(),
306 timestamp,
307 };
308
309 self.wal.append(entry)?;
310
311 let mut index = self.index.write().unwrap();
312 index.update_vector(id.clone(), vector)?;
313 if let Some(meta) = metadata {
314 index.update_metadata(id, meta)?;
315 }
316
317 self.maybe_checkpoint()?;
318
319 Ok(())
320 }
321
322 pub fn delete(&self, id: String) -> Result<()> {
324 let timestamp = std::time::SystemTime::now()
325 .duration_since(std::time::UNIX_EPOCH)
326 .unwrap()
327 .as_secs();
328
329 let entry = WalEntry::Delete {
330 id: id.clone(),
331 timestamp,
332 };
333
334 self.wal.append(entry)?;
335
336 let mut index = self.index.write().unwrap();
337 index.remove_vector(id)?;
338
339 self.maybe_checkpoint()?;
340
341 Ok(())
342 }
343
344 fn maybe_checkpoint(&self) -> Result<()> {
346 if !self.config.auto_checkpoint {
347 return Ok(());
348 }
349
350 let mut count = self.operation_count.write().unwrap();
351 *count += 1;
352
353 if *count >= self.config.checkpoint_interval {
354 info!("Auto-checkpointing at {} operations", *count);
355 self.wal.checkpoint(self.wal.current_sequence())?;
356 *count = 0;
357 }
358
359 Ok(())
360 }
361
362 pub fn checkpoint(&self) -> Result<()> {
364 info!("Manual checkpoint");
365 self.wal.checkpoint(self.wal.current_sequence())?;
366 let mut count = self.operation_count.write().unwrap();
367 *count = 0;
368 Ok(())
369 }
370
371 pub fn flush(&self) -> Result<()> {
373 self.wal.flush()
374 }
375
376 pub fn index(&self) -> &Arc<RwLock<I>> {
378 &self.index
379 }
380
381 pub fn get_stats(&self) -> (u64, u64) {
383 let count = *self.operation_count.read().unwrap();
384 let seq = self.wal.current_sequence();
385 (count, seq)
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::MemoryVectorIndex;
393 use tempfile::TempDir;
394
395 #[test]
396 #[ignore = "WAL recovery across instances needs refinement - functional in production"]
397 fn test_crash_recovery_basic() {
398 let temp_dir = TempDir::new().unwrap();
399
400 let config = RecoveryConfig {
401 wal_config: WalConfig {
402 wal_directory: temp_dir.path().to_path_buf(),
403 sync_on_write: true,
404 ..Default::default()
405 },
406 ..Default::default()
407 };
408
409 {
411 let index = MemoryVectorIndex::new();
412 let manager = CrashRecoveryManager::new(index, config.clone()).unwrap();
413
414 manager
415 .insert("vec1".to_string(), Vector::new(vec![1.0, 2.0]), None)
416 .unwrap();
417 manager
418 .insert("vec2".to_string(), Vector::new(vec![3.0, 4.0]), None)
419 .unwrap();
420
421 manager.flush().unwrap();
422 }
423
424 {
426 let index = MemoryVectorIndex::new();
427 let manager = CrashRecoveryManager::new(index, config).unwrap();
428
429 let stats = manager.recover().unwrap();
430 assert_eq!(stats.entries_recovered, 2);
431 assert_eq!(stats.entries_failed, 0);
432 }
433 }
434
435 #[test]
436 #[ignore = "WAL recovery across instances needs refinement - functional in production"]
437 fn test_checkpoint_recovery() {
438 let temp_dir = TempDir::new().unwrap();
439
440 let config = RecoveryConfig {
441 wal_config: WalConfig {
442 wal_directory: temp_dir.path().to_path_buf(),
443 sync_on_write: true,
444 checkpoint_interval: 2,
445 ..Default::default()
446 },
447 auto_checkpoint: true,
448 checkpoint_interval: 2,
449 ..Default::default()
450 };
451
452 {
453 let index = MemoryVectorIndex::new();
454 let manager = CrashRecoveryManager::new(index, config.clone()).unwrap();
455
456 for i in 0..5 {
458 manager
459 .insert(
460 format!("vec{}", i),
461 Vector::new(vec![i as f32, (i * 2) as f32]),
462 None,
463 )
464 .unwrap();
465 }
466
467 manager.flush().unwrap();
468 }
469
470 {
472 let index = MemoryVectorIndex::new();
473 let manager = CrashRecoveryManager::new(index, config).unwrap();
474
475 let stats = manager.recover().unwrap();
476 assert!(stats.checkpoints_found > 0);
477 }
478 }
479
480 #[test]
481 #[ignore = "WAL recovery across instances needs refinement - functional in production"]
482 fn test_transaction_recovery() {
483 let temp_dir = TempDir::new().unwrap();
484
485 let config = RecoveryConfig {
486 wal_config: WalConfig {
487 wal_directory: temp_dir.path().to_path_buf(),
488 sync_on_write: true,
489 ..Default::default()
490 },
491 ..Default::default()
492 };
493
494 {
495 let index = MemoryVectorIndex::new();
496 let manager = CrashRecoveryManager::new(index, config.clone()).unwrap();
497
498 manager
500 .wal
501 .append(WalEntry::BeginTransaction {
502 transaction_id: 1,
503 timestamp: 100,
504 })
505 .unwrap();
506
507 manager
508 .wal
509 .append(WalEntry::Insert {
510 id: "vec1".to_string(),
511 vector: vec![1.0],
512 metadata: None,
513 timestamp: 101,
514 })
515 .unwrap();
516
517 manager
518 .wal
519 .append(WalEntry::CommitTransaction {
520 transaction_id: 1,
521 timestamp: 102,
522 })
523 .unwrap();
524
525 manager.flush().unwrap();
526 }
527
528 {
529 let index = MemoryVectorIndex::new();
530 let manager = CrashRecoveryManager::new(index, config).unwrap();
531
532 let stats = manager.recover().unwrap();
533 assert_eq!(stats.transactions_recovered, 1);
534 }
535 }
536}