Skip to main content

claw_vector/index/
selector.rs

1// index/selector.rs — auto-selecting index that migrates from FlatIndex to HnswIndex
2// when the collection surpasses HNSW_THRESHOLD (1 000 vectors).
3use std::{
4    path::{Path, PathBuf},
5    time::{SystemTime, UNIX_EPOCH},
6};
7
8use serde::{Deserialize, Serialize};
9use tracing::instrument;
10
11use crate::{
12    config::VectorConfig,
13    error::{VectorError, VectorResult},
14    index::{flat::FlatIndex, hnsw::HnswIndex},
15    types::DistanceMetric,
16};
17
18/// Collection size above which the selector automatically migrates to HNSW.
19pub const HNSW_THRESHOLD: usize = 1_000;
20
21#[derive(Debug, Serialize, Deserialize)]
22struct PersistedIndex {
23    index_type: String,
24    points: Vec<(usize, Vec<f32>)>,
25}
26
27#[derive(Debug, Serialize, Deserialize)]
28struct IndexManifest {
29    blake3: String,
30    vector_count: u64,
31    dimensions: u32,
32    saved_at_ms: u64,
33}
34
35/// Transparently routes between a [`FlatIndex`] and a [`HnswIndex`].
36pub enum IndexSelector {
37    /// Brute-force index for small collections.
38    Flat(FlatIndex),
39    /// Approximate NN index for larger collections.
40    Hnsw(Box<HnswIndex>),
41}
42
43impl IndexSelector {
44    /// Create a new selector (always starts as Flat).
45    pub fn new(dimensions: usize, distance: DistanceMetric, _config: &VectorConfig) -> Self {
46        IndexSelector::Flat(FlatIndex::new(dimensions, distance))
47    }
48
49    /// Insert a single vector, migrating to HNSW if the threshold is crossed.
50    #[instrument(skip(self, vector, config))]
51    pub fn insert(
52        &mut self,
53        id: usize,
54        vector: Vec<f32>,
55        config: &VectorConfig,
56    ) -> VectorResult<()> {
57        match self {
58            IndexSelector::Flat(flat) => {
59                flat.insert(id, vector)?;
60                if flat.len() > HNSW_THRESHOLD {
61                    self.migrate_to_hnsw(config)?;
62                }
63            }
64            IndexSelector::Hnsw(hnsw) => hnsw.insert(id, &vector)?,
65        }
66        Ok(())
67    }
68
69    /// Insert a batch of vectors, migrating to HNSW if the threshold is crossed.
70    #[instrument(skip(self, items, config))]
71    pub fn insert_batch(
72        &mut self,
73        items: Vec<(usize, Vec<f32>)>,
74        config: &VectorConfig,
75    ) -> VectorResult<()> {
76        match self {
77            IndexSelector::Flat(flat) => {
78                flat.insert_batch(items)?;
79                if flat.len() > HNSW_THRESHOLD {
80                    self.migrate_to_hnsw(config)?;
81                }
82            }
83            IndexSelector::Hnsw(hnsw) => hnsw.insert_batch(&items)?,
84        }
85        Ok(())
86    }
87
88    /// Search for `top_k` nearest neighbours of `query`.
89    #[instrument(skip(self, query))]
90    pub fn search(
91        &self,
92        query: &[f32],
93        top_k: usize,
94        ef_search: usize,
95    ) -> VectorResult<Vec<(usize, f32)>> {
96        match self {
97            IndexSelector::Flat(flat) => flat.search(query, top_k),
98            IndexSelector::Hnsw(hnsw) => hnsw.search(query, top_k, ef_search),
99        }
100    }
101
102    /// Delete a vector by id. Returns `true` if the id was present.
103    #[instrument(skip(self))]
104    pub fn delete(&mut self, id: usize) -> VectorResult<bool> {
105        match self {
106            IndexSelector::Flat(flat) => flat.delete(id),
107            IndexSelector::Hnsw(hnsw) => {
108                hnsw.delete(id)?;
109                Ok(true)
110            }
111        }
112    }
113
114    /// Return the number of live elements.
115    pub fn len(&self) -> usize {
116        match self {
117            IndexSelector::Flat(f) => f.len(),
118            IndexSelector::Hnsw(h) => h.len(),
119        }
120    }
121
122    /// Return `true` if the selector contains no live elements.
123    pub fn is_empty(&self) -> bool {
124        self.len() == 0
125    }
126
127    /// Return `true` if the selector is backed by HNSW.
128    pub fn is_hnsw(&self) -> bool {
129        matches!(self, IndexSelector::Hnsw(_))
130    }
131
132    /// Migrate from FlatIndex to HnswIndex, replacing `self`.
133    #[instrument(skip(self, config))]
134    pub fn migrate_to_hnsw(&mut self, config: &VectorConfig) -> VectorResult<()> {
135        let hnsw = match self {
136            IndexSelector::Flat(flat) => {
137                tracing::info!(elements = flat.len(), "migrating flat index to HNSW");
138                flat.to_hnsw(config)?
139            }
140            IndexSelector::Hnsw(_) => return Ok(()),
141        };
142        *self = IndexSelector::Hnsw(Box::new(hnsw));
143        Ok(())
144    }
145
146    /// Persist the index under `<dir>/<collection>/`.
147    #[instrument(skip(self))]
148    pub fn save(&self, dir: &Path, workspace_id: &str, collection: &str) -> VectorResult<()> {
149        let col_dir = dir.join(workspace_id).join(collection);
150        std::fs::create_dir_all(&col_dir)?;
151
152        let persisted = match self {
153            IndexSelector::Flat(flat) => PersistedIndex {
154                index_type: "flat".to_string(),
155                points: flat.all_vectors()?,
156            },
157            IndexSelector::Hnsw(hnsw) => PersistedIndex {
158                index_type: "hnsw".to_string(),
159                points: hnsw.snapshot_points()?,
160            },
161        };
162
163        let payload = serde_json::to_vec(&persisted)?;
164        let final_path = idx_file(&col_dir, collection);
165        let tmp_path = idx_tmp_file(&col_dir, collection);
166        std::fs::write(&tmp_path, &payload)?;
167        std::fs::rename(&tmp_path, &final_path)?;
168
169        let saved_at_ms = SystemTime::now()
170            .duration_since(UNIX_EPOCH)
171            .map(|duration| duration.as_millis() as u64)
172            .unwrap_or(0);
173        let dimensions = match self {
174            IndexSelector::Flat(flat) => flat.dimensions,
175            IndexSelector::Hnsw(_) => {
176                if persisted.points.is_empty() {
177                    0
178                } else {
179                    persisted.points[0].1.len()
180                }
181            }
182        };
183        let manifest = IndexManifest {
184            blake3: blake3::hash(&payload).to_hex().to_string(),
185            vector_count: persisted.points.len() as u64,
186            dimensions: dimensions as u32,
187            saved_at_ms,
188        };
189        std::fs::write(
190            idx_manifest_file(&col_dir, collection),
191            serde_json::to_vec_pretty(&manifest)?,
192        )?;
193        Ok(())
194    }
195
196    /// Reload a previously saved index from `<dir>/<collection>/`.
197    #[instrument(skip(config))]
198    pub fn load(
199        dir: &Path,
200        workspace_id: &str,
201        collection: &str,
202        config: &VectorConfig,
203        distance: DistanceMetric,
204        dimensions: usize,
205    ) -> VectorResult<Self> {
206        let col_dir = dir.join(workspace_id).join(collection);
207        let final_path = idx_file(&col_dir, collection);
208        let tmp_path = idx_tmp_file(&col_dir, collection);
209        let manifest_path = idx_manifest_file(&col_dir, collection);
210
211        if tmp_path.exists() && final_path.exists() {
212            let _ = std::fs::remove_file(&tmp_path);
213        }
214
215        if !manifest_path.exists() {
216            return Err(VectorError::Index("missing index manifest".into()));
217        }
218
219        let manifest: IndexManifest = serde_json::from_slice(&std::fs::read(&manifest_path)?)?;
220        let payload = std::fs::read(&final_path)?;
221        let digest = blake3::hash(&payload);
222        let expected = hex::decode(manifest.blake3)
223            .map_err(|err| VectorError::Index(format!("invalid manifest checksum: {err}")))?;
224        if !constant_time_eq(digest.as_bytes(), &expected) {
225            return Err(VectorError::Index("index checksum mismatch".into()));
226        }
227
228        let persisted: PersistedIndex = serde_json::from_slice(&payload)?;
229        match persisted.index_type.as_str() {
230            "flat" => {
231                let flat = FlatIndex::new(dimensions, distance);
232                flat.insert_batch(persisted.points)?;
233                Ok(IndexSelector::Flat(flat))
234            }
235            "hnsw" => {
236                let hnsw = HnswIndex::new_with_dimensions(config, distance, dimensions)?;
237                hnsw.insert_batch(&persisted.points)?;
238                Ok(IndexSelector::Hnsw(Box::new(hnsw)))
239            }
240            other => Err(VectorError::Index(format!("unknown index_type '{other}'"))),
241        }
242    }
243}
244
245fn idx_file(path: &Path, collection: &str) -> PathBuf {
246    path.join(format!("{collection}.idx"))
247}
248
249fn idx_tmp_file(path: &Path, collection: &str) -> PathBuf {
250    path.join(format!("{collection}.idx.tmp"))
251}
252
253fn idx_manifest_file(path: &Path, collection: &str) -> PathBuf {
254    path.join(format!("{collection}.idx.manifest"))
255}
256
257fn constant_time_eq(left: &[u8], right: &[u8]) -> bool {
258    if left.len() != right.len() {
259        return false;
260    }
261    let mut diff = 0u8;
262    for (a, b) in left.iter().zip(right.iter()) {
263        diff |= a ^ b;
264    }
265    diff == 0
266}