Skip to main content

dbx_core/storage/gpu/
manager.rs

1//! GPU Manager core - initialization, data upload, and cache management.
2
3#[cfg(feature = "gpu")]
4use arrow::array::{Array, Float64Array, Int32Array, Int64Array};
5#[cfg(feature = "gpu")]
6use cudarc::driver::{CudaContext, CudaModule, PushKernelArg};
7#[cfg(feature = "gpu")]
8use cudarc::nvrtc::compile_ptx;
9#[cfg(feature = "gpu")]
10use dashmap::DashMap;
11
12use arrow::record_batch::RecordBatch;
13
14#[cfg(feature = "gpu")]
15use super::data::GpuData;
16#[cfg(feature = "gpu")]
17use super::memory_pool::GpuMemoryPool;
18use super::strategy::{GpuHashStrategy, GpuReductionStrategy};
19use crate::error::{DbxError, DbxResult};
20
21#[cfg(feature = "gpu")]
22const KERNELS_SRC: &str = include_str!("../kernels.cu");
23
24/// Manager for GPU-accelerated operations.
25pub struct GpuManager {
26    /// CUDA device context (pub(super) for impl blocks in other files)
27    #[cfg(feature = "gpu")]
28    pub(super) device: Arc<CudaContext>,
29
30    /// Compiled CUDA module (pub(super) for impl blocks in other files)
31    #[cfg(feature = "gpu")]
32    pub(super) module: Arc<CudaModule>,
33
34    /// Buffer cache: table_name -> column_name -> GpuData
35    /// This avoids re-uploading data that hasn't changed.
36    #[cfg(feature = "gpu")]
37    pub(super) buffer_cache: DashMap<String, DashMap<String, Arc<GpuData>>>,
38
39    /// Hash strategy for GROUP BY operations (runtime configurable)
40    pub(super) hash_strategy: GpuHashStrategy,
41
42    /// Reduction strategy for SUM operations (runtime configurable)
43    pub(super) reduction_strategy: GpuReductionStrategy,
44
45    /// Memory pool for efficient GPU memory allocation
46    #[cfg(feature = "gpu")]
47    pub(super) memory_pool: Arc<GpuMemoryPool>,
48}
49
50impl GpuManager {
51    /// Create a new GpuManager. Returns None if GPU acceleration is disabled
52    /// or if no compatible device is found.
53    pub fn try_new() -> Option<Self> {
54        #[cfg(feature = "gpu")]
55        {
56            let device = match CudaContext::new(0) {
57                Ok(d) => d,
58                Err(e) => {
59                    eprintln!(
60                        "⚠️  GPU Manager: Failed to initialize CUDA device 0: {:?}",
61                        e
62                    );
63                    return None;
64                }
65            };
66
67            // Compile and Load kernels
68            let ptx = match compile_ptx(KERNELS_SRC) {
69                Ok(p) => p,
70                Err(e) => {
71                    eprintln!("⚠️  GPU Manager: Failed to compile PTX kernels: {:?}", e);
72                    return None;
73                }
74            };
75
76            let module = match device.load_module(ptx) {
77                Ok(m) => m,
78                Err(e) => {
79                    eprintln!("⚠️  GPU Manager: Failed to load CUDA module: {:?}", e);
80                    return None;
81                }
82            };
83
84            let memory_pool = Arc::new(GpuMemoryPool::new(
85                device.clone(),
86                256, // 256MB default cache
87            ));
88
89            eprintln!("✅ GPU Manager initialized successfully");
90            Some(Self {
91                device,
92                module,
93                buffer_cache: DashMap::new(),
94                hash_strategy: GpuHashStrategy::default(), // Linear by default
95                reduction_strategy: GpuReductionStrategy::default(), // Auto by default
96                memory_pool,
97            })
98        }
99        #[cfg(not(feature = "gpu"))]
100        {
101            #[allow(unreachable_code)]
102            {
103                None
104            }
105        }
106    }
107
108    /// Set GPU hash strategy for GROUP BY operations
109    pub fn set_hash_strategy(&mut self, strategy: GpuHashStrategy) {
110        self.hash_strategy = strategy;
111    }
112
113    /// Get current GPU hash strategy
114    pub fn hash_strategy(&self) -> GpuHashStrategy {
115        self.hash_strategy
116    }
117
118    /// Set GPU reduction strategy for SUM operations
119    pub fn set_reduction_strategy(&mut self, strategy: GpuReductionStrategy) {
120        self.reduction_strategy = strategy;
121    }
122
123    /// Get current GPU reduction strategy
124    pub fn reduction_strategy(&self) -> GpuReductionStrategy {
125        self.reduction_strategy
126    }
127
128    /// Upload a RecordBatch to GPU memory and cache it.
129    pub fn upload_batch(&self, table: &str, batch: &RecordBatch) -> DbxResult<()> {
130        #[cfg(not(feature = "gpu"))]
131        {
132            let _ = (table, batch);
133            Err(DbxError::NotImplemented(
134                "GPU acceleration is not enabled".to_string(),
135            ))
136        }
137
138        #[cfg(feature = "gpu")]
139        {
140            tracing::debug!(target: "gpu", table = %table, rows = batch.num_rows(), "GPU upload_batch start");
141            let start = std::time::Instant::now();
142
143            let table_cache = self
144                .buffer_cache
145                .entry(table.to_string())
146                .or_insert_with(DashMap::new);
147            let schema = batch.schema();
148
149            for (i, column) in batch.columns().iter().enumerate() {
150                let column_name = schema.field(i).name();
151                if table_cache.contains_key(column_name) {
152                    continue;
153                }
154
155                let gpu_data = self.convert_and_upload(column)?;
156                table_cache.insert(column_name.clone(), Arc::new(gpu_data));
157            }
158
159            tracing::debug!(target: "gpu", table = %table, elapsed_us = start.elapsed().as_micros(), "GPU upload_batch complete");
160            Ok(())
161        }
162    }
163
164    #[cfg(feature = "gpu")]
165    fn convert_and_upload(&self, array: &Arc<dyn Array>) -> DbxResult<GpuData> {
166        match array.data_type() {
167            arrow::datatypes::DataType::Int32 => {
168                let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
169                let stream = self.device.default_stream();
170                // Zero-copy: Access the underlying slice directly
171                let slice = stream
172                    .clone_htod(&arr.values()[..])
173                    .map_err(|e| DbxError::Gpu(format!("CUDA HTOD copy (i32) failed: {:?}", e)))?;
174                Ok(GpuData::Int32(slice))
175            }
176            arrow::datatypes::DataType::Int64 => {
177                let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
178                let stream = self.device.default_stream();
179                let slice = stream
180                    .clone_htod(&arr.values()[..])
181                    .map_err(|e| DbxError::Gpu(format!("CUDA HTOD copy (i64) failed: {:?}", e)))?;
182                Ok(GpuData::Int64(slice))
183            }
184            arrow::datatypes::DataType::Float64 => {
185                let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
186                let stream = self.device.default_stream();
187                let slice = stream
188                    .clone_htod(&arr.values()[..])
189                    .map_err(|e| DbxError::Gpu(format!("CUDA HTOD copy (f64) failed: {:?}", e)))?;
190                Ok(GpuData::Float64(slice))
191            }
192            _ => Err(DbxError::NotImplemented(format!(
193                "GPU upload for type {:?} not supported yet",
194                array.data_type()
195            ))),
196        }
197    }
198
199    /// Upload a RecordBatch to GPU memory using Pinned Memory for faster DMA transfer.
200    pub fn upload_batch_pinned(&self, table: &str, batch: &RecordBatch) -> DbxResult<()> {
201        #[cfg(not(feature = "gpu"))]
202        {
203            let _ = (table, batch);
204            Err(DbxError::NotImplemented(
205                "GPU acceleration is not enabled".to_string(),
206            ))
207        }
208
209        #[cfg(feature = "gpu")]
210        {
211            let table_cache = self
212                .buffer_cache
213                .entry(table.to_string())
214                .or_insert_with(DashMap::new);
215            let schema = batch.schema();
216
217            for (i, column) in batch.columns().iter().enumerate() {
218                let column_name = schema.field(i).name();
219                if table_cache.contains_key(column_name) {
220                    continue;
221                }
222
223                // For Int32, use pinned memory
224                if column.data_type() == &arrow::datatypes::DataType::Int32 {
225                    let arr = column.as_any().downcast_ref::<Int32Array>().unwrap();
226                    let values = &arr.values()[..];
227
228                    let mut pinned = unsafe { self.device.alloc_pinned::<i32>(values.len()) }
229                        .map_err(|e| {
230                            DbxError::Gpu(format!("Failed to alloc pinned memory: {:?}", e))
231                        })?;
232                    // Use unsafe pointer copy as a fallback
233                    unsafe {
234                        let ptr = pinned.as_mut_ptr().map_err(|e| {
235                            DbxError::Gpu(format!("Failed to get pinned memory pointer: {:?}", e))
236                        })?;
237                        std::ptr::copy_nonoverlapping(values.as_ptr(), ptr, values.len());
238                    }
239
240                    let stream = self.device.default_stream();
241                    let slice = stream.clone_htod(&pinned).map_err(|e| {
242                        DbxError::Gpu(format!("CUDA pinned HTOD copy failed: {:?}", e))
243                    })?;
244
245                    table_cache.insert(column_name.clone(), Arc::new(GpuData::Int32(slice)));
246                } else {
247                    let gpu_data = self.convert_and_upload(column)?;
248                    table_cache.insert(column_name.clone(), Arc::new(gpu_data));
249                }
250            }
251            Ok(())
252        }
253    }
254
255    /// Retrieve cached GPU data for a specific column.
256    #[cfg(feature = "gpu")]
257    pub(super) fn get_gpu_data(&self, table: &str, column: &str) -> Option<Arc<GpuData>> {
258        self.buffer_cache
259            .get(table)?
260            .get(column)
261            .map(|v| Arc::clone(&v))
262    }
263
264    pub fn clear_table_cache(&self, table: &str) {
265        #[cfg(feature = "gpu")]
266        {
267            self.buffer_cache.remove(table);
268        }
269        #[cfg(not(feature = "gpu"))]
270        {
271            let _ = table;
272        }
273    }
274
275    pub fn clear_all_cache(&self) {
276        #[cfg(feature = "gpu")]
277        {
278            self.buffer_cache.clear();
279        }
280    }
281}