1use crate::hnsw::query_cache::{QueryCache, QueryCacheConfig};
4use crate::hnsw::{HnswConfig, HnswPerformanceStats, Node};
5use crate::{Vector, VectorIndex};
6use anyhow::Result;
7use std::collections::HashMap;
8use std::sync::atomic::AtomicU64;
9#[cfg(feature = "gpu")]
10use std::sync::Arc;
11
12#[cfg(feature = "gpu")]
13use crate::gpu::GpuAccelerator;
14
15pub struct HnswIndex {
17 config: HnswConfig,
18 nodes: Vec<Node>,
19 uri_to_id: HashMap<String, usize>,
20 entry_point: Option<usize>,
21 level_multiplier: f64,
22 rng_state: u64,
23 stats: HnswPerformanceStats,
25 distance_calculations: AtomicU64,
27 query_cache: Option<QueryCache>,
29 #[cfg(feature = "gpu")]
31 gpu_accelerator: Option<Arc<GpuAccelerator>>,
32 #[cfg(feature = "gpu")]
34 multi_gpu_accelerators: Vec<Arc<GpuAccelerator>>,
35}
36
37impl HnswIndex {
38 pub fn new(config: HnswConfig) -> Result<Self> {
39 #[cfg(feature = "gpu")]
41 let (gpu_accelerator, multi_gpu_accelerators) = if config.enable_gpu {
42 let gpu_config = config.gpu_config.clone().unwrap_or_default();
43
44 if config.enable_multi_gpu && gpu_config.preferred_gpu_ids.len() > 1 {
45 let mut accelerators = Vec::new();
47 for &gpu_id in &gpu_config.preferred_gpu_ids {
48 let mut gpu_conf = gpu_config.clone();
49 gpu_conf.device_id = gpu_id;
50 let accelerator = GpuAccelerator::new(gpu_conf)?;
51 accelerators.push(Arc::new(accelerator));
52 }
53 (None, accelerators)
54 } else {
55 let accelerator = GpuAccelerator::new(gpu_config)?;
57 (Some(Arc::new(accelerator)), Vec::new())
58 }
59 } else {
60 (None, Vec::new())
61 };
62
63 let query_cache = Some(QueryCache::new(QueryCacheConfig::default()));
65
66 Ok(Self {
67 config,
68 nodes: Vec::new(),
69 uri_to_id: HashMap::new(),
70 entry_point: None,
71 level_multiplier: 1.0 / (2.0_f64).ln(),
72 rng_state: 42, stats: HnswPerformanceStats::default(),
74 distance_calculations: AtomicU64::new(0),
75 query_cache,
76 #[cfg(feature = "gpu")]
77 gpu_accelerator,
78 #[cfg(feature = "gpu")]
79 multi_gpu_accelerators,
80 })
81 }
82
83 pub fn new_cpu_only(config: HnswConfig) -> Self {
85 let mut cpu_config = config;
86 cpu_config.enable_gpu = false;
87 cpu_config.enable_multi_gpu = false;
88
89 let query_cache = Some(QueryCache::new(QueryCacheConfig::default()));
91
92 Self {
93 config: cpu_config,
94 nodes: Vec::new(),
95 uri_to_id: HashMap::new(),
96 entry_point: None,
97 level_multiplier: 1.0 / (2.0_f64).ln(),
98 rng_state: 42,
99 stats: HnswPerformanceStats::default(),
100 distance_calculations: AtomicU64::new(0),
101 query_cache,
102 #[cfg(feature = "gpu")]
103 gpu_accelerator: None,
104 #[cfg(feature = "gpu")]
105 multi_gpu_accelerators: Vec::new(),
106 }
107 }
108
109 pub fn enable_query_cache(&mut self, config: QueryCacheConfig) {
111 self.query_cache = Some(QueryCache::new(config));
112 }
113
114 pub fn disable_query_cache(&mut self) {
116 self.query_cache = None;
117 }
118
119 pub fn get_query_cache_stats(&self) -> Option<crate::hnsw::query_cache::QueryCacheStats> {
121 self.query_cache.as_ref().map(|cache| cache.get_stats())
122 }
123
124 pub fn clear_query_cache(&self) {
126 if let Some(ref cache) = self.query_cache {
127 cache.clear();
128 }
129 }
130
131 pub(crate) fn query_cache(&self) -> &Option<QueryCache> {
133 &self.query_cache
134 }
135
136 pub fn uri_to_id(&self) -> &HashMap<String, usize> {
138 &self.uri_to_id
139 }
140
141 pub fn uri_to_id_mut(&mut self) -> &mut HashMap<String, usize> {
143 &mut self.uri_to_id
144 }
145
146 pub fn nodes(&self) -> &Vec<Node> {
148 &self.nodes
149 }
150
151 pub fn nodes_mut(&mut self) -> &mut Vec<Node> {
153 &mut self.nodes
154 }
155
156 pub fn entry_point(&self) -> Option<usize> {
158 self.entry_point
159 }
160
161 pub fn set_entry_point(&mut self, entry_point: Option<usize>) {
163 self.entry_point = entry_point;
164 }
165
166 pub fn config(&self) -> &HnswConfig {
168 &self.config
169 }
170
171 pub fn get_stats(&self) -> &HnswPerformanceStats {
173 &self.stats
174 }
175
176 #[cfg(feature = "gpu")]
178 pub fn is_gpu_available(&self) -> bool {
179 self.config.enable_gpu
180 && (self.gpu_accelerator.is_some() || !self.multi_gpu_accelerators.is_empty())
181 }
182
183 #[cfg(not(feature = "gpu"))]
184 pub fn is_gpu_available(&self) -> bool {
185 false
186 }
187
188 #[cfg(feature = "gpu")]
190 pub fn get_gpu_stats(&self) -> Option<crate::gpu::GpuPerformanceStats> {
191 if let Some(ref _accelerator) = self.gpu_accelerator {
192 None } else {
195 None
196 }
197 }
198
199 #[cfg(feature = "gpu")]
201 pub fn gpu_accelerator(&self) -> Option<&Arc<GpuAccelerator>> {
202 self.gpu_accelerator.as_ref()
203 }
204
205 #[cfg(feature = "gpu")]
207 pub fn multi_gpu_accelerators(&self) -> &Vec<Arc<GpuAccelerator>> {
208 &self.multi_gpu_accelerators
209 }
210
211 pub fn len(&self) -> usize {
213 self.nodes.len()
214 }
215
216 pub fn is_empty(&self) -> bool {
218 self.nodes.is_empty()
219 }
220
221 pub fn stats_mut(&mut self) -> &mut HnswPerformanceStats {
225 &mut self.stats
226 }
227
228 pub fn level_multiplier(&self) -> f64 {
230 self.level_multiplier
231 }
232
233 pub fn rng_state_mut(&mut self) -> &mut u64 {
235 &mut self.rng_state
236 }
237
238 pub fn rng_state(&self) -> u64 {
240 self.rng_state
241 }
242}
243
244impl VectorIndex for HnswIndex {
245 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
246 self.add_vector(uri, vector)
248 }
249
250 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
251 HnswIndex::search_knn(self, query, k)
254 }
255
256 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
257 HnswIndex::range_search(self, query, threshold)
260 }
261
262 fn get_vector(&self, uri: &str) -> Option<&Vector> {
263 self.uri_to_id
264 .get(uri)
265 .and_then(|&id| self.nodes.get(id))
266 .map(|node| &node.vector)
267 }
268}
269
270impl HnswIndex {
271 pub fn remove(&mut self, uri: &str) -> Result<()> {
273 let node_id = if let Some(&id) = self.uri_to_id.get(uri) {
277 id
278 } else {
279 return Err(anyhow::anyhow!("URI not found: {}", uri));
280 };
281
282 if let Some(node) = self.nodes.get(node_id) {
284 let node_connections = node.connections.clone();
285
286 for (level, connections) in node_connections.iter().enumerate() {
288 for &connected_id in connections {
289 if let Some(connected_node) = self.nodes.get_mut(connected_id) {
290 connected_node.remove_connection(level, node_id);
291 }
292 }
293 }
294 }
295
296 if self.entry_point == Some(node_id) {
298 self.entry_point = None;
299
300 let mut highest_level = 0;
302 let mut new_entry_point = None;
303
304 for (id, node) in self.nodes.iter().enumerate() {
305 if id != node_id && node.level() >= highest_level {
306 highest_level = node.level();
307 new_entry_point = Some(id);
308 }
309 }
310
311 self.entry_point = new_entry_point;
312 }
313
314 self.uri_to_id.remove(uri);
316
317 if let Some(node) = self.nodes.get_mut(node_id) {
321 node.connections.clear();
322 }
324
325 self.stats
327 .total_deletions
328 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
329
330 Ok(())
331 }
332
333 pub fn update(&mut self, uri: String, vector: Vector) -> Result<()> {
335 if !self.uri_to_id.contains_key(&uri) {
340 return Err(anyhow::anyhow!("URI not found: {}", uri));
341 }
342
343 let node_id = self.uri_to_id[&uri];
345 let _old_connections = self.nodes.get(node_id).map(|node| node.connections.clone());
346
347 self.remove(&uri)?;
349
350 self.insert(uri.clone(), vector)?;
352
353 self.stats
355 .total_updates
356 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
357
358 Ok(())
364 }
365
366 pub fn clear(&mut self) -> Result<()> {
368 self.nodes.clear();
369 self.uri_to_id.clear();
370 self.entry_point = None;
371 Ok(())
372 }
373
374 pub fn size(&self) -> usize {
376 self.nodes.len()
377 }
378}