reifydb_engine/vm/volcano/join/
natural.rs1use std::collections::{HashMap, HashSet};
5
6use reifydb_core::{
7 common::JoinType,
8 value::column::{columns::Columns, headers::ColumnHeaders},
9};
10use reifydb_runtime::hash::Hash128;
11use reifydb_transaction::transaction::Transaction;
12use reifydb_type::{
13 fragment::Fragment,
14 value::{Value, row_number::RowNumber},
15};
16use tracing::instrument;
17
18use super::common::{JoinContext, compute_join_hash, load_and_merge_all, resolve_column_names};
19use crate::{
20 Result,
21 vm::volcano::query::{QueryContext, QueryNode},
22};
23
24pub struct NaturalJoinNode {
25 left: Box<dyn QueryNode>,
26 right: Box<dyn QueryNode>,
27 join_type: JoinType,
28 alias: Option<Fragment>,
29 headers: Option<ColumnHeaders>,
30 context: JoinContext,
31}
32
33impl NaturalJoinNode {
34 pub(crate) fn new(
35 left: Box<dyn QueryNode>,
36 right: Box<dyn QueryNode>,
37 join_type: JoinType,
38 alias: Option<Fragment>,
39 ) -> Self {
40 Self {
41 left,
42 right,
43 join_type,
44 alias,
45 headers: None,
46 context: JoinContext::new(),
47 }
48 }
49
50 fn find_common_columns(left_columns: &Columns, right_columns: &Columns) -> Vec<(String, usize, usize)> {
51 let mut common_columns = Vec::new();
52
53 for (left_idx, left_col) in left_columns.iter().enumerate() {
54 for (right_idx, right_col) in right_columns.iter().enumerate() {
55 if left_col.name() == right_col.name() {
56 common_columns.push((left_col.name().text().to_string(), left_idx, right_idx));
57 }
58 }
59 }
60
61 common_columns
62 }
63}
64
65impl QueryNode for NaturalJoinNode {
66 #[instrument(name = "volcano::join::natural::initialize", level = "trace", skip_all)]
67 fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
68 self.context.set(ctx);
69 self.left.initialize(rx, ctx)?;
70 self.right.initialize(rx, ctx)?;
71 Ok(())
72 }
73
74 #[instrument(name = "volcano::join::natural::next", level = "trace", skip_all)]
75 fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
76 debug_assert!(self.context.is_initialized(), "NaturalJoinNode::next() called before initialize()");
77
78 if self.headers.is_some() {
79 return Ok(None);
80 }
81
82 let left_columns = load_and_merge_all(&mut self.left, rx, ctx)?;
83 let right_columns = load_and_merge_all(&mut self.right, rx, ctx)?;
84
85 let left_rows = left_columns.row_count();
86 let left_row_numbers = left_columns.row_numbers.to_vec();
87
88 let common_columns = Self::find_common_columns(&left_columns, &right_columns);
89
90 if common_columns.is_empty() {
91 return Ok(None);
92 }
93
94 let excluded_right_cols: HashSet<usize> =
95 common_columns.iter().map(|(_, _, right_idx)| *right_idx).collect();
96
97 let excluded_indices: Vec<usize> = excluded_right_cols.iter().copied().collect();
98
99 let resolved =
100 resolve_column_names(&left_columns, &right_columns, &self.alias, Some(&excluded_indices));
101
102 let mut result_rows = Vec::new();
103 let mut result_row_numbers: Vec<RowNumber> = Vec::new();
104
105 let right_col_indices: Vec<usize> = common_columns.iter().map(|(_, _, ri)| *ri).collect();
106 let mut hash_buf = Vec::with_capacity(256);
107 let mut hash_table: HashMap<Hash128, Vec<usize>> = HashMap::new();
108 let right_rows = right_columns.row_count();
109 for j in 0..right_rows {
110 if let Some(h) = compute_join_hash(&right_columns, &right_col_indices, j, &mut hash_buf) {
111 hash_table.entry(h).or_default().push(j);
112 }
113 }
114
115 let left_col_indices: Vec<usize> = common_columns.iter().map(|(_, li, _)| *li).collect();
116
117 for i in 0..left_rows {
118 let left_row = left_columns.get_row(i);
119 let mut matched = false;
120
121 let candidates = compute_join_hash(&left_columns, &left_col_indices, i, &mut hash_buf)
122 .and_then(|h| hash_table.get(&h));
123
124 if let Some(indices) = candidates {
125 for &j in indices {
126 let right_row = right_columns.get_row(j);
127
128 let all_match = common_columns.iter().all(|(_, left_idx, right_idx)| {
129 left_row[*left_idx] == right_row[*right_idx]
130 });
131
132 if all_match {
133 let mut combined = left_row.clone();
134 for (idx, value) in right_row.iter().enumerate() {
135 if !excluded_right_cols.contains(&idx) {
136 combined.push(value.clone());
137 }
138 }
139 result_rows.push(combined);
140 matched = true;
141 if !left_row_numbers.is_empty() {
142 result_row_numbers.push(left_row_numbers[i]);
143 }
144 }
145 }
146 }
147
148 if !matched && matches!(self.join_type, JoinType::Left) {
149 let mut combined = left_row.clone();
150
151 let undefined_count = right_columns.len() - excluded_right_cols.len();
152 combined.extend(vec![Value::none(); undefined_count]);
153 result_rows.push(combined);
154 if !left_row_numbers.is_empty() {
155 result_row_numbers.push(left_row_numbers[i]);
156 }
157 }
158 }
159
160 let names_refs: Vec<&str> = resolved.qualified_names.iter().map(|s| s.as_str()).collect();
161 let columns = if result_row_numbers.is_empty() {
162 Columns::from_rows(&names_refs, &result_rows)
163 } else {
164 Columns::from_rows(&names_refs, &result_rows).with_row_numbers(result_row_numbers)
165 };
166
167 self.headers = Some(ColumnHeaders::from_columns(&columns));
168 Ok(Some(columns))
169 }
170
171 fn headers(&self) -> Option<ColumnHeaders> {
172 self.headers.clone()
173 }
174}