oxirs_vec/hnsw/
index.rs

1//! Main HNSW index implementation
2
3use crate::hnsw::query_cache::{QueryCache, QueryCacheConfig};
4use crate::hnsw::{HnswConfig, HnswPerformanceStats, Node};
5use crate::{Vector, VectorIndex};
6use anyhow::Result;
7use std::collections::HashMap;
8use std::sync::atomic::AtomicU64;
9#[cfg(feature = "gpu")]
10use std::sync::Arc;
11
12#[cfg(feature = "gpu")]
13use crate::gpu::GpuAccelerator;
14
15/// HNSW index implementation
16pub struct HnswIndex {
17    config: HnswConfig,
18    nodes: Vec<Node>,
19    uri_to_id: HashMap<String, usize>,
20    entry_point: Option<usize>,
21    level_multiplier: f64,
22    rng_state: u64,
23    /// Performance statistics
24    stats: HnswPerformanceStats,
25    /// Distance calculation count (for metrics)
26    distance_calculations: AtomicU64,
27    /// Query result cache for improved performance
28    query_cache: Option<QueryCache>,
29    /// GPU accelerator for CUDA-accelerated operations
30    #[cfg(feature = "gpu")]
31    gpu_accelerator: Option<Arc<GpuAccelerator>>,
32    /// Multi-GPU accelerators for distributed computation
33    #[cfg(feature = "gpu")]
34    multi_gpu_accelerators: Vec<Arc<GpuAccelerator>>,
35}
36
37impl HnswIndex {
38    pub fn new(config: HnswConfig) -> Result<Self> {
39        // Initialize GPU accelerators if enabled
40        #[cfg(feature = "gpu")]
41        let (gpu_accelerator, multi_gpu_accelerators) = if config.enable_gpu {
42            let gpu_config = config.gpu_config.clone().unwrap_or_default();
43
44            if config.enable_multi_gpu && gpu_config.preferred_gpu_ids.len() > 1 {
45                // Initialize multi-GPU setup
46                let mut accelerators = Vec::new();
47                for &gpu_id in &gpu_config.preferred_gpu_ids {
48                    let mut gpu_conf = gpu_config.clone();
49                    gpu_conf.device_id = gpu_id;
50                    let accelerator = GpuAccelerator::new(gpu_conf)?;
51                    accelerators.push(Arc::new(accelerator));
52                }
53                (None, accelerators)
54            } else {
55                // Single GPU setup
56                let accelerator = GpuAccelerator::new(gpu_config)?;
57                (Some(Arc::new(accelerator)), Vec::new())
58            }
59        } else {
60            (None, Vec::new())
61        };
62
63        // Initialize query cache with default configuration
64        let query_cache = Some(QueryCache::new(QueryCacheConfig::default()));
65
66        Ok(Self {
67            config,
68            nodes: Vec::new(),
69            uri_to_id: HashMap::new(),
70            entry_point: None,
71            level_multiplier: 1.0 / (2.0_f64).ln(),
72            rng_state: 42, // Simple deterministic seed
73            stats: HnswPerformanceStats::default(),
74            distance_calculations: AtomicU64::new(0),
75            query_cache,
76            #[cfg(feature = "gpu")]
77            gpu_accelerator,
78            #[cfg(feature = "gpu")]
79            multi_gpu_accelerators,
80        })
81    }
82
83    /// Create a new HNSW index without GPU acceleration (for compatibility)
84    pub fn new_cpu_only(config: HnswConfig) -> Self {
85        let mut cpu_config = config;
86        cpu_config.enable_gpu = false;
87        cpu_config.enable_multi_gpu = false;
88
89        // Initialize query cache with default configuration
90        let query_cache = Some(QueryCache::new(QueryCacheConfig::default()));
91
92        Self {
93            config: cpu_config,
94            nodes: Vec::new(),
95            uri_to_id: HashMap::new(),
96            entry_point: None,
97            level_multiplier: 1.0 / (2.0_f64).ln(),
98            rng_state: 42,
99            stats: HnswPerformanceStats::default(),
100            distance_calculations: AtomicU64::new(0),
101            query_cache,
102            #[cfg(feature = "gpu")]
103            gpu_accelerator: None,
104            #[cfg(feature = "gpu")]
105            multi_gpu_accelerators: Vec::new(),
106        }
107    }
108
109    /// Enable query result caching with custom configuration
110    pub fn enable_query_cache(&mut self, config: QueryCacheConfig) {
111        self.query_cache = Some(QueryCache::new(config));
112    }
113
114    /// Disable query result caching
115    pub fn disable_query_cache(&mut self) {
116        self.query_cache = None;
117    }
118
119    /// Get query cache statistics if caching is enabled
120    pub fn get_query_cache_stats(&self) -> Option<crate::hnsw::query_cache::QueryCacheStats> {
121        self.query_cache.as_ref().map(|cache| cache.get_stats())
122    }
123
124    /// Clear query cache if caching is enabled
125    pub fn clear_query_cache(&self) {
126        if let Some(ref cache) = self.query_cache {
127            cache.clear();
128        }
129    }
130
131    /// Get reference to query cache
132    pub(crate) fn query_cache(&self) -> &Option<QueryCache> {
133        &self.query_cache
134    }
135
136    /// Get the URI to ID mapping
137    pub fn uri_to_id(&self) -> &HashMap<String, usize> {
138        &self.uri_to_id
139    }
140
141    /// Get mutable URI to ID mapping
142    pub fn uri_to_id_mut(&mut self) -> &mut HashMap<String, usize> {
143        &mut self.uri_to_id
144    }
145
146    /// Get the nodes
147    pub fn nodes(&self) -> &Vec<Node> {
148        &self.nodes
149    }
150
151    /// Get mutable nodes
152    pub fn nodes_mut(&mut self) -> &mut Vec<Node> {
153        &mut self.nodes
154    }
155
156    /// Get the entry point
157    pub fn entry_point(&self) -> Option<usize> {
158        self.entry_point
159    }
160
161    /// Set the entry point
162    pub fn set_entry_point(&mut self, entry_point: Option<usize>) {
163        self.entry_point = entry_point;
164    }
165
166    /// Get the configuration
167    pub fn config(&self) -> &HnswConfig {
168        &self.config
169    }
170
171    /// Get performance statistics
172    pub fn get_stats(&self) -> &HnswPerformanceStats {
173        &self.stats
174    }
175
176    /// Check if GPU acceleration is available and enabled
177    #[cfg(feature = "gpu")]
178    pub fn is_gpu_available(&self) -> bool {
179        self.config.enable_gpu
180            && (self.gpu_accelerator.is_some() || !self.multi_gpu_accelerators.is_empty())
181    }
182
183    #[cfg(not(feature = "gpu"))]
184    pub fn is_gpu_available(&self) -> bool {
185        false
186    }
187
188    /// Get GPU performance statistics
189    #[cfg(feature = "gpu")]
190    pub fn get_gpu_stats(&self) -> Option<crate::gpu::GpuPerformanceStats> {
191        if let Some(ref _accelerator) = self.gpu_accelerator {
192            // Would need to implement stats retrieval in GpuAccelerator
193            None // Placeholder
194        } else {
195            None
196        }
197    }
198
199    /// Get reference to GPU accelerator
200    #[cfg(feature = "gpu")]
201    pub fn gpu_accelerator(&self) -> Option<&Arc<GpuAccelerator>> {
202        self.gpu_accelerator.as_ref()
203    }
204
205    /// Get reference to multi-GPU accelerators
206    #[cfg(feature = "gpu")]
207    pub fn multi_gpu_accelerators(&self) -> &Vec<Arc<GpuAccelerator>> {
208        &self.multi_gpu_accelerators
209    }
210
211    /// Get the number of nodes in the index
212    pub fn len(&self) -> usize {
213        self.nodes.len()
214    }
215
216    /// Check if the index is empty
217    pub fn is_empty(&self) -> bool {
218        self.nodes.is_empty()
219    }
220
221    // Duplicate methods removed - already defined above
222
223    /// Get mutable reference to stats
224    pub fn stats_mut(&mut self) -> &mut HnswPerformanceStats {
225        &mut self.stats
226    }
227
228    /// Get level multiplier
229    pub fn level_multiplier(&self) -> f64 {
230        self.level_multiplier
231    }
232
233    /// Get mutable reference to RNG state
234    pub fn rng_state_mut(&mut self) -> &mut u64 {
235        &mut self.rng_state
236    }
237
238    /// Get RNG state
239    pub fn rng_state(&self) -> u64 {
240        self.rng_state
241    }
242}
243
244impl VectorIndex for HnswIndex {
245    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
246        // Use the add_vector implementation from construction module
247        self.add_vector(uri, vector)
248    }
249
250    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
251        // Use the proper HNSW search algorithm from search.rs
252        // This implements hierarchical navigable small world graph traversal
253        HnswIndex::search_knn(self, query, k)
254    }
255
256    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
257        // Use the proper HNSW range search algorithm from search.rs
258        // This implements distance-based filtering with graph traversal
259        HnswIndex::range_search(self, query, threshold)
260    }
261
262    fn get_vector(&self, uri: &str) -> Option<&Vector> {
263        self.uri_to_id
264            .get(uri)
265            .and_then(|&id| self.nodes.get(id))
266            .map(|node| &node.vector)
267    }
268}
269
270impl HnswIndex {
271    /// Remove a vector by its URI (not part of VectorIndex trait)
272    pub fn remove(&mut self, uri: &str) -> Result<()> {
273        // Implementation of vector removal from HNSW index
274
275        // Find the node ID for the URI
276        let node_id = if let Some(&id) = self.uri_to_id.get(uri) {
277            id
278        } else {
279            return Err(anyhow::anyhow!("URI not found: {}", uri));
280        };
281
282        // Remove the node from all its connections at all levels
283        if let Some(node) = self.nodes.get(node_id) {
284            let node_connections = node.connections.clone();
285
286            // Remove this node from all connected nodes
287            for (level, connections) in node_connections.iter().enumerate() {
288                for &connected_id in connections {
289                    if let Some(connected_node) = self.nodes.get_mut(connected_id) {
290                        connected_node.remove_connection(level, node_id);
291                    }
292                }
293            }
294        }
295
296        // If this node was the entry point, find a new entry point
297        if self.entry_point == Some(node_id) {
298            self.entry_point = None;
299
300            // Find a node with the highest level as the new entry point
301            let mut highest_level = 0;
302            let mut new_entry_point = None;
303
304            for (id, node) in self.nodes.iter().enumerate() {
305                if id != node_id && node.level() >= highest_level {
306                    highest_level = node.level();
307                    new_entry_point = Some(id);
308                }
309            }
310
311            self.entry_point = new_entry_point;
312        }
313
314        // Remove the node from URI mapping
315        self.uri_to_id.remove(uri);
316
317        // Mark the node as removed (we don't actually remove it to avoid ID shifts)
318        // In a production implementation, you might use a tombstone approach
319        // or compact the index periodically
320        if let Some(node) = self.nodes.get_mut(node_id) {
321            node.connections.clear();
322            // We could add a "deleted" flag here if needed
323        }
324
325        // Update statistics
326        self.stats
327            .total_deletions
328            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
329
330        Ok(())
331    }
332
333    /// Update a vector by its URI (not part of VectorIndex trait)
334    pub fn update(&mut self, uri: String, vector: Vector) -> Result<()> {
335        // Implementation of vector update in HNSW index
336        // This is a simplified approach: remove and re-add the vector
337
338        // Check if the URI exists
339        if !self.uri_to_id.contains_key(&uri) {
340            return Err(anyhow::anyhow!("URI not found: {}", uri));
341        }
342
343        // Store the current connections before removal for potential optimization
344        let node_id = self.uri_to_id[&uri];
345        let _old_connections = self.nodes.get(node_id).map(|node| node.connections.clone());
346
347        // Remove the old vector
348        self.remove(&uri)?;
349
350        // Add the new vector with the same URI
351        self.insert(uri.clone(), vector)?;
352
353        // Update statistics
354        self.stats
355            .total_updates
356            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
357
358        // In a more sophisticated implementation, we could:
359        // 1. Check if the vector is similar enough to keep some connections
360        // 2. Incrementally update the graph structure
361        // 3. Use lazy updates to batch multiple updates
362
363        Ok(())
364    }
365
366    /// Clear all vectors from the index (not part of VectorIndex trait)
367    pub fn clear(&mut self) -> Result<()> {
368        self.nodes.clear();
369        self.uri_to_id.clear();
370        self.entry_point = None;
371        Ok(())
372    }
373
374    /// Get the number of vectors in the index (not part of VectorIndex trait)
375    pub fn size(&self) -> usize {
376        self.nodes.len()
377    }
378}