use std::fs;
use std::path::{Path, PathBuf};
use lz4_flex::{compress_prepend_size, decompress_size_prepended};
use crate::storage::{Bytes, StoredEntry, hash_key};
use crate::{FastCacheError, Result};
const SNAPSHOT_MAGIC: &[u8; 8] = b"FCSNAP1\0";
const SNAPSHOT_VERSION: u32 = 1;
const SNAPSHOT_HEADER_LEN: usize = 8 + 4 + 8 + 8;
const SNAPSHOT_ENTRY_HEADER_LEN: usize = 4 + 4 + 8;
const SNAPSHOT_COMPRESSED_EXT: &str = "lz4";
#[derive(Debug, Clone)]
pub struct LoadedSnapshot {
pub path: PathBuf,
pub timestamp_ms: u64,
pub entries: Vec<StoredEntry>,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum SnapshotCompression {
None,
Lz4,
}
impl SnapshotCompression {
pub const fn from_enabled(enabled: bool) -> Self {
if enabled { Self::Lz4 } else { Self::None }
}
fn file_name(self, timestamp_ms: u64) -> String {
match self {
Self::None => format!("snapshot-{timestamp_ms}.bin"),
Self::Lz4 => format!("snapshot-{timestamp_ms}.bin.lz4"),
}
}
fn encode(self, body: Vec<u8>) -> Vec<u8> {
match self {
Self::None => body,
Self::Lz4 => compress_prepend_size(&body),
}
}
fn decode_path(path: &Path, bytes: Vec<u8>) -> Result<Vec<u8>> {
if path
.extension()
.is_some_and(|ext| ext == SNAPSHOT_COMPRESSED_EXT)
{
decompress_size_prepended(&bytes).map_err(|error| {
FastCacheError::Persistence(format!("invalid compressed snapshot: {error}"))
})
} else {
Ok(bytes)
}
}
}
#[derive(Debug, Clone)]
pub struct SnapshotStore {
data_dir: PathBuf,
}
impl SnapshotStore {
pub fn new(data_dir: impl AsRef<Path>) -> Self {
Self {
data_dir: data_dir.as_ref().to_path_buf(),
}
}
fn write(
&self,
entries: &[StoredEntry],
timestamp_ms: u64,
compression: SnapshotCompression,
) -> Result<PathBuf> {
fs::create_dir_all(&self.data_dir)?;
let body = SnapshotCodec::encode(entries, timestamp_ms)?;
let bytes = compression.encode(body);
let path = SnapshotName::path(&self.data_dir, timestamp_ms, compression);
fs::write(&path, bytes)?;
Ok(path)
}
fn load_latest(&self) -> Result<Option<LoadedSnapshot>> {
let Some(path) = self.latest_path()? else {
return Ok(None);
};
let bytes = fs::read(&path)?;
let raw = SnapshotCompression::decode_path(&path, bytes)?;
let (timestamp_ms, entries) = SnapshotCodec::decode(&raw)?;
Ok(Some(LoadedSnapshot {
path,
timestamp_ms,
entries,
}))
}
fn latest_path(&self) -> Result<Option<PathBuf>> {
let mut snapshots = Vec::new();
for entry in fs::read_dir(&self.data_dir)? {
let entry = entry?;
let path = entry.path();
if SnapshotName::matches(&path) {
snapshots.push(path);
}
}
snapshots.sort();
Ok(snapshots.pop())
}
}
pub trait SnapshotRepository {
fn write_snapshot(
&self,
entries: &[StoredEntry],
timestamp_ms: u64,
compression: SnapshotCompression,
) -> Result<PathBuf>;
fn load_latest_snapshot(&self) -> Result<Option<LoadedSnapshot>>;
}
impl SnapshotRepository for SnapshotStore {
fn write_snapshot(
&self,
entries: &[StoredEntry],
timestamp_ms: u64,
compression: SnapshotCompression,
) -> Result<PathBuf> {
self.write(entries, timestamp_ms, compression)
}
fn load_latest_snapshot(&self) -> Result<Option<LoadedSnapshot>> {
self.load_latest()
}
}
struct SnapshotCodec;
impl SnapshotCodec {
fn encode(entries: &[StoredEntry], timestamp_ms: u64) -> Result<Vec<u8>> {
let mut entries = entries.to_vec();
entries.sort_by_key(|entry| hash_key(entry.key.as_ref()));
let mut body = Vec::with_capacity(
SNAPSHOT_HEADER_LEN + entries.len().saturating_mul(SNAPSHOT_ENTRY_HEADER_LEN),
);
body.extend_from_slice(SNAPSHOT_MAGIC);
body.extend_from_slice(&SNAPSHOT_VERSION.to_le_bytes());
body.extend_from_slice(×tamp_ms.to_le_bytes());
body.extend_from_slice(&(entries.len() as u64).to_le_bytes());
for entry in entries {
Self::encode_entry(&mut body, &entry)?;
}
Ok(body)
}
fn encode_entry(body: &mut Vec<u8>, entry: &StoredEntry) -> Result<()> {
let key_len = Self::encoded_len(entry.key.len(), "snapshot key is too large")?;
let value_len = Self::encoded_len(entry.value.len(), "snapshot value is too large")?;
body.extend_from_slice(&key_len.to_le_bytes());
body.extend_from_slice(&value_len.to_le_bytes());
body.extend_from_slice(&entry.expire_at_ms.unwrap_or(u64::MAX).to_le_bytes());
body.extend_from_slice(entry.key.as_ref());
body.extend_from_slice(entry.value.as_ref());
Ok(())
}
fn encoded_len(len: usize, message: &'static str) -> Result<u32> {
u32::try_from(len).map_err(|_| FastCacheError::Persistence(message.into()))
}
fn decode(raw: &[u8]) -> Result<(u64, Vec<StoredEntry>)> {
match raw {
bytes if bytes.len() < SNAPSHOT_HEADER_LEN => Err(FastCacheError::Persistence(
"snapshot header is truncated".into(),
)),
bytes if !bytes.starts_with(SNAPSHOT_MAGIC) => Err(FastCacheError::Persistence(
"snapshot magic mismatch".into(),
)),
bytes => Self::decode_validated(bytes),
}
}
fn decode_validated(raw: &[u8]) -> Result<(u64, Vec<StoredEntry>)> {
let mut cursor = SNAPSHOT_MAGIC.len();
match Self::read_u32(raw, &mut cursor, "snapshot version")? {
SNAPSHOT_VERSION => Self::decode_body(raw, &mut cursor),
version => Err(FastCacheError::Persistence(format!(
"unsupported snapshot version: {version}"
))),
}
}
fn decode_body(raw: &[u8], cursor: &mut usize) -> Result<(u64, Vec<StoredEntry>)> {
let timestamp_ms = Self::read_u64(raw, cursor, "snapshot timestamp")?;
let entry_count = usize::try_from(Self::read_u64(raw, cursor, "snapshot entry count")?)
.map_err(|_| FastCacheError::Persistence("snapshot entry count is too large".into()))?;
let mut entries = Vec::with_capacity(entry_count);
for _ in 0..entry_count {
entries.push(Self::decode_entry(raw, cursor)?);
}
Ok((timestamp_ms, entries))
}
fn decode_entry(raw: &[u8], cursor: &mut usize) -> Result<StoredEntry> {
if raw.len().saturating_sub(*cursor) < SNAPSHOT_ENTRY_HEADER_LEN {
return Err(FastCacheError::Persistence(
"snapshot entry header is truncated".into(),
));
}
let key_len = Self::read_u32(raw, cursor, "snapshot key length")? as usize;
let value_len = Self::read_u32(raw, cursor, "snapshot value length")? as usize;
let expire_raw = Self::read_u64(raw, cursor, "snapshot expiration")?;
let body_len = key_len.saturating_add(value_len);
if raw.len().saturating_sub(*cursor) < body_len {
return Err(FastCacheError::Persistence(
"snapshot entry body is truncated".into(),
));
}
let key = raw[*cursor..*cursor + key_len].to_vec();
*cursor += key_len;
let value = raw[*cursor..*cursor + value_len].to_vec();
*cursor += value_len;
Ok(StoredEntry {
key: Bytes::from(key),
value: Bytes::from(value),
expire_at_ms: if expire_raw == u64::MAX {
None
} else {
Some(expire_raw)
},
})
}
fn read_u32(raw: &[u8], cursor: &mut usize, field: &str) -> Result<u32> {
let bytes = Self::read_exact(raw, cursor, 4, field)?;
let mut value = [0; 4];
value.copy_from_slice(bytes);
Ok(u32::from_le_bytes(value))
}
fn read_u64(raw: &[u8], cursor: &mut usize, field: &str) -> Result<u64> {
let bytes = Self::read_exact(raw, cursor, 8, field)?;
let mut value = [0; 8];
value.copy_from_slice(bytes);
Ok(u64::from_le_bytes(value))
}
fn read_exact<'a>(
raw: &'a [u8],
cursor: &mut usize,
len: usize,
field: &str,
) -> Result<&'a [u8]> {
if raw.len().saturating_sub(*cursor) < len {
return Err(FastCacheError::Persistence(format!("{field} is truncated")));
}
let bytes = &raw[*cursor..*cursor + len];
*cursor += len;
Ok(bytes)
}
}
struct SnapshotName;
impl SnapshotName {
fn path(data_dir: &Path, timestamp_ms: u64, compression: SnapshotCompression) -> PathBuf {
data_dir.join(compression.file_name(timestamp_ms))
}
fn matches(path: &Path) -> bool {
path.file_name()
.and_then(|value| value.to_str())
.is_some_and(|name| name.starts_with("snapshot-"))
}
}