use crate::error::{Error, Result};
use crate::utils::security::{ensure_secure_dir, set_secure_file_permissions};
use serde::{Serialize, de::DeserializeOwned};
use std::path::Path;
pub trait StorageBackend: Clone + Send + Sync {
fn extension(&self) -> &str;
fn serialize<T: Serialize>(&self, data: &T) -> Result<String>;
fn deserialize<T: DeserializeOwned>(&self, content: &str) -> Result<T>;
fn read<T: DeserializeOwned>(&self, path: &Path) -> Result<T> {
let content = std::fs::read_to_string(path).map_err(|e| Error::FileRead {
path: path.to_path_buf(),
source: e,
})?;
self.deserialize(&content)
}
fn write<T: Serialize>(&self, path: &Path, data: &T) -> Result<()> {
use std::io::Write;
let content = self.serialize(data)?;
if let Some(parent) = path.parent()
&& !parent.exists()
{
ensure_secure_dir(parent)?;
}
let file_name = path.file_name().ok_or_else(|| {
Error::Config(format!(
"Invalid path '{}': must have a filename",
path.display()
))
})?;
let mut temp_filename = file_name.to_os_string();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
temp_filename.push(format!(".{now}.tmp"));
let temp_path = path.with_file_name(temp_filename);
let result = (|| -> Result<()> {
let mut temp_file =
std::fs::File::create(&temp_path).map_err(|e| Error::FileWrite {
path: temp_path.clone(),
source: e,
})?;
temp_file
.write_all(content.as_bytes())
.map_err(|e| Error::FileWrite {
path: temp_path.clone(),
source: e,
})?;
temp_file.sync_all().map_err(|e| Error::FileWrite {
path: temp_path.clone(),
source: e,
})?;
set_secure_file_permissions(&temp_path)?;
drop(temp_file);
#[cfg(not(windows))]
{
std::fs::rename(&temp_path, path).map_err(|e| Error::FileWrite {
path: path.to_path_buf(),
source: e,
})?;
}
#[cfg(windows)]
{
let mut retries = 0;
let max_retries = 5;
loop {
match std::fs::rename(&temp_path, path) {
Ok(_) => break,
Err(e)
if e.kind() == std::io::ErrorKind::PermissionDenied
&& retries < max_retries =>
{
retries += 1;
std::thread::sleep(std::time::Duration::from_millis(10 * retries));
continue;
}
Err(e) => {
return Err(Error::FileWrite {
path: path.to_path_buf(),
source: e,
});
}
}
}
}
set_secure_file_permissions(path)?;
Ok(())
})();
if result.is_err() {
let _ = std::fs::remove_file(&temp_path);
}
result
}
}
#[derive(Clone)]
pub struct JsonStorage {
pretty: bool,
}
impl Default for JsonStorage {
fn default() -> Self {
Self::new()
}
}
impl JsonStorage {
#[must_use]
pub fn new() -> Self {
Self { pretty: true }
}
#[must_use]
pub fn compact() -> Self {
Self { pretty: false }
}
}
impl StorageBackend for JsonStorage {
fn extension(&self) -> &'static str {
"json"
}
fn serialize<T: Serialize>(&self, data: &T) -> Result<String> {
if self.pretty {
serde_json::to_string_pretty(data).map_err(Error::from)
} else {
serde_json::to_string(data).map_err(Error::from)
}
}
fn deserialize<T: DeserializeOwned>(&self, content: &str) -> Result<T> {
serde_json::from_str(content).map_err(Error::from)
}
}
#[cfg(feature = "toml")]
#[derive(Clone, Default)]
pub struct TomlStorage;
#[cfg(feature = "toml")]
impl TomlStorage {
#[must_use]
pub fn new() -> Self {
Self
}
}
#[cfg(feature = "toml")]
impl StorageBackend for TomlStorage {
fn extension(&self) -> &'static str {
"toml"
}
fn serialize<T: Serialize>(&self, data: &T) -> Result<String> {
toml::to_string(data).map_err(|e| Error::Parse(e.to_string()))
}
fn deserialize<T: DeserializeOwned>(&self, content: &str) -> Result<T> {
toml::from_str(content).map_err(|e| Error::Parse(e.to_string()))
}
}
#[cfg(feature = "yaml")]
#[derive(Clone, Default)]
pub struct YamlStorage;
#[cfg(feature = "yaml")]
impl YamlStorage {
#[must_use]
pub fn new() -> Self {
Self
}
}
#[cfg(feature = "yaml")]
impl StorageBackend for YamlStorage {
fn extension(&self) -> &'static str {
"yaml"
}
fn serialize<T: Serialize>(&self, data: &T) -> Result<String> {
serde_yaml::to_string(data).map_err(|e| Error::Parse(e.to_string()))
}
fn deserialize<T: DeserializeOwned>(&self, content: &str) -> Result<T> {
serde_yaml::from_str(content).map_err(|e| Error::Parse(e.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use tempfile::tempdir;
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct TestData {
name: String,
value: i32,
}
#[test]
fn test_json_serialize_pretty() {
let storage = JsonStorage::new();
let data = TestData {
name: "test".into(),
value: 42,
};
let json = storage.serialize(&data).unwrap();
assert!(json.contains('\n')); assert!(json.contains("\"name\": \"test\""));
}
#[test]
fn test_json_serialize_compact() {
let storage = JsonStorage::compact();
let data = TestData {
name: "test".into(),
value: 42,
};
let json = storage.serialize(&data).unwrap();
assert!(!json.contains('\n')); }
#[test]
fn test_json_roundtrip_sync() {
let storage = JsonStorage::new();
let dir = tempdir().unwrap();
let path = dir.path().join("test.json");
let data = TestData {
name: "hello".into(),
value: 123,
};
storage.write(&path, &data).unwrap();
let loaded: TestData = storage.read(&path).unwrap();
assert_eq!(data, loaded);
}
#[test]
fn test_json_roundtrip_async() {
let storage = JsonStorage::new();
let dir = tempdir().unwrap();
let path = dir.path().join("subdir/test.json");
let data = TestData {
name: "async test".into(),
value: 999,
};
storage.write(&path, &data).unwrap();
let loaded: TestData = storage.read(&path).unwrap();
assert_eq!(data, loaded);
}
#[test]
fn test_read_nonexistent_file() {
let storage = JsonStorage::new();
let result: Result<TestData> = storage.read(Path::new("/nonexistent/file.json"));
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), Error::FileRead { .. }));
}
#[test]
#[cfg(feature = "toml")]
fn test_toml_roundtrip() {
let storage = TomlStorage::new();
let dir = tempdir().unwrap();
let path = dir.path().join("test.toml");
let data = TestData {
name: "toml_test".into(),
value: 99,
};
storage.write(&path, &data).unwrap();
let loaded: TestData = storage.read(&path).unwrap();
assert_eq!(data, loaded);
let content = std::fs::read_to_string(&path).unwrap();
assert!(content.contains("name = \"toml_test\""));
assert!(content.contains("value = 99"));
}
#[test]
#[cfg(feature = "yaml")]
fn test_yaml_roundtrip() {
let storage = super::YamlStorage::new();
let dir = tempdir().unwrap();
let path = dir.path().join("test.yaml");
let data = TestData {
name: "yaml_test".into(),
value: 99,
};
storage.write(&path, &data).unwrap();
let loaded: TestData = storage.read(&path).unwrap();
assert_eq!(data, loaded);
let content = std::fs::read_to_string(&path).unwrap();
assert!(content.contains("name: yaml_test"));
assert!(content.contains("value: 99"));
}
}