1use std::collections::{HashMap, HashSet, VecDeque};
11use std::path::Path;
12use std::time::Duration;
13
14use bytes::Bytes;
15use tracing::warn;
16
17use crate::aof::{self, AofReader, AofRecord};
18use crate::format::FormatError;
19use crate::snapshot::{self, SnapValue, SnapshotReader};
20
21#[derive(Debug, Clone)]
23pub enum RecoveredValue {
24 String(Bytes),
25 List(VecDeque<Bytes>),
26 SortedSet(Vec<(f64, String)>),
28 Hash(HashMap<String, Bytes>),
30 Set(HashSet<String>),
32}
33
34impl From<SnapValue> for RecoveredValue {
35 fn from(sv: SnapValue) -> Self {
36 match sv {
37 SnapValue::String(data) => RecoveredValue::String(data),
38 SnapValue::List(deque) => RecoveredValue::List(deque),
39 SnapValue::SortedSet(members) => RecoveredValue::SortedSet(members),
40 SnapValue::Hash(map) => RecoveredValue::Hash(map),
41 SnapValue::Set(set) => RecoveredValue::Set(set),
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct RecoveredEntry {
49 pub key: String,
50 pub value: RecoveredValue,
51 pub ttl: Option<Duration>,
53}
54
55#[derive(Debug)]
57pub struct RecoveryResult {
58 pub entries: Vec<RecoveredEntry>,
60 pub loaded_snapshot: bool,
62 pub replayed_aof: bool,
64}
65
66pub fn recover_shard(data_dir: &Path, shard_id: u16) -> RecoveryResult {
71 let mut map: HashMap<String, (RecoveredValue, i64)> = HashMap::new();
73 let mut loaded_snapshot = false;
74 let mut replayed_aof = false;
75
76 let snap_path = snapshot::snapshot_path(data_dir, shard_id);
78 if snap_path.exists() {
79 match load_snapshot(&snap_path) {
80 Ok(entries) => {
81 for (key, value, ttl_ms) in entries {
82 map.insert(key, (RecoveredValue::from(value), ttl_ms));
83 }
84 loaded_snapshot = true;
85 }
86 Err(e) => {
87 warn!(shard_id, "failed to load snapshot, starting empty: {e}");
88 }
89 }
90 }
91
92 let aof_path = aof::aof_path(data_dir, shard_id);
94 if aof_path.exists() {
95 match replay_aof(&aof_path, &mut map) {
96 Ok(count) => {
97 if count > 0 {
98 replayed_aof = true;
99 }
100 }
101 Err(e) => {
102 warn!(
103 shard_id,
104 "failed to replay aof, using snapshot state only: {e}"
105 );
106 }
107 }
108 }
109
110 let entries = map
112 .into_iter()
113 .filter(|(_, (_, ttl_ms))| *ttl_ms != 0) .map(|(key, (value, ttl_ms))| RecoveredEntry {
115 key,
116 value,
117 ttl: if ttl_ms < 0 {
118 None
119 } else {
120 Some(Duration::from_millis(ttl_ms as u64))
121 },
122 })
123 .collect();
124
125 RecoveryResult {
126 entries,
127 loaded_snapshot,
128 replayed_aof,
129 }
130}
131
132fn load_snapshot(path: &Path) -> Result<Vec<(String, SnapValue, i64)>, FormatError> {
135 let mut reader = SnapshotReader::open(path)?;
136 let mut entries = Vec::new();
137
138 while let Some(entry) = reader.read_entry()? {
139 entries.push((entry.key, entry.value, entry.expire_ms));
141 }
142
143 reader.verify_footer()?;
144 Ok(entries)
145}
146
147fn apply_incr(map: &mut HashMap<String, (RecoveredValue, i64)>, key: String, delta: i64) {
150 let entry = map
152 .entry(key)
153 .or_insert_with(|| (RecoveredValue::String(Bytes::from("0")), -1));
154 if let RecoveredValue::String(ref mut data) = entry.0 {
155 let current = std::str::from_utf8(data)
156 .ok()
157 .and_then(|s| s.parse::<i64>().ok());
158 if let Some(n) = current {
159 if let Some(new_val) = n.checked_add(delta) {
160 *data = Bytes::from(new_val.to_string());
161 }
162 }
163 }
164}
165
166fn replay_aof(
169 path: &Path,
170 map: &mut HashMap<String, (RecoveredValue, i64)>,
171) -> Result<usize, FormatError> {
172 let mut reader = AofReader::open(path)?;
173 let mut count = 0;
174
175 while let Some(record) = reader.read_record()? {
176 match record {
177 AofRecord::Set {
178 key,
179 value,
180 expire_ms,
181 } => {
182 map.insert(key, (RecoveredValue::String(value), expire_ms));
184 }
185 AofRecord::Del { key } => {
186 map.remove(&key);
187 }
188 AofRecord::Expire { key, seconds } => {
189 if let Some(entry) = map.get_mut(&key) {
190 entry.1 = (seconds * 1000) as i64;
191 }
192 }
193 AofRecord::LPush { key, values } => {
194 let entry = map
195 .entry(key)
196 .or_insert_with(|| (RecoveredValue::List(VecDeque::new()), -1));
197 if let RecoveredValue::List(ref mut deque) = entry.0 {
198 for v in values {
199 deque.push_front(v);
200 }
201 }
202 }
203 AofRecord::RPush { key, values } => {
204 let entry = map
205 .entry(key)
206 .or_insert_with(|| (RecoveredValue::List(VecDeque::new()), -1));
207 if let RecoveredValue::List(ref mut deque) = entry.0 {
208 for v in values {
209 deque.push_back(v);
210 }
211 }
212 }
213 AofRecord::LPop { key } => {
214 if let Some(entry) = map.get_mut(&key) {
215 if let RecoveredValue::List(ref mut deque) = entry.0 {
216 deque.pop_front();
217 if deque.is_empty() {
218 map.remove(&key);
219 count += 1;
220 continue;
221 }
222 }
223 }
224 }
225 AofRecord::RPop { key } => {
226 if let Some(entry) = map.get_mut(&key) {
227 if let RecoveredValue::List(ref mut deque) = entry.0 {
228 deque.pop_back();
229 if deque.is_empty() {
230 map.remove(&key);
231 count += 1;
232 continue;
233 }
234 }
235 }
236 }
237 AofRecord::ZAdd { key, members } => {
238 let entry = map
239 .entry(key)
240 .or_insert_with(|| (RecoveredValue::SortedSet(Vec::new()), -1));
241 if let RecoveredValue::SortedSet(ref mut existing) = entry.0 {
242 let mut index: HashMap<String, usize> = existing
244 .iter()
245 .enumerate()
246 .map(|(i, (_, m))| (m.clone(), i))
247 .collect();
248 for (score, member) in members {
249 if let Some(&pos) = index.get(&member) {
250 existing[pos].0 = score;
251 } else {
252 let pos = existing.len();
253 index.insert(member.clone(), pos);
254 existing.push((score, member));
255 }
256 }
257 }
258 }
259 AofRecord::ZRem { key, members } => {
260 if let Some(entry) = map.get_mut(&key) {
261 if let RecoveredValue::SortedSet(ref mut existing) = entry.0 {
262 let to_remove: HashSet<&str> = members.iter().map(|m| m.as_str()).collect();
263 existing.retain(|(_, m)| !to_remove.contains(m.as_str()));
264 if existing.is_empty() {
265 map.remove(&key);
266 count += 1;
267 continue;
268 }
269 }
270 }
271 }
272 AofRecord::Persist { key } => {
273 if let Some(entry) = map.get_mut(&key) {
274 entry.1 = -1; }
276 }
277 AofRecord::Pexpire { key, milliseconds } => {
278 if let Some(entry) = map.get_mut(&key) {
279 entry.1 = milliseconds as i64;
280 }
281 }
282 AofRecord::Incr { key } => {
283 apply_incr(map, key, 1);
284 }
285 AofRecord::Decr { key } => {
286 apply_incr(map, key, -1);
287 }
288 AofRecord::HSet { key, fields } => {
289 let entry = map
290 .entry(key)
291 .or_insert_with(|| (RecoveredValue::Hash(HashMap::new()), -1));
292 if let RecoveredValue::Hash(ref mut hash) = entry.0 {
293 for (field, value) in fields {
294 hash.insert(field, value);
295 }
296 }
297 }
298 AofRecord::HDel { key, fields } => {
299 if let Some(entry) = map.get_mut(&key) {
300 if let RecoveredValue::Hash(ref mut hash) = entry.0 {
301 for field in fields {
302 hash.remove(&field);
303 }
304 if hash.is_empty() {
305 map.remove(&key);
306 count += 1;
307 continue;
308 }
309 }
310 }
311 }
312 AofRecord::HIncrBy { key, field, delta } => {
313 let entry = map
314 .entry(key)
315 .or_insert_with(|| (RecoveredValue::Hash(HashMap::new()), -1));
316 if let RecoveredValue::Hash(ref mut hash) = entry.0 {
317 let current: i64 = hash
318 .get(&field)
319 .and_then(|v| std::str::from_utf8(v).ok())
320 .and_then(|s| s.parse().ok())
321 .unwrap_or(0);
322 let new_val = current.saturating_add(delta);
323 hash.insert(field, Bytes::from(new_val.to_string()));
324 }
325 }
326 AofRecord::SAdd { key, members } => {
327 let entry = map
328 .entry(key)
329 .or_insert_with(|| (RecoveredValue::Set(HashSet::new()), -1));
330 if let RecoveredValue::Set(ref mut set) = entry.0 {
331 for member in members {
332 set.insert(member);
333 }
334 }
335 }
336 AofRecord::SRem { key, members } => {
337 if let Some(entry) = map.get_mut(&key) {
338 if let RecoveredValue::Set(ref mut set) = entry.0 {
339 for member in members {
340 set.remove(&member);
341 }
342 if set.is_empty() {
343 map.remove(&key);
344 count += 1;
345 continue;
346 }
347 }
348 }
349 }
350 }
351 count += 1;
352 }
353
354 Ok(count)
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use crate::aof::AofWriter;
361 use crate::snapshot::{SnapEntry, SnapValue, SnapshotWriter};
362
363 fn temp_dir() -> tempfile::TempDir {
364 tempfile::tempdir().expect("create temp dir")
365 }
366
367 #[test]
368 fn empty_dir_returns_empty_result() {
369 let dir = temp_dir();
370 let result = recover_shard(dir.path(), 0);
371 assert!(result.entries.is_empty());
372 assert!(!result.loaded_snapshot);
373 assert!(!result.replayed_aof);
374 }
375
376 #[test]
377 fn snapshot_only_recovery() {
378 let dir = temp_dir();
379 let path = snapshot::snapshot_path(dir.path(), 0);
380
381 {
382 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
383 writer
384 .write_entry(&SnapEntry {
385 key: "a".into(),
386 value: SnapValue::String(Bytes::from("1")),
387 expire_ms: -1,
388 })
389 .unwrap();
390 writer
391 .write_entry(&SnapEntry {
392 key: "b".into(),
393 value: SnapValue::String(Bytes::from("2")),
394 expire_ms: 60_000,
395 })
396 .unwrap();
397 writer.finish().unwrap();
398 }
399
400 let result = recover_shard(dir.path(), 0);
401 assert!(result.loaded_snapshot);
402 assert!(!result.replayed_aof);
403 assert_eq!(result.entries.len(), 2);
404 }
405
406 #[test]
407 fn aof_only_recovery() {
408 let dir = temp_dir();
409 let path = aof::aof_path(dir.path(), 0);
410
411 {
412 let mut writer = AofWriter::open(&path).unwrap();
413 writer
414 .write_record(&AofRecord::Set {
415 key: "x".into(),
416 value: Bytes::from("10"),
417 expire_ms: -1,
418 })
419 .unwrap();
420 writer
421 .write_record(&AofRecord::Set {
422 key: "y".into(),
423 value: Bytes::from("20"),
424 expire_ms: -1,
425 })
426 .unwrap();
427 writer.sync().unwrap();
428 }
429
430 let result = recover_shard(dir.path(), 0);
431 assert!(!result.loaded_snapshot);
432 assert!(result.replayed_aof);
433 assert_eq!(result.entries.len(), 2);
434 }
435
436 #[test]
437 fn snapshot_plus_aof_overlay() {
438 let dir = temp_dir();
439
440 {
442 let path = snapshot::snapshot_path(dir.path(), 0);
443 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
444 writer
445 .write_entry(&SnapEntry {
446 key: "a".into(),
447 value: SnapValue::String(Bytes::from("old")),
448 expire_ms: -1,
449 })
450 .unwrap();
451 writer.finish().unwrap();
452 }
453
454 {
456 let path = aof::aof_path(dir.path(), 0);
457 let mut writer = AofWriter::open(&path).unwrap();
458 writer
459 .write_record(&AofRecord::Set {
460 key: "a".into(),
461 value: Bytes::from("new"),
462 expire_ms: -1,
463 })
464 .unwrap();
465 writer
466 .write_record(&AofRecord::Set {
467 key: "b".into(),
468 value: Bytes::from("added"),
469 expire_ms: -1,
470 })
471 .unwrap();
472 writer.sync().unwrap();
473 }
474
475 let result = recover_shard(dir.path(), 0);
476 assert!(result.loaded_snapshot);
477 assert!(result.replayed_aof);
478
479 let map: HashMap<_, _> = result
480 .entries
481 .iter()
482 .map(|e| (e.key.as_str(), e.value.clone()))
483 .collect();
484 assert!(matches!(&map["a"], RecoveredValue::String(b) if b == &Bytes::from("new")));
485 assert!(matches!(&map["b"], RecoveredValue::String(b) if b == &Bytes::from("added")));
486 }
487
488 #[test]
489 fn del_removes_entry_during_replay() {
490 let dir = temp_dir();
491 let path = aof::aof_path(dir.path(), 0);
492
493 {
494 let mut writer = AofWriter::open(&path).unwrap();
495 writer
496 .write_record(&AofRecord::Set {
497 key: "gone".into(),
498 value: Bytes::from("temp"),
499 expire_ms: -1,
500 })
501 .unwrap();
502 writer
503 .write_record(&AofRecord::Del { key: "gone".into() })
504 .unwrap();
505 writer.sync().unwrap();
506 }
507
508 let result = recover_shard(dir.path(), 0);
509 assert!(result.entries.is_empty());
510 }
511
512 #[test]
513 fn expired_entries_skipped() {
514 let dir = temp_dir();
515 let path = snapshot::snapshot_path(dir.path(), 0);
516
517 {
518 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
519 writer
521 .write_entry(&SnapEntry {
522 key: "dead".into(),
523 value: SnapValue::String(Bytes::from("gone")),
524 expire_ms: 0,
525 })
526 .unwrap();
527 writer
529 .write_entry(&SnapEntry {
530 key: "alive".into(),
531 value: SnapValue::String(Bytes::from("here")),
532 expire_ms: 60_000,
533 })
534 .unwrap();
535 writer.finish().unwrap();
536 }
537
538 let result = recover_shard(dir.path(), 0);
539 assert_eq!(result.entries.len(), 1);
540 assert_eq!(result.entries[0].key, "alive");
541 }
542
543 #[test]
544 fn corrupt_snapshot_starts_empty() {
545 let dir = temp_dir();
546 let path = snapshot::snapshot_path(dir.path(), 0);
547
548 std::fs::write(&path, b"garbage data").unwrap();
549
550 let result = recover_shard(dir.path(), 0);
551 assert!(!result.loaded_snapshot);
552 assert!(result.entries.is_empty());
553 }
554
555 #[test]
556 fn sorted_set_snapshot_recovery() {
557 let dir = temp_dir();
558 let path = snapshot::snapshot_path(dir.path(), 0);
559
560 {
561 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
562 writer
563 .write_entry(&SnapEntry {
564 key: "board".into(),
565 value: SnapValue::SortedSet(vec![
566 (100.0, "alice".into()),
567 (200.0, "bob".into()),
568 ]),
569 expire_ms: -1,
570 })
571 .unwrap();
572 writer.finish().unwrap();
573 }
574
575 let result = recover_shard(dir.path(), 0);
576 assert!(result.loaded_snapshot);
577 assert_eq!(result.entries.len(), 1);
578 match &result.entries[0].value {
579 RecoveredValue::SortedSet(members) => {
580 assert_eq!(members.len(), 2);
581 assert!(members.contains(&(100.0, "alice".into())));
582 assert!(members.contains(&(200.0, "bob".into())));
583 }
584 other => panic!("expected SortedSet, got {other:?}"),
585 }
586 }
587
588 #[test]
589 fn sorted_set_aof_replay() {
590 let dir = temp_dir();
591 let path = aof::aof_path(dir.path(), 0);
592
593 {
594 let mut writer = AofWriter::open(&path).unwrap();
595 writer
596 .write_record(&AofRecord::ZAdd {
597 key: "board".into(),
598 members: vec![(100.0, "alice".into()), (200.0, "bob".into())],
599 })
600 .unwrap();
601 writer
602 .write_record(&AofRecord::ZRem {
603 key: "board".into(),
604 members: vec!["alice".into()],
605 })
606 .unwrap();
607 writer.sync().unwrap();
608 }
609
610 let result = recover_shard(dir.path(), 0);
611 assert!(result.replayed_aof);
612 assert_eq!(result.entries.len(), 1);
613 match &result.entries[0].value {
614 RecoveredValue::SortedSet(members) => {
615 assert_eq!(members.len(), 1);
616 assert_eq!(members[0], (200.0, "bob".into()));
617 }
618 other => panic!("expected SortedSet, got {other:?}"),
619 }
620 }
621
622 #[test]
623 fn sorted_set_zrem_auto_deletes_empty() {
624 let dir = temp_dir();
625 let path = aof::aof_path(dir.path(), 0);
626
627 {
628 let mut writer = AofWriter::open(&path).unwrap();
629 writer
630 .write_record(&AofRecord::ZAdd {
631 key: "board".into(),
632 members: vec![(100.0, "alice".into())],
633 })
634 .unwrap();
635 writer
636 .write_record(&AofRecord::ZRem {
637 key: "board".into(),
638 members: vec!["alice".into()],
639 })
640 .unwrap();
641 writer.sync().unwrap();
642 }
643
644 let result = recover_shard(dir.path(), 0);
645 assert!(result.entries.is_empty());
646 }
647
648 #[test]
649 fn expire_record_updates_ttl() {
650 let dir = temp_dir();
651 let path = aof::aof_path(dir.path(), 0);
652
653 {
654 let mut writer = AofWriter::open(&path).unwrap();
655 writer
656 .write_record(&AofRecord::Set {
657 key: "k".into(),
658 value: Bytes::from("v"),
659 expire_ms: -1,
660 })
661 .unwrap();
662 writer
663 .write_record(&AofRecord::Expire {
664 key: "k".into(),
665 seconds: 300,
666 })
667 .unwrap();
668 writer.sync().unwrap();
669 }
670
671 let result = recover_shard(dir.path(), 0);
672 assert_eq!(result.entries.len(), 1);
673 assert!(result.entries[0].ttl.is_some());
674 }
675
676 #[test]
677 fn persist_record_removes_ttl() {
678 let dir = temp_dir();
679 let path = aof::aof_path(dir.path(), 0);
680
681 {
682 let mut writer = AofWriter::open(&path).unwrap();
683 writer
684 .write_record(&AofRecord::Set {
685 key: "k".into(),
686 value: Bytes::from("v"),
687 expire_ms: 60_000,
688 })
689 .unwrap();
690 writer
691 .write_record(&AofRecord::Persist { key: "k".into() })
692 .unwrap();
693 writer.sync().unwrap();
694 }
695
696 let result = recover_shard(dir.path(), 0);
697 assert_eq!(result.entries.len(), 1);
698 assert!(result.entries[0].ttl.is_none());
699 }
700
701 #[test]
702 fn incr_decr_replay() {
703 let dir = temp_dir();
704 let path = aof::aof_path(dir.path(), 0);
705
706 {
707 let mut writer = AofWriter::open(&path).unwrap();
708 writer
709 .write_record(&AofRecord::Set {
710 key: "n".into(),
711 value: Bytes::from("10"),
712 expire_ms: -1,
713 })
714 .unwrap();
715 writer
716 .write_record(&AofRecord::Incr { key: "n".into() })
717 .unwrap();
718 writer
719 .write_record(&AofRecord::Incr { key: "n".into() })
720 .unwrap();
721 writer
722 .write_record(&AofRecord::Decr { key: "n".into() })
723 .unwrap();
724 writer
726 .write_record(&AofRecord::Incr {
727 key: "fresh".into(),
728 })
729 .unwrap();
730 writer.sync().unwrap();
731 }
732
733 let result = recover_shard(dir.path(), 0);
734 let map: HashMap<_, _> = result
735 .entries
736 .iter()
737 .map(|e| (e.key.as_str(), e.value.clone()))
738 .collect();
739
740 match &map["n"] {
742 RecoveredValue::String(data) => assert_eq!(data, &Bytes::from("11")),
743 other => panic!("expected String(\"11\"), got {other:?}"),
744 }
745 match &map["fresh"] {
747 RecoveredValue::String(data) => assert_eq!(data, &Bytes::from("1")),
748 other => panic!("expected String(\"1\"), got {other:?}"),
749 }
750 }
751
752 #[test]
753 fn pexpire_record_sets_ttl() {
754 let dir = temp_dir();
755 let path = aof::aof_path(dir.path(), 0);
756
757 {
758 let mut writer = AofWriter::open(&path).unwrap();
759 writer
760 .write_record(&AofRecord::Set {
761 key: "k".into(),
762 value: Bytes::from("v"),
763 expire_ms: -1,
764 })
765 .unwrap();
766 writer
767 .write_record(&AofRecord::Pexpire {
768 key: "k".into(),
769 milliseconds: 5000,
770 })
771 .unwrap();
772 writer.sync().unwrap();
773 }
774
775 let result = recover_shard(dir.path(), 0);
776 assert_eq!(result.entries.len(), 1);
777 assert!(result.entries[0].ttl.is_some());
778 }
779}