Skip to main content

reifydb_engine/vm/volcano/join/
hash.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::collections::HashMap;
5
6use reifydb_core::value::column::{columns::Columns, headers::ColumnHeaders};
7use reifydb_rql::expression::Expression;
8use reifydb_runtime::hash::Hash128;
9use reifydb_transaction::transaction::Transaction;
10use reifydb_type::{
11	fragment::Fragment,
12	value::{Value, row_number::RowNumber},
13};
14use tracing::instrument;
15
16use super::common::{
17	JoinContext, compute_join_hash, eval_join_condition, keys_equal_by_index, load_and_merge_all,
18	resolve_column_names,
19};
20use crate::{
21	Result,
22	expression::{
23		compile::{CompiledExpr, compile_expression},
24		context::CompileContext,
25	},
26	vm::volcano::query::{QueryContext, QueryNode},
27};
28
29pub(crate) struct EquiKeyPair {
30	pub left_col_name: String,
31	pub right_col_name: String,
32}
33
34pub(crate) struct EquiJoinAnalysis {
35	pub equi_keys: Vec<EquiKeyPair>,
36	pub residual: Vec<Expression>,
37}
38
39pub(crate) fn extract_equi_keys(on: &[Expression]) -> EquiJoinAnalysis {
40	let mut leaves = Vec::new();
41	for expr in on {
42		if contains_or(expr) {
43			return EquiJoinAnalysis {
44				equi_keys: vec![],
45				residual: on.to_vec(),
46			};
47		}
48		flatten_and(expr, &mut leaves);
49	}
50
51	let mut equi_keys = Vec::new();
52	let mut residual = Vec::new();
53
54	for leaf in leaves {
55		match try_extract_equi_pair(&leaf) {
56			Some(pair) => equi_keys.push(pair),
57			None => residual.push(leaf),
58		}
59	}
60
61	EquiJoinAnalysis {
62		equi_keys,
63		residual,
64	}
65}
66
67fn contains_or(expr: &Expression) -> bool {
68	match expr {
69		Expression::Or(_) => true,
70		Expression::And(and) => contains_or(&and.left) || contains_or(&and.right),
71		_ => false,
72	}
73}
74
75fn flatten_and(expr: &Expression, out: &mut Vec<Expression>) {
76	match expr {
77		Expression::And(and) => {
78			flatten_and(&and.left, out);
79			flatten_and(&and.right, out);
80		}
81		other => out.push(other.clone()),
82	}
83}
84
85fn try_extract_equi_pair(expr: &Expression) -> Option<EquiKeyPair> {
86	if let Expression::Equal(eq) = expr {
87		if let (Expression::Column(col), Expression::AccessSource(acc)) = (eq.left.as_ref(), eq.right.as_ref())
88		{
89			return Some(EquiKeyPair {
90				left_col_name: col.0.name.text().to_string(),
91				right_col_name: acc.column.name.text().to_string(),
92			});
93		}
94
95		if let (Expression::AccessSource(acc), Expression::Column(col)) = (eq.left.as_ref(), eq.right.as_ref())
96		{
97			return Some(EquiKeyPair {
98				left_col_name: col.0.name.text().to_string(),
99				right_col_name: acc.column.name.text().to_string(),
100			});
101		}
102	}
103	None
104}
105
106#[derive(Clone, Copy, PartialEq)]
107enum HashJoinMode {
108	Inner,
109	Left,
110}
111
112struct HashJoinState {
113	build_columns: Columns,
114	hash_table: HashMap<Hash128, Vec<usize>>,
115	resolved_names: Vec<String>,
116	right_width: usize,
117	right_key_indices: Vec<usize>,
118	left_key_indices: Vec<usize>,
119
120	probe_batch: Option<Columns>,
121	probe_row_idx: usize,
122	current_matches: Vec<usize>,
123	current_match_idx: usize,
124	current_row_matched: bool,
125	probe_exhausted: bool,
126
127	compiled_residual: Vec<CompiledExpr>,
128
129	hash_buf: Vec<u8>,
130}
131
132pub(crate) struct HashJoinNode {
133	left: Box<dyn QueryNode>,
134	right: Box<dyn QueryNode>,
135
136	left_key_names: Vec<String>,
137	right_key_names: Vec<String>,
138	residual: Vec<Expression>,
139	alias: Option<Fragment>,
140	mode: HashJoinMode,
141
142	headers: Option<ColumnHeaders>,
143	context: JoinContext,
144
145	state: Option<HashJoinState>,
146}
147
148impl HashJoinNode {
149	pub(crate) fn new_inner(
150		left: Box<dyn QueryNode>,
151		right: Box<dyn QueryNode>,
152		analysis: EquiJoinAnalysis,
153		alias: Option<Fragment>,
154	) -> Self {
155		let (left_keys, right_keys) = split_key_names(&analysis.equi_keys);
156		Self {
157			left,
158			right,
159			left_key_names: left_keys,
160			right_key_names: right_keys,
161			residual: analysis.residual,
162			alias,
163			mode: HashJoinMode::Inner,
164			headers: None,
165			context: JoinContext::new(),
166			state: None,
167		}
168	}
169
170	pub(crate) fn new_left(
171		left: Box<dyn QueryNode>,
172		right: Box<dyn QueryNode>,
173		analysis: EquiJoinAnalysis,
174		alias: Option<Fragment>,
175	) -> Self {
176		let (left_keys, right_keys) = split_key_names(&analysis.equi_keys);
177		Self {
178			left,
179			right,
180			left_key_names: left_keys,
181			right_key_names: right_keys,
182			residual: analysis.residual,
183			alias,
184			mode: HashJoinMode::Left,
185			headers: None,
186			context: JoinContext::new(),
187			state: None,
188		}
189	}
190
191	fn build<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<()> {
192		let build_columns = load_and_merge_all(&mut self.right, rx, ctx)?;
193		let right_width = build_columns.len();
194
195		let right_key_indices: Vec<usize> = if build_columns.is_empty() {
196			Vec::new()
197		} else {
198			self.right_key_names
199				.iter()
200				.map(|name| {
201					build_columns
202						.iter()
203						.position(|c| c.name().text() == name)
204						.unwrap_or_else(|| panic!("right key column '{}' not found", name))
205				})
206				.collect()
207		};
208
209		let mut hash_table: HashMap<Hash128, Vec<usize>> = HashMap::new();
210		let mut hash_buf = Vec::with_capacity(256);
211		let row_count = build_columns.row_count();
212		for j in 0..row_count {
213			if let Some(h) = compute_join_hash(&build_columns, &right_key_indices, j, &mut hash_buf) {
214				hash_table.entry(h).or_default().push(j);
215			}
216		}
217
218		let compile_ctx = CompileContext {
219			symbols: &ctx.symbols,
220		};
221		let compiled_residual: Vec<CompiledExpr> = self
222			.residual
223			.iter()
224			.map(|e| compile_expression(&compile_ctx, e).expect("compile residual"))
225			.collect();
226
227		self.state = Some(HashJoinState {
228			build_columns,
229			hash_table,
230			resolved_names: Vec::new(),
231			right_width,
232			right_key_indices,
233			left_key_indices: Vec::new(),
234			probe_batch: None,
235			probe_row_idx: 0,
236			current_matches: Vec::new(),
237			current_match_idx: 0,
238			current_row_matched: false,
239			probe_exhausted: false,
240			compiled_residual,
241			hash_buf,
242		});
243
244		Ok(())
245	}
246}
247
248fn split_key_names(pairs: &[EquiKeyPair]) -> (Vec<String>, Vec<String>) {
249	let left: Vec<String> = pairs.iter().map(|p| p.left_col_name.clone()).collect();
250	let right: Vec<String> = pairs.iter().map(|p| p.right_col_name.clone()).collect();
251	(left, right)
252}
253
254fn compute_matches_for_probe_row(
255	hash_table: &HashMap<Hash128, Vec<usize>>,
256	build_columns: &Columns,
257	probe: &Columns,
258	probe_row_idx: usize,
259	left_key_indices: &[usize],
260	right_key_indices: &[usize],
261	buf: &mut Vec<u8>,
262) -> Vec<usize> {
263	match compute_join_hash(probe, left_key_indices, probe_row_idx, buf) {
264		Some(h) => hash_table
265			.get(&h)
266			.map(|indices| {
267				indices.iter()
268					.copied()
269					.filter(|&build_idx| {
270						keys_equal_by_index(
271							probe,
272							probe_row_idx,
273							left_key_indices,
274							build_columns,
275							build_idx,
276							right_key_indices,
277						)
278					})
279					.collect()
280			})
281			.unwrap_or_default(),
282		None => Vec::new(),
283	}
284}
285
286impl QueryNode for HashJoinNode {
287	#[instrument(level = "trace", skip_all, name = "volcano::join::hash::initialize")]
288	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
289		self.context.set(ctx);
290		self.left.initialize(rx, ctx)?;
291		self.right.initialize(rx, ctx)?;
292		Ok(())
293	}
294
295	#[instrument(level = "trace", skip_all, name = "volcano::join::hash::next")]
296	fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
297		debug_assert!(self.context.is_initialized(), "HashJoinNode::next() called before initialize()");
298
299		if self.state.is_none() {
300			self.build(rx, ctx)?;
301		}
302
303		let batch_size = ctx.batch_size as usize;
304		let stored_ctx = self.context.get().clone();
305
306		let mut state = self.state.take().unwrap();
307
308		if state.probe_exhausted && state.probe_batch.is_none() {
309			if self.headers.is_some() {
310				self.state = Some(state);
311				return Ok(None);
312			}
313			if state.resolved_names.is_empty() {
314				let empty_left = Columns::empty();
315				let resolved =
316					resolve_column_names(&empty_left, &state.build_columns, &self.alias, None);
317				state.resolved_names = resolved.qualified_names;
318			}
319			let names_refs: Vec<&str> = state.resolved_names.iter().map(|s| s.as_str()).collect();
320			let empty: Vec<Vec<Value>> = Vec::new();
321			let columns = Columns::from_rows(&names_refs, &empty);
322			self.headers = Some(ColumnHeaders::from_columns(&columns));
323			self.state = Some(state);
324			return Ok(Some(columns));
325		}
326
327		let mut result_rows: Vec<Vec<Value>> = Vec::new();
328		let mut result_row_numbers: Vec<RowNumber> = Vec::new();
329
330		let resolve_names_and_indices = |state: &mut HashJoinState,
331		                                 probe: &Columns,
332		                                 left_key_names: &[String]| {
333			if state.resolved_names.is_empty() {
334				let resolved = resolve_column_names(probe, &state.build_columns, &self.alias, None);
335				state.resolved_names = resolved.qualified_names;
336			}
337			if state.left_key_indices.is_empty() {
338				state.left_key_indices = left_key_names
339					.iter()
340					.map(|name| {
341						probe.iter().position(|c| c.name().text() == name).unwrap_or_else(
342							|| panic!("left key column '{}' not found", name),
343						)
344					})
345					.collect();
346			}
347		};
348
349		while result_rows.len() < batch_size {
350			if state.probe_batch.is_none() {
351				if state.probe_exhausted {
352					break;
353				}
354				match self.left.next(rx, ctx)? {
355					Some(batch) => {
356						resolve_names_and_indices(&mut state, &batch, &self.left_key_names);
357						state.probe_batch = Some(batch);
358						state.probe_row_idx = 0;
359
360						let probe = state.probe_batch.as_ref().unwrap();
361						if probe.row_count() == 0 {
362							state.probe_batch = None;
363							continue;
364						}
365						state.current_matches = compute_matches_for_probe_row(
366							&state.hash_table,
367							&state.build_columns,
368							probe,
369							0,
370							&state.left_key_indices,
371							&state.right_key_indices,
372							&mut state.hash_buf,
373						);
374						state.current_match_idx = 0;
375						state.current_row_matched = false;
376					}
377					None => {
378						state.probe_exhausted = true;
379						break;
380					}
381				}
382			}
383
384			let probe = state.probe_batch.as_ref().unwrap();
385			let probe_row_count = probe.row_count();
386
387			if state.current_match_idx >= state.current_matches.len() {
388				if self.mode == HashJoinMode::Left && !state.current_row_matched {
389					let left_row = probe.get_row(state.probe_row_idx);
390					let mut combined = left_row;
391					combined.extend(vec![Value::none(); state.right_width]);
392					result_rows.push(combined);
393					if !probe.row_numbers.is_empty() {
394						result_row_numbers.push(probe.row_numbers[state.probe_row_idx]);
395					}
396				}
397
398				state.probe_row_idx += 1;
399				if state.probe_row_idx >= probe_row_count {
400					state.probe_batch = None;
401					continue;
402				}
403
404				state.current_matches = compute_matches_for_probe_row(
405					&state.hash_table,
406					&state.build_columns,
407					probe,
408					state.probe_row_idx,
409					&state.left_key_indices,
410					&state.right_key_indices,
411					&mut state.hash_buf,
412				);
413				state.current_match_idx = 0;
414				state.current_row_matched = false;
415				continue;
416			}
417
418			let build_idx = state.current_matches[state.current_match_idx];
419			state.current_match_idx += 1;
420
421			let left_row = probe.get_row(state.probe_row_idx);
422			let right_row = state.build_columns.get_row(build_idx);
423
424			if !state.compiled_residual.is_empty()
425				&& !eval_join_condition(
426					&state.compiled_residual,
427					probe,
428					&state.build_columns,
429					&left_row,
430					&right_row,
431					&self.alias,
432					&stored_ctx,
433				) {
434				continue;
435			}
436
437			state.current_row_matched = true;
438			let mut combined = left_row;
439			combined.extend(right_row);
440			result_rows.push(combined);
441			if !probe.row_numbers.is_empty() {
442				result_row_numbers.push(probe.row_numbers[state.probe_row_idx]);
443			}
444		}
445
446		self.state = Some(state);
447
448		if result_rows.is_empty() {
449			if self.headers.is_some() {
450				return Ok(None);
451			}
452			if let Some(ref mut state) = self.state {
453				if state.resolved_names.is_empty() {
454					let empty_left = Columns::empty();
455					let resolved = resolve_column_names(
456						&empty_left,
457						&state.build_columns,
458						&self.alias,
459						None,
460					);
461					state.resolved_names = resolved.qualified_names;
462				}
463				let names_refs: Vec<&str> = state.resolved_names.iter().map(|s| s.as_str()).collect();
464				let columns = Columns::from_rows(&names_refs, &result_rows);
465				self.headers = Some(ColumnHeaders::from_columns(&columns));
466				return Ok(Some(columns));
467			}
468			return Ok(None);
469		}
470
471		let state = self.state.as_ref().unwrap();
472		let names_refs: Vec<&str> = state.resolved_names.iter().map(|s| s.as_str()).collect();
473		let columns = if result_row_numbers.is_empty() {
474			Columns::from_rows(&names_refs, &result_rows)
475		} else {
476			Columns::from_rows(&names_refs, &result_rows).with_row_numbers(result_row_numbers)
477		};
478
479		self.headers = Some(ColumnHeaders::from_columns(&columns));
480		Ok(Some(columns))
481	}
482
483	fn headers(&self) -> Option<ColumnHeaders> {
484		self.headers.clone()
485	}
486}