use anyhow::{Context, Result, anyhow};
use getrandom::fill;
use std::fs::{self, OpenOptions};
use std::io::Write;
use std::path::{Path, PathBuf};
#[cfg(not(target_os = "windows"))]
use std::fs::File;
#[cfg(unix)]
use std::os::unix::fs::{MetadataExt, PermissionsExt};
#[derive(Clone)]
pub struct Storage {
path: PathBuf,
}
impl Storage {
pub fn new(path: PathBuf) -> Self {
Self { path }
}
pub fn exists(&self) -> bool {
self.path.exists()
}
pub fn load(&self) -> Result<Vec<u8>> {
#[cfg(unix)]
self.security_check()?;
Ok(fs::read(&self.path)?)
}
pub fn save(&self, data: &[u8]) -> Result<()> {
if let Some(parent) = self.path.parent() {
fs::create_dir_all(parent)?;
#[cfg(unix)]
self.ensure_dir_permissions(parent)?;
}
let tmp_path = self.random_tmp_path()?;
{
let mut tmp_file = OpenOptions::new()
.write(true)
.create_new(true)
.open(&tmp_path)
.context("failed to create temporary file")?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
tmp_file.set_permissions(fs::Permissions::from_mode(0o600))?;
}
tmp_file.write_all(data)?;
tmp_file.sync_all()?; }
if let Err(e) = self.atomic_replace(&tmp_path) {
let _ = fs::remove_file(&tmp_path);
return Err(e);
}
#[cfg(not(target_os = "windows"))]
if let Some(parent) = self.path.parent() {
let dir = File::open(parent)?;
dir.sync_all()?;
}
Ok(())
}
pub fn path(&self) -> &PathBuf {
&self.path
}
fn random_tmp_path(&self) -> Result<PathBuf> {
let mut buf = [0u8; 8]; fill(&mut buf)?;
let rand_string = buf.iter().map(|b| format!("{:02x}", b)).collect::<String>();
let file_name = self
.path
.file_name()
.ok_or_else(|| anyhow!("invalid storage path"))?
.to_string_lossy();
let tmp_name = format!("{}.tmp.{}", file_name, rand_string);
Ok(self.path.with_file_name(tmp_name))
}
#[cfg(target_os = "windows")]
fn atomic_replace(&self, tmp_path: &Path) -> Result<()> {
use std::ffi::OsStr;
use std::os::windows::ffi::OsStrExt;
use windows_sys::Win32::Foundation::GetLastError;
use windows_sys::Win32::{
Foundation::ERROR_FILE_NOT_FOUND,
Storage::FileSystem::{REPLACEFILE_WRITE_THROUGH, ReplaceFileW},
};
fn to_wide(s: &OsStr) -> Vec<u16> {
s.encode_wide().chain(std::iter::once(0)).collect()
}
let target_w = to_wide(self.path.as_os_str());
let tmp_w = to_wide(tmp_path.as_os_str());
let result = unsafe {
ReplaceFileW(
target_w.as_ptr(),
tmp_w.as_ptr(),
std::ptr::null(),
REPLACEFILE_WRITE_THROUGH,
std::ptr::null(),
std::ptr::null(),
)
};
if result != 0 {
return Ok(());
}
let err_code = unsafe { GetLastError() };
if err_code == ERROR_FILE_NOT_FOUND {
fs::rename(tmp_path, &self.path).context("failed to create initial file")?;
return Ok(());
}
Err(std::io::Error::from_raw_os_error(err_code as i32)).context("atomic replace failed")
}
#[cfg(not(target_os = "windows"))]
fn atomic_replace(&self, tmp_path: &Path) -> Result<()> {
fs::rename(tmp_path, &self.path).context("atomic rename failed")?;
Ok(())
}
#[cfg(unix)]
fn security_check(&self) -> Result<()> {
let meta = fs::symlink_metadata(&self.path)?;
if meta.file_type().is_symlink() {
return Err(anyhow!("keystore path must not be a symlink"));
}
let mode = meta.mode() & 0o777;
if mode & 0o077 != 0 {
eprintln!(
"Warning: keynest store permissions too open ({:o}). Fixing to 600.",
mode
);
let mut perms = meta.permissions();
perms.set_mode(0o600);
fs::set_permissions(&self.path, perms)?;
}
if let Some(parent) = self.path.parent() {
self.check_dir_permissions(parent)?;
}
Ok(())
}
#[cfg(unix)]
fn check_dir_permissions(&self, dir: &Path) -> Result<()> {
let meta = fs::symlink_metadata(dir)?;
if meta.file_type().is_symlink() {
return Err(anyhow!("directory must not be a symlink"));
}
let mode = meta.mode() & 0o777;
if mode & 0o077 != 0 {
eprintln!(
"Warning: keynest directory permissions too open ({:o}). Recommended 700",
mode
);
}
Ok(())
}
#[cfg(unix)]
fn ensure_dir_permissions(&self, dir: &Path) -> Result<()> {
let meta = fs::symlink_metadata(dir)?;
if meta.file_type().is_symlink() {
return Err(anyhow!("directory must not be a symlink"));
}
let mode = meta.mode() & 0o777;
if mode & 0o077 != 0 {
let mut perms = meta.permissions();
perms.set_mode(0o700);
fs::set_permissions(dir, perms)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[test]
fn load_returns_written_data() {
let dir = tempdir().unwrap();
let path = dir.path().join("store.db");
let storage = Storage::new(path.clone());
storage.save(b"hello world").unwrap();
let data = storage.load().unwrap();
assert_eq!(data, b"hello world");
}
#[test]
fn load_fails_if_file_does_not_exist() {
let dir = tempdir().unwrap();
let path = dir.path().join("missing.db");
let storage = Storage::new(path);
let result = storage.load();
assert!(result.is_err());
}
#[test]
fn exists_returns_false_if_missing() {
let dir = tempdir().unwrap();
let path = dir.path().join("store.db");
let storage = Storage::new(path);
assert!(!storage.exists());
}
#[test]
fn exists_returns_true_after_save() {
let dir = tempdir().unwrap();
let path = dir.path().join("store.db");
let storage = Storage::new(path.clone());
storage.save(b"data").unwrap();
assert!(storage.exists());
}
#[test]
fn random_tmp_path_has_same_parent() {
let dir = tempdir().unwrap();
let path = dir.path().join("store.db");
let storage = Storage::new(path.clone());
let tmp = storage.random_tmp_path().unwrap();
assert_eq!(tmp.parent(), path.parent());
}
#[test]
fn random_tmp_path_is_not_equal_to_final_path() {
let dir = tempdir().unwrap();
let path = dir.path().join("store.db");
let storage = Storage::new(path.clone());
let tmp = storage.random_tmp_path().unwrap();
assert_ne!(tmp, path);
}
#[test]
fn tmp_names_are_unique() {
let dir = tempdir().unwrap();
let path = dir.path().join("store.db");
let storage = Storage::new(path);
let a = storage.random_tmp_path().unwrap();
let b = storage.random_tmp_path().unwrap();
assert_ne!(a, b);
}
#[test]
fn save_overwrites_large_data() {
let dir = tempdir().unwrap();
let path = dir.path().join("store.db");
let storage = Storage::new(path.clone());
let large = vec![42u8; 10_000];
storage.save(&large).unwrap();
let loaded = storage.load().unwrap();
assert_eq!(loaded.len(), 10_000);
assert_eq!(loaded, large);
}
#[test]
fn save_replaces_existing_file() {
let dir = tempdir().unwrap();
let path = dir.path().join("store.db");
let storage = Storage::new(path.clone());
storage.save(b"first").unwrap();
storage.save(b"second").unwrap();
let content = fs::read(path).unwrap();
assert_eq!(content, b"second");
}
#[test]
fn tmp_file_is_removed_after_success() {
let dir = tempdir().unwrap();
let path = dir.path().join("store.db");
let storage = Storage::new(path.clone());
storage.save(b"data").unwrap();
let entries: Vec<_> = fs::read_dir(dir.path())
.unwrap()
.map(|e| e.unwrap().file_name())
.collect();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0], "store.db");
}
#[test]
fn parent_directory_is_created() {
let dir = tempdir().unwrap();
let nested = dir.path().join("a").join("b").join("c").join("store.db");
let storage = Storage::new(nested.clone());
storage.save(b"data").unwrap();
assert!(nested.exists());
}
#[test]
fn save_sets_file_permissions_0600() {
let dir = tempdir().unwrap();
let path = dir.path().join("store.db");
let storage = Storage::new(path.clone());
storage.save(b"data").unwrap();
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let meta = std::fs::metadata(&path).unwrap();
let mode = meta.permissions().mode() & 0o777;
assert_eq!(mode, 0o600, "file should be 0600");
}
}
}