1use anyhow::{anyhow, Result};
8use scirs2_core::gpu::{GpuBackend, GpuContext};
9use scirs2_core::ndarray_ext::{Array1, Array2};
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, VecDeque};
12use std::sync::atomic::{AtomicUsize, Ordering};
13use std::sync::{Arc, Mutex};
14use std::time::{Duration, Instant};
15use tracing::{debug, info, warn};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct GpuAccelerationConfig {
20 pub enabled: bool,
22 pub device_ids: Vec<usize>,
24 pub memory_pool_size_mb: usize,
26 pub mixed_precision: bool,
28 pub tensor_caching: bool,
30 pub cache_size_mb: usize,
32 pub kernel_fusion: bool,
34 pub memory_mapping: bool,
36 pub unified_memory: bool,
38 pub multi_stream: bool,
40 pub num_streams: usize,
42 pub pipeline_parallelism: bool,
44 pub pipeline_stages: usize,
46}
47
48impl Default for GpuAccelerationConfig {
49 fn default() -> Self {
50 Self {
51 enabled: true,
52 device_ids: vec![0],
53 memory_pool_size_mb: 2048, mixed_precision: true,
55 tensor_caching: true,
56 cache_size_mb: 512, kernel_fusion: true,
58 memory_mapping: true,
59 unified_memory: false, multi_stream: true,
61 num_streams: 4,
62 pipeline_parallelism: false, pipeline_stages: 2,
64 }
65 }
66}
67
68pub struct GpuMemoryPool {
70 config: GpuAccelerationConfig,
71 allocated_blocks: Arc<Mutex<HashMap<usize, MemoryBlock>>>,
72 free_blocks: Arc<Mutex<VecDeque<MemoryBlock>>>,
73 total_allocated: Arc<Mutex<usize>>,
74 allocation_stats: Arc<Mutex<AllocationStats>>,
75}
76
77#[derive(Debug, Clone)]
79struct MemoryBlock {
80 device_id: usize,
81 size_bytes: usize,
82 ptr: usize, allocated_at: Instant,
84 last_used: Instant,
85}
86
87#[derive(Debug, Default, Clone)]
89pub struct AllocationStats {
90 pub total_allocations: usize,
91 pub total_deallocations: usize,
92 pub peak_memory_usage: usize,
93 pub current_memory_usage: usize,
94 pub cache_hits: usize,
95 pub cache_misses: usize,
96}
97
98impl GpuMemoryPool {
99 pub fn new(config: GpuAccelerationConfig) -> Self {
101 Self {
102 config,
103 allocated_blocks: Arc::new(Mutex::new(HashMap::new())),
104 free_blocks: Arc::new(Mutex::new(VecDeque::new())),
105 total_allocated: Arc::new(Mutex::new(0)),
106 allocation_stats: Arc::new(Mutex::new(AllocationStats::default())),
107 }
108 }
109
110 pub fn allocate(&self, size_bytes: usize, device_id: usize) -> Result<usize> {
112 let mut free_blocks = self.free_blocks.lock().expect("lock poisoned");
113 let mut allocated_blocks = self.allocated_blocks.lock().expect("lock poisoned");
114 let mut stats = self.allocation_stats.lock().expect("lock poisoned");
115
116 for (i, block) in free_blocks.iter().enumerate() {
118 if block.size_bytes >= size_bytes && block.device_id == device_id {
119 let block = free_blocks
120 .remove(i)
121 .expect("index i should be valid from enumerate");
122 let block_id = block.ptr;
123
124 let mut reused_block = block;
125 reused_block.last_used = Instant::now();
126
127 allocated_blocks.insert(block_id, reused_block);
128 stats.cache_hits += 1;
129
130 debug!(
131 "Reused GPU memory block {} of size {}",
132 block_id, size_bytes
133 );
134 return Ok(block_id);
135 }
136 }
137
138 stats.cache_misses += 1;
140 stats.total_allocations += 1;
141
142 let block_id = stats.total_allocations; let now = Instant::now();
144
145 let block = MemoryBlock {
146 device_id,
147 size_bytes,
148 ptr: block_id,
149 allocated_at: now,
150 last_used: now,
151 };
152
153 allocated_blocks.insert(block_id, block);
154
155 let mut total_allocated = self.total_allocated.lock().expect("lock poisoned");
156 *total_allocated += size_bytes;
157 stats.current_memory_usage += size_bytes;
158
159 if stats.current_memory_usage > stats.peak_memory_usage {
160 stats.peak_memory_usage = stats.current_memory_usage;
161 }
162
163 info!(
164 "Allocated new GPU memory block {} of size {} bytes",
165 block_id, size_bytes
166 );
167 Ok(block_id)
168 }
169
170 pub fn deallocate(&self, block_id: usize) -> Result<()> {
172 let mut allocated_blocks = self.allocated_blocks.lock().expect("lock poisoned");
173 let mut free_blocks = self.free_blocks.lock().expect("lock poisoned");
174 let mut stats = self.allocation_stats.lock().expect("lock poisoned");
175
176 if let Some(block) = allocated_blocks.remove(&block_id) {
177 stats.total_deallocations += 1;
178 stats.current_memory_usage -= block.size_bytes;
179
180 free_blocks.push_back(block);
182
183 if free_blocks.len() > 100 {
185 free_blocks.pop_front();
186 }
187
188 debug!("Deallocated GPU memory block {}", block_id);
189 Ok(())
190 } else {
191 Err(anyhow!("Block {} not found for deallocation", block_id))
192 }
193 }
194
195 pub fn get_stats(&self) -> AllocationStats {
197 (*self.allocation_stats.lock().expect("lock poisoned")).clone()
198 }
199
200 pub fn defragment(&self) -> Result<()> {
202 let mut free_blocks = self.free_blocks.lock().expect("lock poisoned");
203
204 let mut blocks: Vec<_> = free_blocks.drain(..).collect();
206 blocks.sort_by_key(|b| (b.device_id, b.size_bytes));
207
208 let mut merged_blocks = VecDeque::new();
210 let mut current_block: Option<MemoryBlock> = None;
211
212 for block in blocks {
213 if let Some(ref mut current) = current_block {
214 if current.device_id == block.device_id {
215 current.size_bytes += block.size_bytes;
217 } else {
218 merged_blocks.push_back(current.clone());
219 current_block = Some(block);
220 }
221 } else {
222 current_block = Some(block);
223 }
224 }
225
226 if let Some(block) = current_block {
227 merged_blocks.push_back(block);
228 }
229
230 *free_blocks = merged_blocks;
231
232 info!(
233 "Memory defragmentation completed, {} free blocks remaining",
234 free_blocks.len()
235 );
236 Ok(())
237 }
238}
239
240pub struct TensorCache {
242 config: GpuAccelerationConfig,
243 entity_tensors: Arc<Mutex<HashMap<String, CachedTensor>>>,
244 attention_weights: Arc<Mutex<HashMap<String, CachedTensor>>>,
245 intermediate_activations: Arc<Mutex<HashMap<String, CachedTensor>>>,
246 cache_stats: Arc<Mutex<CacheStats>>,
247}
248
249#[derive(Debug, Clone)]
251struct CachedTensor {
252 data: Array2<f32>, device_id: usize,
254 last_accessed: Instant,
255 access_count: usize,
256 size_bytes: usize,
257}
258
259#[derive(Debug, Default, Clone)]
261pub struct CacheStats {
262 pub hits: usize,
263 pub misses: usize,
264 pub evictions: usize,
265 pub total_memory_usage: usize,
266}
267
268impl TensorCache {
269 pub fn new(config: GpuAccelerationConfig) -> Self {
271 Self {
272 config,
273 entity_tensors: Arc::new(Mutex::new(HashMap::new())),
274 attention_weights: Arc::new(Mutex::new(HashMap::new())),
275 intermediate_activations: Arc::new(Mutex::new(HashMap::new())),
276 cache_stats: Arc::new(Mutex::new(CacheStats::default())),
277 }
278 }
279
280 pub fn cache_entity_tensor(&self, entity: &str, tensor: Array2<f32>, device_id: usize) {
282 let mut cache = self.entity_tensors.lock().expect("lock poisoned");
283 let mut stats = self.cache_stats.lock().expect("lock poisoned");
284
285 let size_bytes = tensor.len() * std::mem::size_of::<f32>();
286
287 let cached_tensor = CachedTensor {
288 data: tensor,
289 device_id,
290 last_accessed: Instant::now(),
291 access_count: 1,
292 size_bytes,
293 };
294
295 self.evict_if_needed(&mut stats);
297
298 cache.insert(entity.to_string(), cached_tensor);
299 stats.total_memory_usage += size_bytes;
300
301 debug!("Cached entity tensor for {}", entity);
302 }
303
304 pub fn get_entity_tensor(&self, entity: &str) -> Option<Array2<f32>> {
306 let mut cache = self.entity_tensors.lock().expect("lock poisoned");
307 let mut stats = self.cache_stats.lock().expect("lock poisoned");
308
309 if let Some(cached) = cache.get_mut(entity) {
310 cached.last_accessed = Instant::now();
311 cached.access_count += 1;
312 stats.hits += 1;
313
314 debug!("Cache hit for entity tensor {}", entity);
315 Some(cached.data.clone())
316 } else {
317 stats.misses += 1;
318 debug!("Cache miss for entity tensor {}", entity);
319 None
320 }
321 }
322
323 pub fn cache_attention_weights(&self, key: &str, weights: Array2<f32>, device_id: usize) {
325 let mut cache = self.attention_weights.lock().expect("lock poisoned");
326 let mut stats = self.cache_stats.lock().expect("lock poisoned");
327
328 let size_bytes = weights.len() * std::mem::size_of::<f32>();
329
330 let cached_tensor = CachedTensor {
331 data: weights,
332 device_id,
333 last_accessed: Instant::now(),
334 access_count: 1,
335 size_bytes,
336 };
337
338 self.evict_if_needed(&mut stats);
339
340 cache.insert(key.to_string(), cached_tensor);
341 stats.total_memory_usage += size_bytes;
342
343 debug!("Cached attention weights for key {}", key);
344 }
345
346 pub fn get_attention_weights(&self, key: &str) -> Option<Array2<f32>> {
348 let mut cache = self.attention_weights.lock().expect("lock poisoned");
349 let mut stats = self.cache_stats.lock().expect("lock poisoned");
350
351 if let Some(cached) = cache.get_mut(key) {
352 cached.last_accessed = Instant::now();
353 cached.access_count += 1;
354 stats.hits += 1;
355
356 debug!("Cache hit for attention weights {}", key);
357 Some(cached.data.clone())
358 } else {
359 stats.misses += 1;
360 debug!("Cache miss for attention weights {}", key);
361 None
362 }
363 }
364
365 fn evict_if_needed(&self, stats: &mut CacheStats) {
367 let max_memory = self.config.cache_size_mb * 1024 * 1024; if stats.total_memory_usage > max_memory {
370 stats.evictions += 1;
372 stats.total_memory_usage = max_memory / 2; warn!("Tensor cache eviction triggered, freed memory");
375 }
376 }
377
378 pub fn get_stats(&self) -> CacheStats {
380 (*self.cache_stats.lock().expect("lock poisoned")).clone()
381 }
382
383 pub fn clear_all(&self) {
385 self.entity_tensors.lock().expect("lock poisoned").clear();
386 self.attention_weights
387 .lock()
388 .expect("lock poisoned")
389 .clear();
390 self.intermediate_activations
391 .lock()
392 .expect("lock poisoned")
393 .clear();
394
395 let mut stats = self.cache_stats.lock().expect("lock poisoned");
396 stats.total_memory_usage = 0;
397
398 info!("Cleared all tensor caches");
399 }
400}
401
402pub struct MixedPrecisionProcessor {
404 config: GpuAccelerationConfig,
405 fp16_enabled: bool,
406 loss_scaling: f32,
407 overflow_detection: bool,
408}
409
410impl MixedPrecisionProcessor {
411 pub fn new(config: GpuAccelerationConfig) -> Self {
413 Self {
414 config: config.clone(),
415 fp16_enabled: config.mixed_precision,
416 loss_scaling: 65536.0, overflow_detection: true,
418 }
419 }
420
421 pub fn to_fp16(&self, tensor: &Array2<f32>) -> Array2<f32> {
423 if !self.fp16_enabled {
424 return tensor.clone();
425 }
426
427 tensor.mapv(|x| {
429 let clamped = x.clamp(-65504.0, 65504.0);
431 (clamped * 1024.0).round() / 1024.0 })
433 }
434
435 pub fn scale_loss(&self, loss: f32) -> f32 {
437 if self.fp16_enabled {
438 loss * self.loss_scaling
439 } else {
440 loss
441 }
442 }
443
444 pub fn unscale_gradients(&self, gradients: &mut Array2<f32>) -> bool {
446 if !self.fp16_enabled {
447 return true;
448 }
449
450 if self.overflow_detection {
452 let has_overflow = gradients.iter().any(|&x| !x.is_finite());
453 if has_overflow {
454 warn!("Gradient overflow detected in mixed precision training");
455 return false;
456 }
457 }
458
459 gradients.mapv_inplace(|x| x / self.loss_scaling);
461 true
462 }
463
464 pub fn adjust_loss_scaling(&mut self, overflow_detected: bool) {
466 if overflow_detected {
467 self.loss_scaling = (self.loss_scaling / 2.0).max(1.0);
468 info!("Reduced loss scaling to {}", self.loss_scaling);
469 } else {
470 self.loss_scaling = (self.loss_scaling * 1.1).min(65536.0);
472 }
473 }
474}
475
476pub struct MultiStreamProcessor {
478 config: GpuAccelerationConfig,
479 pub stream_ids: Vec<usize>,
480 current_stream: usize,
481}
482
483impl MultiStreamProcessor {
484 pub fn new(config: GpuAccelerationConfig) -> Self {
486 let stream_ids = (0..config.num_streams).collect();
487
488 Self {
489 config,
490 stream_ids,
491 current_stream: 0,
492 }
493 }
494
495 pub fn get_next_stream(&mut self) -> usize {
497 let stream_id = self.stream_ids[self.current_stream];
498 self.current_stream = (self.current_stream + 1) % self.stream_ids.len();
499 stream_id
500 }
501
502 pub async fn process_batch_parallel(
504 &mut self,
505 entities: Vec<String>,
506 process_fn: impl Fn(String, usize) -> Array1<f32> + Send + Sync + Copy + 'static,
507 ) -> Result<Vec<Array1<f32>>> {
508 let chunk_size = (entities.len() + self.config.num_streams - 1) / self.config.num_streams;
509 let mut tasks = Vec::new();
510
511 for chunk in entities.chunks(chunk_size) {
512 let stream_id = self.get_next_stream();
513 let chunk_entities = chunk.to_vec();
514
515 let task = tokio::spawn(async move {
516 let mut results = Vec::new();
517 for entity in chunk_entities {
518 let embedding = process_fn(entity, stream_id);
519 results.push(embedding);
520 }
521 results
522 });
523
524 tasks.push(task);
525 }
526
527 let mut all_results = Vec::new();
529 for task in tasks {
530 let chunk_results = task.await?;
531 all_results.extend(chunk_results);
532 }
533
534 Ok(all_results)
535 }
536
537 pub fn synchronize_all(&self) {
539 debug!("Synchronized {} GPU streams", self.stream_ids.len());
541 }
542}
543
544pub struct GpuAccelerationManager {
546 config: GpuAccelerationConfig,
547 memory_pool: GpuMemoryPool,
548 tensor_cache: TensorCache,
549 mixed_precision: MixedPrecisionProcessor,
550 multi_stream: MultiStreamProcessor,
551}
552
553impl GpuAccelerationManager {
554 pub fn new(config: GpuAccelerationConfig) -> Self {
556 let memory_pool = GpuMemoryPool::new(config.clone());
557 let tensor_cache = TensorCache::new(config.clone());
558 let mixed_precision = MixedPrecisionProcessor::new(config.clone());
559 let multi_stream = MultiStreamProcessor::new(config.clone());
560
561 Self {
562 config,
563 memory_pool,
564 tensor_cache,
565 mixed_precision,
566 multi_stream,
567 }
568 }
569
570 pub fn memory_pool(&self) -> &GpuMemoryPool {
572 &self.memory_pool
573 }
574
575 pub fn tensor_cache(&self) -> &TensorCache {
577 &self.tensor_cache
578 }
579
580 pub fn mixed_precision(&mut self) -> &mut MixedPrecisionProcessor {
582 &mut self.mixed_precision
583 }
584
585 pub fn multi_stream(&mut self) -> &mut MultiStreamProcessor {
587 &mut self.multi_stream
588 }
589
590 pub async fn accelerated_embedding_generation(
592 &mut self,
593 entities: Vec<String>,
594 base_compute_fn: impl Fn(&str) -> Array1<f32> + Send + Sync + Copy + 'static,
595 ) -> Result<Vec<Array1<f32>>> {
596 if !self.config.enabled {
597 return Ok(entities.iter().map(|e| base_compute_fn(e)).collect());
599 }
600
601 let results = self
603 .multi_stream
604 .process_batch_parallel(entities, move |entity, stream_id| {
605 debug!("Processing entity {} on stream {}", entity, stream_id);
607 base_compute_fn(&entity)
608 })
609 .await?;
610
611 self.multi_stream.synchronize_all();
612 Ok(results)
613 }
614
615 pub fn get_performance_stats(&self) -> GpuPerformanceStats {
617 let memory_stats = self.memory_pool.get_stats();
618 let cache_stats = self.tensor_cache.get_stats();
619
620 GpuPerformanceStats {
621 memory_allocations: memory_stats.total_allocations,
622 memory_deallocations: memory_stats.total_deallocations,
623 peak_memory_usage_mb: memory_stats.peak_memory_usage / (1024 * 1024),
624 current_memory_usage_mb: memory_stats.current_memory_usage / (1024 * 1024),
625 memory_pool_hits: memory_stats.cache_hits,
626 memory_pool_misses: memory_stats.cache_misses,
627 tensor_cache_hits: cache_stats.hits,
628 tensor_cache_misses: cache_stats.misses,
629 tensor_cache_evictions: cache_stats.evictions,
630 tensor_cache_memory_mb: cache_stats.total_memory_usage / (1024 * 1024),
631 loss_scaling_factor: self.mixed_precision.loss_scaling,
632 num_active_streams: self.config.num_streams,
633 }
634 }
635}
636
637#[derive(Debug, Serialize)]
639pub struct GpuPerformanceStats {
640 pub memory_allocations: usize,
641 pub memory_deallocations: usize,
642 pub peak_memory_usage_mb: usize,
643 pub current_memory_usage_mb: usize,
644 pub memory_pool_hits: usize,
645 pub memory_pool_misses: usize,
646 pub tensor_cache_hits: usize,
647 pub tensor_cache_misses: usize,
648 pub tensor_cache_evictions: usize,
649 pub tensor_cache_memory_mb: usize,
650 pub loss_scaling_factor: f32,
651 pub num_active_streams: usize,
652}
653
654pub struct MemoryDefragmenter {
656 config: GpuAccelerationConfig,
657 defrag_threshold: f32,
658 last_defrag: Instant,
659 defrag_interval: Duration,
660}
661
662impl MemoryDefragmenter {
663 pub fn new(config: GpuAccelerationConfig) -> Self {
665 Self {
666 config,
667 defrag_threshold: 0.7, last_defrag: Instant::now(),
669 defrag_interval: Duration::from_secs(300), }
671 }
672
673 pub fn should_defragment(&self, memory_pool: &GpuMemoryPool) -> bool {
675 let stats = memory_pool.get_stats();
676 let fragmentation_ratio = self.calculate_fragmentation_ratio(&stats);
677
678 fragmentation_ratio > self.defrag_threshold
679 && self.last_defrag.elapsed() > self.defrag_interval
680 }
681
682 fn calculate_fragmentation_ratio(&self, stats: &AllocationStats) -> f32 {
684 if stats.current_memory_usage == 0 {
685 return 0.0;
686 }
687
688 let theoretical_optimal = stats.current_memory_usage;
691 let actual_allocated = stats.peak_memory_usage;
692
693 if actual_allocated == 0 {
694 0.0
695 } else {
696 1.0 - (theoretical_optimal as f32 / actual_allocated as f32)
697 }
698 }
699
700 pub fn defragment(&mut self, memory_pool: &GpuMemoryPool) -> Result<DefragmentationResult> {
702 info!("Starting GPU memory defragmentation");
703 let start_time = Instant::now();
704
705 std::thread::sleep(Duration::from_millis(100));
712
713 let stats_before = memory_pool.get_stats();
714
715 let stats_after = memory_pool.get_stats();
719 self.last_defrag = Instant::now();
720
721 let result = DefragmentationResult {
722 duration: start_time.elapsed(),
723 memory_freed: stats_before
724 .peak_memory_usage
725 .saturating_sub(stats_after.current_memory_usage),
726 fragmentation_before: self.calculate_fragmentation_ratio(&stats_before),
727 fragmentation_after: self.calculate_fragmentation_ratio(&stats_after),
728 };
729
730 info!("Defragmentation completed: {:?}", result);
731 Ok(result)
732 }
733}
734
735#[derive(Debug, Clone)]
737pub struct DefragmentationResult {
738 pub duration: Duration,
739 pub memory_freed: usize,
740 pub fragmentation_before: f32,
741 pub fragmentation_after: f32,
742}
743
744pub struct OutOfCoreProcessor {
746 config: GpuAccelerationConfig,
747 chunk_size: usize,
748 overlap_size: usize,
749 memory_limit: usize,
750}
751
752impl OutOfCoreProcessor {
753 pub fn new(config: GpuAccelerationConfig) -> Self {
755 let memory_limit = config.memory_pool_size_mb * 1024 * 1024; let chunk_size = memory_limit / 4; let overlap_size = chunk_size / 10; Self {
760 config,
761 chunk_size,
762 overlap_size,
763 memory_limit,
764 }
765 }
766
767 pub async fn process_large_batch<T>(
769 &self,
770 data: Vec<T>,
771 process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>> + Send + Sync + Copy,
772 ) -> Result<Vec<Array1<f32>>>
773 where
774 T: Clone + Send + Sync + 'static,
775 {
776 if data.is_empty() {
777 return Ok(Vec::new());
778 }
779
780 let item_size = std::mem::size_of::<T>();
782 let max_items_per_chunk = self.chunk_size / item_size;
783 let chunk_size = max_items_per_chunk.clamp(1, 1000); info!(
786 "Processing {} items in chunks of {}",
787 data.len(),
788 chunk_size
789 );
790
791 let mut results = Vec::new();
792 let mut processed_count = 0;
793
794 for chunk in data.chunks(chunk_size) {
795 let chunk_results = process_fn(chunk)?;
797 results.extend(chunk_results);
798
799 processed_count += chunk.len();
800
801 if processed_count % (chunk_size * 10) == 0 {
802 info!("Processed {}/{} items", processed_count, data.len());
803 }
804
805 tokio::task::yield_now().await;
807 }
808
809 Ok(results)
810 }
811
812 pub async fn process_with_overlap<T>(
814 &self,
815 data: Vec<T>,
816 process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>> + Send + Sync + Copy,
817 ) -> Result<Vec<Array1<f32>>>
818 where
819 T: Clone + Send + Sync + 'static,
820 {
821 if data.is_empty() {
822 return Ok(Vec::new());
823 }
824
825 let item_size = std::mem::size_of::<T>();
826 let max_items_per_chunk = self.chunk_size / item_size;
827 let chunk_size = max_items_per_chunk.clamp(1, 1000);
828
829 let mut results = Vec::new();
830 let mut start_idx = 0;
831
832 while start_idx < data.len() {
833 let end_idx = (start_idx + chunk_size).min(data.len());
834 let chunk = &data[start_idx..end_idx];
835
836 let chunk_results = process_fn(chunk)?;
837
838 let take_count = if start_idx == 0 {
840 chunk_results.len()
841 } else {
842 chunk_results
844 .len()
845 .saturating_sub(self.overlap_size / item_size)
846 };
847
848 results.extend(chunk_results.into_iter().take(take_count));
849
850 start_idx += chunk_size - self.overlap_size / item_size;
851 tokio::task::yield_now().await;
852 }
853
854 Ok(results)
855 }
856}
857
858pub struct DynamicShapeHandler {
860 config: GpuAccelerationConfig,
861 shape_cache: HashMap<Vec<usize>, ShapeInfo>,
862 max_cached_shapes: usize,
863}
864
865#[derive(Debug, Clone)]
867struct ShapeInfo {
868 shape: Vec<usize>,
869 memory_requirement: usize,
870 optimal_batch_size: usize,
871 last_used: Instant,
872}
873
874impl DynamicShapeHandler {
875 pub fn new(config: GpuAccelerationConfig) -> Self {
877 Self {
878 config,
879 shape_cache: HashMap::new(),
880 max_cached_shapes: 100,
881 }
882 }
883
884 pub fn optimize_shape(&mut self, shape: Vec<usize>) -> Vec<usize> {
886 if let Some(shape_info) = self.shape_cache.get_mut(&shape) {
888 shape_info.last_used = Instant::now();
889 return shape_info.shape.clone();
890 }
891
892 let optimized_shape = self.calculate_optimal_shape(&shape);
894
895 self.cache_shape_info(shape.clone(), optimized_shape.clone());
897
898 optimized_shape
899 }
900
901 fn calculate_optimal_shape(&self, shape: &[usize]) -> Vec<usize> {
903 let mut optimized = shape.to_vec();
904
905 const WARP_SIZE: usize = 32;
907
908 for dim in &mut optimized {
909 if *dim > 0 {
910 *dim = ((*dim + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
912 }
913 }
914
915 optimized
916 }
917
918 fn cache_shape_info(&mut self, original_shape: Vec<usize>, optimized_shape: Vec<usize>) {
920 if self.shape_cache.len() >= self.max_cached_shapes {
922 self.evict_oldest_shape();
923 }
924
925 let memory_requirement = optimized_shape.iter().product::<usize>() * 4; let optimal_batch_size = self.calculate_optimal_batch_size(memory_requirement);
927
928 let shape_info = ShapeInfo {
929 shape: optimized_shape,
930 memory_requirement,
931 optimal_batch_size,
932 last_used: Instant::now(),
933 };
934
935 self.shape_cache.insert(original_shape, shape_info);
936 }
937
938 fn calculate_optimal_batch_size(&self, memory_per_item: usize) -> usize {
940 if memory_per_item == 0 {
941 return 1;
942 }
943
944 let available_memory = (self.config.memory_pool_size_mb * 1024 * 1024) / 2; let max_batch_size = available_memory / memory_per_item;
946
947 max_batch_size.clamp(1, 1024)
949 }
950
951 fn evict_oldest_shape(&mut self) {
953 if let Some(oldest_key) = self
954 .shape_cache
955 .iter()
956 .min_by_key(|(_, info)| info.last_used)
957 .map(|(key, _)| key.clone())
958 {
959 self.shape_cache.remove(&oldest_key);
960 }
961 }
962
963 pub fn get_optimal_batch_size(&self, shape: &[usize]) -> usize {
965 self.shape_cache
966 .get(shape)
967 .map(|info| info.optimal_batch_size)
968 .unwrap_or(1)
969 }
970}
971
972pub struct BatchSizeOptimizer {
974 config: GpuAccelerationConfig,
975 performance_history: VecDeque<BatchPerformance>,
976 max_history_size: usize,
977 current_optimal_batch_size: usize,
978}
979
980#[derive(Debug, Clone)]
982struct BatchPerformance {
983 batch_size: usize,
984 processing_time: Duration,
985 memory_usage: usize,
986 throughput: f64, gpu_utilization: f64,
988 timestamp: Instant,
989}
990
991impl BatchSizeOptimizer {
992 pub fn new(config: GpuAccelerationConfig) -> Self {
994 Self {
995 config,
996 performance_history: VecDeque::new(),
997 max_history_size: 50,
998 current_optimal_batch_size: 32, }
1000 }
1001
1002 pub async fn find_optimal_batch_size<T>(
1004 &mut self,
1005 sample_data: Vec<T>,
1006 process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>> + Send + Sync + Copy,
1007 ) -> Result<usize>
1008 where
1009 T: Clone + Send + Sync + 'static,
1010 {
1011 if sample_data.is_empty() {
1012 return Ok(1);
1013 }
1014
1015 info!("Optimizing batch size for embedding generation");
1016
1017 let test_sizes = vec![1, 8, 16, 32, 64, 128, 256, 512];
1018 let max_test_size = sample_data.len().min(512);
1019
1020 let mut best_batch_size = 1;
1021 let mut best_throughput = 0.0;
1022
1023 for &batch_size in &test_sizes {
1024 if batch_size > max_test_size {
1025 break;
1026 }
1027
1028 let performance = self
1030 .test_batch_size(
1031 &sample_data[..batch_size.min(sample_data.len())],
1032 batch_size,
1033 process_fn,
1034 )
1035 .await?;
1036
1037 info!(
1038 "Batch size {}: {:.2} items/sec, {:.1}ms processing time",
1039 batch_size,
1040 performance.throughput,
1041 performance.processing_time.as_millis()
1042 );
1043
1044 if performance.throughput > best_throughput {
1045 best_throughput = performance.throughput;
1046 best_batch_size = batch_size;
1047 }
1048
1049 self.performance_history.push_back(performance);
1051 if self.performance_history.len() > self.max_history_size {
1052 self.performance_history.pop_front();
1053 }
1054
1055 tokio::time::sleep(Duration::from_millis(100)).await;
1057 }
1058
1059 self.current_optimal_batch_size = best_batch_size;
1060 info!("Optimal batch size determined: {}", best_batch_size);
1061
1062 Ok(best_batch_size)
1063 }
1064
1065 async fn test_batch_size<T>(
1067 &self,
1068 sample_data: &[T],
1069 batch_size: usize,
1070 process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>>,
1071 ) -> Result<BatchPerformance>
1072 where
1073 T: Clone,
1074 {
1075 let start_time = Instant::now();
1076 let memory_before = self.estimate_memory_usage();
1077
1078 let _results = process_fn(sample_data)?;
1080
1081 let processing_time = start_time.elapsed();
1082 let memory_after = self.estimate_memory_usage();
1083 let memory_usage = memory_after.saturating_sub(memory_before);
1084
1085 let throughput = if processing_time.as_secs_f64() > 0.0 {
1087 sample_data.len() as f64 / processing_time.as_secs_f64()
1088 } else {
1089 0.0
1090 };
1091
1092 let gpu_utilization = self.estimate_gpu_utilization(batch_size, processing_time);
1094
1095 Ok(BatchPerformance {
1096 batch_size,
1097 processing_time,
1098 memory_usage,
1099 throughput,
1100 gpu_utilization,
1101 timestamp: Instant::now(),
1102 })
1103 }
1104
1105 fn estimate_memory_usage(&self) -> usize {
1107 (self.config.memory_pool_size_mb * 1024 * 1024) / 4 }
1111
1112 fn estimate_gpu_utilization(&self, batch_size: usize, processing_time: Duration) -> f64 {
1114 let base_utilization = (batch_size as f64).log2() / 10.0; let time_factor = if processing_time.as_millis() < 10 {
1117 0.5 } else if processing_time.as_millis() > 1000 {
1119 0.7 } else {
1121 1.0
1122 };
1123
1124 (base_utilization * time_factor).clamp(0.0, 1.0)
1125 }
1126
1127 pub fn get_optimal_batch_size(&self) -> usize {
1129 self.current_optimal_batch_size
1130 }
1131
1132 pub fn get_performance_stats(&self) -> BatchSizeOptimizerStats {
1134 let avg_throughput = if !self.performance_history.is_empty() {
1135 self.performance_history
1136 .iter()
1137 .map(|p| p.throughput)
1138 .sum::<f64>()
1139 / self.performance_history.len() as f64
1140 } else {
1141 0.0
1142 };
1143
1144 let avg_gpu_utilization = if !self.performance_history.is_empty() {
1145 self.performance_history
1146 .iter()
1147 .map(|p| p.gpu_utilization)
1148 .sum::<f64>()
1149 / self.performance_history.len() as f64
1150 } else {
1151 0.0
1152 };
1153
1154 BatchSizeOptimizerStats {
1155 current_optimal_batch_size: self.current_optimal_batch_size,
1156 avg_throughput,
1157 avg_gpu_utilization,
1158 total_tests_performed: self.performance_history.len(),
1159 }
1160 }
1161}
1162
1163#[derive(Debug, Clone, Serialize, Deserialize)]
1165pub struct BatchSizeOptimizerStats {
1166 pub current_optimal_batch_size: usize,
1167 pub avg_throughput: f64,
1168 pub avg_gpu_utilization: f64,
1169 pub total_tests_performed: usize,
1170}
1171
1172#[cfg(test)]
1173mod tests {
1174 use super::*;
1175
1176 #[test]
1177 fn test_gpu_acceleration_config_default() {
1178 let config = GpuAccelerationConfig::default();
1179 assert!(config.enabled);
1180 assert_eq!(config.device_ids, vec![0]);
1181 assert_eq!(config.memory_pool_size_mb, 2048);
1182 assert!(config.mixed_precision);
1183 assert!(config.tensor_caching);
1184 }
1185
1186 #[test]
1187 fn test_memory_pool_allocation() {
1188 let config = GpuAccelerationConfig::default();
1189 let pool = GpuMemoryPool::new(config);
1190
1191 let block_id = pool.allocate(1024, 0).unwrap();
1192 assert!(block_id > 0);
1193
1194 pool.deallocate(block_id).unwrap();
1195
1196 let block_id2 = pool.allocate(1024, 0).unwrap();
1198 assert_eq!(block_id, block_id2);
1199 }
1200
1201 #[test]
1202 fn test_tensor_cache() {
1203 let config = GpuAccelerationConfig::default();
1204 let cache = TensorCache::new(config);
1205
1206 let tensor = Array2::zeros((10, 20));
1207 cache.cache_entity_tensor("test_entity", tensor.clone(), 0);
1208
1209 let cached = cache.get_entity_tensor("test_entity").unwrap();
1210 assert_eq!(cached.shape(), tensor.shape());
1211 }
1212
1213 #[test]
1214 fn test_mixed_precision() {
1215 let config = GpuAccelerationConfig::default();
1216 let processor = MixedPrecisionProcessor::new(config);
1217
1218 let tensor = Array2::from_elem((2, 2), 1.0001);
1220 let fp16_tensor = processor.to_fp16(&tensor);
1221
1222 if processor.fp16_enabled {
1223 assert!(fp16_tensor[[0, 0]] != tensor[[0, 0]]);
1225 } else {
1226 assert_eq!(fp16_tensor[[0, 0]], tensor[[0, 0]]);
1228 }
1229 }
1230
1231 #[tokio::test]
1232 async fn test_multi_stream_processing() {
1233 let config = GpuAccelerationConfig::default();
1234 let mut processor = MultiStreamProcessor::new(config);
1235
1236 let entities = vec!["entity1".to_string(), "entity2".to_string()];
1237 let process_fn = |entity: String, _stream_id: usize| -> Array1<f32> {
1238 Array1::from_vec(vec![entity.len() as f32])
1239 };
1240
1241 let results = processor
1242 .process_batch_parallel(entities, process_fn)
1243 .await
1244 .unwrap();
1245 assert_eq!(results.len(), 2);
1246 }
1247
1248 #[test]
1249 fn test_scirs2_gpu_accelerator() {
1250 let config = GpuAccelerationConfig::default();
1252
1253 match SciRS2GpuAccelerator::new(config) {
1254 Ok(accelerator) => {
1255 assert!(accelerator.num_devices() > 0);
1257 }
1258 Err(_) => {
1259 println!("Skipping GPU test: no hardware available");
1261 }
1262 }
1263 }
1264
1265 #[test]
1266 fn test_tensor_core_operations() {
1267 let config = GpuAccelerationConfig::default();
1268
1269 if let Ok(accelerator) = SciRS2GpuAccelerator::new(config) {
1271 let _matrix_a = Array2::<f32>::ones((256, 512));
1273 let _matrix_b = Array2::<f32>::ones((512, 256));
1274
1275 let stats = accelerator.get_stats();
1277 assert_eq!(stats.total_operations, 0);
1278 } else {
1279 println!("Skipping tensor core test: no GPU hardware available");
1280 }
1281 }
1282}
1283
1284pub struct SciRS2GpuAccelerator {
1292 config: GpuAccelerationConfig,
1293 contexts: Vec<GpuContext>,
1294 operations: Arc<AtomicUsize>,
1295}
1296
1297impl SciRS2GpuAccelerator {
1298 pub fn new(config: GpuAccelerationConfig) -> Result<Self> {
1300 let mut contexts = Vec::new();
1301
1302 for _device_id in &config.device_ids {
1305 match GpuContext::new(GpuBackend::Cuda) {
1306 Ok(ctx) => {
1307 info!("Initialized GPU context");
1308 contexts.push(ctx);
1309 }
1310 Err(e) => {
1311 warn!("Failed to initialize GPU device: {}", e);
1312 }
1313 }
1314 }
1315
1316 if contexts.is_empty() {
1317 return Err(anyhow!("No GPU devices available for acceleration"));
1318 }
1319
1320 Ok(Self {
1321 config,
1322 contexts,
1323 operations: Arc::new(AtomicUsize::new(0)),
1324 })
1325 }
1326
1327 pub fn num_devices(&self) -> usize {
1329 self.contexts.len()
1330 }
1331
1332 pub fn tensor_core_gemm(
1339 &self,
1340 a: &Array2<f32>,
1341 b: &Array2<f32>,
1342 use_mixed_precision: bool,
1343 ) -> Result<Array2<f32>> {
1344 let result = if use_mixed_precision && self.config.mixed_precision {
1347 a.dot(b)
1350 } else {
1351 a.dot(b)
1353 };
1354
1355 self.operations.fetch_add(1, Ordering::Relaxed);
1357
1358 Ok(result)
1359 }
1360
1361 pub fn batch_embed(
1368 &self,
1369 inputs: &[Array1<f32>],
1370 embedding_matrix: &Array2<f32>,
1371 ) -> Result<Vec<Array1<f32>>> {
1372 let batch_size = inputs.len();
1373 let mut results = Vec::with_capacity(batch_size);
1374
1375 let stream_batch_size = if self.config.multi_stream {
1377 (batch_size + self.config.num_streams - 1) / self.config.num_streams
1378 } else {
1379 batch_size
1380 };
1381
1382 for chunk in inputs.chunks(stream_batch_size) {
1384 for input in chunk {
1385 let embedding = embedding_matrix.dot(input);
1388 results.push(embedding);
1389 }
1390 }
1391
1392 self.operations.fetch_add(batch_size, Ordering::Relaxed);
1394
1395 Ok(results)
1396 }
1397
1398 pub fn simd_similarity(
1405 &self,
1406 query: &Array1<f32>,
1407 candidates: &[Array1<f32>],
1408 ) -> Result<Vec<f32>> {
1409 let similarities: Vec<f32> = candidates
1411 .iter()
1412 .map(|candidate| {
1413 query.dot(candidate)
1416 })
1417 .collect();
1418
1419 self.operations
1421 .fetch_add(candidates.len(), Ordering::Relaxed);
1422
1423 Ok(similarities)
1424 }
1425
1426 pub fn get_stats(&self) -> AcceleratorStats {
1428 AcceleratorStats {
1429 total_operations: self.operations.load(Ordering::Relaxed),
1430 num_devices: self.contexts.len(),
1431 profiler_report: "Stats available".to_string(),
1432 }
1433 }
1434
1435 pub fn clear_stats(&self) {
1437 self.operations.store(0, Ordering::Relaxed);
1438 }
1439}
1440
1441#[derive(Debug, Clone)]
1443pub struct AcceleratorStats {
1444 pub total_operations: usize,
1445 pub num_devices: usize,
1446 pub profiler_report: String,
1447}