chess_vector_engine/
gpu_acceleration.rs

1use candle_core::{Device, Result as CandleResult, Tensor};
2use ndarray::{Array1, Array2};
3use std::sync::OnceLock;
4
5/// GPU acceleration backend with intelligent device detection and CPU fallback
6#[derive(Debug, Clone)]
7pub struct GPUAccelerator {
8    device: Device,
9    device_type: DeviceType,
10    /// Available GPU devices for multi-GPU operations
11    available_devices: Vec<Device>,
12    /// Current device index for multi-GPU operations
13    current_device_index: usize,
14}
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum DeviceType {
18    CPU,
19    CUDA,
20    Metal,
21}
22
23static GPU_ACCELERATOR: OnceLock<GPUAccelerator> = OnceLock::new();
24
25impl GPUAccelerator {
26    /// Get the global GPU accelerator instance (singleton pattern for efficiency)
27    pub fn global() -> &'static GPUAccelerator {
28        GPU_ACCELERATOR.get_or_init(|| {
29            Self::new().unwrap_or_else(|_| {
30                println!("Warning: GPU acceleration failed to initialize, using CPU fallback");
31                GPUAccelerator {
32                    device: Device::Cpu,
33                    device_type: DeviceType::CPU,
34                    available_devices: vec![Device::Cpu],
35                    current_device_index: 0,
36                }
37            })
38        })
39    }
40
41    /// Create a new GPU accelerator with intelligent device detection
42    pub fn new() -> CandleResult<Self> {
43        // Try GPU devices in order of preference: CUDA > Metal > CPU
44
45        #[cfg(feature = "cuda")]
46        {
47            match Self::try_cuda() {
48                Ok(accelerator) => {
49                    println!("GPU acceleration enabled: CUDA device detected");
50                    return Ok(accelerator);
51                }
52                Err(e) => {
53                    println!("CUDA initialization failed: {e}, trying Metal...");
54                }
55            }
56        }
57
58        #[cfg(feature = "metal")]
59        {
60            match Self::try_metal() {
61                Ok(accelerator) => {
62                    println!("GPU acceleration enabled: Metal device detected");
63                    return Ok(accelerator);
64                }
65                Err(e) => {
66                    println!("Metal initialization failed: {e}, falling back to CPU");
67                }
68            }
69        }
70
71        println!("GPU acceleration not available, using CPU");
72        Ok(GPUAccelerator {
73            device: Device::Cpu,
74            device_type: DeviceType::CPU,
75            available_devices: vec![Device::Cpu],
76            current_device_index: 0,
77        })
78    }
79
80    #[cfg(feature = "cuda")]
81    fn try_cuda() -> CandleResult<Self> {
82        // Try to detect multiple CUDA devices
83        let mut available_devices = Vec::new();
84        let mut device_count = 0;
85
86        // Try to detect up to 8 CUDA devices
87        for i in 0..8 {
88            if let Ok(device) = Device::new_cuda(i) {
89                available_devices.push(device);
90                device_count += 1;
91            } else {
92                break;
93            }
94        }
95
96        if available_devices.is_empty() {
97            return Err(candle_core::Error::Msg("No CUDA devices available".into()));
98        }
99
100        println!("🚀 Detected {device_count} CUDA device(s)");
101
102        Ok(GPUAccelerator {
103            device: available_devices[0].clone(),
104            device_type: DeviceType::CUDA,
105            available_devices,
106            current_device_index: 0,
107        })
108    }
109
110    #[cfg(not(feature = "cuda"))]
111    #[allow(dead_code)]
112    fn try_cuda() -> CandleResult<Self> {
113        Err(candle_core::Error::Msg("CUDA not compiled".into()))
114    }
115
116    #[cfg(feature = "metal")]
117    fn try_metal() -> CandleResult<Self> {
118        // Try to detect multiple Metal devices
119        let mut available_devices = Vec::new();
120        let mut device_count = 0;
121
122        // Try to detect up to 4 Metal devices (typically fewer than CUDA)
123        for i in 0..4 {
124            if let Ok(device) = Device::new_metal(i) {
125                available_devices.push(device);
126                device_count += 1;
127            } else {
128                break;
129            }
130        }
131
132        if available_devices.is_empty() {
133            return Err(candle_core::Error::Msg("No Metal devices available".into()));
134        }
135
136        println!("🍎 Detected {device_count} Metal device(s)");
137
138        Ok(GPUAccelerator {
139            device: available_devices[0].clone(),
140            device_type: DeviceType::Metal,
141            available_devices,
142            current_device_index: 0,
143        })
144    }
145
146    #[cfg(not(feature = "metal"))]
147    #[allow(dead_code)]
148    fn try_metal() -> CandleResult<Self> {
149        Err(candle_core::Error::Msg("Metal not compiled".into()))
150    }
151
152    /// Get the device type being used
153    pub fn device_type(&self) -> &DeviceType {
154        &self.device_type
155    }
156
157    /// Get the underlying Candle device
158    pub fn device(&self) -> &Device {
159        &self.device
160    }
161
162    /// Check if GPU acceleration is available
163    pub fn is_gpu_enabled(&self) -> bool {
164        matches!(self.device_type, DeviceType::CUDA | DeviceType::Metal)
165    }
166
167    /// Get number of available GPU devices
168    pub fn device_count(&self) -> usize {
169        self.available_devices.len()
170    }
171
172    /// Check if multiple GPU devices are available
173    pub fn is_multi_gpu_available(&self) -> bool {
174        self.is_gpu_enabled() && self.available_devices.len() > 1
175    }
176
177    /// Get all available devices for multi-GPU operations
178    pub fn all_devices(&self) -> &[Device] {
179        &self.available_devices
180    }
181
182    /// Switch to a specific device (for multi-GPU operations)
183    pub fn switch_device(&mut self, device_index: usize) -> Result<(), String> {
184        if device_index >= self.available_devices.len() {
185            return Err(format!(
186                "Device index {} out of range (have {} devices)",
187                device_index,
188                self.available_devices.len()
189            ));
190        }
191
192        self.device = self.available_devices[device_index].clone();
193        self.current_device_index = device_index;
194        Ok(())
195    }
196
197    /// Get current device index
198    pub fn current_device_index(&self) -> usize {
199        self.current_device_index
200    }
201
202    /// Convert ndarray to Candle tensor on the appropriate device
203    pub fn array_to_tensor(&self, array: &Array1<f32>) -> CandleResult<Tensor> {
204        let data = array.as_slice().expect("Array must be contiguous");
205        Tensor::from_slice(data, array.len(), &self.device)
206    }
207
208    /// Convert 2D ndarray to Candle tensor on the appropriate device
209    pub fn array2_to_tensor(&self, array: &Array2<f32>) -> CandleResult<Tensor> {
210        let shape = array.shape();
211        let data = array.as_slice().expect("Array must be contiguous");
212        Tensor::from_slice(data, (shape[0], shape[1]), &self.device)
213    }
214
215    /// Convert Candle tensor back to ndarray
216    pub fn tensor_to_array(&self, tensor: &Tensor) -> CandleResult<Array1<f32>> {
217        let data = tensor.to_vec1::<f32>()?;
218        Ok(Array1::from_vec(data))
219    }
220
221    /// Convert 2D Candle tensor back to ndarray
222    pub fn tensor_to_array2(&self, tensor: &Tensor) -> CandleResult<Array2<f32>> {
223        let dims = tensor.dims();
224        if dims.len() != 2 {
225            return Err(candle_core::Error::Msg("Expected 2D tensor".into()));
226        }
227        let data = tensor.to_vec2::<f32>()?;
228        let flat_data: Vec<f32> = data.into_iter().flatten().collect();
229        Array2::from_shape_vec((dims[0], dims[1]), flat_data)
230            .map_err(|_e| candle_core::Error::Msg("Processing...".to_string()))
231    }
232
233    /// Accelerated cosine similarity computation
234    pub fn cosine_similarity_batch(
235        &self,
236        query: &Array1<f32>,
237        vectors: &Array2<f32>,
238    ) -> CandleResult<Array1<f32>> {
239        if !self.is_gpu_enabled() || vectors.nrows() < 100 {
240            // Fall back to CPU for small batches or when GPU not available
241            return Ok(self.cosine_similarity_cpu(query, vectors));
242        }
243
244        // GPU-accelerated computation
245        let query_tensor = self.array_to_tensor(query)?;
246        let vectors_tensor = self.array2_to_tensor(vectors)?;
247
248        // Normalize query vector
249        let query_norm = query_tensor.sqr()?.sum_keepdim(0)?.sqrt()?;
250        let query_normalized = query_tensor.div(&query_norm)?;
251
252        // Normalize all vectors
253        let vectors_norm = vectors_tensor.sqr()?.sum_keepdim(1)?.sqrt()?;
254        let vectors_normalized = vectors_tensor.div(&vectors_norm)?;
255
256        // Compute dot products (cosine similarity)
257        let similarities = vectors_normalized
258            .matmul(&query_normalized.unsqueeze(1)?)?
259            .squeeze(1)?;
260
261        self.tensor_to_array(&similarities)
262    }
263
264    /// CPU fallback for cosine similarity
265    fn cosine_similarity_cpu(&self, query: &Array1<f32>, vectors: &Array2<f32>) -> Array1<f32> {
266        let query_norm = query.dot(query).sqrt();
267        let mut similarities = Array1::zeros(vectors.nrows());
268
269        for (i, vector) in vectors.outer_iter().enumerate() {
270            let dot_product = query.dot(&vector);
271            let vector_norm = vector.dot(&vector).sqrt();
272            similarities[i] = if vector_norm > 0.0 && query_norm > 0.0 {
273                dot_product / (query_norm * vector_norm)
274            } else {
275                0.0
276            };
277        }
278
279        similarities
280    }
281
282    /// Accelerated matrix multiplication
283    pub fn matmul(&self, a: &Array2<f32>, b: &Array2<f32>) -> CandleResult<Array2<f32>> {
284        if !self.is_gpu_enabled() || a.nrows() < 64 || a.ncols() < 64 {
285            // CPU fallback for small matrices
286            return Ok(a.dot(b));
287        }
288
289        let a_tensor = self.array2_to_tensor(a)?;
290        let b_tensor = self.array2_to_tensor(b)?;
291        let result_tensor = a_tensor.matmul(&b_tensor)?;
292        self.tensor_to_array2(&result_tensor)
293    }
294
295    /// Accelerated vector addition
296    pub fn add_vectors(&self, vectors: &[Array1<f32>]) -> CandleResult<Array1<f32>> {
297        if vectors.is_empty() {
298            return Err(candle_core::Error::Msg(
299                "Cannot add empty vector list".into(),
300            ));
301        }
302
303        if !self.is_gpu_enabled() || vectors.len() < 10 {
304            // CPU fallback
305            let mut result = vectors[0].clone();
306            for vector in &vectors[1..] {
307                result = &result + vector;
308            }
309            return Ok(result);
310        }
311
312        // GPU acceleration
313        let mut result_tensor = self.array_to_tensor(&vectors[0])?;
314        for vector in &vectors[1..] {
315            let vector_tensor = self.array_to_tensor(vector)?;
316            result_tensor = result_tensor.add(&vector_tensor)?;
317        }
318
319        self.tensor_to_array(&result_tensor)
320    }
321
322    /// Get memory usage information
323    pub fn memory_info(&self) -> String {
324        match self.device_type {
325            DeviceType::CPU => "CPU memory (system RAM)".to_string(),
326            DeviceType::CUDA => {
327                #[cfg(feature = "cuda")]
328                {
329                    // Would need CUDA runtime API calls to get actual memory info
330                    "CUDA GPU memory (use nvidia-smi for details)".to_string()
331                }
332                #[cfg(not(feature = "cuda"))]
333                "CUDA not available".to_string()
334            }
335            DeviceType::Metal => {
336                #[cfg(feature = "metal")]
337                {
338                    "Metal GPU memory (system shared)".to_string()
339                }
340                #[cfg(not(feature = "metal"))]
341                "Metal not available".to_string()
342            }
343        }
344    }
345
346    /// Benchmark the device performance
347    pub fn benchmark(&self) -> CandleResult<f64> {
348        let size = 1000;
349        let a = Array2::<f32>::ones((size, size));
350        let b = Array2::<f32>::ones((size, size));
351
352        let start = std::time::Instant::now();
353        let _result = self.matmul(&a, &b)?;
354        let duration = start.elapsed();
355
356        let ops = (size * size * size) as f64; // Matrix multiplication operations
357        let gflops = ops / duration.as_secs_f64() / 1e9;
358
359        Ok(gflops)
360    }
361
362    /// Multi-GPU parallel similarity search (when multiple GPUs available)
363    pub fn multi_gpu_similarity_search(
364        &self,
365        query: &Array1<f32>,
366        vectors: &Array2<f32>,
367    ) -> CandleResult<Array1<f32>> {
368        if !self.is_multi_gpu_available() || vectors.nrows() < 1000 {
369            // Fall back to single GPU/CPU
370            return self.cosine_similarity_batch(query, vectors);
371        }
372
373        println!(
374            "🚀 Using multi-GPU similarity search across {} devices",
375            self.device_count()
376        );
377
378        let chunk_size = vectors.nrows().div_ceil(self.device_count());
379        let mut results = Vec::new();
380
381        // Process chunks in parallel across different GPUs
382        for (device_idx, chunk) in vectors
383            .axis_chunks_iter(ndarray::Axis(0), chunk_size)
384            .enumerate()
385        {
386            if device_idx >= self.available_devices.len() {
387                break;
388            }
389
390            // Create tensor on specific device
391            let device = &self.available_devices[device_idx];
392            let query_tensor = Tensor::from_slice(
393                query.as_slice().expect("Array must be contiguous"),
394                query.len(),
395                device,
396            )?;
397
398            let chunk_data = chunk.as_slice().expect("Chunk must be contiguous");
399            let chunk_tensor =
400                Tensor::from_slice(chunk_data, (chunk.nrows(), chunk.ncols()), device)?;
401
402            // Compute similarities on this GPU
403            let similarities =
404                self.compute_cosine_similarity_tensor(&query_tensor, &chunk_tensor)?;
405            let similarities_array = self.tensor_to_array(&similarities)?;
406            results.push(similarities_array);
407        }
408
409        // Concatenate results
410        let total_len: usize = results.iter().map(|r| r.len()).sum();
411        let mut combined = Vec::with_capacity(total_len);
412        for result in results {
413            combined.extend(result.iter());
414        }
415
416        Ok(Array1::from_vec(combined))
417    }
418
419    /// Helper method to compute cosine similarity on tensor
420    fn compute_cosine_similarity_tensor(
421        &self,
422        query: &Tensor,
423        vectors: &Tensor,
424    ) -> CandleResult<Tensor> {
425        // Normalize query
426        let query_norm = query.sqr()?.sum_keepdim(0)?.sqrt()?;
427        let query_normalized = query.broadcast_div(&query_norm)?;
428
429        // Normalize vectors
430        let vectors_norm = vectors.sqr()?.sum_keepdim(1)?.sqrt()?;
431        let vectors_normalized = vectors.broadcast_div(&vectors_norm)?;
432
433        // Compute dot product (cosine similarity)
434        vectors_normalized
435            .matmul(&query_normalized.unsqueeze(1)?)?
436            .squeeze(1)
437    }
438
439    /// Multi-GPU batch processing for large operations
440    pub fn multi_gpu_batch_process<T, F>(&self, data: &[T], process_fn: F) -> Result<Vec<T>, String>
441    where
442        T: Clone + Send + Sync,
443        F: Fn(&[T], usize) -> Result<Vec<T>, String> + Send + Sync,
444    {
445        if !self.is_multi_gpu_available() || data.len() < 1000 {
446            // Single device processing
447            return process_fn(data, 0);
448        }
449
450        use rayon::prelude::*;
451
452        let chunk_size = data.len().div_ceil(self.device_count());
453
454        println!(
455            "🚀 Multi-GPU batch processing: {} items across {} devices",
456            data.len(),
457            self.device_count()
458        );
459
460        let results: Result<Vec<Vec<T>>, String> = data
461            .par_chunks(chunk_size)
462            .enumerate()
463            .map(|(device_idx, chunk)| {
464                let gpu_idx = device_idx % self.device_count();
465                process_fn(chunk, gpu_idx)
466            })
467            .collect();
468
469        match results {
470            Ok(chunks) => Ok(chunks.into_iter().flatten().collect()),
471            Err(e) => Err(e),
472        }
473    }
474}
475
476impl Default for GPUAccelerator {
477    fn default() -> Self {
478        Self::new().unwrap_or_else(|_| GPUAccelerator {
479            device: Device::Cpu,
480            device_type: DeviceType::CPU,
481            available_devices: vec![Device::Cpu],
482            current_device_index: 0,
483        })
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use ndarray::Array1;
491
492    #[test]
493    fn test_gpu_accelerator_creation() {
494        let accelerator = GPUAccelerator::new().unwrap();
495        println!("Device type: {:?}", accelerator.device_type());
496    }
497
498    #[test]
499    fn test_cosine_similarity() {
500        let accelerator = GPUAccelerator::global();
501        let query = Array1::from_vec(vec![1.0, 2.0, 3.0]);
502        let vectors = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
503
504        let similarities = accelerator
505            .cosine_similarity_batch(&query, &vectors)
506            .unwrap();
507        assert_eq!(similarities.len(), 2);
508        assert!(similarities[0] > 0.9); // Should be close to 1.0 for identical vectors
509    }
510
511    #[test]
512    fn test_matrix_multiplication() {
513        let accelerator = GPUAccelerator::global();
514        let a = Array2::<f32>::ones((2, 3));
515        let b = Array2::<f32>::ones((3, 2));
516
517        let result = accelerator.matmul(&a, &b).unwrap();
518        assert_eq!(result.shape(), &[2, 2]);
519        assert_eq!(result[(0, 0)], 3.0); // Sum of ones
520    }
521
522    #[test]
523    fn test_benchmark() {
524        let accelerator = GPUAccelerator::global();
525        let gflops = accelerator.benchmark().unwrap();
526        println!(
527            "Benchmark: {:.2} GFLOPS on {:?}",
528            gflops,
529            accelerator.device_type()
530        );
531        assert!(gflops > 0.0);
532    }
533}