1use crate::types::DistanceMetric;
36use anyhow::{anyhow, Result};
37use serde::{Deserialize, Serialize};
38#[cfg(all(feature = "cuda", target_os = "linux"))]
39use std::sync::Arc;
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct GpuConfig {
44 pub min_batch_size_for_gpu: usize,
46 pub device_id: usize,
48 pub enabled: bool,
50 pub max_batch_size: usize,
52}
53
54impl Default for GpuConfig {
55 fn default() -> Self {
56 Self {
57 min_batch_size_for_gpu: 100,
58 device_id: 0,
59 enabled: true,
60 max_batch_size: 10_000,
61 }
62 }
63}
64
65impl GpuConfig {
66 pub fn cpu_preferred() -> Self {
68 Self {
69 min_batch_size_for_gpu: 10_000,
70 enabled: false,
71 ..Default::default()
72 }
73 }
74
75 pub fn gpu_preferred() -> Self {
77 Self {
78 min_batch_size_for_gpu: 10,
79 enabled: true,
80 max_batch_size: 100_000,
81 ..Default::default()
82 }
83 }
84}
85
86pub struct GpuBatchProcessor {
91 config: GpuConfig,
92 #[cfg(all(feature = "cuda", target_os = "linux"))]
93 context: Arc<GpuContext>,
94}
95
96#[cfg(all(feature = "cuda", target_os = "linux"))]
97struct GpuContext {
98 _ctx: Arc<cudarc::driver::CudaContext>,
99}
100
101impl GpuBatchProcessor {
102 pub fn new(config: GpuConfig) -> Result<Self> {
106 #[cfg(all(feature = "cuda", target_os = "linux"))]
107 {
108 if config.enabled {
109 let ctx = cudarc::driver::CudaContext::new(config.device_id)
110 .map_err(|e| anyhow!("Failed to initialize CUDA context: {}", e))?;
111
112 Ok(Self {
113 config,
114 context: Arc::new(GpuContext { _ctx: ctx }),
115 })
116 } else {
117 Ok(Self {
118 config,
119 context: Arc::new(GpuContext {
120 _ctx: cudarc::driver::CudaContext::new(0)
121 .map_err(|e| anyhow!("Failed to create default CUDA context: {}", e))?,
122 }),
123 })
124 }
125 }
126
127 #[cfg(not(all(feature = "cuda", target_os = "linux")))]
128 {
129 if config.enabled {
130 tracing::warn!(
131 "GPU acceleration requested but CUDA feature not enabled. Using CPU fallback."
132 );
133 }
134 Ok(Self { config })
135 }
136 }
137
138 pub fn is_gpu_available(&self) -> bool {
140 #[cfg(all(feature = "cuda", target_os = "linux"))]
141 {
142 self.config.enabled
143 }
144 #[cfg(not(all(feature = "cuda", target_os = "linux")))]
145 {
146 false
147 }
148 }
149
150 pub fn batch_distance(
165 &self,
166 queries: &[Vec<f32>],
167 vectors: &[Vec<f32>],
168 metric: DistanceMetric,
169 ) -> Result<Vec<Vec<f32>>> {
170 if queries.is_empty() || vectors.is_empty() {
171 return Ok(vec![]);
172 }
173
174 let query_dim = queries[0].len();
176 let vector_dim = vectors[0].len();
177 if query_dim != vector_dim {
178 return Err(anyhow!(
179 "Dimension mismatch: queries have {} dims, vectors have {} dims",
180 query_dim,
181 vector_dim
182 ));
183 }
184
185 let use_gpu = self.should_use_gpu(queries.len(), vectors.len());
187
188 if use_gpu {
189 #[cfg(all(feature = "cuda", target_os = "linux"))]
190 {
191 self.batch_distance_gpu(queries, vectors, metric)
192 }
193 #[cfg(not(all(feature = "cuda", target_os = "linux")))]
194 {
195 self.batch_distance_cpu(queries, vectors, metric)
196 }
197 } else {
198 self.batch_distance_cpu(queries, vectors, metric)
199 }
200 }
201
202 fn should_use_gpu(&self, _num_queries: usize, _num_vectors: usize) -> bool {
204 if !self.config.enabled {
205 return false;
206 }
207
208 #[cfg(not(all(feature = "cuda", target_os = "linux")))]
209 {
210 false
211 }
212
213 #[cfg(all(feature = "cuda", target_os = "linux"))]
214 {
215 let total_operations = _num_queries * _num_vectors;
216 total_operations >= self.config.min_batch_size_for_gpu
217 }
218 }
219
220 fn batch_distance_cpu(
222 &self,
223 queries: &[Vec<f32>],
224 vectors: &[Vec<f32>],
225 metric: DistanceMetric,
226 ) -> Result<Vec<Vec<f32>>> {
227 use crate::simd;
228
229 let mut results = vec![vec![0.0; vectors.len()]; queries.len()];
230
231 for (i, query) in queries.iter().enumerate() {
232 for (j, vector) in vectors.iter().enumerate() {
233 let distance = match metric {
234 DistanceMetric::Cosine => 1.0 - simd::cosine_similarity_simd(query, vector),
235 DistanceMetric::Euclidean => simd::euclidean_distance_simd(query, vector),
236 DistanceMetric::DotProduct => -simd::dot_product_simd(query, vector),
237 DistanceMetric::Manhattan => simd::manhattan_distance_simd(query, vector),
238 };
239 results[i][j] = distance;
240 }
241 }
242
243 Ok(results)
244 }
245
246 #[cfg(all(feature = "cuda", target_os = "linux"))]
248 fn batch_distance_gpu(
249 &self,
250 queries: &[Vec<f32>],
251 vectors: &[Vec<f32>],
252 metric: DistanceMetric,
253 ) -> Result<Vec<Vec<f32>>> {
254 let num_queries = queries.len();
255 let num_vectors = vectors.len();
256 let dims = queries[0].len();
257
258 let stream = self.context._ctx.default_stream();
260
261 let mut queries_flat = Vec::with_capacity(num_queries * dims);
263 for query in queries {
264 queries_flat.extend_from_slice(query);
265 }
266
267 let mut vectors_flat = Vec::with_capacity(num_vectors * dims);
268 for vector in vectors {
269 vectors_flat.extend_from_slice(vector);
270 }
271
272 let queries_gpu = stream
274 .clone_htod(&queries_flat)
275 .map_err(|e| anyhow!("Failed to copy queries to GPU: {}", e))?;
276
277 let vectors_gpu = stream
278 .clone_htod(&vectors_flat)
279 .map_err(|e| anyhow!("Failed to copy vectors to GPU: {}", e))?;
280
281 let mut results_gpu = stream
282 .alloc_zeros::<f32>(num_queries * num_vectors)
283 .map_err(|e| anyhow!("Failed to allocate GPU memory for results: {}", e))?;
284
285 match metric {
287 DistanceMetric::Cosine => {
288 launch_cosine_kernel(
289 &self.context._ctx,
290 &queries_gpu,
291 &vectors_gpu,
292 &mut results_gpu,
293 num_queries,
294 num_vectors,
295 dims,
296 )?;
297 }
298 DistanceMetric::Euclidean => {
299 launch_euclidean_kernel(
300 &self.context._ctx,
301 &queries_gpu,
302 &vectors_gpu,
303 &mut results_gpu,
304 num_queries,
305 num_vectors,
306 dims,
307 )?;
308 }
309 DistanceMetric::DotProduct => {
310 launch_dot_product_kernel(
311 &self.context._ctx,
312 &queries_gpu,
313 &vectors_gpu,
314 &mut results_gpu,
315 num_queries,
316 num_vectors,
317 dims,
318 )?;
319 }
320 DistanceMetric::Manhattan => {
321 launch_manhattan_kernel(
322 &self.context._ctx,
323 &queries_gpu,
324 &vectors_gpu,
325 &mut results_gpu,
326 num_queries,
327 num_vectors,
328 dims,
329 )?;
330 }
331 }
332
333 let results_flat: Vec<f32> = stream
335 .clone_dtoh(&results_gpu)
336 .map_err(|e| anyhow!("Failed to copy results from GPU: {}", e))?;
337
338 let mut results = vec![vec![0.0; num_vectors]; num_queries];
340 for i in 0..num_queries {
341 for j in 0..num_vectors {
342 results[i][j] = results_flat[i * num_vectors + j];
343 }
344 }
345
346 Ok(results)
347 }
348}
349
350#[cfg(all(feature = "cuda", target_os = "linux"))]
351fn launch_cosine_kernel(
352 _ctx: &Arc<cudarc::driver::CudaContext>,
353 _queries: &cudarc::driver::CudaSlice<f32>,
354 _vectors: &cudarc::driver::CudaSlice<f32>,
355 _results: &mut cudarc::driver::CudaSlice<f32>,
356 _num_queries: usize,
357 _num_vectors: usize,
358 _dims: usize,
359) -> Result<()> {
360 Err(anyhow!("CUDA kernel not yet implemented"))
363}
364
365#[cfg(all(feature = "cuda", target_os = "linux"))]
366fn launch_euclidean_kernel(
367 _ctx: &Arc<cudarc::driver::CudaContext>,
368 _queries: &cudarc::driver::CudaSlice<f32>,
369 _vectors: &cudarc::driver::CudaSlice<f32>,
370 _results: &mut cudarc::driver::CudaSlice<f32>,
371 _num_queries: usize,
372 _num_vectors: usize,
373 _dims: usize,
374) -> Result<()> {
375 Err(anyhow!("CUDA kernel not yet implemented"))
377}
378
379#[cfg(all(feature = "cuda", target_os = "linux"))]
380fn launch_dot_product_kernel(
381 _ctx: &Arc<cudarc::driver::CudaContext>,
382 _queries: &cudarc::driver::CudaSlice<f32>,
383 _vectors: &cudarc::driver::CudaSlice<f32>,
384 _results: &mut cudarc::driver::CudaSlice<f32>,
385 _num_queries: usize,
386 _num_vectors: usize,
387 _dims: usize,
388) -> Result<()> {
389 Err(anyhow!("CUDA kernel not yet implemented"))
391}
392
393#[cfg(all(feature = "cuda", target_os = "linux"))]
394fn launch_manhattan_kernel(
395 _ctx: &Arc<cudarc::driver::CudaContext>,
396 _queries: &cudarc::driver::CudaSlice<f32>,
397 _vectors: &cudarc::driver::CudaSlice<f32>,
398 _results: &mut cudarc::driver::CudaSlice<f32>,
399 _num_queries: usize,
400 _num_vectors: usize,
401 _dims: usize,
402) -> Result<()> {
403 Err(anyhow!("CUDA kernel not yet implemented"))
405}
406
407#[derive(Debug, Clone, Serialize, Deserialize)]
409pub struct GpuStats {
410 pub total_operations: u64,
412 pub gpu_operations: u64,
414 pub cpu_operations: u64,
416 pub avg_batch_size: f64,
418}
419
420impl Default for GpuStats {
421 fn default() -> Self {
422 Self {
423 total_operations: 0,
424 gpu_operations: 0,
425 cpu_operations: 0,
426 avg_batch_size: 0.0,
427 }
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_gpu_config_default() {
437 let config = GpuConfig::default();
438 assert_eq!(config.min_batch_size_for_gpu, 100);
439 assert_eq!(config.device_id, 0);
440 assert!(config.enabled);
441 }
442
443 #[test]
444 fn test_gpu_config_cpu_preferred() {
445 let config = GpuConfig::cpu_preferred();
446 assert_eq!(config.min_batch_size_for_gpu, 10_000);
447 assert!(!config.enabled);
448 }
449
450 #[test]
451 fn test_gpu_config_gpu_preferred() {
452 let config = GpuConfig::gpu_preferred();
453 assert_eq!(config.min_batch_size_for_gpu, 10);
454 assert!(config.enabled);
455 assert_eq!(config.max_batch_size, 100_000);
456 }
457
458 #[test]
459 fn test_gpu_processor_creation_cpu_fallback() {
460 let config = GpuConfig::cpu_preferred();
462 let processor = GpuBatchProcessor::new(config);
463 assert!(processor.is_ok());
464 }
465
466 #[test]
467 fn test_gpu_availability() {
468 let config = GpuConfig::cpu_preferred();
469 let processor = GpuBatchProcessor::new(config).unwrap();
470
471 #[cfg(all(feature = "cuda", target_os = "linux"))]
472 {
473 assert!(!processor.is_gpu_available());
475 }
476
477 #[cfg(not(all(feature = "cuda", target_os = "linux")))]
478 {
479 assert!(!processor.is_gpu_available());
481 }
482 }
483
484 #[test]
485 fn test_batch_distance_cpu_cosine() {
486 let config = GpuConfig::cpu_preferred();
487 let processor = GpuBatchProcessor::new(config).unwrap();
488
489 let queries = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
490 let vectors = vec![vec![1.0, 0.0, 0.0], vec![0.0, 0.0, 1.0]];
491
492 let distances = processor
493 .batch_distance(&queries, &vectors, DistanceMetric::Cosine)
494 .unwrap();
495
496 assert_eq!(distances.len(), 2);
497 assert_eq!(distances[0].len(), 2);
498
499 assert!(distances[0][0] < 0.01);
501 assert!(distances[0][1] > 0.99);
503 }
504
505 #[test]
506 fn test_batch_distance_cpu_euclidean() {
507 let config = GpuConfig::cpu_preferred();
508 let processor = GpuBatchProcessor::new(config).unwrap();
509
510 let queries = vec![vec![0.0, 0.0, 0.0]];
511 let vectors = vec![vec![3.0, 4.0, 0.0]];
512
513 let distances = processor
514 .batch_distance(&queries, &vectors, DistanceMetric::Euclidean)
515 .unwrap();
516
517 assert_eq!(distances.len(), 1);
518 assert_eq!(distances[0].len(), 1);
519
520 assert!((distances[0][0] - 5.0).abs() < 0.01);
522 }
523
524 #[test]
525 fn test_batch_distance_empty_input() {
526 let config = GpuConfig::cpu_preferred();
527 let processor = GpuBatchProcessor::new(config).unwrap();
528
529 let queries: Vec<Vec<f32>> = vec![];
530 let vectors = vec![vec![1.0, 2.0, 3.0]];
531
532 let distances = processor
533 .batch_distance(&queries, &vectors, DistanceMetric::Cosine)
534 .unwrap();
535
536 assert!(distances.is_empty());
537 }
538
539 #[test]
540 fn test_batch_distance_dimension_mismatch() {
541 let config = GpuConfig::cpu_preferred();
542 let processor = GpuBatchProcessor::new(config).unwrap();
543
544 let queries = vec![vec![1.0, 2.0, 3.0]];
545 let vectors = vec![vec![1.0, 2.0]]; let result = processor.batch_distance(&queries, &vectors, DistanceMetric::Cosine);
548
549 assert!(result.is_err());
550 }
551
552 #[test]
553 fn test_should_use_gpu_threshold() {
554 let config = GpuConfig {
555 min_batch_size_for_gpu: 100,
556 enabled: true,
557 ..Default::default()
558 };
559 let processor = GpuBatchProcessor::new(config).unwrap();
560
561 assert!(!processor.should_use_gpu(5, 10)); #[cfg(all(feature = "cuda", target_os = "linux"))]
566 assert!(processor.should_use_gpu(10, 20)); #[cfg(not(all(feature = "cuda", target_os = "linux")))]
569 assert!(!processor.should_use_gpu(10, 20)); }
571
572 #[test]
573 fn test_gpu_stats_default() {
574 let stats = GpuStats::default();
575 assert_eq!(stats.total_operations, 0);
576 assert_eq!(stats.gpu_operations, 0);
577 assert_eq!(stats.cpu_operations, 0);
578 assert_eq!(stats.avg_batch_size, 0.0);
579 }
580}