1use anyhow::{anyhow, Result};
7use scirs2_core::ndarray_ext::{Array1, Array2};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12use tracing::{debug, info, warn};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct GpuAccelerationConfig {
17 pub enabled: bool,
19 pub device_ids: Vec<usize>,
21 pub memory_pool_size_mb: usize,
23 pub mixed_precision: bool,
25 pub tensor_caching: bool,
27 pub cache_size_mb: usize,
29 pub kernel_fusion: bool,
31 pub memory_mapping: bool,
33 pub unified_memory: bool,
35 pub multi_stream: bool,
37 pub num_streams: usize,
39 pub pipeline_parallelism: bool,
41 pub pipeline_stages: usize,
43}
44
45impl Default for GpuAccelerationConfig {
46 fn default() -> Self {
47 Self {
48 enabled: true,
49 device_ids: vec![0],
50 memory_pool_size_mb: 2048, mixed_precision: true,
52 tensor_caching: true,
53 cache_size_mb: 512, kernel_fusion: true,
55 memory_mapping: true,
56 unified_memory: false, multi_stream: true,
58 num_streams: 4,
59 pipeline_parallelism: false, pipeline_stages: 2,
61 }
62 }
63}
64
65pub struct GpuMemoryPool {
67 config: GpuAccelerationConfig,
68 allocated_blocks: Arc<Mutex<HashMap<usize, MemoryBlock>>>,
69 free_blocks: Arc<Mutex<VecDeque<MemoryBlock>>>,
70 total_allocated: Arc<Mutex<usize>>,
71 allocation_stats: Arc<Mutex<AllocationStats>>,
72}
73
74#[derive(Debug, Clone)]
76struct MemoryBlock {
77 device_id: usize,
78 size_bytes: usize,
79 ptr: usize, allocated_at: Instant,
81 last_used: Instant,
82}
83
84#[derive(Debug, Default, Clone)]
86pub struct AllocationStats {
87 pub total_allocations: usize,
88 pub total_deallocations: usize,
89 pub peak_memory_usage: usize,
90 pub current_memory_usage: usize,
91 pub cache_hits: usize,
92 pub cache_misses: usize,
93}
94
95impl GpuMemoryPool {
96 pub fn new(config: GpuAccelerationConfig) -> Self {
98 Self {
99 config,
100 allocated_blocks: Arc::new(Mutex::new(HashMap::new())),
101 free_blocks: Arc::new(Mutex::new(VecDeque::new())),
102 total_allocated: Arc::new(Mutex::new(0)),
103 allocation_stats: Arc::new(Mutex::new(AllocationStats::default())),
104 }
105 }
106
107 pub fn allocate(&self, size_bytes: usize, device_id: usize) -> Result<usize> {
109 let mut free_blocks = self.free_blocks.lock().unwrap();
110 let mut allocated_blocks = self.allocated_blocks.lock().unwrap();
111 let mut stats = self.allocation_stats.lock().unwrap();
112
113 for (i, block) in free_blocks.iter().enumerate() {
115 if block.size_bytes >= size_bytes && block.device_id == device_id {
116 let block = free_blocks.remove(i).unwrap();
117 let block_id = block.ptr;
118
119 let mut reused_block = block;
120 reused_block.last_used = Instant::now();
121
122 allocated_blocks.insert(block_id, reused_block);
123 stats.cache_hits += 1;
124
125 debug!(
126 "Reused GPU memory block {} of size {}",
127 block_id, size_bytes
128 );
129 return Ok(block_id);
130 }
131 }
132
133 stats.cache_misses += 1;
135 stats.total_allocations += 1;
136
137 let block_id = stats.total_allocations; let now = Instant::now();
139
140 let block = MemoryBlock {
141 device_id,
142 size_bytes,
143 ptr: block_id,
144 allocated_at: now,
145 last_used: now,
146 };
147
148 allocated_blocks.insert(block_id, block);
149
150 let mut total_allocated = self.total_allocated.lock().unwrap();
151 *total_allocated += size_bytes;
152 stats.current_memory_usage += size_bytes;
153
154 if stats.current_memory_usage > stats.peak_memory_usage {
155 stats.peak_memory_usage = stats.current_memory_usage;
156 }
157
158 info!(
159 "Allocated new GPU memory block {} of size {} bytes",
160 block_id, size_bytes
161 );
162 Ok(block_id)
163 }
164
165 pub fn deallocate(&self, block_id: usize) -> Result<()> {
167 let mut allocated_blocks = self.allocated_blocks.lock().unwrap();
168 let mut free_blocks = self.free_blocks.lock().unwrap();
169 let mut stats = self.allocation_stats.lock().unwrap();
170
171 if let Some(block) = allocated_blocks.remove(&block_id) {
172 stats.total_deallocations += 1;
173 stats.current_memory_usage -= block.size_bytes;
174
175 free_blocks.push_back(block);
177
178 if free_blocks.len() > 100 {
180 free_blocks.pop_front();
181 }
182
183 debug!("Deallocated GPU memory block {}", block_id);
184 Ok(())
185 } else {
186 Err(anyhow!("Block {} not found for deallocation", block_id))
187 }
188 }
189
190 pub fn get_stats(&self) -> AllocationStats {
192 (*self.allocation_stats.lock().unwrap()).clone()
193 }
194
195 pub fn defragment(&self) -> Result<()> {
197 let mut free_blocks = self.free_blocks.lock().unwrap();
198
199 let mut blocks: Vec<_> = free_blocks.drain(..).collect();
201 blocks.sort_by_key(|b| (b.device_id, b.size_bytes));
202
203 let mut merged_blocks = VecDeque::new();
205 let mut current_block: Option<MemoryBlock> = None;
206
207 for block in blocks {
208 if let Some(ref mut current) = current_block {
209 if current.device_id == block.device_id {
210 current.size_bytes += block.size_bytes;
212 } else {
213 merged_blocks.push_back(current.clone());
214 current_block = Some(block);
215 }
216 } else {
217 current_block = Some(block);
218 }
219 }
220
221 if let Some(block) = current_block {
222 merged_blocks.push_back(block);
223 }
224
225 *free_blocks = merged_blocks;
226
227 info!(
228 "Memory defragmentation completed, {} free blocks remaining",
229 free_blocks.len()
230 );
231 Ok(())
232 }
233}
234
235pub struct TensorCache {
237 config: GpuAccelerationConfig,
238 entity_tensors: Arc<Mutex<HashMap<String, CachedTensor>>>,
239 attention_weights: Arc<Mutex<HashMap<String, CachedTensor>>>,
240 intermediate_activations: Arc<Mutex<HashMap<String, CachedTensor>>>,
241 cache_stats: Arc<Mutex<CacheStats>>,
242}
243
244#[derive(Debug, Clone)]
246struct CachedTensor {
247 data: Array2<f32>, device_id: usize,
249 last_accessed: Instant,
250 access_count: usize,
251 size_bytes: usize,
252}
253
254#[derive(Debug, Default, Clone)]
256pub struct CacheStats {
257 pub hits: usize,
258 pub misses: usize,
259 pub evictions: usize,
260 pub total_memory_usage: usize,
261}
262
263impl TensorCache {
264 pub fn new(config: GpuAccelerationConfig) -> Self {
266 Self {
267 config,
268 entity_tensors: Arc::new(Mutex::new(HashMap::new())),
269 attention_weights: Arc::new(Mutex::new(HashMap::new())),
270 intermediate_activations: Arc::new(Mutex::new(HashMap::new())),
271 cache_stats: Arc::new(Mutex::new(CacheStats::default())),
272 }
273 }
274
275 pub fn cache_entity_tensor(&self, entity: &str, tensor: Array2<f32>, device_id: usize) {
277 let mut cache = self.entity_tensors.lock().unwrap();
278 let mut stats = self.cache_stats.lock().unwrap();
279
280 let size_bytes = tensor.len() * std::mem::size_of::<f32>();
281
282 let cached_tensor = CachedTensor {
283 data: tensor,
284 device_id,
285 last_accessed: Instant::now(),
286 access_count: 1,
287 size_bytes,
288 };
289
290 self.evict_if_needed(&mut stats);
292
293 cache.insert(entity.to_string(), cached_tensor);
294 stats.total_memory_usage += size_bytes;
295
296 debug!("Cached entity tensor for {}", entity);
297 }
298
299 pub fn get_entity_tensor(&self, entity: &str) -> Option<Array2<f32>> {
301 let mut cache = self.entity_tensors.lock().unwrap();
302 let mut stats = self.cache_stats.lock().unwrap();
303
304 if let Some(cached) = cache.get_mut(entity) {
305 cached.last_accessed = Instant::now();
306 cached.access_count += 1;
307 stats.hits += 1;
308
309 debug!("Cache hit for entity tensor {}", entity);
310 Some(cached.data.clone())
311 } else {
312 stats.misses += 1;
313 debug!("Cache miss for entity tensor {}", entity);
314 None
315 }
316 }
317
318 pub fn cache_attention_weights(&self, key: &str, weights: Array2<f32>, device_id: usize) {
320 let mut cache = self.attention_weights.lock().unwrap();
321 let mut stats = self.cache_stats.lock().unwrap();
322
323 let size_bytes = weights.len() * std::mem::size_of::<f32>();
324
325 let cached_tensor = CachedTensor {
326 data: weights,
327 device_id,
328 last_accessed: Instant::now(),
329 access_count: 1,
330 size_bytes,
331 };
332
333 self.evict_if_needed(&mut stats);
334
335 cache.insert(key.to_string(), cached_tensor);
336 stats.total_memory_usage += size_bytes;
337
338 debug!("Cached attention weights for key {}", key);
339 }
340
341 pub fn get_attention_weights(&self, key: &str) -> Option<Array2<f32>> {
343 let mut cache = self.attention_weights.lock().unwrap();
344 let mut stats = self.cache_stats.lock().unwrap();
345
346 if let Some(cached) = cache.get_mut(key) {
347 cached.last_accessed = Instant::now();
348 cached.access_count += 1;
349 stats.hits += 1;
350
351 debug!("Cache hit for attention weights {}", key);
352 Some(cached.data.clone())
353 } else {
354 stats.misses += 1;
355 debug!("Cache miss for attention weights {}", key);
356 None
357 }
358 }
359
360 fn evict_if_needed(&self, stats: &mut CacheStats) {
362 let max_memory = self.config.cache_size_mb * 1024 * 1024; if stats.total_memory_usage > max_memory {
365 stats.evictions += 1;
367 stats.total_memory_usage = max_memory / 2; warn!("Tensor cache eviction triggered, freed memory");
370 }
371 }
372
373 pub fn get_stats(&self) -> CacheStats {
375 (*self.cache_stats.lock().unwrap()).clone()
376 }
377
378 pub fn clear_all(&self) {
380 self.entity_tensors.lock().unwrap().clear();
381 self.attention_weights.lock().unwrap().clear();
382 self.intermediate_activations.lock().unwrap().clear();
383
384 let mut stats = self.cache_stats.lock().unwrap();
385 stats.total_memory_usage = 0;
386
387 info!("Cleared all tensor caches");
388 }
389}
390
391pub struct MixedPrecisionProcessor {
393 config: GpuAccelerationConfig,
394 fp16_enabled: bool,
395 loss_scaling: f32,
396 overflow_detection: bool,
397}
398
399impl MixedPrecisionProcessor {
400 pub fn new(config: GpuAccelerationConfig) -> Self {
402 Self {
403 config: config.clone(),
404 fp16_enabled: config.mixed_precision,
405 loss_scaling: 65536.0, overflow_detection: true,
407 }
408 }
409
410 pub fn to_fp16(&self, tensor: &Array2<f32>) -> Array2<f32> {
412 if !self.fp16_enabled {
413 return tensor.clone();
414 }
415
416 tensor.mapv(|x| {
418 let clamped = x.clamp(-65504.0, 65504.0);
420 (clamped * 1024.0).round() / 1024.0 })
422 }
423
424 pub fn scale_loss(&self, loss: f32) -> f32 {
426 if self.fp16_enabled {
427 loss * self.loss_scaling
428 } else {
429 loss
430 }
431 }
432
433 pub fn unscale_gradients(&self, gradients: &mut Array2<f32>) -> bool {
435 if !self.fp16_enabled {
436 return true;
437 }
438
439 if self.overflow_detection {
441 let has_overflow = gradients.iter().any(|&x| !x.is_finite());
442 if has_overflow {
443 warn!("Gradient overflow detected in mixed precision training");
444 return false;
445 }
446 }
447
448 gradients.mapv_inplace(|x| x / self.loss_scaling);
450 true
451 }
452
453 pub fn adjust_loss_scaling(&mut self, overflow_detected: bool) {
455 if overflow_detected {
456 self.loss_scaling = (self.loss_scaling / 2.0).max(1.0);
457 info!("Reduced loss scaling to {}", self.loss_scaling);
458 } else {
459 self.loss_scaling = (self.loss_scaling * 1.1).min(65536.0);
461 }
462 }
463}
464
465pub struct MultiStreamProcessor {
467 config: GpuAccelerationConfig,
468 pub stream_ids: Vec<usize>,
469 current_stream: usize,
470}
471
472impl MultiStreamProcessor {
473 pub fn new(config: GpuAccelerationConfig) -> Self {
475 let stream_ids = (0..config.num_streams).collect();
476
477 Self {
478 config,
479 stream_ids,
480 current_stream: 0,
481 }
482 }
483
484 pub fn get_next_stream(&mut self) -> usize {
486 let stream_id = self.stream_ids[self.current_stream];
487 self.current_stream = (self.current_stream + 1) % self.stream_ids.len();
488 stream_id
489 }
490
491 pub async fn process_batch_parallel(
493 &mut self,
494 entities: Vec<String>,
495 process_fn: impl Fn(String, usize) -> Array1<f32> + Send + Sync + Copy + 'static,
496 ) -> Result<Vec<Array1<f32>>> {
497 let chunk_size = (entities.len() + self.config.num_streams - 1) / self.config.num_streams;
498 let mut tasks = Vec::new();
499
500 for chunk in entities.chunks(chunk_size) {
501 let stream_id = self.get_next_stream();
502 let chunk_entities = chunk.to_vec();
503
504 let task = tokio::spawn(async move {
505 let mut results = Vec::new();
506 for entity in chunk_entities {
507 let embedding = process_fn(entity, stream_id);
508 results.push(embedding);
509 }
510 results
511 });
512
513 tasks.push(task);
514 }
515
516 let mut all_results = Vec::new();
518 for task in tasks {
519 let chunk_results = task.await?;
520 all_results.extend(chunk_results);
521 }
522
523 Ok(all_results)
524 }
525
526 pub fn synchronize_all(&self) {
528 debug!("Synchronized {} GPU streams", self.stream_ids.len());
530 }
531}
532
533pub struct GpuAccelerationManager {
535 config: GpuAccelerationConfig,
536 memory_pool: GpuMemoryPool,
537 tensor_cache: TensorCache,
538 mixed_precision: MixedPrecisionProcessor,
539 multi_stream: MultiStreamProcessor,
540}
541
542impl GpuAccelerationManager {
543 pub fn new(config: GpuAccelerationConfig) -> Self {
545 let memory_pool = GpuMemoryPool::new(config.clone());
546 let tensor_cache = TensorCache::new(config.clone());
547 let mixed_precision = MixedPrecisionProcessor::new(config.clone());
548 let multi_stream = MultiStreamProcessor::new(config.clone());
549
550 Self {
551 config,
552 memory_pool,
553 tensor_cache,
554 mixed_precision,
555 multi_stream,
556 }
557 }
558
559 pub fn memory_pool(&self) -> &GpuMemoryPool {
561 &self.memory_pool
562 }
563
564 pub fn tensor_cache(&self) -> &TensorCache {
566 &self.tensor_cache
567 }
568
569 pub fn mixed_precision(&mut self) -> &mut MixedPrecisionProcessor {
571 &mut self.mixed_precision
572 }
573
574 pub fn multi_stream(&mut self) -> &mut MultiStreamProcessor {
576 &mut self.multi_stream
577 }
578
579 pub async fn accelerated_embedding_generation(
581 &mut self,
582 entities: Vec<String>,
583 base_compute_fn: impl Fn(&str) -> Array1<f32> + Send + Sync + Copy + 'static,
584 ) -> Result<Vec<Array1<f32>>> {
585 if !self.config.enabled {
586 return Ok(entities.iter().map(|e| base_compute_fn(e)).collect());
588 }
589
590 let results = self
592 .multi_stream
593 .process_batch_parallel(entities, move |entity, stream_id| {
594 debug!("Processing entity {} on stream {}", entity, stream_id);
596 base_compute_fn(&entity)
597 })
598 .await?;
599
600 self.multi_stream.synchronize_all();
601 Ok(results)
602 }
603
604 pub fn get_performance_stats(&self) -> GpuPerformanceStats {
606 let memory_stats = self.memory_pool.get_stats();
607 let cache_stats = self.tensor_cache.get_stats();
608
609 GpuPerformanceStats {
610 memory_allocations: memory_stats.total_allocations,
611 memory_deallocations: memory_stats.total_deallocations,
612 peak_memory_usage_mb: memory_stats.peak_memory_usage / (1024 * 1024),
613 current_memory_usage_mb: memory_stats.current_memory_usage / (1024 * 1024),
614 memory_pool_hits: memory_stats.cache_hits,
615 memory_pool_misses: memory_stats.cache_misses,
616 tensor_cache_hits: cache_stats.hits,
617 tensor_cache_misses: cache_stats.misses,
618 tensor_cache_evictions: cache_stats.evictions,
619 tensor_cache_memory_mb: cache_stats.total_memory_usage / (1024 * 1024),
620 loss_scaling_factor: self.mixed_precision.loss_scaling,
621 num_active_streams: self.config.num_streams,
622 }
623 }
624}
625
626#[derive(Debug, Serialize)]
628pub struct GpuPerformanceStats {
629 pub memory_allocations: usize,
630 pub memory_deallocations: usize,
631 pub peak_memory_usage_mb: usize,
632 pub current_memory_usage_mb: usize,
633 pub memory_pool_hits: usize,
634 pub memory_pool_misses: usize,
635 pub tensor_cache_hits: usize,
636 pub tensor_cache_misses: usize,
637 pub tensor_cache_evictions: usize,
638 pub tensor_cache_memory_mb: usize,
639 pub loss_scaling_factor: f32,
640 pub num_active_streams: usize,
641}
642
643pub struct MemoryDefragmenter {
645 config: GpuAccelerationConfig,
646 defrag_threshold: f32,
647 last_defrag: Instant,
648 defrag_interval: Duration,
649}
650
651impl MemoryDefragmenter {
652 pub fn new(config: GpuAccelerationConfig) -> Self {
654 Self {
655 config,
656 defrag_threshold: 0.7, last_defrag: Instant::now(),
658 defrag_interval: Duration::from_secs(300), }
660 }
661
662 pub fn should_defragment(&self, memory_pool: &GpuMemoryPool) -> bool {
664 let stats = memory_pool.get_stats();
665 let fragmentation_ratio = self.calculate_fragmentation_ratio(&stats);
666
667 fragmentation_ratio > self.defrag_threshold
668 && self.last_defrag.elapsed() > self.defrag_interval
669 }
670
671 fn calculate_fragmentation_ratio(&self, stats: &AllocationStats) -> f32 {
673 if stats.current_memory_usage == 0 {
674 return 0.0;
675 }
676
677 let theoretical_optimal = stats.current_memory_usage;
680 let actual_allocated = stats.peak_memory_usage;
681
682 if actual_allocated == 0 {
683 0.0
684 } else {
685 1.0 - (theoretical_optimal as f32 / actual_allocated as f32)
686 }
687 }
688
689 pub fn defragment(&mut self, memory_pool: &GpuMemoryPool) -> Result<DefragmentationResult> {
691 info!("Starting GPU memory defragmentation");
692 let start_time = Instant::now();
693
694 std::thread::sleep(Duration::from_millis(100));
701
702 let stats_before = memory_pool.get_stats();
703
704 let stats_after = memory_pool.get_stats();
708 self.last_defrag = Instant::now();
709
710 let result = DefragmentationResult {
711 duration: start_time.elapsed(),
712 memory_freed: stats_before
713 .peak_memory_usage
714 .saturating_sub(stats_after.current_memory_usage),
715 fragmentation_before: self.calculate_fragmentation_ratio(&stats_before),
716 fragmentation_after: self.calculate_fragmentation_ratio(&stats_after),
717 };
718
719 info!("Defragmentation completed: {:?}", result);
720 Ok(result)
721 }
722}
723
724#[derive(Debug, Clone)]
726pub struct DefragmentationResult {
727 pub duration: Duration,
728 pub memory_freed: usize,
729 pub fragmentation_before: f32,
730 pub fragmentation_after: f32,
731}
732
733pub struct OutOfCoreProcessor {
735 config: GpuAccelerationConfig,
736 chunk_size: usize,
737 overlap_size: usize,
738 memory_limit: usize,
739}
740
741impl OutOfCoreProcessor {
742 pub fn new(config: GpuAccelerationConfig) -> Self {
744 let memory_limit = config.memory_pool_size_mb * 1024 * 1024; let chunk_size = memory_limit / 4; let overlap_size = chunk_size / 10; Self {
749 config,
750 chunk_size,
751 overlap_size,
752 memory_limit,
753 }
754 }
755
756 pub async fn process_large_batch<T>(
758 &self,
759 data: Vec<T>,
760 process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>> + Send + Sync + Copy,
761 ) -> Result<Vec<Array1<f32>>>
762 where
763 T: Clone + Send + Sync + 'static,
764 {
765 if data.is_empty() {
766 return Ok(Vec::new());
767 }
768
769 let item_size = std::mem::size_of::<T>();
771 let max_items_per_chunk = self.chunk_size / item_size;
772 let chunk_size = max_items_per_chunk.clamp(1, 1000); info!(
775 "Processing {} items in chunks of {}",
776 data.len(),
777 chunk_size
778 );
779
780 let mut results = Vec::new();
781 let mut processed_count = 0;
782
783 for chunk in data.chunks(chunk_size) {
784 let chunk_results = process_fn(chunk)?;
786 results.extend(chunk_results);
787
788 processed_count += chunk.len();
789
790 if processed_count % (chunk_size * 10) == 0 {
791 info!("Processed {}/{} items", processed_count, data.len());
792 }
793
794 tokio::task::yield_now().await;
796 }
797
798 Ok(results)
799 }
800
801 pub async fn process_with_overlap<T>(
803 &self,
804 data: Vec<T>,
805 process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>> + Send + Sync + Copy,
806 ) -> Result<Vec<Array1<f32>>>
807 where
808 T: Clone + Send + Sync + 'static,
809 {
810 if data.is_empty() {
811 return Ok(Vec::new());
812 }
813
814 let item_size = std::mem::size_of::<T>();
815 let max_items_per_chunk = self.chunk_size / item_size;
816 let chunk_size = max_items_per_chunk.clamp(1, 1000);
817
818 let mut results = Vec::new();
819 let mut start_idx = 0;
820
821 while start_idx < data.len() {
822 let end_idx = (start_idx + chunk_size).min(data.len());
823 let chunk = &data[start_idx..end_idx];
824
825 let chunk_results = process_fn(chunk)?;
826
827 let take_count = if start_idx == 0 {
829 chunk_results.len()
830 } else {
831 chunk_results
833 .len()
834 .saturating_sub(self.overlap_size / item_size)
835 };
836
837 results.extend(chunk_results.into_iter().take(take_count));
838
839 start_idx += chunk_size - self.overlap_size / item_size;
840 tokio::task::yield_now().await;
841 }
842
843 Ok(results)
844 }
845}
846
847pub struct DynamicShapeHandler {
849 config: GpuAccelerationConfig,
850 shape_cache: HashMap<Vec<usize>, ShapeInfo>,
851 max_cached_shapes: usize,
852}
853
854#[derive(Debug, Clone)]
856struct ShapeInfo {
857 shape: Vec<usize>,
858 memory_requirement: usize,
859 optimal_batch_size: usize,
860 last_used: Instant,
861}
862
863impl DynamicShapeHandler {
864 pub fn new(config: GpuAccelerationConfig) -> Self {
866 Self {
867 config,
868 shape_cache: HashMap::new(),
869 max_cached_shapes: 100,
870 }
871 }
872
873 pub fn optimize_shape(&mut self, shape: Vec<usize>) -> Vec<usize> {
875 if let Some(shape_info) = self.shape_cache.get_mut(&shape) {
877 shape_info.last_used = Instant::now();
878 return shape_info.shape.clone();
879 }
880
881 let optimized_shape = self.calculate_optimal_shape(&shape);
883
884 self.cache_shape_info(shape.clone(), optimized_shape.clone());
886
887 optimized_shape
888 }
889
890 fn calculate_optimal_shape(&self, shape: &[usize]) -> Vec<usize> {
892 let mut optimized = shape.to_vec();
893
894 const WARP_SIZE: usize = 32;
896
897 for dim in &mut optimized {
898 if *dim > 0 {
899 *dim = ((*dim + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
901 }
902 }
903
904 optimized
905 }
906
907 fn cache_shape_info(&mut self, original_shape: Vec<usize>, optimized_shape: Vec<usize>) {
909 if self.shape_cache.len() >= self.max_cached_shapes {
911 self.evict_oldest_shape();
912 }
913
914 let memory_requirement = optimized_shape.iter().product::<usize>() * 4; let optimal_batch_size = self.calculate_optimal_batch_size(memory_requirement);
916
917 let shape_info = ShapeInfo {
918 shape: optimized_shape,
919 memory_requirement,
920 optimal_batch_size,
921 last_used: Instant::now(),
922 };
923
924 self.shape_cache.insert(original_shape, shape_info);
925 }
926
927 fn calculate_optimal_batch_size(&self, memory_per_item: usize) -> usize {
929 if memory_per_item == 0 {
930 return 1;
931 }
932
933 let available_memory = (self.config.memory_pool_size_mb * 1024 * 1024) / 2; let max_batch_size = available_memory / memory_per_item;
935
936 max_batch_size.clamp(1, 1024)
938 }
939
940 fn evict_oldest_shape(&mut self) {
942 if let Some(oldest_key) = self
943 .shape_cache
944 .iter()
945 .min_by_key(|(_, info)| info.last_used)
946 .map(|(key, _)| key.clone())
947 {
948 self.shape_cache.remove(&oldest_key);
949 }
950 }
951
952 pub fn get_optimal_batch_size(&self, shape: &[usize]) -> usize {
954 self.shape_cache
955 .get(shape)
956 .map(|info| info.optimal_batch_size)
957 .unwrap_or(1)
958 }
959}
960
961pub struct BatchSizeOptimizer {
963 config: GpuAccelerationConfig,
964 performance_history: VecDeque<BatchPerformance>,
965 max_history_size: usize,
966 current_optimal_batch_size: usize,
967}
968
969#[derive(Debug, Clone)]
971struct BatchPerformance {
972 batch_size: usize,
973 processing_time: Duration,
974 memory_usage: usize,
975 throughput: f64, gpu_utilization: f64,
977 timestamp: Instant,
978}
979
980impl BatchSizeOptimizer {
981 pub fn new(config: GpuAccelerationConfig) -> Self {
983 Self {
984 config,
985 performance_history: VecDeque::new(),
986 max_history_size: 50,
987 current_optimal_batch_size: 32, }
989 }
990
991 pub async fn find_optimal_batch_size<T>(
993 &mut self,
994 sample_data: Vec<T>,
995 process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>> + Send + Sync + Copy,
996 ) -> Result<usize>
997 where
998 T: Clone + Send + Sync + 'static,
999 {
1000 if sample_data.is_empty() {
1001 return Ok(1);
1002 }
1003
1004 info!("Optimizing batch size for embedding generation");
1005
1006 let test_sizes = vec![1, 8, 16, 32, 64, 128, 256, 512];
1007 let max_test_size = sample_data.len().min(512);
1008
1009 let mut best_batch_size = 1;
1010 let mut best_throughput = 0.0;
1011
1012 for &batch_size in &test_sizes {
1013 if batch_size > max_test_size {
1014 break;
1015 }
1016
1017 let performance = self
1019 .test_batch_size(
1020 &sample_data[..batch_size.min(sample_data.len())],
1021 batch_size,
1022 process_fn,
1023 )
1024 .await?;
1025
1026 info!(
1027 "Batch size {}: {:.2} items/sec, {:.1}ms processing time",
1028 batch_size,
1029 performance.throughput,
1030 performance.processing_time.as_millis()
1031 );
1032
1033 if performance.throughput > best_throughput {
1034 best_throughput = performance.throughput;
1035 best_batch_size = batch_size;
1036 }
1037
1038 self.performance_history.push_back(performance);
1040 if self.performance_history.len() > self.max_history_size {
1041 self.performance_history.pop_front();
1042 }
1043
1044 tokio::time::sleep(Duration::from_millis(100)).await;
1046 }
1047
1048 self.current_optimal_batch_size = best_batch_size;
1049 info!("Optimal batch size determined: {}", best_batch_size);
1050
1051 Ok(best_batch_size)
1052 }
1053
1054 async fn test_batch_size<T>(
1056 &self,
1057 sample_data: &[T],
1058 batch_size: usize,
1059 process_fn: impl Fn(&[T]) -> Result<Vec<Array1<f32>>>,
1060 ) -> Result<BatchPerformance>
1061 where
1062 T: Clone,
1063 {
1064 let start_time = Instant::now();
1065 let memory_before = self.estimate_memory_usage();
1066
1067 let _results = process_fn(sample_data)?;
1069
1070 let processing_time = start_time.elapsed();
1071 let memory_after = self.estimate_memory_usage();
1072 let memory_usage = memory_after.saturating_sub(memory_before);
1073
1074 let throughput = if processing_time.as_secs_f64() > 0.0 {
1076 sample_data.len() as f64 / processing_time.as_secs_f64()
1077 } else {
1078 0.0
1079 };
1080
1081 let gpu_utilization = self.estimate_gpu_utilization(batch_size, processing_time);
1083
1084 Ok(BatchPerformance {
1085 batch_size,
1086 processing_time,
1087 memory_usage,
1088 throughput,
1089 gpu_utilization,
1090 timestamp: Instant::now(),
1091 })
1092 }
1093
1094 fn estimate_memory_usage(&self) -> usize {
1096 (self.config.memory_pool_size_mb * 1024 * 1024) / 4 }
1100
1101 fn estimate_gpu_utilization(&self, batch_size: usize, processing_time: Duration) -> f64 {
1103 let base_utilization = (batch_size as f64).log2() / 10.0; let time_factor = if processing_time.as_millis() < 10 {
1106 0.5 } else if processing_time.as_millis() > 1000 {
1108 0.7 } else {
1110 1.0
1111 };
1112
1113 (base_utilization * time_factor).clamp(0.0, 1.0)
1114 }
1115
1116 pub fn get_optimal_batch_size(&self) -> usize {
1118 self.current_optimal_batch_size
1119 }
1120
1121 pub fn get_performance_stats(&self) -> BatchSizeOptimizerStats {
1123 let avg_throughput = if !self.performance_history.is_empty() {
1124 self.performance_history
1125 .iter()
1126 .map(|p| p.throughput)
1127 .sum::<f64>()
1128 / self.performance_history.len() as f64
1129 } else {
1130 0.0
1131 };
1132
1133 let avg_gpu_utilization = if !self.performance_history.is_empty() {
1134 self.performance_history
1135 .iter()
1136 .map(|p| p.gpu_utilization)
1137 .sum::<f64>()
1138 / self.performance_history.len() as f64
1139 } else {
1140 0.0
1141 };
1142
1143 BatchSizeOptimizerStats {
1144 current_optimal_batch_size: self.current_optimal_batch_size,
1145 avg_throughput,
1146 avg_gpu_utilization,
1147 total_tests_performed: self.performance_history.len(),
1148 }
1149 }
1150}
1151
1152#[derive(Debug, Clone, Serialize, Deserialize)]
1154pub struct BatchSizeOptimizerStats {
1155 pub current_optimal_batch_size: usize,
1156 pub avg_throughput: f64,
1157 pub avg_gpu_utilization: f64,
1158 pub total_tests_performed: usize,
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163 use super::*;
1164
1165 #[test]
1166 fn test_gpu_acceleration_config_default() {
1167 let config = GpuAccelerationConfig::default();
1168 assert!(config.enabled);
1169 assert_eq!(config.device_ids, vec![0]);
1170 assert_eq!(config.memory_pool_size_mb, 2048);
1171 assert!(config.mixed_precision);
1172 assert!(config.tensor_caching);
1173 }
1174
1175 #[test]
1176 fn test_memory_pool_allocation() {
1177 let config = GpuAccelerationConfig::default();
1178 let pool = GpuMemoryPool::new(config);
1179
1180 let block_id = pool.allocate(1024, 0).unwrap();
1181 assert!(block_id > 0);
1182
1183 pool.deallocate(block_id).unwrap();
1184
1185 let block_id2 = pool.allocate(1024, 0).unwrap();
1187 assert_eq!(block_id, block_id2);
1188 }
1189
1190 #[test]
1191 fn test_tensor_cache() {
1192 let config = GpuAccelerationConfig::default();
1193 let cache = TensorCache::new(config);
1194
1195 let tensor = Array2::zeros((10, 20));
1196 cache.cache_entity_tensor("test_entity", tensor.clone(), 0);
1197
1198 let cached = cache.get_entity_tensor("test_entity").unwrap();
1199 assert_eq!(cached.shape(), tensor.shape());
1200 }
1201
1202 #[test]
1203 fn test_mixed_precision() {
1204 let config = GpuAccelerationConfig::default();
1205 let processor = MixedPrecisionProcessor::new(config);
1206
1207 let tensor = Array2::from_elem((2, 2), 1.0001);
1209 let fp16_tensor = processor.to_fp16(&tensor);
1210
1211 if processor.fp16_enabled {
1212 assert!(fp16_tensor[[0, 0]] != tensor[[0, 0]]);
1214 } else {
1215 assert_eq!(fp16_tensor[[0, 0]], tensor[[0, 0]]);
1217 }
1218 }
1219
1220 #[tokio::test]
1221 async fn test_multi_stream_processing() {
1222 let config = GpuAccelerationConfig::default();
1223 let mut processor = MultiStreamProcessor::new(config);
1224
1225 let entities = vec!["entity1".to_string(), "entity2".to_string()];
1226 let process_fn = |entity: String, _stream_id: usize| -> Array1<f32> {
1227 Array1::from_vec(vec![entity.len() as f32])
1228 };
1229
1230 let results = processor
1231 .process_batch_parallel(entities, process_fn)
1232 .await
1233 .unwrap();
1234 assert_eq!(results.len(), 2);
1235 }
1236}