use crate::codec_file::header::{encode_header, FIXED_HEADER_SIZE, MAGIC_FINAL};
use crate::compressed_vector::to_bytes;
use crate::errors::IoError;
use std::fs::{File, OpenOptions};
use std::io::{Seek, SeekFrom, Write};
use std::path::Path;
use tinyquant_core::codec::CompressedVector;
pub struct CodecFileWriter {
file: File,
vector_count: u64,
config_hash: String,
dimension: u32,
bit_width: u8,
residual: bool,
}
impl CodecFileWriter {
pub fn create(
path: &Path,
config_hash: &str,
dimension: u32,
bit_width: u8,
residual: bool,
metadata: &[u8],
) -> Result<Self, IoError> {
let header = encode_header(config_hash, dimension, bit_width, residual, metadata, 0)?;
let mut file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(path)?;
file.write_all(&header)?;
file.sync_data()?;
Ok(Self {
file,
vector_count: 0,
config_hash: config_hash.to_owned(),
dimension,
bit_width,
residual,
})
}
pub fn append(&mut self, cv: &CompressedVector) -> Result<(), IoError> {
let payload = to_bytes(cv);
#[allow(clippy::cast_possible_truncation)]
let record_len = payload.len() as u32;
self.file.write_all(&record_len.to_le_bytes())?;
self.file.write_all(&payload)?;
self.vector_count += 1;
Ok(())
}
pub fn finalize(mut self) -> Result<(), IoError> {
self.file.sync_data()?;
self.file.seek(SeekFrom::Start(8))?;
self.file.write_all(&self.vector_count.to_le_bytes())?;
self.file.sync_data()?;
self.file.seek(SeekFrom::Start(0))?;
self.file.write_all(MAGIC_FINAL)?;
self.file.sync_data()?;
Ok(())
}
pub const fn vector_count(&self) -> u64 {
self.vector_count
}
pub fn config_hash(&self) -> &str {
&self.config_hash
}
pub const fn dimension(&self) -> u32 {
self.dimension
}
pub const fn bit_width(&self) -> u8 {
self.bit_width
}
pub const fn residual(&self) -> bool {
self.residual
}
#[allow(clippy::missing_const_for_fn)]
pub fn body_offset(config_hash: &str, metadata_len: usize) -> Result<usize, IoError> {
let hash_len = config_hash.as_bytes().len();
if hash_len > 256 {
return Err(IoError::InvalidHeader);
}
let header_end = FIXED_HEADER_SIZE + hash_len + 4 + metadata_len;
Ok(((header_end + 7) / 8) * 8)
}
}