1use crate::hnsw::{HnswConfig, HnswPerformanceStats, Node};
4use crate::{Vector, VectorIndex};
5use anyhow::Result;
6use std::collections::HashMap;
7use std::sync::atomic::AtomicU64;
8#[cfg(feature = "gpu")]
9use std::sync::Arc;
10
11#[cfg(feature = "gpu")]
12use crate::gpu::GpuAccelerator;
13
14pub struct HnswIndex {
16 config: HnswConfig,
17 nodes: Vec<Node>,
18 uri_to_id: HashMap<String, usize>,
19 entry_point: Option<usize>,
20 level_multiplier: f64,
21 rng_state: u64,
22 stats: HnswPerformanceStats,
24 distance_calculations: AtomicU64,
26 #[cfg(feature = "gpu")]
28 gpu_accelerator: Option<Arc<GpuAccelerator>>,
29 #[cfg(feature = "gpu")]
31 multi_gpu_accelerators: Vec<Arc<GpuAccelerator>>,
32}
33
34impl HnswIndex {
35 pub fn new(config: HnswConfig) -> Result<Self> {
36 #[cfg(feature = "gpu")]
38 let (gpu_accelerator, multi_gpu_accelerators) = if config.enable_gpu {
39 let gpu_config = config.gpu_config.clone().unwrap_or_default();
40
41 if config.enable_multi_gpu && gpu_config.preferred_gpu_ids.len() > 1 {
42 let mut accelerators = Vec::new();
44 for &gpu_id in &gpu_config.preferred_gpu_ids {
45 let mut gpu_conf = gpu_config.clone();
46 gpu_conf.device_id = gpu_id;
47 let accelerator = GpuAccelerator::new(gpu_conf)?;
48 accelerators.push(Arc::new(accelerator));
49 }
50 (None, accelerators)
51 } else {
52 let accelerator = GpuAccelerator::new(gpu_config)?;
54 (Some(Arc::new(accelerator)), Vec::new())
55 }
56 } else {
57 (None, Vec::new())
58 };
59
60 Ok(Self {
61 config,
62 nodes: Vec::new(),
63 uri_to_id: HashMap::new(),
64 entry_point: None,
65 level_multiplier: 1.0 / (2.0_f64).ln(),
66 rng_state: 42, stats: HnswPerformanceStats::default(),
68 distance_calculations: AtomicU64::new(0),
69 #[cfg(feature = "gpu")]
70 gpu_accelerator,
71 #[cfg(feature = "gpu")]
72 multi_gpu_accelerators,
73 })
74 }
75
76 pub fn new_cpu_only(config: HnswConfig) -> Self {
78 let mut cpu_config = config;
79 cpu_config.enable_gpu = false;
80 cpu_config.enable_multi_gpu = false;
81
82 Self {
83 config: cpu_config,
84 nodes: Vec::new(),
85 uri_to_id: HashMap::new(),
86 entry_point: None,
87 level_multiplier: 1.0 / (2.0_f64).ln(),
88 rng_state: 42,
89 stats: HnswPerformanceStats::default(),
90 distance_calculations: AtomicU64::new(0),
91 #[cfg(feature = "gpu")]
92 gpu_accelerator: None,
93 #[cfg(feature = "gpu")]
94 multi_gpu_accelerators: Vec::new(),
95 }
96 }
97
98 pub fn uri_to_id(&self) -> &HashMap<String, usize> {
100 &self.uri_to_id
101 }
102
103 pub fn uri_to_id_mut(&mut self) -> &mut HashMap<String, usize> {
105 &mut self.uri_to_id
106 }
107
108 pub fn nodes(&self) -> &Vec<Node> {
110 &self.nodes
111 }
112
113 pub fn nodes_mut(&mut self) -> &mut Vec<Node> {
115 &mut self.nodes
116 }
117
118 pub fn entry_point(&self) -> Option<usize> {
120 self.entry_point
121 }
122
123 pub fn set_entry_point(&mut self, entry_point: Option<usize>) {
125 self.entry_point = entry_point;
126 }
127
128 pub fn config(&self) -> &HnswConfig {
130 &self.config
131 }
132
133 pub fn get_stats(&self) -> &HnswPerformanceStats {
135 &self.stats
136 }
137
138 #[cfg(feature = "gpu")]
140 pub fn is_gpu_available(&self) -> bool {
141 self.config.enable_gpu
142 && (self.gpu_accelerator.is_some() || !self.multi_gpu_accelerators.is_empty())
143 }
144
145 #[cfg(not(feature = "gpu"))]
146 pub fn is_gpu_available(&self) -> bool {
147 false
148 }
149
150 #[cfg(feature = "gpu")]
152 pub fn get_gpu_stats(&self) -> Option<crate::gpu::GpuPerformanceStats> {
153 if let Some(ref accelerator) = self.gpu_accelerator {
154 None } else {
157 None
158 }
159 }
160
161 #[cfg(feature = "gpu")]
163 pub fn gpu_accelerator(&self) -> Option<&Arc<GpuAccelerator>> {
164 self.gpu_accelerator.as_ref()
165 }
166
167 #[cfg(feature = "gpu")]
169 pub fn multi_gpu_accelerators(&self) -> &Vec<Arc<GpuAccelerator>> {
170 &self.multi_gpu_accelerators
171 }
172
173 pub fn len(&self) -> usize {
175 self.nodes.len()
176 }
177
178 pub fn is_empty(&self) -> bool {
180 self.nodes.is_empty()
181 }
182
183 pub fn stats_mut(&mut self) -> &mut HnswPerformanceStats {
187 &mut self.stats
188 }
189
190 pub fn level_multiplier(&self) -> f64 {
192 self.level_multiplier
193 }
194
195 pub fn rng_state_mut(&mut self) -> &mut u64 {
197 &mut self.rng_state
198 }
199
200 pub fn rng_state(&self) -> u64 {
202 self.rng_state
203 }
204}
205
206impl VectorIndex for HnswIndex {
207 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
208 self.add_vector(uri, vector)
210 }
211
212 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
213 if self.nodes.is_empty() || self.entry_point.is_none() {
214 return Ok(Vec::new());
215 }
216
217 let mut results = Vec::new();
220
221 for (uri, &node_id) in &self.uri_to_id {
222 if let Some(node) = self.nodes.get(node_id) {
223 let distance = self.config.metric.distance(query, &node.vector)?;
225 results.push((uri.clone(), distance));
226 }
227 }
228
229 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
231 results.truncate(k);
232
233 Ok(results)
234 }
235
236 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
237 if self.nodes.is_empty() || self.entry_point.is_none() {
238 return Ok(Vec::new());
239 }
240
241 let mut results = Vec::new();
244
245 for (uri, &node_id) in &self.uri_to_id {
246 if let Some(node) = self.nodes.get(node_id) {
247 let distance = self.config.metric.distance(query, &node.vector)?;
249 if distance <= threshold {
250 results.push((uri.clone(), distance));
251 }
252 }
253 }
254
255 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
257
258 Ok(results)
259 }
260
261 fn get_vector(&self, uri: &str) -> Option<&Vector> {
262 self.uri_to_id
263 .get(uri)
264 .and_then(|&id| self.nodes.get(id))
265 .map(|node| &node.vector)
266 }
267}
268
269impl HnswIndex {
270 pub fn remove(&mut self, uri: &str) -> Result<()> {
272 let node_id = if let Some(&id) = self.uri_to_id.get(uri) {
276 id
277 } else {
278 return Err(anyhow::anyhow!("URI not found: {}", uri));
279 };
280
281 if let Some(node) = self.nodes.get(node_id) {
283 let node_connections = node.connections.clone();
284
285 for (level, connections) in node_connections.iter().enumerate() {
287 for &connected_id in connections {
288 if let Some(connected_node) = self.nodes.get_mut(connected_id) {
289 connected_node.remove_connection(level, node_id);
290 }
291 }
292 }
293 }
294
295 if self.entry_point == Some(node_id) {
297 self.entry_point = None;
298
299 let mut highest_level = 0;
301 let mut new_entry_point = None;
302
303 for (id, node) in self.nodes.iter().enumerate() {
304 if id != node_id && node.level() >= highest_level {
305 highest_level = node.level();
306 new_entry_point = Some(id);
307 }
308 }
309
310 self.entry_point = new_entry_point;
311 }
312
313 self.uri_to_id.remove(uri);
315
316 if let Some(node) = self.nodes.get_mut(node_id) {
320 node.connections.clear();
321 }
323
324 self.stats
326 .total_deletions
327 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
328
329 Ok(())
330 }
331
332 pub fn update(&mut self, uri: String, vector: Vector) -> Result<()> {
334 if !self.uri_to_id.contains_key(&uri) {
339 return Err(anyhow::anyhow!("URI not found: {}", uri));
340 }
341
342 let node_id = self.uri_to_id[&uri];
344 let _old_connections = self.nodes.get(node_id).map(|node| node.connections.clone());
345
346 self.remove(&uri)?;
348
349 self.insert(uri.clone(), vector)?;
351
352 self.stats
354 .total_updates
355 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
356
357 Ok(())
363 }
364
365 pub fn clear(&mut self) -> Result<()> {
367 self.nodes.clear();
368 self.uri_to_id.clear();
369 self.entry_point = None;
370 Ok(())
371 }
372
373 pub fn size(&self) -> usize {
375 self.nodes.len()
376 }
377}