use crate::core::value::NULL_VALUE;
use crate::core::{Result, Row, Value};
use crate::executor::expression::JoinFilter;
use crate::executor::operator::{ColumnInfo, Operator, RowRef};
use crate::functions::registry::global_registry;
use crate::parser::ast::Expression;
use super::hash_join::JoinType;
pub struct NestedLoopJoinOperator {
left: Box<dyn Operator>,
right: Box<dyn Operator>,
join_type: JoinType,
condition: Option<Expression>,
filter: Option<JoinFilter>,
schema: Vec<ColumnInfo>,
left_col_count: usize,
right_col_count: usize,
right_rows: Vec<Row>,
current_left_row: Option<Row>,
current_right_idx: usize,
left_had_match: bool,
right_matched: Vec<bool>,
returning_unmatched_right: bool,
unmatched_right_idx: usize,
cached_null_right: Vec<Value>,
cached_null_left: Vec<Value>,
opened: bool,
left_exhausted: bool,
}
impl NestedLoopJoinOperator {
pub fn new(
left: Box<dyn Operator>,
right: Box<dyn Operator>,
join_type: JoinType,
condition: Option<Expression>,
) -> Self {
let mut schema = Vec::new();
schema.extend(left.schema().iter().cloned());
schema.extend(right.schema().iter().cloned());
let left_col_count = left.schema().len();
let right_col_count = right.schema().len();
Self {
left,
right,
join_type,
condition,
filter: None,
schema,
left_col_count,
right_col_count,
right_rows: Vec::new(),
current_left_row: None,
current_right_idx: 0,
left_had_match: false,
right_matched: Vec::new(),
returning_unmatched_right: false,
unmatched_right_idx: 0,
cached_null_right: Vec::new(), cached_null_left: Vec::new(), opened: false,
left_exhausted: false,
}
}
#[inline]
fn null_left_row(&self) -> Row {
Row::from_values(self.cached_null_left.clone())
}
#[inline]
fn null_right_row(&self) -> Row {
Row::from_values(self.cached_null_right.clone())
}
#[inline]
fn combine(&self, left: &Row, right: &Row) -> Row {
Row::from_combined(left, right)
}
#[inline]
fn advance_left(&mut self) -> Result<bool> {
match self.left.next()? {
Some(row_ref) => {
self.current_left_row = Some(row_ref.into_owned());
self.current_right_idx = 0;
self.left_had_match = false;
Ok(true)
}
None => {
self.left_exhausted = true;
Ok(false)
}
}
}
}
impl Operator for NestedLoopJoinOperator {
fn open(&mut self) -> Result<()> {
self.left.open()?;
self.right.open()?;
if matches!(
self.join_type,
JoinType::Left | JoinType::Right | JoinType::Full
) {
self.cached_null_right = vec![NULL_VALUE; self.right_col_count];
self.cached_null_left = vec![NULL_VALUE; self.left_col_count];
}
let left_cols: Vec<String> = self.left.schema().iter().map(|c| c.name.clone()).collect();
let right_cols: Vec<String> = self.right.schema().iter().map(|c| c.name.clone()).collect();
if let Some(ref cond) = self.condition {
self.filter = Some(JoinFilter::new(
cond,
&left_cols,
&right_cols,
global_registry(),
)?);
}
while let Some(row_ref) = self.right.next()? {
self.right_rows.push(row_ref.into_owned());
}
if matches!(self.join_type, JoinType::Right | JoinType::Full) {
self.right_matched = vec![false; self.right_rows.len()];
}
self.advance_left()?;
self.opened = true;
Ok(())
}
fn next(&mut self) -> Result<Option<RowRef>> {
if !self.opened {
return Err(crate::core::Error::internal(
"NestedLoopJoinOperator::next called before open",
));
}
let is_left_outer = matches!(self.join_type, JoinType::Left | JoinType::Full);
let is_right_outer = matches!(self.join_type, JoinType::Right | JoinType::Full);
let is_cross = matches!(self.join_type, JoinType::Cross);
if self.returning_unmatched_right {
while self.unmatched_right_idx < self.right_rows.len() {
let idx = self.unmatched_right_idx;
self.unmatched_right_idx += 1;
if !self.right_matched[idx] {
let null_left = self.null_left_row();
let right_row = &self.right_rows[idx];
let combined = self.combine(&null_left, right_row);
return Ok(Some(RowRef::Owned(combined)));
}
}
return Ok(None);
}
loop {
if self.left_exhausted {
if is_right_outer {
self.returning_unmatched_right = true;
self.unmatched_right_idx = 0;
return self.next();
}
return Ok(None);
}
let left_row = match &self.current_left_row {
Some(row) => row,
None => {
if !self.advance_left()? {
if is_right_outer {
self.returning_unmatched_right = true;
self.unmatched_right_idx = 0;
return self.next();
}
return Ok(None);
}
self.current_left_row.as_ref().unwrap()
}
};
while self.current_right_idx < self.right_rows.len() {
let right_idx = self.current_right_idx;
self.current_right_idx += 1;
let right_row = &self.right_rows[right_idx];
let matches = if let Some(ref filter) = self.filter {
filter.matches(left_row, right_row)
} else {
is_cross || self.condition.is_none()
};
if matches {
self.left_had_match = true;
if is_right_outer {
self.right_matched[right_idx] = true;
}
let combined = self.combine(left_row, right_row);
return Ok(Some(RowRef::Owned(combined)));
}
}
if is_left_outer && !self.left_had_match {
let left_row = self.current_left_row.take().unwrap();
self.advance_left()?;
let null_right = self.null_right_row();
let combined = Row::from_combined_owned(left_row, null_right);
return Ok(Some(RowRef::Owned(combined)));
}
if !self.advance_left()? {
if is_right_outer {
self.returning_unmatched_right = true;
self.unmatched_right_idx = 0;
return self.next();
}
return Ok(None);
}
}
}
fn close(&mut self) -> Result<()> {
self.left.close()?;
self.right.close()?;
Ok(())
}
fn schema(&self) -> &[ColumnInfo] {
&self.schema
}
fn estimated_rows(&self) -> Option<usize> {
let left_est = self.left.estimated_rows()?;
let right_est = self.right.estimated_rows()?;
Some(match self.join_type {
JoinType::Inner => (left_est * right_est) / 10, JoinType::Left => left_est,
JoinType::Right => right_est,
JoinType::Full => left_est + right_est,
JoinType::Cross => left_est * right_est,
JoinType::Semi => left_est.min(right_est),
JoinType::Anti => left_est,
})
}
fn name(&self) -> &str {
match self.join_type {
JoinType::Inner => "NestedLoop (INNER)",
JoinType::Left => "NestedLoop (LEFT)",
JoinType::Right => "NestedLoop (RIGHT)",
JoinType::Full => "NestedLoop (FULL)",
JoinType::Cross => "NestedLoop (CROSS)",
JoinType::Semi => "NestedLoop (SEMI)",
JoinType::Anti => "NestedLoop (ANTI)",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::executor::operator::MaterializedOperator;
use crate::parser::ast::{Identifier, InfixExpression};
use crate::parser::token::{Position, Token, TokenType};
fn make_rows(data: Vec<Vec<i64>>) -> Vec<Row> {
data.into_iter()
.map(|vals| Row::from_values(vals.into_iter().map(Value::integer).collect()))
.collect()
}
fn make_operator(data: Vec<Vec<i64>>, cols: Vec<&str>) -> Box<dyn Operator> {
let rows = make_rows(data);
let schema = cols.into_iter().map(ColumnInfo::new).collect();
Box::new(MaterializedOperator::new(rows, schema))
}
fn collect_results(op: &mut dyn Operator) -> Result<Vec<Row>> {
let mut results = Vec::new();
op.open()?;
while let Some(row_ref) = op.next()? {
results.push(row_ref.into_owned());
}
op.close()?;
Ok(results)
}
fn make_eq_condition(left_col: &str, right_col: &str) -> Expression {
Expression::Infix(InfixExpression::new(
Token::new(TokenType::Operator, "=", Position::default()),
Box::new(Expression::Identifier(Identifier::new(
Token::new(TokenType::Identifier, left_col, Position::default()),
left_col.to_string(),
))),
"=".to_string(),
Box::new(Expression::Identifier(Identifier::new(
Token::new(TokenType::Identifier, right_col, Position::default()),
right_col.to_string(),
))),
))
}
#[test]
fn test_inner_nested_loop() {
let left = make_operator(
vec![vec![1, 10], vec![2, 20], vec![3, 30]],
vec!["left_id", "value"],
);
let right = make_operator(vec![vec![1, 100], vec![3, 300]], vec!["right_id", "data"]);
let condition = make_eq_condition("left_id", "right_id");
let mut join = NestedLoopJoinOperator::new(left, right, JoinType::Inner, Some(condition));
let results = collect_results(&mut join).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_cross_join() {
let left = make_operator(vec![vec![1], vec![2]], vec!["a"]);
let right = make_operator(vec![vec![10], vec![20]], vec!["b"]);
let mut join = NestedLoopJoinOperator::new(left, right, JoinType::Cross, None);
let results = collect_results(&mut join).unwrap();
assert_eq!(results.len(), 4);
}
#[test]
fn test_left_nested_loop() {
let left = make_operator(
vec![vec![1, 10], vec![2, 20], vec![3, 30]],
vec!["left_id", "value"],
);
let right = make_operator(vec![vec![1, 100]], vec!["right_id", "data"]);
let condition = make_eq_condition("left_id", "right_id");
let mut join = NestedLoopJoinOperator::new(left, right, JoinType::Left, Some(condition));
let results = collect_results(&mut join).unwrap();
assert_eq!(results.len(), 3);
let row2 = results
.iter()
.find(|r| r.get(0) == Some(&Value::integer(2)))
.unwrap();
assert!(row2.get(2).unwrap().is_null());
}
#[test]
fn test_right_nested_loop() {
let left = make_operator(vec![vec![1, 10]], vec!["left_id", "value"]);
let right = make_operator(
vec![vec![1, 100], vec![2, 200], vec![3, 300]],
vec!["right_id", "data"],
);
let condition = make_eq_condition("left_id", "right_id");
let mut join = NestedLoopJoinOperator::new(left, right, JoinType::Right, Some(condition));
let results = collect_results(&mut join).unwrap();
assert_eq!(results.len(), 3);
}
}