use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use rand::RngCore;
use crate::error::{Error, Result};
pub trait SerialStore: Send + Sync + 'static {
fn next(&self) -> Result<u64>;
}
pub struct FileSerialStore {
path: PathBuf,
current: parking_lot::Mutex<u64>,
}
impl FileSerialStore {
pub fn open(path: impl Into<PathBuf>) -> Result<Self> {
let path = path.into();
let start = match std::fs::read_to_string(&path) {
Ok(s) => s.trim().parse::<u64>().unwrap_or(0),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => 0,
Err(e) => return Err(Error::Io(e)),
};
Ok(Self { path, current: parking_lot::Mutex::new(start) })
}
}
impl SerialStore for FileSerialStore {
fn next(&self) -> Result<u64> {
let mut guard = self.current.lock();
*guard = guard.saturating_add(1);
let next = *guard;
atomic_write(&self.path, next.to_string().as_bytes())?;
Ok(next)
}
}
pub struct InMemorySerialStore(AtomicU64);
impl InMemorySerialStore {
pub fn new() -> Self {
Self(AtomicU64::new(0))
}
}
impl Default for InMemorySerialStore {
fn default() -> Self {
Self::new()
}
}
impl SerialStore for InMemorySerialStore {
fn next(&self) -> Result<u64> {
Ok(self.0.fetch_add(1, Ordering::Relaxed).saturating_add(1))
}
}
pub fn random_serial_bytes() -> [u8; 20] {
let mut buf = [0u8; 20];
rand::rngs::OsRng.fill_bytes(&mut buf);
buf[0] &= 0x7F;
buf
}
fn atomic_write(path: &Path, bytes: &[u8]) -> Result<()> {
let dir = path.parent().unwrap_or_else(|| Path::new("."));
let file_name = path
.file_name()
.ok_or_else(|| Error::Unexpected("serial path has no file name".into()))?;
let mut tmp = dir.to_path_buf();
tmp.push(format!(".{}.tmp", file_name.to_string_lossy()));
{
let mut f = std::fs::File::create(&tmp)?;
f.write_all(bytes)?;
f.sync_all()?;
}
std::fs::rename(&tmp, path)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn file_serial_survives_reopen() {
let dir = tempdir().unwrap();
let path = dir.path().join("serial");
let a = FileSerialStore::open(&path).unwrap();
assert_eq!(a.next().unwrap(), 1);
assert_eq!(a.next().unwrap(), 2);
drop(a);
let b = FileSerialStore::open(&path).unwrap();
assert_eq!(b.next().unwrap(), 3);
}
#[test]
fn atomic_write_does_not_leave_tmp_on_success() {
let dir = tempdir().unwrap();
let path = dir.path().join("serial");
atomic_write(&path, b"42").unwrap();
assert_eq!(std::fs::read_to_string(&path).unwrap(), "42");
let entries: Vec<_> = std::fs::read_dir(dir.path()).unwrap().collect();
assert_eq!(entries.len(), 1, "temp file should have been renamed away");
}
}