1use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::path::Path;
16
17use crate::dsl::VectorIndexType;
18use crate::error::{Error, Result};
19
20pub const INDEX_META_FILENAME: &str = "metadata.json";
22
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
25pub enum VectorIndexState {
26 #[default]
28 Flat,
29 Built {
31 vector_count: usize,
33 num_clusters: usize,
35 },
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct FieldVectorMeta {
41 pub field_id: u32,
43 pub index_type: VectorIndexType,
45 pub state: VectorIndexState,
47 #[serde(skip_serializing_if = "Option::is_none")]
49 pub centroids_file: Option<String>,
50 #[serde(skip_serializing_if = "Option::is_none")]
52 pub codebook_file: Option<String>,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct IndexMetadata {
58 pub version: u32,
60 pub segments: Vec<String>,
62 #[serde(default)]
64 pub vector_fields: HashMap<u32, FieldVectorMeta>,
65 #[serde(default)]
67 pub total_vectors: usize,
68}
69
70impl Default for IndexMetadata {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl IndexMetadata {
77 pub fn new() -> Self {
79 Self {
80 version: 1,
81 segments: Vec::new(),
82 vector_fields: HashMap::new(),
83 total_vectors: 0,
84 }
85 }
86
87 pub fn from_segments(segments: Vec<String>) -> Self {
89 Self {
90 version: 1,
91 segments,
92 vector_fields: HashMap::new(),
93 total_vectors: 0,
94 }
95 }
96
97 pub fn is_field_built(&self, field_id: u32) -> bool {
99 self.vector_fields
100 .get(&field_id)
101 .map(|f| matches!(f.state, VectorIndexState::Built { .. }))
102 .unwrap_or(false)
103 }
104
105 pub fn get_field_meta(&self, field_id: u32) -> Option<&FieldVectorMeta> {
107 self.vector_fields.get(&field_id)
108 }
109
110 pub fn init_field(&mut self, field_id: u32, index_type: VectorIndexType) {
112 self.vector_fields
113 .entry(field_id)
114 .or_insert(FieldVectorMeta {
115 field_id,
116 index_type,
117 state: VectorIndexState::Flat,
118 centroids_file: None,
119 codebook_file: None,
120 });
121 }
122
123 pub fn mark_field_built(
125 &mut self,
126 field_id: u32,
127 vector_count: usize,
128 num_clusters: usize,
129 centroids_file: String,
130 codebook_file: Option<String>,
131 ) {
132 if let Some(field) = self.vector_fields.get_mut(&field_id) {
133 field.state = VectorIndexState::Built {
134 vector_count,
135 num_clusters,
136 };
137 field.centroids_file = Some(centroids_file);
138 field.codebook_file = codebook_file;
139 }
140 }
141
142 pub fn should_build_field(&self, field_id: u32, threshold: usize) -> bool {
144 if self.is_field_built(field_id) {
146 return false;
147 }
148 self.total_vectors >= threshold
150 }
151
152 pub fn add_segment(&mut self, segment_id: String) {
154 if !self.segments.contains(&segment_id) {
155 self.segments.push(segment_id);
156 }
157 }
158
159 pub fn remove_segments(&mut self, to_remove: &[String]) {
161 self.segments.retain(|s| !to_remove.contains(s));
162 }
163
164 pub async fn load<D: crate::directories::Directory>(dir: &D) -> Result<Self> {
166 let path = Path::new(INDEX_META_FILENAME);
167 match dir.open_read(path).await {
168 Ok(slice) => {
169 let bytes = slice.read_bytes().await?;
170 serde_json::from_slice(bytes.as_slice())
171 .map_err(|e| Error::Serialization(e.to_string()))
172 }
173 Err(_) => {
174 let old_path = Path::new("segments.json");
176 if let Ok(slice) = dir.open_read(old_path).await
177 && let Ok(bytes) = slice.read_bytes().await
178 && let Ok(segments) = serde_json::from_slice::<Vec<String>>(bytes.as_slice())
179 {
180 Ok(Self::from_segments(segments))
181 } else {
182 Ok(Self::new())
183 }
184 }
185 }
186 }
187
188 pub async fn save<D: crate::directories::DirectoryWriter>(&self, dir: &D) -> Result<()> {
190 let path = Path::new(INDEX_META_FILENAME);
191 let bytes =
192 serde_json::to_vec_pretty(self).map_err(|e| Error::Serialization(e.to_string()))?;
193 dir.write(path, &bytes).await.map_err(Error::Io)
194 }
195
196 pub async fn load_trained_structures<D: crate::directories::Directory>(
200 &self,
201 dir: &D,
202 ) -> (
203 rustc_hash::FxHashMap<u32, std::sync::Arc<crate::structures::CoarseCentroids>>,
204 rustc_hash::FxHashMap<u32, std::sync::Arc<crate::structures::PQCodebook>>,
205 ) {
206 use std::sync::Arc;
207
208 let mut centroids = rustc_hash::FxHashMap::default();
209 let mut codebooks = rustc_hash::FxHashMap::default();
210
211 for (field_id, field_meta) in &self.vector_fields {
212 if !matches!(field_meta.state, VectorIndexState::Built { .. }) {
213 continue;
214 }
215
216 if let Some(ref file) = field_meta.centroids_file
218 && let Ok(slice) = dir.open_read(Path::new(file)).await
219 && let Ok(bytes) = slice.read_bytes().await
220 && let Ok(c) =
221 serde_json::from_slice::<crate::structures::CoarseCentroids>(bytes.as_slice())
222 {
223 centroids.insert(*field_id, Arc::new(c));
224 }
225
226 if let Some(ref file) = field_meta.codebook_file
228 && let Ok(slice) = dir.open_read(Path::new(file)).await
229 && let Ok(bytes) = slice.read_bytes().await
230 && let Ok(c) =
231 serde_json::from_slice::<crate::structures::PQCodebook>(bytes.as_slice())
232 {
233 codebooks.insert(*field_id, Arc::new(c));
234 }
235 }
236
237 (centroids, codebooks)
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_metadata_init() {
247 let mut meta = IndexMetadata::new();
248 assert_eq!(meta.total_vectors, 0);
249 assert!(meta.segments.is_empty());
250 assert!(!meta.is_field_built(0));
251
252 meta.init_field(0, VectorIndexType::IvfRaBitQ);
253 assert!(!meta.is_field_built(0));
254 assert!(meta.vector_fields.contains_key(&0));
255 }
256
257 #[test]
258 fn test_metadata_segments() {
259 let mut meta = IndexMetadata::new();
260 meta.add_segment("abc123".to_string());
261 meta.add_segment("def456".to_string());
262 assert_eq!(meta.segments.len(), 2);
263
264 meta.add_segment("abc123".to_string());
266 assert_eq!(meta.segments.len(), 2);
267
268 meta.remove_segments(&["abc123".to_string()]);
269 assert_eq!(meta.segments.len(), 1);
270 assert_eq!(meta.segments[0], "def456");
271 }
272
273 #[test]
274 fn test_mark_field_built() {
275 let mut meta = IndexMetadata::new();
276 meta.init_field(0, VectorIndexType::IvfRaBitQ);
277 meta.total_vectors = 10000;
278
279 assert!(!meta.is_field_built(0));
280
281 meta.mark_field_built(0, 10000, 256, "field_0_centroids.bin".to_string(), None);
282
283 assert!(meta.is_field_built(0));
284 let field = meta.get_field_meta(0).unwrap();
285 assert_eq!(
286 field.centroids_file.as_deref(),
287 Some("field_0_centroids.bin")
288 );
289 }
290
291 #[test]
292 fn test_should_build_field() {
293 let mut meta = IndexMetadata::new();
294 meta.init_field(0, VectorIndexType::IvfRaBitQ);
295
296 meta.total_vectors = 500;
298 assert!(!meta.should_build_field(0, 1000));
299
300 meta.total_vectors = 1500;
302 assert!(meta.should_build_field(0, 1000));
303
304 meta.mark_field_built(0, 1500, 256, "centroids.bin".to_string(), None);
306 assert!(!meta.should_build_field(0, 1000));
307 }
308
309 #[test]
310 fn test_serialization() {
311 let mut meta = IndexMetadata::new();
312 meta.add_segment("seg1".to_string());
313 meta.init_field(0, VectorIndexType::IvfRaBitQ);
314 meta.total_vectors = 5000;
315
316 let json = serde_json::to_string_pretty(&meta).unwrap();
317 let loaded: IndexMetadata = serde_json::from_str(&json).unwrap();
318
319 assert_eq!(loaded.segments, meta.segments);
320 assert_eq!(loaded.total_vectors, meta.total_vectors);
321 assert!(loaded.vector_fields.contains_key(&0));
322 }
323}