1use memmap2::Mmap;
46use ndarray::Array2;
47use std::fs::File;
48use std::path::{Path, PathBuf};
49use std::sync::Arc;
50
51pub const SHARD_MAGIC: [u8; 8] = *b"GAMSAE01";
53pub const DTYPE_F32: u32 = 0;
55pub const HEADER_LEN: usize = 32;
57
58pub(super) const DEFAULT_PREFETCH_WINDOW_BYTES: usize = 8 * 1024 * 1024;
64
65pub(super) const DEFAULT_BATCH_ROWS: usize = 1024;
69
70#[derive(Debug)]
71pub enum ShardError {
72 Io(std::io::Error),
73 BadMagic {
74 path: PathBuf,
75 },
76 BadDtype {
77 path: PathBuf,
78 tag: u32,
79 },
80 Truncated {
81 path: PathBuf,
82 expected: usize,
83 actual: usize,
84 },
85 WidthMismatch {
87 expected: usize,
88 found: usize,
89 path: PathBuf,
90 },
91 ResidencyInvariant {
98 cursor_shard: usize,
99 front_shard: usize,
100 },
101 Empty,
102}
103
104impl std::fmt::Display for ShardError {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 match self {
107 ShardError::Io(e) => write!(f, "shard I/O error: {e}"),
108 ShardError::BadMagic { path } => {
109 write!(f, "shard '{}' has wrong magic header", path.display())
110 }
111 ShardError::BadDtype { path, tag } => write!(
112 f,
113 "shard '{}' has unsupported dtype tag {tag} (only f32={DTYPE_F32})",
114 path.display()
115 ),
116 ShardError::Truncated {
117 path,
118 expected,
119 actual,
120 } => write!(
121 f,
122 "shard '{}' is truncated: header expects {expected} bytes, file has {actual}",
123 path.display()
124 ),
125 ShardError::WidthMismatch {
126 expected,
127 found,
128 path,
129 } => write!(
130 f,
131 "shard '{}' has width p={found}, expected p={expected}",
132 path.display()
133 ),
134 ShardError::ResidencyInvariant {
135 cursor_shard,
136 front_shard,
137 } => write!(
138 f,
139 "shard window residency invariant violated: read cursor is at shard {cursor_shard} but the window front is shard {front_shard}"
140 ),
141 ShardError::Empty => write!(f, "shard source has no shards / no rows"),
142 }
143 }
144}
145
146impl std::error::Error for ShardError {}
147
148impl From<std::io::Error> for ShardError {
149 fn from(e: std::io::Error) -> Self {
150 ShardError::Io(e)
151 }
152}
153
154#[derive(Debug, Clone)]
159pub struct RowBatch {
160 pub rows: Array2<f64>,
162 pub row_ids: Vec<u64>,
164}
165
166impl RowBatch {
167 #[inline]
168 pub fn len(&self) -> usize {
169 self.row_ids.len()
170 }
171
172 #[inline]
173 pub fn is_empty(&self) -> bool {
174 self.row_ids.is_empty()
175 }
176}
177
178pub trait CorpusRowSource {
192 fn total_rows(&self) -> u64;
194 fn width(&self) -> usize;
196 fn next_batch(&mut self) -> Result<Option<RowBatch>, ShardError>;
198 fn reset(&mut self);
200 fn batch_rows(&self) -> usize;
202}
203
204struct MappedShard {
206 mmap: Arc<Mmap>,
207 n_rows: usize,
208 p: usize,
209 data_offset: usize,
210 global_row_base: u64,
212}
213
214impl MappedShard {
215 fn open(path: PathBuf) -> Result<Self, ShardError> {
216 let file = File::open(&path)?;
217 let mmap = unsafe { Mmap::map(&file)? };
221 if mmap.len() < HEADER_LEN {
222 return Err(ShardError::Truncated {
223 path,
224 expected: HEADER_LEN,
225 actual: mmap.len(),
226 });
227 }
228 if mmap[0..8] != SHARD_MAGIC {
229 return Err(ShardError::BadMagic { path });
230 }
231 let n_rows = u64::from_le_bytes(mmap[8..16].try_into().expect("8 bytes")) as usize;
232 let p = u64::from_le_bytes(mmap[16..24].try_into().expect("8 bytes")) as usize;
233 let dtype = u32::from_le_bytes(mmap[24..28].try_into().expect("4 bytes"));
234 if dtype != DTYPE_F32 {
235 return Err(ShardError::BadDtype { path, tag: dtype });
236 }
237 let payload_bytes = n_rows
238 .checked_mul(p)
239 .and_then(|cells| cells.checked_mul(std::mem::size_of::<f32>()))
240 .ok_or_else(|| ShardError::Truncated {
241 path: path.clone(),
242 expected: usize::MAX,
243 actual: mmap.len(),
244 })?;
245 let expected = HEADER_LEN + payload_bytes;
246 if mmap.len() < expected {
247 return Err(ShardError::Truncated {
248 path,
249 expected,
250 actual: mmap.len(),
251 });
252 }
253 Ok(Self {
254 mmap: Arc::new(mmap),
255 n_rows,
256 p,
257 data_offset: HEADER_LEN,
258 global_row_base: 0,
259 })
260 }
261
262 #[inline]
264 fn read_row_into(&self, local_row: usize, out: &mut [f64]) {
265 assert_eq!(out.len(), self.p);
266 let byte_start = self.data_offset + local_row * self.p * std::mem::size_of::<f32>();
267 let bytes = &self.mmap[byte_start..byte_start + self.p * std::mem::size_of::<f32>()];
268 for (c, slot) in out.iter_mut().enumerate() {
269 let b = c * std::mem::size_of::<f32>();
270 let lane = f32::from_le_bytes(bytes[b..b + 4].try_into().expect("4 bytes"));
271 *slot = f64::from(lane);
272 }
273 }
274
275 fn prefetch(&self, byte_start: usize, window: usize) {
278 let payload_end = self.data_offset + self.n_rows * self.p * std::mem::size_of::<f32>();
279 let end = byte_start.saturating_add(window).min(payload_end);
280 if end <= byte_start {
281 return;
282 }
283 let page = 4096usize;
288 let base = self.mmap.as_ptr();
289 let mut off = byte_start;
290 while off < end {
291 unsafe {
294 std::ptr::read_volatile(base.add(off));
295 }
296 off += page;
297 }
298 }
299}
300
301pub struct MmapShardSource {
304 shards: Vec<MappedShard>,
305 p: usize,
306 total_rows: u64,
307 batch_rows: usize,
308 prefetch_window_bytes: usize,
309 cursor_shard: usize,
311 cursor_local_row: usize,
313}
314
315impl MmapShardSource {
316 pub fn open(paths: &[PathBuf]) -> Result<Self, ShardError> {
320 if paths.is_empty() {
321 return Err(ShardError::Empty);
322 }
323 let mut shards = Vec::with_capacity(paths.len());
324 let mut p: Option<usize> = None;
325 let mut running_base: u64 = 0;
326 for path in paths {
327 let mut shard = MappedShard::open(path.clone())?;
328 match p {
329 None => p = Some(shard.p),
330 Some(expected) if expected != shard.p => {
331 return Err(ShardError::WidthMismatch {
332 expected,
333 found: shard.p,
334 path: path.clone(),
335 });
336 }
337 Some(_) => {}
338 }
339 shard.global_row_base = running_base;
340 running_base = running_base.saturating_add(shard.n_rows as u64);
341 shards.push(shard);
342 }
343 let p = p.ok_or(ShardError::Empty)?;
344 let total_rows = running_base;
345 if total_rows == 0 {
346 return Err(ShardError::Empty);
347 }
348 let batch_rows = DEFAULT_BATCH_ROWS.min(total_rows as usize).max(1);
351 Ok(Self {
352 shards,
353 p,
354 total_rows,
355 batch_rows,
356 prefetch_window_bytes: DEFAULT_PREFETCH_WINDOW_BYTES,
357 cursor_shard: 0,
358 cursor_local_row: 0,
359 })
360 }
361
362 pub fn open_dir(dir: &Path) -> Result<Self, ShardError> {
366 let mut paths: Vec<PathBuf> = Vec::new();
367 for entry in std::fs::read_dir(dir)? {
368 let entry = entry?;
369 let path = entry.path();
370 if path.extension().and_then(|e| e.to_str()) == Some("shard") {
371 paths.push(path);
372 }
373 }
374 paths.sort_by(|a, b| a.file_name().cmp(&b.file_name()));
377 if paths.is_empty() {
378 return Err(ShardError::Empty);
379 }
380 Self::open(&paths)
381 }
382
383 #[inline]
385 fn at_end(&self) -> bool {
386 self.cursor_shard >= self.shards.len()
387 }
388
389 fn skip_drained_shards(&mut self) {
393 while self.cursor_shard < self.shards.len()
394 && self.cursor_local_row >= self.shards[self.cursor_shard].n_rows
395 {
396 self.cursor_shard += 1;
397 self.cursor_local_row = 0;
398 }
399 }
400}
401
402impl CorpusRowSource for MmapShardSource {
403 fn total_rows(&self) -> u64 {
404 self.total_rows
405 }
406
407 fn width(&self) -> usize {
408 self.p
409 }
410
411 fn batch_rows(&self) -> usize {
412 self.batch_rows
413 }
414
415 fn reset(&mut self) {
416 self.cursor_shard = 0;
417 self.cursor_local_row = 0;
418 }
419
420 fn next_batch(&mut self) -> Result<Option<RowBatch>, ShardError> {
421 self.skip_drained_shards();
422 if self.at_end() {
423 return Ok(None);
424 }
425 let shard_idx = self.cursor_shard;
430 let take = {
431 let shard = &self.shards[shard_idx];
432 let remaining = shard.n_rows - self.cursor_local_row;
433 self.batch_rows.min(remaining)
434 };
435
436 {
439 let shard = &self.shards[shard_idx];
440 let first_byte =
441 shard.data_offset + self.cursor_local_row * shard.p * std::mem::size_of::<f32>();
442 let want = take * shard.p * std::mem::size_of::<f32>();
443 shard.prefetch(first_byte, want.min(self.prefetch_window_bytes));
444 }
445
446 let p = self.p;
447 let mut rows = Array2::<f64>::zeros((take, p));
448 let mut row_ids = Vec::with_capacity(take);
449 {
450 let shard = &self.shards[shard_idx];
451 for k in 0..take {
452 let local = self.cursor_local_row + k;
453 let mut row_view = rows.row_mut(k);
454 let slice = row_view
455 .as_slice_mut()
456 .expect("freshly allocated contiguous row");
457 shard.read_row_into(local, slice);
458 row_ids.push(shard.global_row_base + local as u64);
459 }
460 }
461 self.cursor_local_row += take;
462 self.skip_drained_shards();
463 Ok(Some(RowBatch { rows, row_ids }))
464 }
465}
466
467pub fn encode_shard_bytes(rows: ndarray::ArrayView2<'_, f64>) -> Vec<u8> {
475 let n_rows = rows.nrows();
476 let p = rows.ncols();
477 let mut out = Vec::with_capacity(HEADER_LEN + n_rows * p * std::mem::size_of::<f32>());
478 out.extend_from_slice(&SHARD_MAGIC);
479 out.extend_from_slice(&(n_rows as u64).to_le_bytes());
480 out.extend_from_slice(&(p as u64).to_le_bytes());
481 out.extend_from_slice(&DTYPE_F32.to_le_bytes());
482 out.extend_from_slice(&0u32.to_le_bytes());
483 for row in rows.outer_iter() {
484 for &v in row.iter() {
485 out.extend_from_slice(&(v as f32).to_le_bytes());
486 }
487 }
488 out
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494 use ndarray::array;
495 use std::io::Write;
496
497 fn write_temp_shard(name: &str, rows: ndarray::ArrayView2<'_, f64>) -> PathBuf {
498 let bytes = encode_shard_bytes(rows);
499 let mut path = std::env::temp_dir();
500 path.push(format!(
501 "gam-sae-corpus-test-{}-{}.shard",
502 std::process::id(),
503 name
504 ));
505 let mut f = File::create(&path).expect("create temp shard");
506 f.write_all(&bytes).expect("write shard");
507 f.sync_all().expect("sync shard");
508 path
509 }
510
511 #[test]
512 fn single_shard_round_trips_rows_and_ids() {
513 let data = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
514 let path = write_temp_shard("single", data.view());
515 let mut src = MmapShardSource::open(&[path.clone()]).expect("open");
516 assert_eq!(src.total_rows(), 3);
517 assert_eq!(src.width(), 3);
518 let batch = src.next_batch().expect("batch").expect("some");
519 assert_eq!(batch.row_ids, vec![0, 1, 2]);
520 assert_eq!(batch.rows, data);
521 assert!(src.next_batch().expect("end").is_none());
522 std::fs::remove_file(&path).ok();
523 }
524
525 #[test]
526 fn multi_shard_global_ids_are_contiguous() {
527 let a = array![[1.0_f64], [2.0]];
528 let b = array![[3.0_f64], [4.0], [5.0]];
529 let pa = write_temp_shard("multi-a", a.view());
530 let pb = write_temp_shard("multi-b", b.view());
531 let mut src = MmapShardSource::open(&[pa.clone(), pb.clone()]).expect("open");
532 assert_eq!(src.total_rows(), 5);
533 let mut all_ids = Vec::new();
534 let mut all_vals = Vec::new();
535 while let Some(batch) = src.next_batch().expect("batch") {
536 all_ids.extend(batch.row_ids.iter().copied());
537 for r in batch.rows.outer_iter() {
538 all_vals.push(r[0]);
539 }
540 }
541 assert_eq!(all_ids, vec![0, 1, 2, 3, 4]);
542 assert_eq!(all_vals, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
543 std::fs::remove_file(&pa).ok();
544 std::fs::remove_file(&pb).ok();
545 }
546
547 #[test]
548 fn reset_replays_identical_sequence() {
549 let data = array![[1.0_f64, 1.0], [2.0, 2.0]];
550 let path = write_temp_shard("reset", data.view());
551 let mut src = MmapShardSource::open(&[path.clone()]).expect("open");
552 let first: Vec<u64> = {
553 let mut ids = Vec::new();
554 while let Some(b) = src.next_batch().expect("b") {
555 ids.extend(b.row_ids);
556 }
557 ids
558 };
559 src.reset();
560 let second: Vec<u64> = {
561 let mut ids = Vec::new();
562 while let Some(b) = src.next_batch().expect("b") {
563 ids.extend(b.row_ids);
564 }
565 ids
566 };
567 assert_eq!(first, second);
568 std::fs::remove_file(&path).ok();
569 }
570
571 #[test]
572 fn bad_magic_is_rejected() {
573 let mut path = std::env::temp_dir();
574 path.push(format!(
575 "gam-sae-corpus-badmagic-{}.shard",
576 std::process::id()
577 ));
578 let mut f = File::create(&path).expect("create");
579 f.write_all(&[0u8; 64]).expect("write");
580 f.sync_all().ok();
581 let err = MmapShardSource::open(&[path.clone()]);
582 assert!(matches!(err, Err(ShardError::BadMagic { .. })));
583 std::fs::remove_file(&path).ok();
584 }
585}