Skip to main content

fast_cache/persistence/
snapshot.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3
4use lz4_flex::{compress_prepend_size, decompress_size_prepended};
5
6use crate::storage::{Bytes, StoredEntry, hash_key};
7use crate::{FastCacheError, Result};
8
9const SNAPSHOT_MAGIC: &[u8; 8] = b"FCSNAP1\0";
10const SNAPSHOT_VERSION: u32 = 1;
11const SNAPSHOT_HEADER_LEN: usize = 8 + 4 + 8 + 8;
12const SNAPSHOT_ENTRY_HEADER_LEN: usize = 4 + 4 + 8;
13const SNAPSHOT_COMPRESSED_EXT: &str = "lz4";
14
15#[derive(Debug, Clone)]
16pub struct LoadedSnapshot {
17    /// Path of the snapshot file that was loaded.
18    pub path: PathBuf,
19    /// Snapshot timestamp captured when the file was written.
20    pub timestamp_ms: u64,
21    /// Live cache entries encoded in the snapshot.
22    pub entries: Vec<StoredEntry>,
23}
24
25/// Compression mode used when writing a snapshot file.
26#[derive(Debug, Clone, Copy, Eq, PartialEq)]
27pub enum SnapshotCompression {
28    /// Store the snapshot body directly.
29    None,
30    /// Store the snapshot body with lz4 size-prepended compression.
31    Lz4,
32}
33
34impl SnapshotCompression {
35    pub const fn from_enabled(enabled: bool) -> Self {
36        if enabled { Self::Lz4 } else { Self::None }
37    }
38
39    fn file_name(self, timestamp_ms: u64) -> String {
40        match self {
41            Self::None => format!("snapshot-{timestamp_ms}.bin"),
42            Self::Lz4 => format!("snapshot-{timestamp_ms}.bin.lz4"),
43        }
44    }
45
46    fn encode(self, body: Vec<u8>) -> Vec<u8> {
47        match self {
48            Self::None => body,
49            Self::Lz4 => compress_prepend_size(&body),
50        }
51    }
52
53    fn decode_path(path: &Path, bytes: Vec<u8>) -> Result<Vec<u8>> {
54        if path
55            .extension()
56            .is_some_and(|ext| ext == SNAPSHOT_COMPRESSED_EXT)
57        {
58            decompress_size_prepended(&bytes).map_err(|error| {
59                FastCacheError::Persistence(format!("invalid compressed snapshot: {error}"))
60            })
61        } else {
62            Ok(bytes)
63        }
64    }
65}
66
67/// Filesystem-backed snapshot repository.
68///
69/// `SnapshotStore` owns the directory scanning and file IO concerns. The
70/// binary snapshot format stays isolated in `SnapshotCodec`, which makes it
71/// easier to evolve the on-disk representation without spreading parsing logic
72/// through persistence recovery.
73#[derive(Debug, Clone)]
74pub struct SnapshotStore {
75    data_dir: PathBuf,
76}
77
78impl SnapshotStore {
79    pub fn new(data_dir: impl AsRef<Path>) -> Self {
80        Self {
81            data_dir: data_dir.as_ref().to_path_buf(),
82        }
83    }
84
85    fn write(
86        &self,
87        entries: &[StoredEntry],
88        timestamp_ms: u64,
89        compression: SnapshotCompression,
90    ) -> Result<PathBuf> {
91        fs::create_dir_all(&self.data_dir)?;
92
93        let body = SnapshotCodec::encode(entries, timestamp_ms)?;
94        let bytes = compression.encode(body);
95        let path = SnapshotName::path(&self.data_dir, timestamp_ms, compression);
96        fs::write(&path, bytes)?;
97        Ok(path)
98    }
99
100    fn load_latest(&self) -> Result<Option<LoadedSnapshot>> {
101        let Some(path) = self.latest_path()? else {
102            return Ok(None);
103        };
104        let bytes = fs::read(&path)?;
105        let raw = SnapshotCompression::decode_path(&path, bytes)?;
106        let (timestamp_ms, entries) = SnapshotCodec::decode(&raw)?;
107        Ok(Some(LoadedSnapshot {
108            path,
109            timestamp_ms,
110            entries,
111        }))
112    }
113
114    fn latest_path(&self) -> Result<Option<PathBuf>> {
115        let mut snapshots = Vec::new();
116        for entry in fs::read_dir(&self.data_dir)? {
117            let entry = entry?;
118            let path = entry.path();
119            if SnapshotName::matches(&path) {
120                snapshots.push(path);
121            }
122        }
123        snapshots.sort();
124        Ok(snapshots.pop())
125    }
126}
127
128/// Snapshot persistence behavior used by WAL recovery and embedded callers.
129pub trait SnapshotRepository {
130    /// Writes `entries` into a timestamped snapshot file and returns its path.
131    fn write_snapshot(
132        &self,
133        entries: &[StoredEntry],
134        timestamp_ms: u64,
135        compression: SnapshotCompression,
136    ) -> Result<PathBuf>;
137
138    /// Loads the newest snapshot file in the repository, if one exists.
139    fn load_latest_snapshot(&self) -> Result<Option<LoadedSnapshot>>;
140}
141
142impl SnapshotRepository for SnapshotStore {
143    fn write_snapshot(
144        &self,
145        entries: &[StoredEntry],
146        timestamp_ms: u64,
147        compression: SnapshotCompression,
148    ) -> Result<PathBuf> {
149        self.write(entries, timestamp_ms, compression)
150    }
151
152    fn load_latest_snapshot(&self) -> Result<Option<LoadedSnapshot>> {
153        self.load_latest()
154    }
155}
156
157struct SnapshotCodec;
158
159impl SnapshotCodec {
160    fn encode(entries: &[StoredEntry], timestamp_ms: u64) -> Result<Vec<u8>> {
161        let mut entries = entries.to_vec();
162        entries.sort_by_key(|entry| hash_key(entry.key.as_ref()));
163
164        let mut body = Vec::with_capacity(
165            SNAPSHOT_HEADER_LEN + entries.len().saturating_mul(SNAPSHOT_ENTRY_HEADER_LEN),
166        );
167        body.extend_from_slice(SNAPSHOT_MAGIC);
168        body.extend_from_slice(&SNAPSHOT_VERSION.to_le_bytes());
169        body.extend_from_slice(&timestamp_ms.to_le_bytes());
170        body.extend_from_slice(&(entries.len() as u64).to_le_bytes());
171        for entry in entries {
172            Self::encode_entry(&mut body, &entry)?;
173        }
174        Ok(body)
175    }
176
177    fn encode_entry(body: &mut Vec<u8>, entry: &StoredEntry) -> Result<()> {
178        let key_len = Self::encoded_len(entry.key.len(), "snapshot key is too large")?;
179        let value_len = Self::encoded_len(entry.value.len(), "snapshot value is too large")?;
180        body.extend_from_slice(&key_len.to_le_bytes());
181        body.extend_from_slice(&value_len.to_le_bytes());
182        body.extend_from_slice(&entry.expire_at_ms.unwrap_or(u64::MAX).to_le_bytes());
183        body.extend_from_slice(entry.key.as_ref());
184        body.extend_from_slice(entry.value.as_ref());
185        Ok(())
186    }
187
188    fn encoded_len(len: usize, message: &'static str) -> Result<u32> {
189        u32::try_from(len).map_err(|_| FastCacheError::Persistence(message.into()))
190    }
191
192    fn decode(raw: &[u8]) -> Result<(u64, Vec<StoredEntry>)> {
193        match raw {
194            bytes if bytes.len() < SNAPSHOT_HEADER_LEN => Err(FastCacheError::Persistence(
195                "snapshot header is truncated".into(),
196            )),
197            bytes if !bytes.starts_with(SNAPSHOT_MAGIC) => Err(FastCacheError::Persistence(
198                "snapshot magic mismatch".into(),
199            )),
200            bytes => Self::decode_validated(bytes),
201        }
202    }
203
204    fn decode_validated(raw: &[u8]) -> Result<(u64, Vec<StoredEntry>)> {
205        let mut cursor = SNAPSHOT_MAGIC.len();
206        match Self::read_u32(raw, &mut cursor, "snapshot version")? {
207            SNAPSHOT_VERSION => Self::decode_body(raw, &mut cursor),
208            version => Err(FastCacheError::Persistence(format!(
209                "unsupported snapshot version: {version}"
210            ))),
211        }
212    }
213
214    fn decode_body(raw: &[u8], cursor: &mut usize) -> Result<(u64, Vec<StoredEntry>)> {
215        let timestamp_ms = Self::read_u64(raw, cursor, "snapshot timestamp")?;
216        let entry_count = usize::try_from(Self::read_u64(raw, cursor, "snapshot entry count")?)
217            .map_err(|_| FastCacheError::Persistence("snapshot entry count is too large".into()))?;
218        let mut entries = Vec::with_capacity(entry_count);
219        for _ in 0..entry_count {
220            entries.push(Self::decode_entry(raw, cursor)?);
221        }
222        Ok((timestamp_ms, entries))
223    }
224
225    fn decode_entry(raw: &[u8], cursor: &mut usize) -> Result<StoredEntry> {
226        if raw.len().saturating_sub(*cursor) < SNAPSHOT_ENTRY_HEADER_LEN {
227            return Err(FastCacheError::Persistence(
228                "snapshot entry header is truncated".into(),
229            ));
230        }
231
232        let key_len = Self::read_u32(raw, cursor, "snapshot key length")? as usize;
233        let value_len = Self::read_u32(raw, cursor, "snapshot value length")? as usize;
234        let expire_raw = Self::read_u64(raw, cursor, "snapshot expiration")?;
235        let body_len = key_len.saturating_add(value_len);
236        if raw.len().saturating_sub(*cursor) < body_len {
237            return Err(FastCacheError::Persistence(
238                "snapshot entry body is truncated".into(),
239            ));
240        }
241
242        let key = raw[*cursor..*cursor + key_len].to_vec();
243        *cursor += key_len;
244        let value = raw[*cursor..*cursor + value_len].to_vec();
245        *cursor += value_len;
246        Ok(StoredEntry {
247            key: Bytes::from(key),
248            value: Bytes::from(value),
249            expire_at_ms: if expire_raw == u64::MAX {
250                None
251            } else {
252                Some(expire_raw)
253            },
254        })
255    }
256
257    fn read_u32(raw: &[u8], cursor: &mut usize, field: &str) -> Result<u32> {
258        let bytes = Self::read_exact(raw, cursor, 4, field)?;
259        let mut value = [0; 4];
260        value.copy_from_slice(bytes);
261        Ok(u32::from_le_bytes(value))
262    }
263
264    fn read_u64(raw: &[u8], cursor: &mut usize, field: &str) -> Result<u64> {
265        let bytes = Self::read_exact(raw, cursor, 8, field)?;
266        let mut value = [0; 8];
267        value.copy_from_slice(bytes);
268        Ok(u64::from_le_bytes(value))
269    }
270
271    fn read_exact<'a>(
272        raw: &'a [u8],
273        cursor: &mut usize,
274        len: usize,
275        field: &str,
276    ) -> Result<&'a [u8]> {
277        if raw.len().saturating_sub(*cursor) < len {
278            return Err(FastCacheError::Persistence(format!("{field} is truncated")));
279        }
280        let bytes = &raw[*cursor..*cursor + len];
281        *cursor += len;
282        Ok(bytes)
283    }
284}
285
286struct SnapshotName;
287
288impl SnapshotName {
289    fn path(data_dir: &Path, timestamp_ms: u64, compression: SnapshotCompression) -> PathBuf {
290        data_dir.join(compression.file_name(timestamp_ms))
291    }
292
293    fn matches(path: &Path) -> bool {
294        path.file_name()
295            .and_then(|value| value.to_str())
296            .is_some_and(|name| name.starts_with("snapshot-"))
297    }
298}