1use anyhow::{anyhow, Result};
8use scirs2_core::gpu::{GpuBackend, GpuContext};
9use scirs2_core::ndarray_ext::{Array1, Array2};
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, VecDeque};
12use std::sync::atomic::{AtomicUsize, Ordering};
13use std::sync::{Arc, Mutex};
14use std::time::{Duration, Instant};
15use tracing::{debug, info, warn};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct GpuAccelerationConfig {
20 pub enabled: bool,
22 pub device_ids: Vec<usize>,
24 pub memory_pool_size_mb: usize,
26 pub mixed_precision: bool,
28 pub tensor_caching: bool,
30 pub cache_size_mb: usize,
32 pub kernel_fusion: bool,
34 pub memory_mapping: bool,
36 pub unified_memory: bool,
38 pub multi_stream: bool,
40 pub num_streams: usize,
42 pub pipeline_parallelism: bool,
44 pub pipeline_stages: usize,
46}
47
48impl Default for GpuAccelerationConfig {
49 fn default() -> Self {
50 Self {
51 enabled: true,
52 device_ids: vec![0],
53 memory_pool_size_mb: 2048, mixed_precision: true,
55 tensor_caching: true,
56 cache_size_mb: 512, kernel_fusion: true,
58 memory_mapping: true,
59 unified_memory: false, multi_stream: true,
61 num_streams: 4,
62 pipeline_parallelism: false, pipeline_stages: 2,
64 }
65 }
66}
67
68pub struct GpuMemoryPool {
70 config: GpuAccelerationConfig,
71 allocated_blocks: Arc<Mutex<HashMap<usize, MemoryBlock>>>,
72 free_blocks: Arc<Mutex<VecDeque<MemoryBlock>>>,
73 total_allocated: Arc<Mutex<usize>>,
74 allocation_stats: Arc<Mutex<AllocationStats>>,
75}
76
77#[derive(Debug, Clone)]
79struct MemoryBlock {
80 device_id: usize,
81 size_bytes: usize,
82 ptr: usize, allocated_at: Instant,
84 last_used: Instant,
85}
86
87#[derive(Debug, Default, Clone)]
89pub struct AllocationStats {
90 pub total_allocations: usize,
91 pub total_deallocations: usize,
92 pub peak_memory_usage: usize,
93 pub current_memory_usage: usize,
94 pub cache_hits: usize,
95 pub cache_misses: usize,
96}
97
98impl GpuMemoryPool {
99 pub fn new(config: GpuAccelerationConfig) -> Self {
101 Self {
102 config,
103 allocated_blocks: Arc::new(Mutex::new(HashMap::new())),
104 free_blocks: Arc::new(Mutex::new(VecDeque::new())),
105 total_allocated: Arc::new(Mutex::new(0)),
106 allocation_stats: Arc::new(Mutex::new(AllocationStats::default())),
107 }
108 }
109
110 pub fn allocate(&self, size_bytes: usize, device_id: usize) -> Result<usize> {
112 let mut free_blocks = self.free_blocks.lock().unwrap();
113 let mut allocated_blocks = self.allocated_blocks.lock().unwrap();
114 let mut stats = self.allocation_stats.lock().unwrap();
115
116 for (i, block) in free_blocks.iter().enumerate() {
118 if block.size_bytes >= size_bytes && block.device_id == device_id {
119 let block = free_blocks.remove(i).unwrap();
120 let block_id = block.ptr;
121
122 let mut reused_block = block;
123 reused_block.last_used = Instant::now();
124
125 allocated_blocks.insert(block_id, reused_block);
126 stats.cache_hits += 1;
127
128 debug!(
129 "Reused GPU memory block {} of size {}",
130 block_id, size_bytes
131 );
132 return Ok(block_id);
133 }
134 }
135
136 stats.cache_misses += 1;
138 stats.total_allocations += 1;
139
140 let block_id = stats.total_allocations; let now = Instant::now();
142
143 let block = MemoryBlock {
144 device_id,
145 size_bytes,
146 ptr: block_id,
147 allocated_at: now,
148 last_used: now,
149 };
150
151 allocated_blocks.insert(block_id, block);
152
153 let mut total_allocated = self.total_allocated.lock().unwrap();
154 *total_allocated += size_bytes;
155 stats.current_memory_usage += size_bytes;
156
157 if stats.current_memory_usage > stats.peak_memory_usage {
158 stats.peak_memory_usage = stats.current_memory_usage;
159 }
160
161 info!(
162 "Allocated new GPU memory block {} of size {} bytes",
163 block_id, size_bytes
164 );
165 Ok(block_id)
166 }
167
168 pub fn deallocate(&self, block_id: usize) -> Result<()> {
170 let mut allocated_blocks = self.allocated_blocks.lock().unwrap();
171 let mut free_blocks = self.free_blocks.lock().unwrap();
172 let mut stats = self.allocation_stats.lock().unwrap();
173
174 if let Some(block) = allocated_blocks.remove(&block_id) {
175 stats.total_deallocations += 1;
176 stats.current_memory_usage -= block.size_bytes;
177
178 free_blocks.push_back(block);
180
181 if free_blocks.len() > 100 {
183 free_blocks.pop_front();
184 }
185
186 debug!("Deallocated GPU memory block {}", block_id);
187 Ok(())
188 } else {
189 Err(anyhow!("Block {} not found for deallocation", block_id))
190 }
191 }
192
193 pub fn get_stats(&self) -> AllocationStats {
195 (*self.allocation_stats.lock().unwrap()).clone()
196 }
197
198 pub fn defragment(&self) -> Result<()> {
200 let mut free_blocks = self.free_blocks.lock().unwrap();
201
202 let mut blocks: Vec<_> = free_blocks.drain(..).collect();
204 blocks.sort_by_key(|b| (b.device_id, b.size_bytes));
205
206 let mut merged_blocks = VecDeque::new();
208 let mut current_block: Option<MemoryBlock> = None;
209
210 for block in blocks {
211 if let Some(ref mut current) = current_block {
212 if current.device_id == block.device_id {
213 current.size_bytes += block.size_bytes;
215 } else {
216 merged_blocks.push_back(current.clone());
217 current_block = Some(block);
218 }
219 } else {
220 current_block = Some(block);
221 }
222 }
223
224 if let Some(block) = current_block {
225 merged_blocks.push_back(block);
226 }
227
228 *free_blocks = merged_blocks;
229
230 info!(
231 "Memory defragmentation completed, {} free blocks remaining",
232 free_blocks.len()
233 );
234 Ok(())
235 }
236}
237
238pub struct TensorCache {
240 config: GpuAccelerationConfig,
241 entity_tensors: Arc<Mutex<HashMap<String, CachedTensor>>>,
242 attention_weights: Arc<Mutex<HashMap<String, CachedTensor>>>,
243 intermediate_activations: Arc<Mutex<HashMap<String, CachedTensor>>>,
244 cache_stats: Arc<Mutex<CacheStats>>,
245}
246
247#[derive(Debug, Clone)]
249struct CachedTensor {
250 data: Array2<f32>, device_id: usize,
252 last_accessed: Instant,
253 access_count: usize,
254 size_bytes: usize,
255}
256
257#[derive(Debug, Default, Clone)]
259pub struct CacheStats {
260 pub hits: usize,
261 pub misses: usize,
262 pub evictions: usize,
263 pub total_memory_usage: usize,
264}
265
266impl TensorCache {
267 pub fn new(config: GpuAccelerationConfig) -> Self {
269 Self {
270 config,
271 entity_tensors: Arc::new(Mutex::new(HashMap::new())),
272 attention_weights: Arc::new(Mutex::new(HashMap::new())),
273 intermediate_activations: Arc::new(Mutex::new(HashMap::new())),
274 cache_stats: Arc::new(Mutex::new(CacheStats::default())),
275 }
276 }
277
278 pub fn cache_entity_tensor(&self, entity: &str, tensor: Array2<f32>, device_id: usize) {
280 let mut cache = self.entity_tensors.lock().unwrap();
281 let mut stats = self.cache_stats.lock().unwrap();
282
283 let size_bytes = tensor.len() * std::mem::size_of::<f32>();
284
285 let cached_tensor = CachedTensor {
286 data: tensor,
287 device_id,
288 last_accessed: Instant::now(),
289 access_count: 1,
290 size_bytes,
291 };
292
293 self.evict_if_needed(&mut stats);
295
296 cache.insert(entity.to_string(), cached_tensor);
297 stats.total_memory_usage += size_bytes;
298
299 debug!("Cached entity tensor for {}", entity);
300 }
301
302 pub fn get_entity_tensor(&self, entity: &str) -> Option<Array2<f32>> {
304 let mut cache = self.entity_tensors.lock().unwrap();
305 let mut stats = self.cache_stats.lock().unwrap();
306
307 if let Some(cached) = cache.get_mut(entity) {
308 cached.last_accessed = Instant::now();
309 cached.access_count += 1;
310 stats.hits += 1;
311
312 debug!("Cache hit for entity tensor {}", entity);
313 Some(cached.data.clone())
314 } else {
315 stats.misses += 1;
316 debug!("Cache miss for entity tensor {}", entity);
317 None
318 }
319 }
320
321 pub fn cache_attention_weights(&self, key: &str, weights: Array2<f32>, device_id: usize) {
323 let mut cache = self.attention_weights.lock().unwrap();
324 let mut stats = self.cache_stats.lock().unwrap();
325
326 let size_bytes = weights.len() * std::mem::size_of::<f32>();
327
328 let cached_tensor = CachedTensor {
329 data: weights,
330 device_id,
331 last_accessed: Instant::now(),
332 access_count: 1,
333 size_bytes,
334 };
335
336 self.evict_if_needed(&mut stats);
337
338 cache.insert(key.to_string(), cached_tensor);
339 stats.total_memory_usage += size_bytes;
340
341 debug!("Cached attention weights for key {}", key);
342 }
343
344 pub fn get_attention_weights(&self, key: &str) -> Option<Array2<f32>> {
346 let mut cache = self.attention_weights.lock().unwrap();
347 let mut stats = self.cache_stats.lock().unwrap();
348
349 if let Some(cached) = cache.get_mut(key) {
350 cached.last_accessed = Instant::now();
351 cached.access_count += 1;
352 stats.hits += 1;
353
354 debug!("Cache hit for attention weights {}", key);
355 Some(cached.data.clone())
356 } else {
357 stats.misses += 1;
358 debug!("Cache miss for attention weights {}", key);
359 None
360 }
361 }
362
363 fn evict_if_needed(&self, stats: &mut CacheStats) {
365 let max_memory = self.config.cache_size_mb * 1024 * 1024; if stats.total_memory_usage > max_memory {
368 stats.evictions += 1;
370 stats.total_memory_usage = max_memory / 2; warn!("Tensor cache eviction triggered, freed memory");
373 }
374 }
375
376 pub fn get_stats(&self) -> CacheStats {
378 (*self.cache_stats.lock().unwrap()).clone()
379 }
380
381 pub fn clear_all(&self) {
383 self.entity_tensors.lock().unwrap().clear();
384 self.attention_weights.lock().unwrap().clear();
385 self.intermediate_activations.lock().unwrap().clear();
386
387 let mut stats = self.cache_stats.lock().unwrap();
388 stats.total_memory_usage = 0;
389
390 info!("Cleared all tensor caches");
391 }
392}
393
394pub struct MixedPrecisionProcessor {
396 config: GpuAccelerationConfig,
397 fp16_enabled: bool,
398 loss_scaling: f32,
399 overflow_detection: bool,
400}
401
402impl MixedPrecisionProcessor {
403 pub fn new(config: GpuAccelerationConfig) -> Self {
405 Self {
406 config: config.clone(),
407 fp16_enabled: config.mixed_precision,
408 loss_scaling: 65536.0, overflow_detection: true,
410 }
411 }
412
413 pub fn to_fp16(&self, tensor: &Array2<f32>) -> Array2<f32> {
415 if !self.fp16_enabled {
416 return tensor.clone();
417 }
418
419 tensor.mapv(|x| {
421 let clamped = x.clamp(-65504.0, 65504.0);
423 (clamped * 1024.0).round() / 1024.0 })
425 }
426
427 pub fn scale_loss(&self, loss: f32) -> f32 {
429 if self.fp16_enabled {
430 loss * self.loss_scaling
431 } else {
432 loss
433 }
434 }
435
436 pub fn unscale_gradients(&self, gradients: &mut Array2<f32>) -> bool {
438 if !self.fp16_enabled {
439 return true;
440 }
441
442 if self.overflow_detection {
444 let has_overflow = gradients.iter().any(|&x| !x.is_finite());
445 if has_overflow {
446 warn!("Gradient overflow detected in mixed precision training");
447 return false;
448 }
449 }
450
451 gradients.mapv_inplace(|x| x / self.loss_scaling);
453 true
454 }
455
456 pub fn adjust_loss_scaling(&mut self, overflow_detected: bool) {
458 if overflow_detected {
459 self.loss_scaling = (self.loss_scaling / 2.0).max(1.0);
460 info!("Reduced loss scaling to {}", self.loss_scaling);
461 } else {
462 self.loss_scaling = (self.loss_scaling * 1.1).min(65536.0);
464 }
465 }
466}
467
468pub struct MultiStreamProcessor {
470 config: GpuAccelerationConfig,
471 pub stream_ids: Vec<usize>,
472 current_stream: usize,
473}
474
475impl MultiStreamProcessor {
476 pub fn new(config: GpuAccelerationConfig) -> Self {
478 let stream_ids = (0..config.num_streams).collect();
479
480 Self {
481 config,
482 stream_ids,
483 current_stream: 0,
484 }
485 }
486
487 pub fn get_next_stream(&mut self) -> usize {
489 let stream_id = self.stream_ids[self.current_stream];
490 self.current_stream = (self.current_stream + 1) % self.stream_ids.len();
491 stream_id
492 }
493
494 pub async fn process_batch_parallel(
496 &mut self,
497 entities: Vec<String>,
498 process_fn: impl Fn(String, usize) -> Array1<f32> + Send + Sync + Copy + 'static,
499 ) -> Result<Vec<Array1<f32>>> {
500 let chunk_size = (entities.len() + self.config.num_streams - 1) / self.config.num_streams;
501 let mut tasks = Vec::new();
502
503 for chunk in entities.chunks(chunk_size) {
504 let stream_id = self.get_next_stream();
505 let chunk_entities = chunk.to_vec();
506
507 let task = tokio::spawn(async move {
508 let mut results = Vec::new();
509 for entity in chunk_entities {
510 let embedding = process_fn(entity, stream_id);
511 results.push(embedding);
512 }
513 results
514 });
515
516 tasks.push(task);
517 }
518
519 let mut all_results = Vec::new();
521 for task in tasks {
522 let chunk_results = task.await?;
523 all_results.extend(chunk_results);
524 }
525
526 Ok(all_results)
527 }
528
529 pub fn synchronize_all(&self) {
531 debug!("Synchronized {} GPU streams", self.stream_ids.len());
533 }
534}
535
536pub struct GpuAccelerationManager {
538 config: GpuAccelerationConfig,
539 memory_pool: GpuMemoryPool,
540 tensor_cache: TensorCache,
541 mixed_precision: MixedPrecisionProcessor,
542 multi_stream: MultiStreamProcessor,
543}
544
545impl GpuAccelerationManager {
546 pub fn new(config: GpuAccelerationConfig) -> Self {
548 let memory_pool = GpuMemoryPool::new(config.clone());
549 let tensor_cache = TensorCache::new(config.clone());
550 let mixed_precision = MixedPrecisionProcessor::new(config.clone());
551 let multi_stream = MultiStreamProcessor::new(config.clone());
552
553 Self {
554 config,
555 memory_pool,
556 tensor_cache,
557 mixed_precision,
558 multi_stream,
559 }
560 }
561
562 pub fn memory_pool(&self) -> &GpuMemoryPool {
564 &self.memory_pool
565 }
566
567 pub fn tensor_cache(&self) -> &TensorCache {
569 &self.tensor_cache
570 }
571
572 pub fn mixed_precision(&mut self) -> &mut MixedPrecisionProcessor {
574 &mut self.mixed_precision
575 }
576
577 pub fn multi_stream(&mut self) -> &mut MultiStreamProcessor {
579 &mut self.multi_stream
580 }
581
582 pub async fn accelerated_embedding_generation(
584 &mut self,
585 entities: Vec<String>,
586 base_compute_fn: impl Fn(&str) -> Array1<f32> + Send + Sync + Copy + 'static,
587 ) -> Result<Vec<Array1<f32>>> {
588 if !self.config.enabled {
589 return Ok(entities.iter().map(|e| base_compute_fn(e)).collect());
591 }
592
593 let results = self
595 .multi_stream
596 .process_batch_parallel(entities, move |entity, stream_id| {
597 debug!("Processing entity {} on stream {}", entity, stream_id);
599 base_compute_fn(&entity)
600 })
601 .await?;
602
603 self.multi_stream.synchronize_all();
604 Ok(results)
605 }
606
607 pub fn get_performance_stats(&self) -> GpuPerformanceStats {
609 let memory_stats = self.memory_pool.get_stats();
610 let cache_stats = self.tensor_cache.get_stats();
611
612 GpuPerformanceStats {
613 memory_allocations: memory_stats.total_allocations,
614 memory_deallocations: memory_stats.total_deallocations,
615 peak_memory_usage_mb: memory_stats.peak_memory_usage / (1024 * 1024),
616 current_memory_usage_mb: memory_stats.current_memory_usage / (1024 * 1024),
617 memory_pool_hits: memory_stats.cache_hits,
618 memory_pool_misses: memory_stats.cache_misses,
619 tensor_cache_hits: cache_stats.hits,
620 tensor_cache_misses: cache_stats.misses,
621 tensor_cache_evictions: cache_stats.evictions,
622 tensor_cache_memory_mb: cache_stats.total_memory_usage / (1024 * 1024),
623 loss_scaling_factor: self.mixed_precision.loss_scaling,
624 num_active_streams: self.config.num_streams,
625 }
626 }
627}
628
629#[derive(Debug, Serialize)]
631pub struct GpuPerformanceStats {
632 pub memory_allocations: usize,
633 pub memory_deallocations: usize,
634 pub peak_memory_usage_mb: usize,
635 pub current_memory_usage_mb: usize,
636 pub memory_pool_hits: usize,
637 pub memory_pool_misses: usize,
638 pub tensor_cache_hits: usize,
639 pub tensor_cache_misses: usize,
640 pub tensor_cache_evictions: usize,
641 pub tensor_cache_memory_mb: usize,
642 pub loss_scaling_factor: f32,
643 pub num_active_streams: usize,
644}
645
646pub struct MemoryDefragmenter {
648 config: GpuAccelerationConfig,
649 defrag_threshold: f32,
650 last_defrag: Instant,
651 defrag_interval: Duration,
652}
653
654impl MemoryDefragmenter {
655 pub fn new(config: GpuAccelerationConfig) -> Self {
657 Self {
658 config,
659 defrag_threshold: 0.7, last_defrag: Instant::now(),
661 defrag_interval: Duration::from_secs(300), }
663 }
664
665 pub fn should_defragment(&self, memory_pool: &GpuMemoryPool) -> bool {
667 let stats = memory_pool.get_stats();
668 let fragmentation_ratio = self.calculate_fragmentation_ratio(&stats);
669
670 fragmentation_ratio > self.defrag_threshold
671 && self.last_defrag.elapsed() > self.defrag_interval
672 }
673
674 fn calculate_fragmentation_ratio(&self, stats: &AllocationStats) -> f32 {
676 if stats.current_memory_usage == 0 {
677 return 0.0;
678 }
679
680 let theoretical_optimal = stats.current_memory_usage;
683 let actual_allocated = stats.peak_memory_usage;
684
685 if actual_allocated == 0 {
686 0.0
687 } else {
688 1.0 - (theoretical_optimal as f32 / actual_allocated as f32)
689 }
690 }
691
692 pub fn defragment(&mut self, memory_pool: &GpuMemoryPool) -> Result<DefragmentationResult> {
694 info!("Starting GPU memory defragmentation");
695 let start_time = Instant::now();
696
697 std::thread::sleep(Duration::from_millis(100));
704
705 let stats_before = memory_pool.get_stats();
706
707 let stats_after = memory_pool.get_stats();
711 self.last_defrag = Instant::now();
712
713 let result = DefragmentationResult {
714 duration: start_time.elapsed(),
715 memory_freed: stats_before
716 .peak_memory_usage
717 .saturating_sub(stats_after.current_memory_usage),
718 fragmentation_before: self.calculate_fragmentation_ratio(&stats_before),
719 fragmentation_after: self.calculate_fragmentation_ratio(&stats_after),
720 };
721
722 info!("Defragmentation completed: {:?}", result);
723 Ok(result)
724 }
725}
726
727#[derive(Debug, Clone)]
729pub struct DefragmentationResult {
730 pub duration: Duration,
731 pub memory_freed: usize,
732 pub fragmentation_before: f32,
733 pub fragmentation_after: f32,
734}
735
736pub struct OutOfCoreProcessor {
738 config: GpuAccelerationConfig,
739 chunk_size: usize,
740 overlap_size: usize,
741 memory_limit: usize,
742}
743
744impl OutOfCoreProcessor {
745 pub fn new(config: GpuAccelerationConfig) -> Self {
747 let memory_limit = config.memory_pool_size_mb * 1024 * 1024; let chunk_size = memory_limit / 4; let overlap_size = chunk_size / 10; Self {
752 config,
753 chunk_size,
754 overlap_size,
755 memory_limit,
756 }
757 }
758
759 pub async fn process_large_batch<T>(
761 &self,
762 data: Vec<T>,
763 process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>> + Send + Sync + Copy,
764 ) -> Result<Vec<Array1<f32>>>
765 where
766 T: Clone + Send + Sync + 'static,
767 {
768 if data.is_empty() {
769 return Ok(Vec::new());
770 }
771
772 let item_size = std::mem::size_of::<T>();
774 let max_items_per_chunk = self.chunk_size / item_size;
775 let chunk_size = max_items_per_chunk.clamp(1, 1000); info!(
778 "Processing {} items in chunks of {}",
779 data.len(),
780 chunk_size
781 );
782
783 let mut results = Vec::new();
784 let mut processed_count = 0;
785
786 for chunk in data.chunks(chunk_size) {
787 let chunk_results = process_fn(chunk)?;
789 results.extend(chunk_results);
790
791 processed_count += chunk.len();
792
793 if processed_count % (chunk_size * 10) == 0 {
794 info!("Processed {}/{} items", processed_count, data.len());
795 }
796
797 tokio::task::yield_now().await;
799 }
800
801 Ok(results)
802 }
803
804 pub async fn process_with_overlap<T>(
806 &self,
807 data: Vec<T>,
808 process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>> + Send + Sync + Copy,
809 ) -> Result<Vec<Array1<f32>>>
810 where
811 T: Clone + Send + Sync + 'static,
812 {
813 if data.is_empty() {
814 return Ok(Vec::new());
815 }
816
817 let item_size = std::mem::size_of::<T>();
818 let max_items_per_chunk = self.chunk_size / item_size;
819 let chunk_size = max_items_per_chunk.clamp(1, 1000);
820
821 let mut results = Vec::new();
822 let mut start_idx = 0;
823
824 while start_idx < data.len() {
825 let end_idx = (start_idx + chunk_size).min(data.len());
826 let chunk = &data[start_idx..end_idx];
827
828 let chunk_results = process_fn(chunk)?;
829
830 let take_count = if start_idx == 0 {
832 chunk_results.len()
833 } else {
834 chunk_results
836 .len()
837 .saturating_sub(self.overlap_size / item_size)
838 };
839
840 results.extend(chunk_results.into_iter().take(take_count));
841
842 start_idx += chunk_size - self.overlap_size / item_size;
843 tokio::task::yield_now().await;
844 }
845
846 Ok(results)
847 }
848}
849
850pub struct DynamicShapeHandler {
852 config: GpuAccelerationConfig,
853 shape_cache: HashMap<Vec<usize>, ShapeInfo>,
854 max_cached_shapes: usize,
855}
856
857#[derive(Debug, Clone)]
859struct ShapeInfo {
860 shape: Vec<usize>,
861 memory_requirement: usize,
862 optimal_batch_size: usize,
863 last_used: Instant,
864}
865
866impl DynamicShapeHandler {
867 pub fn new(config: GpuAccelerationConfig) -> Self {
869 Self {
870 config,
871 shape_cache: HashMap::new(),
872 max_cached_shapes: 100,
873 }
874 }
875
876 pub fn optimize_shape(&mut self, shape: Vec<usize>) -> Vec<usize> {
878 if let Some(shape_info) = self.shape_cache.get_mut(&shape) {
880 shape_info.last_used = Instant::now();
881 return shape_info.shape.clone();
882 }
883
884 let optimized_shape = self.calculate_optimal_shape(&shape);
886
887 self.cache_shape_info(shape.clone(), optimized_shape.clone());
889
890 optimized_shape
891 }
892
893 fn calculate_optimal_shape(&self, shape: &[usize]) -> Vec<usize> {
895 let mut optimized = shape.to_vec();
896
897 const WARP_SIZE: usize = 32;
899
900 for dim in &mut optimized {
901 if *dim > 0 {
902 *dim = ((*dim + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
904 }
905 }
906
907 optimized
908 }
909
910 fn cache_shape_info(&mut self, original_shape: Vec<usize>, optimized_shape: Vec<usize>) {
912 if self.shape_cache.len() >= self.max_cached_shapes {
914 self.evict_oldest_shape();
915 }
916
917 let memory_requirement = optimized_shape.iter().product::<usize>() * 4; let optimal_batch_size = self.calculate_optimal_batch_size(memory_requirement);
919
920 let shape_info = ShapeInfo {
921 shape: optimized_shape,
922 memory_requirement,
923 optimal_batch_size,
924 last_used: Instant::now(),
925 };
926
927 self.shape_cache.insert(original_shape, shape_info);
928 }
929
930 fn calculate_optimal_batch_size(&self, memory_per_item: usize) -> usize {
932 if memory_per_item == 0 {
933 return 1;
934 }
935
936 let available_memory = (self.config.memory_pool_size_mb * 1024 * 1024) / 2; let max_batch_size = available_memory / memory_per_item;
938
939 max_batch_size.clamp(1, 1024)
941 }
942
943 fn evict_oldest_shape(&mut self) {
945 if let Some(oldest_key) = self
946 .shape_cache
947 .iter()
948 .min_by_key(|(_, info)| info.last_used)
949 .map(|(key, _)| key.clone())
950 {
951 self.shape_cache.remove(&oldest_key);
952 }
953 }
954
955 pub fn get_optimal_batch_size(&self, shape: &[usize]) -> usize {
957 self.shape_cache
958 .get(shape)
959 .map(|info| info.optimal_batch_size)
960 .unwrap_or(1)
961 }
962}
963
964pub struct BatchSizeOptimizer {
966 config: GpuAccelerationConfig,
967 performance_history: VecDeque<BatchPerformance>,
968 max_history_size: usize,
969 current_optimal_batch_size: usize,
970}
971
972#[derive(Debug, Clone)]
974struct BatchPerformance {
975 batch_size: usize,
976 processing_time: Duration,
977 memory_usage: usize,
978 throughput: f64, gpu_utilization: f64,
980 timestamp: Instant,
981}
982
983impl BatchSizeOptimizer {
984 pub fn new(config: GpuAccelerationConfig) -> Self {
986 Self {
987 config,
988 performance_history: VecDeque::new(),
989 max_history_size: 50,
990 current_optimal_batch_size: 32, }
992 }
993
994 pub async fn find_optimal_batch_size<T>(
996 &mut self,
997 sample_data: Vec<T>,
998 process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>> + Send + Sync + Copy,
999 ) -> Result<usize>
1000 where
1001 T: Clone + Send + Sync + 'static,
1002 {
1003 if sample_data.is_empty() {
1004 return Ok(1);
1005 }
1006
1007 info!("Optimizing batch size for embedding generation");
1008
1009 let test_sizes = vec![1, 8, 16, 32, 64, 128, 256, 512];
1010 let max_test_size = sample_data.len().min(512);
1011
1012 let mut best_batch_size = 1;
1013 let mut best_throughput = 0.0;
1014
1015 for &batch_size in &test_sizes {
1016 if batch_size > max_test_size {
1017 break;
1018 }
1019
1020 let performance = self
1022 .test_batch_size(
1023 &sample_data[..batch_size.min(sample_data.len())],
1024 batch_size,
1025 process_fn,
1026 )
1027 .await?;
1028
1029 info!(
1030 "Batch size {}: {:.2} items/sec, {:.1}ms processing time",
1031 batch_size,
1032 performance.throughput,
1033 performance.processing_time.as_millis()
1034 );
1035
1036 if performance.throughput > best_throughput {
1037 best_throughput = performance.throughput;
1038 best_batch_size = batch_size;
1039 }
1040
1041 self.performance_history.push_back(performance);
1043 if self.performance_history.len() > self.max_history_size {
1044 self.performance_history.pop_front();
1045 }
1046
1047 tokio::time::sleep(Duration::from_millis(100)).await;
1049 }
1050
1051 self.current_optimal_batch_size = best_batch_size;
1052 info!("Optimal batch size determined: {}", best_batch_size);
1053
1054 Ok(best_batch_size)
1055 }
1056
1057 async fn test_batch_size<T>(
1059 &self,
1060 sample_data: &[T],
1061 batch_size: usize,
1062 process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>>,
1063 ) -> Result<BatchPerformance>
1064 where
1065 T: Clone,
1066 {
1067 let start_time = Instant::now();
1068 let memory_before = self.estimate_memory_usage();
1069
1070 let _results = process_fn(sample_data)?;
1072
1073 let processing_time = start_time.elapsed();
1074 let memory_after = self.estimate_memory_usage();
1075 let memory_usage = memory_after.saturating_sub(memory_before);
1076
1077 let throughput = if processing_time.as_secs_f64() > 0.0 {
1079 sample_data.len() as f64 / processing_time.as_secs_f64()
1080 } else {
1081 0.0
1082 };
1083
1084 let gpu_utilization = self.estimate_gpu_utilization(batch_size, processing_time);
1086
1087 Ok(BatchPerformance {
1088 batch_size,
1089 processing_time,
1090 memory_usage,
1091 throughput,
1092 gpu_utilization,
1093 timestamp: Instant::now(),
1094 })
1095 }
1096
1097 fn estimate_memory_usage(&self) -> usize {
1099 (self.config.memory_pool_size_mb * 1024 * 1024) / 4 }
1103
1104 fn estimate_gpu_utilization(&self, batch_size: usize, processing_time: Duration) -> f64 {
1106 let base_utilization = (batch_size as f64).log2() / 10.0; let time_factor = if processing_time.as_millis() < 10 {
1109 0.5 } else if processing_time.as_millis() > 1000 {
1111 0.7 } else {
1113 1.0
1114 };
1115
1116 (base_utilization * time_factor).clamp(0.0, 1.0)
1117 }
1118
1119 pub fn get_optimal_batch_size(&self) -> usize {
1121 self.current_optimal_batch_size
1122 }
1123
1124 pub fn get_performance_stats(&self) -> BatchSizeOptimizerStats {
1126 let avg_throughput = if !self.performance_history.is_empty() {
1127 self.performance_history
1128 .iter()
1129 .map(|p| p.throughput)
1130 .sum::<f64>()
1131 / self.performance_history.len() as f64
1132 } else {
1133 0.0
1134 };
1135
1136 let avg_gpu_utilization = if !self.performance_history.is_empty() {
1137 self.performance_history
1138 .iter()
1139 .map(|p| p.gpu_utilization)
1140 .sum::<f64>()
1141 / self.performance_history.len() as f64
1142 } else {
1143 0.0
1144 };
1145
1146 BatchSizeOptimizerStats {
1147 current_optimal_batch_size: self.current_optimal_batch_size,
1148 avg_throughput,
1149 avg_gpu_utilization,
1150 total_tests_performed: self.performance_history.len(),
1151 }
1152 }
1153}
1154
1155#[derive(Debug, Clone, Serialize, Deserialize)]
1157pub struct BatchSizeOptimizerStats {
1158 pub current_optimal_batch_size: usize,
1159 pub avg_throughput: f64,
1160 pub avg_gpu_utilization: f64,
1161 pub total_tests_performed: usize,
1162}
1163
1164#[cfg(test)]
1165mod tests {
1166 use super::*;
1167
1168 #[test]
1169 fn test_gpu_acceleration_config_default() {
1170 let config = GpuAccelerationConfig::default();
1171 assert!(config.enabled);
1172 assert_eq!(config.device_ids, vec![0]);
1173 assert_eq!(config.memory_pool_size_mb, 2048);
1174 assert!(config.mixed_precision);
1175 assert!(config.tensor_caching);
1176 }
1177
1178 #[test]
1179 fn test_memory_pool_allocation() {
1180 let config = GpuAccelerationConfig::default();
1181 let pool = GpuMemoryPool::new(config);
1182
1183 let block_id = pool.allocate(1024, 0).unwrap();
1184 assert!(block_id > 0);
1185
1186 pool.deallocate(block_id).unwrap();
1187
1188 let block_id2 = pool.allocate(1024, 0).unwrap();
1190 assert_eq!(block_id, block_id2);
1191 }
1192
1193 #[test]
1194 fn test_tensor_cache() {
1195 let config = GpuAccelerationConfig::default();
1196 let cache = TensorCache::new(config);
1197
1198 let tensor = Array2::zeros((10, 20));
1199 cache.cache_entity_tensor("test_entity", tensor.clone(), 0);
1200
1201 let cached = cache.get_entity_tensor("test_entity").unwrap();
1202 assert_eq!(cached.shape(), tensor.shape());
1203 }
1204
1205 #[test]
1206 fn test_mixed_precision() {
1207 let config = GpuAccelerationConfig::default();
1208 let processor = MixedPrecisionProcessor::new(config);
1209
1210 let tensor = Array2::from_elem((2, 2), 1.0001);
1212 let fp16_tensor = processor.to_fp16(&tensor);
1213
1214 if processor.fp16_enabled {
1215 assert!(fp16_tensor[[0, 0]] != tensor[[0, 0]]);
1217 } else {
1218 assert_eq!(fp16_tensor[[0, 0]], tensor[[0, 0]]);
1220 }
1221 }
1222
1223 #[tokio::test]
1224 async fn test_multi_stream_processing() {
1225 let config = GpuAccelerationConfig::default();
1226 let mut processor = MultiStreamProcessor::new(config);
1227
1228 let entities = vec!["entity1".to_string(), "entity2".to_string()];
1229 let process_fn = |entity: String, _stream_id: usize| -> Array1<f32> {
1230 Array1::from_vec(vec![entity.len() as f32])
1231 };
1232
1233 let results = processor
1234 .process_batch_parallel(entities, process_fn)
1235 .await
1236 .unwrap();
1237 assert_eq!(results.len(), 2);
1238 }
1239
1240 #[test]
1241 fn test_scirs2_gpu_accelerator() {
1242 let config = GpuAccelerationConfig::default();
1244
1245 match SciRS2GpuAccelerator::new(config) {
1246 Ok(accelerator) => {
1247 assert!(accelerator.num_devices() > 0);
1249 }
1250 Err(_) => {
1251 println!("Skipping GPU test: no hardware available");
1253 }
1254 }
1255 }
1256
1257 #[test]
1258 fn test_tensor_core_operations() {
1259 let config = GpuAccelerationConfig::default();
1260
1261 if let Ok(accelerator) = SciRS2GpuAccelerator::new(config) {
1263 let _matrix_a = Array2::<f32>::ones((256, 512));
1265 let _matrix_b = Array2::<f32>::ones((512, 256));
1266
1267 let stats = accelerator.get_stats();
1269 assert_eq!(stats.total_operations, 0);
1270 } else {
1271 println!("Skipping tensor core test: no GPU hardware available");
1272 }
1273 }
1274}
1275
1276pub struct SciRS2GpuAccelerator {
1284 config: GpuAccelerationConfig,
1285 contexts: Vec<GpuContext>,
1286 operations: Arc<AtomicUsize>,
1287}
1288
1289impl SciRS2GpuAccelerator {
1290 pub fn new(config: GpuAccelerationConfig) -> Result<Self> {
1292 let mut contexts = Vec::new();
1293
1294 for _device_id in &config.device_ids {
1297 match GpuContext::new(GpuBackend::Cuda) {
1298 Ok(ctx) => {
1299 info!("Initialized GPU context");
1300 contexts.push(ctx);
1301 }
1302 Err(e) => {
1303 warn!("Failed to initialize GPU device: {}", e);
1304 }
1305 }
1306 }
1307
1308 if contexts.is_empty() {
1309 return Err(anyhow!("No GPU devices available for acceleration"));
1310 }
1311
1312 Ok(Self {
1313 config,
1314 contexts,
1315 operations: Arc::new(AtomicUsize::new(0)),
1316 })
1317 }
1318
1319 pub fn num_devices(&self) -> usize {
1321 self.contexts.len()
1322 }
1323
1324 pub fn tensor_core_gemm(
1331 &self,
1332 a: &Array2<f32>,
1333 b: &Array2<f32>,
1334 use_mixed_precision: bool,
1335 ) -> Result<Array2<f32>> {
1336 let result = if use_mixed_precision && self.config.mixed_precision {
1339 a.dot(b)
1342 } else {
1343 a.dot(b)
1345 };
1346
1347 self.operations.fetch_add(1, Ordering::Relaxed);
1349
1350 Ok(result)
1351 }
1352
1353 pub fn batch_embed(
1360 &self,
1361 inputs: &[Array1<f32>],
1362 embedding_matrix: &Array2<f32>,
1363 ) -> Result<Vec<Array1<f32>>> {
1364 let batch_size = inputs.len();
1365 let mut results = Vec::with_capacity(batch_size);
1366
1367 let stream_batch_size = if self.config.multi_stream {
1369 (batch_size + self.config.num_streams - 1) / self.config.num_streams
1370 } else {
1371 batch_size
1372 };
1373
1374 for chunk in inputs.chunks(stream_batch_size) {
1376 for input in chunk {
1377 let embedding = embedding_matrix.dot(input);
1380 results.push(embedding);
1381 }
1382 }
1383
1384 self.operations.fetch_add(batch_size, Ordering::Relaxed);
1386
1387 Ok(results)
1388 }
1389
1390 pub fn simd_similarity(
1397 &self,
1398 query: &Array1<f32>,
1399 candidates: &[Array1<f32>],
1400 ) -> Result<Vec<f32>> {
1401 let similarities: Vec<f32> = candidates
1403 .iter()
1404 .map(|candidate| {
1405 query.dot(candidate)
1408 })
1409 .collect();
1410
1411 self.operations
1413 .fetch_add(candidates.len(), Ordering::Relaxed);
1414
1415 Ok(similarities)
1416 }
1417
1418 pub fn get_stats(&self) -> AcceleratorStats {
1420 AcceleratorStats {
1421 total_operations: self.operations.load(Ordering::Relaxed),
1422 num_devices: self.contexts.len(),
1423 profiler_report: "Stats available".to_string(),
1424 }
1425 }
1426
1427 pub fn clear_stats(&self) {
1429 self.operations.store(0, Ordering::Relaxed);
1430 }
1431}
1432
1433#[derive(Debug, Clone)]
1435pub struct AcceleratorStats {
1436 pub total_operations: usize,
1437 pub num_devices: usize,
1438 pub profiler_report: String,
1439}