use std::io::{Read, Write};
use iqdb_types::DistanceMetric;
use crate::error::{PersistError, Result};
pub const MAGIC: [u8; 8] = *b"IQDBPRST";
pub const CURRENT_VERSION: u32 = 2;
pub(crate) const MIN_SUPPORTED_VERSION: u32 = 1;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FileHeader {
pub magic: [u8; 8],
pub version: u32,
pub index_type: String,
pub dim: usize,
pub metric: DistanceMetric,
pub n_vectors: usize,
pub crc32: u32,
}
pub(crate) fn metric_to_tag(metric: DistanceMetric) -> Result<u8> {
Ok(match metric {
DistanceMetric::Cosine => 0,
DistanceMetric::DotProduct => 1,
DistanceMetric::Euclidean => 2,
DistanceMetric::Manhattan => 3,
DistanceMetric::Hamming => 4,
_ => return Err(PersistError::UnsupportedMetric { metric }),
})
}
pub(crate) fn tag_to_metric(tag: u8) -> Result<DistanceMetric> {
match tag {
0 => Ok(DistanceMetric::Cosine),
1 => Ok(DistanceMetric::DotProduct),
2 => Ok(DistanceMetric::Euclidean),
3 => Ok(DistanceMetric::Manhattan),
4 => Ok(DistanceMetric::Hamming),
_ => Err(PersistError::InvalidMetric { tag }),
}
}
fn usize_to_u64(value: usize, what: &'static str) -> Result<u64> {
u64::try_from(value).map_err(|_| PersistError::InvalidPayload {
reason: match what {
"dim" => "dim does not fit in u64",
"n_vectors" => "n_vectors does not fit in u64",
"index_type_len" => "index_type length does not fit in u64",
_ => "usize value does not fit in u64",
},
})
}
fn u64_to_usize(value: u64, what: &'static str) -> Result<usize> {
usize::try_from(value).map_err(|_| PersistError::InvalidPayload {
reason: match what {
"dim" => "dim does not fit in usize on this host",
"n_vectors" => "n_vectors does not fit in usize on this host",
"index_type_len" => "index_type length does not fit in usize on this host",
_ => "u64 value does not fit in usize on this host",
},
})
}
pub fn write_header(writer: &mut dyn Write, header: &FileHeader) -> Result<()> {
write_all(writer, &header.magic)?;
write_all(writer, &header.version.to_le_bytes())?;
let it_bytes = header.index_type.as_bytes();
let it_len = usize_to_u64(it_bytes.len(), "index_type_len")?;
write_all(writer, &it_len.to_le_bytes())?;
write_all(writer, it_bytes)?;
let dim_u64 = usize_to_u64(header.dim, "dim")?;
write_all(writer, &dim_u64.to_le_bytes())?;
write_all(writer, &[metric_to_tag(header.metric)?])?;
let n_u64 = usize_to_u64(header.n_vectors, "n_vectors")?;
write_all(writer, &n_u64.to_le_bytes())?;
write_all(writer, &header.crc32.to_le_bytes())?;
Ok(())
}
pub fn read_header(reader: &mut dyn Read) -> Result<FileHeader> {
let mut magic = [0u8; 8];
read_exact_or_truncated(reader, &mut magic)?;
if magic != MAGIC {
return Err(PersistError::BadMagic { found: magic });
}
let mut buf4 = [0u8; 4];
read_exact_or_truncated(reader, &mut buf4)?;
let version = u32::from_le_bytes(buf4);
if !(MIN_SUPPORTED_VERSION..=CURRENT_VERSION).contains(&version) {
return Err(PersistError::UnsupportedVersion {
found: version,
supported: CURRENT_VERSION,
});
}
let mut buf8 = [0u8; 8];
read_exact_or_truncated(reader, &mut buf8)?;
let it_len_u64 = u64::from_le_bytes(buf8);
let it_len = u64_to_usize(it_len_u64, "index_type_len")?;
const MAX_INDEX_TYPE_LEN: usize = 4096;
if it_len > MAX_INDEX_TYPE_LEN {
return Err(PersistError::InvalidPayload {
reason: "index_type length exceeds the 4 KiB cap",
});
}
let mut it_bytes = vec![0u8; it_len];
read_exact_or_truncated(reader, &mut it_bytes)?;
let index_type = String::from_utf8(it_bytes).map_err(|_| PersistError::InvalidPayload {
reason: "index_type is not valid UTF-8",
})?;
read_exact_or_truncated(reader, &mut buf8)?;
let dim = u64_to_usize(u64::from_le_bytes(buf8), "dim")?;
let mut metric_buf = [0u8; 1];
read_exact_or_truncated(reader, &mut metric_buf)?;
let metric = tag_to_metric(metric_buf[0])?;
read_exact_or_truncated(reader, &mut buf8)?;
let n_vectors = u64_to_usize(u64::from_le_bytes(buf8), "n_vectors")?;
read_exact_or_truncated(reader, &mut buf4)?;
let crc32 = u32::from_le_bytes(buf4);
Ok(FileHeader {
magic,
version,
index_type,
dim,
metric,
n_vectors,
crc32,
})
}
fn write_all(writer: &mut dyn Write, bytes: &[u8]) -> Result<()> {
writer.write_all(bytes).map_err(|source| PersistError::Io {
path: std::path::PathBuf::new(),
source,
})
}
fn read_exact_or_truncated(reader: &mut dyn Read, buf: &mut [u8]) -> Result<()> {
match reader.read_exact(buf) {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
Err(PersistError::TruncatedHeader {
needed: buf.len(),
found: 0,
})
}
Err(source) => Err(PersistError::Io {
path: std::path::PathBuf::new(),
source,
}),
}
}