use std::fs::{self, File};
use std::io::{self, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use crate::error::{Result, ShamirError};
use crate::shamir::Share;
const MAGIC_NUMBER: &[u8] = b"SHS1"; const VERSION: u8 = 2;
pub trait ShareStore {
fn store_share(&mut self, share: &Share) -> Result<()>;
fn load_share(&self, index: u8) -> Result<Share>;
fn list_shares(&self) -> Result<Vec<u8>>;
fn delete_share(&mut self, index: u8) -> Result<()>;
}
pub struct FileShareStore {
base_dir: PathBuf,
}
impl FileShareStore {
pub fn new<P: AsRef<Path>>(base_dir: P) -> Result<Self> {
let base_dir = base_dir.as_ref().to_path_buf();
fs::create_dir_all(&base_dir)?;
Ok(Self { base_dir })
}
fn share_path(&self, index: u8) -> PathBuf {
self.base_dir.join(format!("share_{index:03}"))
}
}
impl ShareStore for FileShareStore {
fn store_share(&mut self, share: &Share) -> Result<()> {
let path = self.share_path(share.index);
let file = File::create(path)?;
let mut writer = BufWriter::new(file);
writer.write_all(MAGIC_NUMBER)?;
writer.write_all(&[VERSION])?;
let integrity_flag = if share.integrity_check { 1 } else { 0 };
let compression_flag = if share.compression { 2 } else { 0 };
let flags = integrity_flag | compression_flag;
writer.write_all(&[flags])?;
writer.write_all(&[share.index, share.threshold, share.total_shares])?;
let len = share.data.len() as u32;
writer.write_all(&len.to_le_bytes())?;
writer.write_all(&share.data)?;
Ok(())
}
fn load_share(&self, index: u8) -> Result<Share> {
let path = self.share_path(index);
let mut file = File::open(path).map_err(|e| {
if e.kind() == io::ErrorKind::NotFound {
ShamirError::InvalidShareIndex(index)
} else {
e.into()
}
})?;
let mut magic = [0u8; 4];
file.read_exact(&mut magic)?;
if magic != MAGIC_NUMBER {
return Err(ShamirError::InvalidShareFormat);
}
let mut version = [0u8; 1];
file.read_exact(&mut version)?;
if version[0] > VERSION {
return Err(ShamirError::InvalidShareFormat);
}
let mut flags = [0u8; 1];
file.read_exact(&mut flags)?;
let integrity_check = (flags[0] & 1) != 0;
let compression = (flags[0] & 2) != 0;
let mut header = [0u8; 3];
file.read_exact(&mut header)?;
let (stored_index, threshold, total_shares) = (header[0], header[1], header[2]);
if stored_index != index {
return Err(ShamirError::InvalidShareFormat);
}
let mut len_bytes = [0u8; 4];
file.read_exact(&mut len_bytes)?;
let len = u32::from_le_bytes(len_bytes) as usize;
let mut data = vec![0u8; len];
file.read_exact(&mut data)?;
Ok(Share {
index,
data,
threshold,
total_shares,
integrity_check,
compression,
})
}
fn list_shares(&self) -> Result<Vec<u8>> {
let mut indices = Vec::new();
for entry in fs::read_dir(&self.base_dir)? {
let entry = entry?;
let file_name = entry.file_name();
let file_name = file_name.to_string_lossy();
if let Some(stripped) = file_name.strip_prefix("share_") {
if let Ok(index) = stripped.parse::<u8>() {
indices.push(index);
}
}
}
indices.sort_unstable();
Ok(indices)
}
fn delete_share(&mut self, index: u8) -> Result<()> {
let path = self.share_path(index);
fs::remove_file(path).map_err(|e| {
if e.kind() == io::ErrorKind::NotFound {
ShamirError::InvalidShareIndex(index)
} else {
e.into()
}
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_file_store() -> Result<()> {
let temp_dir = tempdir()?;
let mut store = FileShareStore::new(temp_dir.path())?;
let share = Share {
index: 1,
data: vec![1, 2, 3, 4, 5],
threshold: 3, total_shares: 5, integrity_check: true,
compression: false,
};
store.store_share(&share)?;
let indices = store.list_shares()?;
assert_eq!(indices, vec![1]);
let loaded = store.load_share(1)?;
assert_eq!(loaded.index, share.index);
assert_eq!(loaded.data, share.data);
store.delete_share(1)?;
assert!(store.load_share(1).is_err());
assert!(store.list_shares()?.is_empty());
Ok(())
}
#[test]
fn test_invalid_share_access() {
let temp_dir = tempdir().unwrap();
let mut store = FileShareStore::new(temp_dir.path()).unwrap();
assert!(matches!(
store.load_share(1),
Err(ShamirError::InvalidShareIndex(1))
));
assert!(matches!(
store.delete_share(1),
Err(ShamirError::InvalidShareIndex(1))
));
}
#[test]
fn test_multiple_shares() -> Result<()> {
let temp_dir = tempdir()?;
let mut store = FileShareStore::new(temp_dir.path())?;
for i in 1..=5 {
let share = Share {
index: i,
data: vec![i; 5],
threshold: 3, total_shares: 5, integrity_check: true,
compression: false,
};
store.store_share(&share)?;
}
let indices = store.list_shares()?;
assert_eq!(indices, vec![1, 2, 3, 4, 5]);
for i in 1..=5 {
let share = store.load_share(i)?;
assert_eq!(share.index, i);
assert_eq!(share.data, vec![i; 5]);
}
Ok(())
}
#[test]
fn test_special_characters_path() -> Result<()> {
let temp_dir = tempdir()?;
let dir_path = temp_dir.path().join("special!@#$%^&()_-=+ chars");
let mut store = FileShareStore::new(&dir_path)?;
let share = Share {
index: 1,
data: vec![1, 2, 3],
threshold: 3,
total_shares: 5,
integrity_check: true,
compression: false,
};
store.store_share(&share)?;
let loaded = store.load_share(1)?;
assert_eq!(loaded.data, share.data);
Ok(())
}
#[test]
fn test_read_only_directory() {
let temp_dir = tempdir().unwrap();
let mut perms = fs::metadata(temp_dir.path()).unwrap().permissions();
perms.set_readonly(true);
fs::set_permissions(temp_dir.path(), perms).unwrap();
let mut store = FileShareStore::new(temp_dir.path()).unwrap();
let share = Share {
index: 1,
data: vec![1, 2, 3],
threshold: 3,
total_shares: 5,
integrity_check: true,
compression: false,
};
assert!(matches!(
store.store_share(&share),
Err(ShamirError::IoError(_))
));
}
}