Skip to main content

dbx_core/storage/gpu/
hash_join.rs

1//! GPU Hash Join operations.
2
3#[cfg(feature = "gpu")]
4use cudarc::driver::{LaunchConfig, PushKernelArg};
5
6#[cfg(feature = "gpu")]
7use super::data::GpuData;
8use super::manager::GpuManager;
9use crate::error::{DbxError, DbxResult};
10
11/// Hash Join operations impl block
12impl GpuManager {
13    /// Hash Join on GPU: Build + Probe phases.
14    /// Returns Vec<(probe_row_id, build_row_id)> for matched rows.
15    pub fn hash_join(
16        &self,
17        build_table: &str,
18        build_key_column: &str,
19        probe_table: &str,
20        probe_key_column: &str,
21    ) -> DbxResult<Vec<(i32, i32)>> {
22        #[cfg(not(feature = "gpu"))]
23        {
24            let _ = (build_table, build_key_column, probe_table, probe_key_column);
25            Err(DbxError::NotImplemented(
26                "GPU acceleration is not enabled".to_string(),
27            ))
28        }
29
30        #[cfg(feature = "gpu")]
31        {
32            // Get build side keys
33            let build_keys_data = self
34                .get_gpu_data(build_table, build_key_column)
35                .ok_or_else(|| {
36                    DbxError::Gpu(format!(
37                        "Column {}.{} not found in GPU cache",
38                        build_table, build_key_column
39                    ))
40                })?;
41            let (build_keys_slice, build_n) = match &*build_keys_data {
42                GpuData::Int32(slice) => (slice, slice.len()),
43                _ => {
44                    return Err(DbxError::NotImplemented(
45                        "Hash join keys must be Int32".to_string(),
46                    ));
47                }
48            };
49
50            // Get probe side keys
51            let probe_keys_data = self
52                .get_gpu_data(probe_table, probe_key_column)
53                .ok_or_else(|| {
54                    DbxError::Gpu(format!(
55                        "Column {}.{} not found in GPU cache",
56                        probe_table, probe_key_column
57                    ))
58                })?;
59            let (probe_keys_slice, probe_n) = match &*probe_keys_data {
60                GpuData::Int32(slice) => (slice, slice.len()),
61                _ => {
62                    return Err(DbxError::NotImplemented(
63                        "Hash join keys must be Int32".to_string(),
64                    ));
65                }
66            };
67
68            let stream = self.device.default_stream();
69
70            // Phase 1: Build hash table
71            let table_size = (build_n * 2).next_power_of_two();
72            let mut hash_table_keys = vec![-1i32; table_size];
73            let mut hash_table_row_ids = vec![-1i32; table_size];
74
75            let mut hash_table_keys_dev = stream
76                .clone_htod(&hash_table_keys)
77                .map_err(|e| DbxError::Gpu(format!("Failed to alloc hash table keys: {:?}", e)))?;
78            let mut hash_table_row_ids_dev =
79                stream.clone_htod(&hash_table_row_ids).map_err(|e| {
80                    DbxError::Gpu(format!("Failed to alloc hash table row IDs: {:?}", e))
81                })?;
82
83            // Create build row IDs (0, 1, 2, ...)
84            let build_row_ids: Vec<i32> = (0..build_n as i32).collect();
85            let build_row_ids_dev = stream
86                .clone_htod(&build_row_ids)
87                .map_err(|e| DbxError::Gpu(format!("Failed to alloc build row IDs: {:?}", e)))?;
88
89            let build_func = self
90                .module
91                .load_function("hash_join_build_i32")
92                .map_err(|_| DbxError::Gpu("Kernel hash_join_build_i32 not found".to_string()))?;
93
94            let build_cfg = LaunchConfig::for_num_elems(build_n as u32);
95            let build_n_i32 = build_n as i32;
96            let table_size_i32 = table_size as i32;
97
98            let mut builder = stream.launch_builder(&build_func);
99            builder.arg(build_keys_slice);
100            builder.arg(&build_row_ids_dev);
101            builder.arg(&mut hash_table_keys_dev);
102            builder.arg(&mut hash_table_row_ids_dev);
103            builder.arg(&build_n_i32);
104            builder.arg(&table_size_i32);
105            unsafe { builder.launch(build_cfg) }
106                .map_err(|e| DbxError::Gpu(format!("Build kernel launch failed: {:?}", e)))?;
107
108            stream
109                .synchronize()
110                .map_err(|e| DbxError::Gpu(format!("Build stream sync failed: {:?}", e)))?;
111
112            // Phase 2: Probe hash table
113            let max_output_size = probe_n * 2; // Conservative estimate
114            let mut output_probe_ids = vec![0i32; max_output_size];
115            let mut output_build_ids = vec![0i32; max_output_size];
116            let mut match_count = vec![0i32; 1];
117
118            let mut output_probe_ids_dev = stream
119                .clone_htod(&output_probe_ids)
120                .map_err(|e| DbxError::Gpu(format!("Failed to alloc output probe IDs: {:?}", e)))?;
121            let mut output_build_ids_dev = stream
122                .clone_htod(&output_build_ids)
123                .map_err(|e| DbxError::Gpu(format!("Failed to alloc output build IDs: {:?}", e)))?;
124            let mut match_count_dev = stream
125                .clone_htod(&match_count)
126                .map_err(|e| DbxError::Gpu(format!("Failed to alloc match count: {:?}", e)))?;
127
128            let probe_func = self
129                .module
130                .load_function("hash_join_probe_i32")
131                .map_err(|_| DbxError::Gpu("Kernel hash_join_probe_i32 not found".to_string()))?;
132
133            let probe_cfg = LaunchConfig::for_num_elems(probe_n as u32);
134            let probe_n_i32 = probe_n as i32;
135            let max_output_size_i32 = max_output_size as i32;
136
137            let mut builder = stream.launch_builder(&probe_func);
138            builder.arg(probe_keys_slice);
139            builder.arg(&hash_table_keys_dev);
140            builder.arg(&hash_table_row_ids_dev);
141            builder.arg(&mut output_probe_ids_dev);
142            builder.arg(&mut output_build_ids_dev);
143            builder.arg(&mut match_count_dev);
144            builder.arg(&probe_n_i32);
145            builder.arg(&table_size_i32);
146            builder.arg(&max_output_size_i32);
147            unsafe { builder.launch(probe_cfg) }
148                .map_err(|e| DbxError::Gpu(format!("Probe kernel launch failed: {:?}", e)))?;
149
150            stream
151                .synchronize()
152                .map_err(|e| DbxError::Gpu(format!("Probe stream sync failed: {:?}", e)))?;
153
154            // Copy results back
155            match_count = stream
156                .clone_dtoh(&match_count_dev)
157                .map_err(|e| DbxError::Gpu(format!("Failed to copy match count: {:?}", e)))?;
158            let actual_matches = match_count[0] as usize;
159
160            output_probe_ids = stream
161                .clone_dtoh(&output_probe_ids_dev)
162                .map_err(|e| DbxError::Gpu(format!("Failed to copy output probe IDs: {:?}", e)))?;
163            output_build_ids = stream
164                .clone_dtoh(&output_build_ids_dev)
165                .map_err(|e| DbxError::Gpu(format!("Failed to copy output build IDs: {:?}", e)))?;
166
167            // Extract matched pairs
168            let mut results = Vec::new();
169            for i in 0..actual_matches.min(max_output_size) {
170                results.push((output_probe_ids[i], output_build_ids[i]));
171            }
172
173            Ok(results)
174        }
175    }
176}