use memmap2::Mmap;
use ndarray::Array2;
use std::fs::File;
use std::path::{Path, PathBuf};
use std::sync::Arc;
pub const SHARD_MAGIC: [u8; 8] = *b"GAMSAE01";
pub const DTYPE_F32: u32 = 0;
pub const HEADER_LEN: usize = 32;
pub(super) const DEFAULT_PREFETCH_WINDOW_BYTES: usize = 8 * 1024 * 1024;
pub(super) const DEFAULT_BATCH_ROWS: usize = 1024;
#[derive(Debug)]
pub enum ShardError {
Io(std::io::Error),
BadMagic {
path: PathBuf,
},
BadDtype {
path: PathBuf,
tag: u32,
},
Truncated {
path: PathBuf,
expected: usize,
actual: usize,
},
WidthMismatch {
expected: usize,
found: usize,
path: PathBuf,
},
ResidencyInvariant {
cursor_shard: usize,
front_shard: usize,
},
Empty,
}
impl std::fmt::Display for ShardError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ShardError::Io(e) => write!(f, "shard I/O error: {e}"),
ShardError::BadMagic { path } => {
write!(f, "shard '{}' has wrong magic header", path.display())
}
ShardError::BadDtype { path, tag } => write!(
f,
"shard '{}' has unsupported dtype tag {tag} (only f32={DTYPE_F32})",
path.display()
),
ShardError::Truncated {
path,
expected,
actual,
} => write!(
f,
"shard '{}' is truncated: header expects {expected} bytes, file has {actual}",
path.display()
),
ShardError::WidthMismatch {
expected,
found,
path,
} => write!(
f,
"shard '{}' has width p={found}, expected p={expected}",
path.display()
),
ShardError::ResidencyInvariant {
cursor_shard,
front_shard,
} => write!(
f,
"shard window residency invariant violated: read cursor is at shard {cursor_shard} but the window front is shard {front_shard}"
),
ShardError::Empty => write!(f, "shard source has no shards / no rows"),
}
}
}
impl std::error::Error for ShardError {}
impl From<std::io::Error> for ShardError {
fn from(e: std::io::Error) -> Self {
ShardError::Io(e)
}
}
#[derive(Debug, Clone)]
pub struct RowBatch {
pub rows: Array2<f64>,
pub row_ids: Vec<u64>,
}
impl RowBatch {
#[inline]
pub fn len(&self) -> usize {
self.row_ids.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.row_ids.is_empty()
}
}
pub trait CorpusRowSource {
fn total_rows(&self) -> u64;
fn width(&self) -> usize;
fn next_batch(&mut self) -> Result<Option<RowBatch>, ShardError>;
fn reset(&mut self);
fn batch_rows(&self) -> usize;
}
struct MappedShard {
mmap: Arc<Mmap>,
n_rows: usize,
p: usize,
data_offset: usize,
global_row_base: u64,
}
impl MappedShard {
fn open(path: PathBuf) -> Result<Self, ShardError> {
let file = File::open(&path)?;
let mmap = unsafe { Mmap::map(&file)? };
if mmap.len() < HEADER_LEN {
return Err(ShardError::Truncated {
path,
expected: HEADER_LEN,
actual: mmap.len(),
});
}
if mmap[0..8] != SHARD_MAGIC {
return Err(ShardError::BadMagic { path });
}
let n_rows = u64::from_le_bytes(mmap[8..16].try_into().expect("8 bytes")) as usize;
let p = u64::from_le_bytes(mmap[16..24].try_into().expect("8 bytes")) as usize;
let dtype = u32::from_le_bytes(mmap[24..28].try_into().expect("4 bytes"));
if dtype != DTYPE_F32 {
return Err(ShardError::BadDtype { path, tag: dtype });
}
let payload_bytes = n_rows
.checked_mul(p)
.and_then(|cells| cells.checked_mul(std::mem::size_of::<f32>()))
.ok_or_else(|| ShardError::Truncated {
path: path.clone(),
expected: usize::MAX,
actual: mmap.len(),
})?;
let expected = HEADER_LEN + payload_bytes;
if mmap.len() < expected {
return Err(ShardError::Truncated {
path,
expected,
actual: mmap.len(),
});
}
Ok(Self {
mmap: Arc::new(mmap),
n_rows,
p,
data_offset: HEADER_LEN,
global_row_base: 0,
})
}
#[inline]
fn read_row_into(&self, local_row: usize, out: &mut [f64]) {
assert_eq!(out.len(), self.p);
let byte_start = self.data_offset + local_row * self.p * std::mem::size_of::<f32>();
let bytes = &self.mmap[byte_start..byte_start + self.p * std::mem::size_of::<f32>()];
for (c, slot) in out.iter_mut().enumerate() {
let b = c * std::mem::size_of::<f32>();
let lane = f32::from_le_bytes(bytes[b..b + 4].try_into().expect("4 bytes"));
*slot = f64::from(lane);
}
}
fn prefetch(&self, byte_start: usize, window: usize) {
let payload_end = self.data_offset + self.n_rows * self.p * std::mem::size_of::<f32>();
let end = byte_start.saturating_add(window).min(payload_end);
if end <= byte_start {
return;
}
let page = 4096usize;
let base = self.mmap.as_ptr();
let mut off = byte_start;
while off < end {
unsafe {
std::ptr::read_volatile(base.add(off));
}
off += page;
}
}
}
pub struct MmapShardSource {
shards: Vec<MappedShard>,
p: usize,
total_rows: u64,
batch_rows: usize,
prefetch_window_bytes: usize,
cursor_shard: usize,
cursor_local_row: usize,
}
impl MmapShardSource {
pub fn open(paths: &[PathBuf]) -> Result<Self, ShardError> {
if paths.is_empty() {
return Err(ShardError::Empty);
}
let mut shards = Vec::with_capacity(paths.len());
let mut p: Option<usize> = None;
let mut running_base: u64 = 0;
for path in paths {
let mut shard = MappedShard::open(path.clone())?;
match p {
None => p = Some(shard.p),
Some(expected) if expected != shard.p => {
return Err(ShardError::WidthMismatch {
expected,
found: shard.p,
path: path.clone(),
});
}
Some(_) => {}
}
shard.global_row_base = running_base;
running_base = running_base.saturating_add(shard.n_rows as u64);
shards.push(shard);
}
let p = p.ok_or(ShardError::Empty)?;
let total_rows = running_base;
if total_rows == 0 {
return Err(ShardError::Empty);
}
let batch_rows = DEFAULT_BATCH_ROWS.min(total_rows as usize).max(1);
Ok(Self {
shards,
p,
total_rows,
batch_rows,
prefetch_window_bytes: DEFAULT_PREFETCH_WINDOW_BYTES,
cursor_shard: 0,
cursor_local_row: 0,
})
}
pub fn open_dir(dir: &Path) -> Result<Self, ShardError> {
let mut paths: Vec<PathBuf> = Vec::new();
for entry in std::fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) == Some("shard") {
paths.push(path);
}
}
paths.sort_by(|a, b| a.file_name().cmp(&b.file_name()));
if paths.is_empty() {
return Err(ShardError::Empty);
}
Self::open(&paths)
}
#[inline]
fn at_end(&self) -> bool {
self.cursor_shard >= self.shards.len()
}
fn skip_drained_shards(&mut self) {
while self.cursor_shard < self.shards.len()
&& self.cursor_local_row >= self.shards[self.cursor_shard].n_rows
{
self.cursor_shard += 1;
self.cursor_local_row = 0;
}
}
}
impl CorpusRowSource for MmapShardSource {
fn total_rows(&self) -> u64 {
self.total_rows
}
fn width(&self) -> usize {
self.p
}
fn batch_rows(&self) -> usize {
self.batch_rows
}
fn reset(&mut self) {
self.cursor_shard = 0;
self.cursor_local_row = 0;
}
fn next_batch(&mut self) -> Result<Option<RowBatch>, ShardError> {
self.skip_drained_shards();
if self.at_end() {
return Ok(None);
}
let shard_idx = self.cursor_shard;
let take = {
let shard = &self.shards[shard_idx];
let remaining = shard.n_rows - self.cursor_local_row;
self.batch_rows.min(remaining)
};
{
let shard = &self.shards[shard_idx];
let first_byte =
shard.data_offset + self.cursor_local_row * shard.p * std::mem::size_of::<f32>();
let want = take * shard.p * std::mem::size_of::<f32>();
shard.prefetch(first_byte, want.min(self.prefetch_window_bytes));
}
let p = self.p;
let mut rows = Array2::<f64>::zeros((take, p));
let mut row_ids = Vec::with_capacity(take);
{
let shard = &self.shards[shard_idx];
for k in 0..take {
let local = self.cursor_local_row + k;
let mut row_view = rows.row_mut(k);
let slice = row_view
.as_slice_mut()
.expect("freshly allocated contiguous row");
shard.read_row_into(local, slice);
row_ids.push(shard.global_row_base + local as u64);
}
}
self.cursor_local_row += take;
self.skip_drained_shards();
Ok(Some(RowBatch { rows, row_ids }))
}
}
pub fn encode_shard_bytes(rows: ndarray::ArrayView2<'_, f64>) -> Vec<u8> {
let n_rows = rows.nrows();
let p = rows.ncols();
let mut out = Vec::with_capacity(HEADER_LEN + n_rows * p * std::mem::size_of::<f32>());
out.extend_from_slice(&SHARD_MAGIC);
out.extend_from_slice(&(n_rows as u64).to_le_bytes());
out.extend_from_slice(&(p as u64).to_le_bytes());
out.extend_from_slice(&DTYPE_F32.to_le_bytes());
out.extend_from_slice(&0u32.to_le_bytes());
for row in rows.outer_iter() {
for &v in row.iter() {
out.extend_from_slice(&(v as f32).to_le_bytes());
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
use std::io::Write;
fn write_temp_shard(name: &str, rows: ndarray::ArrayView2<'_, f64>) -> PathBuf {
let bytes = encode_shard_bytes(rows);
let mut path = std::env::temp_dir();
path.push(format!(
"gam-sae-corpus-test-{}-{}.shard",
std::process::id(),
name
));
let mut f = File::create(&path).expect("create temp shard");
f.write_all(&bytes).expect("write shard");
f.sync_all().expect("sync shard");
path
}
#[test]
fn single_shard_round_trips_rows_and_ids() {
let data = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let path = write_temp_shard("single", data.view());
let mut src = MmapShardSource::open(&[path.clone()]).expect("open");
assert_eq!(src.total_rows(), 3);
assert_eq!(src.width(), 3);
let batch = src.next_batch().expect("batch").expect("some");
assert_eq!(batch.row_ids, vec![0, 1, 2]);
assert_eq!(batch.rows, data);
assert!(src.next_batch().expect("end").is_none());
std::fs::remove_file(&path).ok();
}
#[test]
fn multi_shard_global_ids_are_contiguous() {
let a = array![[1.0_f64], [2.0]];
let b = array![[3.0_f64], [4.0], [5.0]];
let pa = write_temp_shard("multi-a", a.view());
let pb = write_temp_shard("multi-b", b.view());
let mut src = MmapShardSource::open(&[pa.clone(), pb.clone()]).expect("open");
assert_eq!(src.total_rows(), 5);
let mut all_ids = Vec::new();
let mut all_vals = Vec::new();
while let Some(batch) = src.next_batch().expect("batch") {
all_ids.extend(batch.row_ids.iter().copied());
for r in batch.rows.outer_iter() {
all_vals.push(r[0]);
}
}
assert_eq!(all_ids, vec![0, 1, 2, 3, 4]);
assert_eq!(all_vals, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
std::fs::remove_file(&pa).ok();
std::fs::remove_file(&pb).ok();
}
#[test]
fn reset_replays_identical_sequence() {
let data = array![[1.0_f64, 1.0], [2.0, 2.0]];
let path = write_temp_shard("reset", data.view());
let mut src = MmapShardSource::open(&[path.clone()]).expect("open");
let first: Vec<u64> = {
let mut ids = Vec::new();
while let Some(b) = src.next_batch().expect("b") {
ids.extend(b.row_ids);
}
ids
};
src.reset();
let second: Vec<u64> = {
let mut ids = Vec::new();
while let Some(b) = src.next_batch().expect("b") {
ids.extend(b.row_ids);
}
ids
};
assert_eq!(first, second);
std::fs::remove_file(&path).ok();
}
#[test]
fn bad_magic_is_rejected() {
let mut path = std::env::temp_dir();
path.push(format!(
"gam-sae-corpus-badmagic-{}.shard",
std::process::id()
));
let mut f = File::create(&path).expect("create");
f.write_all(&[0u8; 64]).expect("write");
f.sync_all().ok();
let err = MmapShardSource::open(&[path.clone()]);
assert!(matches!(err, Err(ShardError::BadMagic { .. })));
std::fs::remove_file(&path).ok();
}
}