Skip to main content

ember_persistence/
recovery.rs

1//! Recovery: loading snapshots and replaying AOF on shard startup.
2//!
3//! The recovery sequence is:
4//! 1. Load snapshot if it exists (bulk restore of entries).
5//! 2. Replay AOF if it exists (apply mutations on top of snapshot state).
6//! 3. Skip entries whose TTL expired during downtime.
7//! 4. If no files exist, start with an empty state.
8//! 5. If files are corrupt, log a warning and start empty.
9
10use std::collections::{HashMap, HashSet, VecDeque};
11use std::path::Path;
12use std::time::Duration;
13
14use bytes::Bytes;
15use tracing::{error, warn};
16
17use crate::aof::{self, AofReader, AofRecord};
18use crate::format::FormatError;
19use crate::snapshot::{self, SnapValue, SnapshotReader};
20
21/// Type alias for an optional encryption key reference. When the
22/// `encryption` feature is disabled, this is always `Option<&()>` —
23/// always `None` — and all encryption branches compile away.
24#[cfg(feature = "encryption")]
25type EncryptionKeyRef<'a> = &'a crate::encryption::EncryptionKey;
26#[cfg(not(feature = "encryption"))]
27type EncryptionKeyRef<'a> = &'a ();
28
29/// The value of a recovered entry.
30#[derive(Debug, Clone)]
31pub enum RecoveredValue {
32    String(Bytes),
33    List(VecDeque<Bytes>),
34    /// Sorted set stored as (score, member) pairs.
35    SortedSet(Vec<(f64, String)>),
36    /// Hash map of field names to values.
37    Hash(HashMap<String, Bytes>),
38    /// Unordered set of unique string members.
39    Set(HashSet<String>),
40    /// A vector set: index config + accumulated (element, vector) pairs.
41    #[cfg(feature = "vector")]
42    Vector {
43        metric: u8,
44        quantization: u8,
45        connectivity: u32,
46        expansion_add: u32,
47        elements: Vec<(String, Vec<f32>)>,
48    },
49    /// A protobuf message: type name + serialized bytes.
50    #[cfg(feature = "protobuf")]
51    Proto {
52        type_name: String,
53        data: Bytes,
54    },
55}
56
57impl From<SnapValue> for RecoveredValue {
58    fn from(sv: SnapValue) -> Self {
59        match sv {
60            SnapValue::String(data) => RecoveredValue::String(data),
61            SnapValue::List(deque) => RecoveredValue::List(deque),
62            SnapValue::SortedSet(members) => RecoveredValue::SortedSet(members),
63            SnapValue::Hash(map) => RecoveredValue::Hash(map),
64            SnapValue::Set(set) => RecoveredValue::Set(set),
65            #[cfg(feature = "vector")]
66            SnapValue::Vector {
67                metric,
68                quantization,
69                connectivity,
70                expansion_add,
71                elements,
72                ..
73            } => RecoveredValue::Vector {
74                metric,
75                quantization,
76                connectivity,
77                expansion_add,
78                elements,
79            },
80            #[cfg(feature = "protobuf")]
81            SnapValue::Proto { type_name, data } => RecoveredValue::Proto { type_name, data },
82        }
83    }
84}
85
86/// A single recovered entry ready to be inserted into a keyspace.
87#[derive(Debug, Clone)]
88pub struct RecoveredEntry {
89    pub key: String,
90    pub value: RecoveredValue,
91    /// Remaining TTL. `None` means no expiration.
92    pub ttl: Option<Duration>,
93}
94
95/// The result of recovering a shard's persisted state.
96#[derive(Debug)]
97pub struct RecoveryResult {
98    /// Recovered entries, keyed by name for easy insertion.
99    pub entries: Vec<RecoveredEntry>,
100    /// Whether a snapshot was loaded.
101    pub loaded_snapshot: bool,
102    /// Whether an AOF was replayed.
103    pub replayed_aof: bool,
104    /// Schemas found in the AOF, deduplicated by name (last wins).
105    /// Each entry is `(schema_name, descriptor_bytes)`.
106    #[cfg(feature = "protobuf")]
107    pub schemas: Vec<(String, Bytes)>,
108}
109
110/// Recovers a shard's state from snapshot and/or AOF files.
111///
112/// Returns a list of live entries to restore into the keyspace.
113/// Entries whose TTL expired during downtime are silently skipped.
114pub fn recover_shard(data_dir: &Path, shard_id: u16) -> RecoveryResult {
115    recover_shard_impl(data_dir, shard_id, None)
116}
117
118/// Recovers a shard's state with an encryption key for decrypting
119/// v3 persistence files. Also handles plaintext v2 files transparently.
120#[cfg(feature = "encryption")]
121pub fn recover_shard_encrypted(
122    data_dir: &Path,
123    shard_id: u16,
124    key: crate::encryption::EncryptionKey,
125) -> RecoveryResult {
126    recover_shard_impl(data_dir, shard_id, Some(&key))
127}
128
129/// Shared implementation. When encryption is not compiled in, the key
130/// parameter is always `None` and all encryption branches are dead code
131/// that the compiler will remove.
132fn recover_shard_impl(
133    data_dir: &Path,
134    shard_id: u16,
135    #[allow(unused_variables)] encryption_key: Option<EncryptionKeyRef<'_>>,
136) -> RecoveryResult {
137    // Track remaining TTL in ms (-1 = no expiry, 0+ = remaining ms)
138    let mut map: HashMap<String, (RecoveredValue, i64)> = HashMap::new();
139    let mut loaded_snapshot = false;
140    let mut replayed_aof = false;
141    #[cfg(feature = "protobuf")]
142    let mut schema_map: HashMap<String, Bytes> = HashMap::new();
143
144    // step 1: load snapshot
145    let snap_path = snapshot::snapshot_path(data_dir, shard_id);
146    if snap_path.exists() {
147        match load_snapshot(&snap_path, shard_id, encryption_key) {
148            Ok(entries) => {
149                for (key, value, ttl_ms) in entries {
150                    map.insert(key, (RecoveredValue::from(value), ttl_ms));
151                }
152                loaded_snapshot = true;
153            }
154            Err(e) => {
155                warn!(shard_id, "failed to load snapshot, starting empty: {e}");
156            }
157        }
158    }
159
160    // step 2: replay AOF
161    let aof_path = aof::aof_path(data_dir, shard_id);
162    if aof_path.exists() {
163        match replay_aof(
164            &aof_path,
165            &mut map,
166            encryption_key,
167            #[cfg(feature = "protobuf")]
168            &mut schema_map,
169        ) {
170            Ok(count) => {
171                if count > 0 {
172                    replayed_aof = true;
173                }
174            }
175            Err(e) => {
176                warn!(
177                    shard_id,
178                    "failed to replay aof, using snapshot state only: {e}"
179                );
180            }
181        }
182    }
183
184    // step 3: filter out expired entries (ttl_ms == 0) and build result.
185    //
186    // TTL values are preserved as-is during replay. Keys that expired while
187    // the server was down (ttl_ms == 0 after decrement in replay_aof) are
188    // skipped here. Any keys with positive remaining TTL will be lazily evicted
189    // on first access after startup.
190    let entries = map
191        .into_iter()
192        .filter(|(_, (_, ttl_ms))| *ttl_ms != 0) // 0 means expired, -1 means no expiry
193        .map(|(key, (value, ttl_ms))| RecoveredEntry {
194            key,
195            value,
196            ttl: if ttl_ms < 0 {
197                None
198            } else {
199                Some(Duration::from_millis(ttl_ms as u64))
200            },
201        })
202        .collect();
203
204    RecoveryResult {
205        entries,
206        loaded_snapshot,
207        replayed_aof,
208        #[cfg(feature = "protobuf")]
209        schemas: schema_map.into_iter().collect(),
210    }
211}
212
213/// Loads entries from a snapshot file.
214/// Returns (key, value, ttl_ms) where ttl_ms is -1 for no expiry.
215fn load_snapshot(
216    path: &Path,
217    expected_shard_id: u16,
218    #[allow(unused_variables)] encryption_key: Option<EncryptionKeyRef<'_>>,
219) -> Result<Vec<(String, SnapValue, i64)>, FormatError> {
220    #[cfg(feature = "encryption")]
221    let mut reader = if let Some(key) = encryption_key {
222        SnapshotReader::open_encrypted(path, key.clone())?
223    } else {
224        SnapshotReader::open(path)?
225    };
226    #[cfg(not(feature = "encryption"))]
227    let mut reader = SnapshotReader::open(path)?;
228
229    if reader.shard_id != expected_shard_id {
230        return Err(FormatError::InvalidData(format!(
231            "snapshot shard_id {} does not match expected {}",
232            reader.shard_id, expected_shard_id
233        )));
234    }
235
236    let mut entries = Vec::new();
237
238    while let Some(entry) = reader.read_entry()? {
239        // entry.expire_ms is -1 for no expiry, or remaining ms
240        entries.push((entry.key, entry.value, entry.expire_ms));
241    }
242
243    reader.verify_footer()?;
244    Ok(entries)
245}
246
247/// Applies an increment/decrement to a recovered entry. If the key doesn't
248/// exist, initializes it to "0" first.
249///
250/// Emits a warning and leaves the entry unchanged if the stored value is not
251/// a valid integer. This matches the liveness policy used at runtime (an error
252/// response rather than a crash), and gives operators a signal that the AOF
253/// contains unexpected data.
254fn apply_incr(map: &mut HashMap<String, (RecoveredValue, i64)>, key: String, delta: i64) {
255    // -1 means no expiry
256    let entry = map
257        .entry(key.clone())
258        .or_insert_with(|| (RecoveredValue::String(Bytes::from("0")), -1));
259    if let RecoveredValue::String(ref mut data) = entry.0 {
260        let current = std::str::from_utf8(data)
261            .ok()
262            .and_then(|s| s.parse::<i64>().ok());
263        if let Some(n) = current {
264            if let Some(new_val) = n.checked_add(delta) {
265                *data = Bytes::from(new_val.to_string());
266            } else {
267                error!(
268                    key = %key,
269                    value = n,
270                    delta,
271                    "INCR overflow during AOF replay: value unchanged — AOF may be corrupt"
272                );
273            }
274        } else {
275            error!(
276                key = %key,
277                "skipping INCR replay: stored value is not an integer — AOF may be corrupt"
278            );
279        }
280    }
281}
282
283/// Replays AOF records into the in-memory map. Returns the number of
284/// records replayed. TTL is stored as remaining ms (-1 = no expiry).
285fn replay_aof(
286    path: &Path,
287    map: &mut HashMap<String, (RecoveredValue, i64)>,
288    #[allow(unused_variables)] encryption_key: Option<EncryptionKeyRef<'_>>,
289    #[cfg(feature = "protobuf")] schema_map: &mut HashMap<String, Bytes>,
290) -> Result<usize, FormatError> {
291    #[cfg(feature = "encryption")]
292    let mut reader = if let Some(key) = encryption_key {
293        AofReader::open_encrypted(path, key.clone())?
294    } else {
295        AofReader::open(path)?
296    };
297    #[cfg(not(feature = "encryption"))]
298    let mut reader = AofReader::open(path)?;
299    let mut count = 0;
300
301    while let Some(record) = reader.read_record()? {
302        match record {
303            AofRecord::Set {
304                key,
305                value,
306                expire_ms,
307            } => {
308                // expire_ms is -1 for no expiry, or remaining ms
309                map.insert(key, (RecoveredValue::String(value), expire_ms));
310            }
311            AofRecord::Del { key } => {
312                map.remove(&key);
313            }
314            AofRecord::Expire { key, seconds } => {
315                if let Some(entry) = map.get_mut(&key) {
316                    entry.1 = seconds.saturating_mul(1000).min(i64::MAX as u64) as i64;
317                }
318            }
319            AofRecord::LPush { key, values } => {
320                let entry = map
321                    .entry(key)
322                    .or_insert_with(|| (RecoveredValue::List(VecDeque::new()), -1));
323                if let RecoveredValue::List(ref mut deque) = entry.0 {
324                    for v in values {
325                        deque.push_front(v);
326                    }
327                }
328            }
329            AofRecord::RPush { key, values } => {
330                let entry = map
331                    .entry(key)
332                    .or_insert_with(|| (RecoveredValue::List(VecDeque::new()), -1));
333                if let RecoveredValue::List(ref mut deque) = entry.0 {
334                    for v in values {
335                        deque.push_back(v);
336                    }
337                }
338            }
339            AofRecord::LPop { key } => {
340                if let Some(entry) = map.get_mut(&key) {
341                    if let RecoveredValue::List(ref mut deque) = entry.0 {
342                        deque.pop_front();
343                        if deque.is_empty() {
344                            map.remove(&key);
345                            count += 1;
346                            continue;
347                        }
348                    }
349                }
350            }
351            AofRecord::RPop { key } => {
352                if let Some(entry) = map.get_mut(&key) {
353                    if let RecoveredValue::List(ref mut deque) = entry.0 {
354                        deque.pop_back();
355                        if deque.is_empty() {
356                            map.remove(&key);
357                            count += 1;
358                            continue;
359                        }
360                    }
361                }
362            }
363            AofRecord::LSet { key, index, value } => {
364                if let Some(entry) = map.get_mut(&key) {
365                    if let RecoveredValue::List(ref mut deque) = entry.0 {
366                        let len = deque.len();
367                        let resolved = if index < 0 {
368                            (len as i64 + index) as usize
369                        } else {
370                            index as usize
371                        };
372                        if resolved < len {
373                            deque[resolved] = value;
374                        }
375                    }
376                }
377            }
378            AofRecord::LTrim { key, start, stop } => {
379                if let Some(entry) = map.get_mut(&key) {
380                    if let RecoveredValue::List(ref mut deque) = entry.0 {
381                        let len = deque.len() as i64;
382                        let s = if start < 0 {
383                            (start + len).max(0) as usize
384                        } else {
385                            (start as usize).min(len as usize)
386                        };
387                        let e = if stop < 0 {
388                            (stop + len).max(-1) as usize
389                        } else {
390                            (stop as usize).min((len as usize).saturating_sub(1))
391                        };
392                        if s > e || s >= len as usize {
393                            deque.clear();
394                        } else {
395                            *deque = deque.drain(s..=e).collect();
396                        }
397                        if deque.is_empty() {
398                            map.remove(&key);
399                            count += 1;
400                            continue;
401                        }
402                    }
403                }
404            }
405            AofRecord::LInsert {
406                key,
407                before,
408                pivot,
409                value,
410            } => {
411                if let Some(entry) = map.get_mut(&key) {
412                    if let RecoveredValue::List(ref mut deque) = entry.0 {
413                        if let Some(pos) = deque.iter().position(|el| el.as_ref() == pivot.as_ref())
414                        {
415                            let insert_at = if before { pos } else { pos + 1 };
416                            deque.insert(insert_at, value);
417                        }
418                    }
419                }
420            }
421            AofRecord::LRem {
422                key,
423                count: cnt,
424                value,
425            } => {
426                if let Some(entry) = map.get_mut(&key) {
427                    if let RecoveredValue::List(ref mut deque) = entry.0 {
428                        let max = if cnt == 0 {
429                            usize::MAX
430                        } else {
431                            cnt.unsigned_abs() as usize
432                        };
433                        let mut removed = 0;
434                        if cnt >= 0 {
435                            // forward scan
436                            deque.retain(|el| {
437                                if removed < max && el.as_ref() == value.as_ref() {
438                                    removed += 1;
439                                    false
440                                } else {
441                                    true
442                                }
443                            });
444                        } else {
445                            // reverse scan: collect indices, remove from back
446                            let mut indices: Vec<usize> = Vec::new();
447                            for (i, el) in deque.iter().enumerate().rev() {
448                                if el.as_ref() == value.as_ref() {
449                                    indices.push(i);
450                                    if indices.len() >= max {
451                                        break;
452                                    }
453                                }
454                            }
455                            // indices are already highest-first
456                            for i in indices {
457                                deque.remove(i);
458                            }
459                        }
460                        if deque.is_empty() {
461                            map.remove(&key);
462                            count += 1;
463                            continue;
464                        }
465                    }
466                }
467            }
468            AofRecord::ZAdd { key, members } => {
469                let entry = map
470                    .entry(key)
471                    .or_insert_with(|| (RecoveredValue::SortedSet(Vec::new()), -1));
472                if let RecoveredValue::SortedSet(ref mut existing) = entry.0 {
473                    // build a position index for O(1) member lookups
474                    let mut index: HashMap<String, usize> = existing
475                        .iter()
476                        .enumerate()
477                        .map(|(i, (_, m))| (m.clone(), i))
478                        .collect();
479                    for (score, member) in members {
480                        if let Some(&pos) = index.get(&member) {
481                            existing[pos].0 = score;
482                        } else {
483                            let pos = existing.len();
484                            index.insert(member.clone(), pos);
485                            existing.push((score, member));
486                        }
487                    }
488                }
489            }
490            AofRecord::ZRem { key, members } => {
491                if let Some(entry) = map.get_mut(&key) {
492                    if let RecoveredValue::SortedSet(ref mut existing) = entry.0 {
493                        let to_remove: HashSet<&str> = members.iter().map(|m| m.as_str()).collect();
494                        existing.retain(|(_, m)| !to_remove.contains(m.as_str()));
495                        if existing.is_empty() {
496                            map.remove(&key);
497                            count += 1;
498                            continue;
499                        }
500                    }
501                }
502            }
503            AofRecord::Persist { key } => {
504                if let Some(entry) = map.get_mut(&key) {
505                    entry.1 = -1; // -1 means no expiry
506                }
507            }
508            AofRecord::Pexpire { key, milliseconds } => {
509                if let Some(entry) = map.get_mut(&key) {
510                    entry.1 = milliseconds.min(i64::MAX as u64) as i64;
511                }
512            }
513            AofRecord::Pexpireat { key, timestamp_ms } => {
514                if let Some(entry) = map.get_mut(&key) {
515                    // Convert the absolute unix timestamp to a remaining TTL
516                    // relative to now. This preserves the exact wall-clock
517                    // deadline across restarts.
518                    let now_ms = std::time::SystemTime::now()
519                        .duration_since(std::time::UNIX_EPOCH)
520                        .unwrap_or_default()
521                        .as_millis() as u64;
522                    if timestamp_ms <= now_ms {
523                        // Already expired — mark as 0 so the filter removes it.
524                        entry.1 = 0;
525                    } else {
526                        let remaining = timestamp_ms.saturating_sub(now_ms);
527                        entry.1 = remaining.min(i64::MAX as u64) as i64;
528                    }
529                }
530            }
531            AofRecord::Incr { key } => {
532                apply_incr(map, key, 1);
533            }
534            AofRecord::Decr { key } => {
535                apply_incr(map, key, -1);
536            }
537            AofRecord::IncrBy { key, delta } => {
538                apply_incr(map, key, delta);
539            }
540            AofRecord::DecrBy { key, delta } => {
541                apply_incr(map, key, delta.saturating_neg());
542            }
543            AofRecord::Append { key, value } => {
544                let entry = map
545                    .entry(key)
546                    .or_insert_with(|| (RecoveredValue::String(Bytes::new()), -1));
547                if let RecoveredValue::String(ref mut data) = entry.0 {
548                    let mut new_data = Vec::with_capacity(data.len() + value.len());
549                    new_data.extend_from_slice(data);
550                    new_data.extend_from_slice(&value);
551                    *data = Bytes::from(new_data);
552                }
553            }
554            AofRecord::SetRange { key, offset, value } => {
555                let entry = map
556                    .entry(key)
557                    .or_insert_with(|| (RecoveredValue::String(Bytes::new()), -1));
558                if let RecoveredValue::String(ref mut data) = entry.0 {
559                    let needed = offset.saturating_add(value.len());
560                    let new_len = data.len().max(needed);
561                    let mut buf = Vec::with_capacity(new_len);
562                    let copy_len = data.len().min(offset);
563                    buf.extend_from_slice(&data[..copy_len]);
564                    if offset > data.len() {
565                        buf.resize(offset, 0);
566                    }
567                    buf.extend_from_slice(&value);
568                    if offset + value.len() < data.len() {
569                        buf.extend_from_slice(&data[offset + value.len()..]);
570                    }
571                    *data = Bytes::from(buf);
572                }
573            }
574            AofRecord::SetBit { key, offset, value } => {
575                let entry = map
576                    .entry(key)
577                    .or_insert_with(|| (RecoveredValue::String(Bytes::new()), -1));
578                if let RecoveredValue::String(ref mut data) = entry.0 {
579                    let byte_idx = (offset / 8) as usize;
580                    let bit_pos = 7 - (offset % 8) as u32;
581                    let mask = 1u8 << bit_pos;
582                    let new_len = data.len().max(byte_idx + 1);
583                    let mut buf = data.to_vec();
584                    buf.resize(new_len, 0);
585                    if value == 1 {
586                        buf[byte_idx] |= mask;
587                    } else {
588                        buf[byte_idx] &= !mask;
589                    }
590                    *data = Bytes::from(buf);
591                }
592            }
593            AofRecord::BitOp { op, dest, keys } => {
594                // op byte: 0=AND, 1=OR, 2=XOR, 3=NOT (matches aof.rs encoding)
595                let sources: Vec<Bytes> = keys
596                    .iter()
597                    .map(|k| {
598                        map.get(k)
599                            .and_then(|(v, _)| {
600                                if let RecoveredValue::String(b) = v {
601                                    Some(b.clone())
602                                } else {
603                                    None
604                                }
605                            })
606                            .unwrap_or_default()
607                    })
608                    .collect();
609                let result_len = sources.iter().map(|s| s.len()).max().unwrap_or(0);
610                let mut result = vec![0u8; result_len];
611                match op {
612                    3 => {
613                        // NOT
614                        let src = sources.first().map(|b| b.as_ref()).unwrap_or(&[]);
615                        for (i, b) in result.iter_mut().enumerate() {
616                            *b = if i < src.len() { !src[i] } else { 0xFF };
617                        }
618                    }
619                    0 => {
620                        // AND
621                        if let Some(first) = sources.first() {
622                            for (i, b) in result.iter_mut().enumerate() {
623                                *b = if i < first.len() { first[i] } else { 0 };
624                            }
625                        }
626                        for src in sources.iter().skip(1) {
627                            for (i, b) in result.iter_mut().enumerate() {
628                                *b &= if i < src.len() { src[i] } else { 0 };
629                            }
630                        }
631                    }
632                    1 => {
633                        // OR
634                        for src in &sources {
635                            for (i, b) in result.iter_mut().enumerate() {
636                                if i < src.len() {
637                                    *b |= src[i];
638                                }
639                            }
640                        }
641                    }
642                    _ => {
643                        // XOR (op == 2)
644                        for src in &sources {
645                            for (i, b) in result.iter_mut().enumerate() {
646                                if i < src.len() {
647                                    *b ^= src[i];
648                                }
649                            }
650                        }
651                    }
652                }
653                map.insert(dest, (RecoveredValue::String(Bytes::from(result)), -1));
654            }
655            AofRecord::Rename { key, newkey } => {
656                if let Some(entry) = map.remove(&key) {
657                    map.insert(newkey, entry);
658                }
659            }
660            AofRecord::Copy {
661                source,
662                destination,
663                replace,
664            } => {
665                if let Some(entry) = map.get(&source) {
666                    let cloned = entry.clone();
667                    if replace || !map.contains_key(&destination) {
668                        map.insert(destination, cloned);
669                    }
670                }
671            }
672            AofRecord::HSet { key, fields } => {
673                let entry = map
674                    .entry(key)
675                    .or_insert_with(|| (RecoveredValue::Hash(HashMap::new()), -1));
676                if let RecoveredValue::Hash(ref mut hash) = entry.0 {
677                    for (field, value) in fields {
678                        hash.insert(field, value);
679                    }
680                }
681            }
682            AofRecord::HDel { key, fields } => {
683                if let Some(entry) = map.get_mut(&key) {
684                    if let RecoveredValue::Hash(ref mut hash) = entry.0 {
685                        for field in fields {
686                            hash.remove(&field);
687                        }
688                        if hash.is_empty() {
689                            map.remove(&key);
690                            count += 1;
691                            continue;
692                        }
693                    }
694                }
695            }
696            AofRecord::HIncrBy { key, field, delta } => {
697                let entry = map
698                    .entry(key)
699                    .or_insert_with(|| (RecoveredValue::Hash(HashMap::new()), -1));
700                if let RecoveredValue::Hash(ref mut hash) = entry.0 {
701                    let current: i64 = hash
702                        .get(&field)
703                        .and_then(|v| std::str::from_utf8(v).ok())
704                        .and_then(|s| s.parse().ok())
705                        .unwrap_or(0);
706                    let new_val = current.saturating_add(delta);
707                    hash.insert(field, Bytes::from(new_val.to_string()));
708                }
709            }
710            AofRecord::SAdd { key, members } => {
711                let entry = map
712                    .entry(key)
713                    .or_insert_with(|| (RecoveredValue::Set(HashSet::new()), -1));
714                if let RecoveredValue::Set(ref mut set) = entry.0 {
715                    for member in members {
716                        set.insert(member);
717                    }
718                }
719            }
720            AofRecord::SRem { key, members } => {
721                if let Some(entry) = map.get_mut(&key) {
722                    if let RecoveredValue::Set(ref mut set) = entry.0 {
723                        for member in members {
724                            set.remove(&member);
725                        }
726                        if set.is_empty() {
727                            map.remove(&key);
728                            count += 1;
729                            continue;
730                        }
731                    }
732                }
733            }
734            #[cfg(feature = "vector")]
735            AofRecord::VAdd {
736                key,
737                element,
738                vector,
739                metric,
740                quantization,
741                connectivity,
742                expansion_add,
743            } => {
744                let entry = map.entry(key).or_insert_with(|| {
745                    (
746                        RecoveredValue::Vector {
747                            metric,
748                            quantization,
749                            connectivity,
750                            expansion_add,
751                            elements: Vec::new(),
752                        },
753                        -1, // no expiry for vector sets
754                    )
755                });
756                if let RecoveredValue::Vector {
757                    ref mut elements, ..
758                } = entry.0
759                {
760                    // replace existing element or add new
761                    if let Some(pos) = elements.iter().position(|(e, _)| *e == element) {
762                        elements[pos].1 = vector;
763                    } else {
764                        elements.push((element, vector));
765                    }
766                }
767            }
768            #[cfg(feature = "vector")]
769            AofRecord::VRem { key, element } => {
770                if let Some(entry) = map.get_mut(&key) {
771                    if let RecoveredValue::Vector {
772                        ref mut elements, ..
773                    } = entry.0
774                    {
775                        elements.retain(|(e, _)| *e != element);
776                        if elements.is_empty() {
777                            map.remove(&key);
778                        }
779                    }
780                }
781            }
782            #[cfg(feature = "protobuf")]
783            AofRecord::ProtoSet {
784                key,
785                type_name,
786                data,
787                expire_ms,
788            } => {
789                map.insert(key, (RecoveredValue::Proto { type_name, data }, expire_ms));
790            }
791            #[cfg(feature = "protobuf")]
792            AofRecord::ProtoRegister { name, descriptor } => {
793                // last-wins: if the same schema name appears multiple times
794                // in the AOF, the final registration is the one we keep.
795                schema_map.insert(name, descriptor);
796            }
797        }
798        count += 1;
799    }
800
801    Ok(count)
802}
803
804#[cfg(test)]
805mod tests {
806    use super::*;
807    use crate::aof::AofWriter;
808    use crate::snapshot::{SnapEntry, SnapValue, SnapshotWriter};
809
810    fn temp_dir() -> tempfile::TempDir {
811        tempfile::tempdir().expect("create temp dir")
812    }
813
814    #[test]
815    fn empty_dir_returns_empty_result() {
816        let dir = temp_dir();
817        let result = recover_shard(dir.path(), 0);
818        assert!(result.entries.is_empty());
819        assert!(!result.loaded_snapshot);
820        assert!(!result.replayed_aof);
821    }
822
823    #[test]
824    fn snapshot_only_recovery() {
825        let dir = temp_dir();
826        let path = snapshot::snapshot_path(dir.path(), 0);
827
828        {
829            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
830            writer
831                .write_entry(&SnapEntry {
832                    key: "a".into(),
833                    value: SnapValue::String(Bytes::from("1")),
834                    expire_ms: -1,
835                })
836                .unwrap();
837            writer
838                .write_entry(&SnapEntry {
839                    key: "b".into(),
840                    value: SnapValue::String(Bytes::from("2")),
841                    expire_ms: 60_000,
842                })
843                .unwrap();
844            writer.finish().unwrap();
845        }
846
847        let result = recover_shard(dir.path(), 0);
848        assert!(result.loaded_snapshot);
849        assert!(!result.replayed_aof);
850        assert_eq!(result.entries.len(), 2);
851    }
852
853    #[test]
854    fn aof_only_recovery() {
855        let dir = temp_dir();
856        let path = aof::aof_path(dir.path(), 0);
857
858        {
859            let mut writer = AofWriter::open(&path).unwrap();
860            writer
861                .write_record(&AofRecord::Set {
862                    key: "x".into(),
863                    value: Bytes::from("10"),
864                    expire_ms: -1,
865                })
866                .unwrap();
867            writer
868                .write_record(&AofRecord::Set {
869                    key: "y".into(),
870                    value: Bytes::from("20"),
871                    expire_ms: -1,
872                })
873                .unwrap();
874            writer.sync().unwrap();
875        }
876
877        let result = recover_shard(dir.path(), 0);
878        assert!(!result.loaded_snapshot);
879        assert!(result.replayed_aof);
880        assert_eq!(result.entries.len(), 2);
881    }
882
883    #[test]
884    fn snapshot_plus_aof_overlay() {
885        let dir = temp_dir();
886
887        // snapshot with key "a" = "old"
888        {
889            let path = snapshot::snapshot_path(dir.path(), 0);
890            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
891            writer
892                .write_entry(&SnapEntry {
893                    key: "a".into(),
894                    value: SnapValue::String(Bytes::from("old")),
895                    expire_ms: -1,
896                })
897                .unwrap();
898            writer.finish().unwrap();
899        }
900
901        // AOF overwrites "a" to "new" and adds "b"
902        {
903            let path = aof::aof_path(dir.path(), 0);
904            let mut writer = AofWriter::open(&path).unwrap();
905            writer
906                .write_record(&AofRecord::Set {
907                    key: "a".into(),
908                    value: Bytes::from("new"),
909                    expire_ms: -1,
910                })
911                .unwrap();
912            writer
913                .write_record(&AofRecord::Set {
914                    key: "b".into(),
915                    value: Bytes::from("added"),
916                    expire_ms: -1,
917                })
918                .unwrap();
919            writer.sync().unwrap();
920        }
921
922        let result = recover_shard(dir.path(), 0);
923        assert!(result.loaded_snapshot);
924        assert!(result.replayed_aof);
925
926        let map: HashMap<_, _> = result
927            .entries
928            .iter()
929            .map(|e| (e.key.as_str(), e.value.clone()))
930            .collect();
931        assert!(matches!(&map["a"], RecoveredValue::String(b) if b == &Bytes::from("new")));
932        assert!(matches!(&map["b"], RecoveredValue::String(b) if b == &Bytes::from("added")));
933    }
934
935    #[test]
936    fn del_removes_entry_during_replay() {
937        let dir = temp_dir();
938        let path = aof::aof_path(dir.path(), 0);
939
940        {
941            let mut writer = AofWriter::open(&path).unwrap();
942            writer
943                .write_record(&AofRecord::Set {
944                    key: "gone".into(),
945                    value: Bytes::from("temp"),
946                    expire_ms: -1,
947                })
948                .unwrap();
949            writer
950                .write_record(&AofRecord::Del { key: "gone".into() })
951                .unwrap();
952            writer.sync().unwrap();
953        }
954
955        let result = recover_shard(dir.path(), 0);
956        assert!(result.entries.is_empty());
957    }
958
959    #[test]
960    fn expired_entries_skipped() {
961        let dir = temp_dir();
962        let path = snapshot::snapshot_path(dir.path(), 0);
963
964        {
965            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
966            // this entry has 0ms remaining — already expired
967            writer
968                .write_entry(&SnapEntry {
969                    key: "dead".into(),
970                    value: SnapValue::String(Bytes::from("gone")),
971                    expire_ms: 0,
972                })
973                .unwrap();
974            // this one has plenty of time
975            writer
976                .write_entry(&SnapEntry {
977                    key: "alive".into(),
978                    value: SnapValue::String(Bytes::from("here")),
979                    expire_ms: 60_000,
980                })
981                .unwrap();
982            writer.finish().unwrap();
983        }
984
985        let result = recover_shard(dir.path(), 0);
986        assert_eq!(result.entries.len(), 1);
987        assert_eq!(result.entries[0].key, "alive");
988    }
989
990    #[test]
991    fn corrupt_snapshot_starts_empty() {
992        let dir = temp_dir();
993        let path = snapshot::snapshot_path(dir.path(), 0);
994
995        std::fs::write(&path, b"garbage data").unwrap();
996
997        let result = recover_shard(dir.path(), 0);
998        assert!(!result.loaded_snapshot);
999        assert!(result.entries.is_empty());
1000    }
1001
1002    #[test]
1003    fn sorted_set_snapshot_recovery() {
1004        let dir = temp_dir();
1005        let path = snapshot::snapshot_path(dir.path(), 0);
1006
1007        {
1008            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
1009            writer
1010                .write_entry(&SnapEntry {
1011                    key: "board".into(),
1012                    value: SnapValue::SortedSet(vec![
1013                        (100.0, "alice".into()),
1014                        (200.0, "bob".into()),
1015                    ]),
1016                    expire_ms: -1,
1017                })
1018                .unwrap();
1019            writer.finish().unwrap();
1020        }
1021
1022        let result = recover_shard(dir.path(), 0);
1023        assert!(result.loaded_snapshot);
1024        assert_eq!(result.entries.len(), 1);
1025        match &result.entries[0].value {
1026            RecoveredValue::SortedSet(members) => {
1027                assert_eq!(members.len(), 2);
1028                assert!(members.contains(&(100.0, "alice".into())));
1029                assert!(members.contains(&(200.0, "bob".into())));
1030            }
1031            other => panic!("expected SortedSet, got {other:?}"),
1032        }
1033    }
1034
1035    #[test]
1036    fn sorted_set_aof_replay() {
1037        let dir = temp_dir();
1038        let path = aof::aof_path(dir.path(), 0);
1039
1040        {
1041            let mut writer = AofWriter::open(&path).unwrap();
1042            writer
1043                .write_record(&AofRecord::ZAdd {
1044                    key: "board".into(),
1045                    members: vec![(100.0, "alice".into()), (200.0, "bob".into())],
1046                })
1047                .unwrap();
1048            writer
1049                .write_record(&AofRecord::ZRem {
1050                    key: "board".into(),
1051                    members: vec!["alice".into()],
1052                })
1053                .unwrap();
1054            writer.sync().unwrap();
1055        }
1056
1057        let result = recover_shard(dir.path(), 0);
1058        assert!(result.replayed_aof);
1059        assert_eq!(result.entries.len(), 1);
1060        match &result.entries[0].value {
1061            RecoveredValue::SortedSet(members) => {
1062                assert_eq!(members.len(), 1);
1063                assert_eq!(members[0], (200.0, "bob".into()));
1064            }
1065            other => panic!("expected SortedSet, got {other:?}"),
1066        }
1067    }
1068
1069    #[test]
1070    fn sorted_set_zrem_auto_deletes_empty() {
1071        let dir = temp_dir();
1072        let path = aof::aof_path(dir.path(), 0);
1073
1074        {
1075            let mut writer = AofWriter::open(&path).unwrap();
1076            writer
1077                .write_record(&AofRecord::ZAdd {
1078                    key: "board".into(),
1079                    members: vec![(100.0, "alice".into())],
1080                })
1081                .unwrap();
1082            writer
1083                .write_record(&AofRecord::ZRem {
1084                    key: "board".into(),
1085                    members: vec!["alice".into()],
1086                })
1087                .unwrap();
1088            writer.sync().unwrap();
1089        }
1090
1091        let result = recover_shard(dir.path(), 0);
1092        assert!(result.entries.is_empty());
1093    }
1094
1095    #[test]
1096    fn expire_record_updates_ttl() {
1097        let dir = temp_dir();
1098        let path = aof::aof_path(dir.path(), 0);
1099
1100        {
1101            let mut writer = AofWriter::open(&path).unwrap();
1102            writer
1103                .write_record(&AofRecord::Set {
1104                    key: "k".into(),
1105                    value: Bytes::from("v"),
1106                    expire_ms: -1,
1107                })
1108                .unwrap();
1109            writer
1110                .write_record(&AofRecord::Expire {
1111                    key: "k".into(),
1112                    seconds: 300,
1113                })
1114                .unwrap();
1115            writer.sync().unwrap();
1116        }
1117
1118        let result = recover_shard(dir.path(), 0);
1119        assert_eq!(result.entries.len(), 1);
1120        assert!(result.entries[0].ttl.is_some());
1121    }
1122
1123    #[test]
1124    fn persist_record_removes_ttl() {
1125        let dir = temp_dir();
1126        let path = aof::aof_path(dir.path(), 0);
1127
1128        {
1129            let mut writer = AofWriter::open(&path).unwrap();
1130            writer
1131                .write_record(&AofRecord::Set {
1132                    key: "k".into(),
1133                    value: Bytes::from("v"),
1134                    expire_ms: 60_000,
1135                })
1136                .unwrap();
1137            writer
1138                .write_record(&AofRecord::Persist { key: "k".into() })
1139                .unwrap();
1140            writer.sync().unwrap();
1141        }
1142
1143        let result = recover_shard(dir.path(), 0);
1144        assert_eq!(result.entries.len(), 1);
1145        assert!(result.entries[0].ttl.is_none());
1146    }
1147
1148    #[test]
1149    fn incr_decr_replay() {
1150        let dir = temp_dir();
1151        let path = aof::aof_path(dir.path(), 0);
1152
1153        {
1154            let mut writer = AofWriter::open(&path).unwrap();
1155            writer
1156                .write_record(&AofRecord::Set {
1157                    key: "n".into(),
1158                    value: Bytes::from("10"),
1159                    expire_ms: -1,
1160                })
1161                .unwrap();
1162            writer
1163                .write_record(&AofRecord::Incr { key: "n".into() })
1164                .unwrap();
1165            writer
1166                .write_record(&AofRecord::Incr { key: "n".into() })
1167                .unwrap();
1168            writer
1169                .write_record(&AofRecord::Decr { key: "n".into() })
1170                .unwrap();
1171            // also test INCR on a new key
1172            writer
1173                .write_record(&AofRecord::Incr {
1174                    key: "fresh".into(),
1175                })
1176                .unwrap();
1177            writer.sync().unwrap();
1178        }
1179
1180        let result = recover_shard(dir.path(), 0);
1181        let map: HashMap<_, _> = result
1182            .entries
1183            .iter()
1184            .map(|e| (e.key.as_str(), e.value.clone()))
1185            .collect();
1186
1187        // 10 + 1 + 1 - 1 = 11
1188        match &map["n"] {
1189            RecoveredValue::String(data) => assert_eq!(data, &Bytes::from("11")),
1190            other => panic!("expected String(\"11\"), got {other:?}"),
1191        }
1192        // 0 + 1 = 1
1193        match &map["fresh"] {
1194            RecoveredValue::String(data) => assert_eq!(data, &Bytes::from("1")),
1195            other => panic!("expected String(\"1\"), got {other:?}"),
1196        }
1197    }
1198
1199    #[test]
1200    fn pexpire_record_sets_ttl() {
1201        let dir = temp_dir();
1202        let path = aof::aof_path(dir.path(), 0);
1203
1204        {
1205            let mut writer = AofWriter::open(&path).unwrap();
1206            writer
1207                .write_record(&AofRecord::Set {
1208                    key: "k".into(),
1209                    value: Bytes::from("v"),
1210                    expire_ms: -1,
1211                })
1212                .unwrap();
1213            writer
1214                .write_record(&AofRecord::Pexpire {
1215                    key: "k".into(),
1216                    milliseconds: 5000,
1217                })
1218                .unwrap();
1219            writer.sync().unwrap();
1220        }
1221
1222        let result = recover_shard(dir.path(), 0);
1223        assert_eq!(result.entries.len(), 1);
1224        assert!(result.entries[0].ttl.is_some());
1225    }
1226
1227    #[cfg(feature = "vector")]
1228    #[test]
1229    fn vector_snapshot_recovery() {
1230        let dir = temp_dir();
1231        let path = snapshot::snapshot_path(dir.path(), 0);
1232
1233        {
1234            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
1235            writer
1236                .write_entry(&SnapEntry {
1237                    key: "embeddings".into(),
1238                    value: SnapValue::Vector {
1239                        metric: 0,
1240                        quantization: 0,
1241                        connectivity: 16,
1242                        expansion_add: 64,
1243                        dim: 3,
1244                        elements: vec![
1245                            ("doc1".into(), vec![1.0, 0.0, 0.0]),
1246                            ("doc2".into(), vec![0.0, 1.0, 0.0]),
1247                        ],
1248                    },
1249                    expire_ms: -1,
1250                })
1251                .unwrap();
1252            writer.finish().unwrap();
1253        }
1254
1255        let result = recover_shard(dir.path(), 0);
1256        assert!(result.loaded_snapshot);
1257        assert_eq!(result.entries.len(), 1);
1258        match &result.entries[0].value {
1259            RecoveredValue::Vector {
1260                metric,
1261                quantization,
1262                elements,
1263                ..
1264            } => {
1265                assert_eq!(*metric, 0);
1266                assert_eq!(*quantization, 0);
1267                assert_eq!(elements.len(), 2);
1268                // dim is inferred from the vector length
1269                assert_eq!(elements[0].1.len(), 3);
1270            }
1271            other => panic!("expected Vector, got {other:?}"),
1272        }
1273    }
1274
1275    #[cfg(feature = "vector")]
1276    #[test]
1277    fn vector_aof_replay() {
1278        let dir = temp_dir();
1279        let path = aof::aof_path(dir.path(), 0);
1280
1281        {
1282            let mut writer = AofWriter::open(&path).unwrap();
1283            writer
1284                .write_record(&AofRecord::VAdd {
1285                    key: "vecs".into(),
1286                    element: "a".into(),
1287                    vector: vec![1.0, 0.0, 0.0],
1288                    metric: 0,
1289                    quantization: 0,
1290                    connectivity: 16,
1291                    expansion_add: 64,
1292                })
1293                .unwrap();
1294            writer
1295                .write_record(&AofRecord::VAdd {
1296                    key: "vecs".into(),
1297                    element: "b".into(),
1298                    vector: vec![0.0, 1.0, 0.0],
1299                    metric: 0,
1300                    quantization: 0,
1301                    connectivity: 16,
1302                    expansion_add: 64,
1303                })
1304                .unwrap();
1305            writer
1306                .write_record(&AofRecord::VRem {
1307                    key: "vecs".into(),
1308                    element: "a".into(),
1309                })
1310                .unwrap();
1311            writer.sync().unwrap();
1312        }
1313
1314        let result = recover_shard(dir.path(), 0);
1315        assert!(result.replayed_aof);
1316        assert_eq!(result.entries.len(), 1);
1317        match &result.entries[0].value {
1318            RecoveredValue::Vector { elements, .. } => {
1319                assert_eq!(elements.len(), 1);
1320                assert_eq!(elements[0].0, "b");
1321            }
1322            other => panic!("expected Vector, got {other:?}"),
1323        }
1324    }
1325
1326    #[cfg(feature = "vector")]
1327    #[test]
1328    fn vector_vrem_auto_deletes_empty() {
1329        let dir = temp_dir();
1330        let path = aof::aof_path(dir.path(), 0);
1331
1332        {
1333            let mut writer = AofWriter::open(&path).unwrap();
1334            writer
1335                .write_record(&AofRecord::VAdd {
1336                    key: "vecs".into(),
1337                    element: "only".into(),
1338                    vector: vec![1.0, 2.0],
1339                    metric: 0,
1340                    quantization: 0,
1341                    connectivity: 16,
1342                    expansion_add: 64,
1343                })
1344                .unwrap();
1345            writer
1346                .write_record(&AofRecord::VRem {
1347                    key: "vecs".into(),
1348                    element: "only".into(),
1349                })
1350                .unwrap();
1351            writer.sync().unwrap();
1352        }
1353
1354        let result = recover_shard(dir.path(), 0);
1355        assert!(result.entries.is_empty());
1356    }
1357
1358    #[cfg(feature = "protobuf")]
1359    #[test]
1360    fn proto_schemas_recovered_from_aof() {
1361        let dir = temp_dir();
1362        let path = aof::aof_path(dir.path(), 0);
1363
1364        {
1365            let mut writer = AofWriter::open(&path).unwrap();
1366            writer
1367                .write_record(&AofRecord::ProtoRegister {
1368                    name: "users".into(),
1369                    descriptor: Bytes::from("fake-descriptor-a"),
1370                })
1371                .unwrap();
1372            // a proto value that depends on the schema
1373            writer
1374                .write_record(&AofRecord::ProtoSet {
1375                    key: "user:1".into(),
1376                    type_name: "test.User".into(),
1377                    data: Bytes::from("some-proto-data"),
1378                    expire_ms: -1,
1379                })
1380                .unwrap();
1381            // re-registration of same schema (last wins)
1382            writer
1383                .write_record(&AofRecord::ProtoRegister {
1384                    name: "users".into(),
1385                    descriptor: Bytes::from("fake-descriptor-b"),
1386                })
1387                .unwrap();
1388            writer.sync().unwrap();
1389        }
1390
1391        let result = recover_shard(dir.path(), 0);
1392        assert!(result.replayed_aof);
1393        assert_eq!(result.entries.len(), 1);
1394
1395        // schemas should be collected with last-wins dedup
1396        assert_eq!(result.schemas.len(), 1);
1397        let (name, desc) = &result.schemas[0];
1398        assert_eq!(name, "users");
1399        assert_eq!(desc, &Bytes::from("fake-descriptor-b"));
1400    }
1401}