oxirs_vec/hnsw/
index.rs

1//! Main HNSW index implementation
2
3use crate::hnsw::{HnswConfig, HnswPerformanceStats, Node};
4use crate::{Vector, VectorIndex};
5use anyhow::Result;
6use std::collections::HashMap;
7use std::sync::atomic::AtomicU64;
8
9#[cfg(feature = "gpu")]
10use crate::gpu::GpuAccelerator;
11
12/// HNSW index implementation
13pub struct HnswIndex {
14    config: HnswConfig,
15    nodes: Vec<Node>,
16    uri_to_id: HashMap<String, usize>,
17    entry_point: Option<usize>,
18    level_multiplier: f64,
19    rng_state: u64,
20    /// Performance statistics
21    stats: HnswPerformanceStats,
22    /// Distance calculation count (for metrics)
23    distance_calculations: AtomicU64,
24    /// GPU accelerator for CUDA-accelerated operations
25    #[cfg(feature = "gpu")]
26    gpu_accelerator: Option<Arc<GpuAccelerator>>,
27    /// Multi-GPU accelerators for distributed computation
28    #[cfg(feature = "gpu")]
29    multi_gpu_accelerators: Vec<Arc<GpuAccelerator>>,
30}
31
32impl HnswIndex {
33    pub fn new(config: HnswConfig) -> Result<Self> {
34        // Initialize GPU accelerators if enabled
35        #[cfg(feature = "gpu")]
36        let (gpu_accelerator, multi_gpu_accelerators) = if config.enable_gpu {
37            let gpu_config = config.gpu_config.clone().unwrap_or_default();
38
39            if config.enable_multi_gpu && gpu_config.preferred_gpu_ids.len() > 1 {
40                // Initialize multi-GPU setup
41                let mut accelerators = Vec::new();
42                for &gpu_id in &gpu_config.preferred_gpu_ids {
43                    let mut gpu_conf = gpu_config.clone();
44                    gpu_conf.device_id = gpu_id;
45                    let accelerator = GpuAccelerator::new(gpu_conf)?;
46                    accelerators.push(Arc::new(accelerator));
47                }
48                (None, accelerators)
49            } else {
50                // Single GPU setup
51                let accelerator = GpuAccelerator::new(gpu_config)?;
52                (Some(Arc::new(accelerator)), Vec::new())
53            }
54        } else {
55            (None, Vec::new())
56        };
57
58        Ok(Self {
59            config,
60            nodes: Vec::new(),
61            uri_to_id: HashMap::new(),
62            entry_point: None,
63            level_multiplier: 1.0 / (2.0_f64).ln(),
64            rng_state: 42, // Simple deterministic seed
65            stats: HnswPerformanceStats::default(),
66            distance_calculations: AtomicU64::new(0),
67            #[cfg(feature = "gpu")]
68            gpu_accelerator,
69            #[cfg(feature = "gpu")]
70            multi_gpu_accelerators,
71        })
72    }
73
74    /// Create a new HNSW index without GPU acceleration (for compatibility)
75    pub fn new_cpu_only(config: HnswConfig) -> Self {
76        let mut cpu_config = config;
77        cpu_config.enable_gpu = false;
78        cpu_config.enable_multi_gpu = false;
79
80        Self {
81            config: cpu_config,
82            nodes: Vec::new(),
83            uri_to_id: HashMap::new(),
84            entry_point: None,
85            level_multiplier: 1.0 / (2.0_f64).ln(),
86            rng_state: 42,
87            stats: HnswPerformanceStats::default(),
88            distance_calculations: AtomicU64::new(0),
89            #[cfg(feature = "gpu")]
90            gpu_accelerator: None,
91            #[cfg(feature = "gpu")]
92            multi_gpu_accelerators: Vec::new(),
93        }
94    }
95
96    /// Get the URI to ID mapping
97    pub fn uri_to_id(&self) -> &HashMap<String, usize> {
98        &self.uri_to_id
99    }
100
101    /// Get mutable URI to ID mapping
102    pub fn uri_to_id_mut(&mut self) -> &mut HashMap<String, usize> {
103        &mut self.uri_to_id
104    }
105
106    /// Get the nodes
107    pub fn nodes(&self) -> &Vec<Node> {
108        &self.nodes
109    }
110
111    /// Get mutable nodes
112    pub fn nodes_mut(&mut self) -> &mut Vec<Node> {
113        &mut self.nodes
114    }
115
116    /// Get the entry point
117    pub fn entry_point(&self) -> Option<usize> {
118        self.entry_point
119    }
120
121    /// Set the entry point
122    pub fn set_entry_point(&mut self, entry_point: Option<usize>) {
123        self.entry_point = entry_point;
124    }
125
126    /// Get the configuration
127    pub fn config(&self) -> &HnswConfig {
128        &self.config
129    }
130
131    /// Get performance statistics
132    pub fn get_stats(&self) -> &HnswPerformanceStats {
133        &self.stats
134    }
135
136    /// Check if GPU acceleration is available and enabled
137    #[cfg(feature = "gpu")]
138    pub fn is_gpu_available(&self) -> bool {
139        self.config.enable_gpu
140            && (self.gpu_accelerator.is_some() || !self.multi_gpu_accelerators.is_empty())
141    }
142
143    #[cfg(not(feature = "gpu"))]
144    pub fn is_gpu_available(&self) -> bool {
145        false
146    }
147
148    /// Get GPU performance statistics
149    #[cfg(feature = "gpu")]
150    pub fn get_gpu_stats(&self) -> Option<crate::gpu::GpuPerformanceStats> {
151        if let Some(ref accelerator) = self.gpu_accelerator {
152            // Would need to implement stats retrieval in GpuAccelerator
153            None // Placeholder
154        } else {
155            None
156        }
157    }
158
159    /// Get the number of nodes in the index
160    pub fn len(&self) -> usize {
161        self.nodes.len()
162    }
163
164    /// Check if the index is empty
165    pub fn is_empty(&self) -> bool {
166        self.nodes.is_empty()
167    }
168
169    // Duplicate methods removed - already defined above
170
171    /// Get mutable reference to stats
172    pub fn stats_mut(&mut self) -> &mut HnswPerformanceStats {
173        &mut self.stats
174    }
175
176    /// Get level multiplier
177    pub fn level_multiplier(&self) -> f64 {
178        self.level_multiplier
179    }
180
181    /// Get mutable reference to RNG state
182    pub fn rng_state_mut(&mut self) -> &mut u64 {
183        &mut self.rng_state
184    }
185
186    /// Get RNG state
187    pub fn rng_state(&self) -> u64 {
188        self.rng_state
189    }
190}
191
192impl VectorIndex for HnswIndex {
193    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
194        // Use the add_vector implementation from construction module
195        self.add_vector(uri, vector)
196    }
197
198    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
199        if self.nodes.is_empty() || self.entry_point.is_none() {
200            return Ok(Vec::new());
201        }
202
203        // Simple brute force search for now (placeholder)
204        // TODO: Implement proper HNSW search algorithm
205        let mut results = Vec::new();
206
207        for (uri, &node_id) in &self.uri_to_id {
208            if let Some(node) = self.nodes.get(node_id) {
209                // Calculate distance using configured metric
210                let distance = self.config.metric.distance(query, &node.vector)?;
211                results.push((uri.clone(), distance));
212            }
213        }
214
215        // Sort by distance and take k closest
216        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
217        results.truncate(k);
218
219        Ok(results)
220    }
221
222    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
223        if self.nodes.is_empty() || self.entry_point.is_none() {
224            return Ok(Vec::new());
225        }
226
227        // Simple brute force threshold search for now
228        // TODO: Implement proper HNSW range search algorithm
229        let mut results = Vec::new();
230
231        for (uri, &node_id) in &self.uri_to_id {
232            if let Some(node) = self.nodes.get(node_id) {
233                // Calculate distance using configured metric
234                let distance = self.config.metric.distance(query, &node.vector)?;
235                if distance <= threshold {
236                    results.push((uri.clone(), distance));
237                }
238            }
239        }
240
241        // Sort by distance
242        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
243
244        Ok(results)
245    }
246
247    fn get_vector(&self, uri: &str) -> Option<&Vector> {
248        self.uri_to_id
249            .get(uri)
250            .and_then(|&id| self.nodes.get(id))
251            .map(|node| &node.vector)
252    }
253}
254
255impl HnswIndex {
256    /// Remove a vector by its URI (not part of VectorIndex trait)
257    pub fn remove(&mut self, uri: &str) -> Result<()> {
258        // Implementation of vector removal from HNSW index
259
260        // Find the node ID for the URI
261        let node_id = if let Some(&id) = self.uri_to_id.get(uri) {
262            id
263        } else {
264            return Err(anyhow::anyhow!("URI not found: {}", uri));
265        };
266
267        // Remove the node from all its connections at all levels
268        if let Some(node) = self.nodes.get(node_id) {
269            let node_connections = node.connections.clone();
270
271            // Remove this node from all connected nodes
272            for (level, connections) in node_connections.iter().enumerate() {
273                for &connected_id in connections {
274                    if let Some(connected_node) = self.nodes.get_mut(connected_id) {
275                        connected_node.remove_connection(level, node_id);
276                    }
277                }
278            }
279        }
280
281        // If this node was the entry point, find a new entry point
282        if self.entry_point == Some(node_id) {
283            self.entry_point = None;
284
285            // Find a node with the highest level as the new entry point
286            let mut highest_level = 0;
287            let mut new_entry_point = None;
288
289            for (id, node) in self.nodes.iter().enumerate() {
290                if id != node_id && node.level() >= highest_level {
291                    highest_level = node.level();
292                    new_entry_point = Some(id);
293                }
294            }
295
296            self.entry_point = new_entry_point;
297        }
298
299        // Remove the node from URI mapping
300        self.uri_to_id.remove(uri);
301
302        // Mark the node as removed (we don't actually remove it to avoid ID shifts)
303        // In a production implementation, you might use a tombstone approach
304        // or compact the index periodically
305        if let Some(node) = self.nodes.get_mut(node_id) {
306            node.connections.clear();
307            // We could add a "deleted" flag here if needed
308        }
309
310        // Update statistics
311        self.stats
312            .total_deletions
313            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
314
315        Ok(())
316    }
317
318    /// Update a vector by its URI (not part of VectorIndex trait)
319    pub fn update(&mut self, uri: String, vector: Vector) -> Result<()> {
320        // Implementation of vector update in HNSW index
321        // This is a simplified approach: remove and re-add the vector
322
323        // Check if the URI exists
324        if !self.uri_to_id.contains_key(&uri) {
325            return Err(anyhow::anyhow!("URI not found: {}", uri));
326        }
327
328        // Store the current connections before removal for potential optimization
329        let node_id = self.uri_to_id[&uri];
330        let _old_connections = self.nodes.get(node_id).map(|node| node.connections.clone());
331
332        // Remove the old vector
333        self.remove(&uri)?;
334
335        // Add the new vector with the same URI
336        self.insert(uri.clone(), vector)?;
337
338        // Update statistics
339        self.stats
340            .total_updates
341            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
342
343        // In a more sophisticated implementation, we could:
344        // 1. Check if the vector is similar enough to keep some connections
345        // 2. Incrementally update the graph structure
346        // 3. Use lazy updates to batch multiple updates
347
348        Ok(())
349    }
350
351    /// Clear all vectors from the index (not part of VectorIndex trait)
352    pub fn clear(&mut self) -> Result<()> {
353        self.nodes.clear();
354        self.uri_to_id.clear();
355        self.entry_point = None;
356        Ok(())
357    }
358
359    /// Get the number of vectors in the index (not part of VectorIndex trait)
360    pub fn size(&self) -> usize {
361        self.nodes.len()
362    }
363}