Skip to main content

ember_persistence/
recovery.rs

1//! Recovery: loading snapshots and replaying AOF on shard startup.
2//!
3//! The recovery sequence is:
4//! 1. Load snapshot if it exists (bulk restore of entries).
5//! 2. Replay AOF if it exists (apply mutations on top of snapshot state).
6//! 3. Skip entries whose TTL expired during downtime.
7//! 4. If no files exist, start with an empty state.
8//! 5. If files are corrupt, log a warning and start empty.
9
10use std::collections::{HashMap, HashSet, VecDeque};
11use std::path::Path;
12use std::time::{Duration, Instant};
13
14use bytes::Bytes;
15use tracing::warn;
16
17use crate::aof::{self, AofReader, AofRecord};
18use crate::format::FormatError;
19use crate::snapshot::{self, SnapValue, SnapshotReader};
20
21/// The value of a recovered entry.
22#[derive(Debug, Clone)]
23pub enum RecoveredValue {
24    String(Bytes),
25    List(VecDeque<Bytes>),
26    /// Sorted set stored as (score, member) pairs.
27    SortedSet(Vec<(f64, String)>),
28}
29
30impl From<SnapValue> for RecoveredValue {
31    fn from(sv: SnapValue) -> Self {
32        match sv {
33            SnapValue::String(data) => RecoveredValue::String(data),
34            SnapValue::List(deque) => RecoveredValue::List(deque),
35            SnapValue::SortedSet(members) => RecoveredValue::SortedSet(members),
36        }
37    }
38}
39
40/// A single recovered entry ready to be inserted into a keyspace.
41#[derive(Debug, Clone)]
42pub struct RecoveredEntry {
43    pub key: String,
44    pub value: RecoveredValue,
45    /// Absolute deadline computed from the persisted remaining TTL.
46    /// `None` means no expiration.
47    pub expires_at: Option<Instant>,
48}
49
50/// The result of recovering a shard's persisted state.
51#[derive(Debug)]
52pub struct RecoveryResult {
53    /// Recovered entries, keyed by name for easy insertion.
54    pub entries: Vec<RecoveredEntry>,
55    /// Whether a snapshot was loaded.
56    pub loaded_snapshot: bool,
57    /// Whether an AOF was replayed.
58    pub replayed_aof: bool,
59}
60
61/// Recovers a shard's state from snapshot and/or AOF files.
62///
63/// Returns a list of live entries to restore into the keyspace.
64/// Entries whose TTL expired during downtime are silently skipped.
65pub fn recover_shard(data_dir: &Path, shard_id: u16) -> RecoveryResult {
66    let now = Instant::now();
67    let mut map: HashMap<String, (RecoveredValue, Option<Instant>)> = HashMap::new();
68    let mut loaded_snapshot = false;
69    let mut replayed_aof = false;
70
71    // step 1: load snapshot
72    let snap_path = snapshot::snapshot_path(data_dir, shard_id);
73    if snap_path.exists() {
74        match load_snapshot(&snap_path, now) {
75            Ok(entries) => {
76                for (key, value, expires_at) in entries {
77                    map.insert(key, (RecoveredValue::from(value), expires_at));
78                }
79                loaded_snapshot = true;
80            }
81            Err(e) => {
82                warn!(shard_id, "failed to load snapshot, starting empty: {e}");
83            }
84        }
85    }
86
87    // step 2: replay AOF
88    let aof_path = aof::aof_path(data_dir, shard_id);
89    if aof_path.exists() {
90        match replay_aof(&aof_path, &mut map, now) {
91            Ok(count) => {
92                if count > 0 {
93                    replayed_aof = true;
94                }
95            }
96            Err(e) => {
97                warn!(
98                    shard_id,
99                    "failed to replay aof, using snapshot state only: {e}"
100                );
101            }
102        }
103    }
104
105    // step 3: filter out expired entries and build result
106    let entries = map
107        .into_iter()
108        .filter(|(_, (_, expires_at))| match expires_at {
109            Some(deadline) => *deadline > now,
110            None => true,
111        })
112        .map(|(key, (value, expires_at))| RecoveredEntry {
113            key,
114            value,
115            expires_at,
116        })
117        .collect();
118
119    RecoveryResult {
120        entries,
121        loaded_snapshot,
122        replayed_aof,
123    }
124}
125
126/// Loads entries from a snapshot file.
127fn load_snapshot(
128    path: &Path,
129    now: Instant,
130) -> Result<Vec<(String, SnapValue, Option<Instant>)>, FormatError> {
131    let mut reader = SnapshotReader::open(path)?;
132    let mut entries = Vec::new();
133
134    while let Some(entry) = reader.read_entry()? {
135        let expires_at = if entry.expire_ms >= 0 {
136            Some(now + Duration::from_millis(entry.expire_ms as u64))
137        } else {
138            None
139        };
140        entries.push((entry.key, entry.value, expires_at));
141    }
142
143    reader.verify_footer()?;
144    Ok(entries)
145}
146
147/// Applies an increment/decrement to a recovered entry. If the key doesn't
148/// exist, initializes it to "0" first. Non-integer values are silently skipped.
149fn apply_incr(
150    map: &mut HashMap<String, (RecoveredValue, Option<Instant>)>,
151    key: String,
152    delta: i64,
153) {
154    let entry = map
155        .entry(key)
156        .or_insert_with(|| (RecoveredValue::String(Bytes::from("0")), None));
157    if let RecoveredValue::String(ref mut data) = entry.0 {
158        let current = std::str::from_utf8(data)
159            .ok()
160            .and_then(|s| s.parse::<i64>().ok());
161        if let Some(n) = current {
162            if let Some(new_val) = n.checked_add(delta) {
163                *data = Bytes::from(new_val.to_string());
164            }
165        }
166    }
167}
168
169/// Replays AOF records into the in-memory map. Returns the number of
170/// records replayed.
171fn replay_aof(
172    path: &Path,
173    map: &mut HashMap<String, (RecoveredValue, Option<Instant>)>,
174    now: Instant,
175) -> Result<usize, FormatError> {
176    let mut reader = AofReader::open(path)?;
177    let mut count = 0;
178
179    while let Some(record) = reader.read_record()? {
180        match record {
181            AofRecord::Set {
182                key,
183                value,
184                expire_ms,
185            } => {
186                let expires_at = if expire_ms >= 0 {
187                    Some(now + Duration::from_millis(expire_ms as u64))
188                } else {
189                    None
190                };
191                map.insert(key, (RecoveredValue::String(value), expires_at));
192            }
193            AofRecord::Del { key } => {
194                map.remove(&key);
195            }
196            AofRecord::Expire { key, seconds } => {
197                if let Some(entry) = map.get_mut(&key) {
198                    entry.1 = Some(now + Duration::from_secs(seconds));
199                }
200            }
201            AofRecord::LPush { key, values } => {
202                let entry = map
203                    .entry(key)
204                    .or_insert_with(|| (RecoveredValue::List(VecDeque::new()), None));
205                if let RecoveredValue::List(ref mut deque) = entry.0 {
206                    for v in values {
207                        deque.push_front(v);
208                    }
209                }
210            }
211            AofRecord::RPush { key, values } => {
212                let entry = map
213                    .entry(key)
214                    .or_insert_with(|| (RecoveredValue::List(VecDeque::new()), None));
215                if let RecoveredValue::List(ref mut deque) = entry.0 {
216                    for v in values {
217                        deque.push_back(v);
218                    }
219                }
220            }
221            AofRecord::LPop { key } => {
222                if let Some(entry) = map.get_mut(&key) {
223                    if let RecoveredValue::List(ref mut deque) = entry.0 {
224                        deque.pop_front();
225                        if deque.is_empty() {
226                            map.remove(&key);
227                            count += 1;
228                            continue;
229                        }
230                    }
231                }
232            }
233            AofRecord::RPop { key } => {
234                if let Some(entry) = map.get_mut(&key) {
235                    if let RecoveredValue::List(ref mut deque) = entry.0 {
236                        deque.pop_back();
237                        if deque.is_empty() {
238                            map.remove(&key);
239                            count += 1;
240                            continue;
241                        }
242                    }
243                }
244            }
245            AofRecord::ZAdd { key, members } => {
246                let entry = map
247                    .entry(key)
248                    .or_insert_with(|| (RecoveredValue::SortedSet(Vec::new()), None));
249                if let RecoveredValue::SortedSet(ref mut existing) = entry.0 {
250                    // build a position index for O(1) member lookups
251                    let mut index: HashMap<String, usize> = existing
252                        .iter()
253                        .enumerate()
254                        .map(|(i, (_, m))| (m.clone(), i))
255                        .collect();
256                    for (score, member) in members {
257                        if let Some(&pos) = index.get(&member) {
258                            existing[pos].0 = score;
259                        } else {
260                            let pos = existing.len();
261                            index.insert(member.clone(), pos);
262                            existing.push((score, member));
263                        }
264                    }
265                }
266            }
267            AofRecord::ZRem { key, members } => {
268                if let Some(entry) = map.get_mut(&key) {
269                    if let RecoveredValue::SortedSet(ref mut existing) = entry.0 {
270                        let to_remove: HashSet<&str> = members.iter().map(|m| m.as_str()).collect();
271                        existing.retain(|(_, m)| !to_remove.contains(m.as_str()));
272                        if existing.is_empty() {
273                            map.remove(&key);
274                            count += 1;
275                            continue;
276                        }
277                    }
278                }
279            }
280            AofRecord::Persist { key } => {
281                if let Some(entry) = map.get_mut(&key) {
282                    entry.1 = None;
283                }
284            }
285            AofRecord::Pexpire { key, milliseconds } => {
286                if let Some(entry) = map.get_mut(&key) {
287                    entry.1 = Some(now + Duration::from_millis(milliseconds));
288                }
289            }
290            AofRecord::Incr { key } => {
291                apply_incr(map, key, 1);
292            }
293            AofRecord::Decr { key } => {
294                apply_incr(map, key, -1);
295            }
296        }
297        count += 1;
298    }
299
300    Ok(count)
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use crate::aof::AofWriter;
307    use crate::snapshot::{SnapEntry, SnapValue, SnapshotWriter};
308
309    fn temp_dir() -> tempfile::TempDir {
310        tempfile::tempdir().expect("create temp dir")
311    }
312
313    #[test]
314    fn empty_dir_returns_empty_result() {
315        let dir = temp_dir();
316        let result = recover_shard(dir.path(), 0);
317        assert!(result.entries.is_empty());
318        assert!(!result.loaded_snapshot);
319        assert!(!result.replayed_aof);
320    }
321
322    #[test]
323    fn snapshot_only_recovery() {
324        let dir = temp_dir();
325        let path = snapshot::snapshot_path(dir.path(), 0);
326
327        {
328            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
329            writer
330                .write_entry(&SnapEntry {
331                    key: "a".into(),
332                    value: SnapValue::String(Bytes::from("1")),
333                    expire_ms: -1,
334                })
335                .unwrap();
336            writer
337                .write_entry(&SnapEntry {
338                    key: "b".into(),
339                    value: SnapValue::String(Bytes::from("2")),
340                    expire_ms: 60_000,
341                })
342                .unwrap();
343            writer.finish().unwrap();
344        }
345
346        let result = recover_shard(dir.path(), 0);
347        assert!(result.loaded_snapshot);
348        assert!(!result.replayed_aof);
349        assert_eq!(result.entries.len(), 2);
350    }
351
352    #[test]
353    fn aof_only_recovery() {
354        let dir = temp_dir();
355        let path = aof::aof_path(dir.path(), 0);
356
357        {
358            let mut writer = AofWriter::open(&path).unwrap();
359            writer
360                .write_record(&AofRecord::Set {
361                    key: "x".into(),
362                    value: Bytes::from("10"),
363                    expire_ms: -1,
364                })
365                .unwrap();
366            writer
367                .write_record(&AofRecord::Set {
368                    key: "y".into(),
369                    value: Bytes::from("20"),
370                    expire_ms: -1,
371                })
372                .unwrap();
373            writer.sync().unwrap();
374        }
375
376        let result = recover_shard(dir.path(), 0);
377        assert!(!result.loaded_snapshot);
378        assert!(result.replayed_aof);
379        assert_eq!(result.entries.len(), 2);
380    }
381
382    #[test]
383    fn snapshot_plus_aof_overlay() {
384        let dir = temp_dir();
385
386        // snapshot with key "a" = "old"
387        {
388            let path = snapshot::snapshot_path(dir.path(), 0);
389            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
390            writer
391                .write_entry(&SnapEntry {
392                    key: "a".into(),
393                    value: SnapValue::String(Bytes::from("old")),
394                    expire_ms: -1,
395                })
396                .unwrap();
397            writer.finish().unwrap();
398        }
399
400        // AOF overwrites "a" to "new" and adds "b"
401        {
402            let path = aof::aof_path(dir.path(), 0);
403            let mut writer = AofWriter::open(&path).unwrap();
404            writer
405                .write_record(&AofRecord::Set {
406                    key: "a".into(),
407                    value: Bytes::from("new"),
408                    expire_ms: -1,
409                })
410                .unwrap();
411            writer
412                .write_record(&AofRecord::Set {
413                    key: "b".into(),
414                    value: Bytes::from("added"),
415                    expire_ms: -1,
416                })
417                .unwrap();
418            writer.sync().unwrap();
419        }
420
421        let result = recover_shard(dir.path(), 0);
422        assert!(result.loaded_snapshot);
423        assert!(result.replayed_aof);
424
425        let map: HashMap<_, _> = result
426            .entries
427            .iter()
428            .map(|e| (e.key.as_str(), e.value.clone()))
429            .collect();
430        assert!(matches!(&map["a"], RecoveredValue::String(b) if b == &Bytes::from("new")));
431        assert!(matches!(&map["b"], RecoveredValue::String(b) if b == &Bytes::from("added")));
432    }
433
434    #[test]
435    fn del_removes_entry_during_replay() {
436        let dir = temp_dir();
437        let path = aof::aof_path(dir.path(), 0);
438
439        {
440            let mut writer = AofWriter::open(&path).unwrap();
441            writer
442                .write_record(&AofRecord::Set {
443                    key: "gone".into(),
444                    value: Bytes::from("temp"),
445                    expire_ms: -1,
446                })
447                .unwrap();
448            writer
449                .write_record(&AofRecord::Del { key: "gone".into() })
450                .unwrap();
451            writer.sync().unwrap();
452        }
453
454        let result = recover_shard(dir.path(), 0);
455        assert!(result.entries.is_empty());
456    }
457
458    #[test]
459    fn expired_entries_skipped() {
460        let dir = temp_dir();
461        let path = snapshot::snapshot_path(dir.path(), 0);
462
463        {
464            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
465            // this entry has 0ms remaining — already expired
466            writer
467                .write_entry(&SnapEntry {
468                    key: "dead".into(),
469                    value: SnapValue::String(Bytes::from("gone")),
470                    expire_ms: 0,
471                })
472                .unwrap();
473            // this one has plenty of time
474            writer
475                .write_entry(&SnapEntry {
476                    key: "alive".into(),
477                    value: SnapValue::String(Bytes::from("here")),
478                    expire_ms: 60_000,
479                })
480                .unwrap();
481            writer.finish().unwrap();
482        }
483
484        let result = recover_shard(dir.path(), 0);
485        assert_eq!(result.entries.len(), 1);
486        assert_eq!(result.entries[0].key, "alive");
487    }
488
489    #[test]
490    fn corrupt_snapshot_starts_empty() {
491        let dir = temp_dir();
492        let path = snapshot::snapshot_path(dir.path(), 0);
493
494        std::fs::write(&path, b"garbage data").unwrap();
495
496        let result = recover_shard(dir.path(), 0);
497        assert!(!result.loaded_snapshot);
498        assert!(result.entries.is_empty());
499    }
500
501    #[test]
502    fn sorted_set_snapshot_recovery() {
503        let dir = temp_dir();
504        let path = snapshot::snapshot_path(dir.path(), 0);
505
506        {
507            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
508            writer
509                .write_entry(&SnapEntry {
510                    key: "board".into(),
511                    value: SnapValue::SortedSet(vec![
512                        (100.0, "alice".into()),
513                        (200.0, "bob".into()),
514                    ]),
515                    expire_ms: -1,
516                })
517                .unwrap();
518            writer.finish().unwrap();
519        }
520
521        let result = recover_shard(dir.path(), 0);
522        assert!(result.loaded_snapshot);
523        assert_eq!(result.entries.len(), 1);
524        match &result.entries[0].value {
525            RecoveredValue::SortedSet(members) => {
526                assert_eq!(members.len(), 2);
527                assert!(members.contains(&(100.0, "alice".into())));
528                assert!(members.contains(&(200.0, "bob".into())));
529            }
530            other => panic!("expected SortedSet, got {other:?}"),
531        }
532    }
533
534    #[test]
535    fn sorted_set_aof_replay() {
536        let dir = temp_dir();
537        let path = aof::aof_path(dir.path(), 0);
538
539        {
540            let mut writer = AofWriter::open(&path).unwrap();
541            writer
542                .write_record(&AofRecord::ZAdd {
543                    key: "board".into(),
544                    members: vec![(100.0, "alice".into()), (200.0, "bob".into())],
545                })
546                .unwrap();
547            writer
548                .write_record(&AofRecord::ZRem {
549                    key: "board".into(),
550                    members: vec!["alice".into()],
551                })
552                .unwrap();
553            writer.sync().unwrap();
554        }
555
556        let result = recover_shard(dir.path(), 0);
557        assert!(result.replayed_aof);
558        assert_eq!(result.entries.len(), 1);
559        match &result.entries[0].value {
560            RecoveredValue::SortedSet(members) => {
561                assert_eq!(members.len(), 1);
562                assert_eq!(members[0], (200.0, "bob".into()));
563            }
564            other => panic!("expected SortedSet, got {other:?}"),
565        }
566    }
567
568    #[test]
569    fn sorted_set_zrem_auto_deletes_empty() {
570        let dir = temp_dir();
571        let path = aof::aof_path(dir.path(), 0);
572
573        {
574            let mut writer = AofWriter::open(&path).unwrap();
575            writer
576                .write_record(&AofRecord::ZAdd {
577                    key: "board".into(),
578                    members: vec![(100.0, "alice".into())],
579                })
580                .unwrap();
581            writer
582                .write_record(&AofRecord::ZRem {
583                    key: "board".into(),
584                    members: vec!["alice".into()],
585                })
586                .unwrap();
587            writer.sync().unwrap();
588        }
589
590        let result = recover_shard(dir.path(), 0);
591        assert!(result.entries.is_empty());
592    }
593
594    #[test]
595    fn expire_record_updates_ttl() {
596        let dir = temp_dir();
597        let path = aof::aof_path(dir.path(), 0);
598
599        {
600            let mut writer = AofWriter::open(&path).unwrap();
601            writer
602                .write_record(&AofRecord::Set {
603                    key: "k".into(),
604                    value: Bytes::from("v"),
605                    expire_ms: -1,
606                })
607                .unwrap();
608            writer
609                .write_record(&AofRecord::Expire {
610                    key: "k".into(),
611                    seconds: 300,
612                })
613                .unwrap();
614            writer.sync().unwrap();
615        }
616
617        let result = recover_shard(dir.path(), 0);
618        assert_eq!(result.entries.len(), 1);
619        assert!(result.entries[0].expires_at.is_some());
620    }
621
622    #[test]
623    fn persist_record_removes_ttl() {
624        let dir = temp_dir();
625        let path = aof::aof_path(dir.path(), 0);
626
627        {
628            let mut writer = AofWriter::open(&path).unwrap();
629            writer
630                .write_record(&AofRecord::Set {
631                    key: "k".into(),
632                    value: Bytes::from("v"),
633                    expire_ms: 60_000,
634                })
635                .unwrap();
636            writer
637                .write_record(&AofRecord::Persist { key: "k".into() })
638                .unwrap();
639            writer.sync().unwrap();
640        }
641
642        let result = recover_shard(dir.path(), 0);
643        assert_eq!(result.entries.len(), 1);
644        assert!(result.entries[0].expires_at.is_none());
645    }
646
647    #[test]
648    fn incr_decr_replay() {
649        let dir = temp_dir();
650        let path = aof::aof_path(dir.path(), 0);
651
652        {
653            let mut writer = AofWriter::open(&path).unwrap();
654            writer
655                .write_record(&AofRecord::Set {
656                    key: "n".into(),
657                    value: Bytes::from("10"),
658                    expire_ms: -1,
659                })
660                .unwrap();
661            writer
662                .write_record(&AofRecord::Incr { key: "n".into() })
663                .unwrap();
664            writer
665                .write_record(&AofRecord::Incr { key: "n".into() })
666                .unwrap();
667            writer
668                .write_record(&AofRecord::Decr { key: "n".into() })
669                .unwrap();
670            // also test INCR on a new key
671            writer
672                .write_record(&AofRecord::Incr {
673                    key: "fresh".into(),
674                })
675                .unwrap();
676            writer.sync().unwrap();
677        }
678
679        let result = recover_shard(dir.path(), 0);
680        let map: HashMap<_, _> = result
681            .entries
682            .iter()
683            .map(|e| (e.key.as_str(), e.value.clone()))
684            .collect();
685
686        // 10 + 1 + 1 - 1 = 11
687        match &map["n"] {
688            RecoveredValue::String(data) => assert_eq!(data, &Bytes::from("11")),
689            other => panic!("expected String(\"11\"), got {other:?}"),
690        }
691        // 0 + 1 = 1
692        match &map["fresh"] {
693            RecoveredValue::String(data) => assert_eq!(data, &Bytes::from("1")),
694            other => panic!("expected String(\"1\"), got {other:?}"),
695        }
696    }
697
698    #[test]
699    fn pexpire_record_sets_ttl() {
700        let dir = temp_dir();
701        let path = aof::aof_path(dir.path(), 0);
702
703        {
704            let mut writer = AofWriter::open(&path).unwrap();
705            writer
706                .write_record(&AofRecord::Set {
707                    key: "k".into(),
708                    value: Bytes::from("v"),
709                    expire_ms: -1,
710                })
711                .unwrap();
712            writer
713                .write_record(&AofRecord::Pexpire {
714                    key: "k".into(),
715                    milliseconds: 5000,
716                })
717                .unwrap();
718            writer.sync().unwrap();
719        }
720
721        let result = recover_shard(dir.path(), 0);
722        assert_eq!(result.entries.len(), 1);
723        assert!(result.entries[0].expires_at.is_some());
724    }
725}