Skip to main content

ember_persistence/
recovery.rs

1//! Recovery: loading snapshots and replaying AOF on shard startup.
2//!
3//! The recovery sequence is:
4//! 1. Load snapshot if it exists (bulk restore of entries).
5//! 2. Replay AOF if it exists (apply mutations on top of snapshot state).
6//! 3. Skip entries whose TTL expired during downtime.
7//! 4. If no files exist, start with an empty state.
8//! 5. If files are corrupt, log a warning and start empty.
9
10use std::collections::{HashMap, HashSet, VecDeque};
11use std::path::Path;
12use std::time::Duration;
13
14use bytes::Bytes;
15use tracing::{error, warn};
16
17use crate::aof::{self, AofReader, AofRecord};
18use crate::format::FormatError;
19use crate::snapshot::{self, SnapValue, SnapshotReader};
20
21/// Type alias for an optional encryption key reference. When the
22/// `encryption` feature is disabled, this is always `Option<&()>` —
23/// always `None` — and all encryption branches compile away.
24#[cfg(feature = "encryption")]
25type EncryptionKeyRef<'a> = &'a crate::encryption::EncryptionKey;
26#[cfg(not(feature = "encryption"))]
27type EncryptionKeyRef<'a> = &'a ();
28
29/// The value of a recovered entry.
30#[derive(Debug, Clone)]
31pub enum RecoveredValue {
32    String(Bytes),
33    List(VecDeque<Bytes>),
34    /// Sorted set stored as (score, member) pairs.
35    SortedSet(Vec<(f64, String)>),
36    /// Hash map of field names to values.
37    Hash(HashMap<String, Bytes>),
38    /// Unordered set of unique string members.
39    Set(HashSet<String>),
40    /// A vector set: index config + accumulated (element, vector) pairs.
41    #[cfg(feature = "vector")]
42    Vector {
43        metric: u8,
44        quantization: u8,
45        connectivity: u32,
46        expansion_add: u32,
47        elements: Vec<(String, Vec<f32>)>,
48    },
49    /// A protobuf message: type name + serialized bytes.
50    #[cfg(feature = "protobuf")]
51    Proto {
52        type_name: String,
53        data: Bytes,
54    },
55}
56
57impl From<SnapValue> for RecoveredValue {
58    fn from(sv: SnapValue) -> Self {
59        match sv {
60            SnapValue::String(data) => RecoveredValue::String(data),
61            SnapValue::List(deque) => RecoveredValue::List(deque),
62            SnapValue::SortedSet(members) => RecoveredValue::SortedSet(members),
63            SnapValue::Hash(map) => RecoveredValue::Hash(map),
64            SnapValue::Set(set) => RecoveredValue::Set(set),
65            #[cfg(feature = "vector")]
66            SnapValue::Vector {
67                metric,
68                quantization,
69                connectivity,
70                expansion_add,
71                elements,
72                ..
73            } => RecoveredValue::Vector {
74                metric,
75                quantization,
76                connectivity,
77                expansion_add,
78                elements,
79            },
80            #[cfg(feature = "protobuf")]
81            SnapValue::Proto { type_name, data } => RecoveredValue::Proto { type_name, data },
82        }
83    }
84}
85
86/// A single recovered entry ready to be inserted into a keyspace.
87#[derive(Debug, Clone)]
88pub struct RecoveredEntry {
89    pub key: String,
90    pub value: RecoveredValue,
91    /// Remaining TTL. `None` means no expiration.
92    pub ttl: Option<Duration>,
93}
94
95/// The result of recovering a shard's persisted state.
96#[derive(Debug)]
97pub struct RecoveryResult {
98    /// Recovered entries, keyed by name for easy insertion.
99    pub entries: Vec<RecoveredEntry>,
100    /// Whether a snapshot was loaded.
101    pub loaded_snapshot: bool,
102    /// Whether an AOF was replayed.
103    pub replayed_aof: bool,
104    /// Schemas found in the AOF, deduplicated by name (last wins).
105    /// Each entry is `(schema_name, descriptor_bytes)`.
106    #[cfg(feature = "protobuf")]
107    pub schemas: Vec<(String, Bytes)>,
108}
109
110/// Recovers a shard's state from snapshot and/or AOF files.
111///
112/// Returns a list of live entries to restore into the keyspace.
113/// Entries whose TTL expired during downtime are silently skipped.
114pub fn recover_shard(data_dir: &Path, shard_id: u16) -> RecoveryResult {
115    recover_shard_impl(data_dir, shard_id, None)
116}
117
118/// Recovers a shard's state with an encryption key for decrypting
119/// v3 persistence files. Also handles plaintext v2 files transparently.
120#[cfg(feature = "encryption")]
121pub fn recover_shard_encrypted(
122    data_dir: &Path,
123    shard_id: u16,
124    key: crate::encryption::EncryptionKey,
125) -> RecoveryResult {
126    recover_shard_impl(data_dir, shard_id, Some(&key))
127}
128
129/// Shared implementation. When encryption is not compiled in, the key
130/// parameter is always `None` and all encryption branches are dead code
131/// that the compiler will remove.
132fn recover_shard_impl(
133    data_dir: &Path,
134    shard_id: u16,
135    #[allow(unused_variables)] encryption_key: Option<EncryptionKeyRef<'_>>,
136) -> RecoveryResult {
137    // Track remaining TTL in ms (-1 = no expiry, 0+ = remaining ms)
138    let mut map: HashMap<String, (RecoveredValue, i64)> = HashMap::new();
139    let mut loaded_snapshot = false;
140    let mut replayed_aof = false;
141    #[cfg(feature = "protobuf")]
142    let mut schema_map: HashMap<String, Bytes> = HashMap::new();
143
144    // step 1: load snapshot
145    let snap_path = snapshot::snapshot_path(data_dir, shard_id);
146    if snap_path.exists() {
147        match load_snapshot(&snap_path, shard_id, encryption_key) {
148            Ok(entries) => {
149                for (key, value, ttl_ms) in entries {
150                    map.insert(key, (RecoveredValue::from(value), ttl_ms));
151                }
152                loaded_snapshot = true;
153            }
154            Err(e) => {
155                warn!(shard_id, "failed to load snapshot, starting empty: {e}");
156            }
157        }
158    }
159
160    // step 2: replay AOF
161    let aof_path = aof::aof_path(data_dir, shard_id);
162    if aof_path.exists() {
163        match replay_aof(
164            &aof_path,
165            &mut map,
166            encryption_key,
167            #[cfg(feature = "protobuf")]
168            &mut schema_map,
169        ) {
170            Ok(count) => {
171                if count > 0 {
172                    replayed_aof = true;
173                }
174            }
175            Err(e) => {
176                warn!(
177                    shard_id,
178                    "failed to replay aof, using snapshot state only: {e}"
179                );
180            }
181        }
182    }
183
184    // step 3: filter out expired entries (ttl_ms == 0) and build result.
185    //
186    // TTL values are preserved as-is during replay. Keys that expired while
187    // the server was down (ttl_ms == 0 after decrement in replay_aof) are
188    // skipped here. Any keys with positive remaining TTL will be lazily evicted
189    // on first access after startup.
190    let entries = map
191        .into_iter()
192        .filter(|(_, (_, ttl_ms))| *ttl_ms != 0) // 0 means expired, -1 means no expiry
193        .map(|(key, (value, ttl_ms))| RecoveredEntry {
194            key,
195            value,
196            ttl: if ttl_ms < 0 {
197                None
198            } else {
199                Some(Duration::from_millis(ttl_ms as u64))
200            },
201        })
202        .collect();
203
204    RecoveryResult {
205        entries,
206        loaded_snapshot,
207        replayed_aof,
208        #[cfg(feature = "protobuf")]
209        schemas: schema_map.into_iter().collect(),
210    }
211}
212
213/// Loads entries from a snapshot file.
214/// Returns (key, value, ttl_ms) where ttl_ms is -1 for no expiry.
215fn load_snapshot(
216    path: &Path,
217    expected_shard_id: u16,
218    #[allow(unused_variables)] encryption_key: Option<EncryptionKeyRef<'_>>,
219) -> Result<Vec<(String, SnapValue, i64)>, FormatError> {
220    #[cfg(feature = "encryption")]
221    let mut reader = if let Some(key) = encryption_key {
222        SnapshotReader::open_encrypted(path, key.clone())?
223    } else {
224        SnapshotReader::open(path)?
225    };
226    #[cfg(not(feature = "encryption"))]
227    let mut reader = SnapshotReader::open(path)?;
228
229    if reader.shard_id != expected_shard_id {
230        return Err(FormatError::InvalidData(format!(
231            "snapshot shard_id {} does not match expected {}",
232            reader.shard_id, expected_shard_id
233        )));
234    }
235
236    let mut entries = Vec::new();
237
238    while let Some(entry) = reader.read_entry()? {
239        // entry.expire_ms is -1 for no expiry, or remaining ms
240        entries.push((entry.key, entry.value, entry.expire_ms));
241    }
242
243    reader.verify_footer()?;
244    Ok(entries)
245}
246
247/// Applies an increment/decrement to a recovered entry. If the key doesn't
248/// exist, initializes it to "0" first.
249///
250/// Emits a warning and leaves the entry unchanged if the stored value is not
251/// a valid integer. This matches the liveness policy used at runtime (an error
252/// response rather than a crash), and gives operators a signal that the AOF
253/// contains unexpected data.
254fn apply_incr(map: &mut HashMap<String, (RecoveredValue, i64)>, key: String, delta: i64) {
255    // -1 means no expiry
256    let entry = map
257        .entry(key.clone())
258        .or_insert_with(|| (RecoveredValue::String(Bytes::from("0")), -1));
259    if let RecoveredValue::String(ref mut data) = entry.0 {
260        let current = std::str::from_utf8(data)
261            .ok()
262            .and_then(|s| s.parse::<i64>().ok());
263        if let Some(n) = current {
264            if let Some(new_val) = n.checked_add(delta) {
265                *data = Bytes::from(new_val.to_string());
266            } else {
267                error!(
268                    key = %key,
269                    value = n,
270                    delta,
271                    "INCR overflow during AOF replay: value unchanged — AOF may be corrupt"
272                );
273            }
274        } else {
275            error!(
276                key = %key,
277                "skipping INCR replay: stored value is not an integer — AOF may be corrupt"
278            );
279        }
280    }
281}
282
283/// Replays AOF records into the in-memory map. Returns the number of
284/// records replayed. TTL is stored as remaining ms (-1 = no expiry).
285fn replay_aof(
286    path: &Path,
287    map: &mut HashMap<String, (RecoveredValue, i64)>,
288    #[allow(unused_variables)] encryption_key: Option<EncryptionKeyRef<'_>>,
289    #[cfg(feature = "protobuf")] schema_map: &mut HashMap<String, Bytes>,
290) -> Result<usize, FormatError> {
291    #[cfg(feature = "encryption")]
292    let mut reader = if let Some(key) = encryption_key {
293        AofReader::open_encrypted(path, key.clone())?
294    } else {
295        AofReader::open(path)?
296    };
297    #[cfg(not(feature = "encryption"))]
298    let mut reader = AofReader::open(path)?;
299    let mut count = 0;
300
301    while let Some(record) = reader.read_record()? {
302        match record {
303            AofRecord::Set {
304                key,
305                value,
306                expire_ms,
307            } => {
308                // expire_ms is -1 for no expiry, or remaining ms
309                map.insert(key, (RecoveredValue::String(value), expire_ms));
310            }
311            AofRecord::Del { key } => {
312                map.remove(&key);
313            }
314            AofRecord::Expire { key, seconds } => {
315                if let Some(entry) = map.get_mut(&key) {
316                    entry.1 = seconds.saturating_mul(1000).min(i64::MAX as u64) as i64;
317                }
318            }
319            AofRecord::LPush { key, values } => {
320                let entry = map
321                    .entry(key)
322                    .or_insert_with(|| (RecoveredValue::List(VecDeque::new()), -1));
323                if let RecoveredValue::List(ref mut deque) = entry.0 {
324                    for v in values {
325                        deque.push_front(v);
326                    }
327                }
328            }
329            AofRecord::RPush { key, values } => {
330                let entry = map
331                    .entry(key)
332                    .or_insert_with(|| (RecoveredValue::List(VecDeque::new()), -1));
333                if let RecoveredValue::List(ref mut deque) = entry.0 {
334                    for v in values {
335                        deque.push_back(v);
336                    }
337                }
338            }
339            AofRecord::LPop { key } => {
340                if let Some(entry) = map.get_mut(&key) {
341                    if let RecoveredValue::List(ref mut deque) = entry.0 {
342                        deque.pop_front();
343                        if deque.is_empty() {
344                            map.remove(&key);
345                            count += 1;
346                            continue;
347                        }
348                    }
349                }
350            }
351            AofRecord::RPop { key } => {
352                if let Some(entry) = map.get_mut(&key) {
353                    if let RecoveredValue::List(ref mut deque) = entry.0 {
354                        deque.pop_back();
355                        if deque.is_empty() {
356                            map.remove(&key);
357                            count += 1;
358                            continue;
359                        }
360                    }
361                }
362            }
363            AofRecord::ZAdd { key, members } => {
364                let entry = map
365                    .entry(key)
366                    .or_insert_with(|| (RecoveredValue::SortedSet(Vec::new()), -1));
367                if let RecoveredValue::SortedSet(ref mut existing) = entry.0 {
368                    // build a position index for O(1) member lookups
369                    let mut index: HashMap<String, usize> = existing
370                        .iter()
371                        .enumerate()
372                        .map(|(i, (_, m))| (m.clone(), i))
373                        .collect();
374                    for (score, member) in members {
375                        if let Some(&pos) = index.get(&member) {
376                            existing[pos].0 = score;
377                        } else {
378                            let pos = existing.len();
379                            index.insert(member.clone(), pos);
380                            existing.push((score, member));
381                        }
382                    }
383                }
384            }
385            AofRecord::ZRem { key, members } => {
386                if let Some(entry) = map.get_mut(&key) {
387                    if let RecoveredValue::SortedSet(ref mut existing) = entry.0 {
388                        let to_remove: HashSet<&str> = members.iter().map(|m| m.as_str()).collect();
389                        existing.retain(|(_, m)| !to_remove.contains(m.as_str()));
390                        if existing.is_empty() {
391                            map.remove(&key);
392                            count += 1;
393                            continue;
394                        }
395                    }
396                }
397            }
398            AofRecord::Persist { key } => {
399                if let Some(entry) = map.get_mut(&key) {
400                    entry.1 = -1; // -1 means no expiry
401                }
402            }
403            AofRecord::Pexpire { key, milliseconds } => {
404                if let Some(entry) = map.get_mut(&key) {
405                    entry.1 = milliseconds.min(i64::MAX as u64) as i64;
406                }
407            }
408            AofRecord::Incr { key } => {
409                apply_incr(map, key, 1);
410            }
411            AofRecord::Decr { key } => {
412                apply_incr(map, key, -1);
413            }
414            AofRecord::IncrBy { key, delta } => {
415                apply_incr(map, key, delta);
416            }
417            AofRecord::DecrBy { key, delta } => {
418                apply_incr(map, key, delta.saturating_neg());
419            }
420            AofRecord::Append { key, value } => {
421                let entry = map
422                    .entry(key)
423                    .or_insert_with(|| (RecoveredValue::String(Bytes::new()), -1));
424                if let RecoveredValue::String(ref mut data) = entry.0 {
425                    let mut new_data = Vec::with_capacity(data.len() + value.len());
426                    new_data.extend_from_slice(data);
427                    new_data.extend_from_slice(&value);
428                    *data = Bytes::from(new_data);
429                }
430            }
431            AofRecord::Rename { key, newkey } => {
432                if let Some(entry) = map.remove(&key) {
433                    map.insert(newkey, entry);
434                }
435            }
436            AofRecord::HSet { key, fields } => {
437                let entry = map
438                    .entry(key)
439                    .or_insert_with(|| (RecoveredValue::Hash(HashMap::new()), -1));
440                if let RecoveredValue::Hash(ref mut hash) = entry.0 {
441                    for (field, value) in fields {
442                        hash.insert(field, value);
443                    }
444                }
445            }
446            AofRecord::HDel { key, fields } => {
447                if let Some(entry) = map.get_mut(&key) {
448                    if let RecoveredValue::Hash(ref mut hash) = entry.0 {
449                        for field in fields {
450                            hash.remove(&field);
451                        }
452                        if hash.is_empty() {
453                            map.remove(&key);
454                            count += 1;
455                            continue;
456                        }
457                    }
458                }
459            }
460            AofRecord::HIncrBy { key, field, delta } => {
461                let entry = map
462                    .entry(key)
463                    .or_insert_with(|| (RecoveredValue::Hash(HashMap::new()), -1));
464                if let RecoveredValue::Hash(ref mut hash) = entry.0 {
465                    let current: i64 = hash
466                        .get(&field)
467                        .and_then(|v| std::str::from_utf8(v).ok())
468                        .and_then(|s| s.parse().ok())
469                        .unwrap_or(0);
470                    let new_val = current.saturating_add(delta);
471                    hash.insert(field, Bytes::from(new_val.to_string()));
472                }
473            }
474            AofRecord::SAdd { key, members } => {
475                let entry = map
476                    .entry(key)
477                    .or_insert_with(|| (RecoveredValue::Set(HashSet::new()), -1));
478                if let RecoveredValue::Set(ref mut set) = entry.0 {
479                    for member in members {
480                        set.insert(member);
481                    }
482                }
483            }
484            AofRecord::SRem { key, members } => {
485                if let Some(entry) = map.get_mut(&key) {
486                    if let RecoveredValue::Set(ref mut set) = entry.0 {
487                        for member in members {
488                            set.remove(&member);
489                        }
490                        if set.is_empty() {
491                            map.remove(&key);
492                            count += 1;
493                            continue;
494                        }
495                    }
496                }
497            }
498            #[cfg(feature = "vector")]
499            AofRecord::VAdd {
500                key,
501                element,
502                vector,
503                metric,
504                quantization,
505                connectivity,
506                expansion_add,
507            } => {
508                let entry = map.entry(key).or_insert_with(|| {
509                    (
510                        RecoveredValue::Vector {
511                            metric,
512                            quantization,
513                            connectivity,
514                            expansion_add,
515                            elements: Vec::new(),
516                        },
517                        -1, // no expiry for vector sets
518                    )
519                });
520                if let RecoveredValue::Vector {
521                    ref mut elements, ..
522                } = entry.0
523                {
524                    // replace existing element or add new
525                    if let Some(pos) = elements.iter().position(|(e, _)| *e == element) {
526                        elements[pos].1 = vector;
527                    } else {
528                        elements.push((element, vector));
529                    }
530                }
531            }
532            #[cfg(feature = "vector")]
533            AofRecord::VRem { key, element } => {
534                if let Some(entry) = map.get_mut(&key) {
535                    if let RecoveredValue::Vector {
536                        ref mut elements, ..
537                    } = entry.0
538                    {
539                        elements.retain(|(e, _)| *e != element);
540                        if elements.is_empty() {
541                            map.remove(&key);
542                        }
543                    }
544                }
545            }
546            #[cfg(feature = "protobuf")]
547            AofRecord::ProtoSet {
548                key,
549                type_name,
550                data,
551                expire_ms,
552            } => {
553                map.insert(key, (RecoveredValue::Proto { type_name, data }, expire_ms));
554            }
555            #[cfg(feature = "protobuf")]
556            AofRecord::ProtoRegister { name, descriptor } => {
557                // last-wins: if the same schema name appears multiple times
558                // in the AOF, the final registration is the one we keep.
559                schema_map.insert(name, descriptor);
560            }
561        }
562        count += 1;
563    }
564
565    Ok(count)
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571    use crate::aof::AofWriter;
572    use crate::snapshot::{SnapEntry, SnapValue, SnapshotWriter};
573
574    fn temp_dir() -> tempfile::TempDir {
575        tempfile::tempdir().expect("create temp dir")
576    }
577
578    #[test]
579    fn empty_dir_returns_empty_result() {
580        let dir = temp_dir();
581        let result = recover_shard(dir.path(), 0);
582        assert!(result.entries.is_empty());
583        assert!(!result.loaded_snapshot);
584        assert!(!result.replayed_aof);
585    }
586
587    #[test]
588    fn snapshot_only_recovery() {
589        let dir = temp_dir();
590        let path = snapshot::snapshot_path(dir.path(), 0);
591
592        {
593            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
594            writer
595                .write_entry(&SnapEntry {
596                    key: "a".into(),
597                    value: SnapValue::String(Bytes::from("1")),
598                    expire_ms: -1,
599                })
600                .unwrap();
601            writer
602                .write_entry(&SnapEntry {
603                    key: "b".into(),
604                    value: SnapValue::String(Bytes::from("2")),
605                    expire_ms: 60_000,
606                })
607                .unwrap();
608            writer.finish().unwrap();
609        }
610
611        let result = recover_shard(dir.path(), 0);
612        assert!(result.loaded_snapshot);
613        assert!(!result.replayed_aof);
614        assert_eq!(result.entries.len(), 2);
615    }
616
617    #[test]
618    fn aof_only_recovery() {
619        let dir = temp_dir();
620        let path = aof::aof_path(dir.path(), 0);
621
622        {
623            let mut writer = AofWriter::open(&path).unwrap();
624            writer
625                .write_record(&AofRecord::Set {
626                    key: "x".into(),
627                    value: Bytes::from("10"),
628                    expire_ms: -1,
629                })
630                .unwrap();
631            writer
632                .write_record(&AofRecord::Set {
633                    key: "y".into(),
634                    value: Bytes::from("20"),
635                    expire_ms: -1,
636                })
637                .unwrap();
638            writer.sync().unwrap();
639        }
640
641        let result = recover_shard(dir.path(), 0);
642        assert!(!result.loaded_snapshot);
643        assert!(result.replayed_aof);
644        assert_eq!(result.entries.len(), 2);
645    }
646
647    #[test]
648    fn snapshot_plus_aof_overlay() {
649        let dir = temp_dir();
650
651        // snapshot with key "a" = "old"
652        {
653            let path = snapshot::snapshot_path(dir.path(), 0);
654            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
655            writer
656                .write_entry(&SnapEntry {
657                    key: "a".into(),
658                    value: SnapValue::String(Bytes::from("old")),
659                    expire_ms: -1,
660                })
661                .unwrap();
662            writer.finish().unwrap();
663        }
664
665        // AOF overwrites "a" to "new" and adds "b"
666        {
667            let path = aof::aof_path(dir.path(), 0);
668            let mut writer = AofWriter::open(&path).unwrap();
669            writer
670                .write_record(&AofRecord::Set {
671                    key: "a".into(),
672                    value: Bytes::from("new"),
673                    expire_ms: -1,
674                })
675                .unwrap();
676            writer
677                .write_record(&AofRecord::Set {
678                    key: "b".into(),
679                    value: Bytes::from("added"),
680                    expire_ms: -1,
681                })
682                .unwrap();
683            writer.sync().unwrap();
684        }
685
686        let result = recover_shard(dir.path(), 0);
687        assert!(result.loaded_snapshot);
688        assert!(result.replayed_aof);
689
690        let map: HashMap<_, _> = result
691            .entries
692            .iter()
693            .map(|e| (e.key.as_str(), e.value.clone()))
694            .collect();
695        assert!(matches!(&map["a"], RecoveredValue::String(b) if b == &Bytes::from("new")));
696        assert!(matches!(&map["b"], RecoveredValue::String(b) if b == &Bytes::from("added")));
697    }
698
699    #[test]
700    fn del_removes_entry_during_replay() {
701        let dir = temp_dir();
702        let path = aof::aof_path(dir.path(), 0);
703
704        {
705            let mut writer = AofWriter::open(&path).unwrap();
706            writer
707                .write_record(&AofRecord::Set {
708                    key: "gone".into(),
709                    value: Bytes::from("temp"),
710                    expire_ms: -1,
711                })
712                .unwrap();
713            writer
714                .write_record(&AofRecord::Del { key: "gone".into() })
715                .unwrap();
716            writer.sync().unwrap();
717        }
718
719        let result = recover_shard(dir.path(), 0);
720        assert!(result.entries.is_empty());
721    }
722
723    #[test]
724    fn expired_entries_skipped() {
725        let dir = temp_dir();
726        let path = snapshot::snapshot_path(dir.path(), 0);
727
728        {
729            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
730            // this entry has 0ms remaining — already expired
731            writer
732                .write_entry(&SnapEntry {
733                    key: "dead".into(),
734                    value: SnapValue::String(Bytes::from("gone")),
735                    expire_ms: 0,
736                })
737                .unwrap();
738            // this one has plenty of time
739            writer
740                .write_entry(&SnapEntry {
741                    key: "alive".into(),
742                    value: SnapValue::String(Bytes::from("here")),
743                    expire_ms: 60_000,
744                })
745                .unwrap();
746            writer.finish().unwrap();
747        }
748
749        let result = recover_shard(dir.path(), 0);
750        assert_eq!(result.entries.len(), 1);
751        assert_eq!(result.entries[0].key, "alive");
752    }
753
754    #[test]
755    fn corrupt_snapshot_starts_empty() {
756        let dir = temp_dir();
757        let path = snapshot::snapshot_path(dir.path(), 0);
758
759        std::fs::write(&path, b"garbage data").unwrap();
760
761        let result = recover_shard(dir.path(), 0);
762        assert!(!result.loaded_snapshot);
763        assert!(result.entries.is_empty());
764    }
765
766    #[test]
767    fn sorted_set_snapshot_recovery() {
768        let dir = temp_dir();
769        let path = snapshot::snapshot_path(dir.path(), 0);
770
771        {
772            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
773            writer
774                .write_entry(&SnapEntry {
775                    key: "board".into(),
776                    value: SnapValue::SortedSet(vec![
777                        (100.0, "alice".into()),
778                        (200.0, "bob".into()),
779                    ]),
780                    expire_ms: -1,
781                })
782                .unwrap();
783            writer.finish().unwrap();
784        }
785
786        let result = recover_shard(dir.path(), 0);
787        assert!(result.loaded_snapshot);
788        assert_eq!(result.entries.len(), 1);
789        match &result.entries[0].value {
790            RecoveredValue::SortedSet(members) => {
791                assert_eq!(members.len(), 2);
792                assert!(members.contains(&(100.0, "alice".into())));
793                assert!(members.contains(&(200.0, "bob".into())));
794            }
795            other => panic!("expected SortedSet, got {other:?}"),
796        }
797    }
798
799    #[test]
800    fn sorted_set_aof_replay() {
801        let dir = temp_dir();
802        let path = aof::aof_path(dir.path(), 0);
803
804        {
805            let mut writer = AofWriter::open(&path).unwrap();
806            writer
807                .write_record(&AofRecord::ZAdd {
808                    key: "board".into(),
809                    members: vec![(100.0, "alice".into()), (200.0, "bob".into())],
810                })
811                .unwrap();
812            writer
813                .write_record(&AofRecord::ZRem {
814                    key: "board".into(),
815                    members: vec!["alice".into()],
816                })
817                .unwrap();
818            writer.sync().unwrap();
819        }
820
821        let result = recover_shard(dir.path(), 0);
822        assert!(result.replayed_aof);
823        assert_eq!(result.entries.len(), 1);
824        match &result.entries[0].value {
825            RecoveredValue::SortedSet(members) => {
826                assert_eq!(members.len(), 1);
827                assert_eq!(members[0], (200.0, "bob".into()));
828            }
829            other => panic!("expected SortedSet, got {other:?}"),
830        }
831    }
832
833    #[test]
834    fn sorted_set_zrem_auto_deletes_empty() {
835        let dir = temp_dir();
836        let path = aof::aof_path(dir.path(), 0);
837
838        {
839            let mut writer = AofWriter::open(&path).unwrap();
840            writer
841                .write_record(&AofRecord::ZAdd {
842                    key: "board".into(),
843                    members: vec![(100.0, "alice".into())],
844                })
845                .unwrap();
846            writer
847                .write_record(&AofRecord::ZRem {
848                    key: "board".into(),
849                    members: vec!["alice".into()],
850                })
851                .unwrap();
852            writer.sync().unwrap();
853        }
854
855        let result = recover_shard(dir.path(), 0);
856        assert!(result.entries.is_empty());
857    }
858
859    #[test]
860    fn expire_record_updates_ttl() {
861        let dir = temp_dir();
862        let path = aof::aof_path(dir.path(), 0);
863
864        {
865            let mut writer = AofWriter::open(&path).unwrap();
866            writer
867                .write_record(&AofRecord::Set {
868                    key: "k".into(),
869                    value: Bytes::from("v"),
870                    expire_ms: -1,
871                })
872                .unwrap();
873            writer
874                .write_record(&AofRecord::Expire {
875                    key: "k".into(),
876                    seconds: 300,
877                })
878                .unwrap();
879            writer.sync().unwrap();
880        }
881
882        let result = recover_shard(dir.path(), 0);
883        assert_eq!(result.entries.len(), 1);
884        assert!(result.entries[0].ttl.is_some());
885    }
886
887    #[test]
888    fn persist_record_removes_ttl() {
889        let dir = temp_dir();
890        let path = aof::aof_path(dir.path(), 0);
891
892        {
893            let mut writer = AofWriter::open(&path).unwrap();
894            writer
895                .write_record(&AofRecord::Set {
896                    key: "k".into(),
897                    value: Bytes::from("v"),
898                    expire_ms: 60_000,
899                })
900                .unwrap();
901            writer
902                .write_record(&AofRecord::Persist { key: "k".into() })
903                .unwrap();
904            writer.sync().unwrap();
905        }
906
907        let result = recover_shard(dir.path(), 0);
908        assert_eq!(result.entries.len(), 1);
909        assert!(result.entries[0].ttl.is_none());
910    }
911
912    #[test]
913    fn incr_decr_replay() {
914        let dir = temp_dir();
915        let path = aof::aof_path(dir.path(), 0);
916
917        {
918            let mut writer = AofWriter::open(&path).unwrap();
919            writer
920                .write_record(&AofRecord::Set {
921                    key: "n".into(),
922                    value: Bytes::from("10"),
923                    expire_ms: -1,
924                })
925                .unwrap();
926            writer
927                .write_record(&AofRecord::Incr { key: "n".into() })
928                .unwrap();
929            writer
930                .write_record(&AofRecord::Incr { key: "n".into() })
931                .unwrap();
932            writer
933                .write_record(&AofRecord::Decr { key: "n".into() })
934                .unwrap();
935            // also test INCR on a new key
936            writer
937                .write_record(&AofRecord::Incr {
938                    key: "fresh".into(),
939                })
940                .unwrap();
941            writer.sync().unwrap();
942        }
943
944        let result = recover_shard(dir.path(), 0);
945        let map: HashMap<_, _> = result
946            .entries
947            .iter()
948            .map(|e| (e.key.as_str(), e.value.clone()))
949            .collect();
950
951        // 10 + 1 + 1 - 1 = 11
952        match &map["n"] {
953            RecoveredValue::String(data) => assert_eq!(data, &Bytes::from("11")),
954            other => panic!("expected String(\"11\"), got {other:?}"),
955        }
956        // 0 + 1 = 1
957        match &map["fresh"] {
958            RecoveredValue::String(data) => assert_eq!(data, &Bytes::from("1")),
959            other => panic!("expected String(\"1\"), got {other:?}"),
960        }
961    }
962
963    #[test]
964    fn pexpire_record_sets_ttl() {
965        let dir = temp_dir();
966        let path = aof::aof_path(dir.path(), 0);
967
968        {
969            let mut writer = AofWriter::open(&path).unwrap();
970            writer
971                .write_record(&AofRecord::Set {
972                    key: "k".into(),
973                    value: Bytes::from("v"),
974                    expire_ms: -1,
975                })
976                .unwrap();
977            writer
978                .write_record(&AofRecord::Pexpire {
979                    key: "k".into(),
980                    milliseconds: 5000,
981                })
982                .unwrap();
983            writer.sync().unwrap();
984        }
985
986        let result = recover_shard(dir.path(), 0);
987        assert_eq!(result.entries.len(), 1);
988        assert!(result.entries[0].ttl.is_some());
989    }
990
991    #[cfg(feature = "vector")]
992    #[test]
993    fn vector_snapshot_recovery() {
994        let dir = temp_dir();
995        let path = snapshot::snapshot_path(dir.path(), 0);
996
997        {
998            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
999            writer
1000                .write_entry(&SnapEntry {
1001                    key: "embeddings".into(),
1002                    value: SnapValue::Vector {
1003                        metric: 0,
1004                        quantization: 0,
1005                        connectivity: 16,
1006                        expansion_add: 64,
1007                        dim: 3,
1008                        elements: vec![
1009                            ("doc1".into(), vec![1.0, 0.0, 0.0]),
1010                            ("doc2".into(), vec![0.0, 1.0, 0.0]),
1011                        ],
1012                    },
1013                    expire_ms: -1,
1014                })
1015                .unwrap();
1016            writer.finish().unwrap();
1017        }
1018
1019        let result = recover_shard(dir.path(), 0);
1020        assert!(result.loaded_snapshot);
1021        assert_eq!(result.entries.len(), 1);
1022        match &result.entries[0].value {
1023            RecoveredValue::Vector {
1024                metric,
1025                quantization,
1026                elements,
1027                ..
1028            } => {
1029                assert_eq!(*metric, 0);
1030                assert_eq!(*quantization, 0);
1031                assert_eq!(elements.len(), 2);
1032                // dim is inferred from the vector length
1033                assert_eq!(elements[0].1.len(), 3);
1034            }
1035            other => panic!("expected Vector, got {other:?}"),
1036        }
1037    }
1038
1039    #[cfg(feature = "vector")]
1040    #[test]
1041    fn vector_aof_replay() {
1042        let dir = temp_dir();
1043        let path = aof::aof_path(dir.path(), 0);
1044
1045        {
1046            let mut writer = AofWriter::open(&path).unwrap();
1047            writer
1048                .write_record(&AofRecord::VAdd {
1049                    key: "vecs".into(),
1050                    element: "a".into(),
1051                    vector: vec![1.0, 0.0, 0.0],
1052                    metric: 0,
1053                    quantization: 0,
1054                    connectivity: 16,
1055                    expansion_add: 64,
1056                })
1057                .unwrap();
1058            writer
1059                .write_record(&AofRecord::VAdd {
1060                    key: "vecs".into(),
1061                    element: "b".into(),
1062                    vector: vec![0.0, 1.0, 0.0],
1063                    metric: 0,
1064                    quantization: 0,
1065                    connectivity: 16,
1066                    expansion_add: 64,
1067                })
1068                .unwrap();
1069            writer
1070                .write_record(&AofRecord::VRem {
1071                    key: "vecs".into(),
1072                    element: "a".into(),
1073                })
1074                .unwrap();
1075            writer.sync().unwrap();
1076        }
1077
1078        let result = recover_shard(dir.path(), 0);
1079        assert!(result.replayed_aof);
1080        assert_eq!(result.entries.len(), 1);
1081        match &result.entries[0].value {
1082            RecoveredValue::Vector { elements, .. } => {
1083                assert_eq!(elements.len(), 1);
1084                assert_eq!(elements[0].0, "b");
1085            }
1086            other => panic!("expected Vector, got {other:?}"),
1087        }
1088    }
1089
1090    #[cfg(feature = "vector")]
1091    #[test]
1092    fn vector_vrem_auto_deletes_empty() {
1093        let dir = temp_dir();
1094        let path = aof::aof_path(dir.path(), 0);
1095
1096        {
1097            let mut writer = AofWriter::open(&path).unwrap();
1098            writer
1099                .write_record(&AofRecord::VAdd {
1100                    key: "vecs".into(),
1101                    element: "only".into(),
1102                    vector: vec![1.0, 2.0],
1103                    metric: 0,
1104                    quantization: 0,
1105                    connectivity: 16,
1106                    expansion_add: 64,
1107                })
1108                .unwrap();
1109            writer
1110                .write_record(&AofRecord::VRem {
1111                    key: "vecs".into(),
1112                    element: "only".into(),
1113                })
1114                .unwrap();
1115            writer.sync().unwrap();
1116        }
1117
1118        let result = recover_shard(dir.path(), 0);
1119        assert!(result.entries.is_empty());
1120    }
1121
1122    #[cfg(feature = "protobuf")]
1123    #[test]
1124    fn proto_schemas_recovered_from_aof() {
1125        let dir = temp_dir();
1126        let path = aof::aof_path(dir.path(), 0);
1127
1128        {
1129            let mut writer = AofWriter::open(&path).unwrap();
1130            writer
1131                .write_record(&AofRecord::ProtoRegister {
1132                    name: "users".into(),
1133                    descriptor: Bytes::from("fake-descriptor-a"),
1134                })
1135                .unwrap();
1136            // a proto value that depends on the schema
1137            writer
1138                .write_record(&AofRecord::ProtoSet {
1139                    key: "user:1".into(),
1140                    type_name: "test.User".into(),
1141                    data: Bytes::from("some-proto-data"),
1142                    expire_ms: -1,
1143                })
1144                .unwrap();
1145            // re-registration of same schema (last wins)
1146            writer
1147                .write_record(&AofRecord::ProtoRegister {
1148                    name: "users".into(),
1149                    descriptor: Bytes::from("fake-descriptor-b"),
1150                })
1151                .unwrap();
1152            writer.sync().unwrap();
1153        }
1154
1155        let result = recover_shard(dir.path(), 0);
1156        assert!(result.replayed_aof);
1157        assert_eq!(result.entries.len(), 1);
1158
1159        // schemas should be collected with last-wins dedup
1160        assert_eq!(result.schemas.len(), 1);
1161        let (name, desc) = &result.schemas[0];
1162        assert_eq!(name, "users");
1163        assert_eq!(desc, &Bytes::from("fake-descriptor-b"));
1164    }
1165}