use std::fs::File;
use std::io::{self, BufReader, BufWriter, Read, Seek, SeekFrom, Write};
#[cfg(unix)]
use std::os::unix::fs::FileExt;
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tempfile::NamedTempFile;
pub trait BucketRecord: Copy + Send + Sync {
const SIZE: usize;
fn write_to(&self, out: &mut [u8]);
fn read_from(bytes: &[u8]) -> Self;
}
impl BucketRecord for u64 {
const SIZE: usize = 8;
#[inline]
fn write_to(&self, out: &mut [u8]) {
debug_assert_eq!(out.len(), 8);
out.copy_from_slice(&self.to_le_bytes());
}
#[inline]
fn read_from(bytes: &[u8]) -> Self {
debug_assert_eq!(bytes.len(), 8);
u64::from_le_bytes(bytes.try_into().unwrap())
}
}
#[allow(dead_code)]
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub struct SaLcp<I> {
pub pos: I,
pub lcp: I,
}
impl BucketRecord for SaLcp<u32> {
const SIZE: usize = 8;
#[inline]
fn write_to(&self, out: &mut [u8]) {
debug_assert_eq!(out.len(), 8);
out[0..4].copy_from_slice(&self.pos.to_le_bytes());
out[4..8].copy_from_slice(&self.lcp.to_le_bytes());
}
#[inline]
fn read_from(bytes: &[u8]) -> Self {
debug_assert_eq!(bytes.len(), 8);
let pos = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
let lcp = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
SaLcp { pos, lcp }
}
}
impl BucketRecord for SaLcp<u64> {
const SIZE: usize = 16;
#[inline]
fn write_to(&self, out: &mut [u8]) {
debug_assert_eq!(out.len(), 16);
out[0..8].copy_from_slice(&self.pos.to_le_bytes());
out[8..16].copy_from_slice(&self.lcp.to_le_bytes());
}
#[inline]
fn read_from(bytes: &[u8]) -> Self {
debug_assert_eq!(bytes.len(), 16);
let pos = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
let lcp = u64::from_le_bytes(bytes[8..16].try_into().unwrap());
SaLcp { pos, lcp }
}
}
const DEFAULT_BUFFER_RECORDS: usize = 2048;
pub trait BucketStore<T>: Send {
fn add_slice(&mut self, rs: &[T]) -> io::Result<()>;
fn mark_boundary(&mut self);
fn total_records(&self) -> usize;
fn boundaries(&self) -> &[usize];
fn load_all(&mut self) -> io::Result<Vec<T>>;
}
pub struct InMemBucket<T> {
records: Vec<T>,
boundaries: Vec<usize>,
}
impl<T: Copy> InMemBucket<T> {
pub fn new() -> Self {
Self {
records: Vec::new(),
boundaries: vec![0],
}
}
}
impl<T: Copy> Default for InMemBucket<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Copy + Send + Sync> BucketStore<T> for InMemBucket<T> {
fn add_slice(&mut self, rs: &[T]) -> io::Result<()> {
self.records.extend_from_slice(rs);
Ok(())
}
fn mark_boundary(&mut self) {
let last = *self.boundaries.last().unwrap();
let now = self.records.len();
if now != last {
self.boundaries.push(now);
}
}
fn total_records(&self) -> usize {
self.records.len()
}
fn boundaries(&self) -> &[usize] {
&self.boundaries
}
fn load_all(&mut self) -> io::Result<Vec<T>> {
Ok(std::mem::take(&mut self.records))
}
}
impl<T: BucketRecord> BucketStore<T> for ExtMemBucket<T> {
fn add_slice(&mut self, rs: &[T]) -> io::Result<()> {
ExtMemBucket::add_slice(self, rs)
}
fn mark_boundary(&mut self) {
ExtMemBucket::mark_boundary(self)
}
fn total_records(&self) -> usize {
ExtMemBucket::total_records(self)
}
fn boundaries(&self) -> &[usize] {
ExtMemBucket::boundaries(self)
}
fn load_all(&mut self) -> io::Result<Vec<T>> {
ExtMemBucket::load_all(self)
}
}
pub struct ExtMemBucket<T: BucketRecord> {
buf: Vec<T>,
buffer_records: usize,
file: Option<NamedTempFile>,
writer: Option<BufWriter<File>>,
on_disk: usize,
boundaries: Vec<usize>,
work_dir: std::path::PathBuf,
prefix: String,
}
impl<T: BucketRecord> ExtMemBucket<T> {
#[allow(dead_code)]
pub fn new(work_dir: impl AsRef<Path>, prefix: impl Into<String>) -> Self {
Self::with_buffer_records(work_dir, prefix, DEFAULT_BUFFER_RECORDS)
}
#[allow(dead_code)]
pub fn with_buffer_records(
work_dir: impl AsRef<Path>,
prefix: impl Into<String>,
buffer_records: usize,
) -> Self {
Self {
buf: Vec::with_capacity(buffer_records),
buffer_records,
file: None,
writer: None,
on_disk: 0,
boundaries: vec![0],
work_dir: work_dir.as_ref().to_path_buf(),
prefix: prefix.into(),
}
}
#[allow(dead_code)]
pub fn add(&mut self, r: T) -> io::Result<()> {
self.buf.push(r);
if self.buf.len() >= self.buffer_records {
self.flush()?;
}
Ok(())
}
pub fn add_slice(&mut self, rs: &[T]) -> io::Result<()> {
if self.buf.len() + rs.len() <= self.buffer_records {
self.buf.extend_from_slice(rs);
return Ok(());
}
self.flush()?;
self.ensure_file()?;
let writer = self.writer.as_mut().unwrap();
write_records(writer, rs)?;
self.on_disk += rs.len();
Ok(())
}
#[allow(dead_code)]
pub fn mark_boundary(&mut self) {
let last = *self.boundaries.last().unwrap();
let now = self.total_records();
if now != last {
self.boundaries.push(now);
}
}
pub fn total_records(&self) -> usize {
self.on_disk + self.buf.len()
}
#[allow(dead_code)]
pub fn boundaries(&self) -> &[usize] {
&self.boundaries
}
pub fn flush(&mut self) -> io::Result<()> {
if self.buf.is_empty() {
return Ok(());
}
self.ensure_file()?;
let writer = self.writer.as_mut().unwrap();
let recs = std::mem::take(&mut self.buf);
write_records(writer, &recs)?;
self.on_disk += recs.len();
self.buf = Vec::with_capacity(self.buffer_records);
Ok(())
}
#[allow(dead_code)]
pub fn load_all(&mut self) -> io::Result<Vec<T>> {
self.flush()?;
if let Some(w) = self.writer.as_mut() {
w.flush()?;
}
let total = self.total_records();
let mut out = Vec::with_capacity(total);
if let Some(file) = self.file.as_ref() {
let mut reader = BufReader::new(file.reopen()?);
reader.seek(SeekFrom::Start(0))?;
read_records(&mut reader, total, &mut out)?;
}
Ok(out)
}
#[allow(dead_code)]
pub fn open_reader(&mut self) -> io::Result<BufReader<File>> {
self.flush()?;
if let Some(w) = self.writer.as_mut() {
w.flush()?;
}
let file = self
.file
.as_ref()
.expect("open_reader on empty bucket — guard with total_records() > 0");
let mut reader = BufReader::new(file.reopen()?);
reader.seek(SeekFrom::Start(0))?;
Ok(reader)
}
fn ensure_file(&mut self) -> io::Result<()> {
if self.file.is_some() {
return Ok(());
}
let f = tempfile::Builder::new()
.prefix(&format!("caps-sa-{}-", self.prefix))
.suffix(".bin")
.tempfile_in(&self.work_dir)?;
let writer = BufWriter::new(f.reopen()?);
self.file = Some(f);
self.writer = Some(writer);
Ok(())
}
}
struct PhysicalFile {
file: File,
cursor: AtomicU64,
}
pub struct BucketPool {
files: Vec<Arc<PhysicalFile>>,
}
impl BucketPool {
pub fn new(n_physical: usize, work_dir: impl AsRef<Path>) -> io::Result<Self> {
let work_dir = work_dir.as_ref();
let n = n_physical.max(1);
let mut files = Vec::with_capacity(n);
for _ in 0..n {
let file = tempfile::tempfile_in(work_dir)?;
files.push(Arc::new(PhysicalFile {
file,
cursor: AtomicU64::new(0),
}));
}
Ok(Self { files })
}
pub fn new_bucket<T: BucketRecord>(&self, bucket_id: usize) -> PooledExtMemBucket<T> {
let phys = Arc::clone(&self.files[bucket_id % self.files.len()]);
PooledExtMemBucket::new(phys)
}
#[allow(dead_code)]
pub fn n_physical(&self) -> usize {
self.files.len()
}
}
pub struct PooledExtMemBucket<T: BucketRecord> {
phys: Arc<PhysicalFile>,
buf: Vec<T>,
buffer_records: usize,
extents: Vec<(u64, u32)>,
boundaries: Vec<usize>,
on_disk: usize,
}
impl<T: BucketRecord> PooledExtMemBucket<T> {
fn new(phys: Arc<PhysicalFile>) -> Self {
Self::with_buffer_records(phys, DEFAULT_BUFFER_RECORDS)
}
fn with_buffer_records(phys: Arc<PhysicalFile>, buffer_records: usize) -> Self {
Self {
phys,
buf: Vec::with_capacity(buffer_records),
buffer_records,
extents: Vec::new(),
boundaries: vec![0],
on_disk: 0,
}
}
fn write_extent(&mut self, rs: &[T]) -> io::Result<()> {
let n_records = rs.len();
if n_records == 0 {
return Ok(());
}
let n_bytes = n_records * T::SIZE;
debug_assert!(
n_bytes <= u32::MAX as usize,
"single extent exceeds 4 GiB (n_bytes={n_bytes})",
);
let mut scratch = vec![0u8; n_bytes];
for (i, r) in rs.iter().enumerate() {
r.write_to(&mut scratch[i * T::SIZE..(i + 1) * T::SIZE]);
}
let offset = self
.phys
.cursor
.fetch_add(n_bytes as u64, Ordering::Relaxed);
pwrite_all(&self.phys.file, &scratch, offset)?;
self.extents.push((offset, n_bytes as u32));
self.on_disk += n_records;
Ok(())
}
fn flush(&mut self) -> io::Result<()> {
if self.buf.is_empty() {
return Ok(());
}
let buf = std::mem::take(&mut self.buf);
self.write_extent(&buf)?;
self.buf = buf;
self.buf.clear();
Ok(())
}
pub fn total_records(&self) -> usize {
self.on_disk + self.buf.len()
}
}
impl<T: BucketRecord> BucketStore<T> for PooledExtMemBucket<T> {
fn add_slice(&mut self, rs: &[T]) -> io::Result<()> {
if self.buf.len() + rs.len() <= self.buffer_records {
self.buf.extend_from_slice(rs);
return Ok(());
}
self.flush()?;
self.write_extent(rs)
}
fn mark_boundary(&mut self) {
let last = *self.boundaries.last().unwrap();
let now = self.total_records();
if now != last {
self.boundaries.push(now);
}
}
fn total_records(&self) -> usize {
PooledExtMemBucket::total_records(self)
}
fn boundaries(&self) -> &[usize] {
&self.boundaries
}
fn load_all(&mut self) -> io::Result<Vec<T>> {
self.flush()?;
let total = self.total_records();
let mut out = Vec::with_capacity(total);
let mut scratch: Vec<u8> = Vec::new();
for &(offset, byte_len) in &self.extents {
let byte_len = byte_len as usize;
let n_records = byte_len / T::SIZE;
if scratch.len() < byte_len {
scratch.resize(byte_len, 0);
}
pread_all(&self.phys.file, &mut scratch[..byte_len], offset)?;
for i in 0..n_records {
out.push(T::read_from(&scratch[i * T::SIZE..(i + 1) * T::SIZE]));
}
}
Ok(out)
}
}
#[cfg(unix)]
fn pwrite_all(file: &File, mut buf: &[u8], mut offset: u64) -> io::Result<()> {
while !buf.is_empty() {
match file.write_at(buf, offset) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"pwrite returned 0",
));
}
Ok(n) => {
buf = &buf[n..];
offset += n as u64;
}
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
Ok(())
}
#[cfg(unix)]
fn pread_all(file: &File, mut buf: &mut [u8], mut offset: u64) -> io::Result<()> {
while !buf.is_empty() {
match file.read_at(buf, offset) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"pread hit EOF before requested length",
));
}
Ok(n) => {
buf = &mut buf[n..];
offset += n as u64;
}
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
Ok(())
}
#[cfg(not(unix))]
compile_error!(
"PooledExtMemBucket currently requires Unix file extension API; \
add Windows support via seek_read/seek_write if needed."
);
fn write_records<T: BucketRecord, W: Write>(w: &mut W, rs: &[T]) -> io::Result<()> {
const CHUNK_RECORDS: usize = 1024;
let mut scratch = vec![0u8; CHUNK_RECORDS * T::SIZE];
for chunk in rs.chunks(CHUNK_RECORDS) {
let bytes = chunk.len() * T::SIZE;
for (i, r) in chunk.iter().enumerate() {
r.write_to(&mut scratch[i * T::SIZE..(i + 1) * T::SIZE]);
}
w.write_all(&scratch[..bytes])?;
}
Ok(())
}
#[allow(dead_code)]
fn read_records<T: BucketRecord, R: Read>(
r: &mut R,
count: usize,
out: &mut Vec<T>,
) -> io::Result<()> {
const CHUNK_RECORDS: usize = 1024;
let mut scratch = vec![0u8; CHUNK_RECORDS * T::SIZE];
let mut remaining = count;
while remaining > 0 {
let take = remaining.min(CHUNK_RECORDS);
let bytes = take * T::SIZE;
r.read_exact(&mut scratch[..bytes])?;
for i in 0..take {
out.push(T::read_from(&scratch[i * T::SIZE..(i + 1) * T::SIZE]));
}
remaining -= take;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn round_trip_below_buffer() {
let dir = tempdir().unwrap();
let mut b: ExtMemBucket<SaLcp<u64>> = ExtMemBucket::new(dir.path(), "test");
for i in 0..10 {
b.add(SaLcp { pos: i, lcp: i * 2 }).unwrap();
}
assert_eq!(b.total_records(), 10);
let loaded = b.load_all().unwrap();
assert_eq!(loaded.len(), 10);
for (i, r) in loaded.iter().enumerate() {
assert_eq!(
*r,
SaLcp {
pos: i as u64,
lcp: (i * 2) as u64
}
);
}
}
#[test]
fn round_trip_with_spill() {
let dir = tempdir().unwrap();
let mut b: ExtMemBucket<SaLcp<u64>> =
ExtMemBucket::with_buffer_records(dir.path(), "spill", 3);
for i in 0..10 {
b.add(SaLcp { pos: i, lcp: 0 }).unwrap();
}
assert_eq!(b.total_records(), 10);
let loaded = b.load_all().unwrap();
assert_eq!(
loaded.iter().map(|r| r.pos).collect::<Vec<_>>(),
(0..10u64).collect::<Vec<_>>()
);
}
#[test]
fn add_slice_bulk_path() {
let dir = tempdir().unwrap();
let mut b: ExtMemBucket<SaLcp<u64>> =
ExtMemBucket::with_buffer_records(dir.path(), "bulk", 4);
let mut input: Vec<SaLcp<u64>> = (0..100).map(|i| SaLcp { pos: i, lcp: 0 }).collect();
b.add_slice(&input).unwrap();
assert_eq!(b.total_records(), 100);
for i in 100..103 {
b.add(SaLcp { pos: i, lcp: 0 }).unwrap();
}
input.extend((100..103).map(|i| SaLcp { pos: i, lcp: 0 }));
let loaded = b.load_all().unwrap();
assert_eq!(loaded, input);
}
#[test]
fn boundaries_track_sub_subarrays() {
let dir = tempdir().unwrap();
let mut b: ExtMemBucket<SaLcp<u64>> = ExtMemBucket::new(dir.path(), "bounds");
for i in 0..3 {
b.add(SaLcp { pos: i, lcp: 0 }).unwrap();
}
b.mark_boundary();
for i in 3..7 {
b.add(SaLcp { pos: i, lcp: 0 }).unwrap();
}
b.mark_boundary();
b.mark_boundary();
for i in 7..10 {
b.add(SaLcp { pos: i, lcp: 0 }).unwrap();
}
b.mark_boundary();
assert_eq!(b.boundaries(), &[0, 3, 7, 10]);
let loaded = b.load_all().unwrap();
assert_eq!(loaded.len(), 10);
}
#[test]
fn empty_bucket() {
let dir = tempdir().unwrap();
let mut b: ExtMemBucket<SaLcp<u64>> = ExtMemBucket::new(dir.path(), "empty");
assert_eq!(b.total_records(), 0);
let loaded = b.load_all().unwrap();
assert!(loaded.is_empty());
}
#[test]
fn pooled_round_trip_below_and_above_buffer() {
let dir = tempdir().unwrap();
let pool = BucketPool::new(1, dir.path()).unwrap();
let mut b: PooledExtMemBucket<SaLcp<u64>> =
PooledExtMemBucket::with_buffer_records(Arc::clone(&pool.files[0]), 3);
for i in 0..10 {
b.add_slice(&[SaLcp { pos: i, lcp: 0 }]).unwrap();
}
assert_eq!(b.total_records(), 10);
let loaded = b.load_all().unwrap();
assert_eq!(
loaded.iter().map(|r| r.pos).collect::<Vec<_>>(),
(0..10u64).collect::<Vec<_>>()
);
}
#[test]
fn pooled_two_buckets_share_one_file() {
let dir = tempdir().unwrap();
let pool = BucketPool::new(1, dir.path()).unwrap();
let mut a: PooledExtMemBucket<SaLcp<u64>> = pool.new_bucket(0);
let mut b: PooledExtMemBucket<SaLcp<u64>> = pool.new_bucket(1);
for i in 0..10u64 {
a.add_slice(&[SaLcp { pos: i, lcp: 0 }]).unwrap();
b.add_slice(&[SaLcp {
pos: 100 + i,
lcp: 0,
}])
.unwrap();
}
let loaded_a = a.load_all().unwrap();
let loaded_b = b.load_all().unwrap();
assert_eq!(
loaded_a.iter().map(|r| r.pos).collect::<Vec<_>>(),
(0..10u64).collect::<Vec<_>>()
);
assert_eq!(
loaded_b.iter().map(|r| r.pos).collect::<Vec<_>>(),
(100..110u64).collect::<Vec<_>>()
);
}
#[test]
fn pooled_boundaries_track_sub_subarrays() {
let dir = tempdir().unwrap();
let pool = BucketPool::new(2, dir.path()).unwrap();
let mut b: PooledExtMemBucket<SaLcp<u64>> = pool.new_bucket(0);
for chunk in &[0..3, 3..7, 7..10] {
let records: Vec<_> = chunk
.clone()
.map(|i| SaLcp {
pos: i as u64,
lcp: 0,
})
.collect();
b.add_slice(&records).unwrap();
b.mark_boundary();
}
assert_eq!(b.boundaries(), &[0, 3, 7, 10]);
let loaded = b.load_all().unwrap();
assert_eq!(loaded.len(), 10);
}
#[test]
fn pooled_empty_bucket() {
let dir = tempdir().unwrap();
let pool = BucketPool::new(2, dir.path()).unwrap();
let mut b: PooledExtMemBucket<SaLcp<u64>> = pool.new_bucket(0);
assert_eq!(b.total_records(), 0);
let loaded = b.load_all().unwrap();
assert!(loaded.is_empty());
}
}