use crate::codec_file::header::{
decode_header, CorpusFileHeader, MAX_CONFIG_HASH_LEN, MAX_METADATA_LEN,
};
use crate::compressed_vector::from_bytes;
use crate::compressed_vector::header::HEADER_SIZE;
use crate::errors::IoError;
use std::io::{Read, Seek, SeekFrom};
use tinyquant_core::codec::CompressedVector;
const MAX_RECORD_LEN: usize = 4 * 1024 * 1024;
pub struct CodecFileReader<R: Read + Seek> {
inner: R,
header: CorpusFileHeader,
records_read: u64,
}
impl<R: Read + Seek> CodecFileReader<R> {
pub fn new(mut inner: R) -> Result<Self, IoError> {
let header = read_and_decode_header(&mut inner)?;
let body_offset = header.body_offset;
inner.seek(SeekFrom::Start(
u64::try_from(body_offset).map_err(|_| IoError::InvalidHeader)?,
))?;
Ok(Self {
inner,
header,
records_read: 0,
})
}
pub const fn header(&self) -> &CorpusFileHeader {
&self.header
}
pub fn next_vector(&mut self) -> Result<Option<CompressedVector>, IoError> {
if self.records_read >= self.header.vector_count {
return Ok(None);
}
let max_record_len = max_record_len_for_header(&self.header);
let cv = read_record(&mut self.inner, max_record_len)?;
self.records_read += 1;
Ok(Some(cv))
}
pub const fn records_read(&self) -> u64 {
self.records_read
}
}
fn read_and_decode_header<R: Read + Seek>(r: &mut R) -> Result<CorpusFileHeader, IoError> {
let mut fixed = [0u8; 24];
r.read_exact(&mut fixed)?;
let chl_bytes: [u8; 2] = fixed
.get(22..24)
.ok_or(IoError::Truncated {
needed: 24,
got: fixed.len(),
})?
.try_into()
.map_err(|_| IoError::InvalidHeader)?;
let config_hash_len = u16::from_le_bytes(chl_bytes) as usize;
if config_hash_len > MAX_CONFIG_HASH_LEN {
return Err(IoError::InvalidHeader);
}
let mut var_prefix = vec![0u8; config_hash_len + 4];
r.read_exact(&mut var_prefix)?;
let ml_bytes: [u8; 4] = var_prefix
.get(config_hash_len..config_hash_len + 4)
.ok_or(IoError::Truncated {
needed: config_hash_len + 4,
got: var_prefix.len(),
})?
.try_into()
.map_err(|_| IoError::InvalidHeader)?;
let metadata_len = u32::from_le_bytes(ml_bytes) as usize;
if metadata_len > MAX_METADATA_LEN {
return Err(IoError::InvalidHeader);
}
let header_end = 24_usize
.checked_add(config_hash_len)
.and_then(|n| n.checked_add(4))
.and_then(|n| n.checked_add(metadata_len))
.ok_or(IoError::InvalidHeader)?;
let body_offset = header_end.next_multiple_of(8);
let remaining_header = body_offset
.checked_sub(24 + config_hash_len + 4)
.ok_or(IoError::InvalidHeader)?;
let mut rest = vec![0u8; remaining_header];
r.read_exact(&mut rest)?;
let mut full = Vec::with_capacity(body_offset);
full.extend_from_slice(&fixed);
full.extend_from_slice(&var_prefix);
full.extend_from_slice(&rest);
decode_header(&full)
}
fn max_record_len_for_header(header: &CorpusFileHeader) -> usize {
let dim = header.dimension as usize;
let bw = header.bit_width as usize;
let packed = dim.saturating_mul(bw).saturating_add(7) / 8;
let residual = if header.residual {
4_usize.saturating_add(dim.saturating_mul(2))
} else {
0
};
HEADER_SIZE
.saturating_add(packed)
.saturating_add(1) .saturating_add(residual)
.min(MAX_RECORD_LEN)
}
fn read_record<R: Read>(r: &mut R, max_record_len: usize) -> Result<CompressedVector, IoError> {
let mut len_buf = [0u8; 4];
r.read_exact(&mut len_buf)?;
let record_len = u32::from_le_bytes(len_buf) as usize;
if record_len > max_record_len {
return Err(IoError::InvalidHeader);
}
let mut payload = vec![0u8; record_len];
r.read_exact(&mut payload)?;
from_bytes(&payload)
}