use std::fs;
use std::io;
use std::io::Write;
use std::path::{Path, PathBuf};
pub trait SequenceProvider: Send {
fn next(&mut self) -> Result<i64, io::Error>;
}
#[derive(Debug)]
pub struct InMemorySeq {
current: i64,
}
impl InMemorySeq {
#[must_use]
pub const fn starting_at(initial: i64) -> Self {
Self { current: initial }
}
#[must_use]
pub const fn from_zero() -> Self {
Self::starting_at(0)
}
}
impl Default for InMemorySeq {
fn default() -> Self {
Self::from_zero()
}
}
impl SequenceProvider for InMemorySeq {
fn next(&mut self) -> Result<i64, io::Error> {
self.current = self
.current
.checked_add(1)
.ok_or_else(|| io::Error::other("sequence counter overflowed i64"))?;
Ok(self.current)
}
}
#[derive(Debug)]
pub struct FileSeq {
path: PathBuf,
current: i64,
}
impl FileSeq {
pub fn open_or_create(path: impl AsRef<Path>) -> io::Result<Self> {
let path = path.as_ref().to_path_buf();
let current = if path.exists() {
let raw = fs::read_to_string(&path)?;
raw.trim().parse::<i64>().map_err(|err| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("seq counter file is corrupted: {err}"),
)
})?
} else {
0
};
Ok(Self { path, current })
}
#[must_use]
pub const fn current(&self) -> i64 {
self.current
}
fn persist(&self) -> io::Result<()> {
let mut tmp = self.path.clone().into_os_string();
tmp.push(".tmp");
let tmp = PathBuf::from(tmp);
{
let mut file = fs::File::create(&tmp)?;
file.write_all(self.current.to_string().as_bytes())?;
file.sync_all()?;
}
fs::rename(&tmp, &self.path)?;
#[cfg(unix)]
{
let dir = self
.path
.parent()
.filter(|p| !p.as_os_str().is_empty())
.unwrap_or_else(|| Path::new("."));
fs::File::open(dir)?.sync_all()?;
}
Ok(())
}
}
impl SequenceProvider for FileSeq {
fn next(&mut self) -> io::Result<i64> {
let next = self
.current
.checked_add(1)
.ok_or_else(|| io::Error::other("sequence counter overflowed i64"))?;
self.current = next;
self.persist()?;
Ok(next)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use std::path::PathBuf;
fn unique_tempfile(suffix: &str) -> PathBuf {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or_default();
let tid = std::thread::current().id();
env::temp_dir().join(format!("hdm-am-seq-test-{tid:?}-{nanos}-{suffix}"))
}
#[test]
fn in_memory_seq_starts_at_initial_plus_one() {
let mut seq = InMemorySeq::starting_at(42);
assert_eq!(seq.next().unwrap(), 43);
assert_eq!(seq.next().unwrap(), 44);
assert_eq!(seq.next().unwrap(), 45);
}
#[test]
fn in_memory_seq_default_starts_at_one() {
let mut seq = InMemorySeq::default();
assert_eq!(seq.next().unwrap(), 1);
}
#[test]
fn in_memory_seq_overflows_safely() {
let mut seq = InMemorySeq::starting_at(i64::MAX);
let err = seq.next().expect_err("expected overflow");
assert_eq!(err.kind(), io::ErrorKind::Other);
}
#[test]
fn file_seq_persists_across_reopen() {
let path = unique_tempfile("persist");
let _ = fs::remove_file(&path);
{
let mut seq = FileSeq::open_or_create(&path).unwrap();
assert_eq!(seq.next().unwrap(), 1);
assert_eq!(seq.next().unwrap(), 2);
assert_eq!(seq.next().unwrap(), 3);
}
{
let mut seq = FileSeq::open_or_create(&path).unwrap();
assert_eq!(seq.current(), 3);
assert_eq!(seq.next().unwrap(), 4);
}
let _ = fs::remove_file(&path);
}
#[test]
fn file_seq_starts_at_zero_when_missing() {
let path = unique_tempfile("missing");
let _ = fs::remove_file(&path);
let mut seq = FileSeq::open_or_create(&path).unwrap();
assert_eq!(seq.current(), 0);
assert_eq!(seq.next().unwrap(), 1);
let _ = fs::remove_file(&path);
}
#[test]
fn file_seq_refuses_to_open_corrupted_file() {
let path = unique_tempfile("corrupt");
fs::write(&path, "this is not a number").unwrap();
let err = FileSeq::open_or_create(&path).expect_err("expected parse error");
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
let _ = fs::remove_file(&path);
}
#[test]
fn file_seq_persists_to_disk_each_call() {
let path = unique_tempfile("disk");
let _ = fs::remove_file(&path);
let mut seq = FileSeq::open_or_create(&path).unwrap();
seq.next().unwrap();
seq.next().unwrap();
seq.next().unwrap();
let raw = fs::read_to_string(&path).unwrap();
assert_eq!(raw.trim(), "3");
let _ = fs::remove_file(&path);
}
}