use scirs2_core::ndarray_ext::{Array1, Array2, Axis};
use scirs2_core::parallel_ops::{IntoParallelIterator, ParallelIterator};
use crate::Result;
use dashmap::DashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuConfig {
pub device_type: DeviceSelection,
pub use_mixed_precision: bool,
pub batch_size: usize,
pub auto_fallback: bool,
pub enable_caching: bool,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum DeviceSelection {
Auto,
#[allow(dead_code)]
Cuda,
#[allow(dead_code)]
Metal,
Cpu,
}
impl Default for GpuConfig {
fn default() -> Self {
Self {
device_type: DeviceSelection::Cpu, use_mixed_precision: false,
batch_size: 1024,
auto_fallback: true,
enable_caching: true,
}
}
}
impl GpuConfig {
pub fn auto_detect() -> Self {
Self::default()
}
pub fn cpu_only() -> Self {
Self::default()
}
pub fn high_performance() -> Self {
Self {
batch_size: 4096,
enable_caching: true,
..Default::default()
}
}
pub fn low_memory() -> Self {
Self {
batch_size: 256,
enable_caching: false,
..Default::default()
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GpuOperationStats {
pub total_operations: u64,
pub gpu_operations: u64,
pub cpu_fallback_operations: u64,
pub total_time_ms: f64,
pub cache_hits: u64,
pub cache_misses: u64,
}
impl GpuOperationStats {
pub fn cache_hit_rate(&self) -> f64 {
let total = self.cache_hits + self.cache_misses;
if total == 0 {
0.0
} else {
(self.cache_hits as f64 / total as f64) * 100.0
}
}
pub fn avg_time_ms(&self) -> f64 {
if self.total_operations == 0 {
0.0
} else {
self.total_time_ms / self.total_operations as f64
}
}
}
pub struct GpuQueryEngine {
config: GpuConfig,
stats: Arc<RwLock<GpuOperationStats>>,
result_cache: Arc<DashMap<u64, Vec<f32>>>,
}
impl GpuQueryEngine {
pub fn new(config: GpuConfig) -> Result<Self> {
Ok(Self {
config,
stats: Arc::new(RwLock::new(GpuOperationStats::default())),
result_cache: Arc::new(DashMap::new()),
})
}
pub fn vector_similarity_search(
&self,
embeddings: &Array2<f32>,
query: &Array1<f32>,
top_k: usize,
) -> Result<Vec<(usize, f32)>> {
let start = std::time::Instant::now();
let mut stats = self.stats.write();
stats.total_operations += 1;
let query_hash = self.hash_query(query);
if self.config.enable_caching {
if let Some(cached) = self.result_cache.get(&query_hash) {
stats.cache_hits += 1;
let results = Self::extract_top_k(&cached, top_k);
stats.total_time_ms += start.elapsed().as_secs_f64() * 1000.0;
return Ok(results);
}
stats.cache_misses += 1;
}
stats.cpu_fallback_operations += 1;
let results = self.simd_similarity_search_impl(embeddings, query, top_k)?;
if self.config.enable_caching {
let scores: Vec<f32> = results.iter().map(|(_, score)| *score).collect();
self.result_cache.insert(query_hash, scores);
}
stats.total_time_ms += start.elapsed().as_secs_f64() * 1000.0;
drop(stats);
Ok(results)
}
fn simd_similarity_search_impl(
&self,
embeddings: &Array2<f32>,
query: &Array1<f32>,
top_k: usize,
) -> Result<Vec<(usize, f32)>> {
let query_slice = query
.as_slice()
.ok_or_else(|| anyhow::anyhow!("Query vector must be contiguous"))?;
let similarities: Vec<f32> = embeddings
.axis_iter(Axis(0))
.into_par_iter()
.map(|embedding| {
let emb_slice = embedding
.as_slice()
.expect("embedding array should be contiguous");
Self::cosine_similarity_simd(emb_slice, query_slice)
})
.collect();
Ok(Self::extract_top_k(&similarities, top_k))
}
fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0f32;
let mut norm_a_sq = 0.0f32;
let mut norm_b_sq = 0.0f32;
for i in 0..a.len().min(b.len()) {
dot += a[i] * b[i];
norm_a_sq += a[i] * a[i];
norm_b_sq += b[i] * b[i];
}
let norm_a = norm_a_sq.sqrt();
let norm_b = norm_b_sq.sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
fn extract_top_k(scores: &[f32], top_k: usize) -> Vec<(usize, f32)> {
let mut indexed: Vec<_> = scores.iter().enumerate().map(|(i, &s)| (i, s)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(top_k);
indexed
}
fn hash_query(&self, query: &Array1<f32>) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
for &v in query.iter() {
v.to_bits().hash(&mut hasher);
}
hasher.finish()
}
pub fn stats(&self) -> GpuOperationStats {
self.stats.read().clone()
}
pub fn reset_stats(&self) {
*self.stats.write() = GpuOperationStats::default();
}
pub fn clear_cache(&self) {
self.result_cache.clear();
}
pub fn gpu_info(&self) -> Option<String> {
Some(format!(
"CPU-optimized SIMD mode (batch_size: {})",
self.config.batch_size
))
}
pub fn is_gpu_available(&self) -> bool {
false }
pub fn config(&self) -> &GpuConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray_ext::array;
#[test]
fn test_gpu_config_creation() {
let config = GpuConfig::auto_detect();
assert!(config.auto_fallback);
assert_eq!(config.batch_size, 1024);
let cpu_config = GpuConfig::cpu_only();
assert!(matches!(cpu_config.device_type, DeviceSelection::Cpu));
let high_perf = GpuConfig::high_performance();
assert_eq!(high_perf.batch_size, 4096);
let low_mem = GpuConfig::low_memory();
assert_eq!(low_mem.batch_size, 256);
}
#[test]
fn test_gpu_stats() {
let stats = GpuOperationStats {
total_operations: 100,
cpu_fallback_operations: 100,
cache_hits: 30,
cache_misses: 70,
total_time_ms: 100.0,
..Default::default()
};
assert_eq!(stats.cache_hit_rate(), 30.0);
assert_eq!(stats.avg_time_ms(), 1.0);
}
#[test]
fn test_engine_creation() {
let config = GpuConfig::cpu_only();
let engine = GpuQueryEngine::new(config);
assert!(engine.is_ok());
let engine = engine.unwrap();
assert!(!engine.is_gpu_available());
assert!(matches!(engine.config().device_type, DeviceSelection::Cpu));
}
#[test]
fn test_vector_similarity_cpu() {
let config = GpuConfig::cpu_only();
let engine = GpuQueryEngine::new(config).unwrap();
let embeddings = array![
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.707, 0.707, 0.0],
];
let query = array![1.0, 0.0, 0.0];
let results = engine.vector_similarity_search(&embeddings, &query, 2);
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 0); assert!((results[0].1 - 1.0).abs() < 1e-6);
}
#[test]
fn test_stats_tracking() {
let config = GpuConfig::cpu_only();
let engine = GpuQueryEngine::new(config).unwrap();
let embeddings = array![[1.0, 0.0], [0.0, 1.0]];
let query = array![1.0, 0.0];
let _ = engine.vector_similarity_search(&embeddings, &query, 1);
let _ = engine.vector_similarity_search(&embeddings, &query, 1);
let stats = engine.stats();
assert_eq!(stats.total_operations, 2);
assert_eq!(stats.cache_hits, 1); }
#[test]
fn test_cache_operations() {
let config = GpuConfig::cpu_only();
let engine = GpuQueryEngine::new(config).unwrap();
let embeddings = array![[1.0, 0.0], [0.0, 1.0]];
let query = array![1.0, 0.0];
let _ = engine.vector_similarity_search(&embeddings, &query, 1);
assert_eq!(engine.stats().cache_misses, 1);
let _ = engine.vector_similarity_search(&embeddings, &query, 1);
assert_eq!(engine.stats().cache_hits, 1);
engine.clear_cache();
let _ = engine.vector_similarity_search(&embeddings, &query, 1);
assert_eq!(engine.stats().cache_misses, 2);
}
#[test]
fn test_extract_top_k() {
let scores = vec![0.1, 0.9, 0.3, 0.7, 0.5];
let top_3 = GpuQueryEngine::extract_top_k(&scores, 3);
assert_eq!(top_3.len(), 3);
assert_eq!(top_3[0].0, 1); assert_eq!(top_3[1].0, 3); assert_eq!(top_3[2].0, 4); }
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let sim = GpuQueryEngine::cosine_similarity_simd(&a, &b);
assert!((sim - 1.0).abs() < 1e-6);
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let sim = GpuQueryEngine::cosine_similarity_simd(&a, &b);
assert!(sim.abs() < 1e-6);
let a = vec![1.0, 1.0];
let b = vec![1.0, 1.0];
let sim = GpuQueryEngine::cosine_similarity_simd(&a, &b);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn test_high_performance_config() {
let config = GpuConfig::high_performance();
let engine = GpuQueryEngine::new(config).unwrap();
let embeddings = array![[1.0, 0.0], [0.0, 1.0]];
let query = array![1.0, 0.0];
let results = engine.vector_similarity_search(&embeddings, &query, 1);
assert!(results.is_ok());
assert_eq!(results.unwrap().len(), 1);
}
}