1use candle_core::{Device, Result as CandleResult, Tensor};
2use ndarray::{Array1, Array2};
3use std::sync::OnceLock;
4
5#[derive(Debug, Clone)]
7pub struct GPUAccelerator {
8 device: Device,
9 device_type: DeviceType,
10 available_devices: Vec<Device>,
12 current_device_index: usize,
14}
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum DeviceType {
18 CPU,
19 CUDA,
20 Metal,
21}
22
23static GPU_ACCELERATOR: OnceLock<GPUAccelerator> = OnceLock::new();
24
25impl GPUAccelerator {
26 pub fn global() -> &'static GPUAccelerator {
28 GPU_ACCELERATOR.get_or_init(|| {
29 Self::new().unwrap_or_else(|_| {
30 println!("Warning: GPU acceleration failed to initialize, using CPU fallback");
31 GPUAccelerator {
32 device: Device::Cpu,
33 device_type: DeviceType::CPU,
34 available_devices: vec![Device::Cpu],
35 current_device_index: 0,
36 }
37 })
38 })
39 }
40
41 pub fn new() -> CandleResult<Self> {
43 #[cfg(feature = "cuda")]
46 {
47 match Self::try_cuda() {
48 Ok(accelerator) => {
49 println!("GPU acceleration enabled: CUDA device detected");
50 return Ok(accelerator);
51 }
52 Err(e) => {
53 println!("CUDA initialization failed: {e}, trying Metal...");
54 }
55 }
56 }
57
58 #[cfg(feature = "metal")]
59 {
60 match Self::try_metal() {
61 Ok(accelerator) => {
62 println!("GPU acceleration enabled: Metal device detected");
63 return Ok(accelerator);
64 }
65 Err(e) => {
66 println!("Metal initialization failed: {e}, falling back to CPU");
67 }
68 }
69 }
70
71 println!("GPU acceleration not available, using CPU");
72 Ok(GPUAccelerator {
73 device: Device::Cpu,
74 device_type: DeviceType::CPU,
75 available_devices: vec![Device::Cpu],
76 current_device_index: 0,
77 })
78 }
79
80 #[cfg(feature = "cuda")]
81 fn try_cuda() -> CandleResult<Self> {
82 let mut available_devices = Vec::new();
84 let mut device_count = 0;
85
86 for i in 0..8 {
88 if let Ok(device) = Device::new_cuda(i) {
89 available_devices.push(device);
90 device_count += 1;
91 } else {
92 break;
93 }
94 }
95
96 if available_devices.is_empty() {
97 return Err(candle_core::Error::Msg("No CUDA devices available".into()));
98 }
99
100 println!("🚀 Detected {device_count} CUDA device(s)");
101
102 Ok(GPUAccelerator {
103 device: available_devices[0].clone(),
104 device_type: DeviceType::CUDA,
105 available_devices,
106 current_device_index: 0,
107 })
108 }
109
110 #[cfg(not(feature = "cuda"))]
111 #[allow(dead_code)]
112 fn try_cuda() -> CandleResult<Self> {
113 Err(candle_core::Error::Msg("CUDA not compiled".into()))
114 }
115
116 #[cfg(feature = "metal")]
117 fn try_metal() -> CandleResult<Self> {
118 let mut available_devices = Vec::new();
120 let mut device_count = 0;
121
122 for i in 0..4 {
124 if let Ok(device) = Device::new_metal(i) {
125 available_devices.push(device);
126 device_count += 1;
127 } else {
128 break;
129 }
130 }
131
132 if available_devices.is_empty() {
133 return Err(candle_core::Error::Msg("No Metal devices available".into()));
134 }
135
136 println!("🍎 Detected {device_count} Metal device(s)");
137
138 Ok(GPUAccelerator {
139 device: available_devices[0].clone(),
140 device_type: DeviceType::Metal,
141 available_devices,
142 current_device_index: 0,
143 })
144 }
145
146 #[cfg(not(feature = "metal"))]
147 #[allow(dead_code)]
148 fn try_metal() -> CandleResult<Self> {
149 Err(candle_core::Error::Msg("Metal not compiled".into()))
150 }
151
152 pub fn device_type(&self) -> &DeviceType {
154 &self.device_type
155 }
156
157 pub fn device(&self) -> &Device {
159 &self.device
160 }
161
162 pub fn is_gpu_enabled(&self) -> bool {
164 matches!(self.device_type, DeviceType::CUDA | DeviceType::Metal)
165 }
166
167 pub fn device_count(&self) -> usize {
169 self.available_devices.len()
170 }
171
172 pub fn is_multi_gpu_available(&self) -> bool {
174 self.is_gpu_enabled() && self.available_devices.len() > 1
175 }
176
177 pub fn all_devices(&self) -> &[Device] {
179 &self.available_devices
180 }
181
182 pub fn switch_device(&mut self, device_index: usize) -> Result<(), String> {
184 if device_index >= self.available_devices.len() {
185 return Err(format!(
186 "Device index {} out of range (have {} devices)",
187 device_index,
188 self.available_devices.len()
189 ));
190 }
191
192 self.device = self.available_devices[device_index].clone();
193 self.current_device_index = device_index;
194 Ok(())
195 }
196
197 pub fn current_device_index(&self) -> usize {
199 self.current_device_index
200 }
201
202 pub fn array_to_tensor(&self, array: &Array1<f32>) -> CandleResult<Tensor> {
204 let data = array.as_slice().expect("Array must be contiguous");
205 Tensor::from_slice(data, array.len(), &self.device)
206 }
207
208 pub fn array2_to_tensor(&self, array: &Array2<f32>) -> CandleResult<Tensor> {
210 let shape = array.shape();
211 let data = array.as_slice().expect("Array must be contiguous");
212 Tensor::from_slice(data, (shape[0], shape[1]), &self.device)
213 }
214
215 pub fn tensor_to_array(&self, tensor: &Tensor) -> CandleResult<Array1<f32>> {
217 let data = tensor.to_vec1::<f32>()?;
218 Ok(Array1::from_vec(data))
219 }
220
221 pub fn tensor_to_array2(&self, tensor: &Tensor) -> CandleResult<Array2<f32>> {
223 let dims = tensor.dims();
224 if dims.len() != 2 {
225 return Err(candle_core::Error::Msg("Expected 2D tensor".into()));
226 }
227 let data = tensor.to_vec2::<f32>()?;
228 let flat_data: Vec<f32> = data.into_iter().flatten().collect();
229 Array2::from_shape_vec((dims[0], dims[1]), flat_data)
230 .map_err(|_e| candle_core::Error::Msg("Processing...".to_string()))
231 }
232
233 pub fn cosine_similarity_batch(
235 &self,
236 query: &Array1<f32>,
237 vectors: &Array2<f32>,
238 ) -> CandleResult<Array1<f32>> {
239 if !self.is_gpu_enabled() || vectors.nrows() < 100 {
240 return Ok(self.cosine_similarity_cpu(query, vectors));
242 }
243
244 let query_tensor = self.array_to_tensor(query)?;
246 let vectors_tensor = self.array2_to_tensor(vectors)?;
247
248 let query_norm = query_tensor.sqr()?.sum_keepdim(0)?.sqrt()?;
250 let query_normalized = query_tensor.div(&query_norm)?;
251
252 let vectors_norm = vectors_tensor.sqr()?.sum_keepdim(1)?.sqrt()?;
254 let vectors_normalized = vectors_tensor.div(&vectors_norm)?;
255
256 let similarities = vectors_normalized
258 .matmul(&query_normalized.unsqueeze(1)?)?
259 .squeeze(1)?;
260
261 self.tensor_to_array(&similarities)
262 }
263
264 fn cosine_similarity_cpu(&self, query: &Array1<f32>, vectors: &Array2<f32>) -> Array1<f32> {
266 let query_norm = query.dot(query).sqrt();
267 let mut similarities = Array1::zeros(vectors.nrows());
268
269 for (i, vector) in vectors.outer_iter().enumerate() {
270 let dot_product = query.dot(&vector);
271 let vector_norm = vector.dot(&vector).sqrt();
272 similarities[i] = if vector_norm > 0.0 && query_norm > 0.0 {
273 dot_product / (query_norm * vector_norm)
274 } else {
275 0.0
276 };
277 }
278
279 similarities
280 }
281
282 pub fn matmul(&self, a: &Array2<f32>, b: &Array2<f32>) -> CandleResult<Array2<f32>> {
284 if !self.is_gpu_enabled() || a.nrows() < 64 || a.ncols() < 64 {
285 return Ok(a.dot(b));
287 }
288
289 let a_tensor = self.array2_to_tensor(a)?;
290 let b_tensor = self.array2_to_tensor(b)?;
291 let result_tensor = a_tensor.matmul(&b_tensor)?;
292 self.tensor_to_array2(&result_tensor)
293 }
294
295 pub fn add_vectors(&self, vectors: &[Array1<f32>]) -> CandleResult<Array1<f32>> {
297 if vectors.is_empty() {
298 return Err(candle_core::Error::Msg(
299 "Cannot add empty vector list".into(),
300 ));
301 }
302
303 if !self.is_gpu_enabled() || vectors.len() < 10 {
304 let mut result = vectors[0].clone();
306 for vector in &vectors[1..] {
307 result = &result + vector;
308 }
309 return Ok(result);
310 }
311
312 let mut result_tensor = self.array_to_tensor(&vectors[0])?;
314 for vector in &vectors[1..] {
315 let vector_tensor = self.array_to_tensor(vector)?;
316 result_tensor = result_tensor.add(&vector_tensor)?;
317 }
318
319 self.tensor_to_array(&result_tensor)
320 }
321
322 pub fn memory_info(&self) -> String {
324 match self.device_type {
325 DeviceType::CPU => "CPU memory (system RAM)".to_string(),
326 DeviceType::CUDA => {
327 #[cfg(feature = "cuda")]
328 {
329 "CUDA GPU memory (use nvidia-smi for details)".to_string()
331 }
332 #[cfg(not(feature = "cuda"))]
333 "CUDA not available".to_string()
334 }
335 DeviceType::Metal => {
336 #[cfg(feature = "metal")]
337 {
338 "Metal GPU memory (system shared)".to_string()
339 }
340 #[cfg(not(feature = "metal"))]
341 "Metal not available".to_string()
342 }
343 }
344 }
345
346 pub fn benchmark(&self) -> CandleResult<f64> {
348 let size = 1000;
349 let a = Array2::<f32>::ones((size, size));
350 let b = Array2::<f32>::ones((size, size));
351
352 let start = std::time::Instant::now();
353 let _result = self.matmul(&a, &b)?;
354 let duration = start.elapsed();
355
356 let ops = (size * size * size) as f64; let gflops = ops / duration.as_secs_f64() / 1e9;
358
359 Ok(gflops)
360 }
361
362 pub fn multi_gpu_similarity_search(
364 &self,
365 query: &Array1<f32>,
366 vectors: &Array2<f32>,
367 ) -> CandleResult<Array1<f32>> {
368 if !self.is_multi_gpu_available() || vectors.nrows() < 1000 {
369 return self.cosine_similarity_batch(query, vectors);
371 }
372
373 println!(
374 "🚀 Using multi-GPU similarity search across {} devices",
375 self.device_count()
376 );
377
378 let chunk_size = vectors.nrows().div_ceil(self.device_count());
379 let mut results = Vec::new();
380
381 for (device_idx, chunk) in vectors
383 .axis_chunks_iter(ndarray::Axis(0), chunk_size)
384 .enumerate()
385 {
386 if device_idx >= self.available_devices.len() {
387 break;
388 }
389
390 let device = &self.available_devices[device_idx];
392 let query_tensor = Tensor::from_slice(
393 query.as_slice().expect("Array must be contiguous"),
394 query.len(),
395 device,
396 )?;
397
398 let chunk_data = chunk.as_slice().expect("Chunk must be contiguous");
399 let chunk_tensor =
400 Tensor::from_slice(chunk_data, (chunk.nrows(), chunk.ncols()), device)?;
401
402 let similarities =
404 self.compute_cosine_similarity_tensor(&query_tensor, &chunk_tensor)?;
405 let similarities_array = self.tensor_to_array(&similarities)?;
406 results.push(similarities_array);
407 }
408
409 let total_len: usize = results.iter().map(|r| r.len()).sum();
411 let mut combined = Vec::with_capacity(total_len);
412 for result in results {
413 combined.extend(result.iter());
414 }
415
416 Ok(Array1::from_vec(combined))
417 }
418
419 fn compute_cosine_similarity_tensor(
421 &self,
422 query: &Tensor,
423 vectors: &Tensor,
424 ) -> CandleResult<Tensor> {
425 let query_norm = query.sqr()?.sum_keepdim(0)?.sqrt()?;
427 let query_normalized = query.broadcast_div(&query_norm)?;
428
429 let vectors_norm = vectors.sqr()?.sum_keepdim(1)?.sqrt()?;
431 let vectors_normalized = vectors.broadcast_div(&vectors_norm)?;
432
433 vectors_normalized
435 .matmul(&query_normalized.unsqueeze(1)?)?
436 .squeeze(1)
437 }
438
439 pub fn multi_gpu_batch_process<T, F>(&self, data: &[T], process_fn: F) -> Result<Vec<T>, String>
441 where
442 T: Clone + Send + Sync,
443 F: Fn(&[T], usize) -> Result<Vec<T>, String> + Send + Sync,
444 {
445 if !self.is_multi_gpu_available() || data.len() < 1000 {
446 return process_fn(data, 0);
448 }
449
450 use rayon::prelude::*;
451
452 let chunk_size = data.len().div_ceil(self.device_count());
453
454 println!(
455 "🚀 Multi-GPU batch processing: {} items across {} devices",
456 data.len(),
457 self.device_count()
458 );
459
460 let results: Result<Vec<Vec<T>>, String> = data
461 .par_chunks(chunk_size)
462 .enumerate()
463 .map(|(device_idx, chunk)| {
464 let gpu_idx = device_idx % self.device_count();
465 process_fn(chunk, gpu_idx)
466 })
467 .collect();
468
469 match results {
470 Ok(chunks) => Ok(chunks.into_iter().flatten().collect()),
471 Err(e) => Err(e),
472 }
473 }
474}
475
476impl Default for GPUAccelerator {
477 fn default() -> Self {
478 Self::new().unwrap_or_else(|_| GPUAccelerator {
479 device: Device::Cpu,
480 device_type: DeviceType::CPU,
481 available_devices: vec![Device::Cpu],
482 current_device_index: 0,
483 })
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use ndarray::Array1;
491
492 #[test]
493 fn test_gpu_accelerator_creation() {
494 let accelerator = GPUAccelerator::new().unwrap();
495 println!("Device type: {:?}", accelerator.device_type());
496 }
497
498 #[test]
499 fn test_cosine_similarity() {
500 let accelerator = GPUAccelerator::global();
501 let query = Array1::from_vec(vec![1.0, 2.0, 3.0]);
502 let vectors = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
503
504 let similarities = accelerator
505 .cosine_similarity_batch(&query, &vectors)
506 .unwrap();
507 assert_eq!(similarities.len(), 2);
508 assert!(similarities[0] > 0.9); }
510
511 #[test]
512 fn test_matrix_multiplication() {
513 let accelerator = GPUAccelerator::global();
514 let a = Array2::<f32>::ones((2, 3));
515 let b = Array2::<f32>::ones((3, 2));
516
517 let result = accelerator.matmul(&a, &b).unwrap();
518 assert_eq!(result.shape(), &[2, 2]);
519 assert_eq!(result[(0, 0)], 3.0); }
521
522 #[test]
523 fn test_benchmark() {
524 let accelerator = GPUAccelerator::global();
525 let gflops = accelerator.benchmark().unwrap();
526 println!(
527 "Benchmark: {:.2} GFLOPS on {:?}",
528 gflops,
529 accelerator.device_type()
530 );
531 assert!(gflops > 0.0);
532 }
533}