use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use crate::dsl::VectorIndexType;
use crate::error::{Error, Result};
use crate::schema::Schema;
pub const INDEX_META_FILENAME: &str = "metadata.json";
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
pub enum VectorIndexState {
#[default]
Flat,
Built {
vector_count: usize,
num_clusters: usize,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FieldVectorMeta {
pub field_id: u32,
pub index_type: VectorIndexType,
pub state: VectorIndexState,
#[serde(skip_serializing_if = "Option::is_none")]
pub centroids_file: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub codebook_file: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexMetadata {
pub version: u32,
pub schema: Schema,
pub segments: Vec<String>,
#[serde(default)]
pub vector_fields: HashMap<u32, FieldVectorMeta>,
#[serde(default)]
pub total_vectors: usize,
}
impl IndexMetadata {
pub fn new(schema: Schema) -> Self {
Self {
version: 1,
schema,
segments: Vec::new(),
vector_fields: HashMap::new(),
total_vectors: 0,
}
}
pub fn is_field_built(&self, field_id: u32) -> bool {
self.vector_fields
.get(&field_id)
.map(|f| matches!(f.state, VectorIndexState::Built { .. }))
.unwrap_or(false)
}
pub fn get_field_meta(&self, field_id: u32) -> Option<&FieldVectorMeta> {
self.vector_fields.get(&field_id)
}
pub fn init_field(&mut self, field_id: u32, index_type: VectorIndexType) {
self.vector_fields
.entry(field_id)
.or_insert(FieldVectorMeta {
field_id,
index_type,
state: VectorIndexState::Flat,
centroids_file: None,
codebook_file: None,
});
}
pub fn mark_field_built(
&mut self,
field_id: u32,
vector_count: usize,
num_clusters: usize,
centroids_file: String,
codebook_file: Option<String>,
) {
if let Some(field) = self.vector_fields.get_mut(&field_id) {
field.state = VectorIndexState::Built {
vector_count,
num_clusters,
};
field.centroids_file = Some(centroids_file);
field.codebook_file = codebook_file;
}
}
pub fn should_build_field(&self, field_id: u32, threshold: usize) -> bool {
if self.is_field_built(field_id) {
return false;
}
self.total_vectors >= threshold
}
pub fn add_segment(&mut self, segment_id: String) {
if !self.segments.contains(&segment_id) {
self.segments.push(segment_id);
}
}
pub fn remove_segments(&mut self, to_remove: &[String]) {
self.segments.retain(|s| !to_remove.contains(s));
}
pub async fn load<D: crate::directories::Directory>(dir: &D) -> Result<Self> {
let path = Path::new(INDEX_META_FILENAME);
let slice = dir.open_read(path).await?;
let bytes = slice.read_bytes().await?;
serde_json::from_slice(bytes.as_slice()).map_err(|e| Error::Serialization(e.to_string()))
}
pub async fn save<D: crate::directories::DirectoryWriter>(&self, dir: &D) -> Result<()> {
let path = Path::new(INDEX_META_FILENAME);
let bytes =
serde_json::to_vec_pretty(self).map_err(|e| Error::Serialization(e.to_string()))?;
dir.write(path, &bytes).await.map_err(Error::Io)
}
pub async fn load_trained_structures<D: crate::directories::Directory>(
&self,
dir: &D,
) -> (
rustc_hash::FxHashMap<u32, std::sync::Arc<crate::structures::CoarseCentroids>>,
rustc_hash::FxHashMap<u32, std::sync::Arc<crate::structures::PQCodebook>>,
) {
use std::sync::Arc;
let mut centroids = rustc_hash::FxHashMap::default();
let mut codebooks = rustc_hash::FxHashMap::default();
for (field_id, field_meta) in &self.vector_fields {
if !matches!(field_meta.state, VectorIndexState::Built { .. }) {
continue;
}
if let Some(ref file) = field_meta.centroids_file
&& let Ok(slice) = dir.open_read(Path::new(file)).await
&& let Ok(bytes) = slice.read_bytes().await
&& let Ok(c) =
serde_json::from_slice::<crate::structures::CoarseCentroids>(bytes.as_slice())
{
centroids.insert(*field_id, Arc::new(c));
}
if let Some(ref file) = field_meta.codebook_file
&& let Ok(slice) = dir.open_read(Path::new(file)).await
&& let Ok(bytes) = slice.read_bytes().await
&& let Ok(c) =
serde_json::from_slice::<crate::structures::PQCodebook>(bytes.as_slice())
{
codebooks.insert(*field_id, Arc::new(c));
}
}
(centroids, codebooks)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_schema() -> Schema {
Schema::default()
}
#[test]
fn test_metadata_init() {
let mut meta = IndexMetadata::new(test_schema());
assert_eq!(meta.total_vectors, 0);
assert!(meta.segments.is_empty());
assert!(!meta.is_field_built(0));
meta.init_field(0, VectorIndexType::IvfRaBitQ);
assert!(!meta.is_field_built(0));
assert!(meta.vector_fields.contains_key(&0));
}
#[test]
fn test_metadata_segments() {
let mut meta = IndexMetadata::new(test_schema());
meta.add_segment("abc123".to_string());
meta.add_segment("def456".to_string());
assert_eq!(meta.segments.len(), 2);
meta.add_segment("abc123".to_string());
assert_eq!(meta.segments.len(), 2);
meta.remove_segments(&["abc123".to_string()]);
assert_eq!(meta.segments.len(), 1);
assert_eq!(meta.segments[0], "def456");
}
#[test]
fn test_mark_field_built() {
let mut meta = IndexMetadata::new(test_schema());
meta.init_field(0, VectorIndexType::IvfRaBitQ);
meta.total_vectors = 10000;
assert!(!meta.is_field_built(0));
meta.mark_field_built(0, 10000, 256, "field_0_centroids.bin".to_string(), None);
assert!(meta.is_field_built(0));
let field = meta.get_field_meta(0).unwrap();
assert_eq!(
field.centroids_file.as_deref(),
Some("field_0_centroids.bin")
);
}
#[test]
fn test_should_build_field() {
let mut meta = IndexMetadata::new(test_schema());
meta.init_field(0, VectorIndexType::IvfRaBitQ);
meta.total_vectors = 500;
assert!(!meta.should_build_field(0, 1000));
meta.total_vectors = 1500;
assert!(meta.should_build_field(0, 1000));
meta.mark_field_built(0, 1500, 256, "centroids.bin".to_string(), None);
assert!(!meta.should_build_field(0, 1000));
}
#[test]
fn test_serialization() {
let mut meta = IndexMetadata::new(test_schema());
meta.add_segment("seg1".to_string());
meta.init_field(0, VectorIndexType::IvfRaBitQ);
meta.total_vectors = 5000;
let json = serde_json::to_string_pretty(&meta).unwrap();
let loaded: IndexMetadata = serde_json::from_str(&json).unwrap();
assert_eq!(loaded.segments, meta.segments);
assert_eq!(loaded.total_vectors, meta.total_vectors);
assert!(loaded.vector_fields.contains_key(&0));
}
}