1use std::collections::HashMap;
5
6use reifydb_core::value::column::{columns::Columns, headers::ColumnHeaders};
7use reifydb_rql::expression::Expression;
8use reifydb_runtime::hash::Hash128;
9use reifydb_transaction::transaction::Transaction;
10use reifydb_type::{
11 fragment::Fragment,
12 value::{Value, row_number::RowNumber},
13};
14use tracing::instrument;
15
16use super::common::{
17 JoinContext, compute_join_hash, eval_join_condition, keys_equal_by_index, load_and_merge_all,
18 resolve_column_names,
19};
20use crate::{
21 Result,
22 expression::{
23 compile::{CompiledExpr, compile_expression},
24 context::CompileContext,
25 },
26 vm::volcano::query::{QueryContext, QueryNode},
27};
28
29pub(crate) struct EquiKeyPair {
30 pub left_col_name: String,
31 pub right_col_name: String,
32}
33
34pub(crate) struct EquiJoinAnalysis {
35 pub equi_keys: Vec<EquiKeyPair>,
36 pub residual: Vec<Expression>,
37}
38
39pub(crate) fn extract_equi_keys(on: &[Expression]) -> EquiJoinAnalysis {
40 let mut leaves = Vec::new();
41 for expr in on {
42 if contains_or(expr) {
43 return EquiJoinAnalysis {
44 equi_keys: vec![],
45 residual: on.to_vec(),
46 };
47 }
48 flatten_and(expr, &mut leaves);
49 }
50
51 let mut equi_keys = Vec::new();
52 let mut residual = Vec::new();
53
54 for leaf in leaves {
55 match try_extract_equi_pair(&leaf) {
56 Some(pair) => equi_keys.push(pair),
57 None => residual.push(leaf),
58 }
59 }
60
61 EquiJoinAnalysis {
62 equi_keys,
63 residual,
64 }
65}
66
67fn contains_or(expr: &Expression) -> bool {
68 match expr {
69 Expression::Or(_) => true,
70 Expression::And(and) => contains_or(&and.left) || contains_or(&and.right),
71 _ => false,
72 }
73}
74
75fn flatten_and(expr: &Expression, out: &mut Vec<Expression>) {
76 match expr {
77 Expression::And(and) => {
78 flatten_and(&and.left, out);
79 flatten_and(&and.right, out);
80 }
81 other => out.push(other.clone()),
82 }
83}
84
85fn try_extract_equi_pair(expr: &Expression) -> Option<EquiKeyPair> {
86 if let Expression::Equal(eq) = expr {
87 if let (Expression::Column(col), Expression::AccessSource(acc)) = (eq.left.as_ref(), eq.right.as_ref())
88 {
89 return Some(EquiKeyPair {
90 left_col_name: col.0.name.text().to_string(),
91 right_col_name: acc.column.name.text().to_string(),
92 });
93 }
94
95 if let (Expression::AccessSource(acc), Expression::Column(col)) = (eq.left.as_ref(), eq.right.as_ref())
96 {
97 return Some(EquiKeyPair {
98 left_col_name: col.0.name.text().to_string(),
99 right_col_name: acc.column.name.text().to_string(),
100 });
101 }
102 }
103 None
104}
105
106#[derive(Clone, Copy, PartialEq)]
107enum HashJoinMode {
108 Inner,
109 Left,
110}
111
112struct HashJoinState {
113 build_columns: Columns,
114 hash_table: HashMap<Hash128, Vec<usize>>,
115 resolved_names: Vec<String>,
116 right_width: usize,
117 right_key_indices: Vec<usize>,
118 left_key_indices: Vec<usize>,
119
120 probe_batch: Option<Columns>,
121 probe_row_idx: usize,
122 current_matches: Vec<usize>,
123 current_match_idx: usize,
124 current_row_matched: bool,
125 probe_exhausted: bool,
126
127 compiled_residual: Vec<CompiledExpr>,
128
129 hash_buf: Vec<u8>,
130}
131
132pub(crate) struct HashJoinNode {
133 left: Box<dyn QueryNode>,
134 right: Box<dyn QueryNode>,
135
136 left_key_names: Vec<String>,
137 right_key_names: Vec<String>,
138 residual: Vec<Expression>,
139 alias: Option<Fragment>,
140 mode: HashJoinMode,
141
142 headers: Option<ColumnHeaders>,
143 context: JoinContext,
144
145 state: Option<HashJoinState>,
146}
147
148impl HashJoinNode {
149 pub(crate) fn new_inner(
150 left: Box<dyn QueryNode>,
151 right: Box<dyn QueryNode>,
152 analysis: EquiJoinAnalysis,
153 alias: Option<Fragment>,
154 ) -> Self {
155 let (left_keys, right_keys) = split_key_names(&analysis.equi_keys);
156 Self {
157 left,
158 right,
159 left_key_names: left_keys,
160 right_key_names: right_keys,
161 residual: analysis.residual,
162 alias,
163 mode: HashJoinMode::Inner,
164 headers: None,
165 context: JoinContext::new(),
166 state: None,
167 }
168 }
169
170 pub(crate) fn new_left(
171 left: Box<dyn QueryNode>,
172 right: Box<dyn QueryNode>,
173 analysis: EquiJoinAnalysis,
174 alias: Option<Fragment>,
175 ) -> Self {
176 let (left_keys, right_keys) = split_key_names(&analysis.equi_keys);
177 Self {
178 left,
179 right,
180 left_key_names: left_keys,
181 right_key_names: right_keys,
182 residual: analysis.residual,
183 alias,
184 mode: HashJoinMode::Left,
185 headers: None,
186 context: JoinContext::new(),
187 state: None,
188 }
189 }
190
191 fn build<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<()> {
192 let build_columns = load_and_merge_all(&mut self.right, rx, ctx)?;
193 let right_width = build_columns.len();
194
195 let right_key_indices: Vec<usize> = if build_columns.is_empty() {
196 Vec::new()
197 } else {
198 self.right_key_names
199 .iter()
200 .map(|name| {
201 build_columns
202 .iter()
203 .position(|c| c.name().text() == name)
204 .unwrap_or_else(|| panic!("right key column '{}' not found", name))
205 })
206 .collect()
207 };
208
209 let mut hash_table: HashMap<Hash128, Vec<usize>> = HashMap::new();
210 let mut hash_buf = Vec::with_capacity(256);
211 let row_count = build_columns.row_count();
212 for j in 0..row_count {
213 if let Some(h) = compute_join_hash(&build_columns, &right_key_indices, j, &mut hash_buf) {
214 hash_table.entry(h).or_default().push(j);
215 }
216 }
217
218 let compile_ctx = CompileContext {
219 symbols: &ctx.symbols,
220 };
221 let compiled_residual: Vec<CompiledExpr> = self
222 .residual
223 .iter()
224 .map(|e| compile_expression(&compile_ctx, e).expect("compile residual"))
225 .collect();
226
227 self.state = Some(HashJoinState {
228 build_columns,
229 hash_table,
230 resolved_names: Vec::new(),
231 right_width,
232 right_key_indices,
233 left_key_indices: Vec::new(),
234 probe_batch: None,
235 probe_row_idx: 0,
236 current_matches: Vec::new(),
237 current_match_idx: 0,
238 current_row_matched: false,
239 probe_exhausted: false,
240 compiled_residual,
241 hash_buf,
242 });
243
244 Ok(())
245 }
246}
247
248fn split_key_names(pairs: &[EquiKeyPair]) -> (Vec<String>, Vec<String>) {
249 let left: Vec<String> = pairs.iter().map(|p| p.left_col_name.clone()).collect();
250 let right: Vec<String> = pairs.iter().map(|p| p.right_col_name.clone()).collect();
251 (left, right)
252}
253
254fn compute_matches_for_probe_row(
255 hash_table: &HashMap<Hash128, Vec<usize>>,
256 build_columns: &Columns,
257 probe: &Columns,
258 probe_row_idx: usize,
259 left_key_indices: &[usize],
260 right_key_indices: &[usize],
261 buf: &mut Vec<u8>,
262) -> Vec<usize> {
263 match compute_join_hash(probe, left_key_indices, probe_row_idx, buf) {
264 Some(h) => hash_table
265 .get(&h)
266 .map(|indices| {
267 indices.iter()
268 .copied()
269 .filter(|&build_idx| {
270 keys_equal_by_index(
271 probe,
272 probe_row_idx,
273 left_key_indices,
274 build_columns,
275 build_idx,
276 right_key_indices,
277 )
278 })
279 .collect()
280 })
281 .unwrap_or_default(),
282 None => Vec::new(),
283 }
284}
285
286impl QueryNode for HashJoinNode {
287 #[instrument(level = "trace", skip_all, name = "volcano::join::hash::initialize")]
288 fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
289 self.context.set(ctx);
290 self.left.initialize(rx, ctx)?;
291 self.right.initialize(rx, ctx)?;
292 Ok(())
293 }
294
295 #[instrument(level = "trace", skip_all, name = "volcano::join::hash::next")]
296 fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
297 debug_assert!(self.context.is_initialized(), "HashJoinNode::next() called before initialize()");
298
299 if self.state.is_none() {
300 self.build(rx, ctx)?;
301 }
302
303 let batch_size = ctx.batch_size as usize;
304 let stored_ctx = self.context.get().clone();
305
306 let mut state = self.state.take().unwrap();
307
308 if state.probe_exhausted && state.probe_batch.is_none() {
309 if self.headers.is_some() {
310 self.state = Some(state);
311 return Ok(None);
312 }
313 if state.resolved_names.is_empty() {
314 let empty_left = Columns::empty();
315 let resolved =
316 resolve_column_names(&empty_left, &state.build_columns, &self.alias, None);
317 state.resolved_names = resolved.qualified_names;
318 }
319 let names_refs: Vec<&str> = state.resolved_names.iter().map(|s| s.as_str()).collect();
320 let empty: Vec<Vec<Value>> = Vec::new();
321 let columns = Columns::from_rows(&names_refs, &empty);
322 self.headers = Some(ColumnHeaders::from_columns(&columns));
323 self.state = Some(state);
324 return Ok(Some(columns));
325 }
326
327 let mut result_rows: Vec<Vec<Value>> = Vec::new();
328 let mut result_row_numbers: Vec<RowNumber> = Vec::new();
329
330 let resolve_names_and_indices = |state: &mut HashJoinState,
331 probe: &Columns,
332 left_key_names: &[String]| {
333 if state.resolved_names.is_empty() {
334 let resolved = resolve_column_names(probe, &state.build_columns, &self.alias, None);
335 state.resolved_names = resolved.qualified_names;
336 }
337 if state.left_key_indices.is_empty() {
338 state.left_key_indices = left_key_names
339 .iter()
340 .map(|name| {
341 probe.iter().position(|c| c.name().text() == name).unwrap_or_else(
342 || panic!("left key column '{}' not found", name),
343 )
344 })
345 .collect();
346 }
347 };
348
349 while result_rows.len() < batch_size {
350 if state.probe_batch.is_none() {
351 if state.probe_exhausted {
352 break;
353 }
354 match self.left.next(rx, ctx)? {
355 Some(batch) => {
356 resolve_names_and_indices(&mut state, &batch, &self.left_key_names);
357 state.probe_batch = Some(batch);
358 state.probe_row_idx = 0;
359
360 let probe = state.probe_batch.as_ref().unwrap();
361 if probe.row_count() == 0 {
362 state.probe_batch = None;
363 continue;
364 }
365 state.current_matches = compute_matches_for_probe_row(
366 &state.hash_table,
367 &state.build_columns,
368 probe,
369 0,
370 &state.left_key_indices,
371 &state.right_key_indices,
372 &mut state.hash_buf,
373 );
374 state.current_match_idx = 0;
375 state.current_row_matched = false;
376 }
377 None => {
378 state.probe_exhausted = true;
379 break;
380 }
381 }
382 }
383
384 let probe = state.probe_batch.as_ref().unwrap();
385 let probe_row_count = probe.row_count();
386
387 if state.current_match_idx >= state.current_matches.len() {
388 if self.mode == HashJoinMode::Left && !state.current_row_matched {
389 let left_row = probe.get_row(state.probe_row_idx);
390 let mut combined = left_row;
391 combined.extend(vec![Value::none(); state.right_width]);
392 result_rows.push(combined);
393 if !probe.row_numbers.is_empty() {
394 result_row_numbers.push(probe.row_numbers[state.probe_row_idx]);
395 }
396 }
397
398 state.probe_row_idx += 1;
399 if state.probe_row_idx >= probe_row_count {
400 state.probe_batch = None;
401 continue;
402 }
403
404 state.current_matches = compute_matches_for_probe_row(
405 &state.hash_table,
406 &state.build_columns,
407 probe,
408 state.probe_row_idx,
409 &state.left_key_indices,
410 &state.right_key_indices,
411 &mut state.hash_buf,
412 );
413 state.current_match_idx = 0;
414 state.current_row_matched = false;
415 continue;
416 }
417
418 let build_idx = state.current_matches[state.current_match_idx];
419 state.current_match_idx += 1;
420
421 let left_row = probe.get_row(state.probe_row_idx);
422 let right_row = state.build_columns.get_row(build_idx);
423
424 if !state.compiled_residual.is_empty()
425 && !eval_join_condition(
426 &state.compiled_residual,
427 probe,
428 &state.build_columns,
429 &left_row,
430 &right_row,
431 &self.alias,
432 &stored_ctx,
433 ) {
434 continue;
435 }
436
437 state.current_row_matched = true;
438 let mut combined = left_row;
439 combined.extend(right_row);
440 result_rows.push(combined);
441 if !probe.row_numbers.is_empty() {
442 result_row_numbers.push(probe.row_numbers[state.probe_row_idx]);
443 }
444 }
445
446 self.state = Some(state);
447
448 if result_rows.is_empty() {
449 if self.headers.is_some() {
450 return Ok(None);
451 }
452 if let Some(ref mut state) = self.state {
453 if state.resolved_names.is_empty() {
454 let empty_left = Columns::empty();
455 let resolved = resolve_column_names(
456 &empty_left,
457 &state.build_columns,
458 &self.alias,
459 None,
460 );
461 state.resolved_names = resolved.qualified_names;
462 }
463 let names_refs: Vec<&str> = state.resolved_names.iter().map(|s| s.as_str()).collect();
464 let columns = Columns::from_rows(&names_refs, &result_rows);
465 self.headers = Some(ColumnHeaders::from_columns(&columns));
466 return Ok(Some(columns));
467 }
468 return Ok(None);
469 }
470
471 let state = self.state.as_ref().unwrap();
472 let names_refs: Vec<&str> = state.resolved_names.iter().map(|s| s.as_str()).collect();
473 let columns = if result_row_numbers.is_empty() {
474 Columns::from_rows(&names_refs, &result_rows)
475 } else {
476 Columns::from_rows(&names_refs, &result_rows).with_row_numbers(result_row_numbers)
477 };
478
479 self.headers = Some(ColumnHeaders::from_columns(&columns));
480 Ok(Some(columns))
481 }
482
483 fn headers(&self) -> Option<ColumnHeaders> {
484 self.headers.clone()
485 }
486}