use std::net::IpAddr;
use std::path::Path;
use serde::{Deserialize, Serialize};
use xxhash_rust::xxh3::xxh3_64;
use crate::error::{Error, Result};
const MAGIC: &[u8; 4] = b"F2RS";
const VERSION: u8 = 3;
const HEADER_SIZE: usize = 13;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateSnapshot {
pub bans: Vec<BanRecord>,
pub ban_counts: Vec<(IpAddr, u32)>,
pub snapshot_time: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct BanRecord {
pub ip: IpAddr,
pub jail_id: String,
pub banned_at: i64,
pub expires_at: Option<i64>,
}
pub fn save(path: &Path, snapshot: &StateSnapshot) -> Result<()> {
let payload =
postcard::to_allocvec(snapshot).map_err(|e| Error::state_corrupt(format!("{e}")))?;
let hash = xxh3_64(&payload);
let mut buf = Vec::with_capacity(HEADER_SIZE + payload.len());
buf.extend_from_slice(MAGIC);
buf.push(VERSION);
buf.extend_from_slice(&hash.to_le_bytes());
buf.extend_from_slice(&payload);
let dir = path.parent().ok_or_else(|| {
Error::io(
"state file has no parent directory",
std::io::Error::new(std::io::ErrorKind::InvalidInput, "no parent"),
)
})?;
std::fs::create_dir_all(dir).map_err(|e| Error::io("creating state directory", e))?;
let tmp_path = path.with_extension("tmp");
std::fs::write(&tmp_path, &buf).map_err(|e| Error::io("writing state temp file", e))?;
let f = std::fs::File::open(&tmp_path).map_err(|e| Error::io("opening state temp file", e))?;
f.sync_all()
.map_err(|e| Error::io("fsyncing state file", e))?;
std::fs::rename(&tmp_path, path).map_err(|e| Error::io("renaming state file", e))?;
Ok(())
}
pub fn load(path: &Path) -> Result<Option<StateSnapshot>> {
let data = match std::fs::read(path) {
Ok(d) => d,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(e) => return Err(Error::io("reading state file", e)),
};
if data.len() < HEADER_SIZE {
return Err(Error::state_corrupt("file too small"));
}
if &data[..4] != MAGIC {
return Err(Error::state_corrupt(format!(
"bad magic: expected F2RS, got {:?}",
&data[..4]
)));
}
let version = data[4];
if version != VERSION {
return Err(Error::state_corrupt(format!(
"unsupported version: {} (expected {})",
version, VERSION
)));
}
let stored_hash = u64::from_le_bytes([
data[5], data[6], data[7], data[8], data[9], data[10], data[11], data[12],
]);
let payload = &data[HEADER_SIZE..];
let computed_hash = xxh3_64(payload);
if stored_hash != computed_hash {
return Err(Error::state_corrupt(format!(
"xxh3 mismatch: stored={stored_hash:#x}, computed={computed_hash:#x}"
)));
}
let snapshot =
postcard::from_bytes(payload).map_err(|e| Error::state_corrupt(format!("{e}")))?;
Ok(Some(snapshot))
}