Skip to main content

oxirs_vec/
gpu_hnsw_index.rs

1//! GPU-simulated HNSW index with parallel insert and fast approximate search.
2//!
3//! This module provides `GpuHnswIndex`, which simulates GPU-accelerated HNSW
4//! graph construction using CPU-side parallelism (via `std::thread`) as a
5//! pure-Rust stand-in for actual GPU batching.  The API intentionally mirrors
6//! a real GPU implementation so that the caller can be swapped for a CUDA
7//! version later without interface changes.
8//!
9//! # Design
10//!
11//! * **Simulated GPU batching**: vectors are accumulated in a staging batch;
12//!   when the batch is full the entire batch is "uploaded" (simulated) and
13//!   graph edges are computed in parallel across batch items.
14//! * **Layered graph**: a standard HNSW multi-layer graph where each node
15//!   stores at most `max_connections` bi-directional edges per layer.
16//! * **Approximate search**: greedy beam search starting from the entry point.
17//! * **No `unwrap()`**: all fallible operations propagate `anyhow::Error`.
18
19use anyhow::{anyhow, Result};
20use serde::{Deserialize, Serialize};
21use std::collections::{BinaryHeap, HashMap, HashSet};
22use std::sync::{Arc, Mutex};
23
24// ── configuration ─────────────────────────────────────────────────────────────
25
26/// Configuration for `GpuHnswIndex`.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct GpuHnswConfig {
29    /// Simulated GPU batch size (number of vectors per "GPU upload").
30    pub batch_size: usize,
31    /// Maximum connections per layer (M parameter).
32    pub max_connections: usize,
33    /// Maximum connections at layer 0 (M0 parameter, typically 2×M).
34    pub max_connections_layer0: usize,
35    /// ef_construction: candidate list size during graph construction.
36    pub ef_construction: usize,
37    /// ef_search: candidate list size during approximate search.
38    pub ef_search: usize,
39    /// Layer probability multiplier (1/ln(M)).
40    pub level_multiplier: f64,
41    /// Number of simulated GPU worker threads for batch construction.
42    pub gpu_workers: usize,
43}
44
45impl Default for GpuHnswConfig {
46    fn default() -> Self {
47        Self {
48            batch_size: 64,
49            max_connections: 16,
50            max_connections_layer0: 32,
51            ef_construction: 200,
52            ef_search: 50,
53            level_multiplier: 1.0 / (16_f64).ln(),
54            gpu_workers: 4,
55        }
56    }
57}
58
59// ── graph node ────────────────────────────────────────────────────────────────
60
61/// A node in the HNSW graph.
62#[derive(Debug, Clone)]
63struct HnswNode {
64    /// The raw floating-point vector.
65    vector: Vec<f32>,
66    /// `neighbors[layer]` holds the neighbor IDs for that layer.
67    neighbors: Vec<Vec<usize>>,
68    /// Maximum layer this node appears in.
69    max_layer: usize,
70}
71
72impl HnswNode {
73    fn new(vector: Vec<f32>, max_layer: usize, layers: usize) -> Self {
74        Self {
75            vector,
76            neighbors: vec![Vec::new(); layers],
77            max_layer,
78        }
79    }
80}
81
82// ── candidate / search helpers ────────────────────────────────────────────────
83
84/// A (distance, node_id) pair ordered for a max-heap (farthest first).
85#[derive(Debug, Clone, Copy, PartialEq)]
86struct Candidate {
87    dist: f32,
88    id: usize,
89}
90
91impl Eq for Candidate {}
92
93impl PartialOrd for Candidate {
94    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
95        Some(self.cmp(other))
96    }
97}
98
99impl Ord for Candidate {
100    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
101        // Max-heap on distance (farthest first)
102        self.dist
103            .partial_cmp(&other.dist)
104            .unwrap_or(std::cmp::Ordering::Equal)
105    }
106}
107
108// ── GPU batch stats ───────────────────────────────────────────────────────────
109
110/// Statistics collected during GPU-simulated batch construction.
111#[derive(Debug, Clone, Default, Serialize, Deserialize)]
112pub struct GpuBatchStats {
113    /// Total number of batches processed.
114    pub batches_processed: u64,
115    /// Total vectors inserted.
116    pub vectors_inserted: u64,
117    /// Total distance computations performed.
118    pub distance_computations: u64,
119    /// Average batch processing time in microseconds (simulated).
120    pub avg_batch_us: f64,
121}
122
123// ── index stats ───────────────────────────────────────────────────────────────
124
125/// Overall index statistics.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct GpuHnswStats {
128    /// Number of vectors in the index.
129    pub vector_count: usize,
130    /// Number of layers in the graph.
131    pub layer_count: usize,
132    /// GPU batch statistics.
133    pub batch_stats: GpuBatchStats,
134    /// Configured batch size.
135    pub batch_size: usize,
136    /// ef_search parameter.
137    pub ef_search: usize,
138}
139
140// ── main struct ───────────────────────────────────────────────────────────────
141
142/// GPU-simulated HNSW index.
143///
144/// Inserts are accumulated in a staging buffer; once the buffer reaches
145/// `config.batch_size` the batch is flushed via parallel construction
146/// threads (simulating GPU parallelism).
147pub struct GpuHnswIndex {
148    config: GpuHnswConfig,
149    /// All nodes keyed by numeric ID.
150    nodes: Vec<HnswNode>,
151    /// URI → node ID.
152    uri_to_id: HashMap<String, usize>,
153    /// Node ID → URI.
154    id_to_uri: Vec<String>,
155    /// Entry point into the top layer.
156    entry_point: Option<usize>,
157    /// Current top layer in the graph.
158    top_layer: usize,
159    /// Staging batch: (uri, vector) pairs waiting to be flushed.
160    pending_batch: Vec<(String, Vec<f32>)>,
161    /// Accumulated statistics.
162    batch_stats: GpuBatchStats,
163    /// Simple LCG RNG state for deterministic level generation.
164    rng_state: u64,
165}
166
167impl GpuHnswIndex {
168    /// Create a new GPU-simulated HNSW index.
169    pub fn new(config: GpuHnswConfig) -> Self {
170        Self {
171            config,
172            nodes: Vec::new(),
173            uri_to_id: HashMap::new(),
174            id_to_uri: Vec::new(),
175            entry_point: None,
176            top_layer: 0,
177            pending_batch: Vec::new(),
178            batch_stats: GpuBatchStats::default(),
179            rng_state: 0x9e3779b97f4a7c15,
180        }
181    }
182
183    // ── public API ────────────────────────────────────────────────────────────
184
185    /// Insert a vector into the index.
186    ///
187    /// The vector is first placed into the staging batch.  When the batch
188    /// reaches `config.batch_size` it is flushed automatically.
189    pub fn insert(&mut self, uri: String, vector: Vec<f32>) -> Result<()> {
190        if self.uri_to_id.contains_key(&uri) {
191            return Err(anyhow!("URI '{}' already exists in index", uri));
192        }
193        self.pending_batch.push((uri, vector));
194        if self.pending_batch.len() >= self.config.batch_size {
195            self.flush_batch()?;
196        }
197        Ok(())
198    }
199
200    /// Force-flush the pending staging batch regardless of its size.
201    pub fn flush(&mut self) -> Result<()> {
202        if !self.pending_batch.is_empty() {
203            self.flush_batch()?;
204        }
205        Ok(())
206    }
207
208    /// Search for the `k` approximate nearest neighbours of `query`.
209    ///
210    /// Any unflushed vectors in the staging batch are **not** searched.
211    /// Call `flush` first if you need them included.
212    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
213        if self.nodes.is_empty() {
214            return Ok(Vec::new());
215        }
216
217        let entry = self
218            .entry_point
219            .ok_or_else(|| anyhow!("No entry point set"))?;
220
221        // Greedy descent through upper layers
222        let mut current_nearest = entry;
223        for layer in (1..=self.top_layer).rev() {
224            current_nearest = self.greedy_search_layer(query, current_nearest, layer)?;
225        }
226
227        // Beam search at layer 0
228        let candidates =
229            self.beam_search_layer(query, current_nearest, 0, self.config.ef_search)?;
230
231        // Collect top-k
232        let mut results: Vec<(String, f32)> = candidates
233            .into_iter()
234            .map(|c| {
235                let uri = self.id_to_uri[c.id].clone();
236                (uri, c.dist)
237            })
238            .collect();
239
240        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
241        results.truncate(k);
242        Ok(results)
243    }
244
245    /// Number of vectors currently committed to the graph (excludes pending batch).
246    pub fn len(&self) -> usize {
247        self.nodes.len()
248    }
249
250    /// Number of vectors in the pending (un-flushed) batch.
251    pub fn pending_count(&self) -> usize {
252        self.pending_batch.len()
253    }
254
255    /// Returns `true` if the graph contains no committed nodes.
256    pub fn is_empty(&self) -> bool {
257        self.nodes.is_empty()
258    }
259
260    /// Return a snapshot of current statistics.
261    pub fn stats(&self) -> GpuHnswStats {
262        GpuHnswStats {
263            vector_count: self.nodes.len(),
264            layer_count: self.top_layer + 1,
265            batch_stats: self.batch_stats.clone(),
266            batch_size: self.config.batch_size,
267            ef_search: self.config.ef_search,
268        }
269    }
270
271    /// Access the configuration.
272    pub fn config(&self) -> &GpuHnswConfig {
273        &self.config
274    }
275
276    // ── private internals ─────────────────────────────────────────────────────
277
278    /// Flush the pending staging batch by constructing graph edges in parallel.
279    ///
280    /// We simulate GPU batching by distributing the level-assignment step
281    /// (embarrassingly parallel) across `gpu_workers` threads, then serially
282    /// inserting each node into the graph (graph mutation requires the global
283    /// state so cannot be done in parallel without complex locking).
284    fn flush_batch(&mut self) -> Result<()> {
285        let batch = std::mem::take(&mut self.pending_batch);
286        let batch_len = batch.len();
287
288        // ── Simulate GPU batch: compute random levels in parallel ─────────
289        let workers = self.config.gpu_workers.max(1);
290        let level_multiplier = self.config.level_multiplier;
291
292        // Share seeds for parallel workers
293        let seeds: Vec<u64> = (0..batch_len)
294            .map(|i| {
295                let mut s = self
296                    .rng_state
297                    .wrapping_add((i as u64).wrapping_mul(0x9e3779b97f4a7c15));
298                // xorshift64
299                s ^= s << 13;
300                s ^= s >> 7;
301                s ^= s << 17;
302                s
303            })
304            .collect();
305
306        // Assign one seed to rng_state for next call
307        self.rng_state = seeds.last().copied().unwrap_or(self.rng_state);
308
309        // Parallel level computation (simulate GPU kernel)
310        let levels_shared: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(vec![0usize; batch_len]));
311        let chunk_size = (batch_len + workers - 1) / workers;
312
313        std::thread::scope(|scope| {
314            let seeds_ref = &seeds;
315            let levels_ref = Arc::clone(&levels_shared);
316            for worker_id in 0..workers {
317                let start = worker_id * chunk_size;
318                let end = (start + chunk_size).min(batch_len);
319                if start >= end {
320                    break;
321                }
322                let lm = level_multiplier;
323                let levels_clone = Arc::clone(&levels_ref);
324                scope.spawn(move || {
325                    let mut local_results = Vec::with_capacity(end - start);
326                    for (i, &seed) in seeds_ref.iter().enumerate().skip(start).take(end - start) {
327                        // Use the seed to generate a uniform float in (0, 1)
328                        let uniform = (seed >> 11) as f64 / (u64::MAX >> 11) as f64;
329                        // HNSW level = floor(-ln(uniform) * level_multiplier)
330                        let level = if uniform > 0.0 {
331                            (-uniform.ln() * lm).floor() as usize
332                        } else {
333                            0
334                        };
335                        local_results.push((i, level));
336                    }
337                    // Write back
338                    if let Ok(mut guard) = levels_clone.lock() {
339                        for (idx, lvl) in local_results {
340                            guard[idx] = lvl;
341                        }
342                    }
343                });
344            }
345        });
346
347        let levels = Arc::try_unwrap(levels_shared)
348            .map_err(|_| anyhow!("Arc unwrap failed"))?
349            .into_inner()
350            .map_err(|e| anyhow!("Mutex poisoned: {e}"))?;
351
352        // ── Serial graph construction ──────────────────────────────────────
353        let mut dist_count = 0u64;
354        for (item_idx, (uri, vector)) in batch.into_iter().enumerate() {
355            let node_level = levels[item_idx];
356            let node_id = self.nodes.len();
357            let layer_count = node_level + 1;
358
359            // Pre-allocate layers (extend top_layer if needed)
360            let total_layers = self.top_layer.max(node_level) + 1;
361            let new_node = HnswNode::new(vector.clone(), node_level, total_layers);
362            self.nodes.push(new_node);
363            self.uri_to_id.insert(uri.clone(), node_id);
364            self.id_to_uri.push(uri);
365
366            if let Some(ep) = self.entry_point {
367                // Extend existing nodes' neighbor lists if necessary
368                let current_max = self.top_layer;
369                if node_level > current_max {
370                    // Extend all existing node neighbor vecs
371                    for n in &mut self.nodes {
372                        let extra = (node_level + 1).saturating_sub(n.neighbors.len());
373                        n.neighbors
374                            .extend(std::iter::repeat_with(Vec::new).take(extra));
375                    }
376                    self.top_layer = node_level;
377                }
378
379                // Greedy descent through layers above node_level
380                let mut current_ep = ep;
381                for layer in (layer_count..=self.top_layer).rev() {
382                    current_ep =
383                        self.greedy_search_layer_mut(&vector, current_ep, layer, &mut dist_count)?;
384                }
385
386                // For each layer from min(node_level, top_layer) down to 0
387                let max_conns = self.config.max_connections;
388                let max_conns_l0 = self.config.max_connections_layer0;
389
390                for layer in (0..layer_count).rev() {
391                    let ef = self.config.ef_construction;
392                    let candidates = self.beam_search_layer_with_count(
393                        &vector,
394                        current_ep,
395                        layer,
396                        ef,
397                        &mut dist_count,
398                    )?;
399
400                    // Pick the best neighbors (simple select-n)
401                    let m = if layer == 0 { max_conns_l0 } else { max_conns };
402                    let selected: Vec<usize> = candidates.iter().take(m).map(|c| c.id).collect();
403
404                    // Add bidirectional edges
405                    self.nodes[node_id].neighbors[layer].extend_from_slice(&selected);
406
407                    for &neighbor_id in &selected {
408                        // Prune neighbor's list if over capacity
409                        self.nodes[neighbor_id].neighbors[layer].push(node_id);
410                        let cap = if layer == 0 { max_conns_l0 } else { max_conns };
411                        self.prune_connections(neighbor_id, layer, cap);
412                    }
413
414                    // Update ep for next layer
415                    if let Some(best) = candidates.first() {
416                        current_ep = best.id;
417                    }
418                }
419
420                if node_level > current_max {
421                    self.entry_point = Some(node_id);
422                }
423            } else {
424                // First node — just set as entry point.
425                // Ensure neighbor vecs have the right length.
426                let total = self.top_layer.max(node_level) + 1;
427                let extra = total.saturating_sub(self.nodes[node_id].neighbors.len());
428                self.nodes[node_id]
429                    .neighbors
430                    .extend(std::iter::repeat_with(Vec::new).take(extra));
431                self.top_layer = node_level;
432                self.entry_point = Some(node_id);
433            }
434        }
435
436        // Update stats
437        let time_us = (batch_len as f64 * 12.5) + 100.0; // Simulated GPU time
438        let prev_batches = self.batch_stats.batches_processed as f64;
439        let new_avg = if prev_batches > 0.0 {
440            (self.batch_stats.avg_batch_us * prev_batches + time_us) / (prev_batches + 1.0)
441        } else {
442            time_us
443        };
444
445        self.batch_stats.batches_processed += 1;
446        self.batch_stats.vectors_inserted += batch_len as u64;
447        self.batch_stats.distance_computations += dist_count;
448        self.batch_stats.avg_batch_us = new_avg;
449
450        Ok(())
451    }
452
453    /// Greedy single-hop search at `layer` (used during graph construction descent).
454    fn greedy_search_layer(&self, query: &[f32], entry: usize, layer: usize) -> Result<usize> {
455        let mut current = entry;
456        let mut current_dist = self.euclidean_sq(query, &self.nodes[current].vector);
457
458        loop {
459            let mut improved = false;
460            for &neighbor in &self.nodes[current].neighbors[layer] {
461                if neighbor >= self.nodes.len() {
462                    continue;
463                }
464                let d = self.euclidean_sq(query, &self.nodes[neighbor].vector);
465                if d < current_dist {
466                    current_dist = d;
467                    current = neighbor;
468                    improved = true;
469                }
470            }
471            if !improved {
472                break;
473            }
474        }
475        Ok(current)
476    }
477
478    /// Greedy search during graph construction (mutable, tracks distance count).
479    fn greedy_search_layer_mut(
480        &self,
481        query: &[f32],
482        entry: usize,
483        layer: usize,
484        dist_count: &mut u64,
485    ) -> Result<usize> {
486        let mut current = entry;
487        *dist_count += 1;
488        let mut current_dist = self.euclidean_sq(query, &self.nodes[current].vector);
489
490        loop {
491            let mut improved = false;
492            let neighbors = self.nodes[current].neighbors[layer].clone();
493            for neighbor in neighbors {
494                if neighbor >= self.nodes.len() {
495                    continue;
496                }
497                *dist_count += 1;
498                let d = self.euclidean_sq(query, &self.nodes[neighbor].vector);
499                if d < current_dist {
500                    current_dist = d;
501                    current = neighbor;
502                    improved = true;
503                }
504            }
505            if !improved {
506                break;
507            }
508        }
509        Ok(current)
510    }
511
512    /// Beam (ef) search at a specific layer — returns ordered candidate list (closest first).
513    fn beam_search_layer(
514        &self,
515        query: &[f32],
516        entry: usize,
517        layer: usize,
518        ef: usize,
519    ) -> Result<Vec<Candidate>> {
520        let mut dummy = 0u64;
521        self.beam_search_layer_with_count(query, entry, layer, ef, &mut dummy)
522    }
523
524    /// Beam search with distance counter (used during construction).
525    fn beam_search_layer_with_count(
526        &self,
527        query: &[f32],
528        entry: usize,
529        layer: usize,
530        ef: usize,
531        dist_count: &mut u64,
532    ) -> Result<Vec<Candidate>> {
533        if entry >= self.nodes.len() {
534            return Ok(Vec::new());
535        }
536
537        let mut visited: HashSet<usize> = HashSet::new();
538        visited.insert(entry);
539
540        *dist_count += 1;
541        let d_entry = self.euclidean_sq(query, &self.nodes[entry].vector);
542
543        // candidates = max-heap (farthest first, for pruning)
544        let mut candidates: BinaryHeap<Candidate> = BinaryHeap::new();
545        // to_visit = min-heap (closest first, for expansion)
546        let mut to_visit: BinaryHeap<std::cmp::Reverse<Candidate>> = BinaryHeap::new();
547
548        candidates.push(Candidate {
549            dist: d_entry,
550            id: entry,
551        });
552        to_visit.push(std::cmp::Reverse(Candidate {
553            dist: d_entry,
554            id: entry,
555        }));
556
557        while let Some(std::cmp::Reverse(current)) = to_visit.pop() {
558            // Terminate if current candidate is farther than worst in result set
559            if let Some(worst) = candidates.peek() {
560                if current.dist > worst.dist {
561                    break;
562                }
563            }
564
565            let neighbors = if layer < self.nodes[current.id].neighbors.len() {
566                self.nodes[current.id].neighbors[layer].clone()
567            } else {
568                Vec::new()
569            };
570
571            for neighbor in neighbors {
572                if neighbor >= self.nodes.len() || visited.contains(&neighbor) {
573                    continue;
574                }
575                visited.insert(neighbor);
576
577                *dist_count += 1;
578                let d = self.euclidean_sq(query, &self.nodes[neighbor].vector);
579                let worst_dist = candidates.peek().map(|c| c.dist).unwrap_or(f32::MAX);
580
581                if d < worst_dist || candidates.len() < ef {
582                    candidates.push(Candidate {
583                        dist: d,
584                        id: neighbor,
585                    });
586                    to_visit.push(std::cmp::Reverse(Candidate {
587                        dist: d,
588                        id: neighbor,
589                    }));
590                    if candidates.len() > ef {
591                        candidates.pop();
592                    }
593                }
594            }
595        }
596
597        // Convert to sorted vec (closest first)
598        let mut result: Vec<Candidate> = candidates.into_vec();
599        result.sort_by(|a, b| {
600            a.dist
601                .partial_cmp(&b.dist)
602                .unwrap_or(std::cmp::Ordering::Equal)
603        });
604        Ok(result)
605    }
606
607    /// Prune a node's neighbor list at `layer` to at most `cap` entries.
608    fn prune_connections(&mut self, node_id: usize, layer: usize, cap: usize) {
609        if self.nodes[node_id].neighbors[layer].len() > cap {
610            // Collect with distances from node's own vector
611            let node_vec = self.nodes[node_id].vector.clone();
612            let mut with_dist: Vec<(usize, f32)> = self.nodes[node_id].neighbors[layer]
613                .iter()
614                .filter(|&&n| n < self.nodes.len())
615                .map(|&n| (n, self.euclidean_sq(&node_vec, &self.nodes[n].vector)))
616                .collect();
617            with_dist.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
618            with_dist.truncate(cap);
619            self.nodes[node_id].neighbors[layer] = with_dist.into_iter().map(|(n, _)| n).collect();
620        }
621    }
622
623    /// Squared Euclidean distance (used as distance metric; no sqrt needed for ordering).
624    #[inline]
625    fn euclidean_sq(&self, a: &[f32], b: &[f32]) -> f32 {
626        a.iter()
627            .zip(b.iter())
628            .map(|(&x, &y)| {
629                let d = x - y;
630                d * d
631            })
632            .sum()
633    }
634}
635
636// ─────────────────────────────────────────────────────────────────────────────
637// Tests
638// ─────────────────────────────────────────────────────────────────────────────
639
640#[cfg(test)]
641mod tests {
642    use super::*;
643    use anyhow::Result;
644
645    fn make_index(batch_size: usize) -> GpuHnswIndex {
646        let config = GpuHnswConfig {
647            batch_size,
648            max_connections: 8,
649            max_connections_layer0: 16,
650            ef_construction: 20,
651            ef_search: 16,
652            gpu_workers: 2,
653            ..Default::default()
654        };
655        GpuHnswIndex::new(config)
656    }
657
658    fn vec2(x: f32, y: f32) -> Vec<f32> {
659        vec![x, y]
660    }
661
662    // ── basic functionality ────────────────────────────────────────────────
663
664    #[test]
665    fn test_new_index_is_empty() {
666        let index = make_index(4);
667        assert!(index.is_empty());
668        assert_eq!(index.len(), 0);
669        assert_eq!(index.pending_count(), 0);
670    }
671
672    #[test]
673    fn test_insert_pending_accumulates() -> Result<()> {
674        let mut index = make_index(8);
675        index.insert("a".to_string(), vec2(1.0, 0.0))?;
676        index.insert("b".to_string(), vec2(0.0, 1.0))?;
677        assert_eq!(index.pending_count(), 2);
678        assert_eq!(index.len(), 0); // Not yet flushed
679        Ok(())
680    }
681
682    #[test]
683    fn test_auto_flush_on_batch_full() -> Result<()> {
684        let mut index = make_index(3);
685        for i in 0..3 {
686            index.insert(format!("v{}", i), vec![i as f32, 0.0])?;
687        }
688        // Batch of 3 triggers auto-flush
689        assert_eq!(index.len(), 3);
690        assert_eq!(index.pending_count(), 0);
691        Ok(())
692    }
693
694    #[test]
695    fn test_manual_flush() -> Result<()> {
696        let mut index = make_index(16);
697        index.insert("x".to_string(), vec2(1.0, 1.0))?;
698        assert_eq!(index.pending_count(), 1);
699        index.flush()?;
700        assert_eq!(index.len(), 1);
701        assert_eq!(index.pending_count(), 0);
702        Ok(())
703    }
704
705    #[test]
706    fn test_search_empty_returns_empty() -> Result<()> {
707        let index = make_index(4);
708        let result = index.search(&[1.0, 0.0], 5)?;
709        assert!(result.is_empty());
710        Ok(())
711    }
712
713    #[test]
714    fn test_search_single_vector() -> Result<()> {
715        let mut index = make_index(4);
716        index.insert("only".to_string(), vec2(1.0, 0.0))?;
717        index.flush()?;
718        let result = index.search(&[1.0, 0.0], 1)?;
719        assert_eq!(result.len(), 1);
720        assert_eq!(result[0].0, "only");
721        Ok(())
722    }
723
724    #[test]
725    fn test_search_nearest_neighbour() -> Result<()> {
726        let mut index = make_index(8);
727        index.insert("origin".to_string(), vec2(0.0, 0.0))?;
728        index.insert("right".to_string(), vec2(10.0, 0.0))?;
729        index.insert("up".to_string(), vec2(0.0, 10.0))?;
730        index.flush()?;
731
732        // Query near origin
733        let result = index.search(&[0.1, 0.1], 1)?;
734        assert_eq!(result.len(), 1);
735        assert_eq!(result[0].0, "origin");
736        Ok(())
737    }
738
739    #[test]
740    fn test_search_top_k_ordering() -> Result<()> {
741        let mut index = make_index(4);
742        for i in 0..4 {
743            index.insert(format!("v{}", i), vec![i as f32 * 2.0, 0.0])?;
744        }
745        index.flush()?;
746
747        let result = index.search(&[0.0, 0.0], 2)?;
748        assert!(result.len() <= 2);
749        // Closest should come first
750        if result.len() == 2 {
751            assert!(
752                result[0].1 <= result[1].1,
753                "Results should be ordered by distance"
754            );
755        }
756        Ok(())
757    }
758
759    #[test]
760    fn test_duplicate_uri_rejected() -> Result<()> {
761        let mut index = make_index(8);
762        index.insert("dup".to_string(), vec2(1.0, 0.0))?;
763        index.flush()?;
764        let err = index.insert("dup".to_string(), vec2(2.0, 0.0));
765        assert!(err.is_err());
766        Ok(())
767    }
768
769    #[test]
770    fn test_stats_accumulate() -> Result<()> {
771        let mut index = make_index(4);
772        for i in 0..8 {
773            index.insert(format!("v{}", i), vec![i as f32, 0.0])?;
774        }
775        index.flush()?; // flush any remainder
776        let stats = index.stats();
777        assert_eq!(stats.vector_count, 8);
778        assert!(stats.batch_stats.batches_processed >= 2);
779        assert_eq!(stats.batch_stats.vectors_inserted, 8);
780        Ok(())
781    }
782
783    #[test]
784    fn test_stats_avg_batch_time_positive() -> Result<()> {
785        let mut index = make_index(2);
786        index.insert("a".to_string(), vec2(0.0, 0.0))?;
787        index.insert("b".to_string(), vec2(1.0, 0.0))?;
788        let stats = index.stats();
789        assert!(stats.batch_stats.avg_batch_us > 0.0);
790        Ok(())
791    }
792
793    #[test]
794    fn test_larger_dataset_correctness() -> Result<()> {
795        let mut index = make_index(10);
796        // Add 50 vectors in a line along x-axis
797        for i in 0..50 {
798            index.insert(format!("v{}", i), vec![i as f32, 0.0])?;
799        }
800        index.flush()?;
801
802        assert_eq!(index.len(), 50);
803
804        // Nearest to x=25 should be v25
805        let result = index.search(&[25.0, 0.0], 3)?;
806        assert!(!result.is_empty());
807        // The closest vector should be very close to 25.0
808        assert!(result[0].1 < 2.0_f32);
809        Ok(())
810    }
811
812    #[test]
813    fn test_multi_batch_flush_consistency() -> Result<()> {
814        let mut index = make_index(5);
815        for i in 0..20 {
816            index.insert(format!("v{}", i), vec![i as f32, (i % 3) as f32])?;
817        }
818        index.flush()?;
819        let stats = index.stats();
820        assert_eq!(stats.vector_count, 20);
821        assert!(stats.batch_stats.batches_processed >= 4);
822        Ok(())
823    }
824
825    #[test]
826    fn test_config_accessors() {
827        let config = GpuHnswConfig {
828            batch_size: 32,
829            max_connections: 12,
830            ..Default::default()
831        };
832        let index = GpuHnswIndex::new(config);
833        assert_eq!(index.config().batch_size, 32);
834        assert_eq!(index.config().max_connections, 12);
835    }
836
837    #[test]
838    fn test_gpu_workers_default() {
839        let config = GpuHnswConfig::default();
840        assert_eq!(config.gpu_workers, 4);
841        assert_eq!(config.batch_size, 64);
842    }
843
844    #[test]
845    fn test_single_dimension_vectors() -> Result<()> {
846        let mut index = make_index(4);
847        index.insert("a".to_string(), vec![1.0])?;
848        index.insert("b".to_string(), vec![5.0])?;
849        index.insert("c".to_string(), vec![10.0])?;
850        index.insert("d".to_string(), vec![3.0])?;
851        index.flush()?;
852        let result = index.search(&[4.5], 2)?;
853        assert!(!result.is_empty());
854        Ok(())
855    }
856
857    #[test]
858    fn test_high_dimensional_vectors() -> Result<()> {
859        let dim = 128;
860        let mut index = make_index(8);
861        for i in 0..16 {
862            let v: Vec<f32> = (0..dim).map(|d| (i * dim + d) as f32 * 0.01).collect();
863            index.insert(format!("v{}", i), v)?;
864        }
865        index.flush()?;
866        let query: Vec<f32> = (0..dim).map(|d| d as f32 * 0.01).collect();
867        let result = index.search(&query, 3)?;
868        assert!(!result.is_empty());
869        assert_eq!(result[0].0, "v0"); // v0 is at 0..dim * 0.01
870        Ok(())
871    }
872
873    #[test]
874    fn test_search_returns_at_most_k() -> Result<()> {
875        let mut index = make_index(4);
876        for i in 0..10 {
877            index.insert(format!("v{}", i), vec![i as f32])?;
878        }
879        index.flush()?;
880        let result = index.search(&[5.0], 3)?;
881        assert!(result.len() <= 3);
882        Ok(())
883    }
884
885    #[test]
886    fn test_distance_computations_counted() -> Result<()> {
887        let mut index = make_index(4);
888        for i in 0..8 {
889            index.insert(format!("v{}", i), vec![i as f32, 0.0])?;
890        }
891        index.flush()?;
892        let stats = index.stats();
893        // Some distance computations should have occurred during construction
894        assert!(stats.batch_stats.distance_computations > 0);
895        Ok(())
896    }
897
898    #[test]
899    fn test_pending_not_searched() -> Result<()> {
900        let mut index = make_index(100); // Large batch so nothing auto-flushes
901        index.insert("pending".to_string(), vec2(0.0, 0.0))?;
902        // pending_count = 1, len = 0
903        assert_eq!(index.pending_count(), 1);
904        assert_eq!(index.len(), 0);
905        // Search on empty committed graph
906        let result = index.search(&[0.0, 0.0], 1)?;
907        assert!(result.is_empty());
908        Ok(())
909    }
910
911    #[test]
912    fn test_flush_empty_pending_noop() -> Result<()> {
913        let mut index = make_index(4);
914        index.insert("a".to_string(), vec2(1.0, 0.0))?;
915        index.flush()?;
916        // Second flush on empty pending
917        index.flush()?;
918        assert_eq!(index.len(), 1);
919        Ok(())
920    }
921
922    #[test]
923    fn test_layer_count_in_stats() -> Result<()> {
924        let mut index = make_index(4);
925        index.insert("a".to_string(), vec2(0.0, 0.0))?;
926        index.flush()?;
927        let stats = index.stats();
928        assert!(stats.layer_count >= 1);
929        Ok(())
930    }
931}