use std::io::Cursor;
use iqdb_index::Index;
use crate::Persistable;
use crate::checksum;
use crate::config::{Compression, PersistConfig};
use crate::error::{PersistError, Result};
use crate::format::{self, CURRENT_VERSION, FileHeader, MAGIC};
use crate::storage::{StdFsStorage, Storage};
pub struct PersistedIndex<I: Index + Persistable> {
inner: I,
config: PersistConfig,
storage: Box<dyn Storage>,
}
impl<I: Index + Persistable + core::fmt::Debug> core::fmt::Debug for PersistedIndex<I> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("PersistedIndex")
.field("inner", &self.inner)
.field("config", &self.config)
.field("storage", &"<dyn Storage>")
.finish()
}
}
impl<I: Index + Persistable> PersistedIndex<I> {
pub fn open_with(inner: I, config: PersistConfig) -> Result<Self> {
validate_config(&config)?;
Ok(Self {
inner,
config,
storage: Box::new(StdFsStorage),
})
}
pub fn load(config: PersistConfig) -> Result<Self> {
validate_config(&config)?;
Self::load_with_storage(config, Box::new(StdFsStorage))
}
#[cfg(test)]
pub(crate) fn open_with_storage(
inner: I,
config: PersistConfig,
storage: Box<dyn Storage>,
) -> Result<Self> {
validate_config(&config)?;
Ok(Self {
inner,
config,
storage,
})
}
pub(crate) fn load_with_storage(
config: PersistConfig,
storage: Box<dyn Storage>,
) -> Result<Self> {
let bytes = storage.read_all(&config.path)?;
let mut cursor = Cursor::new(&bytes[..]);
let header = format::read_header(&mut cursor)?;
let header_end =
usize::try_from(cursor.position()).map_err(|_| PersistError::InvalidPayload {
reason: "header position does not fit in usize",
})?;
if header_end > bytes.len() {
return Err(PersistError::TruncatedHeader {
needed: header_end,
found: bytes.len(),
});
}
if header.index_type != I::INDEX_TYPE {
return Err(PersistError::InvalidIndexType {
found: header.index_type,
expected: I::INDEX_TYPE,
});
}
let payload = &bytes[header_end..];
checksum::verify(payload, header.crc32)?;
let mut payload_cursor = Cursor::new(payload);
let inner = <I as Persistable>::load_from(&mut payload_cursor)?;
if inner.dim() != header.dim {
return Err(PersistError::InvalidPayload {
reason: "header dim disagrees with payload-reconstructed index",
});
}
if inner.metric() != header.metric {
return Err(PersistError::InvalidPayload {
reason: "header metric disagrees with payload-reconstructed index",
});
}
if inner.len() != header.n_vectors {
return Err(PersistError::InvalidPayload {
reason: "header n_vectors disagrees with payload-reconstructed index",
});
}
Ok(Self {
inner,
config,
storage,
})
}
#[must_use]
pub fn index(&self) -> &I {
&self.inner
}
pub fn index_mut(&mut self) -> &mut I {
&mut self.inner
}
#[must_use]
pub fn config(&self) -> &PersistConfig {
&self.config
}
#[tracing::instrument(level = "debug", skip_all, fields(
path = %self.config.path.display(),
index_type = I::INDEX_TYPE,
n = self.inner.len(),
))]
pub fn save(&self) -> Result<()> {
let mut payload_buf: Vec<u8> = Vec::new();
<I as Persistable>::save_to(&self.inner, &mut payload_buf)?;
let crc32 = checksum::compute(&payload_buf);
let header = FileHeader {
magic: MAGIC,
version: CURRENT_VERSION,
index_type: I::INDEX_TYPE.to_string(),
dim: self.inner.dim(),
metric: self.inner.metric(),
n_vectors: self.inner.len(),
crc32,
};
let mut full: Vec<u8> = Vec::with_capacity(payload_buf.len() + 64);
format::write_header(&mut full, &header)?;
full.extend_from_slice(&payload_buf);
self.storage
.write_atomic(&self.config.path, &full, self.config.fsync_policy)
}
}
fn validate_config(config: &PersistConfig) -> Result<()> {
if config.wal_enabled {
return Err(PersistError::Unsupported {
feature: "wal_enabled",
available_in: "v0.3",
});
}
if !matches!(config.compression, Compression::None) {
return Err(PersistError::Unsupported {
feature: "compression",
available_in: "v0.4",
});
}
Ok(())
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used, clippy::expect_used)]
use std::io::{Read, Write};
use std::sync::Arc;
use iqdb_index::{Index, IndexCore, IndexStats};
use iqdb_types::{DistanceMetric, Hit, Metadata, Result as IqdbResult, SearchParams, VectorId};
use super::*;
use crate::format::{metric_to_tag, tag_to_metric};
#[derive(Debug)]
struct MockIndex {
dim: usize,
metric: DistanceMetric,
n: usize,
}
impl IndexCore for MockIndex {
fn insert(&mut self, _: VectorId, _: Arc<[f32]>, _: Option<Metadata>) -> IqdbResult<()> {
self.n += 1;
Ok(())
}
fn delete(&mut self, _: &VectorId) -> IqdbResult<()> {
Ok(())
}
fn search(&self, _: &[f32], _: &SearchParams) -> IqdbResult<Vec<Hit>> {
Ok(Vec::new())
}
fn len(&self) -> usize {
self.n
}
fn dim(&self) -> usize {
self.dim
}
fn metric(&self) -> DistanceMetric {
self.metric
}
fn flush(&mut self) -> IqdbResult<()> {
Ok(())
}
fn stats(&self) -> IndexStats {
IndexStats {
n_vectors: self.n,
index_type: "mock",
..IndexStats::default()
}
}
}
impl Index for MockIndex {
type Config = ();
fn new(dim: usize, metric: DistanceMetric, _: ()) -> IqdbResult<Self> {
Ok(Self { dim, metric, n: 0 })
}
}
impl Persistable for MockIndex {
const INDEX_TYPE: &'static str = "mock";
fn save_to(&self, writer: &mut dyn Write) -> Result<()> {
writer
.write_all(&[metric_to_tag(self.metric)?])
.map_err(io_err)?;
let dim_u64 = u64::try_from(self.dim).map_err(|_| PersistError::InvalidPayload {
reason: "mock dim does not fit in u64",
})?;
writer.write_all(&dim_u64.to_le_bytes()).map_err(io_err)?;
let n_u64 = u64::try_from(self.n).map_err(|_| PersistError::InvalidPayload {
reason: "mock n does not fit in u64",
})?;
writer.write_all(&n_u64.to_le_bytes()).map_err(io_err)?;
Ok(())
}
fn load_from(reader: &mut dyn Read) -> Result<Self> {
let mut tag = [0u8; 1];
reader.read_exact(&mut tag).map_err(io_err)?;
let metric = tag_to_metric(tag[0])?;
let mut buf = [0u8; 8];
reader.read_exact(&mut buf).map_err(io_err)?;
let dim = usize::try_from(u64::from_le_bytes(buf)).map_err(|_| {
PersistError::InvalidPayload {
reason: "mock dim does not fit in usize",
}
})?;
reader.read_exact(&mut buf).map_err(io_err)?;
let n = usize::try_from(u64::from_le_bytes(buf)).map_err(|_| {
PersistError::InvalidPayload {
reason: "mock n does not fit in usize",
}
})?;
Ok(Self { dim, metric, n })
}
}
fn io_err(source: std::io::Error) -> PersistError {
PersistError::Io {
path: std::path::PathBuf::new(),
source,
}
}
struct FailingRenameStorage;
impl Storage for FailingRenameStorage {
fn read_all(&self, path: &std::path::Path) -> Result<Vec<u8>> {
StdFsStorage.read_all(path)
}
fn write_atomic(
&self,
target: &std::path::Path,
payload: &[u8],
_policy: crate::config::FsyncPolicy,
) -> Result<()> {
use std::fs::OpenOptions;
let target_dir = target.parent().unwrap_or_else(|| std::path::Path::new("."));
let file_name = target.file_name().unwrap();
let temp_path = target_dir.join(format!(
"{}.tmp.failtest.{}",
file_name.to_string_lossy(),
std::process::id(),
));
{
let mut f = OpenOptions::new()
.create_new(true)
.write(true)
.open(&temp_path)
.map_err(|source| PersistError::Io {
path: temp_path.clone(),
source,
})?;
f.write_all(payload).map_err(|source| PersistError::Io {
path: temp_path.clone(),
source,
})?;
f.sync_all().map_err(|source| PersistError::Io {
path: temp_path.clone(),
source,
})?;
}
let _cleanup = std::fs::remove_file(&temp_path);
Err(PersistError::Io {
path: target.to_path_buf(),
source: std::io::Error::other("simulated rename failure"),
})
}
}
#[test]
fn save_failure_leaves_original_file_intact() {
let dir = tempfile::tempdir().unwrap();
let snapshot = dir.path().join("idx.iqdb");
let inner = MockIndex {
dim: 16,
metric: DistanceMetric::Cosine,
n: 7,
};
let cfg = PersistConfig::new(&snapshot);
let wrap = PersistedIndex::open_with(inner, cfg.clone()).unwrap();
wrap.save().unwrap();
let good_bytes = std::fs::read(&snapshot).unwrap();
assert!(!good_bytes.is_empty(), "good save produced empty file");
let other = MockIndex {
dim: 16,
metric: DistanceMetric::Cosine,
n: 99,
};
let wrap2 =
PersistedIndex::open_with_storage(other, cfg.clone(), Box::new(FailingRenameStorage))
.unwrap();
let err = wrap2.save().unwrap_err();
assert!(matches!(err, PersistError::Io { .. }));
let after_bytes = std::fs::read(&snapshot).unwrap();
assert_eq!(
after_bytes, good_bytes,
"rename failure corrupted the snapshot"
);
let restored: PersistedIndex<MockIndex> = PersistedIndex::load(cfg).unwrap();
assert_eq!(restored.index().len(), 7);
}
#[test]
fn validate_config_rejects_wal_and_compression() {
let dir = tempfile::tempdir().unwrap();
let snapshot = dir.path().join("idx.iqdb");
let mut cfg = PersistConfig::new(&snapshot);
cfg.wal_enabled = true;
let inner = MockIndex {
dim: 4,
metric: DistanceMetric::Euclidean,
n: 0,
};
let err = PersistedIndex::open_with(inner, cfg).unwrap_err();
assert!(matches!(
err,
PersistError::Unsupported {
feature: "wal_enabled",
..
}
));
let mut cfg2 = PersistConfig::new(&snapshot);
cfg2.compression = Compression::Lz4;
let inner2 = MockIndex {
dim: 4,
metric: DistanceMetric::Euclidean,
n: 0,
};
let err = PersistedIndex::open_with(inner2, cfg2).unwrap_err();
assert!(matches!(
err,
PersistError::Unsupported {
feature: "compression",
..
}
));
}
#[test]
fn crc_mismatch_after_byte_flip_in_payload() {
let dir = tempfile::tempdir().unwrap();
let snapshot = dir.path().join("idx.iqdb");
let inner = MockIndex {
dim: 8,
metric: DistanceMetric::Cosine,
n: 11,
};
let cfg = PersistConfig::new(&snapshot);
PersistedIndex::open_with(inner, cfg.clone())
.unwrap()
.save()
.unwrap();
let mut bytes = std::fs::read(&snapshot).unwrap();
let last = bytes.len() - 1;
bytes[last] ^= 0x01;
std::fs::write(&snapshot, &bytes).unwrap();
let err: PersistError = PersistedIndex::<MockIndex>::load(cfg).unwrap_err();
assert!(
matches!(err, PersistError::ChecksumMismatch { .. }),
"expected ChecksumMismatch, got {err:?}",
);
}
#[test]
fn invalid_index_type_on_wrong_i_surfaces_loudly() {
let dir = tempfile::tempdir().unwrap();
let snapshot = dir.path().join("idx.iqdb");
let cfg = PersistConfig::new(&snapshot);
let inner = MockIndex {
dim: 4,
metric: DistanceMetric::Euclidean,
n: 3,
};
PersistedIndex::open_with(inner, cfg.clone())
.unwrap()
.save()
.unwrap();
#[derive(Debug)]
struct OtherMock;
impl IndexCore for OtherMock {
fn insert(
&mut self,
_: VectorId,
_: Arc<[f32]>,
_: Option<Metadata>,
) -> IqdbResult<()> {
Ok(())
}
fn delete(&mut self, _: &VectorId) -> IqdbResult<()> {
Ok(())
}
fn search(&self, _: &[f32], _: &SearchParams) -> IqdbResult<Vec<Hit>> {
Ok(Vec::new())
}
fn len(&self) -> usize {
0
}
fn dim(&self) -> usize {
4
}
fn metric(&self) -> DistanceMetric {
DistanceMetric::Euclidean
}
fn flush(&mut self) -> IqdbResult<()> {
Ok(())
}
fn stats(&self) -> IndexStats {
IndexStats {
index_type: "other",
..IndexStats::default()
}
}
}
impl Index for OtherMock {
type Config = ();
fn new(_: usize, _: DistanceMetric, _: ()) -> IqdbResult<Self> {
Ok(Self)
}
}
impl Persistable for OtherMock {
const INDEX_TYPE: &'static str = "other";
fn save_to(&self, _w: &mut dyn Write) -> Result<()> {
Ok(())
}
fn load_from(_r: &mut dyn Read) -> Result<Self> {
Ok(Self)
}
}
let err = PersistedIndex::<OtherMock>::load(cfg).unwrap_err();
assert!(
matches!(
err,
PersistError::InvalidIndexType {
expected: "other",
..
}
),
"expected InvalidIndexType, got {err:?}",
);
}
#[test]
fn roundtrip_through_storage_recovers_state() {
let dir = tempfile::tempdir().unwrap();
let snapshot = dir.path().join("idx.iqdb");
let cfg = PersistConfig::new(&snapshot);
let inner = MockIndex {
dim: 32,
metric: DistanceMetric::Manhattan,
n: 42,
};
let wrap = PersistedIndex::open_with(inner, cfg.clone()).unwrap();
wrap.save().unwrap();
let restored: PersistedIndex<MockIndex> = PersistedIndex::load(cfg).unwrap();
assert_eq!(restored.index().dim(), 32);
assert_eq!(restored.index().metric(), DistanceMetric::Manhattan);
assert_eq!(restored.index().len(), 42);
}
}