1use std::collections::HashMap;
8use std::io::Cursor;
9use std::path::Path;
10
11use anyhow::Result;
12use ndarray::Array2;
13use tracing::info;
14
15use crate::hnsw::build::build_hnsw_with_threads;
16use crate::hnsw::csr::convert_to_csr;
17use crate::hnsw::graph::{HnswConfig, HnswGraph, VectorStorage};
18use crate::hnsw::io::{read_hnsw_index, write_hnsw_compact, write_hnsw_standard};
19use crate::hnsw::search::{SearchParams, search_hnsw, search_hnsw_recompute};
20use crate::index::DistanceMetric;
21
22pub use crate::hnsw::search::PruningStrategy;
24
25#[derive(Debug)]
33pub enum BackendConfig {
34 Hnsw {
35 m: usize,
36 ef_construction: usize,
37 distance_metric: DistanceMetric,
38 is_compact: bool,
39 is_recompute: bool,
40 num_threads: usize,
41 seed: Option<u64>,
42 },
43 }
45
46impl BackendConfig {
47 pub fn hnsw_default() -> Self {
49 let defaults = HnswConfig::default();
50 Self::Hnsw {
51 m: defaults.m,
52 ef_construction: defaults.ef_construction,
53 distance_metric: defaults.distance_metric,
54 is_compact: defaults.is_compact,
55 is_recompute: defaults.is_recompute,
56 num_threads: std::thread::available_parallelism()
57 .map(|n| n.get())
58 .unwrap_or(1),
59 seed: defaults.seed,
60 }
61 }
62
63 pub fn from_name(name: &str) -> Result<Self> {
65 match name {
66 "hnsw" => Ok(Self::hnsw_default()),
67 other => anyhow::bail!(
68 "Backend '{}' is not supported. Available backends: hnsw",
69 other
70 ),
71 }
72 }
73
74 pub fn name(&self) -> &str {
76 match self {
77 Self::Hnsw { .. } => "hnsw",
78 }
79 }
80
81 pub fn distance_metric(&self) -> DistanceMetric {
83 match self {
84 Self::Hnsw {
85 distance_metric, ..
86 } => *distance_metric,
87 }
88 }
89
90 pub fn set_distance_metric(&mut self, metric: DistanceMetric) {
92 match self {
93 Self::Hnsw {
94 distance_metric, ..
95 } => *distance_metric = metric,
96 }
97 }
98
99 pub fn set_m(&mut self, val: usize) {
101 match self {
102 Self::Hnsw { m, .. } => *m = val,
103 }
104 }
105
106 pub fn set_ef_construction(&mut self, val: usize) {
108 match self {
109 Self::Hnsw {
110 ef_construction, ..
111 } => *ef_construction = val,
112 }
113 }
114
115 pub fn set_compact(&mut self, val: bool) {
117 match self {
118 Self::Hnsw { is_compact, .. } => *is_compact = val,
119 }
120 }
121
122 pub fn set_recompute(&mut self, val: bool) {
124 match self {
125 Self::Hnsw { is_recompute, .. } => *is_recompute = val,
126 }
127 }
128
129 pub fn set_num_threads(&mut self, val: usize) {
131 match self {
132 Self::Hnsw { num_threads, .. } => *num_threads = val.max(1),
133 }
134 }
135
136 pub fn to_backend_kwargs(&self) -> HashMap<String, serde_json::Value> {
138 match self {
139 Self::Hnsw {
140 m,
141 ef_construction,
142 distance_metric,
143 is_compact,
144 is_recompute,
145 ..
146 } => {
147 let mut kwargs = HashMap::new();
148 kwargs.insert("M".to_string(), serde_json::json!(m));
149 kwargs.insert(
150 "efConstruction".to_string(),
151 serde_json::json!(ef_construction),
152 );
153 kwargs.insert(
154 "distance_metric".to_string(),
155 serde_json::json!(match distance_metric {
156 DistanceMetric::L2 => "l2",
157 DistanceMetric::Cosine => "cosine",
158 DistanceMetric::Mips => "mips",
159 }),
160 );
161 kwargs.insert("is_compact".to_string(), serde_json::json!(is_compact));
162 kwargs.insert("is_recompute".to_string(), serde_json::json!(is_recompute));
163 kwargs
164 }
165 }
166 }
167
168 pub fn to_hnsw_config(&self) -> HnswConfig {
170 match self {
171 Self::Hnsw {
172 m,
173 ef_construction,
174 distance_metric,
175 is_compact,
176 is_recompute,
177 seed,
178 ..
179 } => HnswConfig {
180 m: *m,
181 ef_construction: *ef_construction,
182 ef_search: 64, distance_metric: *distance_metric,
184 is_compact: *is_compact,
185 is_recompute: *is_recompute,
186 seed: *seed,
187 },
188 }
189 }
190
191 pub fn is_compact(&self) -> bool {
193 match self {
194 Self::Hnsw { is_compact, .. } => *is_compact,
195 }
196 }
197
198 pub fn is_recompute(&self) -> bool {
200 match self {
201 Self::Hnsw { is_recompute, .. } => *is_recompute,
202 }
203 }
204}
205
206pub enum BackendIndex {
212 Hnsw(HnswGraph),
213 }
215
216impl std::fmt::Debug for BackendIndex {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 match self {
219 Self::Hnsw(g) => f
220 .debug_struct("BackendIndex::Hnsw")
221 .field("ntotal", &g.ntotal)
222 .field("dimensions", &g.dimensions)
223 .finish(),
224 }
225 }
226}
227
228impl BackendIndex {
229 pub fn ntotal(&self) -> usize {
231 match self {
232 Self::Hnsw(g) => g.ntotal,
233 }
234 }
235
236 pub fn dimensions(&self) -> usize {
238 match self {
239 Self::Hnsw(g) => g.dimensions,
240 }
241 }
242
243 pub fn is_pruned(&self) -> bool {
245 match self {
246 Self::Hnsw(g) => g.is_pruned(),
247 }
248 }
249}
250
251pub fn build_backend(
260 config: &BackendConfig,
261 embeddings: &Array2<f32>,
262 index_file: &Path,
263 progress: Option<&dyn crate::hnsw::IndexProgress>,
264) -> Result<()> {
265 match config {
266 BackendConfig::Hnsw {
267 num_threads,
268 is_recompute,
269 is_compact,
270 distance_metric,
271 ..
272 } => {
273 let hnsw_config = config.to_hnsw_config();
274
275 info!(
276 "Building HNSW graph (M={}, efConstruction={})",
277 hnsw_config.m, hnsw_config.ef_construction
278 );
279 let mut graph =
280 build_hnsw_with_threads(embeddings, &hnsw_config, *num_threads, progress)?;
281
282 if !is_recompute {
284 let flat: Vec<f32> = embeddings.iter().copied().collect();
285 let storage_bytes = flat
286 .iter()
287 .flat_map(|f| f.to_le_bytes())
288 .collect::<Vec<u8>>();
289
290 let fourcc = match distance_metric {
291 DistanceMetric::L2 => u32::from_le_bytes(*b"IxFl"),
292 _ => u32::from_le_bytes(*b"IxFI"),
293 };
294
295 graph.vector_storage = VectorStorage::Raw {
296 fourcc,
297 data: storage_bytes,
298 };
299 }
300
301 let graph = if *is_compact {
303 info!("Converting to compact CSR format");
304 convert_to_csr(&graph)?
305 } else {
306 graph
307 };
308
309 let mut file = std::fs::File::create(index_file)?;
311 if graph.is_compact() {
312 write_hnsw_compact(&mut file, &graph)?;
313 } else {
314 write_hnsw_standard(&mut file, &graph)?;
315 }
316
317 Ok(())
318 }
319 }
320}
321
322pub fn read_backend_index(backend_name: &str, index_file: &Path) -> Result<BackendIndex> {
324 match backend_name {
325 "hnsw" => {
326 let index_data = std::fs::read(index_file)?;
327 let mut cursor = Cursor::new(index_data);
328 let graph = read_hnsw_index(&mut cursor)?;
329 Ok(BackendIndex::Hnsw(graph))
330 }
331 other => anyhow::bail!("Unknown backend '{}' — cannot read index", other),
332 }
333}
334
335pub fn search_backend(
337 index: &BackendIndex,
338 query: &[f32],
339 top_k: usize,
340 params: &SearchParams,
341) -> (Vec<usize>, Vec<f32>) {
342 match index {
343 BackendIndex::Hnsw(graph) => {
344 match &graph.vector_storage {
345 VectorStorage::Raw { data, .. } => {
346 let flat_vectors: Vec<f32> = data
347 .chunks_exact(4)
348 .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
349 .collect();
350 search_hnsw(graph, query, top_k, &flat_vectors, params)
351 }
352 VectorStorage::Null => {
353 (Vec::new(), Vec::new())
356 }
357 }
358 }
359 }
360}
361
362pub fn search_backend_recompute<F>(
364 index: &BackendIndex,
365 query: &[f32],
366 top_k: usize,
367 params: &SearchParams,
368 compute_distance: F,
369) -> (Vec<usize>, Vec<f32>)
370where
371 F: FnMut(&[usize], &[f32], &mut [f32]),
372{
373 match index {
374 BackendIndex::Hnsw(graph) => {
375 search_hnsw_recompute(graph, query, top_k, params, compute_distance)
376 }
377 }
378}
379
380#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn test_backend_config_hnsw_default() {
390 let cfg = BackendConfig::hnsw_default();
391 assert_eq!(cfg.name(), "hnsw");
392 assert_eq!(cfg.distance_metric(), DistanceMetric::Mips);
393 assert!(cfg.is_compact());
394 assert!(cfg.is_recompute());
395 }
396
397 #[test]
398 fn test_backend_config_from_name() {
399 assert!(BackendConfig::from_name("hnsw").is_ok());
400 assert!(BackendConfig::from_name("ivf").is_err());
401 assert!(BackendConfig::from_name("unknown").is_err());
402 }
403
404 #[test]
405 fn test_backend_config_setters() {
406 let mut cfg = BackendConfig::hnsw_default();
407 cfg.set_m(16);
408 cfg.set_ef_construction(100);
409 cfg.set_compact(false);
410 cfg.set_recompute(false);
411 cfg.set_distance_metric(DistanceMetric::L2);
412 cfg.set_num_threads(4);
413
414 assert!(!cfg.is_compact());
415 assert!(!cfg.is_recompute());
416 assert_eq!(cfg.distance_metric(), DistanceMetric::L2);
417
418 let hnsw = cfg.to_hnsw_config();
419 assert_eq!(hnsw.m, 16);
420 assert_eq!(hnsw.ef_construction, 100);
421 assert!(!hnsw.is_compact);
422 assert!(!hnsw.is_recompute);
423 assert_eq!(hnsw.distance_metric, DistanceMetric::L2);
424 }
425
426 #[test]
427 fn test_backend_kwargs_serialization() {
428 let cfg = BackendConfig::hnsw_default();
429 let kwargs = cfg.to_backend_kwargs();
430 assert_eq!(kwargs["M"], serde_json::json!(32));
431 assert_eq!(kwargs["efConstruction"], serde_json::json!(200));
432 assert_eq!(kwargs["distance_metric"], serde_json::json!("mips"));
433 assert_eq!(kwargs["is_compact"], serde_json::json!(true));
434 assert_eq!(kwargs["is_recompute"], serde_json::json!(true));
435 }
436
437 #[test]
438 fn test_read_backend_index_unknown() {
439 let tmp = tempfile::NamedTempFile::new().unwrap();
440 let result = read_backend_index("unknown", tmp.path());
441 assert!(result.is_err());
442 }
443}