use crate::column_store::ColumnStore;
use crate::error::{Error, Result};
use crate::point::{Point, SearchResult};
use crate::velesql::{ColumnRef, JoinClause, JoinCondition, JoinType};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct JoinedResult {
pub search_result: SearchResult,
pub column_data: HashMap<String, serde_json::Value>,
}
impl JoinedResult {
#[must_use]
pub fn new(
search_result: SearchResult,
column_data: HashMap<String, serde_json::Value>,
) -> Self {
Self {
search_result,
column_data,
}
}
}
const SMALL_BATCH_THRESHOLD: usize = 100;
const MEDIUM_BATCH_THRESHOLD: usize = 10_000;
const MEDIUM_BATCH_SIZE: usize = 1_000;
const LARGE_BATCH_SIZE: usize = 5_000;
#[must_use]
pub fn adaptive_batch_size(key_count: usize) -> usize {
match key_count {
0..=SMALL_BATCH_THRESHOLD => key_count.max(1),
n if n <= MEDIUM_BATCH_THRESHOLD => MEDIUM_BATCH_SIZE,
_ => LARGE_BATCH_SIZE,
}
}
#[must_use]
pub fn extract_join_keys(results: &[SearchResult], condition: &JoinCondition) -> Vec<(usize, i64)> {
let key_column = &condition.right.column;
results
.iter()
.enumerate()
.filter_map(|(idx, r)| {
r.point
.payload
.as_ref()
.and_then(|payload| {
payload.get(key_column).and_then(|v| {
v.as_i64().or_else(|| {
if key_column == "id" {
i64::try_from(r.point.id).ok()
} else {
None
}
})
})
})
.or_else(|| {
if key_column == "id" {
i64::try_from(r.point.id).ok()
} else {
None
}
})
.map(|key| (idx, key))
})
.collect()
}
pub fn execute_join(
results: &[SearchResult],
join: &JoinClause,
column_store: &ColumnStore,
) -> Result<Vec<JoinedResult>> {
let condition = validate_join_condition(join, column_store)?;
let join_keys = extract_join_keys(results, &condition);
if join_keys.is_empty() {
return Ok(Vec::new());
}
let batch_size = adaptive_batch_size(join_keys.len());
let null_row_data = build_null_row_data(column_store);
let mut matched_left_indices = vec![false; results.len()];
let mut matched_right_pks: HashSet<i64> = HashSet::with_capacity(join_keys.len());
let mut joined_results = process_join_batches(
results,
&join_keys,
batch_size,
join.join_type,
column_store,
&null_row_data,
&mut matched_left_indices,
&mut matched_right_pks,
);
if matches!(join.join_type, JoinType::Right | JoinType::Full) {
append_unmatched_right_rows(
column_store,
&condition.left.column,
&matched_right_pks,
&mut joined_results,
);
}
if matches!(join.join_type, JoinType::Left | JoinType::Full) {
append_unmatched_left_rows(
results,
&matched_left_indices,
&null_row_data,
&mut joined_results,
);
}
Ok(joined_results)
}
fn validate_join_condition(join: &JoinClause, column_store: &ColumnStore) -> Result<JoinCondition> {
let condition = resolve_join_condition(join).ok_or_else(|| {
Error::Query(format!(
"JOIN on table '{}' must use ON condition or USING(single_column).",
join.table
))
})?;
let join_column = &condition.left.column;
let pk_column = column_store.primary_key_column().ok_or_else(|| {
Error::Query(format!(
"JOIN target '{}' has no primary key configured.",
join.table
))
})?;
if join_column != pk_column {
return Err(Error::Query(format!(
"JOIN on table '{}' requires primary key '{}', got '{}'.",
join.table, pk_column, join_column
)));
}
Ok(condition)
}
#[allow(clippy::too_many_arguments)]
fn process_join_batches(
results: &[SearchResult],
join_keys: &[(usize, i64)],
batch_size: usize,
join_type: JoinType,
column_store: &ColumnStore,
null_row_data: &HashMap<String, serde_json::Value>,
matched_left_indices: &mut [bool],
matched_right_pks: &mut HashSet<i64>,
) -> Vec<JoinedResult> {
let mut joined_results = Vec::with_capacity(join_keys.len());
for chunk in join_keys.chunks(batch_size) {
let pks: Vec<i64> = chunk.iter().map(|(_, pk)| *pk).collect();
let row_map = batch_get_rows(column_store, &pks);
for (result_idx, pk) in chunk {
if let Some(column_data) = row_map.get(pk) {
joined_results.push(JoinedResult::new(
results[*result_idx].clone(),
column_data.clone(),
));
matched_left_indices[*result_idx] = true;
matched_right_pks.insert(*pk);
} else if matches!(join_type, JoinType::Left | JoinType::Full) {
joined_results.push(JoinedResult::new(
results[*result_idx].clone(),
null_row_data.clone(),
));
matched_left_indices[*result_idx] = true;
}
}
}
joined_results
}
fn append_unmatched_right_rows(
column_store: &ColumnStore,
join_column: &str,
matched_right_pks: &HashSet<i64>,
joined_results: &mut Vec<JoinedResult>,
) {
for row_idx in column_store.live_row_indices() {
let Some(pk_value) = column_store.get_value_as_json(join_column, row_idx) else {
continue;
};
let Some(pk) = pk_value.as_i64() else {
continue;
};
if matched_right_pks.contains(&pk) {
continue;
}
let Ok(point_id) = u64::try_from(pk) else {
continue;
};
let row_data = row_as_json_map(column_store, row_idx);
let synthetic_result =
SearchResult::new(Point::metadata_only(point_id, serde_json::json!({})), 0.0);
joined_results.push(JoinedResult::new(synthetic_result, row_data));
}
}
fn append_unmatched_left_rows(
results: &[SearchResult],
matched_left_indices: &[bool],
null_row_data: &HashMap<String, serde_json::Value>,
joined_results: &mut Vec<JoinedResult>,
) {
for (idx, left_result) in results.iter().enumerate() {
if !matched_left_indices[idx] {
joined_results.push(JoinedResult::new(
left_result.clone(),
null_row_data.clone(),
));
}
}
}
fn resolve_join_condition(join: &JoinClause) -> Option<JoinCondition> {
if let Some(condition) = &join.condition {
return Some(normalize_join_condition(condition, join));
}
let Some(using_columns) = &join.using_columns else {
return None;
};
if using_columns.len() != 1 {
return None;
}
let join_column = using_columns[0].clone();
Some(JoinCondition {
left: ColumnRef {
table: Some(join.table.clone()),
column: join_column.clone(),
},
right: ColumnRef {
table: None,
column: join_column,
},
})
}
fn normalize_join_condition(condition: &JoinCondition, join: &JoinClause) -> JoinCondition {
let is_join_side = |table: Option<&str>| {
table.is_some_and(|t| t == join.table || join.alias.as_deref().is_some_and(|a| a == t))
};
if is_join_side(condition.left.table.as_deref()) {
return condition.clone();
}
if is_join_side(condition.right.table.as_deref()) {
return JoinCondition {
left: condition.right.clone(),
right: condition.left.clone(),
};
}
condition.clone()
}
fn row_as_json_map(
column_store: &ColumnStore,
row_idx: usize,
) -> HashMap<String, serde_json::Value> {
let mut row_data = HashMap::new();
for col_name in column_store.column_names() {
if let Some(value) = column_store.get_value_as_json(col_name, row_idx) {
row_data.insert(col_name.to_string(), value);
}
}
row_data
}
fn batch_get_rows(
column_store: &ColumnStore,
pks: &[i64],
) -> HashMap<i64, HashMap<String, serde_json::Value>> {
let mut result = HashMap::with_capacity(pks.len());
for &pk in pks {
if let Some(row_idx) = column_store.get_row_idx_by_pk(pk) {
result.insert(pk, row_as_json_map(column_store, row_idx));
}
}
result
}
fn build_null_row_data(column_store: &ColumnStore) -> HashMap<String, serde_json::Value> {
column_store
.column_names()
.map(|name| (name.to_string(), serde_json::Value::Null))
.collect()
}
#[must_use]
pub fn joined_to_search_results(joined: Vec<JoinedResult>) -> Vec<SearchResult> {
joined
.into_iter()
.map(|jr| {
let mut result = jr.search_result;
let mut payload = result
.point
.payload
.take()
.and_then(|p| p.as_object().cloned())
.unwrap_or_default();
for (key, value) in jr.column_data {
payload.insert(key, value);
}
result.point.payload = Some(serde_json::Value::Object(payload));
result
})
.collect()
}