oxirs_embed/
gpu_acceleration.rs

1//! GPU acceleration and optimization features for embedding models
2//!
3//! This module provides advanced GPU acceleration capabilities including
4//! memory pooling, tensor caching, mixed precision, and compute optimization.
5
6use anyhow::{anyhow, Result};
7use scirs2_core::ndarray_ext::{Array1, Array2};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12use tracing::{debug, info, warn};
13
14/// GPU acceleration configuration
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct GpuAccelerationConfig {
17    /// Enable GPU acceleration
18    pub enabled: bool,
19    /// GPU device IDs to use
20    pub device_ids: Vec<usize>,
21    /// Memory pool size in MB
22    pub memory_pool_size_mb: usize,
23    /// Enable mixed precision
24    pub mixed_precision: bool,
25    /// Enable tensor caching
26    pub tensor_caching: bool,
27    /// Cache size in MB
28    pub cache_size_mb: usize,
29    /// Enable kernel fusion
30    pub kernel_fusion: bool,
31    /// Enable memory mapping
32    pub memory_mapping: bool,
33    /// Enable unified memory
34    pub unified_memory: bool,
35    /// Enable multi-stream processing
36    pub multi_stream: bool,
37    /// Number of streams for multi-stream processing
38    pub num_streams: usize,
39    /// Enable pipeline parallelism
40    pub pipeline_parallelism: bool,
41    /// Pipeline stages
42    pub pipeline_stages: usize,
43}
44
45impl Default for GpuAccelerationConfig {
46    fn default() -> Self {
47        Self {
48            enabled: true,
49            device_ids: vec![0],
50            memory_pool_size_mb: 2048, // 2GB default
51            mixed_precision: true,
52            tensor_caching: true,
53            cache_size_mb: 512, // 512MB cache
54            kernel_fusion: true,
55            memory_mapping: true,
56            unified_memory: false, // Conservative default
57            multi_stream: true,
58            num_streams: 4,
59            pipeline_parallelism: false, // Requires careful setup
60            pipeline_stages: 2,
61        }
62    }
63}
64
65/// GPU memory pool for efficient memory management
66pub struct GpuMemoryPool {
67    config: GpuAccelerationConfig,
68    allocated_blocks: Arc<Mutex<HashMap<usize, MemoryBlock>>>,
69    free_blocks: Arc<Mutex<VecDeque<MemoryBlock>>>,
70    total_allocated: Arc<Mutex<usize>>,
71    allocation_stats: Arc<Mutex<AllocationStats>>,
72}
73
74/// Memory block descriptor
75#[derive(Debug, Clone)]
76struct MemoryBlock {
77    device_id: usize,
78    size_bytes: usize,
79    ptr: usize, // In real implementation, this would be a GPU pointer
80    allocated_at: Instant,
81    last_used: Instant,
82}
83
84/// Memory allocation statistics
85#[derive(Debug, Default, Clone)]
86pub struct AllocationStats {
87    pub total_allocations: usize,
88    pub total_deallocations: usize,
89    pub peak_memory_usage: usize,
90    pub current_memory_usage: usize,
91    pub cache_hits: usize,
92    pub cache_misses: usize,
93}
94
95impl GpuMemoryPool {
96    /// Create new GPU memory pool
97    pub fn new(config: GpuAccelerationConfig) -> Self {
98        Self {
99            config,
100            allocated_blocks: Arc::new(Mutex::new(HashMap::new())),
101            free_blocks: Arc::new(Mutex::new(VecDeque::new())),
102            total_allocated: Arc::new(Mutex::new(0)),
103            allocation_stats: Arc::new(Mutex::new(AllocationStats::default())),
104        }
105    }
106
107    /// Allocate GPU memory block
108    pub fn allocate(&self, size_bytes: usize, device_id: usize) -> Result<usize> {
109        let mut free_blocks = self.free_blocks.lock().unwrap();
110        let mut allocated_blocks = self.allocated_blocks.lock().unwrap();
111        let mut stats = self.allocation_stats.lock().unwrap();
112
113        // Try to find a suitable free block first
114        for (i, block) in free_blocks.iter().enumerate() {
115            if block.size_bytes >= size_bytes && block.device_id == device_id {
116                let block = free_blocks.remove(i).unwrap();
117                let block_id = block.ptr;
118
119                let mut reused_block = block;
120                reused_block.last_used = Instant::now();
121
122                allocated_blocks.insert(block_id, reused_block);
123                stats.cache_hits += 1;
124
125                debug!(
126                    "Reused GPU memory block {} of size {}",
127                    block_id, size_bytes
128                );
129                return Ok(block_id);
130            }
131        }
132
133        // No suitable free block found, allocate new one
134        stats.cache_misses += 1;
135        stats.total_allocations += 1;
136
137        let block_id = stats.total_allocations; // Simple ID generation
138        let now = Instant::now();
139
140        let block = MemoryBlock {
141            device_id,
142            size_bytes,
143            ptr: block_id,
144            allocated_at: now,
145            last_used: now,
146        };
147
148        allocated_blocks.insert(block_id, block);
149
150        let mut total_allocated = self.total_allocated.lock().unwrap();
151        *total_allocated += size_bytes;
152        stats.current_memory_usage += size_bytes;
153
154        if stats.current_memory_usage > stats.peak_memory_usage {
155            stats.peak_memory_usage = stats.current_memory_usage;
156        }
157
158        info!(
159            "Allocated new GPU memory block {} of size {} bytes",
160            block_id, size_bytes
161        );
162        Ok(block_id)
163    }
164
165    /// Deallocate GPU memory block
166    pub fn deallocate(&self, block_id: usize) -> Result<()> {
167        let mut allocated_blocks = self.allocated_blocks.lock().unwrap();
168        let mut free_blocks = self.free_blocks.lock().unwrap();
169        let mut stats = self.allocation_stats.lock().unwrap();
170
171        if let Some(block) = allocated_blocks.remove(&block_id) {
172            stats.total_deallocations += 1;
173            stats.current_memory_usage -= block.size_bytes;
174
175            // Add to free blocks for reuse
176            free_blocks.push_back(block);
177
178            // Limit free blocks to prevent memory leaks
179            if free_blocks.len() > 100 {
180                free_blocks.pop_front();
181            }
182
183            debug!("Deallocated GPU memory block {}", block_id);
184            Ok(())
185        } else {
186            Err(anyhow!("Block {} not found for deallocation", block_id))
187        }
188    }
189
190    /// Get allocation statistics
191    pub fn get_stats(&self) -> AllocationStats {
192        (*self.allocation_stats.lock().unwrap()).clone()
193    }
194
195    /// Defragment memory by consolidating free blocks
196    pub fn defragment(&self) -> Result<()> {
197        let mut free_blocks = self.free_blocks.lock().unwrap();
198
199        // Sort free blocks by device and size
200        let mut blocks: Vec<_> = free_blocks.drain(..).collect();
201        blocks.sort_by_key(|b| (b.device_id, b.size_bytes));
202
203        // Merge adjacent blocks (simplified implementation)
204        let mut merged_blocks = VecDeque::new();
205        let mut current_block: Option<MemoryBlock> = None;
206
207        for block in blocks {
208            if let Some(ref mut current) = current_block {
209                if current.device_id == block.device_id {
210                    // In a real implementation, we'd check if blocks are adjacent
211                    current.size_bytes += block.size_bytes;
212                } else {
213                    merged_blocks.push_back(current.clone());
214                    current_block = Some(block);
215                }
216            } else {
217                current_block = Some(block);
218            }
219        }
220
221        if let Some(block) = current_block {
222            merged_blocks.push_back(block);
223        }
224
225        *free_blocks = merged_blocks;
226
227        info!(
228            "Memory defragmentation completed, {} free blocks remaining",
229            free_blocks.len()
230        );
231        Ok(())
232    }
233}
234
235/// Tensor cache for frequently used tensors
236pub struct TensorCache {
237    config: GpuAccelerationConfig,
238    entity_tensors: Arc<Mutex<HashMap<String, CachedTensor>>>,
239    attention_weights: Arc<Mutex<HashMap<String, CachedTensor>>>,
240    intermediate_activations: Arc<Mutex<HashMap<String, CachedTensor>>>,
241    cache_stats: Arc<Mutex<CacheStats>>,
242}
243
244/// Cached tensor with metadata
245#[derive(Debug, Clone)]
246struct CachedTensor {
247    data: Array2<f32>, // In real implementation, this would be GPU tensor
248    device_id: usize,
249    last_accessed: Instant,
250    access_count: usize,
251    size_bytes: usize,
252}
253
254/// Cache statistics
255#[derive(Debug, Default, Clone)]
256pub struct CacheStats {
257    pub hits: usize,
258    pub misses: usize,
259    pub evictions: usize,
260    pub total_memory_usage: usize,
261}
262
263impl TensorCache {
264    /// Create new tensor cache
265    pub fn new(config: GpuAccelerationConfig) -> Self {
266        Self {
267            config,
268            entity_tensors: Arc::new(Mutex::new(HashMap::new())),
269            attention_weights: Arc::new(Mutex::new(HashMap::new())),
270            intermediate_activations: Arc::new(Mutex::new(HashMap::new())),
271            cache_stats: Arc::new(Mutex::new(CacheStats::default())),
272        }
273    }
274
275    /// Cache entity tensor
276    pub fn cache_entity_tensor(&self, entity: &str, tensor: Array2<f32>, device_id: usize) {
277        let mut cache = self.entity_tensors.lock().unwrap();
278        let mut stats = self.cache_stats.lock().unwrap();
279
280        let size_bytes = tensor.len() * std::mem::size_of::<f32>();
281
282        let cached_tensor = CachedTensor {
283            data: tensor,
284            device_id,
285            last_accessed: Instant::now(),
286            access_count: 1,
287            size_bytes,
288        };
289
290        // Check if we need to evict old entries
291        self.evict_if_needed(&mut stats);
292
293        cache.insert(entity.to_string(), cached_tensor);
294        stats.total_memory_usage += size_bytes;
295
296        debug!("Cached entity tensor for {}", entity);
297    }
298
299    /// Get cached entity tensor
300    pub fn get_entity_tensor(&self, entity: &str) -> Option<Array2<f32>> {
301        let mut cache = self.entity_tensors.lock().unwrap();
302        let mut stats = self.cache_stats.lock().unwrap();
303
304        if let Some(cached) = cache.get_mut(entity) {
305            cached.last_accessed = Instant::now();
306            cached.access_count += 1;
307            stats.hits += 1;
308
309            debug!("Cache hit for entity tensor {}", entity);
310            Some(cached.data.clone())
311        } else {
312            stats.misses += 1;
313            debug!("Cache miss for entity tensor {}", entity);
314            None
315        }
316    }
317
318    /// Cache attention weights
319    pub fn cache_attention_weights(&self, key: &str, weights: Array2<f32>, device_id: usize) {
320        let mut cache = self.attention_weights.lock().unwrap();
321        let mut stats = self.cache_stats.lock().unwrap();
322
323        let size_bytes = weights.len() * std::mem::size_of::<f32>();
324
325        let cached_tensor = CachedTensor {
326            data: weights,
327            device_id,
328            last_accessed: Instant::now(),
329            access_count: 1,
330            size_bytes,
331        };
332
333        self.evict_if_needed(&mut stats);
334
335        cache.insert(key.to_string(), cached_tensor);
336        stats.total_memory_usage += size_bytes;
337
338        debug!("Cached attention weights for key {}", key);
339    }
340
341    /// Get cached attention weights
342    pub fn get_attention_weights(&self, key: &str) -> Option<Array2<f32>> {
343        let mut cache = self.attention_weights.lock().unwrap();
344        let mut stats = self.cache_stats.lock().unwrap();
345
346        if let Some(cached) = cache.get_mut(key) {
347            cached.last_accessed = Instant::now();
348            cached.access_count += 1;
349            stats.hits += 1;
350
351            debug!("Cache hit for attention weights {}", key);
352            Some(cached.data.clone())
353        } else {
354            stats.misses += 1;
355            debug!("Cache miss for attention weights {}", key);
356            None
357        }
358    }
359
360    /// Evict old entries if cache is too large
361    fn evict_if_needed(&self, stats: &mut CacheStats) {
362        let max_memory = self.config.cache_size_mb * 1024 * 1024; // Convert MB to bytes
363
364        if stats.total_memory_usage > max_memory {
365            // Simple LRU eviction (would be more sophisticated in real implementation)
366            stats.evictions += 1;
367            stats.total_memory_usage = max_memory / 2; // Simplified
368
369            warn!("Tensor cache eviction triggered, freed memory");
370        }
371    }
372
373    /// Get cache statistics
374    pub fn get_stats(&self) -> CacheStats {
375        (*self.cache_stats.lock().unwrap()).clone()
376    }
377
378    /// Clear all caches
379    pub fn clear_all(&self) {
380        self.entity_tensors.lock().unwrap().clear();
381        self.attention_weights.lock().unwrap().clear();
382        self.intermediate_activations.lock().unwrap().clear();
383
384        let mut stats = self.cache_stats.lock().unwrap();
385        stats.total_memory_usage = 0;
386
387        info!("Cleared all tensor caches");
388    }
389}
390
391/// Mixed precision training and inference
392pub struct MixedPrecisionProcessor {
393    config: GpuAccelerationConfig,
394    fp16_enabled: bool,
395    loss_scaling: f32,
396    overflow_detection: bool,
397}
398
399impl MixedPrecisionProcessor {
400    /// Create new mixed precision processor
401    pub fn new(config: GpuAccelerationConfig) -> Self {
402        Self {
403            config: config.clone(),
404            fp16_enabled: config.mixed_precision,
405            loss_scaling: 65536.0, // Default loss scaling for FP16
406            overflow_detection: true,
407        }
408    }
409
410    /// Convert tensor to FP16 for computation
411    pub fn to_fp16(&self, tensor: &Array2<f32>) -> Array2<f32> {
412        if !self.fp16_enabled {
413            return tensor.clone();
414        }
415
416        // Simulate FP16 conversion (real implementation would use GPU ops)
417        tensor.mapv(|x| {
418            // Clamp to FP16 range and simulate precision loss
419            let clamped = x.clamp(-65504.0, 65504.0);
420            (clamped * 1024.0).round() / 1024.0 // Simulate FP16 precision
421        })
422    }
423
424    /// Apply loss scaling for gradient computation
425    pub fn scale_loss(&self, loss: f32) -> f32 {
426        if self.fp16_enabled {
427            loss * self.loss_scaling
428        } else {
429            loss
430        }
431    }
432
433    /// Unscale gradients after loss scaling
434    pub fn unscale_gradients(&self, gradients: &mut Array2<f32>) -> bool {
435        if !self.fp16_enabled {
436            return true;
437        }
438
439        // Check for overflow
440        if self.overflow_detection {
441            let has_overflow = gradients.iter().any(|&x| !x.is_finite());
442            if has_overflow {
443                warn!("Gradient overflow detected in mixed precision training");
444                return false;
445            }
446        }
447
448        // Unscale gradients
449        gradients.mapv_inplace(|x| x / self.loss_scaling);
450        true
451    }
452
453    /// Adjust loss scaling based on overflow detection
454    pub fn adjust_loss_scaling(&mut self, overflow_detected: bool) {
455        if overflow_detected {
456            self.loss_scaling = (self.loss_scaling / 2.0).max(1.0);
457            info!("Reduced loss scaling to {}", self.loss_scaling);
458        } else {
459            // Gradually increase loss scaling if no overflow
460            self.loss_scaling = (self.loss_scaling * 1.1).min(65536.0);
461        }
462    }
463}
464
465/// Multi-stream processor for parallel GPU operations
466pub struct MultiStreamProcessor {
467    config: GpuAccelerationConfig,
468    pub stream_ids: Vec<usize>,
469    current_stream: usize,
470}
471
472impl MultiStreamProcessor {
473    /// Create new multi-stream processor
474    pub fn new(config: GpuAccelerationConfig) -> Self {
475        let stream_ids = (0..config.num_streams).collect();
476
477        Self {
478            config,
479            stream_ids,
480            current_stream: 0,
481        }
482    }
483
484    /// Get next available stream
485    pub fn get_next_stream(&mut self) -> usize {
486        let stream_id = self.stream_ids[self.current_stream];
487        self.current_stream = (self.current_stream + 1) % self.stream_ids.len();
488        stream_id
489    }
490
491    /// Process embeddings in parallel across multiple streams
492    pub async fn process_batch_parallel(
493        &mut self,
494        entities: Vec<String>,
495        process_fn: impl Fn(String, usize) -> Array1<f32> + Send + Sync + Copy + 'static,
496    ) -> Result<Vec<Array1<f32>>> {
497        let chunk_size = (entities.len() + self.config.num_streams - 1) / self.config.num_streams;
498        let mut tasks = Vec::new();
499
500        for chunk in entities.chunks(chunk_size) {
501            let stream_id = self.get_next_stream();
502            let chunk_entities = chunk.to_vec();
503
504            let task = tokio::spawn(async move {
505                let mut results = Vec::new();
506                for entity in chunk_entities {
507                    let embedding = process_fn(entity, stream_id);
508                    results.push(embedding);
509                }
510                results
511            });
512
513            tasks.push(task);
514        }
515
516        // Collect results from all streams
517        let mut all_results = Vec::new();
518        for task in tasks {
519            let chunk_results = task.await?;
520            all_results.extend(chunk_results);
521        }
522
523        Ok(all_results)
524    }
525
526    /// Synchronize all streams
527    pub fn synchronize_all(&self) {
528        // In real implementation, this would synchronize GPU streams
529        debug!("Synchronized {} GPU streams", self.stream_ids.len());
530    }
531}
532
533/// Main GPU acceleration manager
534pub struct GpuAccelerationManager {
535    config: GpuAccelerationConfig,
536    memory_pool: GpuMemoryPool,
537    tensor_cache: TensorCache,
538    mixed_precision: MixedPrecisionProcessor,
539    multi_stream: MultiStreamProcessor,
540}
541
542impl GpuAccelerationManager {
543    /// Create new GPU acceleration manager
544    pub fn new(config: GpuAccelerationConfig) -> Self {
545        let memory_pool = GpuMemoryPool::new(config.clone());
546        let tensor_cache = TensorCache::new(config.clone());
547        let mixed_precision = MixedPrecisionProcessor::new(config.clone());
548        let multi_stream = MultiStreamProcessor::new(config.clone());
549
550        Self {
551            config,
552            memory_pool,
553            tensor_cache,
554            mixed_precision,
555            multi_stream,
556        }
557    }
558
559    /// Get memory pool
560    pub fn memory_pool(&self) -> &GpuMemoryPool {
561        &self.memory_pool
562    }
563
564    /// Get tensor cache
565    pub fn tensor_cache(&self) -> &TensorCache {
566        &self.tensor_cache
567    }
568
569    /// Get mixed precision processor
570    pub fn mixed_precision(&mut self) -> &mut MixedPrecisionProcessor {
571        &mut self.mixed_precision
572    }
573
574    /// Get multi-stream processor
575    pub fn multi_stream(&mut self) -> &mut MultiStreamProcessor {
576        &mut self.multi_stream
577    }
578
579    /// Optimize embedding computation with GPU acceleration
580    pub async fn accelerated_embedding_generation(
581        &mut self,
582        entities: Vec<String>,
583        base_compute_fn: impl Fn(&str) -> Array1<f32> + Send + Sync + Copy + 'static,
584    ) -> Result<Vec<Array1<f32>>> {
585        if !self.config.enabled {
586            // Fallback to CPU computation
587            return Ok(entities.iter().map(|e| base_compute_fn(e)).collect());
588        }
589
590        // Use multi-stream processing for parallel computation
591        let results = self
592            .multi_stream
593            .process_batch_parallel(entities, move |entity, stream_id| {
594                // In real implementation, this would use the appropriate GPU stream
595                debug!("Processing entity {} on stream {}", entity, stream_id);
596                base_compute_fn(&entity)
597            })
598            .await?;
599
600        self.multi_stream.synchronize_all();
601        Ok(results)
602    }
603
604    /// Get comprehensive performance stats
605    pub fn get_performance_stats(&self) -> GpuPerformanceStats {
606        let memory_stats = self.memory_pool.get_stats();
607        let cache_stats = self.tensor_cache.get_stats();
608
609        GpuPerformanceStats {
610            memory_allocations: memory_stats.total_allocations,
611            memory_deallocations: memory_stats.total_deallocations,
612            peak_memory_usage_mb: memory_stats.peak_memory_usage / (1024 * 1024),
613            current_memory_usage_mb: memory_stats.current_memory_usage / (1024 * 1024),
614            memory_pool_hits: memory_stats.cache_hits,
615            memory_pool_misses: memory_stats.cache_misses,
616            tensor_cache_hits: cache_stats.hits,
617            tensor_cache_misses: cache_stats.misses,
618            tensor_cache_evictions: cache_stats.evictions,
619            tensor_cache_memory_mb: cache_stats.total_memory_usage / (1024 * 1024),
620            loss_scaling_factor: self.mixed_precision.loss_scaling,
621            num_active_streams: self.config.num_streams,
622        }
623    }
624}
625
626/// GPU performance statistics
627#[derive(Debug, Serialize)]
628pub struct GpuPerformanceStats {
629    pub memory_allocations: usize,
630    pub memory_deallocations: usize,
631    pub peak_memory_usage_mb: usize,
632    pub current_memory_usage_mb: usize,
633    pub memory_pool_hits: usize,
634    pub memory_pool_misses: usize,
635    pub tensor_cache_hits: usize,
636    pub tensor_cache_misses: usize,
637    pub tensor_cache_evictions: usize,
638    pub tensor_cache_memory_mb: usize,
639    pub loss_scaling_factor: f32,
640    pub num_active_streams: usize,
641}
642
643/// Memory defragmentation utilities
644pub struct MemoryDefragmenter {
645    config: GpuAccelerationConfig,
646    defrag_threshold: f32,
647    last_defrag: Instant,
648    defrag_interval: Duration,
649}
650
651impl MemoryDefragmenter {
652    /// Create new memory defragmenter
653    pub fn new(config: GpuAccelerationConfig) -> Self {
654        Self {
655            config,
656            defrag_threshold: 0.7, // Defrag when 70% fragmented
657            last_defrag: Instant::now(),
658            defrag_interval: Duration::from_secs(300), // Defrag every 5 minutes max
659        }
660    }
661
662    /// Check if defragmentation is needed
663    pub fn should_defragment(&self, memory_pool: &GpuMemoryPool) -> bool {
664        let stats = memory_pool.get_stats();
665        let fragmentation_ratio = self.calculate_fragmentation_ratio(&stats);
666
667        fragmentation_ratio > self.defrag_threshold
668            && self.last_defrag.elapsed() > self.defrag_interval
669    }
670
671    /// Calculate memory fragmentation ratio
672    fn calculate_fragmentation_ratio(&self, stats: &AllocationStats) -> f32 {
673        if stats.current_memory_usage == 0 {
674            return 0.0;
675        }
676
677        // Simplified fragmentation calculation
678        // In real implementation, would analyze actual memory layout
679        let theoretical_optimal = stats.current_memory_usage;
680        let actual_allocated = stats.peak_memory_usage;
681
682        if actual_allocated == 0 {
683            0.0
684        } else {
685            1.0 - (theoretical_optimal as f32 / actual_allocated as f32)
686        }
687    }
688
689    /// Perform memory defragmentation
690    pub fn defragment(&mut self, memory_pool: &GpuMemoryPool) -> Result<DefragmentationResult> {
691        info!("Starting GPU memory defragmentation");
692        let start_time = Instant::now();
693
694        // In real implementation, would:
695        // 1. Identify fragmented memory regions
696        // 2. Move active allocations to contiguous regions
697        // 3. Release fragmented blocks back to the pool
698
699        // Simulate defragmentation work
700        std::thread::sleep(Duration::from_millis(100));
701
702        let stats_before = memory_pool.get_stats();
703
704        // Simulate memory compaction (in real implementation would actually move memory)
705        // This would involve GPU kernel calls to move data
706
707        let stats_after = memory_pool.get_stats();
708        self.last_defrag = Instant::now();
709
710        let result = DefragmentationResult {
711            duration: start_time.elapsed(),
712            memory_freed: stats_before
713                .peak_memory_usage
714                .saturating_sub(stats_after.current_memory_usage),
715            fragmentation_before: self.calculate_fragmentation_ratio(&stats_before),
716            fragmentation_after: self.calculate_fragmentation_ratio(&stats_after),
717        };
718
719        info!("Defragmentation completed: {:?}", result);
720        Ok(result)
721    }
722}
723
724/// Results of memory defragmentation operation
725#[derive(Debug, Clone)]
726pub struct DefragmentationResult {
727    pub duration: Duration,
728    pub memory_freed: usize,
729    pub fragmentation_before: f32,
730    pub fragmentation_after: f32,
731}
732
733/// Out-of-core processing for handling datasets larger than GPU memory
734pub struct OutOfCoreProcessor {
735    config: GpuAccelerationConfig,
736    chunk_size: usize,
737    overlap_size: usize,
738    memory_limit: usize,
739}
740
741impl OutOfCoreProcessor {
742    /// Create new out-of-core processor
743    pub fn new(config: GpuAccelerationConfig) -> Self {
744        let memory_limit = config.memory_pool_size_mb * 1024 * 1024; // Convert to bytes
745        let chunk_size = memory_limit / 4; // Use 25% of available memory per chunk
746        let overlap_size = chunk_size / 10; // 10% overlap between chunks
747
748        Self {
749            config,
750            chunk_size,
751            overlap_size,
752            memory_limit,
753        }
754    }
755
756    /// Process large embedding batch using out-of-core strategy
757    pub async fn process_large_batch<T>(
758        &self,
759        data: Vec<T>,
760        process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>> + Send + Sync + Copy,
761    ) -> Result<Vec<Array1<f32>>>
762    where
763        T: Clone + Send + Sync + 'static,
764    {
765        if data.is_empty() {
766            return Ok(Vec::new());
767        }
768
769        // Calculate optimal chunk size based on data size and memory constraints
770        let item_size = std::mem::size_of::<T>();
771        let max_items_per_chunk = self.chunk_size / item_size;
772        let chunk_size = max_items_per_chunk.clamp(1, 1000); // Between 1 and 1000 items
773
774        info!(
775            "Processing {} items in chunks of {}",
776            data.len(),
777            chunk_size
778        );
779
780        let mut results = Vec::new();
781        let mut processed_count = 0;
782
783        for chunk in data.chunks(chunk_size) {
784            // Process chunk on GPU
785            let chunk_results = process_fn(chunk)?;
786            results.extend(chunk_results);
787
788            processed_count += chunk.len();
789
790            if processed_count % (chunk_size * 10) == 0 {
791                info!("Processed {}/{} items", processed_count, data.len());
792            }
793
794            // Yield control to allow other tasks to run
795            tokio::task::yield_now().await;
796        }
797
798        Ok(results)
799    }
800
801    /// Process with overlapping windows for context-dependent embeddings
802    pub async fn process_with_overlap<T>(
803        &self,
804        data: Vec<T>,
805        process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>> + Send + Sync + Copy,
806    ) -> Result<Vec<Array1<f32>>>
807    where
808        T: Clone + Send + Sync + 'static,
809    {
810        if data.is_empty() {
811            return Ok(Vec::new());
812        }
813
814        let item_size = std::mem::size_of::<T>();
815        let max_items_per_chunk = self.chunk_size / item_size;
816        let chunk_size = max_items_per_chunk.clamp(1, 1000);
817
818        let mut results = Vec::new();
819        let mut start_idx = 0;
820
821        while start_idx < data.len() {
822            let end_idx = (start_idx + chunk_size).min(data.len());
823            let chunk = &data[start_idx..end_idx];
824
825            let chunk_results = process_fn(chunk)?;
826
827            // Handle overlap by only taking non-overlapping results
828            let take_count = if start_idx == 0 {
829                chunk_results.len()
830            } else {
831                // Skip overlap_size results from the beginning
832                chunk_results
833                    .len()
834                    .saturating_sub(self.overlap_size / item_size)
835            };
836
837            results.extend(chunk_results.into_iter().take(take_count));
838
839            start_idx += chunk_size - self.overlap_size / item_size;
840            tokio::task::yield_now().await;
841        }
842
843        Ok(results)
844    }
845}
846
847/// Dynamic shape handling for variable-size inputs
848pub struct DynamicShapeHandler {
849    config: GpuAccelerationConfig,
850    shape_cache: HashMap<Vec<usize>, ShapeInfo>,
851    max_cached_shapes: usize,
852}
853
854/// Information about tensor shapes for optimization
855#[derive(Debug, Clone)]
856struct ShapeInfo {
857    shape: Vec<usize>,
858    memory_requirement: usize,
859    optimal_batch_size: usize,
860    last_used: Instant,
861}
862
863impl DynamicShapeHandler {
864    /// Create new dynamic shape handler
865    pub fn new(config: GpuAccelerationConfig) -> Self {
866        Self {
867            config,
868            shape_cache: HashMap::new(),
869            max_cached_shapes: 100,
870        }
871    }
872
873    /// Optimize tensor shapes for GPU processing
874    pub fn optimize_shape(&mut self, shape: Vec<usize>) -> Vec<usize> {
875        // Check cache first
876        if let Some(shape_info) = self.shape_cache.get_mut(&shape) {
877            shape_info.last_used = Instant::now();
878            return shape_info.shape.clone();
879        }
880
881        // Calculate optimal shape based on GPU characteristics
882        let optimized_shape = self.calculate_optimal_shape(&shape);
883
884        // Cache the result
885        self.cache_shape_info(shape.clone(), optimized_shape.clone());
886
887        optimized_shape
888    }
889
890    /// Calculate optimal shape for GPU processing
891    fn calculate_optimal_shape(&self, shape: &[usize]) -> Vec<usize> {
892        let mut optimized = shape.to_vec();
893
894        // Align dimensions to GPU warp/wavefront sizes (typically 32)
895        const WARP_SIZE: usize = 32;
896
897        for dim in &mut optimized {
898            if *dim > 0 {
899                // Round up to next multiple of warp size for better GPU utilization
900                *dim = ((*dim + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
901            }
902        }
903
904        optimized
905    }
906
907    /// Cache shape information
908    fn cache_shape_info(&mut self, original_shape: Vec<usize>, optimized_shape: Vec<usize>) {
909        // Evict old entries if cache is full
910        if self.shape_cache.len() >= self.max_cached_shapes {
911            self.evict_oldest_shape();
912        }
913
914        let memory_requirement = optimized_shape.iter().product::<usize>() * 4; // Assume f32
915        let optimal_batch_size = self.calculate_optimal_batch_size(memory_requirement);
916
917        let shape_info = ShapeInfo {
918            shape: optimized_shape,
919            memory_requirement,
920            optimal_batch_size,
921            last_used: Instant::now(),
922        };
923
924        self.shape_cache.insert(original_shape, shape_info);
925    }
926
927    /// Calculate optimal batch size for given memory requirement
928    fn calculate_optimal_batch_size(&self, memory_per_item: usize) -> usize {
929        if memory_per_item == 0 {
930            return 1;
931        }
932
933        let available_memory = (self.config.memory_pool_size_mb * 1024 * 1024) / 2; // Use 50% of available memory
934        let max_batch_size = available_memory / memory_per_item;
935
936        // Clamp to reasonable range
937        max_batch_size.clamp(1, 1024)
938    }
939
940    /// Evict oldest cached shape
941    fn evict_oldest_shape(&mut self) {
942        if let Some(oldest_key) = self
943            .shape_cache
944            .iter()
945            .min_by_key(|(_, info)| info.last_used)
946            .map(|(key, _)| key.clone())
947        {
948            self.shape_cache.remove(&oldest_key);
949        }
950    }
951
952    /// Get optimal batch size for given shape
953    pub fn get_optimal_batch_size(&self, shape: &[usize]) -> usize {
954        self.shape_cache
955            .get(shape)
956            .map(|info| info.optimal_batch_size)
957            .unwrap_or(1)
958    }
959}
960
961/// Batch size optimizer for maximizing GPU utilization
962pub struct BatchSizeOptimizer {
963    config: GpuAccelerationConfig,
964    performance_history: VecDeque<BatchPerformance>,
965    max_history_size: usize,
966    current_optimal_batch_size: usize,
967}
968
969/// Performance metrics for a batch processing operation
970#[derive(Debug, Clone)]
971struct BatchPerformance {
972    batch_size: usize,
973    processing_time: Duration,
974    memory_usage: usize,
975    throughput: f64, // items per second
976    gpu_utilization: f64,
977    timestamp: Instant,
978}
979
980impl BatchSizeOptimizer {
981    /// Create new batch size optimizer
982    pub fn new(config: GpuAccelerationConfig) -> Self {
983        Self {
984            config,
985            performance_history: VecDeque::new(),
986            max_history_size: 50,
987            current_optimal_batch_size: 32, // Conservative starting point
988        }
989    }
990
991    /// Find optimal batch size through adaptive testing
992    pub async fn find_optimal_batch_size<T>(
993        &mut self,
994        sample_data: Vec<T>,
995        process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>> + Send + Sync + Copy,
996    ) -> Result<usize>
997    where
998        T: Clone + Send + Sync + 'static,
999    {
1000        if sample_data.is_empty() {
1001            return Ok(1);
1002        }
1003
1004        info!("Optimizing batch size for embedding generation");
1005
1006        let test_sizes = vec![1, 8, 16, 32, 64, 128, 256, 512];
1007        let max_test_size = sample_data.len().min(512);
1008
1009        let mut best_batch_size = 1;
1010        let mut best_throughput = 0.0;
1011
1012        for &batch_size in &test_sizes {
1013            if batch_size > max_test_size {
1014                break;
1015            }
1016
1017            // Test this batch size
1018            let performance = self
1019                .test_batch_size(
1020                    &sample_data[..batch_size.min(sample_data.len())],
1021                    batch_size,
1022                    process_fn,
1023                )
1024                .await?;
1025
1026            info!(
1027                "Batch size {}: {:.2} items/sec, {:.1}ms processing time",
1028                batch_size,
1029                performance.throughput,
1030                performance.processing_time.as_millis()
1031            );
1032
1033            if performance.throughput > best_throughput {
1034                best_throughput = performance.throughput;
1035                best_batch_size = batch_size;
1036            }
1037
1038            // Add to performance history
1039            self.performance_history.push_back(performance);
1040            if self.performance_history.len() > self.max_history_size {
1041                self.performance_history.pop_front();
1042            }
1043
1044            // Small delay between tests
1045            tokio::time::sleep(Duration::from_millis(100)).await;
1046        }
1047
1048        self.current_optimal_batch_size = best_batch_size;
1049        info!("Optimal batch size determined: {}", best_batch_size);
1050
1051        Ok(best_batch_size)
1052    }
1053
1054    /// Test performance of a specific batch size
1055    async fn test_batch_size<T>(
1056        &self,
1057        sample_data: &[T],
1058        batch_size: usize,
1059        process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>>,
1060    ) -> Result<BatchPerformance>
1061    where
1062        T: Clone,
1063    {
1064        let start_time = Instant::now();
1065        let memory_before = self.estimate_memory_usage();
1066
1067        // Process the batch
1068        let _results = process_fn(sample_data)?;
1069
1070        let processing_time = start_time.elapsed();
1071        let memory_after = self.estimate_memory_usage();
1072        let memory_usage = memory_after.saturating_sub(memory_before);
1073
1074        // Calculate throughput
1075        let throughput = if processing_time.as_secs_f64() > 0.0 {
1076            sample_data.len() as f64 / processing_time.as_secs_f64()
1077        } else {
1078            0.0
1079        };
1080
1081        // Estimate GPU utilization (simplified)
1082        let gpu_utilization = self.estimate_gpu_utilization(batch_size, processing_time);
1083
1084        Ok(BatchPerformance {
1085            batch_size,
1086            processing_time,
1087            memory_usage,
1088            throughput,
1089            gpu_utilization,
1090            timestamp: Instant::now(),
1091        })
1092    }
1093
1094    /// Estimate current memory usage
1095    fn estimate_memory_usage(&self) -> usize {
1096        // In real implementation, would query actual GPU memory usage
1097        // For simulation, return a reasonable estimate
1098        (self.config.memory_pool_size_mb * 1024 * 1024) / 4 // Assume 25% usage
1099    }
1100
1101    /// Estimate GPU utilization based on batch size and processing time
1102    fn estimate_gpu_utilization(&self, batch_size: usize, processing_time: Duration) -> f64 {
1103        // Simplified model: larger batches generally improve utilization up to a point
1104        let base_utilization = (batch_size as f64).log2() / 10.0; // Log scale
1105        let time_factor = if processing_time.as_millis() < 10 {
1106            0.5 // Very fast suggests underutilization
1107        } else if processing_time.as_millis() > 1000 {
1108            0.7 // Very slow might indicate bottlenecks
1109        } else {
1110            1.0
1111        };
1112
1113        (base_utilization * time_factor).clamp(0.0, 1.0)
1114    }
1115
1116    /// Get current optimal batch size
1117    pub fn get_optimal_batch_size(&self) -> usize {
1118        self.current_optimal_batch_size
1119    }
1120
1121    /// Get performance statistics
1122    pub fn get_performance_stats(&self) -> BatchSizeOptimizerStats {
1123        let avg_throughput = if !self.performance_history.is_empty() {
1124            self.performance_history
1125                .iter()
1126                .map(|p| p.throughput)
1127                .sum::<f64>()
1128                / self.performance_history.len() as f64
1129        } else {
1130            0.0
1131        };
1132
1133        let avg_gpu_utilization = if !self.performance_history.is_empty() {
1134            self.performance_history
1135                .iter()
1136                .map(|p| p.gpu_utilization)
1137                .sum::<f64>()
1138                / self.performance_history.len() as f64
1139        } else {
1140            0.0
1141        };
1142
1143        BatchSizeOptimizerStats {
1144            current_optimal_batch_size: self.current_optimal_batch_size,
1145            avg_throughput,
1146            avg_gpu_utilization,
1147            total_tests_performed: self.performance_history.len(),
1148        }
1149    }
1150}
1151
1152/// Statistics from batch size optimization
1153#[derive(Debug, Clone, Serialize, Deserialize)]
1154pub struct BatchSizeOptimizerStats {
1155    pub current_optimal_batch_size: usize,
1156    pub avg_throughput: f64,
1157    pub avg_gpu_utilization: f64,
1158    pub total_tests_performed: usize,
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163    use super::*;
1164
1165    #[test]
1166    fn test_gpu_acceleration_config_default() {
1167        let config = GpuAccelerationConfig::default();
1168        assert!(config.enabled);
1169        assert_eq!(config.device_ids, vec![0]);
1170        assert_eq!(config.memory_pool_size_mb, 2048);
1171        assert!(config.mixed_precision);
1172        assert!(config.tensor_caching);
1173    }
1174
1175    #[test]
1176    fn test_memory_pool_allocation() {
1177        let config = GpuAccelerationConfig::default();
1178        let pool = GpuMemoryPool::new(config);
1179
1180        let block_id = pool.allocate(1024, 0).unwrap();
1181        assert!(block_id > 0);
1182
1183        pool.deallocate(block_id).unwrap();
1184
1185        // Should reuse the block
1186        let block_id2 = pool.allocate(1024, 0).unwrap();
1187        assert_eq!(block_id, block_id2);
1188    }
1189
1190    #[test]
1191    fn test_tensor_cache() {
1192        let config = GpuAccelerationConfig::default();
1193        let cache = TensorCache::new(config);
1194
1195        let tensor = Array2::zeros((10, 20));
1196        cache.cache_entity_tensor("test_entity", tensor.clone(), 0);
1197
1198        let cached = cache.get_entity_tensor("test_entity").unwrap();
1199        assert_eq!(cached.shape(), tensor.shape());
1200    }
1201
1202    #[test]
1203    fn test_mixed_precision() {
1204        let config = GpuAccelerationConfig::default();
1205        let processor = MixedPrecisionProcessor::new(config);
1206
1207        // Use a value that will definitely cause precision loss in FP16 simulation
1208        let tensor = Array2::from_elem((2, 2), 1.0001);
1209        let fp16_tensor = processor.to_fp16(&tensor);
1210
1211        if processor.fp16_enabled {
1212            // Should have some precision loss in FP16 simulation
1213            assert!(fp16_tensor[[0, 0]] != tensor[[0, 0]]);
1214        } else {
1215            // If FP16 is disabled, values should be identical
1216            assert_eq!(fp16_tensor[[0, 0]], tensor[[0, 0]]);
1217        }
1218    }
1219
1220    #[tokio::test]
1221    async fn test_multi_stream_processing() {
1222        let config = GpuAccelerationConfig::default();
1223        let mut processor = MultiStreamProcessor::new(config);
1224
1225        let entities = vec!["entity1".to_string(), "entity2".to_string()];
1226        let process_fn = |entity: String, _stream_id: usize| -> Array1<f32> {
1227            Array1::from_vec(vec![entity.len() as f32])
1228        };
1229
1230        let results = processor
1231            .process_batch_parallel(entities, process_fn)
1232            .await
1233            .unwrap();
1234        assert_eq!(results.len(), 2);
1235    }
1236}