1use std::fs;
2use std::path::{Path, PathBuf};
3
4use lz4_flex::{compress_prepend_size, decompress_size_prepended};
5
6use crate::storage::{Bytes, StoredEntry, hash_key};
7use crate::{FastCacheError, Result};
8
9const SNAPSHOT_MAGIC: &[u8; 8] = b"FCSNAP1\0";
10const SNAPSHOT_VERSION: u32 = 1;
11const SNAPSHOT_HEADER_LEN: usize = 8 + 4 + 8 + 8;
12const SNAPSHOT_ENTRY_HEADER_LEN: usize = 4 + 4 + 8;
13const SNAPSHOT_COMPRESSED_EXT: &str = "lz4";
14
15#[derive(Debug, Clone)]
16pub struct LoadedSnapshot {
17 pub path: PathBuf,
19 pub timestamp_ms: u64,
21 pub entries: Vec<StoredEntry>,
23}
24
25#[derive(Debug, Clone, Copy, Eq, PartialEq)]
27pub enum SnapshotCompression {
28 None,
30 Lz4,
32}
33
34impl SnapshotCompression {
35 pub const fn from_enabled(enabled: bool) -> Self {
36 if enabled { Self::Lz4 } else { Self::None }
37 }
38
39 fn file_name(self, timestamp_ms: u64) -> String {
40 match self {
41 Self::None => format!("snapshot-{timestamp_ms}.bin"),
42 Self::Lz4 => format!("snapshot-{timestamp_ms}.bin.lz4"),
43 }
44 }
45
46 fn encode(self, body: Vec<u8>) -> Vec<u8> {
47 match self {
48 Self::None => body,
49 Self::Lz4 => compress_prepend_size(&body),
50 }
51 }
52
53 fn decode_path(path: &Path, bytes: Vec<u8>) -> Result<Vec<u8>> {
54 if path
55 .extension()
56 .is_some_and(|ext| ext == SNAPSHOT_COMPRESSED_EXT)
57 {
58 decompress_size_prepended(&bytes).map_err(|error| {
59 FastCacheError::Persistence(format!("invalid compressed snapshot: {error}"))
60 })
61 } else {
62 Ok(bytes)
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
74pub struct SnapshotStore {
75 data_dir: PathBuf,
76}
77
78impl SnapshotStore {
79 pub fn new(data_dir: impl AsRef<Path>) -> Self {
80 Self {
81 data_dir: data_dir.as_ref().to_path_buf(),
82 }
83 }
84
85 fn write(
86 &self,
87 entries: &[StoredEntry],
88 timestamp_ms: u64,
89 compression: SnapshotCompression,
90 ) -> Result<PathBuf> {
91 fs::create_dir_all(&self.data_dir)?;
92
93 let body = SnapshotCodec::encode(entries, timestamp_ms)?;
94 let bytes = compression.encode(body);
95 let path = SnapshotName::path(&self.data_dir, timestamp_ms, compression);
96 fs::write(&path, bytes)?;
97 Ok(path)
98 }
99
100 fn load_latest(&self) -> Result<Option<LoadedSnapshot>> {
101 let Some(path) = self.latest_path()? else {
102 return Ok(None);
103 };
104 let bytes = fs::read(&path)?;
105 let raw = SnapshotCompression::decode_path(&path, bytes)?;
106 let (timestamp_ms, entries) = SnapshotCodec::decode(&raw)?;
107 Ok(Some(LoadedSnapshot {
108 path,
109 timestamp_ms,
110 entries,
111 }))
112 }
113
114 fn latest_path(&self) -> Result<Option<PathBuf>> {
115 let mut snapshots = Vec::new();
116 for entry in fs::read_dir(&self.data_dir)? {
117 let entry = entry?;
118 let path = entry.path();
119 if SnapshotName::matches(&path) {
120 snapshots.push(path);
121 }
122 }
123 snapshots.sort();
124 Ok(snapshots.pop())
125 }
126}
127
128pub trait SnapshotRepository {
130 fn write_snapshot(
132 &self,
133 entries: &[StoredEntry],
134 timestamp_ms: u64,
135 compression: SnapshotCompression,
136 ) -> Result<PathBuf>;
137
138 fn load_latest_snapshot(&self) -> Result<Option<LoadedSnapshot>>;
140}
141
142impl SnapshotRepository for SnapshotStore {
143 fn write_snapshot(
144 &self,
145 entries: &[StoredEntry],
146 timestamp_ms: u64,
147 compression: SnapshotCompression,
148 ) -> Result<PathBuf> {
149 self.write(entries, timestamp_ms, compression)
150 }
151
152 fn load_latest_snapshot(&self) -> Result<Option<LoadedSnapshot>> {
153 self.load_latest()
154 }
155}
156
157struct SnapshotCodec;
158
159impl SnapshotCodec {
160 fn encode(entries: &[StoredEntry], timestamp_ms: u64) -> Result<Vec<u8>> {
161 let mut entries = entries.to_vec();
162 entries.sort_by_key(|entry| hash_key(entry.key.as_ref()));
163
164 let mut body = Vec::with_capacity(
165 SNAPSHOT_HEADER_LEN + entries.len().saturating_mul(SNAPSHOT_ENTRY_HEADER_LEN),
166 );
167 body.extend_from_slice(SNAPSHOT_MAGIC);
168 body.extend_from_slice(&SNAPSHOT_VERSION.to_le_bytes());
169 body.extend_from_slice(×tamp_ms.to_le_bytes());
170 body.extend_from_slice(&(entries.len() as u64).to_le_bytes());
171 for entry in entries {
172 Self::encode_entry(&mut body, &entry)?;
173 }
174 Ok(body)
175 }
176
177 fn encode_entry(body: &mut Vec<u8>, entry: &StoredEntry) -> Result<()> {
178 let key_len = Self::encoded_len(entry.key.len(), "snapshot key is too large")?;
179 let value_len = Self::encoded_len(entry.value.len(), "snapshot value is too large")?;
180 body.extend_from_slice(&key_len.to_le_bytes());
181 body.extend_from_slice(&value_len.to_le_bytes());
182 body.extend_from_slice(&entry.expire_at_ms.unwrap_or(u64::MAX).to_le_bytes());
183 body.extend_from_slice(entry.key.as_ref());
184 body.extend_from_slice(entry.value.as_ref());
185 Ok(())
186 }
187
188 fn encoded_len(len: usize, message: &'static str) -> Result<u32> {
189 u32::try_from(len).map_err(|_| FastCacheError::Persistence(message.into()))
190 }
191
192 fn decode(raw: &[u8]) -> Result<(u64, Vec<StoredEntry>)> {
193 match raw {
194 bytes if bytes.len() < SNAPSHOT_HEADER_LEN => Err(FastCacheError::Persistence(
195 "snapshot header is truncated".into(),
196 )),
197 bytes if !bytes.starts_with(SNAPSHOT_MAGIC) => Err(FastCacheError::Persistence(
198 "snapshot magic mismatch".into(),
199 )),
200 bytes => Self::decode_validated(bytes),
201 }
202 }
203
204 fn decode_validated(raw: &[u8]) -> Result<(u64, Vec<StoredEntry>)> {
205 let mut cursor = SNAPSHOT_MAGIC.len();
206 match Self::read_u32(raw, &mut cursor, "snapshot version")? {
207 SNAPSHOT_VERSION => Self::decode_body(raw, &mut cursor),
208 version => Err(FastCacheError::Persistence(format!(
209 "unsupported snapshot version: {version}"
210 ))),
211 }
212 }
213
214 fn decode_body(raw: &[u8], cursor: &mut usize) -> Result<(u64, Vec<StoredEntry>)> {
215 let timestamp_ms = Self::read_u64(raw, cursor, "snapshot timestamp")?;
216 let entry_count = usize::try_from(Self::read_u64(raw, cursor, "snapshot entry count")?)
217 .map_err(|_| FastCacheError::Persistence("snapshot entry count is too large".into()))?;
218 let mut entries = Vec::with_capacity(entry_count);
219 for _ in 0..entry_count {
220 entries.push(Self::decode_entry(raw, cursor)?);
221 }
222 Ok((timestamp_ms, entries))
223 }
224
225 fn decode_entry(raw: &[u8], cursor: &mut usize) -> Result<StoredEntry> {
226 if raw.len().saturating_sub(*cursor) < SNAPSHOT_ENTRY_HEADER_LEN {
227 return Err(FastCacheError::Persistence(
228 "snapshot entry header is truncated".into(),
229 ));
230 }
231
232 let key_len = Self::read_u32(raw, cursor, "snapshot key length")? as usize;
233 let value_len = Self::read_u32(raw, cursor, "snapshot value length")? as usize;
234 let expire_raw = Self::read_u64(raw, cursor, "snapshot expiration")?;
235 let body_len = key_len.saturating_add(value_len);
236 if raw.len().saturating_sub(*cursor) < body_len {
237 return Err(FastCacheError::Persistence(
238 "snapshot entry body is truncated".into(),
239 ));
240 }
241
242 let key = raw[*cursor..*cursor + key_len].to_vec();
243 *cursor += key_len;
244 let value = raw[*cursor..*cursor + value_len].to_vec();
245 *cursor += value_len;
246 Ok(StoredEntry {
247 key: Bytes::from(key),
248 value: Bytes::from(value),
249 expire_at_ms: if expire_raw == u64::MAX {
250 None
251 } else {
252 Some(expire_raw)
253 },
254 })
255 }
256
257 fn read_u32(raw: &[u8], cursor: &mut usize, field: &str) -> Result<u32> {
258 let bytes = Self::read_exact(raw, cursor, 4, field)?;
259 let mut value = [0; 4];
260 value.copy_from_slice(bytes);
261 Ok(u32::from_le_bytes(value))
262 }
263
264 fn read_u64(raw: &[u8], cursor: &mut usize, field: &str) -> Result<u64> {
265 let bytes = Self::read_exact(raw, cursor, 8, field)?;
266 let mut value = [0; 8];
267 value.copy_from_slice(bytes);
268 Ok(u64::from_le_bytes(value))
269 }
270
271 fn read_exact<'a>(
272 raw: &'a [u8],
273 cursor: &mut usize,
274 len: usize,
275 field: &str,
276 ) -> Result<&'a [u8]> {
277 if raw.len().saturating_sub(*cursor) < len {
278 return Err(FastCacheError::Persistence(format!("{field} is truncated")));
279 }
280 let bytes = &raw[*cursor..*cursor + len];
281 *cursor += len;
282 Ok(bytes)
283 }
284}
285
286struct SnapshotName;
287
288impl SnapshotName {
289 fn path(data_dir: &Path, timestamp_ms: u64, compression: SnapshotCompression) -> PathBuf {
290 data_dir.join(compression.file_name(timestamp_ms))
291 }
292
293 fn matches(path: &Path) -> bool {
294 path.file_name()
295 .and_then(|value| value.to_str())
296 .is_some_and(|name| name.starts_with("snapshot-"))
297 }
298}