1use crate::hnsw::{HnswConfig, HnswPerformanceStats, Node};
4use crate::{Vector, VectorIndex};
5use anyhow::Result;
6use std::collections::HashMap;
7use std::sync::atomic::AtomicU64;
8
9#[cfg(feature = "gpu")]
10use crate::gpu::GpuAccelerator;
11
12pub struct HnswIndex {
14 config: HnswConfig,
15 nodes: Vec<Node>,
16 uri_to_id: HashMap<String, usize>,
17 entry_point: Option<usize>,
18 level_multiplier: f64,
19 rng_state: u64,
20 stats: HnswPerformanceStats,
22 distance_calculations: AtomicU64,
24 #[cfg(feature = "gpu")]
26 gpu_accelerator: Option<Arc<GpuAccelerator>>,
27 #[cfg(feature = "gpu")]
29 multi_gpu_accelerators: Vec<Arc<GpuAccelerator>>,
30}
31
32impl HnswIndex {
33 pub fn new(config: HnswConfig) -> Result<Self> {
34 #[cfg(feature = "gpu")]
36 let (gpu_accelerator, multi_gpu_accelerators) = if config.enable_gpu {
37 let gpu_config = config.gpu_config.clone().unwrap_or_default();
38
39 if config.enable_multi_gpu && gpu_config.preferred_gpu_ids.len() > 1 {
40 let mut accelerators = Vec::new();
42 for &gpu_id in &gpu_config.preferred_gpu_ids {
43 let mut gpu_conf = gpu_config.clone();
44 gpu_conf.device_id = gpu_id;
45 let accelerator = GpuAccelerator::new(gpu_conf)?;
46 accelerators.push(Arc::new(accelerator));
47 }
48 (None, accelerators)
49 } else {
50 let accelerator = GpuAccelerator::new(gpu_config)?;
52 (Some(Arc::new(accelerator)), Vec::new())
53 }
54 } else {
55 (None, Vec::new())
56 };
57
58 Ok(Self {
59 config,
60 nodes: Vec::new(),
61 uri_to_id: HashMap::new(),
62 entry_point: None,
63 level_multiplier: 1.0 / (2.0_f64).ln(),
64 rng_state: 42, stats: HnswPerformanceStats::default(),
66 distance_calculations: AtomicU64::new(0),
67 #[cfg(feature = "gpu")]
68 gpu_accelerator,
69 #[cfg(feature = "gpu")]
70 multi_gpu_accelerators,
71 })
72 }
73
74 pub fn new_cpu_only(config: HnswConfig) -> Self {
76 let mut cpu_config = config;
77 cpu_config.enable_gpu = false;
78 cpu_config.enable_multi_gpu = false;
79
80 Self {
81 config: cpu_config,
82 nodes: Vec::new(),
83 uri_to_id: HashMap::new(),
84 entry_point: None,
85 level_multiplier: 1.0 / (2.0_f64).ln(),
86 rng_state: 42,
87 stats: HnswPerformanceStats::default(),
88 distance_calculations: AtomicU64::new(0),
89 #[cfg(feature = "gpu")]
90 gpu_accelerator: None,
91 #[cfg(feature = "gpu")]
92 multi_gpu_accelerators: Vec::new(),
93 }
94 }
95
96 pub fn uri_to_id(&self) -> &HashMap<String, usize> {
98 &self.uri_to_id
99 }
100
101 pub fn uri_to_id_mut(&mut self) -> &mut HashMap<String, usize> {
103 &mut self.uri_to_id
104 }
105
106 pub fn nodes(&self) -> &Vec<Node> {
108 &self.nodes
109 }
110
111 pub fn nodes_mut(&mut self) -> &mut Vec<Node> {
113 &mut self.nodes
114 }
115
116 pub fn entry_point(&self) -> Option<usize> {
118 self.entry_point
119 }
120
121 pub fn set_entry_point(&mut self, entry_point: Option<usize>) {
123 self.entry_point = entry_point;
124 }
125
126 pub fn config(&self) -> &HnswConfig {
128 &self.config
129 }
130
131 pub fn get_stats(&self) -> &HnswPerformanceStats {
133 &self.stats
134 }
135
136 #[cfg(feature = "gpu")]
138 pub fn is_gpu_available(&self) -> bool {
139 self.config.enable_gpu
140 && (self.gpu_accelerator.is_some() || !self.multi_gpu_accelerators.is_empty())
141 }
142
143 #[cfg(not(feature = "gpu"))]
144 pub fn is_gpu_available(&self) -> bool {
145 false
146 }
147
148 #[cfg(feature = "gpu")]
150 pub fn get_gpu_stats(&self) -> Option<crate::gpu::GpuPerformanceStats> {
151 if let Some(ref accelerator) = self.gpu_accelerator {
152 None } else {
155 None
156 }
157 }
158
159 pub fn len(&self) -> usize {
161 self.nodes.len()
162 }
163
164 pub fn is_empty(&self) -> bool {
166 self.nodes.is_empty()
167 }
168
169 pub fn stats_mut(&mut self) -> &mut HnswPerformanceStats {
173 &mut self.stats
174 }
175
176 pub fn level_multiplier(&self) -> f64 {
178 self.level_multiplier
179 }
180
181 pub fn rng_state_mut(&mut self) -> &mut u64 {
183 &mut self.rng_state
184 }
185
186 pub fn rng_state(&self) -> u64 {
188 self.rng_state
189 }
190}
191
192impl VectorIndex for HnswIndex {
193 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
194 self.add_vector(uri, vector)
196 }
197
198 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
199 if self.nodes.is_empty() || self.entry_point.is_none() {
200 return Ok(Vec::new());
201 }
202
203 let mut results = Vec::new();
206
207 for (uri, &node_id) in &self.uri_to_id {
208 if let Some(node) = self.nodes.get(node_id) {
209 let distance = self.config.metric.distance(query, &node.vector)?;
211 results.push((uri.clone(), distance));
212 }
213 }
214
215 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
217 results.truncate(k);
218
219 Ok(results)
220 }
221
222 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
223 if self.nodes.is_empty() || self.entry_point.is_none() {
224 return Ok(Vec::new());
225 }
226
227 let mut results = Vec::new();
230
231 for (uri, &node_id) in &self.uri_to_id {
232 if let Some(node) = self.nodes.get(node_id) {
233 let distance = self.config.metric.distance(query, &node.vector)?;
235 if distance <= threshold {
236 results.push((uri.clone(), distance));
237 }
238 }
239 }
240
241 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
243
244 Ok(results)
245 }
246
247 fn get_vector(&self, uri: &str) -> Option<&Vector> {
248 self.uri_to_id
249 .get(uri)
250 .and_then(|&id| self.nodes.get(id))
251 .map(|node| &node.vector)
252 }
253}
254
255impl HnswIndex {
256 pub fn remove(&mut self, uri: &str) -> Result<()> {
258 let node_id = if let Some(&id) = self.uri_to_id.get(uri) {
262 id
263 } else {
264 return Err(anyhow::anyhow!("URI not found: {}", uri));
265 };
266
267 if let Some(node) = self.nodes.get(node_id) {
269 let node_connections = node.connections.clone();
270
271 for (level, connections) in node_connections.iter().enumerate() {
273 for &connected_id in connections {
274 if let Some(connected_node) = self.nodes.get_mut(connected_id) {
275 connected_node.remove_connection(level, node_id);
276 }
277 }
278 }
279 }
280
281 if self.entry_point == Some(node_id) {
283 self.entry_point = None;
284
285 let mut highest_level = 0;
287 let mut new_entry_point = None;
288
289 for (id, node) in self.nodes.iter().enumerate() {
290 if id != node_id && node.level() >= highest_level {
291 highest_level = node.level();
292 new_entry_point = Some(id);
293 }
294 }
295
296 self.entry_point = new_entry_point;
297 }
298
299 self.uri_to_id.remove(uri);
301
302 if let Some(node) = self.nodes.get_mut(node_id) {
306 node.connections.clear();
307 }
309
310 self.stats
312 .total_deletions
313 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
314
315 Ok(())
316 }
317
318 pub fn update(&mut self, uri: String, vector: Vector) -> Result<()> {
320 if !self.uri_to_id.contains_key(&uri) {
325 return Err(anyhow::anyhow!("URI not found: {}", uri));
326 }
327
328 let node_id = self.uri_to_id[&uri];
330 let _old_connections = self.nodes.get(node_id).map(|node| node.connections.clone());
331
332 self.remove(&uri)?;
334
335 self.insert(uri.clone(), vector)?;
337
338 self.stats
340 .total_updates
341 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
342
343 Ok(())
349 }
350
351 pub fn clear(&mut self) -> Result<()> {
353 self.nodes.clear();
354 self.uri_to_id.clear();
355 self.entry_point = None;
356 Ok(())
357 }
358
359 pub fn size(&self) -> usize {
361 self.nodes.len()
362 }
363}