Skip to main content

leann_core/
backend.rs

1//! Backend abstraction for LEANN index engines.
2//!
3//! Dispatches build, read, and search operations to the appropriate backend
4//! (currently HNSW only) via enum matching. Zero runtime overhead — the
5//! compiler devirtualizes each match arm.
6
7use std::collections::HashMap;
8use std::io::Cursor;
9use std::path::Path;
10
11use anyhow::Result;
12use ndarray::Array2;
13use tracing::info;
14
15use crate::hnsw::build::build_hnsw_with_threads;
16use crate::hnsw::csr::convert_to_csr;
17use crate::hnsw::graph::{HnswConfig, HnswGraph, VectorStorage};
18use crate::hnsw::io::{read_hnsw_index, write_hnsw_compact, write_hnsw_standard};
19use crate::hnsw::search::{SearchParams, search_hnsw, search_hnsw_recompute};
20use crate::index::DistanceMetric;
21
22// Re-export search types so callers don't need to reach into hnsw::search.
23pub use crate::hnsw::search::PruningStrategy;
24
25// ---------------------------------------------------------------------------
26// BackendConfig
27// ---------------------------------------------------------------------------
28
29/// Backend-specific build configuration.
30///
31/// Each variant holds all parameters needed to build an index with that backend.
32#[derive(Debug)]
33pub enum BackendConfig {
34    Hnsw {
35        m: usize,
36        ef_construction: usize,
37        distance_metric: DistanceMetric,
38        is_compact: bool,
39        is_recompute: bool,
40        num_threads: usize,
41        seed: Option<u64>,
42    },
43    // Future: Ivf { nlist: usize, distance_metric: DistanceMetric },
44}
45
46impl BackendConfig {
47    /// Default HNSW configuration (matches `HnswConfig::default()`).
48    pub fn hnsw_default() -> Self {
49        let defaults = HnswConfig::default();
50        Self::Hnsw {
51            m: defaults.m,
52            ef_construction: defaults.ef_construction,
53            distance_metric: defaults.distance_metric,
54            is_compact: defaults.is_compact,
55            is_recompute: defaults.is_recompute,
56            num_threads: std::thread::available_parallelism()
57                .map(|n| n.get())
58                .unwrap_or(1),
59            seed: defaults.seed,
60        }
61    }
62
63    /// Create a default config for the given backend name.
64    pub fn from_name(name: &str) -> Result<Self> {
65        match name {
66            "hnsw" => Ok(Self::hnsw_default()),
67            other => anyhow::bail!(
68                "Backend '{}' is not supported. Available backends: hnsw",
69                other
70            ),
71        }
72    }
73
74    /// Backend name string (for serialization into `IndexMeta`).
75    pub fn name(&self) -> &str {
76        match self {
77            Self::Hnsw { .. } => "hnsw",
78        }
79    }
80
81    /// Distance metric for this backend configuration.
82    pub fn distance_metric(&self) -> DistanceMetric {
83        match self {
84            Self::Hnsw {
85                distance_metric, ..
86            } => *distance_metric,
87        }
88    }
89
90    /// Set the distance metric.
91    pub fn set_distance_metric(&mut self, metric: DistanceMetric) {
92        match self {
93            Self::Hnsw {
94                distance_metric, ..
95            } => *distance_metric = metric,
96        }
97    }
98
99    /// Set M parameter (HNSW only).
100    pub fn set_m(&mut self, val: usize) {
101        match self {
102            Self::Hnsw { m, .. } => *m = val,
103        }
104    }
105
106    /// Set efConstruction parameter (HNSW only).
107    pub fn set_ef_construction(&mut self, val: usize) {
108        match self {
109            Self::Hnsw {
110                ef_construction, ..
111            } => *ef_construction = val,
112        }
113    }
114
115    /// Set compact mode (HNSW only).
116    pub fn set_compact(&mut self, val: bool) {
117        match self {
118            Self::Hnsw { is_compact, .. } => *is_compact = val,
119        }
120    }
121
122    /// Set recompute mode (HNSW only).
123    pub fn set_recompute(&mut self, val: bool) {
124        match self {
125            Self::Hnsw { is_recompute, .. } => *is_recompute = val,
126        }
127    }
128
129    /// Set number of build threads (HNSW only).
130    pub fn set_num_threads(&mut self, val: usize) {
131        match self {
132            Self::Hnsw { num_threads, .. } => *num_threads = val.max(1),
133        }
134    }
135
136    /// Convert to `backend_kwargs` map for `IndexMeta` serialization.
137    pub fn to_backend_kwargs(&self) -> HashMap<String, serde_json::Value> {
138        match self {
139            Self::Hnsw {
140                m,
141                ef_construction,
142                distance_metric,
143                is_compact,
144                is_recompute,
145                ..
146            } => {
147                let mut kwargs = HashMap::new();
148                kwargs.insert("M".to_string(), serde_json::json!(m));
149                kwargs.insert(
150                    "efConstruction".to_string(),
151                    serde_json::json!(ef_construction),
152                );
153                kwargs.insert(
154                    "distance_metric".to_string(),
155                    serde_json::json!(match distance_metric {
156                        DistanceMetric::L2 => "l2",
157                        DistanceMetric::Cosine => "cosine",
158                        DistanceMetric::Mips => "mips",
159                    }),
160                );
161                kwargs.insert("is_compact".to_string(), serde_json::json!(is_compact));
162                kwargs.insert("is_recompute".to_string(), serde_json::json!(is_recompute));
163                kwargs
164            }
165        }
166    }
167
168    /// Extract an `HnswConfig` from this configuration (panics if not HNSW).
169    pub fn to_hnsw_config(&self) -> HnswConfig {
170        match self {
171            Self::Hnsw {
172                m,
173                ef_construction,
174                distance_metric,
175                is_compact,
176                is_recompute,
177                seed,
178                ..
179            } => HnswConfig {
180                m: *m,
181                ef_construction: *ef_construction,
182                ef_search: 64, // search-time param, not stored in build config
183                distance_metric: *distance_metric,
184                is_compact: *is_compact,
185                is_recompute: *is_recompute,
186                seed: *seed,
187            },
188        }
189    }
190
191    /// Whether this config uses compact storage.
192    pub fn is_compact(&self) -> bool {
193        match self {
194            Self::Hnsw { is_compact, .. } => *is_compact,
195        }
196    }
197
198    /// Whether this config uses recompute mode.
199    pub fn is_recompute(&self) -> bool {
200        match self {
201            Self::Hnsw { is_recompute, .. } => *is_recompute,
202        }
203    }
204}
205
206// ---------------------------------------------------------------------------
207// BackendIndex
208// ---------------------------------------------------------------------------
209
210/// A loaded backend index, ready for search.
211pub enum BackendIndex {
212    Hnsw(HnswGraph),
213    // Future: Ivf { ... },
214}
215
216impl std::fmt::Debug for BackendIndex {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        match self {
219            Self::Hnsw(g) => f
220                .debug_struct("BackendIndex::Hnsw")
221                .field("ntotal", &g.ntotal)
222                .field("dimensions", &g.dimensions)
223                .finish(),
224        }
225    }
226}
227
228impl BackendIndex {
229    /// Total number of indexed vectors.
230    pub fn ntotal(&self) -> usize {
231        match self {
232            Self::Hnsw(g) => g.ntotal,
233        }
234    }
235
236    /// Dimensionality of indexed vectors.
237    pub fn dimensions(&self) -> usize {
238        match self {
239            Self::Hnsw(g) => g.dimensions,
240        }
241    }
242
243    /// Whether vector storage has been pruned (recompute mode).
244    pub fn is_pruned(&self) -> bool {
245        match self {
246            Self::Hnsw(g) => g.is_pruned(),
247        }
248    }
249}
250
251// ---------------------------------------------------------------------------
252// Dispatch functions
253// ---------------------------------------------------------------------------
254
255/// Build an index and write it to disk.
256///
257/// Handles graph construction, optional CSR compaction, optional vector storage,
258/// and serialization — the full pipeline from embeddings to on-disk index.
259pub fn build_backend(
260    config: &BackendConfig,
261    embeddings: &Array2<f32>,
262    index_file: &Path,
263    progress: Option<&dyn crate::hnsw::IndexProgress>,
264) -> Result<()> {
265    match config {
266        BackendConfig::Hnsw {
267            num_threads,
268            is_recompute,
269            is_compact,
270            distance_metric,
271            ..
272        } => {
273            let hnsw_config = config.to_hnsw_config();
274
275            info!(
276                "Building HNSW graph (M={}, efConstruction={})",
277                hnsw_config.m, hnsw_config.ef_construction
278            );
279            let mut graph =
280                build_hnsw_with_threads(embeddings, &hnsw_config, *num_threads, progress)?;
281
282            // Store vectors if not using recompute
283            if !is_recompute {
284                let flat: Vec<f32> = embeddings.iter().copied().collect();
285                let storage_bytes = flat
286                    .iter()
287                    .flat_map(|f| f.to_le_bytes())
288                    .collect::<Vec<u8>>();
289
290                let fourcc = match distance_metric {
291                    DistanceMetric::L2 => u32::from_le_bytes(*b"IxFl"),
292                    _ => u32::from_le_bytes(*b"IxFI"),
293                };
294
295                graph.vector_storage = VectorStorage::Raw {
296                    fourcc,
297                    data: storage_bytes,
298                };
299            }
300
301            // Convert to CSR if compact mode
302            let graph = if *is_compact {
303                info!("Converting to compact CSR format");
304                convert_to_csr(&graph)?
305            } else {
306                graph
307            };
308
309            // Write index file
310            let mut file = std::fs::File::create(index_file)?;
311            if graph.is_compact() {
312                write_hnsw_compact(&mut file, &graph)?;
313            } else {
314                write_hnsw_standard(&mut file, &graph)?;
315            }
316
317            Ok(())
318        }
319    }
320}
321
322/// Read an index from disk.
323pub fn read_backend_index(backend_name: &str, index_file: &Path) -> Result<BackendIndex> {
324    match backend_name {
325        "hnsw" => {
326            let index_data = std::fs::read(index_file)?;
327            let mut cursor = Cursor::new(index_data);
328            let graph = read_hnsw_index(&mut cursor)?;
329            Ok(BackendIndex::Hnsw(graph))
330        }
331        other => anyhow::bail!("Unknown backend '{}' — cannot read index", other),
332    }
333}
334
335/// Search using stored vectors.
336pub fn search_backend(
337    index: &BackendIndex,
338    query: &[f32],
339    top_k: usize,
340    params: &SearchParams,
341) -> (Vec<usize>, Vec<f32>) {
342    match index {
343        BackendIndex::Hnsw(graph) => {
344            match &graph.vector_storage {
345                VectorStorage::Raw { data, .. } => {
346                    let flat_vectors: Vec<f32> = data
347                        .chunks_exact(4)
348                        .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
349                        .collect();
350                    search_hnsw(graph, query, top_k, &flat_vectors, params)
351                }
352                VectorStorage::Null => {
353                    // No stored vectors — return empty results.
354                    // Caller should use search_backend_recompute instead.
355                    (Vec::new(), Vec::new())
356                }
357            }
358        }
359    }
360}
361
362/// Search using recomputed distances (embedding provider callback).
363pub fn search_backend_recompute<F>(
364    index: &BackendIndex,
365    query: &[f32],
366    top_k: usize,
367    params: &SearchParams,
368    compute_distance: F,
369) -> (Vec<usize>, Vec<f32>)
370where
371    F: FnMut(&[usize], &[f32], &mut [f32]),
372{
373    match index {
374        BackendIndex::Hnsw(graph) => {
375            search_hnsw_recompute(graph, query, top_k, params, compute_distance)
376        }
377    }
378}
379
380// ---------------------------------------------------------------------------
381// Tests
382// ---------------------------------------------------------------------------
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    #[test]
389    fn test_backend_config_hnsw_default() {
390        let cfg = BackendConfig::hnsw_default();
391        assert_eq!(cfg.name(), "hnsw");
392        assert_eq!(cfg.distance_metric(), DistanceMetric::Mips);
393        assert!(cfg.is_compact());
394        assert!(cfg.is_recompute());
395    }
396
397    #[test]
398    fn test_backend_config_from_name() {
399        assert!(BackendConfig::from_name("hnsw").is_ok());
400        assert!(BackendConfig::from_name("ivf").is_err());
401        assert!(BackendConfig::from_name("unknown").is_err());
402    }
403
404    #[test]
405    fn test_backend_config_setters() {
406        let mut cfg = BackendConfig::hnsw_default();
407        cfg.set_m(16);
408        cfg.set_ef_construction(100);
409        cfg.set_compact(false);
410        cfg.set_recompute(false);
411        cfg.set_distance_metric(DistanceMetric::L2);
412        cfg.set_num_threads(4);
413
414        assert!(!cfg.is_compact());
415        assert!(!cfg.is_recompute());
416        assert_eq!(cfg.distance_metric(), DistanceMetric::L2);
417
418        let hnsw = cfg.to_hnsw_config();
419        assert_eq!(hnsw.m, 16);
420        assert_eq!(hnsw.ef_construction, 100);
421        assert!(!hnsw.is_compact);
422        assert!(!hnsw.is_recompute);
423        assert_eq!(hnsw.distance_metric, DistanceMetric::L2);
424    }
425
426    #[test]
427    fn test_backend_kwargs_serialization() {
428        let cfg = BackendConfig::hnsw_default();
429        let kwargs = cfg.to_backend_kwargs();
430        assert_eq!(kwargs["M"], serde_json::json!(32));
431        assert_eq!(kwargs["efConstruction"], serde_json::json!(200));
432        assert_eq!(kwargs["distance_metric"], serde_json::json!("mips"));
433        assert_eq!(kwargs["is_compact"], serde_json::json!(true));
434        assert_eq!(kwargs["is_recompute"], serde_json::json!(true));
435    }
436
437    #[test]
438    fn test_read_backend_index_unknown() {
439        let tmp = tempfile::NamedTempFile::new().unwrap();
440        let result = read_backend_index("unknown", tmp.path());
441        assert!(result.is_err());
442    }
443}