use std::io::Cursor;
use std::sync::Arc;
use iqdb_index::Index;
use iqdb_types::{Metadata, VectorId};
use crate::Persistable;
use crate::config::{Compression, PersistConfig};
use crate::error::{PersistError, Result};
use crate::format::{self, CURRENT_VERSION, FileHeader, MAGIC};
use crate::storage::{StdFsStorage, Storage};
use crate::wal::Wal;
use crate::{checksum, compression, recovery};
pub struct PersistedIndex<I: Index + Persistable> {
inner: I,
config: PersistConfig,
storage: Box<dyn Storage>,
wal: Option<Wal>,
}
impl<I: Index + Persistable + core::fmt::Debug> core::fmt::Debug for PersistedIndex<I> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("PersistedIndex")
.field("inner", &self.inner)
.field("config", &self.config)
.field("storage", &"<dyn Storage>")
.field("wal", &self.wal.is_some())
.finish()
}
}
impl<I: Index + Persistable> PersistedIndex<I> {
pub fn open_with(inner: I, config: PersistConfig) -> Result<Self> {
Self::open_with_storage(inner, config, Box::new(StdFsStorage))
}
pub fn load(config: PersistConfig) -> Result<Self> {
Self::load_with_storage(config, Box::new(StdFsStorage))
}
pub(crate) fn open_with_storage(
inner: I,
config: PersistConfig,
storage: Box<dyn Storage>,
) -> Result<Self> {
validate_config(&config)?;
let mut this = Self {
inner,
config,
storage,
wal: None,
};
if this.config.wal_enabled {
this.write_snapshot()?;
let wal = Wal::create(&this.config.path, this.config.fsync_policy)?;
this.wal = Some(wal);
}
Ok(this)
}
pub(crate) fn load_with_storage(
config: PersistConfig,
storage: Box<dyn Storage>,
) -> Result<Self> {
validate_config(&config)?;
let bytes = storage.read_all(&config.path)?;
let mut cursor = Cursor::new(&bytes[..]);
let header = format::read_header(&mut cursor)?;
let header_end =
usize::try_from(cursor.position()).map_err(|_| PersistError::InvalidPayload {
reason: "header position does not fit in usize",
})?;
if header_end > bytes.len() {
return Err(PersistError::TruncatedHeader {
needed: header_end,
found: bytes.len(),
});
}
if header.index_type != I::INDEX_TYPE {
return Err(PersistError::InvalidIndexType {
found: header.index_type,
expected: I::INDEX_TYPE,
});
}
let payload = &bytes[header_end..];
checksum::verify(payload, header.crc32)?;
let raw: Vec<u8> = if header.version >= 2 {
if payload.len() < 9 {
return Err(PersistError::TruncatedPayload {
needed: 9,
found: payload.len() as u64,
});
}
let tag = payload[0];
let mut len_bytes = [0u8; 8];
len_bytes.copy_from_slice(&payload[1..9]);
let uncompressed_len =
usize::try_from(u64::from_le_bytes(len_bytes)).map_err(|_| {
PersistError::InvalidPayload {
reason: "uncompressed length does not fit in usize on this host",
}
})?;
compression::decode(tag, &payload[9..], uncompressed_len)?
} else {
payload.to_vec()
};
let mut payload_cursor = Cursor::new(&raw[..]);
let inner = <I as Persistable>::load_from(&mut payload_cursor)?;
if inner.dim() != header.dim {
return Err(PersistError::InvalidPayload {
reason: "header dim disagrees with payload-reconstructed index",
});
}
if inner.metric() != header.metric {
return Err(PersistError::InvalidPayload {
reason: "header metric disagrees with payload-reconstructed index",
});
}
if inner.len() != header.n_vectors {
return Err(PersistError::InvalidPayload {
reason: "header n_vectors disagrees with payload-reconstructed index",
});
}
let mut this = Self {
inner,
config,
storage,
wal: None,
};
if this.config.wal_enabled {
let _applied = recovery::replay(&this.config.path, &mut this.inner)?;
let wal = Wal::open_for_append(&this.config.path, this.config.fsync_policy)?;
this.wal = Some(wal);
}
Ok(this)
}
#[must_use]
pub fn index(&self) -> &I {
&self.inner
}
pub fn index_mut(&mut self) -> &mut I {
&mut self.inner
}
#[must_use]
pub fn config(&self) -> &PersistConfig {
&self.config
}
pub fn insert(
&mut self,
id: VectorId,
vector: Arc<[f32]>,
meta: Option<Metadata>,
) -> Result<()> {
let Self { inner, wal, .. } = self;
match wal {
Some(w) => {
let mark = w.mark()?;
w.append_insert(&id, &vector, meta.as_ref())?;
match inner.insert(id, vector, meta) {
Ok(()) => Ok(()),
Err(e) => {
w.rollback(mark)?;
Err(PersistError::from(e))
}
}
}
None => {
inner.insert(id, vector, meta)?;
Ok(())
}
}
}
pub fn delete(&mut self, id: &VectorId) -> Result<()> {
let Self { inner, wal, .. } = self;
match wal {
Some(w) => {
let mark = w.mark()?;
w.append_delete(id)?;
match inner.delete(id) {
Ok(()) => Ok(()),
Err(e) => {
w.rollback(mark)?;
Err(PersistError::from(e))
}
}
}
None => {
inner.delete(id)?;
Ok(())
}
}
}
pub fn save(&self) -> Result<()> {
self.write_snapshot()
}
pub fn checkpoint(&mut self) -> Result<()> {
self.write_snapshot()?;
if let Some(wal) = &mut self.wal {
wal.reset()?;
}
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, fields(
path = %self.config.path.display(),
index_type = I::INDEX_TYPE,
n = self.inner.len(),
))]
fn write_snapshot(&self) -> Result<()> {
let mut payload_buf: Vec<u8> = Vec::new();
<I as Persistable>::save_to(&self.inner, &mut payload_buf)?;
let scheme = self.config.compression;
let data = compression::encode(scheme, &payload_buf)?;
let uncompressed_len =
u64::try_from(payload_buf.len()).map_err(|_| PersistError::InvalidPayload {
reason: "payload length does not fit in u64",
})?;
let mut region: Vec<u8> = Vec::with_capacity(9 + data.len());
region.push(compression::scheme_tag(scheme));
region.extend_from_slice(&uncompressed_len.to_le_bytes());
region.extend_from_slice(&data);
let crc32 = checksum::compute(®ion);
let header = FileHeader {
magic: MAGIC,
version: CURRENT_VERSION,
index_type: I::INDEX_TYPE.to_string(),
dim: self.inner.dim(),
metric: self.inner.metric(),
n_vectors: self.inner.len(),
crc32,
};
let mut full: Vec<u8> = Vec::with_capacity(region.len() + 64);
format::write_header(&mut full, &header)?;
full.extend_from_slice(®ion);
self.storage
.write_atomic(&self.config.path, &full, self.config.fsync_policy)
}
}
fn validate_config(config: &PersistConfig) -> Result<()> {
match config.compression {
Compression::None => Ok(()),
Compression::Zstd { .. } => {
if cfg!(feature = "zstd") {
Ok(())
} else {
Err(PersistError::Unsupported {
feature: "Zstd compression",
available_in: "the `zstd` cargo feature",
})
}
}
Compression::Lz4 => {
if cfg!(feature = "lz4") {
Ok(())
} else {
Err(PersistError::Unsupported {
feature: "LZ4 compression",
available_in: "the `lz4` cargo feature",
})
}
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used, clippy::expect_used)]
use std::io::{Read, Write};
use std::sync::Arc;
use iqdb_index::{Index, IndexCore, IndexStats};
use iqdb_types::{DistanceMetric, Hit, Metadata, Result as IqdbResult, SearchParams, VectorId};
use super::*;
use crate::format::{metric_to_tag, tag_to_metric};
#[derive(Debug)]
struct MockIndex {
dim: usize,
metric: DistanceMetric,
n: usize,
}
impl IndexCore for MockIndex {
fn insert(&mut self, _: VectorId, _: Arc<[f32]>, _: Option<Metadata>) -> IqdbResult<()> {
self.n += 1;
Ok(())
}
fn delete(&mut self, _: &VectorId) -> IqdbResult<()> {
Ok(())
}
fn search(&self, _: &[f32], _: &SearchParams) -> IqdbResult<Vec<Hit>> {
Ok(Vec::new())
}
fn len(&self) -> usize {
self.n
}
fn dim(&self) -> usize {
self.dim
}
fn metric(&self) -> DistanceMetric {
self.metric
}
fn flush(&mut self) -> IqdbResult<()> {
Ok(())
}
fn stats(&self) -> IndexStats {
IndexStats {
n_vectors: self.n,
index_type: "mock",
..IndexStats::default()
}
}
}
impl Index for MockIndex {
type Config = ();
fn new(dim: usize, metric: DistanceMetric, _: ()) -> IqdbResult<Self> {
Ok(Self { dim, metric, n: 0 })
}
}
impl Persistable for MockIndex {
const INDEX_TYPE: &'static str = "mock";
fn save_to(&self, writer: &mut dyn Write) -> Result<()> {
writer
.write_all(&[metric_to_tag(self.metric)?])
.map_err(io_err)?;
let dim_u64 = u64::try_from(self.dim).map_err(|_| PersistError::InvalidPayload {
reason: "mock dim does not fit in u64",
})?;
writer.write_all(&dim_u64.to_le_bytes()).map_err(io_err)?;
let n_u64 = u64::try_from(self.n).map_err(|_| PersistError::InvalidPayload {
reason: "mock n does not fit in u64",
})?;
writer.write_all(&n_u64.to_le_bytes()).map_err(io_err)?;
Ok(())
}
fn load_from(reader: &mut dyn Read) -> Result<Self> {
let mut tag = [0u8; 1];
reader.read_exact(&mut tag).map_err(io_err)?;
let metric = tag_to_metric(tag[0])?;
let mut buf = [0u8; 8];
reader.read_exact(&mut buf).map_err(io_err)?;
let dim = usize::try_from(u64::from_le_bytes(buf)).map_err(|_| {
PersistError::InvalidPayload {
reason: "mock dim does not fit in usize",
}
})?;
reader.read_exact(&mut buf).map_err(io_err)?;
let n = usize::try_from(u64::from_le_bytes(buf)).map_err(|_| {
PersistError::InvalidPayload {
reason: "mock n does not fit in usize",
}
})?;
Ok(Self { dim, metric, n })
}
}
fn io_err(source: std::io::Error) -> PersistError {
PersistError::Io {
path: std::path::PathBuf::new(),
source,
}
}
struct FailingRenameStorage;
impl Storage for FailingRenameStorage {
fn read_all(&self, path: &std::path::Path) -> Result<Vec<u8>> {
StdFsStorage.read_all(path)
}
fn write_atomic(
&self,
target: &std::path::Path,
payload: &[u8],
_policy: crate::config::FsyncPolicy,
) -> Result<()> {
use std::fs::OpenOptions;
let target_dir = target.parent().unwrap_or_else(|| std::path::Path::new("."));
let file_name = target.file_name().unwrap();
let temp_path = target_dir.join(format!(
"{}.tmp.failtest.{}",
file_name.to_string_lossy(),
std::process::id(),
));
{
let mut f = OpenOptions::new()
.create_new(true)
.write(true)
.open(&temp_path)
.map_err(|source| PersistError::Io {
path: temp_path.clone(),
source,
})?;
f.write_all(payload).map_err(|source| PersistError::Io {
path: temp_path.clone(),
source,
})?;
f.sync_all().map_err(|source| PersistError::Io {
path: temp_path.clone(),
source,
})?;
}
let _cleanup = std::fs::remove_file(&temp_path);
Err(PersistError::Io {
path: target.to_path_buf(),
source: std::io::Error::other("simulated rename failure"),
})
}
}
#[test]
fn save_failure_leaves_original_file_intact() {
let dir = tempfile::tempdir().unwrap();
let snapshot = dir.path().join("idx.iqdb");
let inner = MockIndex {
dim: 16,
metric: DistanceMetric::Cosine,
n: 7,
};
let cfg = PersistConfig::new(&snapshot);
let wrap = PersistedIndex::open_with(inner, cfg.clone()).unwrap();
wrap.save().unwrap();
let good_bytes = std::fs::read(&snapshot).unwrap();
assert!(!good_bytes.is_empty(), "good save produced empty file");
let other = MockIndex {
dim: 16,
metric: DistanceMetric::Cosine,
n: 99,
};
let wrap2 =
PersistedIndex::open_with_storage(other, cfg.clone(), Box::new(FailingRenameStorage))
.unwrap();
let err = wrap2.save().unwrap_err();
assert!(matches!(err, PersistError::Io { .. }));
let after_bytes = std::fs::read(&snapshot).unwrap();
assert_eq!(
after_bytes, good_bytes,
"rename failure corrupted the snapshot"
);
let restored: PersistedIndex<MockIndex> = PersistedIndex::load(cfg).unwrap();
assert_eq!(restored.index().len(), 7);
}
#[test]
fn wal_config_is_accepted() {
let dir = tempfile::tempdir().unwrap();
let snapshot = dir.path().join("idx.iqdb");
let mut cfg = PersistConfig::new(&snapshot);
cfg.wal_enabled = true;
let inner = MockIndex {
dim: 4,
metric: DistanceMetric::Euclidean,
n: 0,
};
assert!(PersistedIndex::open_with(inner, cfg).is_ok());
}
#[cfg(not(feature = "lz4"))]
#[test]
fn validate_config_rejects_lz4_without_feature() {
let dir = tempfile::tempdir().unwrap();
let snapshot = dir.path().join("idx.iqdb");
let mut cfg = PersistConfig::new(&snapshot);
cfg.compression = Compression::Lz4;
let inner = MockIndex {
dim: 4,
metric: DistanceMetric::Euclidean,
n: 0,
};
let err = PersistedIndex::open_with(inner, cfg).unwrap_err();
assert!(matches!(err, PersistError::Unsupported { .. }));
}
#[cfg(not(feature = "zstd"))]
#[test]
fn validate_config_rejects_zstd_without_feature() {
let dir = tempfile::tempdir().unwrap();
let snapshot = dir.path().join("idx.iqdb");
let mut cfg = PersistConfig::new(&snapshot);
cfg.compression = Compression::Zstd { level: 3 };
let inner = MockIndex {
dim: 4,
metric: DistanceMetric::Euclidean,
n: 0,
};
let err = PersistedIndex::open_with(inner, cfg).unwrap_err();
assert!(matches!(err, PersistError::Unsupported { .. }));
}
#[test]
fn crc_mismatch_after_byte_flip_in_payload() {
let dir = tempfile::tempdir().unwrap();
let snapshot = dir.path().join("idx.iqdb");
let inner = MockIndex {
dim: 8,
metric: DistanceMetric::Cosine,
n: 11,
};
let cfg = PersistConfig::new(&snapshot);
PersistedIndex::open_with(inner, cfg.clone())
.unwrap()
.save()
.unwrap();
let mut bytes = std::fs::read(&snapshot).unwrap();
let last = bytes.len() - 1;
bytes[last] ^= 0x01;
std::fs::write(&snapshot, &bytes).unwrap();
let err: PersistError = PersistedIndex::<MockIndex>::load(cfg).unwrap_err();
assert!(
matches!(err, PersistError::ChecksumMismatch { .. }),
"expected ChecksumMismatch, got {err:?}",
);
}
#[test]
fn invalid_index_type_on_wrong_i_surfaces_loudly() {
let dir = tempfile::tempdir().unwrap();
let snapshot = dir.path().join("idx.iqdb");
let cfg = PersistConfig::new(&snapshot);
let inner = MockIndex {
dim: 4,
metric: DistanceMetric::Euclidean,
n: 3,
};
PersistedIndex::open_with(inner, cfg.clone())
.unwrap()
.save()
.unwrap();
#[derive(Debug)]
struct OtherMock;
impl IndexCore for OtherMock {
fn insert(
&mut self,
_: VectorId,
_: Arc<[f32]>,
_: Option<Metadata>,
) -> IqdbResult<()> {
Ok(())
}
fn delete(&mut self, _: &VectorId) -> IqdbResult<()> {
Ok(())
}
fn search(&self, _: &[f32], _: &SearchParams) -> IqdbResult<Vec<Hit>> {
Ok(Vec::new())
}
fn len(&self) -> usize {
0
}
fn dim(&self) -> usize {
4
}
fn metric(&self) -> DistanceMetric {
DistanceMetric::Euclidean
}
fn flush(&mut self) -> IqdbResult<()> {
Ok(())
}
fn stats(&self) -> IndexStats {
IndexStats {
index_type: "other",
..IndexStats::default()
}
}
}
impl Index for OtherMock {
type Config = ();
fn new(_: usize, _: DistanceMetric, _: ()) -> IqdbResult<Self> {
Ok(Self)
}
}
impl Persistable for OtherMock {
const INDEX_TYPE: &'static str = "other";
fn save_to(&self, _w: &mut dyn Write) -> Result<()> {
Ok(())
}
fn load_from(_r: &mut dyn Read) -> Result<Self> {
Ok(Self)
}
}
let err = PersistedIndex::<OtherMock>::load(cfg).unwrap_err();
assert!(
matches!(
err,
PersistError::InvalidIndexType {
expected: "other",
..
}
),
"expected InvalidIndexType, got {err:?}",
);
}
#[test]
fn roundtrip_through_storage_recovers_state() {
let dir = tempfile::tempdir().unwrap();
let snapshot = dir.path().join("idx.iqdb");
let cfg = PersistConfig::new(&snapshot);
let inner = MockIndex {
dim: 32,
metric: DistanceMetric::Manhattan,
n: 42,
};
let wrap = PersistedIndex::open_with(inner, cfg.clone()).unwrap();
wrap.save().unwrap();
let restored: PersistedIndex<MockIndex> = PersistedIndex::load(cfg).unwrap();
assert_eq!(restored.index().dim(), 32);
assert_eq!(restored.index().metric(), DistanceMetric::Manhattan);
assert_eq!(restored.index().len(), 42);
}
}