1use ndarray::Array2;
54use std::collections::VecDeque;
55use std::fs::File;
56use std::io::{Read, Seek, SeekFrom};
57use std::path::PathBuf;
58use std::sync::Arc;
59
60use super::shard_reader::{
61 CorpusRowSource, DEFAULT_BATCH_ROWS, DEFAULT_PREFETCH_WINDOW_BYTES, DTYPE_F32, HEADER_LEN,
62 RowBatch, SHARD_MAGIC, ShardError,
63};
64
65pub const PREFETCH_SHARDS_AHEAD: usize = 2;
72
73pub const DESIGNED_SAMPLE_MANDATORY_MIN_ROWS: u64 = 100_000_000;
83
84#[inline]
87pub fn designed_sampling_mandatory(total_rows: u64) -> bool {
88 total_rows >= DESIGNED_SAMPLE_MANDATORY_MIN_ROWS
89}
90
91pub trait ObjectStore: Send + Sync {
99 fn list_shards(&self) -> Result<Vec<String>, ShardError>;
102
103 fn fetch(&self, key: &str) -> Result<Vec<u8>, ShardError>;
105
106 fn fetch_range(&self, key: &str, offset: u64, len: usize) -> Result<Vec<u8>, ShardError> {
110 let full = self.fetch(key)?;
111 let start = (offset as usize).min(full.len());
112 let end = start.saturating_add(len).min(full.len());
113 Ok(full[start..end].to_vec())
114 }
115}
116
117pub struct FsObjectStore {
120 root: PathBuf,
121}
122
123impl FsObjectStore {
124 pub fn new(root: PathBuf) -> Self {
125 Self { root }
126 }
127
128 fn path_of(&self, key: &str) -> PathBuf {
129 self.root.join(key)
130 }
131}
132
133impl ObjectStore for FsObjectStore {
134 fn list_shards(&self) -> Result<Vec<String>, ShardError> {
135 let mut keys = Vec::new();
136 for entry in std::fs::read_dir(&self.root)? {
137 let entry = entry?;
138 let path = entry.path();
139 if path.extension().and_then(|e| e.to_str()) == Some("shard") {
140 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
141 keys.push(name.to_string());
142 }
143 }
144 }
145 Ok(keys)
146 }
147
148 fn fetch(&self, key: &str) -> Result<Vec<u8>, ShardError> {
149 let mut bytes = Vec::new();
150 File::open(self.path_of(key))?.read_to_end(&mut bytes)?;
151 Ok(bytes)
152 }
153
154 fn fetch_range(&self, key: &str, offset: u64, len: usize) -> Result<Vec<u8>, ShardError> {
155 let mut file = File::open(self.path_of(key))?;
156 let total = file.metadata()?.len();
157 let start = offset.min(total);
158 let take = (len as u64).min(total - start) as usize;
159 file.seek(SeekFrom::Start(start))?;
160 let mut buf = vec![0u8; take];
161 file.read_exact(&mut buf)?;
162 Ok(buf)
163 }
164}
165
166#[derive(Clone, Debug)]
168struct ShardMeta {
169 key: String,
170 n_rows: usize,
171 global_row_base: u64,
173}
174
175struct ResidentShard {
178 shard_idx: usize,
180 payload: Vec<u8>,
182}
183
184fn parse_header(key: &str, header: &[u8]) -> Result<(usize, usize), ShardError> {
187 let path = PathBuf::from(key);
188 if header.len() < HEADER_LEN {
189 return Err(ShardError::Truncated {
190 path,
191 expected: HEADER_LEN,
192 actual: header.len(),
193 });
194 }
195 if header[0..8] != SHARD_MAGIC {
196 return Err(ShardError::BadMagic { path });
197 }
198 let n_rows = u64::from_le_bytes(header[8..16].try_into().expect("8 bytes")) as usize;
199 let p = u64::from_le_bytes(header[16..24].try_into().expect("8 bytes")) as usize;
200 let dtype = u32::from_le_bytes(header[24..28].try_into().expect("4 bytes"));
201 if dtype != DTYPE_F32 {
202 return Err(ShardError::BadDtype { path, tag: dtype });
203 }
204 Ok((n_rows, p))
205}
206
207pub struct ObjectStoreShardSource {
211 store: Arc<dyn ObjectStore>,
212 shards: Vec<ShardMeta>,
214 p: usize,
215 total_rows: u64,
216 batch_rows: usize,
217 window: VecDeque<ResidentShard>,
221 cursor_shard: usize,
223 cursor_local_row: usize,
225}
226
227impl ObjectStoreShardSource {
228 pub fn open(store: Arc<dyn ObjectStore>) -> Result<Self, ShardError> {
232 let mut keys = store.list_shards()?;
233 keys.sort();
236 if keys.is_empty() {
237 return Err(ShardError::Empty);
238 }
239 let mut shards = Vec::with_capacity(keys.len());
240 let mut p: Option<usize> = None;
241 let mut running_base: u64 = 0;
242 for key in keys {
243 let header = store.fetch_range(&key, 0, HEADER_LEN)?;
244 let (n_rows, shard_p) = parse_header(&key, &header)?;
245 match p {
246 None => p = Some(shard_p),
247 Some(expected) if expected != shard_p => {
248 return Err(ShardError::WidthMismatch {
249 expected,
250 found: shard_p,
251 path: PathBuf::from(&key),
252 });
253 }
254 Some(_) => {}
255 }
256 shards.push(ShardMeta {
257 key,
258 n_rows,
259 global_row_base: running_base,
260 });
261 running_base = running_base.saturating_add(n_rows as u64);
262 }
263 let p = p.ok_or(ShardError::Empty)?;
264 let total_rows = running_base;
265 if total_rows == 0 {
266 return Err(ShardError::Empty);
267 }
268 let row_bytes = p.max(1) * std::mem::size_of::<f32>();
272 let window_rows = (DEFAULT_PREFETCH_WINDOW_BYTES / row_bytes).max(1);
273 let batch_rows = DEFAULT_BATCH_ROWS
274 .min(total_rows as usize)
275 .min(window_rows)
276 .max(1);
277 Ok(Self {
278 store,
279 shards,
280 p,
281 total_rows,
282 batch_rows,
283 window: VecDeque::new(),
284 cursor_shard: 0,
285 cursor_local_row: 0,
286 })
287 }
288
289 #[inline]
291 fn at_end(&self) -> bool {
292 self.cursor_shard >= self.shards.len()
293 }
294
295 fn fetch_shard(&self, shard_idx: usize) -> Result<ResidentShard, ShardError> {
297 let meta = &self.shards[shard_idx];
298 let payload_len = meta
299 .n_rows
300 .checked_mul(self.p)
301 .and_then(|cells| cells.checked_mul(std::mem::size_of::<f32>()))
302 .ok_or_else(|| ShardError::Truncated {
303 path: PathBuf::from(&meta.key),
304 expected: usize::MAX,
305 actual: 0,
306 })?;
307 let payload = self
308 .store
309 .fetch_range(&meta.key, HEADER_LEN as u64, payload_len)?;
310 if payload.len() < payload_len {
311 return Err(ShardError::Truncated {
312 path: PathBuf::from(&meta.key),
313 expected: HEADER_LEN + payload_len,
314 actual: HEADER_LEN + payload.len(),
315 });
316 }
317 Ok(ResidentShard { shard_idx, payload })
318 }
319
320 fn fill_window(&mut self) -> Result<(), ShardError> {
324 while self.cursor_shard < self.shards.len()
326 && self.cursor_local_row >= self.shards[self.cursor_shard].n_rows
327 {
328 self.cursor_shard += 1;
329 self.cursor_local_row = 0;
330 if let Some(front) = self.window.front() {
331 if front.shard_idx < self.cursor_shard {
332 self.window.pop_front();
333 }
334 }
335 }
336 if self.at_end() {
337 self.window.clear();
338 return Ok(());
339 }
340 while let Some(front) = self.window.front() {
342 if front.shard_idx < self.cursor_shard {
343 self.window.pop_front();
344 } else {
345 break;
346 }
347 }
348 let want_last = (self.cursor_shard + PREFETCH_SHARDS_AHEAD).min(self.shards.len() - 1);
350 let mut next_fetch = match self.window.back() {
351 Some(back) => back.shard_idx + 1,
352 None => self.cursor_shard,
353 };
354 while next_fetch <= want_last {
355 let resident = self.fetch_shard(next_fetch)?;
356 self.window.push_back(resident);
357 next_fetch += 1;
358 }
359 Ok(())
360 }
361}
362
363impl CorpusRowSource for ObjectStoreShardSource {
364 fn total_rows(&self) -> u64 {
365 self.total_rows
366 }
367
368 fn width(&self) -> usize {
369 self.p
370 }
371
372 fn batch_rows(&self) -> usize {
373 self.batch_rows
374 }
375
376 fn reset(&mut self) {
377 self.cursor_shard = 0;
378 self.cursor_local_row = 0;
379 self.window.clear();
380 }
381
382 fn next_batch(&mut self) -> Result<Option<RowBatch>, ShardError> {
383 self.fill_window()?;
384 if self.at_end() {
385 return Ok(None);
386 }
387 let meta = &self.shards[self.cursor_shard];
388 let front = self
389 .window
390 .front()
391 .expect("fill_window leaves the current shard resident");
392 if front.shard_idx != self.cursor_shard {
398 return Err(ShardError::ResidencyInvariant {
399 cursor_shard: self.cursor_shard,
400 front_shard: front.shard_idx,
401 });
402 }
403
404 let remaining = meta.n_rows - self.cursor_local_row;
407 let take = self.batch_rows.min(remaining);
408 let row_bytes = self.p * std::mem::size_of::<f32>();
409 let mut rows = Array2::<f64>::zeros((take, self.p));
410 let mut row_ids = Vec::with_capacity(take);
411 for k in 0..take {
412 let local = self.cursor_local_row + k;
413 let start = local * row_bytes;
414 let bytes = &front.payload[start..start + row_bytes];
415 let mut row_view = rows.row_mut(k);
416 let slice = row_view
417 .as_slice_mut()
418 .expect("freshly allocated contiguous row");
419 for (c, slot) in slice.iter_mut().enumerate() {
420 let b = c * std::mem::size_of::<f32>();
421 let lane = f32::from_le_bytes(bytes[b..b + 4].try_into().expect("4 bytes"));
422 *slot = f64::from(lane);
423 }
424 row_ids.push(meta.global_row_base + local as u64);
425 }
426 self.cursor_local_row += take;
427 Ok(Some(RowBatch { rows, row_ids }))
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::super::shard_reader::{MmapShardSource, encode_shard_bytes};
434 use super::*;
435 use ndarray::array;
436 use std::io::Write;
437 use std::sync::Mutex;
438
439 fn temp_store(name: &str, shards: &[(&str, Array2<f64>)]) -> PathBuf {
440 let mut dir = std::env::temp_dir();
441 dir.push(format!(
442 "gam-sae-objstore-test-{}-{}",
443 std::process::id(),
444 name
445 ));
446 std::fs::create_dir_all(&dir).expect("create store dir");
447 for (key, rows) in shards {
448 let bytes = encode_shard_bytes(rows.view());
449 let mut f = File::create(dir.join(key)).expect("create shard");
450 f.write_all(&bytes).expect("write shard");
451 f.sync_all().expect("sync shard");
452 }
453 dir
454 }
455
456 fn drain(src: &mut dyn CorpusRowSource) -> (Vec<u64>, Vec<f64>) {
457 let mut ids = Vec::new();
458 let mut vals = Vec::new();
459 while let Some(batch) = src.next_batch().expect("batch") {
460 ids.extend(batch.row_ids.iter().copied());
461 vals.extend(batch.rows.iter().copied());
462 }
463 (ids, vals)
464 }
465
466 #[test]
467 fn object_store_replays_the_mmap_row_sequence_exactly() {
468 let a = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
472 let b = array![[7.0_f64, 8.0], [9.0, 10.0]];
473 let dir = temp_store("parity", &[("a.shard", a), ("b.shard", b)]);
474
475 let store = Arc::new(FsObjectStore::new(dir.clone()));
476 let mut remote = ObjectStoreShardSource::open(store).expect("open object-store source");
477 let mut local = MmapShardSource::open_dir(&dir).expect("open mmap source");
478
479 assert_eq!(remote.total_rows(), local.total_rows());
480 assert_eq!(remote.width(), local.width());
481 let (ids_r, vals_r) = drain(&mut remote);
482 let (ids_l, vals_l) = drain(&mut local);
483 assert_eq!(ids_r, ids_l);
484 assert_eq!(
485 vals_r.iter().map(|v| v.to_bits()).collect::<Vec<_>>(),
486 vals_l.iter().map(|v| v.to_bits()).collect::<Vec<_>>(),
487 "object-store rows must be bit-identical to mmap rows"
488 );
489
490 remote.reset();
492 let (ids_again, vals_again) = drain(&mut remote);
493 assert_eq!(ids_again, ids_r);
494 assert_eq!(vals_again, vals_r);
495 std::fs::remove_dir_all(&dir).ok();
496 }
497
498 struct CountingStore {
504 inner: FsObjectStore,
505 payload_fetches: Mutex<Vec<String>>,
506 }
507
508 impl ObjectStore for CountingStore {
509 fn list_shards(&self) -> Result<Vec<String>, ShardError> {
510 self.inner.list_shards()
511 }
512 fn fetch(&self, key: &str) -> Result<Vec<u8>, ShardError> {
513 self.inner.fetch(key)
514 }
515 fn fetch_range(&self, key: &str, offset: u64, len: usize) -> Result<Vec<u8>, ShardError> {
516 if offset as usize >= HEADER_LEN {
517 self.payload_fetches.lock().unwrap().push(key.to_string());
518 }
519 self.inner.fetch_range(key, offset, len)
520 }
521 }
522
523 #[test]
524 fn prefetch_is_bounded_and_in_key_order() {
525 let mk = |v: f64| array![[v]];
529 let dir = temp_store(
530 "bounded",
531 &[
532 ("s0.shard", mk(0.0)),
533 ("s1.shard", mk(1.0)),
534 ("s2.shard", mk(2.0)),
535 ("s3.shard", mk(3.0)),
536 ("s4.shard", mk(4.0)),
537 ("s5.shard", mk(5.0)),
538 ],
539 );
540 let store = Arc::new(CountingStore {
541 inner: FsObjectStore::new(dir.clone()),
542 payload_fetches: Mutex::new(Vec::new()),
543 });
544 let mut src =
545 ObjectStoreShardSource::open(Arc::clone(&store) as Arc<dyn ObjectStore>).expect("open");
546 let first = src.next_batch().expect("batch").expect("some");
547 assert_eq!(first.row_ids, vec![0]);
548 {
549 let fetched = store.payload_fetches.lock().unwrap();
550 assert!(
551 fetched.len() <= 1 + PREFETCH_SHARDS_AHEAD,
552 "first batch fetched {} shard payloads; window allows {}",
553 fetched.len(),
554 1 + PREFETCH_SHARDS_AHEAD
555 );
556 let mut sorted = fetched.clone();
557 sorted.sort();
558 assert_eq!(*fetched, sorted, "payload fetches must be in key order");
559 }
560 let (ids, _) = drain(&mut src);
562 assert_eq!(ids, vec![1, 2, 3, 4, 5]);
563 let fetched = store.payload_fetches.lock().unwrap();
564 assert_eq!(fetched.len(), 6, "each shard payload fetched exactly once");
565 std::fs::remove_dir_all(&dir).ok();
566 }
567
568 #[test]
569 fn mandatory_selectivity_threshold_is_pure_and_monotone() {
570 assert!(!designed_sampling_mandatory(0));
571 assert!(!designed_sampling_mandatory(
572 DESIGNED_SAMPLE_MANDATORY_MIN_ROWS - 1
573 ));
574 assert!(designed_sampling_mandatory(
575 DESIGNED_SAMPLE_MANDATORY_MIN_ROWS
576 ));
577 assert!(designed_sampling_mandatory(u64::MAX));
578 }
579
580 #[test]
581 fn width_mismatch_is_rejected() {
582 let dir = temp_store(
583 "width",
584 &[
585 ("a.shard", array![[1.0_f64, 2.0]]),
586 ("b.shard", array![[3.0_f64]]),
587 ],
588 );
589 let store = Arc::new(FsObjectStore::new(dir.clone()));
590 let err = ObjectStoreShardSource::open(store);
591 assert!(matches!(err, Err(ShardError::WidthMismatch { .. })));
592 std::fs::remove_dir_all(&dir).ok();
593 }
594}