Skip to main content

oxirs_embed/acceleration/
gpu.rs

1//! GPU Acceleration for Embedding Computations
2//!
3//! This module provides GPU-accelerated implementations of embedding operations
4//! using scirs2-linalg GPU features for CUDA, OpenCL, ROCm, and Metal backends.
5
6use crate::models::common::*;
7use anyhow::Result;
8use scirs2_core::ndarray_ext::{Array1, Array2};
9#[cfg(feature = "gpu")]
10use std::collections::VecDeque;
11#[cfg(feature = "gpu")]
12use std::sync::atomic::{AtomicU64, Ordering};
13#[cfg(feature = "gpu")]
14use std::sync::{Arc, Mutex, RwLock};
15#[cfg(feature = "gpu")]
16use std::time::{Duration, Instant};
17
18#[cfg(feature = "gpu")]
19// TODO: scirs2_linalg::gpu module is not yet available
20// Enable this when the GPU module is implemented in scirs2_linalg
21// use scirs2_linalg::gpu::{GpuArray, GpuContext, GpuError};
22// Placeholder types until scirs2_linalg::gpu is available
23pub type GpuArray<T> = Vec<T>;
24#[cfg(feature = "gpu")]
25pub type GpuContext = ();
26#[cfg(feature = "gpu")]
27#[derive(Debug)]
28pub enum GpuError {
29    /// The GPU backend required for this operation is not yet available.
30    ///
31    /// This variant is returned by all operations that are pending
32    /// `scirs2_linalg::gpu` stabilisation.
33    BackendUnavailable {
34        /// Human-readable description of what is missing
35        reason: String,
36        /// Suggestion for how the caller can proceed
37        fallback: String,
38    },
39    /// A general GPU error not covered by the above variants.
40    Other(String),
41}
42
43#[cfg(feature = "gpu")]
44impl std::fmt::Display for GpuError {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            GpuError::BackendUnavailable { reason, fallback } => {
48                write!(f, "GPU backend unavailable: {reason}. Fallback: {fallback}")
49            }
50            GpuError::Other(msg) => write!(f, "GPU error: {msg}"),
51        }
52    }
53}
54
55#[cfg(feature = "gpu")]
56impl std::error::Error for GpuError {}
57
58/// Memory pool for GPU buffers
59#[cfg(feature = "gpu")]
60#[derive(Debug)]
61pub struct GpuMemoryPool {
62    available_buffers: VecDeque<GpuArray<f32>>,
63    buffer_size: usize,
64    total_allocated: AtomicU64,
65    peak_usage: AtomicU64,
66}
67
68/// Adaptive batch sizing configuration
69#[cfg(feature = "gpu")]
70#[derive(Debug, Clone)]
71pub struct AdaptiveBatchConfig {
72    pub min_batch_size: usize,
73    pub max_batch_size: usize,
74    pub target_gpu_utilization: f32,
75    pub memory_usage_threshold: f32,
76}
77
78/// Enhanced GPU-accelerated embedding computations with memory pooling and adaptive batching
79#[cfg(feature = "gpu")]
80pub struct GpuEmbeddingAccelerator {
81    context: GpuContext,
82    device_id: u32,
83    memory_pool: Arc<Mutex<GpuMemoryPool>>,
84    batch_config: AdaptiveBatchConfig,
85    performance_stats: Arc<RwLock<GpuPerformanceStats>>,
86    optimal_batch_size: Arc<AtomicU64>,
87}
88
89/// GPU performance statistics
90#[cfg(feature = "gpu")]
91#[derive(Debug, Default)]
92pub struct GpuPerformanceStats {
93    pub total_operations: u64,
94    pub total_compute_time: Duration,
95    pub memory_transfers: u64,
96    pub cache_hits: u64,
97    pub cache_misses: u64,
98    pub average_batch_size: f32,
99    pub gpu_utilization_percentage: f32,
100}
101
102/// Comprehensive GPU performance report
103#[cfg(feature = "gpu")]
104#[derive(Debug)]
105pub struct GpuPerformanceReport {
106    pub device_id: u32,
107    pub total_operations: u64,
108    pub average_compute_time: Duration,
109    pub gpu_utilization: f32,
110    pub memory_allocated_mb: f64,
111    pub memory_peak_mb: f64,
112    pub cache_hit_rate: f32,
113    pub optimal_batch_size: usize,
114}
115
116#[cfg(feature = "gpu")]
117impl GpuMemoryPool {
118    pub fn new(buffer_size: usize, initial_pool_size: usize) -> Self {
119        Self {
120            available_buffers: VecDeque::with_capacity(initial_pool_size),
121            buffer_size,
122            total_allocated: AtomicU64::new(0),
123            peak_usage: AtomicU64::new(0),
124        }
125    }
126
127    pub fn get_buffer(&mut self) -> Option<GpuArray<f32>> {
128        self.available_buffers.pop_front()
129    }
130
131    pub fn return_buffer(&mut self, buffer: GpuArray<f32>) {
132        if buffer.len() == self.buffer_size {
133            self.available_buffers.push_back(buffer);
134        }
135        // If buffer size doesn't match, let it drop (auto-deallocate)
136    }
137
138    pub fn get_memory_stats(&self) -> (u64, u64) {
139        (
140            self.total_allocated.load(Ordering::Relaxed),
141            self.peak_usage.load(Ordering::Relaxed),
142        )
143    }
144}
145
146#[cfg(feature = "gpu")]
147impl GpuEmbeddingAccelerator {
148    /// Create a new enhanced GPU accelerator with memory pooling and adaptive batching
149    /// Note: Currently using placeholder until scirs2_linalg::gpu is available
150    pub fn new(device_id: u32) -> Result<Self, GpuError> {
151        let context = (); // Placeholder GpuContext
152
153        let memory_pool = Arc::new(Mutex::new(GpuMemoryPool::new(1024 * 1024, 10))); // 1MB buffers, 10 initial
154
155        let batch_config = AdaptiveBatchConfig {
156            min_batch_size: 32,
157            max_batch_size: 8192,
158            target_gpu_utilization: 0.85,
159            memory_usage_threshold: 0.8,
160        };
161
162        Ok(Self {
163            context,
164            device_id,
165            memory_pool,
166            batch_config,
167            performance_stats: Arc::new(RwLock::new(GpuPerformanceStats::default())),
168            optimal_batch_size: Arc::new(AtomicU64::new(512)), // Start with reasonable default
169        })
170    }
171
172    /// Get optimal batch size based on recent performance
173    pub async fn get_optimal_batch_size(&self, data_size: usize) -> usize {
174        let optimal = self.optimal_batch_size.load(Ordering::Relaxed) as usize;
175        let config_min = self.batch_config.min_batch_size;
176        let config_max = self.batch_config.max_batch_size;
177
178        // Clamp to configuration bounds and data size
179        optimal.clamp(config_min, config_max.min(data_size))
180    }
181
182    /// Update optimal batch size based on performance feedback
183    pub async fn update_batch_size_feedback(&self, _batch_size: usize, performance_score: f32) {
184        let current_optimal = self.optimal_batch_size.load(Ordering::Relaxed) as usize;
185
186        // Simple adaptive algorithm: increase if performance is good, decrease if poor
187        let new_optimal = if performance_score > 0.8 {
188            // Good performance, try larger batches
189            (current_optimal as f32 * 1.1).round() as usize
190        } else if performance_score < 0.5 {
191            // Poor performance, try smaller batches
192            (current_optimal as f32 * 0.9).round() as usize
193        } else {
194            current_optimal
195        };
196
197        let clamped_optimal = new_optimal.clamp(
198            self.batch_config.min_batch_size,
199            self.batch_config.max_batch_size,
200        );
201
202        self.optimal_batch_size
203            .store(clamped_optimal as u64, Ordering::Relaxed);
204    }
205
206    /// GPU-accelerated batch distance computation
207    pub fn batch_l2_distances_gpu(
208        &self,
209        vectors_a: &[Array1<f64>],
210        vectors_b: &[Array1<f64>],
211    ) -> Result<Vec<f64>, GpuError> {
212        // TODO: Use actual GPU operations when scirs2_linalg::gpu is available
213        // For now, use CPU implementation as fallback
214        let mut distances = Vec::with_capacity(vectors_a.len());
215        for (a, b) in vectors_a.iter().zip(vectors_b.iter()) {
216            let dist: f64 = a
217                .iter()
218                .zip(b.iter())
219                .map(|(x, y)| (x - y).powi(2))
220                .sum::<f64>()
221                .sqrt();
222            distances.push(dist);
223        }
224        Ok(distances)
225    }
226
227    /// GPU-accelerated cosine similarity matrix
228    /// Note: Currently using CPU fallback until scirs2_linalg::gpu is available
229    pub fn cosine_similarity_matrix_gpu(
230        &self,
231        vectors: &[Array1<f64>],
232    ) -> Result<Array2<f64>, GpuError> {
233        // TODO: Use actual GPU operations when scirs2_linalg::gpu is available
234        // For now, use CPU implementation as fallback
235        use scirs2_core::ndarray_ext::Array2;
236
237        let n = vectors.len();
238        let mut similarity_matrix = Array2::zeros((n, n));
239
240        for i in 0..n {
241            for j in 0..n {
242                let dot: f64 = vectors[i]
243                    .iter()
244                    .zip(vectors[j].iter())
245                    .map(|(a, b)| a * b)
246                    .sum();
247                let norm_i: f64 = vectors[i].iter().map(|x| x * x).sum::<f64>().sqrt();
248                let norm_j: f64 = vectors[j].iter().map(|x| x * x).sum::<f64>().sqrt();
249                similarity_matrix[[i, j]] = dot / (norm_i * norm_j + 1e-8);
250            }
251        }
252        Ok(similarity_matrix)
253    }
254
255    /// GPU-accelerated gradient updates for large embedding matrices
256    /// Note: Currently using CPU fallback until scirs2_linalg::gpu is available
257    pub fn batch_gradient_update_gpu(
258        &self,
259        embeddings: &mut [Array2<f64>],
260        gradients: &[Array2<f64>],
261        learning_rate: f64,
262        l2_reg: f64,
263    ) -> Result<(), GpuError> {
264        // TODO: Use actual GPU operations when scirs2_linalg::gpu is available
265        // For now, use CPU implementation as fallback
266        for (embedding, gradient) in embeddings.iter_mut().zip(gradients.iter()) {
267            // Apply gradient update with L2 regularization
268            for (emb, grad) in embedding.iter_mut().zip(gradient.iter()) {
269                *emb -= learning_rate * (grad + l2_reg * *emb);
270            }
271        }
272        Ok(())
273    }
274
275    /// Advanced GPU-accelerated adaptive batch processing with memory pooling
276    pub async fn adaptive_batch_processing<T, R>(
277        &self,
278        data: &[T],
279        mut process_fn: impl FnMut(&[T]) -> Result<Vec<R>, GpuError>,
280    ) -> Result<Vec<R>, GpuError> {
281        let start_time = Instant::now();
282        let batch_size = self.get_optimal_batch_size(data.len()).await;
283
284        let mut results = Vec::with_capacity(data.len());
285        let mut total_processing_time = Duration::ZERO;
286
287        for chunk in data.chunks(batch_size) {
288            let chunk_start = Instant::now();
289            let chunk_results = process_fn(chunk)?;
290            let chunk_time = chunk_start.elapsed();
291
292            results.extend(chunk_results);
293            total_processing_time += chunk_time;
294        }
295
296        // Calculate performance score and update batch size
297        let total_time = start_time.elapsed();
298        let gpu_utilization = total_processing_time.as_secs_f32() / total_time.as_secs_f32();
299        let performance_score = gpu_utilization.min(1.0);
300
301        self.update_batch_size_feedback(batch_size, performance_score)
302            .await;
303
304        // Update performance statistics
305        let mut stats = self
306            .performance_stats
307            .write()
308            .expect("lock should not be poisoned");
309        stats.total_operations += 1;
310        stats.total_compute_time += total_time;
311        stats.gpu_utilization_percentage = gpu_utilization * 100.0;
312        stats.average_batch_size = (stats.average_batch_size + batch_size as f32) / 2.0;
313
314        Ok(results)
315    }
316
317    /// GPU-accelerated matrix multiplication with memory reuse
318    /// Note: Currently using CPU fallback until scirs2_linalg::gpu is available
319    pub async fn optimized_matrix_multiply(
320        &self,
321        a: &Array2<f32>,
322        b: &Array2<f32>,
323    ) -> Result<Array2<f32>, GpuError> {
324        // TODO: Use actual GPU operations when scirs2_linalg::gpu is available
325        // For now, use CPU implementation as fallback
326        let result = a.dot(b);
327
328        Ok(result)
329    }
330
331    /// High-performance embedding search with GPU acceleration
332    pub async fn gpu_embedding_search(
333        &self,
334        query_embedding: &Array1<f32>,
335        database_embeddings: &[Array1<f32>],
336        top_k: usize,
337    ) -> Result<Vec<(usize, f32)>, GpuError> {
338        // Use adaptive batching for large databases
339        let batch_size = self.get_optimal_batch_size(database_embeddings.len()).await;
340        let mut all_similarities = Vec::with_capacity(database_embeddings.len());
341
342        // Process in adaptive batches
343        for (batch_idx, batch) in database_embeddings.chunks(batch_size).enumerate() {
344            let similarities = self
345                .compute_batch_similarities(query_embedding, batch)
346                .await?;
347
348            for (local_idx, similarity) in similarities.iter().enumerate() {
349                let global_idx = batch_idx * batch_size + local_idx;
350                all_similarities.push((global_idx, *similarity));
351            }
352        }
353
354        // Sort and return top-k
355        all_similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
356        all_similarities.truncate(top_k);
357
358        Ok(all_similarities)
359    }
360
361    /// Compute similarities for a batch with GPU acceleration
362    /// Note: Currently using CPU fallback until scirs2_linalg::gpu is available
363    async fn compute_batch_similarities(
364        &self,
365        query: &Array1<f32>,
366        batch: &[Array1<f32>],
367    ) -> Result<Vec<f32>, GpuError> {
368        // TODO: Use actual GPU operations when scirs2_linalg::gpu is available
369        // For now, use CPU implementation as fallback
370        let mut similarities = Vec::with_capacity(batch.len());
371
372        for emb in batch {
373            // Compute cosine similarity: (a ยท b) / (||a|| * ||b||)
374            let dot_product: f32 = query.iter().zip(emb.iter()).map(|(a, b)| a * b).sum();
375            let norm_query: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
376            let norm_emb: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
377            let similarity = dot_product / (norm_query * norm_emb + 1e-8);
378            similarities.push(similarity);
379        }
380
381        Ok(similarities)
382    }
383
384    /// GPU-accelerated Xavier initialization for large embedding matrices
385    /// Note: Currently using CPU fallback until scirs2_linalg::gpu is available
386    pub fn xavier_init_gpu(
387        &self,
388        shapes: &[(usize, usize)],
389        fan_in: usize,
390        fan_out: usize,
391        seed: u64,
392    ) -> Result<Vec<Array2<f64>>, GpuError> {
393        use scirs2_core::random::Random;
394
395        let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
396        let mut rng = Random::seed(seed);
397        let scale = 2.0 * limit;
398
399        let mut results = Vec::with_capacity(shapes.len());
400        for &(rows, cols) in shapes {
401            // Generate uniform random numbers in [-limit, limit]
402            let data: Vec<f64> = (0..rows * cols)
403                .map(|_| rng.random_f64() * scale - limit)
404                .collect();
405            let array = Array2::from_shape_vec((rows, cols), data)
406                .map_err(|e| GpuError::Other(format!("Failed to create array: {}", e)))?;
407            results.push(array);
408        }
409        Ok(results)
410    }
411
412    /// GPU-accelerated contrastive learning updates
413    ///
414    /// This operation requires `scirs2_linalg::gpu` which is not yet stable.
415    /// Returning `Ok(0.0)` would silently report zero loss even when the model
416    /// has not converged, masking training failures from callers.
417    pub fn contrastive_learning_gpu(
418        &self,
419        _entity_embeddings: &mut [Array1<f32>],
420        _similarity_pairs: &[(usize, usize)],
421        _negative_samples: &[(usize, usize)],
422        _temperature: f32,
423        _learning_rate: f32,
424    ) -> Result<f32, GpuError> {
425        Err(GpuError::BackendUnavailable {
426            reason: "contrastive_learning_gpu requires scirs2_linalg::gpu \
427                     GPU tensor operations which are not yet stable"
428                .to_string(),
429            fallback: "implement contrastive learning on the CPU embedding arrays directly \
430                       without calling this method"
431                .to_string(),
432        })
433    }
434
435    /// Helper function to upload vectors to GPU.
436    ///
437    /// Returns an error rather than an empty `Vec` to surface the waiting state
438    /// to callers: an empty slice is indistinguishable from an empty upload, which
439    /// would silently produce wrong results in downstream GPU kernels.
440    fn upload_vectors_to_gpu(&self, _vectors: &[Array1<f64>]) -> Result<GpuArray<f64>, GpuError> {
441        Err(GpuError::BackendUnavailable {
442            reason: "upload_vectors_to_gpu requires scirs2_linalg::gpu \
443                     GPU memory transfer which is not yet stable"
444                .to_string(),
445            fallback: "operate on CPU Array1<f64> slices directly; \
446                       GPU upload is a no-op until the backend is available"
447                .to_string(),
448        })
449    }
450
451    /// Helper function to upload f32 vectors to GPU.
452    ///
453    /// Returns an error rather than an empty `Vec` to surface the waiting state
454    /// to callers.
455    fn upload_f32_vectors_to_gpu(
456        &self,
457        _vectors: &[Array1<f32>],
458    ) -> Result<GpuArray<f32>, GpuError> {
459        Err(GpuError::BackendUnavailable {
460            reason: "upload_f32_vectors_to_gpu requires scirs2_linalg::gpu \
461                     GPU memory transfer which is not yet stable"
462                .to_string(),
463            fallback: "operate on CPU Array1<f32> slices directly; \
464                       GPU upload is a no-op until the backend is available"
465                .to_string(),
466        })
467    }
468
469    /// Get GPU device info
470    pub fn device_info(&self) -> String {
471        format!(
472            "GPU Device {} (placeholder - scirs2_linalg::gpu not yet available)",
473            self.device_id
474        )
475    }
476
477    /// Get available GPU memory.
478    ///
479    /// Cannot return a meaningful value without `scirs2_linalg::gpu`; returning
480    /// `Ok(0)` would be interpreted by callers as "GPU has no memory" rather than
481    /// "GPU memory query is not yet implemented", causing incorrect adaptive batching.
482    pub fn available_memory(&self) -> Result<u64, GpuError> {
483        Err(GpuError::BackendUnavailable {
484            reason: "available_memory requires scirs2_linalg::gpu device query \
485                     which is not yet stable"
486                .to_string(),
487            fallback: "use GpuEmbeddingAccelerator::device_info() for a status string, \
488                       or check available system RAM via the non-GPU fallback"
489                .to_string(),
490        })
491    }
492
493    /// GPU memory and performance monitoring
494    pub async fn get_performance_report(&self) -> GpuPerformanceReport {
495        let stats = self
496            .performance_stats
497            .read()
498            .expect("lock should not be poisoned");
499        let (allocated, peak) = {
500            let pool = self
501                .memory_pool
502                .lock()
503                .expect("lock should not be poisoned");
504            pool.get_memory_stats()
505        };
506
507        GpuPerformanceReport {
508            device_id: self.device_id,
509            total_operations: stats.total_operations,
510            average_compute_time: if stats.total_operations > 0 {
511                stats.total_compute_time / stats.total_operations as u32
512            } else {
513                Duration::ZERO
514            },
515            gpu_utilization: stats.gpu_utilization_percentage,
516            memory_allocated_mb: allocated as f64 / (1024.0 * 1024.0),
517            memory_peak_mb: peak as f64 / (1024.0 * 1024.0),
518            cache_hit_rate: if stats.cache_hits + stats.cache_misses > 0 {
519                stats.cache_hits as f32 / (stats.cache_hits + stats.cache_misses) as f32
520            } else {
521                0.0
522            },
523            optimal_batch_size: self.optimal_batch_size.load(Ordering::Relaxed) as usize,
524        }
525    }
526
527    /// Reset performance statistics
528    pub fn reset_performance_stats(&self) {
529        let mut stats = self
530            .performance_stats
531            .write()
532            .expect("lock should not be poisoned");
533        *stats = GpuPerformanceStats::default();
534        self.optimal_batch_size.store(512, Ordering::Relaxed);
535    }
536
537    /// Get current memory pool status
538    pub fn get_memory_pool_status(&self) -> (usize, u64, u64) {
539        let pool = self
540            .memory_pool
541            .lock()
542            .expect("lock should not be poisoned");
543        let (allocated, peak) = pool.get_memory_stats();
544        (pool.available_buffers.len(), allocated, peak)
545    }
546}
547
548/// CPU fallback implementations when GPU is not available
549#[cfg(not(feature = "gpu"))]
550use scirs2_core::random::Random;
551
552#[cfg(not(feature = "gpu"))]
553pub struct GpuEmbeddingAccelerator;
554
555#[cfg(not(feature = "gpu"))]
556impl GpuEmbeddingAccelerator {
557    pub fn new(_device_id: u32) -> Result<Self> {
558        Ok(Self)
559    }
560
561    /// Fallback to CPU implementation
562    pub fn batch_l2_distances_gpu(
563        &self,
564        vectors_a: &[Array1<f64>],
565        vectors_b: &[Array1<f64>],
566    ) -> Result<Vec<f64>> {
567        Ok(batch_l2_distances(vectors_a, vectors_b))
568    }
569
570    /// Fallback to CPU implementation
571    pub fn cosine_similarity_matrix_gpu(&self, vectors: &[Array1<f64>]) -> Result<Array2<f64>> {
572        Ok(pairwise_distances(vectors))
573    }
574
575    /// Fallback to CPU implementation
576    pub fn batch_gradient_update_gpu(
577        &self,
578        embeddings: &mut [Array2<f64>],
579        gradients: &[Array2<f64>],
580        learning_rate: f64,
581        l2_reg: f64,
582    ) -> Result<()> {
583        batch_gradient_update(embeddings, gradients, learning_rate, l2_reg);
584        Ok(())
585    }
586
587    /// Fallback to CPU implementation
588    pub fn xavier_init_gpu(
589        &self,
590        shapes: &[(usize, usize)],
591        fan_in: usize,
592        fan_out: usize,
593        _seed: u64,
594    ) -> Result<Vec<Array2<f64>>> {
595        let mut rng = Random::default();
596        Ok(batch_xavier_init(shapes, fan_in, fan_out, &mut rng))
597    }
598
599    pub fn device_info(&self) -> String {
600        "CPU (GPU acceleration not available)".to_string()
601    }
602
603    pub fn available_memory(&self) -> Result<u64> {
604        // Return available system RAM as approximation
605        Ok(8 * 1024 * 1024 * 1024) // 8GB default
606    }
607}
608
609/// Adaptive acceleration that chooses between GPU and CPU based on problem size
610pub struct AdaptiveEmbeddingAccelerator {
611    gpu_accelerator: Option<GpuEmbeddingAccelerator>,
612    gpu_threshold: usize,
613}
614
615impl AdaptiveEmbeddingAccelerator {
616    /// Create adaptive accelerator with optional GPU support
617    pub fn new(device_id: Option<u32>, gpu_threshold: usize) -> Result<Self> {
618        #[allow(unused_variables)]
619        let gpu_accelerator = if let Some(id) = device_id {
620            #[cfg(feature = "gpu")]
621            {
622                GpuEmbeddingAccelerator::new(id).ok()
623            }
624            #[cfg(not(feature = "gpu"))]
625            {
626                None
627            }
628        } else {
629            None
630        };
631
632        Ok(Self {
633            gpu_accelerator,
634            gpu_threshold,
635        })
636    }
637
638    /// Intelligently choose between GPU and CPU for distance computation
639    pub fn adaptive_batch_distances(
640        &self,
641        vectors_a: &[Array1<f64>],
642        vectors_b: &[Array1<f64>],
643    ) -> Result<Vec<f64>> {
644        if self.should_use_gpu(vectors_a.len() * vectors_b.len()) {
645            if let Some(ref gpu) = self.gpu_accelerator {
646                return gpu
647                    .batch_l2_distances_gpu(vectors_a, vectors_b)
648                    .map_err(|e| anyhow::anyhow!("GPU error: {:?}", e));
649            }
650        }
651
652        // Fallback to optimized CPU implementation
653        Ok(batch_l2_distances(vectors_a, vectors_b))
654    }
655
656    /// Intelligently choose between GPU and CPU for gradient updates
657    pub fn adaptive_gradient_update(
658        &self,
659        embeddings: &mut [Array2<f64>],
660        gradients: &[Array2<f64>],
661        learning_rate: f64,
662        l2_reg: f64,
663    ) -> Result<()> {
664        let total_elements: usize = embeddings.iter().map(|e| e.len()).sum();
665
666        if self.should_use_gpu(total_elements) {
667            if let Some(ref gpu) = self.gpu_accelerator {
668                return gpu
669                    .batch_gradient_update_gpu(embeddings, gradients, learning_rate, l2_reg)
670                    .map_err(|e| anyhow::anyhow!("GPU error: {:?}", e));
671            }
672        }
673
674        // Fallback to optimized CPU implementation
675        batch_gradient_update(embeddings, gradients, learning_rate, l2_reg);
676        Ok(())
677    }
678
679    /// Check if GPU should be used based on problem size
680    fn should_use_gpu(&self, problem_size: usize) -> bool {
681        self.gpu_accelerator.is_some() && problem_size >= self.gpu_threshold
682    }
683
684    /// Get acceleration info
685    pub fn info(&self) -> String {
686        match &self.gpu_accelerator {
687            Some(gpu) => format!(
688                "Adaptive: {} (threshold: {})",
689                gpu.device_info(),
690                self.gpu_threshold
691            ),
692            None => format!("Adaptive: CPU only (threshold: {})", self.gpu_threshold),
693        }
694    }
695}
696
697#[cfg(test)]
698mod tests {
699    use super::*;
700
701    #[test]
702    fn test_adaptive_accelerator_creation() {
703        let accelerator = AdaptiveEmbeddingAccelerator::new(None, 1000).expect("should succeed");
704        assert!(accelerator.info().contains("CPU only"));
705    }
706
707    #[test]
708    fn test_fallback_distance_computation() {
709        let accelerator = AdaptiveEmbeddingAccelerator::new(None, 1000).expect("should succeed");
710
711        let vectors_a = vec![
712            Array1::from_vec(vec![1.0, 2.0, 3.0]),
713            Array1::from_vec(vec![4.0, 5.0, 6.0]),
714        ];
715        let vectors_b = vec![
716            Array1::from_vec(vec![7.0, 8.0, 9.0]),
717            Array1::from_vec(vec![10.0, 11.0, 12.0]),
718        ];
719
720        let distances = accelerator
721            .adaptive_batch_distances(&vectors_a, &vectors_b)
722            .expect("should succeed");
723        assert_eq!(distances.len(), 4); // 2x2 combinations
724    }
725
726    #[test]
727    fn test_fallback_gradient_update() {
728        let accelerator = AdaptiveEmbeddingAccelerator::new(None, 1000).expect("should succeed");
729
730        let mut embeddings = vec![Array2::zeros((2, 3))];
731        let gradients = vec![Array2::ones((2, 3))];
732
733        accelerator
734            .adaptive_gradient_update(&mut embeddings, &gradients, 0.01, 0.001)
735            .expect("should succeed");
736
737        // Check that gradients were applied
738        assert!(embeddings[0][[0, 0]] != 0.0);
739    }
740
741    #[cfg(feature = "gpu")]
742    #[test]
743    fn test_gpu_accelerator_creation() {
744        // This test will only run when GPU features are enabled
745        match GpuEmbeddingAccelerator::new(0) {
746            Ok(gpu) => {
747                println!("GPU Accelerator: {}", gpu.device_info());
748                // available_memory surfaces a typed error while scirs2_linalg::gpu is pending
749                match gpu.available_memory() {
750                    Ok(bytes) => println!("Available GPU Memory: {} MB", bytes / (1024 * 1024)),
751                    Err(e) => println!("GPU memory query pending: {}", e),
752                }
753            }
754            Err(_) => {
755                println!("GPU not available for testing");
756            }
757        }
758    }
759}