use std::collections::{HashMap, HashSet};
use reifydb_core::{
common::JoinType,
value::column::{columns::Columns, headers::ColumnHeaders},
};
use reifydb_runtime::hash::Hash128;
use reifydb_transaction::transaction::Transaction;
use reifydb_type::{
fragment::Fragment,
value::{Value, row_number::RowNumber},
};
use tracing::instrument;
use super::common::{JoinContext, compute_join_hash, load_and_merge_all, resolve_column_names};
use crate::{
Result,
vm::volcano::query::{QueryContext, QueryNode},
};
pub struct NaturalJoinNode {
left: Box<dyn QueryNode>,
right: Box<dyn QueryNode>,
join_type: JoinType,
alias: Option<Fragment>,
headers: Option<ColumnHeaders>,
context: JoinContext,
}
impl NaturalJoinNode {
pub(crate) fn new(
left: Box<dyn QueryNode>,
right: Box<dyn QueryNode>,
join_type: JoinType,
alias: Option<Fragment>,
) -> Self {
Self {
left,
right,
join_type,
alias,
headers: None,
context: JoinContext::new(),
}
}
fn find_common_columns(left_columns: &Columns, right_columns: &Columns) -> Vec<(String, usize, usize)> {
let mut common_columns = Vec::new();
for (left_idx, left_col) in left_columns.iter().enumerate() {
for (right_idx, right_col) in right_columns.iter().enumerate() {
if left_col.name() == right_col.name() {
common_columns.push((left_col.name().text().to_string(), left_idx, right_idx));
}
}
}
common_columns
}
}
impl QueryNode for NaturalJoinNode {
#[instrument(name = "volcano::join::natural::initialize", level = "trace", skip_all)]
fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
self.context.set(ctx);
self.left.initialize(rx, ctx)?;
self.right.initialize(rx, ctx)?;
Ok(())
}
#[instrument(name = "volcano::join::natural::next", level = "trace", skip_all)]
fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
debug_assert!(self.context.is_initialized(), "NaturalJoinNode::next() called before initialize()");
if self.headers.is_some() {
return Ok(None);
}
let left_columns = load_and_merge_all(&mut self.left, rx, ctx)?;
let right_columns = load_and_merge_all(&mut self.right, rx, ctx)?;
let left_rows = left_columns.row_count();
let left_row_numbers = left_columns.row_numbers.to_vec();
let common_columns = Self::find_common_columns(&left_columns, &right_columns);
if common_columns.is_empty() {
return Ok(None);
}
let excluded_right_cols: HashSet<usize> =
common_columns.iter().map(|(_, _, right_idx)| *right_idx).collect();
let excluded_indices: Vec<usize> = excluded_right_cols.iter().copied().collect();
let resolved =
resolve_column_names(&left_columns, &right_columns, &self.alias, Some(&excluded_indices));
let mut result_rows = Vec::new();
let mut result_row_numbers: Vec<RowNumber> = Vec::new();
let right_col_indices: Vec<usize> = common_columns.iter().map(|(_, _, ri)| *ri).collect();
let mut hash_buf = Vec::with_capacity(256);
let mut hash_table: HashMap<Hash128, Vec<usize>> = HashMap::new();
let right_rows = right_columns.row_count();
for j in 0..right_rows {
if let Some(h) = compute_join_hash(&right_columns, &right_col_indices, j, &mut hash_buf) {
hash_table.entry(h).or_default().push(j);
}
}
let left_col_indices: Vec<usize> = common_columns.iter().map(|(_, li, _)| *li).collect();
for i in 0..left_rows {
let left_row = left_columns.get_row(i);
let mut matched = false;
let candidates = compute_join_hash(&left_columns, &left_col_indices, i, &mut hash_buf)
.and_then(|h| hash_table.get(&h));
if let Some(indices) = candidates {
for &j in indices {
let right_row = right_columns.get_row(j);
let all_match = common_columns.iter().all(|(_, left_idx, right_idx)| {
left_row[*left_idx] == right_row[*right_idx]
});
if all_match {
let mut combined = left_row.clone();
for (idx, value) in right_row.iter().enumerate() {
if !excluded_right_cols.contains(&idx) {
combined.push(value.clone());
}
}
result_rows.push(combined);
matched = true;
if !left_row_numbers.is_empty() {
result_row_numbers.push(left_row_numbers[i]);
}
}
}
}
if !matched && matches!(self.join_type, JoinType::Left) {
let mut combined = left_row.clone();
let undefined_count = right_columns.len() - excluded_right_cols.len();
combined.extend(vec![Value::none(); undefined_count]);
result_rows.push(combined);
if !left_row_numbers.is_empty() {
result_row_numbers.push(left_row_numbers[i]);
}
}
}
let names_refs: Vec<&str> = resolved.qualified_names.iter().map(|s| s.as_str()).collect();
let columns = if result_row_numbers.is_empty() {
Columns::from_rows(&names_refs, &result_rows)
} else {
Columns::from_rows_with_row_numbers(&names_refs, &result_rows, result_row_numbers)
};
self.headers = Some(ColumnHeaders::from_columns(&columns));
Ok(Some(columns))
}
fn headers(&self) -> Option<ColumnHeaders> {
self.headers.clone()
}
}