1use std::collections::{HashMap, HashSet, VecDeque};
11use std::path::Path;
12use std::time::{Duration, Instant};
13
14use bytes::Bytes;
15use tracing::warn;
16
17use crate::aof::{self, AofReader, AofRecord};
18use crate::format::FormatError;
19use crate::snapshot::{self, SnapValue, SnapshotReader};
20
21#[derive(Debug, Clone)]
23pub enum RecoveredValue {
24 String(Bytes),
25 List(VecDeque<Bytes>),
26 SortedSet(Vec<(f64, String)>),
28}
29
30impl From<SnapValue> for RecoveredValue {
31 fn from(sv: SnapValue) -> Self {
32 match sv {
33 SnapValue::String(data) => RecoveredValue::String(data),
34 SnapValue::List(deque) => RecoveredValue::List(deque),
35 SnapValue::SortedSet(members) => RecoveredValue::SortedSet(members),
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct RecoveredEntry {
43 pub key: String,
44 pub value: RecoveredValue,
45 pub expires_at: Option<Instant>,
48}
49
50#[derive(Debug)]
52pub struct RecoveryResult {
53 pub entries: Vec<RecoveredEntry>,
55 pub loaded_snapshot: bool,
57 pub replayed_aof: bool,
59}
60
61pub fn recover_shard(data_dir: &Path, shard_id: u16) -> RecoveryResult {
66 let now = Instant::now();
67 let mut map: HashMap<String, (RecoveredValue, Option<Instant>)> = HashMap::new();
68 let mut loaded_snapshot = false;
69 let mut replayed_aof = false;
70
71 let snap_path = snapshot::snapshot_path(data_dir, shard_id);
73 if snap_path.exists() {
74 match load_snapshot(&snap_path, now) {
75 Ok(entries) => {
76 for (key, value, expires_at) in entries {
77 map.insert(key, (RecoveredValue::from(value), expires_at));
78 }
79 loaded_snapshot = true;
80 }
81 Err(e) => {
82 warn!(shard_id, "failed to load snapshot, starting empty: {e}");
83 }
84 }
85 }
86
87 let aof_path = aof::aof_path(data_dir, shard_id);
89 if aof_path.exists() {
90 match replay_aof(&aof_path, &mut map, now) {
91 Ok(count) => {
92 if count > 0 {
93 replayed_aof = true;
94 }
95 }
96 Err(e) => {
97 warn!(
98 shard_id,
99 "failed to replay aof, using snapshot state only: {e}"
100 );
101 }
102 }
103 }
104
105 let entries = map
107 .into_iter()
108 .filter(|(_, (_, expires_at))| match expires_at {
109 Some(deadline) => *deadline > now,
110 None => true,
111 })
112 .map(|(key, (value, expires_at))| RecoveredEntry {
113 key,
114 value,
115 expires_at,
116 })
117 .collect();
118
119 RecoveryResult {
120 entries,
121 loaded_snapshot,
122 replayed_aof,
123 }
124}
125
126fn load_snapshot(
128 path: &Path,
129 now: Instant,
130) -> Result<Vec<(String, SnapValue, Option<Instant>)>, FormatError> {
131 let mut reader = SnapshotReader::open(path)?;
132 let mut entries = Vec::new();
133
134 while let Some(entry) = reader.read_entry()? {
135 let expires_at = if entry.expire_ms >= 0 {
136 Some(now + Duration::from_millis(entry.expire_ms as u64))
137 } else {
138 None
139 };
140 entries.push((entry.key, entry.value, expires_at));
141 }
142
143 reader.verify_footer()?;
144 Ok(entries)
145}
146
147fn apply_incr(
150 map: &mut HashMap<String, (RecoveredValue, Option<Instant>)>,
151 key: String,
152 delta: i64,
153) {
154 let entry = map
155 .entry(key)
156 .or_insert_with(|| (RecoveredValue::String(Bytes::from("0")), None));
157 if let RecoveredValue::String(ref mut data) = entry.0 {
158 let current = std::str::from_utf8(data)
159 .ok()
160 .and_then(|s| s.parse::<i64>().ok());
161 if let Some(n) = current {
162 if let Some(new_val) = n.checked_add(delta) {
163 *data = Bytes::from(new_val.to_string());
164 }
165 }
166 }
167}
168
169fn replay_aof(
172 path: &Path,
173 map: &mut HashMap<String, (RecoveredValue, Option<Instant>)>,
174 now: Instant,
175) -> Result<usize, FormatError> {
176 let mut reader = AofReader::open(path)?;
177 let mut count = 0;
178
179 while let Some(record) = reader.read_record()? {
180 match record {
181 AofRecord::Set {
182 key,
183 value,
184 expire_ms,
185 } => {
186 let expires_at = if expire_ms >= 0 {
187 Some(now + Duration::from_millis(expire_ms as u64))
188 } else {
189 None
190 };
191 map.insert(key, (RecoveredValue::String(value), expires_at));
192 }
193 AofRecord::Del { key } => {
194 map.remove(&key);
195 }
196 AofRecord::Expire { key, seconds } => {
197 if let Some(entry) = map.get_mut(&key) {
198 entry.1 = Some(now + Duration::from_secs(seconds));
199 }
200 }
201 AofRecord::LPush { key, values } => {
202 let entry = map
203 .entry(key)
204 .or_insert_with(|| (RecoveredValue::List(VecDeque::new()), None));
205 if let RecoveredValue::List(ref mut deque) = entry.0 {
206 for v in values {
207 deque.push_front(v);
208 }
209 }
210 }
211 AofRecord::RPush { key, values } => {
212 let entry = map
213 .entry(key)
214 .or_insert_with(|| (RecoveredValue::List(VecDeque::new()), None));
215 if let RecoveredValue::List(ref mut deque) = entry.0 {
216 for v in values {
217 deque.push_back(v);
218 }
219 }
220 }
221 AofRecord::LPop { key } => {
222 if let Some(entry) = map.get_mut(&key) {
223 if let RecoveredValue::List(ref mut deque) = entry.0 {
224 deque.pop_front();
225 if deque.is_empty() {
226 map.remove(&key);
227 count += 1;
228 continue;
229 }
230 }
231 }
232 }
233 AofRecord::RPop { key } => {
234 if let Some(entry) = map.get_mut(&key) {
235 if let RecoveredValue::List(ref mut deque) = entry.0 {
236 deque.pop_back();
237 if deque.is_empty() {
238 map.remove(&key);
239 count += 1;
240 continue;
241 }
242 }
243 }
244 }
245 AofRecord::ZAdd { key, members } => {
246 let entry = map
247 .entry(key)
248 .or_insert_with(|| (RecoveredValue::SortedSet(Vec::new()), None));
249 if let RecoveredValue::SortedSet(ref mut existing) = entry.0 {
250 let mut index: HashMap<String, usize> = existing
252 .iter()
253 .enumerate()
254 .map(|(i, (_, m))| (m.clone(), i))
255 .collect();
256 for (score, member) in members {
257 if let Some(&pos) = index.get(&member) {
258 existing[pos].0 = score;
259 } else {
260 let pos = existing.len();
261 index.insert(member.clone(), pos);
262 existing.push((score, member));
263 }
264 }
265 }
266 }
267 AofRecord::ZRem { key, members } => {
268 if let Some(entry) = map.get_mut(&key) {
269 if let RecoveredValue::SortedSet(ref mut existing) = entry.0 {
270 let to_remove: HashSet<&str> = members.iter().map(|m| m.as_str()).collect();
271 existing.retain(|(_, m)| !to_remove.contains(m.as_str()));
272 if existing.is_empty() {
273 map.remove(&key);
274 count += 1;
275 continue;
276 }
277 }
278 }
279 }
280 AofRecord::Persist { key } => {
281 if let Some(entry) = map.get_mut(&key) {
282 entry.1 = None;
283 }
284 }
285 AofRecord::Pexpire { key, milliseconds } => {
286 if let Some(entry) = map.get_mut(&key) {
287 entry.1 = Some(now + Duration::from_millis(milliseconds));
288 }
289 }
290 AofRecord::Incr { key } => {
291 apply_incr(map, key, 1);
292 }
293 AofRecord::Decr { key } => {
294 apply_incr(map, key, -1);
295 }
296 }
297 count += 1;
298 }
299
300 Ok(count)
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use crate::aof::AofWriter;
307 use crate::snapshot::{SnapEntry, SnapValue, SnapshotWriter};
308
309 fn temp_dir() -> tempfile::TempDir {
310 tempfile::tempdir().expect("create temp dir")
311 }
312
313 #[test]
314 fn empty_dir_returns_empty_result() {
315 let dir = temp_dir();
316 let result = recover_shard(dir.path(), 0);
317 assert!(result.entries.is_empty());
318 assert!(!result.loaded_snapshot);
319 assert!(!result.replayed_aof);
320 }
321
322 #[test]
323 fn snapshot_only_recovery() {
324 let dir = temp_dir();
325 let path = snapshot::snapshot_path(dir.path(), 0);
326
327 {
328 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
329 writer
330 .write_entry(&SnapEntry {
331 key: "a".into(),
332 value: SnapValue::String(Bytes::from("1")),
333 expire_ms: -1,
334 })
335 .unwrap();
336 writer
337 .write_entry(&SnapEntry {
338 key: "b".into(),
339 value: SnapValue::String(Bytes::from("2")),
340 expire_ms: 60_000,
341 })
342 .unwrap();
343 writer.finish().unwrap();
344 }
345
346 let result = recover_shard(dir.path(), 0);
347 assert!(result.loaded_snapshot);
348 assert!(!result.replayed_aof);
349 assert_eq!(result.entries.len(), 2);
350 }
351
352 #[test]
353 fn aof_only_recovery() {
354 let dir = temp_dir();
355 let path = aof::aof_path(dir.path(), 0);
356
357 {
358 let mut writer = AofWriter::open(&path).unwrap();
359 writer
360 .write_record(&AofRecord::Set {
361 key: "x".into(),
362 value: Bytes::from("10"),
363 expire_ms: -1,
364 })
365 .unwrap();
366 writer
367 .write_record(&AofRecord::Set {
368 key: "y".into(),
369 value: Bytes::from("20"),
370 expire_ms: -1,
371 })
372 .unwrap();
373 writer.sync().unwrap();
374 }
375
376 let result = recover_shard(dir.path(), 0);
377 assert!(!result.loaded_snapshot);
378 assert!(result.replayed_aof);
379 assert_eq!(result.entries.len(), 2);
380 }
381
382 #[test]
383 fn snapshot_plus_aof_overlay() {
384 let dir = temp_dir();
385
386 {
388 let path = snapshot::snapshot_path(dir.path(), 0);
389 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
390 writer
391 .write_entry(&SnapEntry {
392 key: "a".into(),
393 value: SnapValue::String(Bytes::from("old")),
394 expire_ms: -1,
395 })
396 .unwrap();
397 writer.finish().unwrap();
398 }
399
400 {
402 let path = aof::aof_path(dir.path(), 0);
403 let mut writer = AofWriter::open(&path).unwrap();
404 writer
405 .write_record(&AofRecord::Set {
406 key: "a".into(),
407 value: Bytes::from("new"),
408 expire_ms: -1,
409 })
410 .unwrap();
411 writer
412 .write_record(&AofRecord::Set {
413 key: "b".into(),
414 value: Bytes::from("added"),
415 expire_ms: -1,
416 })
417 .unwrap();
418 writer.sync().unwrap();
419 }
420
421 let result = recover_shard(dir.path(), 0);
422 assert!(result.loaded_snapshot);
423 assert!(result.replayed_aof);
424
425 let map: HashMap<_, _> = result
426 .entries
427 .iter()
428 .map(|e| (e.key.as_str(), e.value.clone()))
429 .collect();
430 assert!(matches!(&map["a"], RecoveredValue::String(b) if b == &Bytes::from("new")));
431 assert!(matches!(&map["b"], RecoveredValue::String(b) if b == &Bytes::from("added")));
432 }
433
434 #[test]
435 fn del_removes_entry_during_replay() {
436 let dir = temp_dir();
437 let path = aof::aof_path(dir.path(), 0);
438
439 {
440 let mut writer = AofWriter::open(&path).unwrap();
441 writer
442 .write_record(&AofRecord::Set {
443 key: "gone".into(),
444 value: Bytes::from("temp"),
445 expire_ms: -1,
446 })
447 .unwrap();
448 writer
449 .write_record(&AofRecord::Del { key: "gone".into() })
450 .unwrap();
451 writer.sync().unwrap();
452 }
453
454 let result = recover_shard(dir.path(), 0);
455 assert!(result.entries.is_empty());
456 }
457
458 #[test]
459 fn expired_entries_skipped() {
460 let dir = temp_dir();
461 let path = snapshot::snapshot_path(dir.path(), 0);
462
463 {
464 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
465 writer
467 .write_entry(&SnapEntry {
468 key: "dead".into(),
469 value: SnapValue::String(Bytes::from("gone")),
470 expire_ms: 0,
471 })
472 .unwrap();
473 writer
475 .write_entry(&SnapEntry {
476 key: "alive".into(),
477 value: SnapValue::String(Bytes::from("here")),
478 expire_ms: 60_000,
479 })
480 .unwrap();
481 writer.finish().unwrap();
482 }
483
484 let result = recover_shard(dir.path(), 0);
485 assert_eq!(result.entries.len(), 1);
486 assert_eq!(result.entries[0].key, "alive");
487 }
488
489 #[test]
490 fn corrupt_snapshot_starts_empty() {
491 let dir = temp_dir();
492 let path = snapshot::snapshot_path(dir.path(), 0);
493
494 std::fs::write(&path, b"garbage data").unwrap();
495
496 let result = recover_shard(dir.path(), 0);
497 assert!(!result.loaded_snapshot);
498 assert!(result.entries.is_empty());
499 }
500
501 #[test]
502 fn sorted_set_snapshot_recovery() {
503 let dir = temp_dir();
504 let path = snapshot::snapshot_path(dir.path(), 0);
505
506 {
507 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
508 writer
509 .write_entry(&SnapEntry {
510 key: "board".into(),
511 value: SnapValue::SortedSet(vec![
512 (100.0, "alice".into()),
513 (200.0, "bob".into()),
514 ]),
515 expire_ms: -1,
516 })
517 .unwrap();
518 writer.finish().unwrap();
519 }
520
521 let result = recover_shard(dir.path(), 0);
522 assert!(result.loaded_snapshot);
523 assert_eq!(result.entries.len(), 1);
524 match &result.entries[0].value {
525 RecoveredValue::SortedSet(members) => {
526 assert_eq!(members.len(), 2);
527 assert!(members.contains(&(100.0, "alice".into())));
528 assert!(members.contains(&(200.0, "bob".into())));
529 }
530 other => panic!("expected SortedSet, got {other:?}"),
531 }
532 }
533
534 #[test]
535 fn sorted_set_aof_replay() {
536 let dir = temp_dir();
537 let path = aof::aof_path(dir.path(), 0);
538
539 {
540 let mut writer = AofWriter::open(&path).unwrap();
541 writer
542 .write_record(&AofRecord::ZAdd {
543 key: "board".into(),
544 members: vec![(100.0, "alice".into()), (200.0, "bob".into())],
545 })
546 .unwrap();
547 writer
548 .write_record(&AofRecord::ZRem {
549 key: "board".into(),
550 members: vec!["alice".into()],
551 })
552 .unwrap();
553 writer.sync().unwrap();
554 }
555
556 let result = recover_shard(dir.path(), 0);
557 assert!(result.replayed_aof);
558 assert_eq!(result.entries.len(), 1);
559 match &result.entries[0].value {
560 RecoveredValue::SortedSet(members) => {
561 assert_eq!(members.len(), 1);
562 assert_eq!(members[0], (200.0, "bob".into()));
563 }
564 other => panic!("expected SortedSet, got {other:?}"),
565 }
566 }
567
568 #[test]
569 fn sorted_set_zrem_auto_deletes_empty() {
570 let dir = temp_dir();
571 let path = aof::aof_path(dir.path(), 0);
572
573 {
574 let mut writer = AofWriter::open(&path).unwrap();
575 writer
576 .write_record(&AofRecord::ZAdd {
577 key: "board".into(),
578 members: vec![(100.0, "alice".into())],
579 })
580 .unwrap();
581 writer
582 .write_record(&AofRecord::ZRem {
583 key: "board".into(),
584 members: vec!["alice".into()],
585 })
586 .unwrap();
587 writer.sync().unwrap();
588 }
589
590 let result = recover_shard(dir.path(), 0);
591 assert!(result.entries.is_empty());
592 }
593
594 #[test]
595 fn expire_record_updates_ttl() {
596 let dir = temp_dir();
597 let path = aof::aof_path(dir.path(), 0);
598
599 {
600 let mut writer = AofWriter::open(&path).unwrap();
601 writer
602 .write_record(&AofRecord::Set {
603 key: "k".into(),
604 value: Bytes::from("v"),
605 expire_ms: -1,
606 })
607 .unwrap();
608 writer
609 .write_record(&AofRecord::Expire {
610 key: "k".into(),
611 seconds: 300,
612 })
613 .unwrap();
614 writer.sync().unwrap();
615 }
616
617 let result = recover_shard(dir.path(), 0);
618 assert_eq!(result.entries.len(), 1);
619 assert!(result.entries[0].expires_at.is_some());
620 }
621
622 #[test]
623 fn persist_record_removes_ttl() {
624 let dir = temp_dir();
625 let path = aof::aof_path(dir.path(), 0);
626
627 {
628 let mut writer = AofWriter::open(&path).unwrap();
629 writer
630 .write_record(&AofRecord::Set {
631 key: "k".into(),
632 value: Bytes::from("v"),
633 expire_ms: 60_000,
634 })
635 .unwrap();
636 writer
637 .write_record(&AofRecord::Persist { key: "k".into() })
638 .unwrap();
639 writer.sync().unwrap();
640 }
641
642 let result = recover_shard(dir.path(), 0);
643 assert_eq!(result.entries.len(), 1);
644 assert!(result.entries[0].expires_at.is_none());
645 }
646
647 #[test]
648 fn incr_decr_replay() {
649 let dir = temp_dir();
650 let path = aof::aof_path(dir.path(), 0);
651
652 {
653 let mut writer = AofWriter::open(&path).unwrap();
654 writer
655 .write_record(&AofRecord::Set {
656 key: "n".into(),
657 value: Bytes::from("10"),
658 expire_ms: -1,
659 })
660 .unwrap();
661 writer
662 .write_record(&AofRecord::Incr { key: "n".into() })
663 .unwrap();
664 writer
665 .write_record(&AofRecord::Incr { key: "n".into() })
666 .unwrap();
667 writer
668 .write_record(&AofRecord::Decr { key: "n".into() })
669 .unwrap();
670 writer
672 .write_record(&AofRecord::Incr {
673 key: "fresh".into(),
674 })
675 .unwrap();
676 writer.sync().unwrap();
677 }
678
679 let result = recover_shard(dir.path(), 0);
680 let map: HashMap<_, _> = result
681 .entries
682 .iter()
683 .map(|e| (e.key.as_str(), e.value.clone()))
684 .collect();
685
686 match &map["n"] {
688 RecoveredValue::String(data) => assert_eq!(data, &Bytes::from("11")),
689 other => panic!("expected String(\"11\"), got {other:?}"),
690 }
691 match &map["fresh"] {
693 RecoveredValue::String(data) => assert_eq!(data, &Bytes::from("1")),
694 other => panic!("expected String(\"1\"), got {other:?}"),
695 }
696 }
697
698 #[test]
699 fn pexpire_record_sets_ttl() {
700 let dir = temp_dir();
701 let path = aof::aof_path(dir.path(), 0);
702
703 {
704 let mut writer = AofWriter::open(&path).unwrap();
705 writer
706 .write_record(&AofRecord::Set {
707 key: "k".into(),
708 value: Bytes::from("v"),
709 expire_ms: -1,
710 })
711 .unwrap();
712 writer
713 .write_record(&AofRecord::Pexpire {
714 key: "k".into(),
715 milliseconds: 5000,
716 })
717 .unwrap();
718 writer.sync().unwrap();
719 }
720
721 let result = recover_shard(dir.path(), 0);
722 assert_eq!(result.entries.len(), 1);
723 assert!(result.entries[0].expires_at.is_some());
724 }
725}