Skip to main content

oxirs_embed/acceleration/
gpu.rs

1//! GPU Acceleration for Embedding Computations
2//!
3//! This module provides GPU-accelerated implementations of embedding operations
4//! using scirs2-linalg GPU features for CUDA, OpenCL, ROCm, and Metal backends.
5
6use crate::models::common::*;
7use anyhow::Result;
8use scirs2_core::ndarray_ext::{Array1, Array2};
9#[cfg(feature = "gpu")]
10use std::collections::VecDeque;
11#[cfg(feature = "gpu")]
12use std::sync::atomic::{AtomicU64, Ordering};
13#[cfg(feature = "gpu")]
14use std::sync::{Arc, Mutex, RwLock};
15#[cfg(feature = "gpu")]
16use std::time::{Duration, Instant};
17
18#[cfg(feature = "gpu")]
19// TODO: scirs2_linalg::gpu module is not yet available
20// Enable this when the GPU module is implemented in scirs2_linalg
21// use scirs2_linalg::gpu::{GpuArray, GpuContext, GpuError};
22// Placeholder types until scirs2_linalg::gpu is available
23pub type GpuArray<T> = Vec<T>;
24#[cfg(feature = "gpu")]
25pub type GpuContext = ();
26#[cfg(feature = "gpu")]
27#[derive(Debug)]
28pub struct GpuError(String);
29
30#[cfg(feature = "gpu")]
31impl std::fmt::Display for GpuError {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        write!(f, "{}", self.0)
34    }
35}
36
37#[cfg(feature = "gpu")]
38impl std::error::Error for GpuError {}
39
40/// Memory pool for GPU buffers
41#[cfg(feature = "gpu")]
42#[derive(Debug)]
43pub struct GpuMemoryPool {
44    available_buffers: VecDeque<GpuArray<f32>>,
45    buffer_size: usize,
46    total_allocated: AtomicU64,
47    peak_usage: AtomicU64,
48}
49
50/// Adaptive batch sizing configuration
51#[cfg(feature = "gpu")]
52#[derive(Debug, Clone)]
53pub struct AdaptiveBatchConfig {
54    pub min_batch_size: usize,
55    pub max_batch_size: usize,
56    pub target_gpu_utilization: f32,
57    pub memory_usage_threshold: f32,
58}
59
60/// Enhanced GPU-accelerated embedding computations with memory pooling and adaptive batching
61#[cfg(feature = "gpu")]
62pub struct GpuEmbeddingAccelerator {
63    context: GpuContext,
64    device_id: u32,
65    memory_pool: Arc<Mutex<GpuMemoryPool>>,
66    batch_config: AdaptiveBatchConfig,
67    performance_stats: Arc<RwLock<GpuPerformanceStats>>,
68    optimal_batch_size: Arc<AtomicU64>,
69}
70
71/// GPU performance statistics
72#[cfg(feature = "gpu")]
73#[derive(Debug, Default)]
74pub struct GpuPerformanceStats {
75    pub total_operations: u64,
76    pub total_compute_time: Duration,
77    pub memory_transfers: u64,
78    pub cache_hits: u64,
79    pub cache_misses: u64,
80    pub average_batch_size: f32,
81    pub gpu_utilization_percentage: f32,
82}
83
84/// Comprehensive GPU performance report
85#[cfg(feature = "gpu")]
86#[derive(Debug)]
87pub struct GpuPerformanceReport {
88    pub device_id: u32,
89    pub total_operations: u64,
90    pub average_compute_time: Duration,
91    pub gpu_utilization: f32,
92    pub memory_allocated_mb: f64,
93    pub memory_peak_mb: f64,
94    pub cache_hit_rate: f32,
95    pub optimal_batch_size: usize,
96}
97
98#[cfg(feature = "gpu")]
99impl GpuMemoryPool {
100    pub fn new(buffer_size: usize, initial_pool_size: usize) -> Self {
101        Self {
102            available_buffers: VecDeque::with_capacity(initial_pool_size),
103            buffer_size,
104            total_allocated: AtomicU64::new(0),
105            peak_usage: AtomicU64::new(0),
106        }
107    }
108
109    pub fn get_buffer(&mut self) -> Option<GpuArray<f32>> {
110        self.available_buffers.pop_front()
111    }
112
113    pub fn return_buffer(&mut self, buffer: GpuArray<f32>) {
114        if buffer.len() == self.buffer_size {
115            self.available_buffers.push_back(buffer);
116        }
117        // If buffer size doesn't match, let it drop (auto-deallocate)
118    }
119
120    pub fn get_memory_stats(&self) -> (u64, u64) {
121        (
122            self.total_allocated.load(Ordering::Relaxed),
123            self.peak_usage.load(Ordering::Relaxed),
124        )
125    }
126}
127
128#[cfg(feature = "gpu")]
129impl GpuEmbeddingAccelerator {
130    /// Create a new enhanced GPU accelerator with memory pooling and adaptive batching
131    /// Note: Currently using placeholder until scirs2_linalg::gpu is available
132    pub fn new(device_id: u32) -> Result<Self, GpuError> {
133        let context = (); // Placeholder GpuContext
134
135        let memory_pool = Arc::new(Mutex::new(GpuMemoryPool::new(1024 * 1024, 10))); // 1MB buffers, 10 initial
136
137        let batch_config = AdaptiveBatchConfig {
138            min_batch_size: 32,
139            max_batch_size: 8192,
140            target_gpu_utilization: 0.85,
141            memory_usage_threshold: 0.8,
142        };
143
144        Ok(Self {
145            context,
146            device_id,
147            memory_pool,
148            batch_config,
149            performance_stats: Arc::new(RwLock::new(GpuPerformanceStats::default())),
150            optimal_batch_size: Arc::new(AtomicU64::new(512)), // Start with reasonable default
151        })
152    }
153
154    /// Get optimal batch size based on recent performance
155    pub async fn get_optimal_batch_size(&self, data_size: usize) -> usize {
156        let optimal = self.optimal_batch_size.load(Ordering::Relaxed) as usize;
157        let config_min = self.batch_config.min_batch_size;
158        let config_max = self.batch_config.max_batch_size;
159
160        // Clamp to configuration bounds and data size
161        optimal.clamp(config_min, config_max.min(data_size))
162    }
163
164    /// Update optimal batch size based on performance feedback
165    pub async fn update_batch_size_feedback(&self, _batch_size: usize, performance_score: f32) {
166        let current_optimal = self.optimal_batch_size.load(Ordering::Relaxed) as usize;
167
168        // Simple adaptive algorithm: increase if performance is good, decrease if poor
169        let new_optimal = if performance_score > 0.8 {
170            // Good performance, try larger batches
171            (current_optimal as f32 * 1.1).round() as usize
172        } else if performance_score < 0.5 {
173            // Poor performance, try smaller batches
174            (current_optimal as f32 * 0.9).round() as usize
175        } else {
176            current_optimal
177        };
178
179        let clamped_optimal = new_optimal.clamp(
180            self.batch_config.min_batch_size,
181            self.batch_config.max_batch_size,
182        );
183
184        self.optimal_batch_size
185            .store(clamped_optimal as u64, Ordering::Relaxed);
186    }
187
188    /// GPU-accelerated batch distance computation
189    pub fn batch_l2_distances_gpu(
190        &self,
191        vectors_a: &[Array1<f64>],
192        vectors_b: &[Array1<f64>],
193    ) -> Result<Vec<f64>, GpuError> {
194        // TODO: Use actual GPU operations when scirs2_linalg::gpu is available
195        // For now, use CPU implementation as fallback
196        let mut distances = Vec::with_capacity(vectors_a.len());
197        for (a, b) in vectors_a.iter().zip(vectors_b.iter()) {
198            let dist: f64 = a
199                .iter()
200                .zip(b.iter())
201                .map(|(x, y)| (x - y).powi(2))
202                .sum::<f64>()
203                .sqrt();
204            distances.push(dist);
205        }
206        Ok(distances)
207    }
208
209    /// GPU-accelerated cosine similarity matrix
210    /// Note: Currently using CPU fallback until scirs2_linalg::gpu is available
211    pub fn cosine_similarity_matrix_gpu(
212        &self,
213        vectors: &[Array1<f64>],
214    ) -> Result<Array2<f64>, GpuError> {
215        // TODO: Use actual GPU operations when scirs2_linalg::gpu is available
216        // For now, use CPU implementation as fallback
217        use scirs2_core::ndarray_ext::Array2;
218
219        let n = vectors.len();
220        let mut similarity_matrix = Array2::zeros((n, n));
221
222        for i in 0..n {
223            for j in 0..n {
224                let dot: f64 = vectors[i]
225                    .iter()
226                    .zip(vectors[j].iter())
227                    .map(|(a, b)| a * b)
228                    .sum();
229                let norm_i: f64 = vectors[i].iter().map(|x| x * x).sum::<f64>().sqrt();
230                let norm_j: f64 = vectors[j].iter().map(|x| x * x).sum::<f64>().sqrt();
231                similarity_matrix[[i, j]] = dot / (norm_i * norm_j + 1e-8);
232            }
233        }
234        Ok(similarity_matrix)
235    }
236
237    /// GPU-accelerated gradient updates for large embedding matrices
238    /// Note: Currently using CPU fallback until scirs2_linalg::gpu is available
239    pub fn batch_gradient_update_gpu(
240        &self,
241        embeddings: &mut [Array2<f64>],
242        gradients: &[Array2<f64>],
243        learning_rate: f64,
244        l2_reg: f64,
245    ) -> Result<(), GpuError> {
246        // TODO: Use actual GPU operations when scirs2_linalg::gpu is available
247        // For now, use CPU implementation as fallback
248        for (embedding, gradient) in embeddings.iter_mut().zip(gradients.iter()) {
249            // Apply gradient update with L2 regularization
250            for (emb, grad) in embedding.iter_mut().zip(gradient.iter()) {
251                *emb -= learning_rate * (grad + l2_reg * *emb);
252            }
253        }
254        Ok(())
255    }
256
257    /// Advanced GPU-accelerated adaptive batch processing with memory pooling
258    pub async fn adaptive_batch_processing<T, R>(
259        &self,
260        data: &[T],
261        mut process_fn: impl FnMut(&[T]) -> Result<Vec<R>, GpuError>,
262    ) -> Result<Vec<R>, GpuError> {
263        let start_time = Instant::now();
264        let batch_size = self.get_optimal_batch_size(data.len()).await;
265
266        let mut results = Vec::with_capacity(data.len());
267        let mut total_processing_time = Duration::ZERO;
268
269        for chunk in data.chunks(batch_size) {
270            let chunk_start = Instant::now();
271            let chunk_results = process_fn(chunk)?;
272            let chunk_time = chunk_start.elapsed();
273
274            results.extend(chunk_results);
275            total_processing_time += chunk_time;
276        }
277
278        // Calculate performance score and update batch size
279        let total_time = start_time.elapsed();
280        let gpu_utilization = total_processing_time.as_secs_f32() / total_time.as_secs_f32();
281        let performance_score = gpu_utilization.min(1.0);
282
283        self.update_batch_size_feedback(batch_size, performance_score)
284            .await;
285
286        // Update performance statistics
287        let mut stats = self
288            .performance_stats
289            .write()
290            .expect("lock should not be poisoned");
291        stats.total_operations += 1;
292        stats.total_compute_time += total_time;
293        stats.gpu_utilization_percentage = gpu_utilization * 100.0;
294        stats.average_batch_size = (stats.average_batch_size + batch_size as f32) / 2.0;
295
296        Ok(results)
297    }
298
299    /// GPU-accelerated matrix multiplication with memory reuse
300    /// Note: Currently using CPU fallback until scirs2_linalg::gpu is available
301    pub async fn optimized_matrix_multiply(
302        &self,
303        a: &Array2<f32>,
304        b: &Array2<f32>,
305    ) -> Result<Array2<f32>, GpuError> {
306        // TODO: Use actual GPU operations when scirs2_linalg::gpu is available
307        // For now, use CPU implementation as fallback
308        let result = a.dot(b);
309
310        Ok(result)
311    }
312
313    /// High-performance embedding search with GPU acceleration
314    pub async fn gpu_embedding_search(
315        &self,
316        query_embedding: &Array1<f32>,
317        database_embeddings: &[Array1<f32>],
318        top_k: usize,
319    ) -> Result<Vec<(usize, f32)>, GpuError> {
320        // Use adaptive batching for large databases
321        let batch_size = self.get_optimal_batch_size(database_embeddings.len()).await;
322        let mut all_similarities = Vec::with_capacity(database_embeddings.len());
323
324        // Process in adaptive batches
325        for (batch_idx, batch) in database_embeddings.chunks(batch_size).enumerate() {
326            let similarities = self
327                .compute_batch_similarities(query_embedding, batch)
328                .await?;
329
330            for (local_idx, similarity) in similarities.iter().enumerate() {
331                let global_idx = batch_idx * batch_size + local_idx;
332                all_similarities.push((global_idx, *similarity));
333            }
334        }
335
336        // Sort and return top-k
337        all_similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
338        all_similarities.truncate(top_k);
339
340        Ok(all_similarities)
341    }
342
343    /// Compute similarities for a batch with GPU acceleration
344    /// Note: Currently using CPU fallback until scirs2_linalg::gpu is available
345    async fn compute_batch_similarities(
346        &self,
347        query: &Array1<f32>,
348        batch: &[Array1<f32>],
349    ) -> Result<Vec<f32>, GpuError> {
350        // TODO: Use actual GPU operations when scirs2_linalg::gpu is available
351        // For now, use CPU implementation as fallback
352        let mut similarities = Vec::with_capacity(batch.len());
353
354        for emb in batch {
355            // Compute cosine similarity: (a ยท b) / (||a|| * ||b||)
356            let dot_product: f32 = query.iter().zip(emb.iter()).map(|(a, b)| a * b).sum();
357            let norm_query: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
358            let norm_emb: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
359            let similarity = dot_product / (norm_query * norm_emb + 1e-8);
360            similarities.push(similarity);
361        }
362
363        Ok(similarities)
364    }
365
366    /// GPU-accelerated Xavier initialization for large embedding matrices
367    /// Note: Currently using CPU fallback until scirs2_linalg::gpu is available
368    pub fn xavier_init_gpu(
369        &self,
370        shapes: &[(usize, usize)],
371        fan_in: usize,
372        fan_out: usize,
373        seed: u64,
374    ) -> Result<Vec<Array2<f64>>, GpuError> {
375        use scirs2_core::random::Random;
376
377        let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
378        let mut rng = Random::seed(seed);
379        let scale = 2.0 * limit;
380
381        let mut results = Vec::with_capacity(shapes.len());
382        for &(rows, cols) in shapes {
383            // Generate uniform random numbers in [-limit, limit]
384            let data: Vec<f64> = (0..rows * cols)
385                .map(|_| rng.random_f64() * scale - limit)
386                .collect();
387            let array = Array2::from_shape_vec((rows, cols), data)
388                .map_err(|e| GpuError(format!("Failed to create array: {}", e)))?;
389            results.push(array);
390        }
391        Ok(results)
392    }
393
394    /// GPU-accelerated contrastive learning updates
395    /// Note: Currently using CPU fallback until scirs2_linalg::gpu is available
396    pub fn contrastive_learning_gpu(
397        &self,
398        _entity_embeddings: &mut [Array1<f32>],
399        _similarity_pairs: &[(usize, usize)],
400        _negative_samples: &[(usize, usize)],
401        _temperature: f32,
402        _learning_rate: f32,
403    ) -> Result<f32, GpuError> {
404        // TODO: Use actual GPU operations when scirs2_linalg::gpu is available
405        // For now, return placeholder loss
406        Ok(0.0)
407    }
408
409    /// Helper function to upload vectors to GPU
410    /// Note: Currently using CPU fallback until scirs2_linalg::gpu is available
411    fn upload_vectors_to_gpu(&self, _vectors: &[Array1<f64>]) -> Result<GpuArray<f64>, GpuError> {
412        // TODO: Use actual GPU operations when scirs2_linalg::gpu is available
413        Ok(Vec::new())
414    }
415
416    /// Helper function to upload f32 vectors to GPU
417    /// Note: Currently using CPU fallback until scirs2_linalg::gpu is available
418    fn upload_f32_vectors_to_gpu(
419        &self,
420        _vectors: &[Array1<f32>],
421    ) -> Result<GpuArray<f32>, GpuError> {
422        // TODO: Use actual GPU operations when scirs2_linalg::gpu is available
423        Ok(Vec::new())
424    }
425
426    /// Get GPU device info
427    pub fn device_info(&self) -> String {
428        format!(
429            "GPU Device {} (placeholder - scirs2_linalg::gpu not yet available)",
430            self.device_id
431        )
432    }
433
434    /// Get available GPU memory
435    /// Note: Currently using placeholder until scirs2_linalg::gpu is available
436    pub fn available_memory(&self) -> Result<u64, GpuError> {
437        // TODO: Use actual GPU operations when scirs2_linalg::gpu is available
438        Ok(0)
439    }
440
441    /// GPU memory and performance monitoring
442    pub async fn get_performance_report(&self) -> GpuPerformanceReport {
443        let stats = self
444            .performance_stats
445            .read()
446            .expect("lock should not be poisoned");
447        let (allocated, peak) = {
448            let pool = self
449                .memory_pool
450                .lock()
451                .expect("lock should not be poisoned");
452            pool.get_memory_stats()
453        };
454
455        GpuPerformanceReport {
456            device_id: self.device_id,
457            total_operations: stats.total_operations,
458            average_compute_time: if stats.total_operations > 0 {
459                stats.total_compute_time / stats.total_operations as u32
460            } else {
461                Duration::ZERO
462            },
463            gpu_utilization: stats.gpu_utilization_percentage,
464            memory_allocated_mb: allocated as f64 / (1024.0 * 1024.0),
465            memory_peak_mb: peak as f64 / (1024.0 * 1024.0),
466            cache_hit_rate: if stats.cache_hits + stats.cache_misses > 0 {
467                stats.cache_hits as f32 / (stats.cache_hits + stats.cache_misses) as f32
468            } else {
469                0.0
470            },
471            optimal_batch_size: self.optimal_batch_size.load(Ordering::Relaxed) as usize,
472        }
473    }
474
475    /// Reset performance statistics
476    pub fn reset_performance_stats(&self) {
477        let mut stats = self
478            .performance_stats
479            .write()
480            .expect("lock should not be poisoned");
481        *stats = GpuPerformanceStats::default();
482        self.optimal_batch_size.store(512, Ordering::Relaxed);
483    }
484
485    /// Get current memory pool status
486    pub fn get_memory_pool_status(&self) -> (usize, u64, u64) {
487        let pool = self
488            .memory_pool
489            .lock()
490            .expect("lock should not be poisoned");
491        let (allocated, peak) = pool.get_memory_stats();
492        (pool.available_buffers.len(), allocated, peak)
493    }
494}
495
496/// CPU fallback implementations when GPU is not available
497#[cfg(not(feature = "gpu"))]
498use scirs2_core::random::Random;
499
500#[cfg(not(feature = "gpu"))]
501pub struct GpuEmbeddingAccelerator;
502
503#[cfg(not(feature = "gpu"))]
504impl GpuEmbeddingAccelerator {
505    pub fn new(_device_id: u32) -> Result<Self> {
506        Ok(Self)
507    }
508
509    /// Fallback to CPU implementation
510    pub fn batch_l2_distances_gpu(
511        &self,
512        vectors_a: &[Array1<f64>],
513        vectors_b: &[Array1<f64>],
514    ) -> Result<Vec<f64>> {
515        Ok(batch_l2_distances(vectors_a, vectors_b))
516    }
517
518    /// Fallback to CPU implementation
519    pub fn cosine_similarity_matrix_gpu(&self, vectors: &[Array1<f64>]) -> Result<Array2<f64>> {
520        Ok(pairwise_distances(vectors))
521    }
522
523    /// Fallback to CPU implementation
524    pub fn batch_gradient_update_gpu(
525        &self,
526        embeddings: &mut [Array2<f64>],
527        gradients: &[Array2<f64>],
528        learning_rate: f64,
529        l2_reg: f64,
530    ) -> Result<()> {
531        batch_gradient_update(embeddings, gradients, learning_rate, l2_reg);
532        Ok(())
533    }
534
535    /// Fallback to CPU implementation
536    pub fn xavier_init_gpu(
537        &self,
538        shapes: &[(usize, usize)],
539        fan_in: usize,
540        fan_out: usize,
541        _seed: u64,
542    ) -> Result<Vec<Array2<f64>>> {
543        let mut rng = Random::default();
544        Ok(batch_xavier_init(shapes, fan_in, fan_out, &mut rng))
545    }
546
547    pub fn device_info(&self) -> String {
548        "CPU (GPU acceleration not available)".to_string()
549    }
550
551    pub fn available_memory(&self) -> Result<u64> {
552        // Return available system RAM as approximation
553        Ok(8 * 1024 * 1024 * 1024) // 8GB default
554    }
555}
556
557/// Adaptive acceleration that chooses between GPU and CPU based on problem size
558pub struct AdaptiveEmbeddingAccelerator {
559    gpu_accelerator: Option<GpuEmbeddingAccelerator>,
560    gpu_threshold: usize,
561}
562
563impl AdaptiveEmbeddingAccelerator {
564    /// Create adaptive accelerator with optional GPU support
565    pub fn new(device_id: Option<u32>, gpu_threshold: usize) -> Result<Self> {
566        #[allow(unused_variables)]
567        let gpu_accelerator = if let Some(id) = device_id {
568            #[cfg(feature = "gpu")]
569            {
570                GpuEmbeddingAccelerator::new(id).ok()
571            }
572            #[cfg(not(feature = "gpu"))]
573            {
574                None
575            }
576        } else {
577            None
578        };
579
580        Ok(Self {
581            gpu_accelerator,
582            gpu_threshold,
583        })
584    }
585
586    /// Intelligently choose between GPU and CPU for distance computation
587    pub fn adaptive_batch_distances(
588        &self,
589        vectors_a: &[Array1<f64>],
590        vectors_b: &[Array1<f64>],
591    ) -> Result<Vec<f64>> {
592        if self.should_use_gpu(vectors_a.len() * vectors_b.len()) {
593            if let Some(ref gpu) = self.gpu_accelerator {
594                return gpu
595                    .batch_l2_distances_gpu(vectors_a, vectors_b)
596                    .map_err(|e| anyhow::anyhow!("GPU error: {:?}", e));
597            }
598        }
599
600        // Fallback to optimized CPU implementation
601        Ok(batch_l2_distances(vectors_a, vectors_b))
602    }
603
604    /// Intelligently choose between GPU and CPU for gradient updates
605    pub fn adaptive_gradient_update(
606        &self,
607        embeddings: &mut [Array2<f64>],
608        gradients: &[Array2<f64>],
609        learning_rate: f64,
610        l2_reg: f64,
611    ) -> Result<()> {
612        let total_elements: usize = embeddings.iter().map(|e| e.len()).sum();
613
614        if self.should_use_gpu(total_elements) {
615            if let Some(ref gpu) = self.gpu_accelerator {
616                return gpu
617                    .batch_gradient_update_gpu(embeddings, gradients, learning_rate, l2_reg)
618                    .map_err(|e| anyhow::anyhow!("GPU error: {:?}", e));
619            }
620        }
621
622        // Fallback to optimized CPU implementation
623        batch_gradient_update(embeddings, gradients, learning_rate, l2_reg);
624        Ok(())
625    }
626
627    /// Check if GPU should be used based on problem size
628    fn should_use_gpu(&self, problem_size: usize) -> bool {
629        self.gpu_accelerator.is_some() && problem_size >= self.gpu_threshold
630    }
631
632    /// Get acceleration info
633    pub fn info(&self) -> String {
634        match &self.gpu_accelerator {
635            Some(gpu) => format!(
636                "Adaptive: {} (threshold: {})",
637                gpu.device_info(),
638                self.gpu_threshold
639            ),
640            None => format!("Adaptive: CPU only (threshold: {})", self.gpu_threshold),
641        }
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648
649    #[test]
650    fn test_adaptive_accelerator_creation() {
651        let accelerator = AdaptiveEmbeddingAccelerator::new(None, 1000).unwrap();
652        assert!(accelerator.info().contains("CPU only"));
653    }
654
655    #[test]
656    fn test_fallback_distance_computation() {
657        let accelerator = AdaptiveEmbeddingAccelerator::new(None, 1000).unwrap();
658
659        let vectors_a = vec![
660            Array1::from_vec(vec![1.0, 2.0, 3.0]),
661            Array1::from_vec(vec![4.0, 5.0, 6.0]),
662        ];
663        let vectors_b = vec![
664            Array1::from_vec(vec![7.0, 8.0, 9.0]),
665            Array1::from_vec(vec![10.0, 11.0, 12.0]),
666        ];
667
668        let distances = accelerator
669            .adaptive_batch_distances(&vectors_a, &vectors_b)
670            .unwrap();
671        assert_eq!(distances.len(), 4); // 2x2 combinations
672    }
673
674    #[test]
675    fn test_fallback_gradient_update() {
676        let accelerator = AdaptiveEmbeddingAccelerator::new(None, 1000).unwrap();
677
678        let mut embeddings = vec![Array2::zeros((2, 3))];
679        let gradients = vec![Array2::ones((2, 3))];
680
681        accelerator
682            .adaptive_gradient_update(&mut embeddings, &gradients, 0.01, 0.001)
683            .unwrap();
684
685        // Check that gradients were applied
686        assert!(embeddings[0][[0, 0]] != 0.0);
687    }
688
689    #[cfg(feature = "gpu")]
690    #[test]
691    fn test_gpu_accelerator_creation() {
692        // This test will only run when GPU features are enabled
693        match GpuEmbeddingAccelerator::new(0) {
694            Ok(gpu) => {
695                println!("GPU Accelerator: {}", gpu.device_info());
696                let memory = gpu.available_memory().unwrap_or(0);
697                println!("Available GPU Memory: {} MB", memory / (1024 * 1024));
698            }
699            Err(_) => {
700                println!("GPU not available for testing");
701            }
702        }
703    }
704}