use std::collections::HashMap;
use std::io::{Read, Seek, SeekFrom, Write};
use tempfile::NamedTempFile;
use tracing::debug;
const NO_PARENT_SLOT: u64 = u64::MAX;
#[derive(Clone, Copy, Debug)]
struct SpillRecord {
cost: u64,
parent_slot: u64,
}
impl SpillRecord {
fn from_parts(cost: u64, parent_slot: Option<usize>) -> Self {
let parent_slot = parent_slot
.and_then(|slot| u64::try_from(slot).ok())
.unwrap_or(NO_PARENT_SLOT);
Self { cost, parent_slot }
}
fn parent_slot(self) -> Option<usize> {
if self.parent_slot == NO_PARENT_SLOT {
None
} else {
usize::try_from(self.parent_slot).ok()
}
}
fn to_bytes(self) -> [u8; 16] {
let mut out = [0_u8; 16];
out[0..8].copy_from_slice(&self.cost.to_le_bytes());
out[8..16].copy_from_slice(&self.parent_slot.to_le_bytes());
out
}
fn from_bytes(bytes: [u8; 16]) -> Self {
let mut cost = [0_u8; 8];
let mut parent_slot = [0_u8; 8];
cost.copy_from_slice(&bytes[0..8]);
parent_slot.copy_from_slice(&bytes[8..16]);
Self {
cost: u64::from_le_bytes(cost),
parent_slot: u64::from_le_bytes(parent_slot),
}
}
}
#[derive(Debug)]
pub(super) struct SwapStore {
file: NamedTempFile,
write_buffer: HashMap<u64, SpillRecord>,
write_buffer_capacity: usize,
}
impl SwapStore {
pub(super) const SPILL_RECORD_BYTES: u64 = std::mem::size_of::<SpillRecord>() as u64;
pub(super) fn new(write_buffer_capacity: usize) -> std::io::Result<Self> {
let file = tempfile::Builder::new()
.prefix("revrt-routing-swap-")
.suffix(".bin")
.tempfile()?;
let write_buffer_capacity = write_buffer_capacity.max(1);
debug!("Swap for Dijkstra graph at {:?}", file.path());
debug!(
"Swap buffer capacity set to {} entries",
write_buffer_capacity
);
Ok(Self {
file,
write_buffer: HashMap::with_capacity(write_buffer_capacity),
write_buffer_capacity,
})
}
fn slot_offset(slot: u64) -> std::io::Result<u64> {
slot.checked_mul(16)
.ok_or_else(|| std::io::Error::other("swap slot offset overflow"))
}
pub(super) fn write_slot(
&mut self,
slot: usize,
record: (u64, Option<usize>),
) -> std::io::Result<()> {
let slot = u64::try_from(slot).map_err(|_| std::io::Error::other("slot overflow"))?;
self.write_buffer
.insert(slot, SpillRecord::from_parts(record.0, record.1));
if self.write_buffer.len() >= self.write_buffer_capacity {
self.flush()?;
}
Ok(())
}
pub(super) fn flush(&mut self) -> std::io::Result<()> {
if self.write_buffer.is_empty() {
return Ok(());
}
debug!("Flushing {} entries to disk", self.write_buffer.len());
let mut buffered_entries = self.write_buffer.drain().collect::<Vec<_>>();
if buffered_entries.len() > 1 {
buffered_entries.sort_unstable_by_key(|(slot, _)| *slot);
}
for (slot, record) in buffered_entries {
let offset = Self::slot_offset(slot)?;
let file = self.file.as_file_mut();
file.seek(SeekFrom::Start(offset))?;
file.write_all(&record.to_bytes())?;
}
self.file.as_file_mut().flush()
}
pub(super) fn read_slot(&mut self, slot: usize) -> std::io::Result<(u64, Option<usize>)> {
let slot = u64::try_from(slot).map_err(|_| std::io::Error::other("slot overflow"))?;
if let Some(record) = self.write_buffer.get(&slot).copied() {
return Ok((record.cost, record.parent_slot()));
}
let offset = Self::slot_offset(slot)?;
let file = self.file.as_file_mut();
file.seek(SeekFrom::Start(offset))?;
let mut bytes = [0_u8; 16];
file.read_exact(&mut bytes)?;
let record = SpillRecord::from_bytes(bytes);
Ok((record.cost, record.parent_slot()))
}
pub(super) fn slot_in_buffer(&self, slot: usize) -> bool {
let slot = match u64::try_from(slot) {
Ok(slot) => slot,
Err(_) => return false,
};
self.write_buffer.contains_key(&slot)
}
}
impl Drop for SwapStore {
fn drop(&mut self) {
let _ = self.flush();
}
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Barrier};
use std::thread;
use super::*;
#[test]
fn swap_store_reads_written_slot() {
let mut swap = SwapStore::new(2).unwrap();
swap.write_slot(42, (7, Some(55))).unwrap();
swap.flush().unwrap();
let restored = swap.read_slot(42).unwrap();
assert_eq!(restored.0, 7);
assert_eq!(restored.1, Some(55));
}
#[test]
fn swap_store_read_flushes_buffered_writes() {
let mut swap = SwapStore::new(8).unwrap();
swap.write_slot(4, (11, Some(3))).unwrap();
let restored = swap.read_slot(4).unwrap();
assert_eq!(restored.0, 11);
assert_eq!(restored.1, Some(3));
swap.flush().unwrap();
let restored = swap.read_slot(4).unwrap();
assert_eq!(restored.0, 11);
assert_eq!(restored.1, Some(3));
}
#[test]
fn swap_store_read_prefers_pending_buffer() {
let mut swap = SwapStore::new(8).unwrap();
swap.write_slot(9, (5, Some(1))).unwrap();
swap.flush().unwrap();
swap.write_slot(9, (8, Some(7))).unwrap();
let restored = swap.read_slot(9).unwrap();
assert_eq!(restored.0, 8);
assert_eq!(restored.1, Some(7));
swap.flush().unwrap();
let restored = swap.read_slot(9).unwrap();
assert_eq!(restored.0, 8);
assert_eq!(restored.1, Some(7));
}
#[test]
fn swap_store_isolates_parallel_instances() {
let thread_count = 8;
let start_barrier = Arc::new(Barrier::new(thread_count));
let read_barrier = Arc::new(Barrier::new(thread_count));
thread::scope(|scope| {
let mut handles = Vec::new();
for worker in 0..thread_count {
let start_barrier = Arc::clone(&start_barrier);
let read_barrier = Arc::clone(&read_barrier);
handles.push(scope.spawn(move || {
let mut swap = SwapStore::new(1).unwrap();
start_barrier.wait();
swap.write_slot(0, (worker as u64, Some(worker))).unwrap();
swap.flush().unwrap();
read_barrier.wait();
swap.read_slot(0).unwrap()
}));
}
for (worker, handle) in handles.into_iter().enumerate() {
let restored = handle.join().unwrap();
assert_eq!(restored.0, worker as u64);
assert_eq!(restored.1, Some(worker));
}
});
}
}