use crate::codec_file::header::{decode_header, CorpusFileHeader};
use crate::compressed_vector::from_bytes;
use crate::errors::IoError;
use std::io::{Read, Seek, SeekFrom};
use tinyquant_core::codec::CompressedVector;
pub struct CodecFileReader<R: Read + Seek> {
inner: R,
header: CorpusFileHeader,
records_read: u64,
}
impl<R: Read + Seek> CodecFileReader<R> {
pub fn new(mut inner: R) -> Result<Self, IoError> {
let header = read_and_decode_header(&mut inner)?;
let body_offset = header.body_offset;
inner.seek(SeekFrom::Start(
u64::try_from(body_offset).map_err(|_| IoError::InvalidHeader)?,
))?;
Ok(Self {
inner,
header,
records_read: 0,
})
}
pub const fn header(&self) -> &CorpusFileHeader {
&self.header
}
pub fn next_vector(&mut self) -> Result<Option<CompressedVector>, IoError> {
if self.records_read >= self.header.vector_count {
return Ok(None);
}
let cv = read_record(&mut self.inner)?;
self.records_read += 1;
Ok(Some(cv))
}
pub const fn records_read(&self) -> u64 {
self.records_read
}
}
fn read_and_decode_header<R: Read + Seek>(r: &mut R) -> Result<CorpusFileHeader, IoError> {
let mut fixed = [0u8; 24];
r.read_exact(&mut fixed)?;
let chl_bytes: [u8; 2] = fixed
.get(22..24)
.ok_or(IoError::Truncated {
needed: 24,
got: fixed.len(),
})?
.try_into()
.map_err(|_| IoError::InvalidHeader)?;
let config_hash_len = u16::from_le_bytes(chl_bytes) as usize;
let mut var_prefix = vec![0u8; config_hash_len + 4];
r.read_exact(&mut var_prefix)?;
let ml_bytes: [u8; 4] = var_prefix
.get(config_hash_len..config_hash_len + 4)
.ok_or(IoError::Truncated {
needed: config_hash_len + 4,
got: var_prefix.len(),
})?
.try_into()
.map_err(|_| IoError::InvalidHeader)?;
let metadata_len = u32::from_le_bytes(ml_bytes) as usize;
let header_end = 24 + config_hash_len + 4 + metadata_len;
let body_offset = ((header_end + 7) / 8) * 8;
let remaining_header = body_offset - 24 - config_hash_len - 4;
let mut rest = vec![0u8; remaining_header];
r.read_exact(&mut rest)?;
let mut full = Vec::with_capacity(body_offset);
full.extend_from_slice(&fixed);
full.extend_from_slice(&var_prefix);
full.extend_from_slice(&rest);
decode_header(&full)
}
fn read_record<R: Read>(r: &mut R) -> Result<CompressedVector, IoError> {
let mut len_buf = [0u8; 4];
r.read_exact(&mut len_buf)?;
let record_len = u32::from_le_bytes(len_buf) as usize;
let mut payload = vec![0u8; record_len];
r.read_exact(&mut payload)?;
from_bytes(&payload)
}