oxirs_vec/gpu/
config.rs

1//! GPU configuration structures and enums
2
3use serde::{Deserialize, Serialize};
4
5/// Configuration for GPU operations
6#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
7pub struct GpuConfig {
8    pub device_id: i32,
9    pub enable_mixed_precision: bool,
10    pub enable_tensor_cores: bool,
11    pub batch_size: usize,
12    pub memory_pool_size: usize,
13    pub stream_count: usize,
14    pub enable_peer_access: bool,
15    pub enable_unified_memory: bool,
16    pub enable_async_execution: bool,
17    pub enable_multi_gpu: bool,
18    pub preferred_gpu_ids: Vec<i32>,
19    pub dynamic_batch_sizing: bool,
20    pub enable_memory_compression: bool,
21    pub kernel_cache_size: usize,
22    pub optimization_level: OptimizationLevel,
23    pub precision_mode: PrecisionMode,
24}
25
26/// GPU optimization levels
27#[derive(
28    Debug, Clone, Copy, PartialEq, Serialize, Deserialize, oxicode::Encode, oxicode::Decode,
29)]
30pub enum OptimizationLevel {
31    Debug,       // Maximum debugging, minimal optimization
32    Balanced,    // Good balance of performance and debugging
33    Performance, // Maximum performance, minimal debugging
34    Extreme,     // Aggressive optimizations, may reduce precision
35}
36
37/// Precision modes for GPU computations
38#[derive(
39    Debug, Clone, Copy, PartialEq, Serialize, Deserialize, oxicode::Encode, oxicode::Decode,
40)]
41pub enum PrecisionMode {
42    FP32,     // Single precision
43    FP16,     // Half precision
44    Mixed,    // Mixed precision (FP16 for compute, FP32 for storage)
45    INT8,     // 8-bit integer quantization
46    Adaptive, // Adaptive precision based on data characteristics
47}
48
49impl Default for GpuConfig {
50    fn default() -> Self {
51        Self {
52            device_id: 0,
53            enable_mixed_precision: true,
54            enable_tensor_cores: true,
55            batch_size: 1024,
56            memory_pool_size: 1024 * 1024 * 1024, // 1GB
57            stream_count: 4,
58            enable_peer_access: false,
59            enable_unified_memory: false,
60            enable_async_execution: true,
61            enable_multi_gpu: false,
62            preferred_gpu_ids: vec![0],
63            dynamic_batch_sizing: true,
64            enable_memory_compression: false,
65            kernel_cache_size: 100, // Cache up to 100 compiled kernels
66            optimization_level: OptimizationLevel::Balanced,
67            precision_mode: PrecisionMode::FP32,
68        }
69    }
70}
71
72impl GpuConfig {
73    /// Create a high-performance configuration
74    pub fn high_performance() -> Self {
75        Self {
76            optimization_level: OptimizationLevel::Performance,
77            enable_mixed_precision: true,
78            enable_tensor_cores: true,
79            enable_async_execution: true,
80            batch_size: 2048,
81            stream_count: 8,
82            ..Default::default()
83        }
84    }
85
86    /// Create a memory-optimized configuration
87    pub fn memory_optimized() -> Self {
88        Self {
89            enable_memory_compression: true,
90            enable_unified_memory: true,
91            batch_size: 512,
92            memory_pool_size: 512 * 1024 * 1024, // 512MB
93            ..Default::default()
94        }
95    }
96
97    /// Create a debug-friendly configuration
98    pub fn debug() -> Self {
99        Self {
100            optimization_level: OptimizationLevel::Debug,
101            enable_mixed_precision: false,
102            enable_async_execution: false,
103            batch_size: 64,
104            stream_count: 1,
105            ..Default::default()
106        }
107    }
108
109    /// Validate the configuration
110    pub fn validate(&self) -> anyhow::Result<()> {
111        if self.batch_size == 0 {
112            return Err(anyhow::anyhow!("Batch size must be greater than 0"));
113        }
114        if self.stream_count == 0 {
115            return Err(anyhow::anyhow!("Stream count must be greater than 0"));
116        }
117        if self.memory_pool_size == 0 {
118            return Err(anyhow::anyhow!("Memory pool size must be greater than 0"));
119        }
120        if self.kernel_cache_size == 0 {
121            return Err(anyhow::anyhow!("Kernel cache size must be greater than 0"));
122        }
123        if self.preferred_gpu_ids.is_empty() {
124            return Err(anyhow::anyhow!(
125                "Must specify at least one preferred GPU ID"
126            ));
127        }
128        Ok(())
129    }
130
131    /// Calculate optimal batch size based on available memory
132    pub fn calculate_optimal_batch_size(
133        &self,
134        vector_dim: usize,
135        available_memory: usize,
136    ) -> usize {
137        let bytes_per_vector = vector_dim * std::mem::size_of::<f32>();
138        let max_vectors = available_memory / bytes_per_vector / 4; // Reserve 75% for safety
139        max_vectors
140            .min(self.batch_size * 4)
141            .max(self.batch_size / 4)
142    }
143}