oxirs_vec/hnsw/
index.rs

1//! Main HNSW index implementation
2
3use crate::hnsw::{HnswConfig, HnswPerformanceStats, Node};
4use crate::{Vector, VectorIndex};
5use anyhow::Result;
6use std::collections::HashMap;
7use std::sync::atomic::AtomicU64;
8#[cfg(feature = "gpu")]
9use std::sync::Arc;
10
11#[cfg(feature = "gpu")]
12use crate::gpu::GpuAccelerator;
13
14/// HNSW index implementation
15pub struct HnswIndex {
16    config: HnswConfig,
17    nodes: Vec<Node>,
18    uri_to_id: HashMap<String, usize>,
19    entry_point: Option<usize>,
20    level_multiplier: f64,
21    rng_state: u64,
22    /// Performance statistics
23    stats: HnswPerformanceStats,
24    /// Distance calculation count (for metrics)
25    distance_calculations: AtomicU64,
26    /// GPU accelerator for CUDA-accelerated operations
27    #[cfg(feature = "gpu")]
28    gpu_accelerator: Option<Arc<GpuAccelerator>>,
29    /// Multi-GPU accelerators for distributed computation
30    #[cfg(feature = "gpu")]
31    multi_gpu_accelerators: Vec<Arc<GpuAccelerator>>,
32}
33
34impl HnswIndex {
35    pub fn new(config: HnswConfig) -> Result<Self> {
36        // Initialize GPU accelerators if enabled
37        #[cfg(feature = "gpu")]
38        let (gpu_accelerator, multi_gpu_accelerators) = if config.enable_gpu {
39            let gpu_config = config.gpu_config.clone().unwrap_or_default();
40
41            if config.enable_multi_gpu && gpu_config.preferred_gpu_ids.len() > 1 {
42                // Initialize multi-GPU setup
43                let mut accelerators = Vec::new();
44                for &gpu_id in &gpu_config.preferred_gpu_ids {
45                    let mut gpu_conf = gpu_config.clone();
46                    gpu_conf.device_id = gpu_id;
47                    let accelerator = GpuAccelerator::new(gpu_conf)?;
48                    accelerators.push(Arc::new(accelerator));
49                }
50                (None, accelerators)
51            } else {
52                // Single GPU setup
53                let accelerator = GpuAccelerator::new(gpu_config)?;
54                (Some(Arc::new(accelerator)), Vec::new())
55            }
56        } else {
57            (None, Vec::new())
58        };
59
60        Ok(Self {
61            config,
62            nodes: Vec::new(),
63            uri_to_id: HashMap::new(),
64            entry_point: None,
65            level_multiplier: 1.0 / (2.0_f64).ln(),
66            rng_state: 42, // Simple deterministic seed
67            stats: HnswPerformanceStats::default(),
68            distance_calculations: AtomicU64::new(0),
69            #[cfg(feature = "gpu")]
70            gpu_accelerator,
71            #[cfg(feature = "gpu")]
72            multi_gpu_accelerators,
73        })
74    }
75
76    /// Create a new HNSW index without GPU acceleration (for compatibility)
77    pub fn new_cpu_only(config: HnswConfig) -> Self {
78        let mut cpu_config = config;
79        cpu_config.enable_gpu = false;
80        cpu_config.enable_multi_gpu = false;
81
82        Self {
83            config: cpu_config,
84            nodes: Vec::new(),
85            uri_to_id: HashMap::new(),
86            entry_point: None,
87            level_multiplier: 1.0 / (2.0_f64).ln(),
88            rng_state: 42,
89            stats: HnswPerformanceStats::default(),
90            distance_calculations: AtomicU64::new(0),
91            #[cfg(feature = "gpu")]
92            gpu_accelerator: None,
93            #[cfg(feature = "gpu")]
94            multi_gpu_accelerators: Vec::new(),
95        }
96    }
97
98    /// Get the URI to ID mapping
99    pub fn uri_to_id(&self) -> &HashMap<String, usize> {
100        &self.uri_to_id
101    }
102
103    /// Get mutable URI to ID mapping
104    pub fn uri_to_id_mut(&mut self) -> &mut HashMap<String, usize> {
105        &mut self.uri_to_id
106    }
107
108    /// Get the nodes
109    pub fn nodes(&self) -> &Vec<Node> {
110        &self.nodes
111    }
112
113    /// Get mutable nodes
114    pub fn nodes_mut(&mut self) -> &mut Vec<Node> {
115        &mut self.nodes
116    }
117
118    /// Get the entry point
119    pub fn entry_point(&self) -> Option<usize> {
120        self.entry_point
121    }
122
123    /// Set the entry point
124    pub fn set_entry_point(&mut self, entry_point: Option<usize>) {
125        self.entry_point = entry_point;
126    }
127
128    /// Get the configuration
129    pub fn config(&self) -> &HnswConfig {
130        &self.config
131    }
132
133    /// Get performance statistics
134    pub fn get_stats(&self) -> &HnswPerformanceStats {
135        &self.stats
136    }
137
138    /// Check if GPU acceleration is available and enabled
139    #[cfg(feature = "gpu")]
140    pub fn is_gpu_available(&self) -> bool {
141        self.config.enable_gpu
142            && (self.gpu_accelerator.is_some() || !self.multi_gpu_accelerators.is_empty())
143    }
144
145    #[cfg(not(feature = "gpu"))]
146    pub fn is_gpu_available(&self) -> bool {
147        false
148    }
149
150    /// Get GPU performance statistics
151    #[cfg(feature = "gpu")]
152    pub fn get_gpu_stats(&self) -> Option<crate::gpu::GpuPerformanceStats> {
153        if let Some(ref accelerator) = self.gpu_accelerator {
154            // Would need to implement stats retrieval in GpuAccelerator
155            None // Placeholder
156        } else {
157            None
158        }
159    }
160
161    /// Get reference to GPU accelerator
162    #[cfg(feature = "gpu")]
163    pub fn gpu_accelerator(&self) -> Option<&Arc<GpuAccelerator>> {
164        self.gpu_accelerator.as_ref()
165    }
166
167    /// Get reference to multi-GPU accelerators
168    #[cfg(feature = "gpu")]
169    pub fn multi_gpu_accelerators(&self) -> &Vec<Arc<GpuAccelerator>> {
170        &self.multi_gpu_accelerators
171    }
172
173    /// Get the number of nodes in the index
174    pub fn len(&self) -> usize {
175        self.nodes.len()
176    }
177
178    /// Check if the index is empty
179    pub fn is_empty(&self) -> bool {
180        self.nodes.is_empty()
181    }
182
183    // Duplicate methods removed - already defined above
184
185    /// Get mutable reference to stats
186    pub fn stats_mut(&mut self) -> &mut HnswPerformanceStats {
187        &mut self.stats
188    }
189
190    /// Get level multiplier
191    pub fn level_multiplier(&self) -> f64 {
192        self.level_multiplier
193    }
194
195    /// Get mutable reference to RNG state
196    pub fn rng_state_mut(&mut self) -> &mut u64 {
197        &mut self.rng_state
198    }
199
200    /// Get RNG state
201    pub fn rng_state(&self) -> u64 {
202        self.rng_state
203    }
204}
205
206impl VectorIndex for HnswIndex {
207    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
208        // Use the add_vector implementation from construction module
209        self.add_vector(uri, vector)
210    }
211
212    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
213        if self.nodes.is_empty() || self.entry_point.is_none() {
214            return Ok(Vec::new());
215        }
216
217        // Simple brute force search for now (placeholder)
218        // TODO: Implement proper HNSW search algorithm
219        let mut results = Vec::new();
220
221        for (uri, &node_id) in &self.uri_to_id {
222            if let Some(node) = self.nodes.get(node_id) {
223                // Calculate distance using configured metric
224                let distance = self.config.metric.distance(query, &node.vector)?;
225                results.push((uri.clone(), distance));
226            }
227        }
228
229        // Sort by distance and take k closest
230        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
231        results.truncate(k);
232
233        Ok(results)
234    }
235
236    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
237        if self.nodes.is_empty() || self.entry_point.is_none() {
238            return Ok(Vec::new());
239        }
240
241        // Simple brute force threshold search for now
242        // TODO: Implement proper HNSW range search algorithm
243        let mut results = Vec::new();
244
245        for (uri, &node_id) in &self.uri_to_id {
246            if let Some(node) = self.nodes.get(node_id) {
247                // Calculate distance using configured metric
248                let distance = self.config.metric.distance(query, &node.vector)?;
249                if distance <= threshold {
250                    results.push((uri.clone(), distance));
251                }
252            }
253        }
254
255        // Sort by distance
256        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
257
258        Ok(results)
259    }
260
261    fn get_vector(&self, uri: &str) -> Option<&Vector> {
262        self.uri_to_id
263            .get(uri)
264            .and_then(|&id| self.nodes.get(id))
265            .map(|node| &node.vector)
266    }
267}
268
269impl HnswIndex {
270    /// Remove a vector by its URI (not part of VectorIndex trait)
271    pub fn remove(&mut self, uri: &str) -> Result<()> {
272        // Implementation of vector removal from HNSW index
273
274        // Find the node ID for the URI
275        let node_id = if let Some(&id) = self.uri_to_id.get(uri) {
276            id
277        } else {
278            return Err(anyhow::anyhow!("URI not found: {}", uri));
279        };
280
281        // Remove the node from all its connections at all levels
282        if let Some(node) = self.nodes.get(node_id) {
283            let node_connections = node.connections.clone();
284
285            // Remove this node from all connected nodes
286            for (level, connections) in node_connections.iter().enumerate() {
287                for &connected_id in connections {
288                    if let Some(connected_node) = self.nodes.get_mut(connected_id) {
289                        connected_node.remove_connection(level, node_id);
290                    }
291                }
292            }
293        }
294
295        // If this node was the entry point, find a new entry point
296        if self.entry_point == Some(node_id) {
297            self.entry_point = None;
298
299            // Find a node with the highest level as the new entry point
300            let mut highest_level = 0;
301            let mut new_entry_point = None;
302
303            for (id, node) in self.nodes.iter().enumerate() {
304                if id != node_id && node.level() >= highest_level {
305                    highest_level = node.level();
306                    new_entry_point = Some(id);
307                }
308            }
309
310            self.entry_point = new_entry_point;
311        }
312
313        // Remove the node from URI mapping
314        self.uri_to_id.remove(uri);
315
316        // Mark the node as removed (we don't actually remove it to avoid ID shifts)
317        // In a production implementation, you might use a tombstone approach
318        // or compact the index periodically
319        if let Some(node) = self.nodes.get_mut(node_id) {
320            node.connections.clear();
321            // We could add a "deleted" flag here if needed
322        }
323
324        // Update statistics
325        self.stats
326            .total_deletions
327            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
328
329        Ok(())
330    }
331
332    /// Update a vector by its URI (not part of VectorIndex trait)
333    pub fn update(&mut self, uri: String, vector: Vector) -> Result<()> {
334        // Implementation of vector update in HNSW index
335        // This is a simplified approach: remove and re-add the vector
336
337        // Check if the URI exists
338        if !self.uri_to_id.contains_key(&uri) {
339            return Err(anyhow::anyhow!("URI not found: {}", uri));
340        }
341
342        // Store the current connections before removal for potential optimization
343        let node_id = self.uri_to_id[&uri];
344        let _old_connections = self.nodes.get(node_id).map(|node| node.connections.clone());
345
346        // Remove the old vector
347        self.remove(&uri)?;
348
349        // Add the new vector with the same URI
350        self.insert(uri.clone(), vector)?;
351
352        // Update statistics
353        self.stats
354            .total_updates
355            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
356
357        // In a more sophisticated implementation, we could:
358        // 1. Check if the vector is similar enough to keep some connections
359        // 2. Incrementally update the graph structure
360        // 3. Use lazy updates to batch multiple updates
361
362        Ok(())
363    }
364
365    /// Clear all vectors from the index (not part of VectorIndex trait)
366    pub fn clear(&mut self) -> Result<()> {
367        self.nodes.clear();
368        self.uri_to_id.clear();
369        self.entry_point = None;
370        Ok(())
371    }
372
373    /// Get the number of vectors in the index (not part of VectorIndex trait)
374    pub fn size(&self) -> usize {
375        self.nodes.len()
376    }
377}