Skip to main content

reifydb_engine/vm/volcano/join/
natural.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::collections::{HashMap, HashSet};
5
6use reifydb_core::{
7	common::JoinType,
8	value::column::{columns::Columns, headers::ColumnHeaders},
9};
10use reifydb_runtime::hash::Hash128;
11use reifydb_transaction::transaction::Transaction;
12use reifydb_type::{
13	fragment::Fragment,
14	value::{Value, row_number::RowNumber},
15};
16use tracing::instrument;
17
18use super::common::{JoinContext, compute_join_hash, load_and_merge_all, resolve_column_names};
19use crate::{
20	Result,
21	vm::volcano::query::{QueryContext, QueryNode},
22};
23
24pub struct NaturalJoinNode {
25	left: Box<dyn QueryNode>,
26	right: Box<dyn QueryNode>,
27	join_type: JoinType,
28	alias: Option<Fragment>,
29	headers: Option<ColumnHeaders>,
30	context: JoinContext,
31}
32
33impl NaturalJoinNode {
34	pub(crate) fn new(
35		left: Box<dyn QueryNode>,
36		right: Box<dyn QueryNode>,
37		join_type: JoinType,
38		alias: Option<Fragment>,
39	) -> Self {
40		Self {
41			left,
42			right,
43			join_type,
44			alias,
45			headers: None,
46			context: JoinContext::new(),
47		}
48	}
49
50	fn find_common_columns(left_columns: &Columns, right_columns: &Columns) -> Vec<(String, usize, usize)> {
51		let mut common_columns = Vec::new();
52
53		for (left_idx, left_col) in left_columns.iter().enumerate() {
54			for (right_idx, right_col) in right_columns.iter().enumerate() {
55				if left_col.name() == right_col.name() {
56					common_columns.push((left_col.name().text().to_string(), left_idx, right_idx));
57				}
58			}
59		}
60
61		common_columns
62	}
63}
64
65impl QueryNode for NaturalJoinNode {
66	#[instrument(name = "volcano::join::natural::initialize", level = "trace", skip_all)]
67	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
68		self.context.set(ctx);
69		self.left.initialize(rx, ctx)?;
70		self.right.initialize(rx, ctx)?;
71		Ok(())
72	}
73
74	#[instrument(name = "volcano::join::natural::next", level = "trace", skip_all)]
75	fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
76		debug_assert!(self.context.is_initialized(), "NaturalJoinNode::next() called before initialize()");
77
78		if self.headers.is_some() {
79			return Ok(None);
80		}
81
82		let left_columns = load_and_merge_all(&mut self.left, rx, ctx)?;
83		let right_columns = load_and_merge_all(&mut self.right, rx, ctx)?;
84
85		let left_rows = left_columns.row_count();
86		let left_row_numbers = left_columns.row_numbers.to_vec();
87
88		let common_columns = Self::find_common_columns(&left_columns, &right_columns);
89
90		if common_columns.is_empty() {
91			return Ok(None);
92		}
93
94		let excluded_right_cols: HashSet<usize> =
95			common_columns.iter().map(|(_, _, right_idx)| *right_idx).collect();
96
97		let excluded_indices: Vec<usize> = excluded_right_cols.iter().copied().collect();
98
99		let resolved =
100			resolve_column_names(&left_columns, &right_columns, &self.alias, Some(&excluded_indices));
101
102		let mut result_rows = Vec::new();
103		let mut result_row_numbers: Vec<RowNumber> = Vec::new();
104
105		let right_col_indices: Vec<usize> = common_columns.iter().map(|(_, _, ri)| *ri).collect();
106		let mut hash_buf = Vec::with_capacity(256);
107		let mut hash_table: HashMap<Hash128, Vec<usize>> = HashMap::new();
108		let right_rows = right_columns.row_count();
109		for j in 0..right_rows {
110			if let Some(h) = compute_join_hash(&right_columns, &right_col_indices, j, &mut hash_buf) {
111				hash_table.entry(h).or_default().push(j);
112			}
113		}
114
115		let left_col_indices: Vec<usize> = common_columns.iter().map(|(_, li, _)| *li).collect();
116
117		for i in 0..left_rows {
118			let left_row = left_columns.get_row(i);
119			let mut matched = false;
120
121			let candidates = compute_join_hash(&left_columns, &left_col_indices, i, &mut hash_buf)
122				.and_then(|h| hash_table.get(&h));
123
124			if let Some(indices) = candidates {
125				for &j in indices {
126					let right_row = right_columns.get_row(j);
127
128					let all_match = common_columns.iter().all(|(_, left_idx, right_idx)| {
129						left_row[*left_idx] == right_row[*right_idx]
130					});
131
132					if all_match {
133						let mut combined = left_row.clone();
134						for (idx, value) in right_row.iter().enumerate() {
135							if !excluded_right_cols.contains(&idx) {
136								combined.push(value.clone());
137							}
138						}
139						result_rows.push(combined);
140						matched = true;
141						if !left_row_numbers.is_empty() {
142							result_row_numbers.push(left_row_numbers[i]);
143						}
144					}
145				}
146			}
147
148			if !matched && matches!(self.join_type, JoinType::Left) {
149				let mut combined = left_row.clone();
150
151				let undefined_count = right_columns.len() - excluded_right_cols.len();
152				combined.extend(vec![Value::none(); undefined_count]);
153				result_rows.push(combined);
154				if !left_row_numbers.is_empty() {
155					result_row_numbers.push(left_row_numbers[i]);
156				}
157			}
158		}
159
160		let names_refs: Vec<&str> = resolved.qualified_names.iter().map(|s| s.as_str()).collect();
161		let columns = if result_row_numbers.is_empty() {
162			Columns::from_rows(&names_refs, &result_rows)
163		} else {
164			Columns::from_rows(&names_refs, &result_rows).with_row_numbers(result_row_numbers)
165		};
166
167		self.headers = Some(ColumnHeaders::from_columns(&columns));
168		Ok(Some(columns))
169	}
170
171	fn headers(&self) -> Option<ColumnHeaders> {
172		self.headers.clone()
173	}
174}