1use scirs2_core::ndarray_ext::{Array1, Array2, Axis};
36use scirs2_core::parallel_ops::{IntoParallelIterator, ParallelIterator};
37
38use crate::Result;
39
40use dashmap::DashMap;
41use parking_lot::RwLock;
42use serde::{Deserialize, Serialize};
43use std::sync::Arc;
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct GpuConfig {
48 pub device_type: DeviceSelection,
50
51 pub use_mixed_precision: bool,
53
54 pub batch_size: usize,
56
57 pub auto_fallback: bool,
59
60 pub enable_caching: bool,
62}
63
64#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
65pub enum DeviceSelection {
66 Auto,
68 #[allow(dead_code)]
70 Cuda,
71 #[allow(dead_code)]
73 Metal,
74 Cpu,
76}
77
78impl Default for GpuConfig {
79 fn default() -> Self {
80 Self {
81 device_type: DeviceSelection::Cpu, use_mixed_precision: false,
83 batch_size: 1024,
84 auto_fallback: true,
85 enable_caching: true,
86 }
87 }
88}
89
90impl GpuConfig {
91 pub fn auto_detect() -> Self {
93 Self::default()
94 }
95
96 pub fn cpu_only() -> Self {
98 Self::default()
99 }
100
101 pub fn high_performance() -> Self {
103 Self {
104 batch_size: 4096,
105 enable_caching: true,
106 ..Default::default()
107 }
108 }
109
110 pub fn low_memory() -> Self {
112 Self {
113 batch_size: 256,
114 enable_caching: false,
115 ..Default::default()
116 }
117 }
118}
119
120#[derive(Debug, Clone, Default, Serialize, Deserialize)]
122pub struct GpuOperationStats {
123 pub total_operations: u64,
125
126 pub gpu_operations: u64,
128
129 pub cpu_fallback_operations: u64,
131
132 pub total_time_ms: f64,
134
135 pub cache_hits: u64,
137
138 pub cache_misses: u64,
140}
141
142impl GpuOperationStats {
143 pub fn cache_hit_rate(&self) -> f64 {
145 let total = self.cache_hits + self.cache_misses;
146 if total == 0 {
147 0.0
148 } else {
149 (self.cache_hits as f64 / total as f64) * 100.0
150 }
151 }
152
153 pub fn avg_time_ms(&self) -> f64 {
155 if self.total_operations == 0 {
156 0.0
157 } else {
158 self.total_time_ms / self.total_operations as f64
159 }
160 }
161}
162
163pub struct GpuQueryEngine {
167 config: GpuConfig,
169
170 stats: Arc<RwLock<GpuOperationStats>>,
172
173 result_cache: Arc<DashMap<u64, Vec<f32>>>,
175}
176
177impl GpuQueryEngine {
178 pub fn new(config: GpuConfig) -> Result<Self> {
180 Ok(Self {
181 config,
182 stats: Arc::new(RwLock::new(GpuOperationStats::default())),
183 result_cache: Arc::new(DashMap::new()),
184 })
185 }
186
187 pub fn vector_similarity_search(
199 &self,
200 embeddings: &Array2<f32>,
201 query: &Array1<f32>,
202 top_k: usize,
203 ) -> Result<Vec<(usize, f32)>> {
204 let start = std::time::Instant::now();
205
206 let mut stats = self.stats.write();
207 stats.total_operations += 1;
208
209 let query_hash = self.hash_query(query);
211 if self.config.enable_caching {
212 if let Some(cached) = self.result_cache.get(&query_hash) {
213 stats.cache_hits += 1;
214 let results = Self::extract_top_k(&cached, top_k);
215 stats.total_time_ms += start.elapsed().as_secs_f64() * 1000.0;
216 return Ok(results);
217 }
218 stats.cache_misses += 1;
219 }
220
221 stats.cpu_fallback_operations += 1;
223 let results = self.simd_similarity_search_impl(embeddings, query, top_k)?;
224
225 if self.config.enable_caching {
227 let scores: Vec<f32> = results.iter().map(|(_, score)| *score).collect();
228 self.result_cache.insert(query_hash, scores);
229 }
230
231 stats.total_time_ms += start.elapsed().as_secs_f64() * 1000.0;
232 drop(stats);
233
234 Ok(results)
235 }
236
237 fn simd_similarity_search_impl(
239 &self,
240 embeddings: &Array2<f32>,
241 query: &Array1<f32>,
242 top_k: usize,
243 ) -> Result<Vec<(usize, f32)>> {
244 let query_slice = query
246 .as_slice()
247 .ok_or_else(|| anyhow::anyhow!("Query vector must be contiguous"))?;
248
249 let similarities: Vec<f32> = embeddings
250 .axis_iter(Axis(0))
251 .into_par_iter()
252 .map(|embedding| {
253 let emb_slice = embedding
254 .as_slice()
255 .expect("embedding array should be contiguous");
256 Self::cosine_similarity_simd(emb_slice, query_slice)
257 })
258 .collect();
259
260 Ok(Self::extract_top_k(&similarities, top_k))
261 }
262
263 fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
265 let mut dot = 0.0f32;
269 let mut norm_a_sq = 0.0f32;
270 let mut norm_b_sq = 0.0f32;
271
272 for i in 0..a.len().min(b.len()) {
274 dot += a[i] * b[i];
275 norm_a_sq += a[i] * a[i];
276 norm_b_sq += b[i] * b[i];
277 }
278
279 let norm_a = norm_a_sq.sqrt();
280 let norm_b = norm_b_sq.sqrt();
281
282 if norm_a == 0.0 || norm_b == 0.0 {
283 0.0
284 } else {
285 dot / (norm_a * norm_b)
286 }
287 }
288
289 fn extract_top_k(scores: &[f32], top_k: usize) -> Vec<(usize, f32)> {
291 let mut indexed: Vec<_> = scores.iter().enumerate().map(|(i, &s)| (i, s)).collect();
292 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
293 indexed.truncate(top_k);
294 indexed
295 }
296
297 fn hash_query(&self, query: &Array1<f32>) -> u64 {
299 use std::collections::hash_map::DefaultHasher;
300 use std::hash::{Hash, Hasher};
301
302 let mut hasher = DefaultHasher::new();
303 for &v in query.iter() {
304 v.to_bits().hash(&mut hasher);
305 }
306 hasher.finish()
307 }
308
309 pub fn stats(&self) -> GpuOperationStats {
311 self.stats.read().clone()
312 }
313
314 pub fn reset_stats(&self) {
316 *self.stats.write() = GpuOperationStats::default();
317 }
318
319 pub fn clear_cache(&self) {
321 self.result_cache.clear();
322 }
323
324 pub fn gpu_info(&self) -> Option<String> {
326 Some(format!(
328 "CPU-optimized SIMD mode (batch_size: {})",
329 self.config.batch_size
330 ))
331 }
332
333 pub fn is_gpu_available(&self) -> bool {
335 false }
337
338 pub fn config(&self) -> &GpuConfig {
340 &self.config
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use scirs2_core::ndarray_ext::array;
348
349 #[test]
350 fn test_gpu_config_creation() {
351 let config = GpuConfig::auto_detect();
352 assert!(config.auto_fallback);
353 assert_eq!(config.batch_size, 1024);
354
355 let cpu_config = GpuConfig::cpu_only();
356 assert!(matches!(cpu_config.device_type, DeviceSelection::Cpu));
357
358 let high_perf = GpuConfig::high_performance();
359 assert_eq!(high_perf.batch_size, 4096);
360
361 let low_mem = GpuConfig::low_memory();
362 assert_eq!(low_mem.batch_size, 256);
363 }
364
365 #[test]
366 fn test_gpu_stats() {
367 let stats = GpuOperationStats {
368 total_operations: 100,
369 cpu_fallback_operations: 100,
370 cache_hits: 30,
371 cache_misses: 70,
372 total_time_ms: 100.0,
373 ..Default::default()
374 };
375
376 assert_eq!(stats.cache_hit_rate(), 30.0);
377 assert_eq!(stats.avg_time_ms(), 1.0);
378 }
379
380 #[test]
381 fn test_engine_creation() {
382 let config = GpuConfig::cpu_only();
383 let engine = GpuQueryEngine::new(config);
384 assert!(engine.is_ok());
385
386 let engine = engine.unwrap();
387 assert!(!engine.is_gpu_available());
388 assert!(matches!(engine.config().device_type, DeviceSelection::Cpu));
389 }
390
391 #[test]
392 fn test_vector_similarity_cpu() {
393 let config = GpuConfig::cpu_only();
394 let engine = GpuQueryEngine::new(config).unwrap();
395
396 let embeddings = array![
398 [1.0, 0.0, 0.0],
399 [0.0, 1.0, 0.0],
400 [0.0, 0.0, 1.0],
401 [0.707, 0.707, 0.0],
402 ];
403 let query = array![1.0, 0.0, 0.0];
404
405 let results = engine.vector_similarity_search(&embeddings, &query, 2);
406 assert!(results.is_ok());
407
408 let results = results.unwrap();
409 assert_eq!(results.len(), 2);
410 assert_eq!(results[0].0, 0); assert!((results[0].1 - 1.0).abs() < 1e-6);
412 }
413
414 #[test]
415 fn test_stats_tracking() {
416 let config = GpuConfig::cpu_only();
417 let engine = GpuQueryEngine::new(config).unwrap();
418
419 let embeddings = array![[1.0, 0.0], [0.0, 1.0]];
420 let query = array![1.0, 0.0];
421
422 let _ = engine.vector_similarity_search(&embeddings, &query, 1);
424 let _ = engine.vector_similarity_search(&embeddings, &query, 1); let stats = engine.stats();
427 assert_eq!(stats.total_operations, 2);
428 assert_eq!(stats.cache_hits, 1); }
430
431 #[test]
432 fn test_cache_operations() {
433 let config = GpuConfig::cpu_only();
434 let engine = GpuQueryEngine::new(config).unwrap();
435
436 let embeddings = array![[1.0, 0.0], [0.0, 1.0]];
437 let query = array![1.0, 0.0];
438
439 let _ = engine.vector_similarity_search(&embeddings, &query, 1);
441 assert_eq!(engine.stats().cache_misses, 1);
442
443 let _ = engine.vector_similarity_search(&embeddings, &query, 1);
445 assert_eq!(engine.stats().cache_hits, 1);
446
447 engine.clear_cache();
449
450 let _ = engine.vector_similarity_search(&embeddings, &query, 1);
452 assert_eq!(engine.stats().cache_misses, 2);
453 }
454
455 #[test]
456 fn test_extract_top_k() {
457 let scores = vec![0.1, 0.9, 0.3, 0.7, 0.5];
458 let top_3 = GpuQueryEngine::extract_top_k(&scores, 3);
459
460 assert_eq!(top_3.len(), 3);
461 assert_eq!(top_3[0].0, 1); assert_eq!(top_3[1].0, 3); assert_eq!(top_3[2].0, 4); }
465
466 #[test]
467 fn test_cosine_similarity() {
468 let a = vec![1.0, 0.0, 0.0];
469 let b = vec![1.0, 0.0, 0.0];
470 let sim = GpuQueryEngine::cosine_similarity_simd(&a, &b);
471 assert!((sim - 1.0).abs() < 1e-6);
472
473 let a = vec![1.0, 0.0];
474 let b = vec![0.0, 1.0];
475 let sim = GpuQueryEngine::cosine_similarity_simd(&a, &b);
476 assert!(sim.abs() < 1e-6);
477
478 let a = vec![1.0, 1.0];
479 let b = vec![1.0, 1.0];
480 let sim = GpuQueryEngine::cosine_similarity_simd(&a, &b);
481 assert!((sim - 1.0).abs() < 1e-6);
482 }
483
484 #[test]
485 fn test_high_performance_config() {
486 let config = GpuConfig::high_performance();
487 let engine = GpuQueryEngine::new(config).unwrap();
488
489 let embeddings = array![[1.0, 0.0], [0.0, 1.0]];
490 let query = array![1.0, 0.0];
491
492 let results = engine.vector_similarity_search(&embeddings, &query, 1);
493 assert!(results.is_ok());
494 assert_eq!(results.unwrap().len(), 1);
495 }
496}