use crate::error::StorageError;
use async_trait::async_trait;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::io;
use std::path::{Path, PathBuf};
pub type StorageResult<T> = std::result::Result<T, StorageError>;
#[async_trait]
pub trait Storage: Send + Sync {
async fn save(&self, key: &str, value: &serde_json::Value) -> StorageResult<()>;
async fn load(&self, key: &str) -> StorageResult<Option<serde_json::Value>>;
async fn delete(&self, key: &str) -> StorageResult<()>;
async fn list(&self) -> StorageResult<Vec<String>>;
}
#[derive(Serialize, Deserialize)]
struct Envelope {
key: String,
checksum: String,
data: serde_json::Value,
}
fn checksum(value: &serde_json::Value) -> StorageResult<String> {
let bytes = serde_json::to_vec(value)?;
let mut hasher = Sha256::new();
hasher.update(&bytes);
Ok(format!("{:x}", hasher.finalize()))
}
fn safe_filename(key: &str) -> StorageResult<String> {
if key.is_empty() || key.len() > 1024 {
return Err(StorageError::InvalidKey(key.to_string()));
}
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
Ok(format!("{:x}.json", hasher.finalize()))
}
#[cfg(unix)]
fn sync_parent_dir(path: &Path) {
if let Some(parent) = path.parent() {
if let Ok(dir) = std::fs::File::open(parent) {
let _ = dir.sync_all();
}
}
}
#[cfg(not(unix))]
fn sync_parent_dir(_path: &Path) {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Durability {
Fast,
Atomic,
Durable,
}
pub struct FileStorage {
root: PathBuf,
durability: Durability,
}
impl FileStorage {
pub fn open(dir: impl AsRef<Path>) -> StorageResult<Self> {
Self::open_with(dir, Durability::Fast)
}
pub fn open_with(dir: impl AsRef<Path>, durability: Durability) -> StorageResult<Self> {
let root = dir.as_ref().to_path_buf();
std::fs::create_dir_all(&root)?;
Ok(FileStorage { root, durability })
}
pub fn durability(&self) -> Durability {
self.durability
}
fn path_for(&self, key: &str) -> StorageResult<PathBuf> {
Ok(self.root.join(safe_filename(key)?))
}
}
#[async_trait]
impl Storage for FileStorage {
async fn save(&self, key: &str, value: &serde_json::Value) -> StorageResult<()> {
let path = self.path_for(key)?;
let envelope = Envelope {
key: key.to_string(),
checksum: checksum(value)?,
data: value.clone(),
};
let bytes = serde_json::to_vec(&envelope)?;
let root = self.root.clone();
let durability = self.durability;
tokio::task::spawn_blocking(move || -> StorageResult<()> {
match durability {
Durability::Fast => {
use std::io::Write;
let mut f = std::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&path)?;
f.write_all(&bytes)?;
Ok(())
}
Durability::Atomic => {
let tmp = root.join(format!(".tmp-{}", uuid::Uuid::new_v4()));
std::fs::write(&tmp, &bytes)?;
std::fs::rename(&tmp, &path)?;
Ok(())
}
Durability::Durable => {
use std::io::Write;
let tmp = root.join(format!(".tmp-{}", uuid::Uuid::new_v4()));
{
let mut f = std::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&tmp)?;
f.write_all(&bytes)?;
f.sync_all()?;
}
std::fs::rename(&tmp, &path)?;
sync_parent_dir(&path);
Ok(())
}
}
})
.await
.map_err(|e| StorageError::Io(io::Error::other(e)))?
}
async fn load(&self, key: &str) -> StorageResult<Option<serde_json::Value>> {
let path = self.path_for(key)?;
let bytes = match tokio::fs::read(&path).await {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(e) => return Err(e.into()),
};
let envelope: Envelope = serde_json::from_slice(&bytes)?;
if checksum(&envelope.data)? != envelope.checksum {
return Err(StorageError::ChecksumMismatch(key.to_string()));
}
Ok(Some(envelope.data))
}
async fn delete(&self, key: &str) -> StorageResult<()> {
let path = self.path_for(key)?;
match tokio::fs::remove_file(&path).await {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(e.into()),
}
}
async fn list(&self) -> StorageResult<Vec<String>> {
let mut keys = Vec::new();
let mut dir = tokio::fs::read_dir(&self.root).await?;
while let Some(entry) = dir.next_entry().await? {
let name = entry.file_name();
let name = name.to_string_lossy();
if name.ends_with(".json") && !name.starts_with(".tmp-") {
if let Ok(bytes) = tokio::fs::read(entry.path()).await {
if let Ok(envelope) = serde_json::from_slice::<Envelope>(&bytes) {
keys.push(envelope.key);
}
}
}
}
Ok(keys)
}
}
#[derive(Default)]
pub struct MemoryStorage {
map: Mutex<HashMap<String, serde_json::Value>>,
}
impl MemoryStorage {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl Storage for MemoryStorage {
async fn save(&self, key: &str, value: &serde_json::Value) -> StorageResult<()> {
self.map.lock().insert(key.to_string(), value.clone());
Ok(())
}
async fn load(&self, key: &str) -> StorageResult<Option<serde_json::Value>> {
Ok(self.map.lock().get(key).cloned())
}
async fn delete(&self, key: &str) -> StorageResult<()> {
self.map.lock().remove(key);
Ok(())
}
async fn list(&self) -> StorageResult<Vec<String>> {
Ok(self.map.lock().keys().cloned().collect())
}
}