use std::fs::{File, OpenOptions};
use std::io::{self, Read, Seek, SeekFrom, Write};
use std::path::Path;
use std::sync::RwLock;
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
#[cfg(unix)]
use std::os::unix::fs::FileExt;
const MAGIC: [u8; 8] = *b"BSTK\x00\x01\x00\x00";
const HEADER_SIZE: u64 = 16;
fn durable_sync(file: &File) -> io::Result<()> {
#[cfg(target_os = "macos")]
{
let ret = unsafe { libc::fcntl(file.as_raw_fd(), libc::F_FULLFSYNC) };
if ret != -1 {
return Ok(());
}
}
file.sync_data()
}
#[cfg(unix)]
fn flock_exclusive(file: &File) -> io::Result<()> {
let ret = unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) };
if ret == 0 {
Ok(())
} else {
Err(io::Error::last_os_error())
}
}
fn init_header(file: &mut File) -> io::Result<()> {
file.seek(SeekFrom::Start(0))?;
file.write_all(&MAGIC)?;
file.write_all(&0u64.to_le_bytes())
}
fn write_committed_len(file: &mut File, len: u64) -> io::Result<()> {
file.seek(SeekFrom::Start(8))?;
file.write_all(&len.to_le_bytes())
}
#[cfg(unix)]
fn pread_exact(file: &File, offset: u64, len: usize) -> io::Result<Vec<u8>> {
let mut buf = vec![0u8; len];
file.read_exact_at(&mut buf, offset)?;
Ok(buf)
}
fn read_header(file: &mut File) -> io::Result<u64> {
file.seek(SeekFrom::Start(0))?;
let mut hdr = [0u8; 16];
file.read_exact(&mut hdr)?;
if hdr[0..8] != MAGIC {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"bstack: bad magic number — not a bstack file or wrong version",
));
}
Ok(u64::from_le_bytes(hdr[8..16].try_into().unwrap()))
}
pub struct BStack {
lock: RwLock<File>,
}
impl BStack {
pub fn open(path: impl AsRef<Path>) -> io::Result<Self> {
let mut file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(path)?;
#[cfg(unix)]
flock_exclusive(&file)?;
let raw_size = file.metadata()?.len();
if raw_size == 0 {
init_header(&mut file)?;
durable_sync(&file)?;
} else if raw_size < HEADER_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"bstack: file is {raw_size} bytes — too small to contain the 16-byte header"
),
));
} else {
let committed_len = read_header(&mut file)?;
let actual_data_len = raw_size - HEADER_SIZE;
if actual_data_len != committed_len {
let correct_len = committed_len.min(actual_data_len);
file.set_len(HEADER_SIZE + correct_len)?;
write_committed_len(&mut file, correct_len)?;
durable_sync(&file)?;
}
}
Ok(BStack {
lock: RwLock::new(file),
})
}
pub fn push(&self, data: &[u8]) -> io::Result<u64> {
let mut file = self.lock.write().unwrap();
let file_end = file.seek(SeekFrom::End(0))?;
let logical_offset = file_end - HEADER_SIZE;
if data.is_empty() {
return Ok(logical_offset);
}
if let Err(e) = file.write_all(data) {
let _ = file.set_len(file_end);
return Err(e);
}
let new_len = logical_offset + data.len() as u64;
if let Err(e) = write_committed_len(&mut file, new_len).and_then(|_| durable_sync(&*file)) {
let _ = file.set_len(file_end);
let _ = write_committed_len(&mut file, logical_offset);
return Err(e);
}
Ok(logical_offset)
}
pub fn pop(&self, n: u64) -> io::Result<Vec<u8>> {
let mut file = self.lock.write().unwrap();
let raw_size = file.seek(SeekFrom::End(0))?;
let data_size = raw_size - HEADER_SIZE;
if n > data_size {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("pop({n}) exceeds payload size ({data_size})"),
));
}
let new_data_len = data_size - n;
file.seek(SeekFrom::Start(HEADER_SIZE + new_data_len))?;
let mut buf = vec![0u8; n as usize];
file.read_exact(&mut buf)?;
file.set_len(HEADER_SIZE + new_data_len)?;
write_committed_len(&mut file, new_data_len)?;
durable_sync(&*file)?;
Ok(buf)
}
pub fn peek(&self, offset: u64) -> io::Result<Vec<u8>> {
#[cfg(unix)]
{
let file = self.lock.read().unwrap();
let data_size = file.metadata()?.len().saturating_sub(HEADER_SIZE);
if offset > data_size {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("peek offset ({offset}) exceeds payload size ({data_size})"),
));
}
return pread_exact(&*file, HEADER_SIZE + offset, (data_size - offset) as usize);
}
#[cfg(not(unix))]
{
let mut file = self.lock.write().unwrap();
let raw_size = file.seek(SeekFrom::End(0))?;
let data_size = raw_size.saturating_sub(HEADER_SIZE);
if offset > data_size {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("peek offset ({offset}) exceeds payload size ({data_size})"),
));
}
file.seek(SeekFrom::Start(HEADER_SIZE + offset))?;
let mut buf = vec![0u8; (data_size - offset) as usize];
file.read_exact(&mut buf)?;
Ok(buf)
}
}
pub fn get(&self, start: u64, end: u64) -> io::Result<Vec<u8>> {
if end < start {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("get: end ({end}) < start ({start})"),
));
}
#[cfg(unix)]
{
let file = self.lock.read().unwrap();
let data_size = file.metadata()?.len().saturating_sub(HEADER_SIZE);
if end > data_size {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("get: end ({end}) exceeds payload size ({data_size})"),
));
}
return pread_exact(&*file, HEADER_SIZE + start, (end - start) as usize);
}
#[cfg(not(unix))]
{
let mut file = self.lock.write().unwrap();
let raw_size = file.seek(SeekFrom::End(0))?;
let data_size = raw_size.saturating_sub(HEADER_SIZE);
if end > data_size {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("get: end ({end}) exceeds payload size ({data_size})"),
));
}
file.seek(SeekFrom::Start(HEADER_SIZE + start))?;
let mut buf = vec![0u8; (end - start) as usize];
file.read_exact(&mut buf)?;
Ok(buf)
}
}
pub fn len(&self) -> io::Result<u64> {
let file = self.lock.read().unwrap();
Ok(file.metadata()?.len().saturating_sub(HEADER_SIZE))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::ErrorKind;
fn mk_stack() -> (BStack, std::path::PathBuf) {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
let pid = std::process::id();
let path = std::env::temp_dir().join(format!("bstack_test_{pid}_{id}.bin"));
let stack = BStack::open(&path).unwrap();
(stack, path)
}
struct Guard(std::path::PathBuf);
impl Drop for Guard {
fn drop(&mut self) {
let _ = std::fs::remove_file(&self.0);
}
}
#[test]
fn push_returns_correct_offsets() {
let (s, p) = mk_stack();
let _g = Guard(p);
let off0 = s.push(b"hello").unwrap();
let off1 = s.push(b"world").unwrap();
let off2 = s.push(b"!").unwrap();
assert_eq!(off0, 0);
assert_eq!(off1, 5);
assert_eq!(off2, 10);
assert_eq!(s.len().unwrap(), 11);
}
#[test]
fn pop_returns_correct_bytes_and_shrinks() {
let (s, p) = mk_stack();
let _g = Guard(p);
s.push(b"abcde").unwrap();
s.push(b"fghij").unwrap();
assert_eq!(s.len().unwrap(), 10);
let bytes = s.pop(5).unwrap();
assert_eq!(bytes, b"fghij");
assert_eq!(s.len().unwrap(), 5);
let bytes = s.pop(5).unwrap();
assert_eq!(bytes, b"abcde");
assert_eq!(s.len().unwrap(), 0);
}
#[test]
fn pop_across_push_boundary() {
let (s, p) = mk_stack();
let _g = Guard(p);
s.push(b"12345").unwrap();
s.push(b"67890").unwrap();
let bytes = s.pop(7).unwrap();
assert_eq!(bytes, b"4567890");
assert_eq!(s.len().unwrap(), 3);
}
#[test]
fn pop_on_empty_file_returns_error() {
let (s, p) = mk_stack();
let _g = Guard(p);
let err = s.pop(1).unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidInput);
}
#[test]
fn pop_n_exceeds_file_size_returns_error() {
let (s, p) = mk_stack();
let _g = Guard(p);
s.push(b"abc").unwrap();
let err = s.pop(10).unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidInput);
assert_eq!(s.len().unwrap(), 3);
}
#[test]
fn peek_reads_from_offset_to_end() {
let (s, p) = mk_stack();
let _g = Guard(p);
s.push(b"hello").unwrap();
s.push(b"world").unwrap();
assert_eq!(s.peek(0).unwrap(), b"helloworld");
assert_eq!(s.peek(5).unwrap(), b"world");
assert_eq!(s.peek(7).unwrap(), b"rld");
assert_eq!(s.peek(10).unwrap(), b"");
}
#[test]
fn peek_offset_exceeds_size_returns_error() {
let (s, p) = mk_stack();
let _g = Guard(p);
s.push(b"abc").unwrap();
let err = s.peek(10).unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidInput);
assert_eq!(s.len().unwrap(), 3);
}
#[test]
fn get_reads_half_open_range() {
let (s, p) = mk_stack();
let _g = Guard(p);
s.push(b"hello").unwrap();
s.push(b"world").unwrap();
assert_eq!(s.get(0, 5).unwrap(), b"hello");
assert_eq!(s.get(5, 10).unwrap(), b"world");
assert_eq!(s.get(3, 8).unwrap(), b"lowor");
assert_eq!(s.get(4, 4).unwrap(), b"");
}
#[test]
fn get_end_exceeds_size_returns_error() {
let (s, p) = mk_stack();
let _g = Guard(p);
s.push(b"abc").unwrap();
let err = s.get(0, 10).unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidInput);
}
#[test]
fn get_end_less_than_start_returns_error() {
let (s, p) = mk_stack();
let _g = Guard(p);
s.push(b"abcde").unwrap();
let err = s.get(4, 2).unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidInput);
}
#[test]
fn get_does_not_modify_file() {
let (s, p) = mk_stack();
let _g = Guard(p);
s.push(b"hello").unwrap();
s.push(b"world").unwrap();
let _ = s.get(2, 8).unwrap();
assert_eq!(s.len().unwrap(), 10);
let off = s.push(b"!").unwrap();
assert_eq!(off, 10);
}
#[test]
fn interleaved_push_pop_correct_state() {
let (s, p) = mk_stack();
let _g = Guard(p);
let o0 = s.push(b"AAA").unwrap();
assert_eq!(o0, 0);
let o1 = s.push(b"BB").unwrap();
assert_eq!(o1, 3);
let popped = s.pop(2).unwrap();
assert_eq!(popped, b"BB");
let o2 = s.push(b"CCCC").unwrap();
assert_eq!(o2, 3);
assert_eq!(s.len().unwrap(), 7);
let all = s.pop(7).unwrap();
assert_eq!(all, b"AAACCCC");
assert_eq!(s.len().unwrap(), 0);
}
#[test]
fn reopen_reads_back_correct_data() {
let (s, p) = mk_stack();
let _g = Guard(p.clone());
s.push(b"hello").unwrap();
s.push(b"world").unwrap();
drop(s);
let s2 = BStack::open(&p).unwrap();
assert_eq!(s2.len().unwrap(), 10);
assert_eq!(s2.peek(0).unwrap(), b"helloworld");
}
#[test]
fn reopen_and_continue_pushing() {
let (s, p) = mk_stack();
let _g = Guard(p.clone());
let off0 = s.push(b"first").unwrap();
assert_eq!(off0, 0);
drop(s);
let s2 = BStack::open(&p).unwrap();
let off1 = s2.push(b"second").unwrap();
assert_eq!(off1, 5);
assert_eq!(s2.len().unwrap(), 11);
assert_eq!(s2.peek(0).unwrap(), b"firstsecond");
}
#[test]
fn reopen_after_pop_sees_truncated_file() {
let (s, p) = mk_stack();
let _g = Guard(p.clone());
s.push(b"hello").unwrap();
s.push(b"world").unwrap();
s.pop(5).unwrap();
drop(s);
let s2 = BStack::open(&p).unwrap();
assert_eq!(s2.len().unwrap(), 5);
assert_eq!(s2.peek(0).unwrap(), b"hello");
}
#[test]
fn push_empty_slice() {
let (s, p) = mk_stack();
let _g = Guard(p);
let off0 = s.push(b"abc").unwrap();
let off1 = s.push(&[]).unwrap();
let off2 = s.push(b"def").unwrap();
assert_eq!(off0, 0);
assert_eq!(off1, 3);
assert_eq!(off2, 3);
assert_eq!(s.len().unwrap(), 6);
assert_eq!(s.peek(0).unwrap(), b"abcdef");
}
#[test]
fn pop_zero_bytes() {
let (s, p) = mk_stack();
let _g = Guard(p);
s.push(b"abc").unwrap();
let bytes = s.pop(0).unwrap();
assert_eq!(bytes, b"");
assert_eq!(s.len().unwrap(), 3);
let off = s.push(b"d").unwrap();
assert_eq!(off, 3);
}
#[test]
fn peek_zero_offset_on_empty_file() {
let (s, p) = mk_stack();
let _g = Guard(p);
assert_eq!(s.peek(0).unwrap(), b"");
}
#[test]
fn get_zero_range_on_empty_file() {
let (s, p) = mk_stack();
let _g = Guard(p);
assert_eq!(s.get(0, 0).unwrap(), b"");
}
#[test]
fn drain_to_zero_then_push_starts_at_offset_zero() {
let (s, p) = mk_stack();
let _g = Guard(p);
s.push(b"hello").unwrap();
s.pop(5).unwrap();
assert_eq!(s.len().unwrap(), 0);
let off = s.push(b"world").unwrap();
assert_eq!(off, 0);
assert_eq!(s.len().unwrap(), 5);
assert_eq!(s.peek(0).unwrap(), b"world");
}
#[test]
fn peek_does_not_modify_file() {
let (s, p) = mk_stack();
let _g = Guard(p);
s.push(b"hello").unwrap();
s.push(b"world").unwrap();
let _ = s.peek(3).unwrap();
assert_eq!(s.len().unwrap(), 10);
let off = s.push(b"!").unwrap();
assert_eq!(off, 10);
}
#[test]
fn binary_roundtrip_all_byte_values() {
let (s, p) = mk_stack();
let _g = Guard(p);
let data: Vec<u8> = (0u16..512).map(|i| (i % 256) as u8).collect();
s.push(&data).unwrap();
let got = s.pop(data.len() as u64).unwrap();
assert_eq!(got, data);
assert_eq!(s.len().unwrap(), 0);
}
#[test]
fn large_payload_roundtrip() {
let (s, p) = mk_stack();
let _g = Guard(p);
let payload: Vec<u8> = (0..1024 * 1024)
.map(|i: usize| (i.wrapping_mul(7).wrapping_add(13)) as u8)
.collect();
s.push(&payload).unwrap();
let got = s.get(0, payload.len() as u64).unwrap();
assert_eq!(got, payload);
assert_eq!(s.len().unwrap(), payload.len() as u64);
}
#[test]
fn new_file_has_valid_header() {
let (s, p) = mk_stack();
let _g = Guard(p.clone());
drop(s);
let raw = std::fs::read(&p).unwrap();
assert_eq!(raw.len(), HEADER_SIZE as usize, "new file should be exactly 16 bytes");
assert_eq!(&raw[0..8], &MAGIC, "magic mismatch");
let clen = u64::from_le_bytes(raw[8..16].try_into().unwrap());
assert_eq!(clen, 0, "committed length should be 0 for empty stack");
}
#[test]
fn header_committed_len_matches_after_pushes() {
let (s, p) = mk_stack();
let _g = Guard(p.clone());
s.push(b"hello").unwrap(); s.push(b"world").unwrap(); drop(s);
let raw = std::fs::read(&p).unwrap();
let clen = u64::from_le_bytes(raw[8..16].try_into().unwrap());
assert_eq!(clen, 10);
assert_eq!(raw.len() as u64, HEADER_SIZE + 10);
}
#[test]
fn header_committed_len_matches_after_pop() {
let (s, p) = mk_stack();
let _g = Guard(p.clone());
s.push(b"hello").unwrap();
s.push(b"world").unwrap();
s.pop(5).unwrap();
drop(s);
let raw = std::fs::read(&p).unwrap();
let clen = u64::from_le_bytes(raw[8..16].try_into().unwrap());
assert_eq!(clen, 5);
assert_eq!(raw.len() as u64, HEADER_SIZE + 5);
}
#[test]
fn open_rejects_bad_magic() {
let path = {
use std::sync::atomic::{AtomicU64, Ordering};
static C: AtomicU64 = AtomicU64::new(0);
let id = C.fetch_add(1, Ordering::Relaxed);
std::env::temp_dir().join(format!("bstack_badmagic_{}.bin", id))
};
let _g = Guard(path.clone());
let mut bad: Vec<u8> = b"WRONGHDR".to_vec();
bad.extend_from_slice(&0u64.to_le_bytes());
std::fs::write(&path, &bad).unwrap();
let err = BStack::open(&path).err().unwrap();
assert_eq!(err.kind(), ErrorKind::InvalidData);
assert!(err.to_string().contains("magic"));
}
#[test]
fn open_rejects_truncated_header() {
let path = {
use std::sync::atomic::{AtomicU64, Ordering};
static C: AtomicU64 = AtomicU64::new(0);
let id = C.fetch_add(1, Ordering::Relaxed);
std::env::temp_dir().join(format!("bstack_smallfile_{}.bin", id))
};
let _g = Guard(path.clone());
std::fs::write(&path, b"tooshort").unwrap();
let err = BStack::open(&path).err().unwrap();
assert_eq!(err.kind(), ErrorKind::InvalidData);
}
#[test]
fn recovery_truncates_partial_push() {
let (s, p) = mk_stack();
let _g = Guard(p.clone());
s.push(b"committed").unwrap(); drop(s);
{
use std::io::Write;
let mut f = OpenOptions::new().append(true).open(&p).unwrap();
f.write_all(b"ghost").unwrap();
}
let raw = std::fs::read(&p).unwrap();
assert_eq!(raw.len(), (HEADER_SIZE + 9 + 5) as usize);
let clen_before = u64::from_le_bytes(raw[8..16].try_into().unwrap());
assert_eq!(clen_before, 9);
let s2 = BStack::open(&p).unwrap();
assert_eq!(s2.len().unwrap(), 9);
assert_eq!(s2.peek(0).unwrap(), b"committed");
drop(s2);
let raw2 = std::fs::read(&p).unwrap();
assert_eq!(raw2.len(), (HEADER_SIZE + 9) as usize);
}
#[test]
fn recovery_repairs_header_after_partial_pop() {
let (s, p) = mk_stack();
let _g = Guard(p.clone());
s.push(b"hello").unwrap(); s.push(b"world").unwrap(); drop(s);
{
let f = OpenOptions::new().write(true).open(&p).unwrap();
f.set_len(HEADER_SIZE + 5).unwrap();
}
let raw = std::fs::read(&p).unwrap();
assert_eq!(raw.len(), (HEADER_SIZE + 5) as usize);
let clen_before = u64::from_le_bytes(raw[8..16].try_into().unwrap());
assert_eq!(clen_before, 10, "header should still claim 10 before recovery");
let s2 = BStack::open(&p).unwrap();
assert_eq!(s2.len().unwrap(), 5);
assert_eq!(s2.peek(0).unwrap(), b"hello");
drop(s2);
let raw2 = std::fs::read(&p).unwrap();
let clen_after = u64::from_le_bytes(raw2[8..16].try_into().unwrap());
assert_eq!(clen_after, 5, "clen should be repaired to 5 after recovery");
}
#[cfg(unix)]
#[test]
fn concurrent_reads_do_not_serialise() {
use std::sync::Arc;
use std::thread;
let (s, p) = mk_stack();
let _g = Guard(p);
const RECORDS: usize = 8;
const RSIZE: u64 = 16;
for i in 0..RECORDS {
let mut rec = [0u8; RSIZE as usize];
rec[0] = i as u8;
s.push(&rec).unwrap();
}
let s = Arc::new(s);
let handles: Vec<_> = (0..32)
.map(|_| {
let s = Arc::clone(&s);
thread::spawn(move || {
for i in 0..RECORDS {
let off = i as u64 * RSIZE;
let via_get = s.get(off, off + RSIZE).unwrap();
assert_eq!(via_get[0], i as u8);
let via_peek = s.peek(off).unwrap();
assert_eq!(via_peek[0], i as u8);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn concurrent_pushes_non_overlapping() {
use std::collections::HashSet;
use std::sync::Arc;
use std::thread;
let (s, p) = mk_stack();
let _g = Guard(p);
let s = Arc::new(s);
const THREADS: usize = 8;
const PER_THREAD: usize = 100;
const ITEM: usize = 16;
let handles: Vec<_> = (0..THREADS)
.map(|t| {
let s = Arc::clone(&s);
thread::spawn(move || {
(0..PER_THREAD)
.map(|i| {
let mut data = [0u8; ITEM];
data[0] = t as u8;
data[1..9].copy_from_slice(&(i as u64).to_le_bytes());
let off = s.push(&data).unwrap();
(off, t, i)
})
.collect::<Vec<_>>()
})
})
.collect();
let results: Vec<_> = handles.into_iter().flat_map(|h| h.join().unwrap()).collect();
for &(off, _, _) in &results {
assert_eq!(off % ITEM as u64, 0, "offset {off} is not aligned to ITEM");
}
let mut seen: HashSet<u64> = HashSet::new();
for &(off, _, _) in &results {
assert!(seen.insert(off), "duplicate offset {off}");
}
assert_eq!(s.len().unwrap(), (THREADS * PER_THREAD * ITEM) as u64);
for &(off, t, i) in &results {
let slot = s.get(off, off + ITEM as u64).unwrap();
assert_eq!(slot[0], t as u8, "thread id mismatch at offset {off}");
let idx = u64::from_le_bytes(slot[1..9].try_into().unwrap());
assert_eq!(idx, i as u64, "item index mismatch at offset {off}");
}
}
#[test]
fn concurrent_len_is_multiple_of_item_size() {
use std::sync::Arc;
use std::thread;
let (s, p) = mk_stack();
let _g = Guard(p);
let s = Arc::new(s);
const ITEM: u64 = 8;
const PUSH_THREADS: usize = 4;
const PUSHES_PER_THREAD: usize = 200;
let push_handles: Vec<_> = (0..PUSH_THREADS)
.map(|_| {
let s = Arc::clone(&s);
thread::spawn(move || {
for _ in 0..PUSHES_PER_THREAD {
s.push(&[0xBEu8; ITEM as usize]).unwrap();
}
})
})
.collect();
let len_handle = {
let s = Arc::clone(&s);
thread::spawn(move || {
for _ in 0..2000 {
let size = s.len().unwrap();
assert_eq!(
size % ITEM,
0,
"torn write: size {size} is not a multiple of {ITEM}"
);
}
})
};
for h in push_handles {
h.join().unwrap();
}
len_handle.join().unwrap();
assert_eq!(
s.len().unwrap(),
(PUSH_THREADS * PUSHES_PER_THREAD) as u64 * ITEM
);
}
}