1use std::{
4 path::{Path, PathBuf},
5 time::{SystemTime, UNIX_EPOCH},
6};
7
8use serde::{Deserialize, Serialize};
9use tracing::instrument;
10
11use crate::{
12 config::VectorConfig,
13 error::{VectorError, VectorResult},
14 index::{flat::FlatIndex, hnsw::HnswIndex},
15 types::DistanceMetric,
16};
17
18pub const HNSW_THRESHOLD: usize = 1_000;
20
21#[derive(Debug, Serialize, Deserialize)]
22struct PersistedIndex {
23 index_type: String,
24 points: Vec<(usize, Vec<f32>)>,
25}
26
27#[derive(Debug, Serialize, Deserialize)]
28struct IndexManifest {
29 blake3: String,
30 vector_count: u64,
31 dimensions: u32,
32 saved_at_ms: u64,
33}
34
35pub enum IndexSelector {
37 Flat(FlatIndex),
39 Hnsw(Box<HnswIndex>),
41}
42
43impl IndexSelector {
44 pub fn new(dimensions: usize, distance: DistanceMetric, _config: &VectorConfig) -> Self {
46 IndexSelector::Flat(FlatIndex::new(dimensions, distance))
47 }
48
49 #[instrument(skip(self, vector, config))]
51 pub fn insert(
52 &mut self,
53 id: usize,
54 vector: Vec<f32>,
55 config: &VectorConfig,
56 ) -> VectorResult<()> {
57 match self {
58 IndexSelector::Flat(flat) => {
59 flat.insert(id, vector)?;
60 if flat.len() > HNSW_THRESHOLD {
61 self.migrate_to_hnsw(config)?;
62 }
63 }
64 IndexSelector::Hnsw(hnsw) => hnsw.insert(id, &vector)?,
65 }
66 Ok(())
67 }
68
69 #[instrument(skip(self, items, config))]
71 pub fn insert_batch(
72 &mut self,
73 items: Vec<(usize, Vec<f32>)>,
74 config: &VectorConfig,
75 ) -> VectorResult<()> {
76 match self {
77 IndexSelector::Flat(flat) => {
78 flat.insert_batch(items)?;
79 if flat.len() > HNSW_THRESHOLD {
80 self.migrate_to_hnsw(config)?;
81 }
82 }
83 IndexSelector::Hnsw(hnsw) => hnsw.insert_batch(&items)?,
84 }
85 Ok(())
86 }
87
88 #[instrument(skip(self, query))]
90 pub fn search(
91 &self,
92 query: &[f32],
93 top_k: usize,
94 ef_search: usize,
95 ) -> VectorResult<Vec<(usize, f32)>> {
96 match self {
97 IndexSelector::Flat(flat) => flat.search(query, top_k),
98 IndexSelector::Hnsw(hnsw) => hnsw.search(query, top_k, ef_search),
99 }
100 }
101
102 #[instrument(skip(self))]
104 pub fn delete(&mut self, id: usize) -> VectorResult<bool> {
105 match self {
106 IndexSelector::Flat(flat) => flat.delete(id),
107 IndexSelector::Hnsw(hnsw) => {
108 hnsw.delete(id)?;
109 Ok(true)
110 }
111 }
112 }
113
114 pub fn len(&self) -> usize {
116 match self {
117 IndexSelector::Flat(f) => f.len(),
118 IndexSelector::Hnsw(h) => h.len(),
119 }
120 }
121
122 pub fn is_empty(&self) -> bool {
124 self.len() == 0
125 }
126
127 pub fn is_hnsw(&self) -> bool {
129 matches!(self, IndexSelector::Hnsw(_))
130 }
131
132 #[instrument(skip(self, config))]
134 pub fn migrate_to_hnsw(&mut self, config: &VectorConfig) -> VectorResult<()> {
135 let hnsw = match self {
136 IndexSelector::Flat(flat) => {
137 tracing::info!(elements = flat.len(), "migrating flat index to HNSW");
138 flat.to_hnsw(config)?
139 }
140 IndexSelector::Hnsw(_) => return Ok(()),
141 };
142 *self = IndexSelector::Hnsw(Box::new(hnsw));
143 Ok(())
144 }
145
146 #[instrument(skip(self))]
148 pub fn save(&self, dir: &Path, workspace_id: &str, collection: &str) -> VectorResult<()> {
149 let col_dir = dir.join(workspace_id).join(collection);
150 std::fs::create_dir_all(&col_dir)?;
151
152 let persisted = match self {
153 IndexSelector::Flat(flat) => PersistedIndex {
154 index_type: "flat".to_string(),
155 points: flat.all_vectors()?,
156 },
157 IndexSelector::Hnsw(hnsw) => PersistedIndex {
158 index_type: "hnsw".to_string(),
159 points: hnsw.snapshot_points()?,
160 },
161 };
162
163 let payload = serde_json::to_vec(&persisted)?;
164 let final_path = idx_file(&col_dir, collection);
165 let tmp_path = idx_tmp_file(&col_dir, collection);
166 std::fs::write(&tmp_path, &payload)?;
167 std::fs::rename(&tmp_path, &final_path)?;
168
169 let saved_at_ms = SystemTime::now()
170 .duration_since(UNIX_EPOCH)
171 .map(|duration| duration.as_millis() as u64)
172 .unwrap_or(0);
173 let dimensions = match self {
174 IndexSelector::Flat(flat) => flat.dimensions,
175 IndexSelector::Hnsw(_) => {
176 if persisted.points.is_empty() {
177 0
178 } else {
179 persisted.points[0].1.len()
180 }
181 }
182 };
183 let manifest = IndexManifest {
184 blake3: blake3::hash(&payload).to_hex().to_string(),
185 vector_count: persisted.points.len() as u64,
186 dimensions: dimensions as u32,
187 saved_at_ms,
188 };
189 std::fs::write(
190 idx_manifest_file(&col_dir, collection),
191 serde_json::to_vec_pretty(&manifest)?,
192 )?;
193 Ok(())
194 }
195
196 #[instrument(skip(config))]
198 pub fn load(
199 dir: &Path,
200 workspace_id: &str,
201 collection: &str,
202 config: &VectorConfig,
203 distance: DistanceMetric,
204 dimensions: usize,
205 ) -> VectorResult<Self> {
206 let col_dir = dir.join(workspace_id).join(collection);
207 let final_path = idx_file(&col_dir, collection);
208 let tmp_path = idx_tmp_file(&col_dir, collection);
209 let manifest_path = idx_manifest_file(&col_dir, collection);
210
211 if tmp_path.exists() && final_path.exists() {
212 let _ = std::fs::remove_file(&tmp_path);
213 }
214
215 if !manifest_path.exists() {
216 return Err(VectorError::Index("missing index manifest".into()));
217 }
218
219 let manifest: IndexManifest = serde_json::from_slice(&std::fs::read(&manifest_path)?)?;
220 let payload = std::fs::read(&final_path)?;
221 let digest = blake3::hash(&payload);
222 let expected = hex::decode(manifest.blake3)
223 .map_err(|err| VectorError::Index(format!("invalid manifest checksum: {err}")))?;
224 if !constant_time_eq(digest.as_bytes(), &expected) {
225 return Err(VectorError::Index("index checksum mismatch".into()));
226 }
227
228 let persisted: PersistedIndex = serde_json::from_slice(&payload)?;
229 match persisted.index_type.as_str() {
230 "flat" => {
231 let flat = FlatIndex::new(dimensions, distance);
232 flat.insert_batch(persisted.points)?;
233 Ok(IndexSelector::Flat(flat))
234 }
235 "hnsw" => {
236 let hnsw = HnswIndex::new_with_dimensions(config, distance, dimensions)?;
237 hnsw.insert_batch(&persisted.points)?;
238 Ok(IndexSelector::Hnsw(Box::new(hnsw)))
239 }
240 other => Err(VectorError::Index(format!("unknown index_type '{other}'"))),
241 }
242 }
243}
244
245fn idx_file(path: &Path, collection: &str) -> PathBuf {
246 path.join(format!("{collection}.idx"))
247}
248
249fn idx_tmp_file(path: &Path, collection: &str) -> PathBuf {
250 path.join(format!("{collection}.idx.tmp"))
251}
252
253fn idx_manifest_file(path: &Path, collection: &str) -> PathBuf {
254 path.join(format!("{collection}.idx.manifest"))
255}
256
257fn constant_time_eq(left: &[u8], right: &[u8]) -> bool {
258 if left.len() != right.len() {
259 return false;
260 }
261 let mut diff = 0u8;
262 for (a, b) in left.iter().zip(right.iter()) {
263 diff |= a ^ b;
264 }
265 diff == 0
266}