use crate::codec_file::header::{decode_header, CorpusFileHeader};
use crate::errors::IoError;
use crate::zero_copy::view::CompressedVectorView;
use memmap2::Mmap;
use std::fs::File;
use std::path::Path;
pub struct CorpusFileReader {
mmap: Mmap,
header: CorpusFileHeader,
}
impl std::fmt::Debug for CorpusFileReader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CorpusFileReader")
.field("mmap_len", &self.mmap.len())
.field("vector_count", &self.header.vector_count)
.field("dimension", &self.header.dimension)
.field("bit_width", &self.header.bit_width)
.field("body_offset", &self.header.body_offset)
.finish()
}
}
impl CorpusFileReader {
pub fn open(path: &Path) -> Result<Self, IoError> {
let file = File::open(path)?;
let mmap = unsafe { Mmap::map(&file)? };
let header = decode_header(&mmap)?;
Ok(Self { mmap, header })
}
pub const fn header(&self) -> &CorpusFileHeader {
&self.header
}
pub fn iter(&self) -> CorpusFileIter<'_> {
let body = self.mmap.get(self.header.body_offset..).unwrap_or_default();
CorpusFileIter {
remaining: body,
count: self.header.vector_count,
errored: false,
}
}
}
impl<'a> IntoIterator for &'a CorpusFileReader {
type Item = Result<CompressedVectorView<'a>, IoError>;
type IntoIter = CorpusFileIter<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
pub struct CorpusFileIter<'a> {
remaining: &'a [u8],
count: u64,
errored: bool,
}
fn read_record_len(data: &[u8]) -> Result<(usize, &[u8]), IoError> {
let len_bytes: [u8; 4] = data
.get(0..4)
.ok_or(IoError::Truncated {
needed: 4,
got: data.len(),
})?
.try_into()
.map_err(|_| IoError::Truncated {
needed: 4,
got: data.len(),
})?;
let record_len = u32::from_le_bytes(len_bytes) as usize;
let tail = data.get(4..).ok_or(IoError::Truncated {
needed: 4,
got: data.len(),
})?;
Ok((record_len, tail))
}
impl<'a> Iterator for CorpusFileIter<'a> {
type Item = Result<CompressedVectorView<'a>, IoError>;
fn next(&mut self) -> Option<Self::Item> {
if self.errored || self.count == 0 {
return None;
}
let (record_len, after_len) = match read_record_len(self.remaining) {
Ok(v) => v,
Err(e) => {
self.errored = true;
return Some(Err(e));
}
};
let Some(payload) = after_len.get(..record_len) else {
self.errored = true;
return Some(Err(IoError::Truncated {
needed: record_len,
got: after_len.len(),
}));
};
match CompressedVectorView::parse(payload) {
Ok((view, _tail)) => {
self.remaining = after_len.get(record_len..).unwrap_or_default();
self.count -= 1;
Some(Ok(view))
}
Err(e) => {
self.errored = true;
Some(Err(e))
}
}
}
}