use std::sync::Arc;
use crate::database::MoteDB;
use crate::types::{Value, SqlRow};
use crate::error::Result;
pub struct IndexNestedLoopJoin {
db: Arc<MoteDB>,
}
impl IndexNestedLoopJoin {
pub fn new(db: Arc<MoteDB>) -> Self {
Self { db }
}
pub fn execute(
&self,
outer_table: &str,
inner_table: &str,
join_column: &str,
outer_rows: Vec<SqlRow>,
) -> Result<Vec<SqlRow>> {
let mut results = Vec::with_capacity(outer_rows.len());
let has_index = self.check_index_exists(inner_table, join_column);
if !has_index {
return self.nested_loop_join(outer_rows, inner_table, join_column);
}
for outer_row in outer_rows {
if let Some(key) = outer_row.get(join_column) {
match self.index_lookup(inner_table, join_column, key) {
Ok(inner_row_ids) => {
for row_id in inner_row_ids {
if let Ok(Some(inner_row_data)) = self.db.get_table_row(inner_table, row_id) {
let inner_row = self.vec_to_sql_row(&inner_row_data, inner_table)?;
let merged = Self::merge_rows(&outer_row, &inner_row);
results.push(merged);
}
}
}
Err(_) => {
continue;
}
}
}
}
Ok(results)
}
fn check_index_exists(&self, table_name: &str, column_name: &str) -> bool {
self.db.query_by_column(table_name, column_name, &Value::Integer(0)).is_ok()
}
fn index_lookup(&self, table_name: &str, column_name: &str, key: &Value) -> Result<Vec<u64>> {
self.db.query_by_column(table_name, column_name, key)
}
fn nested_loop_join(
&self,
outer_rows: Vec<SqlRow>,
inner_table: &str,
join_column: &str,
) -> Result<Vec<SqlRow>> {
let mut results = Vec::with_capacity(outer_rows.len() * 2);
let max_rows = 100_000_u64;
for outer_row in outer_rows {
let outer_key = outer_row.get(join_column);
for row_id in 0..max_rows {
match self.db.get_table_row(inner_table, row_id) {
Ok(Some(inner_row_data)) => {
let inner_row = self.vec_to_sql_row(&inner_row_data, inner_table)?;
if let Some(inner_key) = inner_row.get(join_column) {
if outer_key == Some(inner_key) {
let merged = Self::merge_rows(&outer_row, &inner_row);
results.push(merged);
}
}
}
Ok(None) => break, Err(_) => break, }
}
}
Ok(results)
}
fn vec_to_sql_row(&self, values: &[Value], table_name: &str) -> Result<SqlRow> {
let schema = self.db.get_table_schema(table_name)?;
let mut row = SqlRow::new();
for (i, col) in schema.columns.iter().enumerate() {
if i < values.len() {
let qualified_name = format!("{}.{}", table_name, col.name);
row.insert(qualified_name, values[i].clone());
}
}
Ok(row)
}
fn merge_rows(outer_row: &SqlRow, inner_row: &SqlRow) -> SqlRow {
let mut merged = SqlRow::with_capacity(outer_row.len() + inner_row.len());
for (col, val) in outer_row.iter() {
merged.insert(col.clone(), val.clone());
}
for (col, val) in inner_row.iter() {
merged.insert(col.clone(), val.clone());
}
merged
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_index_join_basic() {
}
}