1use super::{GpuAccelerator, GpuConfig, GpuMemoryPool, GpuPerformanceStats};
4use crate::{similarity::SimilarityMetric, Vector, VectorData};
5use anyhow::{anyhow, Result};
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9
10#[derive(Debug)]
12pub struct GpuVectorIndex {
13 accelerator: Arc<GpuAccelerator>,
14 vectors: Vec<Vector>,
15 vector_data: Vec<f32>,
16 dimension: usize,
17 memory_pool: Arc<RwLock<GpuMemoryPool>>,
18 uri_map: HashMap<String, usize>,
19}
20
21impl GpuVectorIndex {
22 pub fn new(config: GpuConfig) -> Result<Self> {
24 let accelerator = Arc::new(GpuAccelerator::new(config.clone())?);
25 let memory_pool = Arc::new(RwLock::new(GpuMemoryPool::new(&config, 1024)?));
26
27 Ok(Self {
28 accelerator,
29 vectors: Vec::new(),
30 vector_data: Vec::new(),
31 dimension: 0,
32 memory_pool,
33 uri_map: HashMap::new(),
34 })
35 }
36
37 pub fn add_vectors(&mut self, vectors: Vec<Vector>) -> Result<()> {
39 if vectors.is_empty() {
40 return Ok(());
41 }
42
43 if self.dimension == 0 {
45 self.dimension = vectors[0].dimensions;
46 }
47
48 for vector in &vectors {
50 if vector.dimensions != self.dimension {
51 return Err(anyhow!(
52 "Vector dimension mismatch: expected {}, got {}",
53 self.dimension,
54 vector.dimensions
55 ));
56 }
57 }
58
59 for vector in &vectors {
61 match &vector.values {
62 VectorData::F32(data) => self.vector_data.extend(data),
63 VectorData::F64(data) => {
64 self.vector_data.extend(data.iter().map(|&x| x as f32));
66 }
67 _ => return Err(anyhow!("Unsupported vector precision for GPU processing")),
68 }
69 }
70
71 self.vectors.extend(vectors);
72 Ok(())
73 }
74
75 pub fn search(
77 &self,
78 query: &Vector,
79 k: usize,
80 metric: SimilarityMetric,
81 ) -> Result<Vec<(usize, f32)>> {
82 if self.vectors.is_empty() {
83 return Ok(Vec::new());
84 }
85
86 let query_data = match &query.values {
87 VectorData::F32(data) => data.clone(),
88 VectorData::F64(data) => data.iter().map(|&x| x as f32).collect(),
89 _ => {
90 return Err(anyhow!(
91 "Unsupported query vector precision for GPU processing"
92 ))
93 }
94 };
95
96 if query.dimensions != self.dimension {
97 return Err(anyhow!(
98 "Query dimension mismatch: expected {}, got {}",
99 self.dimension,
100 query.dimensions
101 ));
102 }
103
104 let similarities = self.accelerator.compute_similarity(
106 &query_data,
107 &self.vector_data,
108 1,
109 self.vectors.len(),
110 self.dimension,
111 metric,
112 )?;
113
114 let mut results: Vec<(usize, f32)> = similarities.into_iter().enumerate().collect();
116
117 match metric {
119 SimilarityMetric::Cosine | SimilarityMetric::Pearson | SimilarityMetric::Jaccard => {
120 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
121 }
122 _ => {
123 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
124 }
125 }
126
127 results.truncate(k);
128 Ok(results)
129 }
130
131 pub fn batch_search(
133 &self,
134 queries: &[Vector],
135 k: usize,
136 metric: SimilarityMetric,
137 ) -> Result<Vec<Vec<(usize, f32)>>> {
138 let mut results = Vec::new();
139
140 for query in queries {
141 let query_results = self.search(query, k, metric)?;
142 results.push(query_results);
143 }
144
145 Ok(results)
146 }
147
148 pub fn len(&self) -> usize {
150 self.vectors.len()
151 }
152
153 pub fn is_empty(&self) -> bool {
155 self.vectors.is_empty()
156 }
157
158 pub fn dimension(&self) -> usize {
160 self.dimension
161 }
162
163 pub fn performance_stats(&self) -> Arc<parking_lot::RwLock<GpuPerformanceStats>> {
165 self.accelerator.performance_stats()
166 }
167
168 pub fn clear(&mut self) {
170 self.vectors.clear();
171 self.vector_data.clear();
172 self.dimension = 0;
173 self.accelerator.reset_stats();
174 }
175}
176
177impl crate::VectorIndex for GpuVectorIndex {
178 fn insert(&mut self, uri: String, vector: crate::Vector) -> Result<()> {
179 let index = self.vectors.len();
181 self.uri_map.insert(uri, index);
182 self.add_vectors(vec![vector])?;
183 Ok(())
184 }
185
186 fn search_knn(&self, query: &crate::Vector, k: usize) -> Result<Vec<(String, f32)>> {
187 let results = self.search(query, k, SimilarityMetric::Cosine)?;
188 Ok(results
189 .into_iter()
190 .filter_map(|(index, score)| {
191 self.uri_map
193 .iter()
194 .find(|&(_, &idx)| idx == index)
195 .map(|(uri, _)| (uri.clone(), score))
196 })
197 .collect())
198 }
199
200 fn search_threshold(
201 &self,
202 query: &crate::Vector,
203 threshold: f32,
204 ) -> Result<Vec<(String, f32)>> {
205 let results = self.search(query, 1000, SimilarityMetric::Cosine)?;
207 Ok(results
208 .into_iter()
209 .filter(|(_, score)| *score >= threshold)
210 .filter_map(|(index, score)| {
211 self.uri_map
212 .iter()
213 .find(|&(_, &idx)| idx == index)
214 .map(|(uri, _)| (uri.clone(), score))
215 })
216 .collect())
217 }
218
219 fn get_vector(&self, uri: &str) -> Option<&crate::Vector> {
220 if let Some(&index) = self.uri_map.get(uri) {
221 self.vectors.get(index)
222 } else {
223 None
224 }
225 }
226}
227
228#[derive(Debug)]
230pub struct AdvancedGpuVectorIndex {
231 base_index: GpuVectorIndex,
232 enable_quantization: bool,
233 quantization_bits: u8,
234 use_tensor_cores: bool,
235}
236
237impl AdvancedGpuVectorIndex {
238 pub fn new(mut config: GpuConfig) -> Result<Self> {
240 config.enable_tensor_cores = true;
241 config.enable_mixed_precision = true;
242
243 let base_index = GpuVectorIndex::new(config)?;
244
245 Ok(Self {
246 base_index,
247 enable_quantization: false,
248 quantization_bits: 8,
249 use_tensor_cores: true,
250 })
251 }
252
253 pub fn enable_quantization(&mut self, bits: u8) {
255 self.enable_quantization = true;
256 self.quantization_bits = bits;
257 }
258
259 pub fn batch_process(
261 &self,
262 queries: &[Vector],
263 batch_size: usize,
264 k: usize,
265 metric: SimilarityMetric,
266 ) -> Result<Vec<Vec<(usize, f32)>>> {
267 let mut all_results = Vec::new();
268
269 for batch in queries.chunks(batch_size) {
270 let batch_results = self.base_index.batch_search(batch, k, metric)?;
271 all_results.extend(batch_results);
272 }
273
274 Ok(all_results)
275 }
276
277 pub fn memory_stats(&self) -> Result<MemoryUsageStats> {
279 let device = self.base_index.accelerator.device();
280 let pool_stats = self.base_index.memory_pool.read().stats();
281
282 Ok(MemoryUsageStats {
283 total_gpu_memory: device.total_memory,
284 free_gpu_memory: device.free_memory,
285 used_by_index: pool_stats.used_memory,
286 vector_count: self.base_index.len(),
287 dimension: self.base_index.dimension(),
288 memory_per_vector: if !self.base_index.is_empty() {
289 pool_stats.used_memory / self.base_index.len()
290 } else {
291 0
292 },
293 })
294 }
295}
296
297#[derive(Debug, Clone)]
299pub struct MemoryUsageStats {
300 pub total_gpu_memory: usize,
301 pub free_gpu_memory: usize,
302 pub used_by_index: usize,
303 pub vector_count: usize,
304 pub dimension: usize,
305 pub memory_per_vector: usize,
306}
307
308impl MemoryUsageStats {
309 pub fn utilization(&self) -> f64 {
311 if self.total_gpu_memory > 0 {
312 (self.total_gpu_memory - self.free_gpu_memory) as f64 / self.total_gpu_memory as f64
313 } else {
314 0.0
315 }
316 }
317
318 pub fn print(&self) {
320 println!("GPU Vector Index Memory Usage:");
321 println!(
322 " Total GPU Memory: {:.2} GB",
323 self.total_gpu_memory as f64 / 1024.0 / 1024.0 / 1024.0
324 );
325 println!(
326 " Free GPU Memory: {:.2} GB",
327 self.free_gpu_memory as f64 / 1024.0 / 1024.0 / 1024.0
328 );
329 println!(
330 " Used by Index: {:.2} MB",
331 self.used_by_index as f64 / 1024.0 / 1024.0
332 );
333 println!(" Vectors: {} ({}D)", self.vector_count, self.dimension);
334 println!(
335 " Memory per Vector: {:.2} KB",
336 self.memory_per_vector as f64 / 1024.0
337 );
338 println!(" GPU Utilization: {:.1}%", self.utilization() * 100.0);
339 }
340}
341
342#[derive(Debug)]
344pub struct BatchVectorProcessor {
345 accelerator: Arc<GpuAccelerator>,
346 batch_size: usize,
347 max_concurrent_batches: usize,
348}
349
350impl BatchVectorProcessor {
351 pub fn new(config: GpuConfig, batch_size: usize) -> Result<Self> {
353 let accelerator = Arc::new(GpuAccelerator::new(config)?);
354 let max_concurrent_batches = 4; Ok(Self {
357 accelerator,
358 batch_size,
359 max_concurrent_batches,
360 })
361 }
362
363 pub fn process_batches<F, R>(&self, vectors: &[Vector], operation: F) -> Result<Vec<R>>
365 where
366 F: Fn(&[Vector]) -> Result<Vec<R>> + Send + Sync,
367 R: Send,
368 {
369 let mut results = Vec::new();
370
371 for batch in vectors.chunks(self.batch_size) {
372 let batch_results = operation(batch)?;
373 results.extend(batch_results);
374 }
375
376 Ok(results)
377 }
378
379 pub fn parallel_process_batches<F, R>(&self, vectors: &[Vector], operation: F) -> Result<Vec<R>>
381 where
382 F: Fn(&[Vector]) -> Result<Vec<R>> + Send + Sync + Clone + 'static,
383 R: Send + 'static,
384 {
385 use std::thread;
386
387 let chunks: Vec<&[Vector]> = vectors.chunks(self.batch_size).collect();
388 let mut handles = Vec::new();
389 let mut results = Vec::new();
390
391 for chunk_batch in chunks.chunks(self.max_concurrent_batches) {
392 for chunk in chunk_batch {
393 let chunk_vec = chunk.to_vec();
394 let op = operation.clone();
395
396 let handle = thread::spawn(move || op(&chunk_vec));
397 handles.push(handle);
398 }
399
400 for handle in handles.drain(..) {
402 match handle.join() {
403 Ok(Ok(batch_results)) => results.extend(batch_results),
404 Ok(Err(e)) => return Err(e),
405 Err(_) => return Err(anyhow!("Thread panicked during batch processing")),
406 }
407 }
408 }
409
410 Ok(results)
411 }
412
413 pub fn throughput(&self, vectors_processed: usize, duration: std::time::Duration) -> f64 {
415 if duration.as_secs_f64() > 0.0 {
416 vectors_processed as f64 / duration.as_secs_f64()
417 } else {
418 0.0
419 }
420 }
421}