use serde::{Deserialize, Serialize};
pub const SEGMENT_MAGIC: &[u8; 4] = b"VCNT";
pub const CHECKPOINT_MAGIC: [u8; 4] = *b"VCKP";
pub const FORMAT_VERSION: u32 = 1;
pub const WAL_MAGIC: [u8; 4] = *b"VWAL";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum IndexType {
Hnsw,
DiskAnn,
IvfPq,
ScaNN,
Sng,
LearnedSparse,
Flat,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CompressionType {
None,
ProductQuantization,
ScalarQuantization,
BinaryQuantization,
RaBitQ,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SegmentHeader {
pub segment_id: u64,
pub vector_count: u64,
pub dimension: u32,
pub index_type: IndexType,
pub compression: CompressionType,
pub created_at: u64,
pub metadata: std::collections::HashMap<String, String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum WalEntryType {
Insert = 1,
Delete = 2,
Update = 3,
Checkpoint = 4,
}
impl TryFrom<u8> for WalEntryType {
type Error = ();
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
1 => Ok(WalEntryType::Insert),
2 => Ok(WalEntryType::Delete),
3 => Ok(WalEntryType::Update),
4 => Ok(WalEntryType::Checkpoint),
_ => Err(()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexManifest {
pub version: u32,
pub index_type: IndexType,
pub dimension: u32,
pub total_vectors: u64,
pub segments: Vec<u64>,
pub wal_sequence: u64,
pub checkpoint_id: Option<u64>,
pub config: serde_json::Value,
pub created_at: u64,
pub modified_at: u64,
}
#[derive(Debug, Clone, Default)]
pub struct SegmentOffsets {
pub term_dict_offset: u64,
pub term_dict_len: u64,
pub term_info_offset: u64,
pub term_info_len: u64,
pub postings_offset: u64,
pub postings_len: u64,
pub doc_lengths_offset: u64,
pub doc_lengths_len: u64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SegmentFooter {
pub magic: [u8; 4],
pub format_version: u32,
pub header_offset: u64,
pub vectors_offset: u64,
pub graph_offset: u64,
pub ids_offset: u64,
pub term_dict_offset: u64,
pub term_dict_len: u64,
pub postings_offset: u64,
pub postings_len: u64,
pub doc_count: u32,
pub max_doc_id: u32,
pub checksum: u32,
}
impl SegmentFooter {
pub fn new(doc_count: u32, max_doc_id: u32, offsets: SegmentOffsets) -> Self {
Self {
magic: *SEGMENT_MAGIC,
format_version: FORMAT_VERSION,
header_offset: 0,
vectors_offset: 0,
graph_offset: 0,
ids_offset: 0,
term_dict_offset: offsets.term_dict_offset,
term_dict_len: offsets.term_dict_len,
postings_offset: offsets.postings_offset,
postings_len: offsets.postings_len,
doc_count,
max_doc_id,
checksum: 0,
}
}
const SERIALIZED_SIZE: usize = 84;
pub fn read<R: std::io::Read>(reader: &mut R) -> super::error::PersistenceResult<Self> {
let mut buf = vec![0u8; Self::SERIALIZED_SIZE];
reader.read_exact(&mut buf)?;
let mut cursor = std::io::Cursor::new(&buf);
use std::io::Read;
let mut magic = [0u8; 4];
cursor.read_exact(&mut magic)?;
let mut u32_buf = [0u8; 4];
cursor.read_exact(&mut u32_buf)?;
let format_version = u32::from_le_bytes(u32_buf);
let mut u64_buf = [0u8; 8];
cursor.read_exact(&mut u64_buf)?;
let header_offset = u64::from_le_bytes(u64_buf);
cursor.read_exact(&mut u64_buf)?;
let vectors_offset = u64::from_le_bytes(u64_buf);
cursor.read_exact(&mut u64_buf)?;
let graph_offset = u64::from_le_bytes(u64_buf);
cursor.read_exact(&mut u64_buf)?;
let ids_offset = u64::from_le_bytes(u64_buf);
cursor.read_exact(&mut u64_buf)?;
let term_dict_offset = u64::from_le_bytes(u64_buf);
cursor.read_exact(&mut u64_buf)?;
let term_dict_len = u64::from_le_bytes(u64_buf);
cursor.read_exact(&mut u64_buf)?;
let postings_offset = u64::from_le_bytes(u64_buf);
cursor.read_exact(&mut u64_buf)?;
let postings_len = u64::from_le_bytes(u64_buf);
cursor.read_exact(&mut u32_buf)?;
let doc_count = u32::from_le_bytes(u32_buf);
cursor.read_exact(&mut u32_buf)?;
let max_doc_id = u32::from_le_bytes(u32_buf);
cursor.read_exact(&mut u32_buf)?;
let checksum = u32::from_le_bytes(u32_buf);
Ok(Self {
magic,
format_version,
header_offset,
vectors_offset,
graph_offset,
ids_offset,
term_dict_offset,
term_dict_len,
postings_offset,
postings_len,
doc_count,
max_doc_id,
checksum,
})
}
pub fn write<W: std::io::Write>(&self, writer: &mut W) -> super::error::PersistenceResult<()> {
writer.write_all(&self.magic)?;
writer.write_all(&self.format_version.to_le_bytes())?;
writer.write_all(&self.header_offset.to_le_bytes())?;
writer.write_all(&self.vectors_offset.to_le_bytes())?;
writer.write_all(&self.graph_offset.to_le_bytes())?;
writer.write_all(&self.ids_offset.to_le_bytes())?;
writer.write_all(&self.term_dict_offset.to_le_bytes())?;
writer.write_all(&self.term_dict_len.to_le_bytes())?;
writer.write_all(&self.postings_offset.to_le_bytes())?;
writer.write_all(&self.postings_len.to_le_bytes())?;
writer.write_all(&self.doc_count.to_le_bytes())?;
writer.write_all(&self.max_doc_id.to_le_bytes())?;
writer.write_all(&self.checksum.to_le_bytes())?;
Ok(())
}
}
pub trait Persistable: Sized {
fn to_bytes(&self) -> crate::Result<Vec<u8>>;
fn from_bytes(bytes: &[u8]) -> crate::Result<Self>;
fn size_hint(&self) -> usize;
}
pub trait IndexPersistence: Sized {
fn save(&self, path: &std::path::Path) -> crate::Result<()>;
fn load(path: &std::path::Path) -> crate::Result<Self>;
fn exists(path: &std::path::Path) -> bool {
path.join("manifest.json").exists()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wal_entry_type_roundtrip() {
assert_eq!(WalEntryType::try_from(1), Ok(WalEntryType::Insert));
assert_eq!(WalEntryType::try_from(2), Ok(WalEntryType::Delete));
assert_eq!(WalEntryType::try_from(3), Ok(WalEntryType::Update));
assert_eq!(WalEntryType::try_from(4), Ok(WalEntryType::Checkpoint));
assert_eq!(WalEntryType::try_from(99), Err(()));
}
#[test]
fn test_segment_header_serde() {
let header = SegmentHeader {
segment_id: 1,
vector_count: 1000,
dimension: 128,
index_type: IndexType::Hnsw,
compression: CompressionType::None,
created_at: 1234567890,
metadata: std::collections::HashMap::new(),
};
let json = serde_json::to_string(&header).unwrap();
let parsed: SegmentHeader = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.segment_id, 1);
assert_eq!(parsed.index_type, IndexType::Hnsw);
}
#[test]
fn test_manifest_serde() {
let manifest = IndexManifest {
version: FORMAT_VERSION,
index_type: IndexType::DiskAnn,
dimension: 384,
total_vectors: 10000,
segments: vec![1, 2, 3],
wal_sequence: 42,
checkpoint_id: Some(5),
config: serde_json::json!({"M": 16, "ef_construction": 200}),
created_at: 1234567890,
modified_at: 1234567899,
};
let json = serde_json::to_string_pretty(&manifest).unwrap();
let parsed: IndexManifest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.version, FORMAT_VERSION);
assert_eq!(parsed.segments.len(), 3);
}
}