1use std::collections::{HashMap, HashSet, VecDeque};
11use std::path::Path;
12use std::time::{Duration, Instant};
13
14use bytes::Bytes;
15use tracing::warn;
16
17use crate::aof::{self, AofReader, AofRecord};
18use crate::format::FormatError;
19use crate::snapshot::{self, SnapValue, SnapshotReader};
20
21#[derive(Debug, Clone)]
23pub enum RecoveredValue {
24 String(Bytes),
25 List(VecDeque<Bytes>),
26 SortedSet(Vec<(f64, String)>),
28 Hash(HashMap<String, Bytes>),
30 Set(HashSet<String>),
32}
33
34impl From<SnapValue> for RecoveredValue {
35 fn from(sv: SnapValue) -> Self {
36 match sv {
37 SnapValue::String(data) => RecoveredValue::String(data),
38 SnapValue::List(deque) => RecoveredValue::List(deque),
39 SnapValue::SortedSet(members) => RecoveredValue::SortedSet(members),
40 SnapValue::Hash(map) => RecoveredValue::Hash(map),
41 SnapValue::Set(set) => RecoveredValue::Set(set),
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct RecoveredEntry {
49 pub key: String,
50 pub value: RecoveredValue,
51 pub expires_at: Option<Instant>,
54}
55
56#[derive(Debug)]
58pub struct RecoveryResult {
59 pub entries: Vec<RecoveredEntry>,
61 pub loaded_snapshot: bool,
63 pub replayed_aof: bool,
65}
66
67pub fn recover_shard(data_dir: &Path, shard_id: u16) -> RecoveryResult {
72 let now = Instant::now();
73 let mut map: HashMap<String, (RecoveredValue, Option<Instant>)> = HashMap::new();
74 let mut loaded_snapshot = false;
75 let mut replayed_aof = false;
76
77 let snap_path = snapshot::snapshot_path(data_dir, shard_id);
79 if snap_path.exists() {
80 match load_snapshot(&snap_path, now) {
81 Ok(entries) => {
82 for (key, value, expires_at) in entries {
83 map.insert(key, (RecoveredValue::from(value), expires_at));
84 }
85 loaded_snapshot = true;
86 }
87 Err(e) => {
88 warn!(shard_id, "failed to load snapshot, starting empty: {e}");
89 }
90 }
91 }
92
93 let aof_path = aof::aof_path(data_dir, shard_id);
95 if aof_path.exists() {
96 match replay_aof(&aof_path, &mut map, now) {
97 Ok(count) => {
98 if count > 0 {
99 replayed_aof = true;
100 }
101 }
102 Err(e) => {
103 warn!(
104 shard_id,
105 "failed to replay aof, using snapshot state only: {e}"
106 );
107 }
108 }
109 }
110
111 let entries = map
113 .into_iter()
114 .filter(|(_, (_, expires_at))| match expires_at {
115 Some(deadline) => *deadline > now,
116 None => true,
117 })
118 .map(|(key, (value, expires_at))| RecoveredEntry {
119 key,
120 value,
121 expires_at,
122 })
123 .collect();
124
125 RecoveryResult {
126 entries,
127 loaded_snapshot,
128 replayed_aof,
129 }
130}
131
132fn load_snapshot(
134 path: &Path,
135 now: Instant,
136) -> Result<Vec<(String, SnapValue, Option<Instant>)>, FormatError> {
137 let mut reader = SnapshotReader::open(path)?;
138 let mut entries = Vec::new();
139
140 while let Some(entry) = reader.read_entry()? {
141 let expires_at = if entry.expire_ms >= 0 {
142 Some(now + Duration::from_millis(entry.expire_ms as u64))
143 } else {
144 None
145 };
146 entries.push((entry.key, entry.value, expires_at));
147 }
148
149 reader.verify_footer()?;
150 Ok(entries)
151}
152
153fn apply_incr(
156 map: &mut HashMap<String, (RecoveredValue, Option<Instant>)>,
157 key: String,
158 delta: i64,
159) {
160 let entry = map
161 .entry(key)
162 .or_insert_with(|| (RecoveredValue::String(Bytes::from("0")), None));
163 if let RecoveredValue::String(ref mut data) = entry.0 {
164 let current = std::str::from_utf8(data)
165 .ok()
166 .and_then(|s| s.parse::<i64>().ok());
167 if let Some(n) = current {
168 if let Some(new_val) = n.checked_add(delta) {
169 *data = Bytes::from(new_val.to_string());
170 }
171 }
172 }
173}
174
175fn replay_aof(
178 path: &Path,
179 map: &mut HashMap<String, (RecoveredValue, Option<Instant>)>,
180 now: Instant,
181) -> Result<usize, FormatError> {
182 let mut reader = AofReader::open(path)?;
183 let mut count = 0;
184
185 while let Some(record) = reader.read_record()? {
186 match record {
187 AofRecord::Set {
188 key,
189 value,
190 expire_ms,
191 } => {
192 let expires_at = if expire_ms >= 0 {
193 Some(now + Duration::from_millis(expire_ms as u64))
194 } else {
195 None
196 };
197 map.insert(key, (RecoveredValue::String(value), expires_at));
198 }
199 AofRecord::Del { key } => {
200 map.remove(&key);
201 }
202 AofRecord::Expire { key, seconds } => {
203 if let Some(entry) = map.get_mut(&key) {
204 entry.1 = Some(now + Duration::from_secs(seconds));
205 }
206 }
207 AofRecord::LPush { key, values } => {
208 let entry = map
209 .entry(key)
210 .or_insert_with(|| (RecoveredValue::List(VecDeque::new()), None));
211 if let RecoveredValue::List(ref mut deque) = entry.0 {
212 for v in values {
213 deque.push_front(v);
214 }
215 }
216 }
217 AofRecord::RPush { key, values } => {
218 let entry = map
219 .entry(key)
220 .or_insert_with(|| (RecoveredValue::List(VecDeque::new()), None));
221 if let RecoveredValue::List(ref mut deque) = entry.0 {
222 for v in values {
223 deque.push_back(v);
224 }
225 }
226 }
227 AofRecord::LPop { key } => {
228 if let Some(entry) = map.get_mut(&key) {
229 if let RecoveredValue::List(ref mut deque) = entry.0 {
230 deque.pop_front();
231 if deque.is_empty() {
232 map.remove(&key);
233 count += 1;
234 continue;
235 }
236 }
237 }
238 }
239 AofRecord::RPop { key } => {
240 if let Some(entry) = map.get_mut(&key) {
241 if let RecoveredValue::List(ref mut deque) = entry.0 {
242 deque.pop_back();
243 if deque.is_empty() {
244 map.remove(&key);
245 count += 1;
246 continue;
247 }
248 }
249 }
250 }
251 AofRecord::ZAdd { key, members } => {
252 let entry = map
253 .entry(key)
254 .or_insert_with(|| (RecoveredValue::SortedSet(Vec::new()), None));
255 if let RecoveredValue::SortedSet(ref mut existing) = entry.0 {
256 let mut index: HashMap<String, usize> = existing
258 .iter()
259 .enumerate()
260 .map(|(i, (_, m))| (m.clone(), i))
261 .collect();
262 for (score, member) in members {
263 if let Some(&pos) = index.get(&member) {
264 existing[pos].0 = score;
265 } else {
266 let pos = existing.len();
267 index.insert(member.clone(), pos);
268 existing.push((score, member));
269 }
270 }
271 }
272 }
273 AofRecord::ZRem { key, members } => {
274 if let Some(entry) = map.get_mut(&key) {
275 if let RecoveredValue::SortedSet(ref mut existing) = entry.0 {
276 let to_remove: HashSet<&str> = members.iter().map(|m| m.as_str()).collect();
277 existing.retain(|(_, m)| !to_remove.contains(m.as_str()));
278 if existing.is_empty() {
279 map.remove(&key);
280 count += 1;
281 continue;
282 }
283 }
284 }
285 }
286 AofRecord::Persist { key } => {
287 if let Some(entry) = map.get_mut(&key) {
288 entry.1 = None;
289 }
290 }
291 AofRecord::Pexpire { key, milliseconds } => {
292 if let Some(entry) = map.get_mut(&key) {
293 entry.1 = Some(now + Duration::from_millis(milliseconds));
294 }
295 }
296 AofRecord::Incr { key } => {
297 apply_incr(map, key, 1);
298 }
299 AofRecord::Decr { key } => {
300 apply_incr(map, key, -1);
301 }
302 AofRecord::HSet { key, fields } => {
303 let entry = map
304 .entry(key)
305 .or_insert_with(|| (RecoveredValue::Hash(HashMap::new()), None));
306 if let RecoveredValue::Hash(ref mut hash) = entry.0 {
307 for (field, value) in fields {
308 hash.insert(field, value);
309 }
310 }
311 }
312 AofRecord::HDel { key, fields } => {
313 if let Some(entry) = map.get_mut(&key) {
314 if let RecoveredValue::Hash(ref mut hash) = entry.0 {
315 for field in fields {
316 hash.remove(&field);
317 }
318 if hash.is_empty() {
319 map.remove(&key);
320 count += 1;
321 continue;
322 }
323 }
324 }
325 }
326 AofRecord::HIncrBy { key, field, delta } => {
327 let entry = map
328 .entry(key)
329 .or_insert_with(|| (RecoveredValue::Hash(HashMap::new()), None));
330 if let RecoveredValue::Hash(ref mut hash) = entry.0 {
331 let current: i64 = hash
332 .get(&field)
333 .and_then(|v| std::str::from_utf8(v).ok())
334 .and_then(|s| s.parse().ok())
335 .unwrap_or(0);
336 let new_val = current.saturating_add(delta);
337 hash.insert(field, Bytes::from(new_val.to_string()));
338 }
339 }
340 AofRecord::SAdd { key, members } => {
341 let entry = map
342 .entry(key)
343 .or_insert_with(|| (RecoveredValue::Set(HashSet::new()), None));
344 if let RecoveredValue::Set(ref mut set) = entry.0 {
345 for member in members {
346 set.insert(member);
347 }
348 }
349 }
350 AofRecord::SRem { key, members } => {
351 if let Some(entry) = map.get_mut(&key) {
352 if let RecoveredValue::Set(ref mut set) = entry.0 {
353 for member in members {
354 set.remove(&member);
355 }
356 if set.is_empty() {
357 map.remove(&key);
358 count += 1;
359 continue;
360 }
361 }
362 }
363 }
364 }
365 count += 1;
366 }
367
368 Ok(count)
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use crate::aof::AofWriter;
375 use crate::snapshot::{SnapEntry, SnapValue, SnapshotWriter};
376
377 fn temp_dir() -> tempfile::TempDir {
378 tempfile::tempdir().expect("create temp dir")
379 }
380
381 #[test]
382 fn empty_dir_returns_empty_result() {
383 let dir = temp_dir();
384 let result = recover_shard(dir.path(), 0);
385 assert!(result.entries.is_empty());
386 assert!(!result.loaded_snapshot);
387 assert!(!result.replayed_aof);
388 }
389
390 #[test]
391 fn snapshot_only_recovery() {
392 let dir = temp_dir();
393 let path = snapshot::snapshot_path(dir.path(), 0);
394
395 {
396 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
397 writer
398 .write_entry(&SnapEntry {
399 key: "a".into(),
400 value: SnapValue::String(Bytes::from("1")),
401 expire_ms: -1,
402 })
403 .unwrap();
404 writer
405 .write_entry(&SnapEntry {
406 key: "b".into(),
407 value: SnapValue::String(Bytes::from("2")),
408 expire_ms: 60_000,
409 })
410 .unwrap();
411 writer.finish().unwrap();
412 }
413
414 let result = recover_shard(dir.path(), 0);
415 assert!(result.loaded_snapshot);
416 assert!(!result.replayed_aof);
417 assert_eq!(result.entries.len(), 2);
418 }
419
420 #[test]
421 fn aof_only_recovery() {
422 let dir = temp_dir();
423 let path = aof::aof_path(dir.path(), 0);
424
425 {
426 let mut writer = AofWriter::open(&path).unwrap();
427 writer
428 .write_record(&AofRecord::Set {
429 key: "x".into(),
430 value: Bytes::from("10"),
431 expire_ms: -1,
432 })
433 .unwrap();
434 writer
435 .write_record(&AofRecord::Set {
436 key: "y".into(),
437 value: Bytes::from("20"),
438 expire_ms: -1,
439 })
440 .unwrap();
441 writer.sync().unwrap();
442 }
443
444 let result = recover_shard(dir.path(), 0);
445 assert!(!result.loaded_snapshot);
446 assert!(result.replayed_aof);
447 assert_eq!(result.entries.len(), 2);
448 }
449
450 #[test]
451 fn snapshot_plus_aof_overlay() {
452 let dir = temp_dir();
453
454 {
456 let path = snapshot::snapshot_path(dir.path(), 0);
457 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
458 writer
459 .write_entry(&SnapEntry {
460 key: "a".into(),
461 value: SnapValue::String(Bytes::from("old")),
462 expire_ms: -1,
463 })
464 .unwrap();
465 writer.finish().unwrap();
466 }
467
468 {
470 let path = aof::aof_path(dir.path(), 0);
471 let mut writer = AofWriter::open(&path).unwrap();
472 writer
473 .write_record(&AofRecord::Set {
474 key: "a".into(),
475 value: Bytes::from("new"),
476 expire_ms: -1,
477 })
478 .unwrap();
479 writer
480 .write_record(&AofRecord::Set {
481 key: "b".into(),
482 value: Bytes::from("added"),
483 expire_ms: -1,
484 })
485 .unwrap();
486 writer.sync().unwrap();
487 }
488
489 let result = recover_shard(dir.path(), 0);
490 assert!(result.loaded_snapshot);
491 assert!(result.replayed_aof);
492
493 let map: HashMap<_, _> = result
494 .entries
495 .iter()
496 .map(|e| (e.key.as_str(), e.value.clone()))
497 .collect();
498 assert!(matches!(&map["a"], RecoveredValue::String(b) if b == &Bytes::from("new")));
499 assert!(matches!(&map["b"], RecoveredValue::String(b) if b == &Bytes::from("added")));
500 }
501
502 #[test]
503 fn del_removes_entry_during_replay() {
504 let dir = temp_dir();
505 let path = aof::aof_path(dir.path(), 0);
506
507 {
508 let mut writer = AofWriter::open(&path).unwrap();
509 writer
510 .write_record(&AofRecord::Set {
511 key: "gone".into(),
512 value: Bytes::from("temp"),
513 expire_ms: -1,
514 })
515 .unwrap();
516 writer
517 .write_record(&AofRecord::Del { key: "gone".into() })
518 .unwrap();
519 writer.sync().unwrap();
520 }
521
522 let result = recover_shard(dir.path(), 0);
523 assert!(result.entries.is_empty());
524 }
525
526 #[test]
527 fn expired_entries_skipped() {
528 let dir = temp_dir();
529 let path = snapshot::snapshot_path(dir.path(), 0);
530
531 {
532 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
533 writer
535 .write_entry(&SnapEntry {
536 key: "dead".into(),
537 value: SnapValue::String(Bytes::from("gone")),
538 expire_ms: 0,
539 })
540 .unwrap();
541 writer
543 .write_entry(&SnapEntry {
544 key: "alive".into(),
545 value: SnapValue::String(Bytes::from("here")),
546 expire_ms: 60_000,
547 })
548 .unwrap();
549 writer.finish().unwrap();
550 }
551
552 let result = recover_shard(dir.path(), 0);
553 assert_eq!(result.entries.len(), 1);
554 assert_eq!(result.entries[0].key, "alive");
555 }
556
557 #[test]
558 fn corrupt_snapshot_starts_empty() {
559 let dir = temp_dir();
560 let path = snapshot::snapshot_path(dir.path(), 0);
561
562 std::fs::write(&path, b"garbage data").unwrap();
563
564 let result = recover_shard(dir.path(), 0);
565 assert!(!result.loaded_snapshot);
566 assert!(result.entries.is_empty());
567 }
568
569 #[test]
570 fn sorted_set_snapshot_recovery() {
571 let dir = temp_dir();
572 let path = snapshot::snapshot_path(dir.path(), 0);
573
574 {
575 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
576 writer
577 .write_entry(&SnapEntry {
578 key: "board".into(),
579 value: SnapValue::SortedSet(vec![
580 (100.0, "alice".into()),
581 (200.0, "bob".into()),
582 ]),
583 expire_ms: -1,
584 })
585 .unwrap();
586 writer.finish().unwrap();
587 }
588
589 let result = recover_shard(dir.path(), 0);
590 assert!(result.loaded_snapshot);
591 assert_eq!(result.entries.len(), 1);
592 match &result.entries[0].value {
593 RecoveredValue::SortedSet(members) => {
594 assert_eq!(members.len(), 2);
595 assert!(members.contains(&(100.0, "alice".into())));
596 assert!(members.contains(&(200.0, "bob".into())));
597 }
598 other => panic!("expected SortedSet, got {other:?}"),
599 }
600 }
601
602 #[test]
603 fn sorted_set_aof_replay() {
604 let dir = temp_dir();
605 let path = aof::aof_path(dir.path(), 0);
606
607 {
608 let mut writer = AofWriter::open(&path).unwrap();
609 writer
610 .write_record(&AofRecord::ZAdd {
611 key: "board".into(),
612 members: vec![(100.0, "alice".into()), (200.0, "bob".into())],
613 })
614 .unwrap();
615 writer
616 .write_record(&AofRecord::ZRem {
617 key: "board".into(),
618 members: vec!["alice".into()],
619 })
620 .unwrap();
621 writer.sync().unwrap();
622 }
623
624 let result = recover_shard(dir.path(), 0);
625 assert!(result.replayed_aof);
626 assert_eq!(result.entries.len(), 1);
627 match &result.entries[0].value {
628 RecoveredValue::SortedSet(members) => {
629 assert_eq!(members.len(), 1);
630 assert_eq!(members[0], (200.0, "bob".into()));
631 }
632 other => panic!("expected SortedSet, got {other:?}"),
633 }
634 }
635
636 #[test]
637 fn sorted_set_zrem_auto_deletes_empty() {
638 let dir = temp_dir();
639 let path = aof::aof_path(dir.path(), 0);
640
641 {
642 let mut writer = AofWriter::open(&path).unwrap();
643 writer
644 .write_record(&AofRecord::ZAdd {
645 key: "board".into(),
646 members: vec![(100.0, "alice".into())],
647 })
648 .unwrap();
649 writer
650 .write_record(&AofRecord::ZRem {
651 key: "board".into(),
652 members: vec!["alice".into()],
653 })
654 .unwrap();
655 writer.sync().unwrap();
656 }
657
658 let result = recover_shard(dir.path(), 0);
659 assert!(result.entries.is_empty());
660 }
661
662 #[test]
663 fn expire_record_updates_ttl() {
664 let dir = temp_dir();
665 let path = aof::aof_path(dir.path(), 0);
666
667 {
668 let mut writer = AofWriter::open(&path).unwrap();
669 writer
670 .write_record(&AofRecord::Set {
671 key: "k".into(),
672 value: Bytes::from("v"),
673 expire_ms: -1,
674 })
675 .unwrap();
676 writer
677 .write_record(&AofRecord::Expire {
678 key: "k".into(),
679 seconds: 300,
680 })
681 .unwrap();
682 writer.sync().unwrap();
683 }
684
685 let result = recover_shard(dir.path(), 0);
686 assert_eq!(result.entries.len(), 1);
687 assert!(result.entries[0].expires_at.is_some());
688 }
689
690 #[test]
691 fn persist_record_removes_ttl() {
692 let dir = temp_dir();
693 let path = aof::aof_path(dir.path(), 0);
694
695 {
696 let mut writer = AofWriter::open(&path).unwrap();
697 writer
698 .write_record(&AofRecord::Set {
699 key: "k".into(),
700 value: Bytes::from("v"),
701 expire_ms: 60_000,
702 })
703 .unwrap();
704 writer
705 .write_record(&AofRecord::Persist { key: "k".into() })
706 .unwrap();
707 writer.sync().unwrap();
708 }
709
710 let result = recover_shard(dir.path(), 0);
711 assert_eq!(result.entries.len(), 1);
712 assert!(result.entries[0].expires_at.is_none());
713 }
714
715 #[test]
716 fn incr_decr_replay() {
717 let dir = temp_dir();
718 let path = aof::aof_path(dir.path(), 0);
719
720 {
721 let mut writer = AofWriter::open(&path).unwrap();
722 writer
723 .write_record(&AofRecord::Set {
724 key: "n".into(),
725 value: Bytes::from("10"),
726 expire_ms: -1,
727 })
728 .unwrap();
729 writer
730 .write_record(&AofRecord::Incr { key: "n".into() })
731 .unwrap();
732 writer
733 .write_record(&AofRecord::Incr { key: "n".into() })
734 .unwrap();
735 writer
736 .write_record(&AofRecord::Decr { key: "n".into() })
737 .unwrap();
738 writer
740 .write_record(&AofRecord::Incr {
741 key: "fresh".into(),
742 })
743 .unwrap();
744 writer.sync().unwrap();
745 }
746
747 let result = recover_shard(dir.path(), 0);
748 let map: HashMap<_, _> = result
749 .entries
750 .iter()
751 .map(|e| (e.key.as_str(), e.value.clone()))
752 .collect();
753
754 match &map["n"] {
756 RecoveredValue::String(data) => assert_eq!(data, &Bytes::from("11")),
757 other => panic!("expected String(\"11\"), got {other:?}"),
758 }
759 match &map["fresh"] {
761 RecoveredValue::String(data) => assert_eq!(data, &Bytes::from("1")),
762 other => panic!("expected String(\"1\"), got {other:?}"),
763 }
764 }
765
766 #[test]
767 fn pexpire_record_sets_ttl() {
768 let dir = temp_dir();
769 let path = aof::aof_path(dir.path(), 0);
770
771 {
772 let mut writer = AofWriter::open(&path).unwrap();
773 writer
774 .write_record(&AofRecord::Set {
775 key: "k".into(),
776 value: Bytes::from("v"),
777 expire_ms: -1,
778 })
779 .unwrap();
780 writer
781 .write_record(&AofRecord::Pexpire {
782 key: "k".into(),
783 milliseconds: 5000,
784 })
785 .unwrap();
786 writer.sync().unwrap();
787 }
788
789 let result = recover_shard(dir.path(), 0);
790 assert_eq!(result.entries.len(), 1);
791 assert!(result.entries[0].expires_at.is_some());
792 }
793}