use reifydb_core::value::column::{columns::Columns, headers::ColumnHeaders};
use reifydb_rql::expression::Expression;
use reifydb_transaction::transaction::Transaction;
use reifydb_type::{
fragment::Fragment,
value::{Value, row_number::RowNumber},
};
use tracing::instrument;
use super::common::{JoinContext, build_eval_columns, load_and_merge_all, resolve_column_names};
use crate::{
Result,
expression::{
compile::compile_expression,
context::{CompileContext, EvalContext},
},
vm::volcano::query::{QueryContext, QueryNode},
};
#[derive(Clone, Copy, PartialEq)]
enum NestedLoopMode {
Inner,
Left,
}
pub struct NestedLoopJoinNode {
left: Box<dyn QueryNode>,
right: Box<dyn QueryNode>,
on: Vec<Expression>,
alias: Option<Fragment>,
mode: NestedLoopMode,
headers: Option<ColumnHeaders>,
context: JoinContext,
}
impl NestedLoopJoinNode {
pub(crate) fn new_inner(
left: Box<dyn QueryNode>,
right: Box<dyn QueryNode>,
on: Vec<Expression>,
alias: Option<Fragment>,
) -> Self {
Self {
left,
right,
on,
alias,
mode: NestedLoopMode::Inner,
headers: None,
context: JoinContext::new(),
}
}
pub(crate) fn new_left(
left: Box<dyn QueryNode>,
right: Box<dyn QueryNode>,
on: Vec<Expression>,
alias: Option<Fragment>,
) -> Self {
Self {
left,
right,
on,
alias,
mode: NestedLoopMode::Left,
headers: None,
context: JoinContext::new(),
}
}
}
impl QueryNode for NestedLoopJoinNode {
#[instrument(level = "trace", skip_all, name = "volcano::join::nested_loop::initialize")]
fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
let compile_ctx = CompileContext {
functions: &ctx.services.functions,
symbols: &ctx.symbols,
};
self.context.compiled =
self.on.iter().map(|e| compile_expression(&compile_ctx, e).expect("compile")).collect();
self.context.set(ctx);
self.left.initialize(rx, ctx)?;
self.right.initialize(rx, ctx)?;
Ok(())
}
#[instrument(level = "trace", skip_all, name = "volcano::join::nested_loop::next")]
fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
debug_assert!(self.context.is_initialized(), "NestedLoopJoinNode::next() called before initialize()");
let _stored_ctx = self.context.get();
if self.headers.is_some() {
return Ok(None);
}
let left_columns = load_and_merge_all(&mut self.left, rx, ctx)?;
let right_columns = load_and_merge_all(&mut self.right, rx, ctx)?;
let left_rows = left_columns.row_count();
let right_rows = right_columns.row_count();
let right_width = right_columns.len();
let left_row_numbers = left_columns.row_numbers.to_vec();
let resolved = resolve_column_names(&left_columns, &right_columns, &self.alias, None);
let session = EvalContext::from_query(ctx);
let mut result_rows = Vec::new();
let mut result_row_numbers: Vec<RowNumber> = Vec::new();
for i in 0..left_rows {
let left_row = left_columns.get_row(i);
let mut matched = false;
for j in 0..right_rows {
let right_row = right_columns.get_row(j);
let eval_columns = build_eval_columns(
&left_columns,
&right_columns,
&left_row,
&right_row,
&self.alias,
);
let exec_ctx = session.with_eval_join(Columns::new(eval_columns));
let all_true = self.context.compiled.iter().fold(true, |acc, compiled_expr| {
let col = compiled_expr.execute(&exec_ctx).unwrap();
matches!(col.data().get_value(0), Value::Boolean(true)) && acc
});
if all_true {
let mut combined = left_row.clone();
combined.extend(right_row.clone());
result_rows.push(combined);
matched = true;
if !left_row_numbers.is_empty() {
result_row_numbers.push(left_row_numbers[i]);
}
}
}
if self.mode == NestedLoopMode::Left && !matched {
let mut combined = left_row.clone();
combined.extend(vec![Value::none(); right_width]);
result_rows.push(combined);
if !left_row_numbers.is_empty() {
result_row_numbers.push(left_row_numbers[i]);
}
}
}
let names_refs: Vec<&str> = resolved.qualified_names.iter().map(|s| s.as_str()).collect();
let columns = if result_row_numbers.is_empty() {
Columns::from_rows(&names_refs, &result_rows)
} else {
Columns::from_rows_with_row_numbers(&names_refs, &result_rows, result_row_numbers)
};
self.headers = Some(ColumnHeaders::from_columns(&columns));
Ok(Some(columns))
}
fn headers(&self) -> Option<ColumnHeaders> {
self.headers.clone()
}
}