dbx_core/storage/gpu/
hash_join.rs1#[cfg(feature = "gpu")]
4use cudarc::driver::{LaunchConfig, PushKernelArg};
5
6#[cfg(feature = "gpu")]
7use super::data::GpuData;
8use super::manager::GpuManager;
9use crate::error::{DbxError, DbxResult};
10
11impl GpuManager {
13 pub fn hash_join(
16 &self,
17 build_table: &str,
18 build_key_column: &str,
19 probe_table: &str,
20 probe_key_column: &str,
21 ) -> DbxResult<Vec<(i32, i32)>> {
22 #[cfg(not(feature = "gpu"))]
23 {
24 let _ = (build_table, build_key_column, probe_table, probe_key_column);
25 Err(DbxError::NotImplemented(
26 "GPU acceleration is not enabled".to_string(),
27 ))
28 }
29
30 #[cfg(feature = "gpu")]
31 {
32 let build_keys_data = self
34 .get_gpu_data(build_table, build_key_column)
35 .ok_or_else(|| {
36 DbxError::Gpu(format!(
37 "Column {}.{} not found in GPU cache",
38 build_table, build_key_column
39 ))
40 })?;
41 let (build_keys_slice, build_n) = match &*build_keys_data {
42 GpuData::Int32(slice) => (slice, slice.len()),
43 _ => {
44 return Err(DbxError::NotImplemented(
45 "Hash join keys must be Int32".to_string(),
46 ));
47 }
48 };
49
50 let probe_keys_data = self
52 .get_gpu_data(probe_table, probe_key_column)
53 .ok_or_else(|| {
54 DbxError::Gpu(format!(
55 "Column {}.{} not found in GPU cache",
56 probe_table, probe_key_column
57 ))
58 })?;
59 let (probe_keys_slice, probe_n) = match &*probe_keys_data {
60 GpuData::Int32(slice) => (slice, slice.len()),
61 _ => {
62 return Err(DbxError::NotImplemented(
63 "Hash join keys must be Int32".to_string(),
64 ));
65 }
66 };
67
68 let stream = self.device.default_stream();
69
70 let table_size = (build_n * 2).next_power_of_two();
72 let mut hash_table_keys = vec![-1i32; table_size];
73 let mut hash_table_row_ids = vec![-1i32; table_size];
74
75 let mut hash_table_keys_dev = stream
76 .clone_htod(&hash_table_keys)
77 .map_err(|e| DbxError::Gpu(format!("Failed to alloc hash table keys: {:?}", e)))?;
78 let mut hash_table_row_ids_dev =
79 stream.clone_htod(&hash_table_row_ids).map_err(|e| {
80 DbxError::Gpu(format!("Failed to alloc hash table row IDs: {:?}", e))
81 })?;
82
83 let build_row_ids: Vec<i32> = (0..build_n as i32).collect();
85 let build_row_ids_dev = stream
86 .clone_htod(&build_row_ids)
87 .map_err(|e| DbxError::Gpu(format!("Failed to alloc build row IDs: {:?}", e)))?;
88
89 let build_func = self
90 .module
91 .load_function("hash_join_build_i32")
92 .map_err(|_| DbxError::Gpu("Kernel hash_join_build_i32 not found".to_string()))?;
93
94 let build_cfg = LaunchConfig::for_num_elems(build_n as u32);
95 let build_n_i32 = build_n as i32;
96 let table_size_i32 = table_size as i32;
97
98 let mut builder = stream.launch_builder(&build_func);
99 builder.arg(build_keys_slice);
100 builder.arg(&build_row_ids_dev);
101 builder.arg(&mut hash_table_keys_dev);
102 builder.arg(&mut hash_table_row_ids_dev);
103 builder.arg(&build_n_i32);
104 builder.arg(&table_size_i32);
105 unsafe { builder.launch(build_cfg) }
106 .map_err(|e| DbxError::Gpu(format!("Build kernel launch failed: {:?}", e)))?;
107
108 stream
109 .synchronize()
110 .map_err(|e| DbxError::Gpu(format!("Build stream sync failed: {:?}", e)))?;
111
112 let max_output_size = probe_n * 2; let mut output_probe_ids = vec![0i32; max_output_size];
115 let mut output_build_ids = vec![0i32; max_output_size];
116 let mut match_count = vec![0i32; 1];
117
118 let mut output_probe_ids_dev = stream
119 .clone_htod(&output_probe_ids)
120 .map_err(|e| DbxError::Gpu(format!("Failed to alloc output probe IDs: {:?}", e)))?;
121 let mut output_build_ids_dev = stream
122 .clone_htod(&output_build_ids)
123 .map_err(|e| DbxError::Gpu(format!("Failed to alloc output build IDs: {:?}", e)))?;
124 let mut match_count_dev = stream
125 .clone_htod(&match_count)
126 .map_err(|e| DbxError::Gpu(format!("Failed to alloc match count: {:?}", e)))?;
127
128 let probe_func = self
129 .module
130 .load_function("hash_join_probe_i32")
131 .map_err(|_| DbxError::Gpu("Kernel hash_join_probe_i32 not found".to_string()))?;
132
133 let probe_cfg = LaunchConfig::for_num_elems(probe_n as u32);
134 let probe_n_i32 = probe_n as i32;
135 let max_output_size_i32 = max_output_size as i32;
136
137 let mut builder = stream.launch_builder(&probe_func);
138 builder.arg(probe_keys_slice);
139 builder.arg(&hash_table_keys_dev);
140 builder.arg(&hash_table_row_ids_dev);
141 builder.arg(&mut output_probe_ids_dev);
142 builder.arg(&mut output_build_ids_dev);
143 builder.arg(&mut match_count_dev);
144 builder.arg(&probe_n_i32);
145 builder.arg(&table_size_i32);
146 builder.arg(&max_output_size_i32);
147 unsafe { builder.launch(probe_cfg) }
148 .map_err(|e| DbxError::Gpu(format!("Probe kernel launch failed: {:?}", e)))?;
149
150 stream
151 .synchronize()
152 .map_err(|e| DbxError::Gpu(format!("Probe stream sync failed: {:?}", e)))?;
153
154 match_count = stream
156 .clone_dtoh(&match_count_dev)
157 .map_err(|e| DbxError::Gpu(format!("Failed to copy match count: {:?}", e)))?;
158 let actual_matches = match_count[0] as usize;
159
160 output_probe_ids = stream
161 .clone_dtoh(&output_probe_ids_dev)
162 .map_err(|e| DbxError::Gpu(format!("Failed to copy output probe IDs: {:?}", e)))?;
163 output_build_ids = stream
164 .clone_dtoh(&output_build_ids_dev)
165 .map_err(|e| DbxError::Gpu(format!("Failed to copy output build IDs: {:?}", e)))?;
166
167 let mut results = Vec::new();
169 for i in 0..actual_matches.min(max_output_size) {
170 results.push((output_probe_ids[i], output_build_ids[i]));
171 }
172
173 Ok(results)
174 }
175 }
176}