use crate::algebra::{Solution, Variable};
use crate::cost_model::CostModel;
use crate::executor::parallel_optimized::{CacheFriendlyHashJoin, SortMergeJoin};
use anyhow::Result;
use std::collections::HashSet;
pub struct JoinAlgorithmSelector {
#[allow(dead_code)]
cost_model: CostModel,
hash_join: CacheFriendlyHashJoin,
sort_merge_join: SortMergeJoin,
memory_threshold: usize,
}
impl JoinAlgorithmSelector {
pub fn new(cost_model: CostModel, memory_threshold: usize) -> Self {
Self {
cost_model,
hash_join: CacheFriendlyHashJoin::new(16), sort_merge_join: SortMergeJoin::new(memory_threshold),
memory_threshold,
}
}
pub fn execute_optimal_join(
&mut self,
left_solutions: Vec<Solution>,
right_solutions: Vec<Solution>,
join_variables: &[Variable],
) -> Result<(Vec<Solution>, JoinExecutionStats)> {
let start_time = std::time::Instant::now();
let join_info =
self.analyze_join_characteristics(&left_solutions, &right_solutions, join_variables);
let selected_algorithm = self.select_join_algorithm(&join_info)?;
let result = match selected_algorithm {
OptimalJoinAlgorithm::HashJoin => {
self.hash_join
.join_parallel(left_solutions, right_solutions, join_variables)?
}
OptimalJoinAlgorithm::SortMergeJoin => {
self.sort_merge_join
.join(left_solutions, right_solutions, join_variables)?
}
OptimalJoinAlgorithm::NestedLoopJoin => {
self.execute_nested_loop_join(left_solutions, right_solutions, join_variables)?
}
OptimalJoinAlgorithm::IndexJoin => {
self.hash_join
.join_parallel(left_solutions, right_solutions, join_variables)?
}
};
let execution_time = start_time.elapsed();
let stats = JoinExecutionStats {
algorithm_used: selected_algorithm,
execution_time,
input_cardinalities: (join_info.left_cardinality, join_info.right_cardinality),
output_cardinality: result.len(),
memory_used: self.estimate_memory_usage(&result),
join_selectivity: result.len() as f64
/ (join_info.left_cardinality as f64 * join_info.right_cardinality as f64).max(1.0),
};
Ok((result, stats))
}
fn analyze_join_characteristics(
&self,
left_solutions: &[Solution],
right_solutions: &[Solution],
join_variables: &[Variable],
) -> JoinCharacteristics {
let left_cardinality = left_solutions.len();
let right_cardinality = right_solutions.len();
let left_distinct = self.estimate_distinct_values(left_solutions, join_variables);
let right_distinct = self.estimate_distinct_values(right_solutions, join_variables);
let estimated_selectivity = if left_distinct > 0 && right_distinct > 0 {
1.0 / (left_distinct.max(right_distinct) as f64)
} else {
0.1 };
let left_sorted = self.is_sorted_by_join_keys(left_solutions, join_variables);
let right_sorted = self.is_sorted_by_join_keys(right_solutions, join_variables);
let memory_requirement = (left_cardinality + right_cardinality) * 100;
JoinCharacteristics {
left_cardinality,
right_cardinality,
left_distinct_values: left_distinct,
right_distinct_values: right_distinct,
estimated_selectivity,
left_sorted,
right_sorted,
memory_requirement,
join_variable_count: join_variables.len(),
}
}
fn select_join_algorithm(
&mut self,
join_info: &JoinCharacteristics,
) -> Result<OptimalJoinAlgorithm> {
let candidate_algorithm =
if join_info.left_cardinality < 1000 || join_info.right_cardinality < 1000 {
OptimalJoinAlgorithm::NestedLoopJoin
} else if join_info.left_sorted && join_info.right_sorted {
OptimalJoinAlgorithm::SortMergeJoin
} else if join_info.memory_requirement > self.memory_threshold {
OptimalJoinAlgorithm::SortMergeJoin
} else if join_info.estimated_selectivity < 0.01 {
OptimalJoinAlgorithm::HashJoin
} else {
OptimalJoinAlgorithm::HashJoin
};
let final_algorithm = self.validate_with_cost_model(candidate_algorithm, join_info)?;
Ok(final_algorithm)
}
fn validate_with_cost_model(
&self,
candidate: OptimalJoinAlgorithm,
join_info: &JoinCharacteristics,
) -> Result<OptimalJoinAlgorithm> {
let mut algorithm_costs = Vec::new();
let hash_join_cost = self.estimate_hash_join_cost(join_info);
algorithm_costs.push((OptimalJoinAlgorithm::HashJoin, hash_join_cost));
let sort_merge_cost = self.estimate_sort_merge_cost(join_info);
algorithm_costs.push((OptimalJoinAlgorithm::SortMergeJoin, sort_merge_cost));
if join_info.left_cardinality < 10000 && join_info.right_cardinality < 10000 {
let nested_loop_cost = self.estimate_nested_loop_cost(join_info);
algorithm_costs.push((OptimalJoinAlgorithm::NestedLoopJoin, nested_loop_cost));
}
let (optimal_algorithm, optimal_cost) = algorithm_costs
.iter()
.min_by(|(_, cost1), (_, cost2)| {
cost1
.partial_cmp(cost2)
.unwrap_or(std::cmp::Ordering::Equal)
})
.copied()
.unwrap_or((candidate, f64::MAX));
let candidate_cost = algorithm_costs
.iter()
.find(|(algo, _)| *algo == candidate)
.map(|(_, cost)| *cost)
.unwrap_or(optimal_cost);
let improvement_threshold = 0.20;
if optimal_cost < candidate_cost * (1.0 - improvement_threshold) {
tracing::debug!(
"Cost model override: {:?} -> {:?} (cost: {:.2} -> {:.2}, {:.1}% improvement)",
candidate,
optimal_algorithm,
candidate_cost,
optimal_cost,
(candidate_cost - optimal_cost) / candidate_cost * 100.0
);
Ok(optimal_algorithm)
} else {
Ok(candidate)
}
}
fn estimate_hash_join_cost(&self, join_info: &JoinCharacteristics) -> f64 {
let build_cost = join_info.left_cardinality.min(join_info.right_cardinality) as f64;
let probe_cost = join_info.left_cardinality.max(join_info.right_cardinality) as f64;
let lookup_cost = probe_cost;
let output_cost = (join_info.left_cardinality as f64
* join_info.right_cardinality as f64
* join_info.estimated_selectivity)
.max(1.0);
build_cost + lookup_cost + output_cost
}
fn estimate_sort_merge_cost(&self, join_info: &JoinCharacteristics) -> f64 {
let left_sort_cost = if join_info.left_sorted {
0.0
} else {
join_info.left_cardinality as f64 * (join_info.left_cardinality as f64).log2()
};
let right_sort_cost = if join_info.right_sorted {
0.0
} else {
join_info.right_cardinality as f64 * (join_info.right_cardinality as f64).log2()
};
let merge_cost = (join_info.left_cardinality + join_info.right_cardinality) as f64;
let output_cost = (join_info.left_cardinality as f64
* join_info.right_cardinality as f64
* join_info.estimated_selectivity)
.max(1.0);
left_sort_cost + right_sort_cost + merge_cost + output_cost
}
fn estimate_nested_loop_cost(&self, join_info: &JoinCharacteristics) -> f64 {
let scan_cost = join_info.left_cardinality as f64 * join_info.right_cardinality as f64;
let output_cost = scan_cost * join_info.estimated_selectivity;
scan_cost + output_cost
}
fn estimate_distinct_values(
&self,
solutions: &[Solution],
join_variables: &[Variable],
) -> usize {
let mut distinct_values = HashSet::new();
for solution in solutions {
for binding in solution {
for var in join_variables {
if let Some(term) = binding.get(var) {
distinct_values.insert(term.clone());
}
}
}
}
distinct_values.len()
}
fn is_sorted_by_join_keys(&self, solutions: &[Solution], join_variables: &[Variable]) -> bool {
if solutions.len() <= 1 {
return true;
}
let sample_size = (solutions.len() / 10).clamp(10, 100);
let step = solutions.len() / sample_size;
for i in 1..sample_size {
let idx = i * step;
if idx >= solutions.len() {
break;
}
if self.compare_solutions_by_join_key(
&solutions[idx - step],
&solutions[idx],
join_variables,
) == std::cmp::Ordering::Greater
{
return false;
}
}
true
}
fn compare_solutions_by_join_key(
&self,
left: &Solution,
right: &Solution,
join_variables: &[Variable],
) -> std::cmp::Ordering {
use std::cmp::Ordering;
let left_binding = left.first();
let right_binding = right.first();
match (left_binding, right_binding) {
(Some(l_binding), Some(r_binding)) => {
for var in join_variables {
let left_term = l_binding.get(var);
let right_term = r_binding.get(var);
let cmp = match (left_term, right_term) {
(Some(l), Some(r)) => {
format!("{l}").cmp(&format!("{r}"))
}
(Some(_), None) => Ordering::Greater,
(None, Some(_)) => Ordering::Less,
(None, None) => Ordering::Equal,
};
if cmp != Ordering::Equal {
return cmp;
}
}
Ordering::Equal
}
(Some(_), None) => Ordering::Greater,
(None, Some(_)) => Ordering::Less,
(None, None) => Ordering::Equal,
}
}
fn execute_nested_loop_join(
&self,
left_solutions: Vec<Solution>,
right_solutions: Vec<Solution>,
join_variables: &[Variable],
) -> Result<Vec<Solution>> {
let mut result = Vec::new();
for left_solution in &left_solutions {
for right_solution in &right_solutions {
if let Some(merged) =
self.try_merge_solutions(left_solution, right_solution, join_variables)?
{
result.push(merged);
}
}
}
Ok(result)
}
fn try_merge_solutions(
&self,
left: &Solution,
right: &Solution,
join_variables: &[Variable],
) -> Result<Option<Solution>> {
let mut result = Vec::new();
for left_binding in left {
for right_binding in right {
let mut compatible = true;
for var in join_variables {
if let (Some(left_term), Some(right_term)) =
(left_binding.get(var), right_binding.get(var))
{
if left_term != right_term {
compatible = false;
break;
}
}
}
if compatible {
let mut merged_binding = left_binding.clone();
for (var, term) in right_binding {
if !merged_binding.contains_key(var) {
merged_binding.insert(var.clone(), term.clone());
}
}
result.push(merged_binding);
}
}
}
if result.is_empty() {
Ok(None)
} else {
Ok(Some(result))
}
}
fn estimate_memory_usage(&self, solutions: &[Solution]) -> usize {
solutions.len() * 1024 }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OptimalJoinAlgorithm {
HashJoin,
SortMergeJoin,
NestedLoopJoin,
IndexJoin,
}
#[derive(Debug, Clone)]
pub struct JoinCharacteristics {
pub left_cardinality: usize,
pub right_cardinality: usize,
pub left_distinct_values: usize,
pub right_distinct_values: usize,
pub estimated_selectivity: f64,
pub left_sorted: bool,
pub right_sorted: bool,
pub memory_requirement: usize,
pub join_variable_count: usize,
}
#[derive(Debug, Clone)]
pub struct JoinExecutionStats {
pub algorithm_used: OptimalJoinAlgorithm,
pub execution_time: std::time::Duration,
pub input_cardinalities: (usize, usize),
pub output_cardinality: usize,
pub memory_used: usize,
pub join_selectivity: f64,
}
impl JoinExecutionStats {
pub fn performance_summary(&self) -> String {
format!(
"Algorithm: {:?}, Time: {:?}, Input: ({}, {}), Output: {}, Selectivity: {:.4}, Memory: {} bytes",
self.algorithm_used,
self.execution_time,
self.input_cardinalities.0,
self.input_cardinalities.1,
self.output_cardinality,
self.join_selectivity,
self.memory_used
)
}
}
#[cfg(feature = "parallel")]
pub struct ParallelHashJoinAccelerator;
#[cfg(feature = "parallel")]
impl Default for ParallelHashJoinAccelerator {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "parallel")]
impl ParallelHashJoinAccelerator {
pub fn new() -> Self {
Self
}
pub fn compute_hashes_parallel(&self, keys: &[String]) -> Result<Vec<u64>> {
use rayon::prelude::*;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let hashes: Vec<u64> = keys
.par_iter()
.map(|k| {
let mut hasher = DefaultHasher::new();
k.hash(&mut hasher);
hasher.finish()
})
.collect();
Ok(hashes)
}
pub fn parallel_equi_join(
&self,
left_keys: &[u64],
right_keys: &[u64],
) -> Result<Vec<(usize, usize)>> {
use rayon::prelude::*;
const CHUNK_SIZE: usize = 64;
let right_map: std::collections::HashMap<u64, Vec<usize>> = {
let mut map = std::collections::HashMap::new();
for (idx, &key) in right_keys.iter().enumerate() {
map.entry(key).or_insert_with(Vec::new).push(idx);
}
map
};
let matches: Vec<Vec<(usize, usize)>> = left_keys
.par_chunks(CHUNK_SIZE)
.enumerate()
.map(|(chunk_idx, chunk)| {
let mut chunk_matches = Vec::new();
for (offset, &key) in chunk.iter().enumerate() {
if let Some(right_indices) = right_map.get(&key) {
let left_idx = chunk_idx * CHUNK_SIZE + offset;
for &right_idx in right_indices {
chunk_matches.push((left_idx, right_idx));
}
}
}
chunk_matches
})
.collect();
Ok(matches.into_iter().flatten().collect())
}
pub fn parallel_partition_join(
&self,
left: Vec<(u64, usize)>,
right: Vec<(u64, usize)>,
num_partitions: usize,
) -> Result<Vec<(usize, usize)>> {
use rayon::prelude::*;
use std::sync::Arc;
let partition_mask = num_partitions - 1;
let mut left_partitions: Vec<Vec<(u64, usize)>> = vec![Vec::new(); num_partitions];
let mut right_partitions: Vec<Vec<(u64, usize)>> = vec![Vec::new(); num_partitions];
for (key, idx) in left {
let partition = (key as usize) & partition_mask;
left_partitions[partition].push((key, idx));
}
for (key, idx) in right {
let partition = (key as usize) & partition_mask;
right_partitions[partition].push((key, idx));
}
let right_partitions = Arc::new(right_partitions);
let matches: Vec<Vec<(usize, usize)>> = left_partitions
.into_par_iter()
.enumerate()
.map(|(p, left_partition)| {
let mut partition_matches = Vec::new();
let left_keys: Vec<u64> = left_partition.iter().map(|(k, _)| *k).collect();
let right_keys: Vec<u64> = right_partitions[p].iter().map(|(k, _)| *k).collect();
if let Ok(local_matches) = self.parallel_equi_join(&left_keys, &right_keys) {
for (l, r) in local_matches {
if l < left_partition.len() && r < right_partitions[p].len() {
partition_matches.push((left_partition[l].1, right_partitions[p][r].1));
}
}
}
partition_matches
})
.collect();
Ok(matches.into_iter().flatten().collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algebra::{Binding, Term, Variable};
use crate::cost_model::{CostModel, CostModelConfig};
use oxirs_core::model::NamedNode;
#[test]
fn test_join_algorithm_selection() {
let cost_model = CostModel::new(CostModelConfig::default());
let mut selector = JoinAlgorithmSelector::new(cost_model, 1024 * 1024);
let var_x = Variable::new("x").unwrap();
let _var_y = Variable::new("y").unwrap();
let left_solutions = vec![create_test_solution(&var_x, "value1")];
let right_solutions = vec![create_test_solution(&var_x, "value1")];
let join_variables = vec![var_x];
let result =
selector.execute_optimal_join(left_solutions, right_solutions, &join_variables);
assert!(result.is_ok());
let (solutions, stats) = result.unwrap();
assert!(!solutions.is_empty());
println!("Join stats: {}", stats.performance_summary());
}
fn create_test_solution(variable: &Variable, value: &str) -> Solution {
let mut binding = Binding::new();
binding.insert(
variable.clone(),
Term::Iri(NamedNode::new_unchecked(format!(
"http://example.org/{value}"
))),
);
vec![binding]
}
fn create_multi_var_solution(
var_x: &Variable,
val_x: &str,
var_y: &Variable,
val_y: &str,
) -> Solution {
let mut binding = Binding::new();
binding.insert(
var_x.clone(),
Term::Iri(NamedNode::new_unchecked(format!(
"http://example.org/x/{val_x}"
))),
);
binding.insert(
var_y.clone(),
Term::Iri(NamedNode::new_unchecked(format!(
"http://example.org/y/{val_y}"
))),
);
vec![binding]
}
#[test]
fn test_empty_left_input() {
let cost_model = CostModel::new(CostModelConfig::default());
let mut selector = JoinAlgorithmSelector::new(cost_model, 1024 * 1024);
let var_x = Variable::new("x").unwrap();
let left_solutions = vec![];
let right_solutions = vec![create_test_solution(&var_x, "value1")];
let join_variables = vec![var_x];
let result =
selector.execute_optimal_join(left_solutions, right_solutions, &join_variables);
assert!(result.is_ok());
let (solutions, stats) = result.unwrap();
assert!(solutions.is_empty());
assert_eq!(stats.output_cardinality, 0);
}
#[test]
fn test_empty_right_input() {
let cost_model = CostModel::new(CostModelConfig::default());
let mut selector = JoinAlgorithmSelector::new(cost_model, 1024 * 1024);
let var_x = Variable::new("x").unwrap();
let left_solutions = vec![create_test_solution(&var_x, "value1")];
let right_solutions = vec![];
let join_variables = vec![var_x];
let result =
selector.execute_optimal_join(left_solutions, right_solutions, &join_variables);
assert!(result.is_ok());
let (solutions, stats) = result.unwrap();
assert!(solutions.is_empty());
assert_eq!(stats.output_cardinality, 0);
}
#[test]
fn test_no_matching_values() {
let cost_model = CostModel::new(CostModelConfig::default());
let mut selector = JoinAlgorithmSelector::new(cost_model, 1024 * 1024);
let var_x = Variable::new("x").unwrap();
let left_solutions = vec![create_test_solution(&var_x, "value1")];
let right_solutions = vec![create_test_solution(&var_x, "value2")];
let join_variables = vec![var_x];
let result =
selector.execute_optimal_join(left_solutions, right_solutions, &join_variables);
assert!(result.is_ok());
let (solutions, _stats) = result.unwrap();
assert!(
solutions.is_empty(),
"No matches should result in empty output"
);
}
#[test]
fn test_multiple_join_variables() {
let cost_model = CostModel::new(CostModelConfig::default());
let mut selector = JoinAlgorithmSelector::new(cost_model, 1024 * 1024);
let var_x = Variable::new("x").unwrap();
let var_y = Variable::new("y").unwrap();
let left_solutions = vec![
create_multi_var_solution(&var_x, "a", &var_y, "1"),
create_multi_var_solution(&var_x, "b", &var_y, "2"),
];
let right_solutions = vec![
create_multi_var_solution(&var_x, "a", &var_y, "1"),
create_multi_var_solution(&var_x, "c", &var_y, "3"),
];
let join_variables = vec![var_x, var_y];
let result =
selector.execute_optimal_join(left_solutions, right_solutions, &join_variables);
assert!(result.is_ok());
let (solutions, stats) = result.unwrap();
assert_eq!(solutions.len(), 1, "Should find one matching solution");
assert!(stats.join_selectivity > 0.0);
}
#[test]
fn test_large_dataset_join() {
let cost_model = CostModel::new(CostModelConfig::default());
let mut selector = JoinAlgorithmSelector::new(cost_model, 1024 * 1024);
let var_x = Variable::new("x").unwrap();
let left_solutions: Vec<_> = (0..100)
.map(|i| create_test_solution(&var_x, &format!("value{}", i)))
.collect();
let right_solutions: Vec<_> = (50..150)
.map(|i| create_test_solution(&var_x, &format!("value{}", i)))
.collect();
let join_variables = vec![var_x];
let result =
selector.execute_optimal_join(left_solutions, right_solutions, &join_variables);
assert!(result.is_ok());
let (solutions, stats) = result.unwrap();
assert_eq!(solutions.len(), 50);
assert!(
!stats.execution_time.is_zero(),
"Execution time should be measured"
);
println!("Large join stats: {}", stats.performance_summary());
}
#[test]
fn test_hash_join_preferred_for_large_inputs() {
let cost_model = CostModel::new(CostModelConfig::default());
let mut selector = JoinAlgorithmSelector::new(cost_model, 10 * 1024 * 1024);
let var_x = Variable::new("x").unwrap();
let left_solutions: Vec<_> = (0..1000)
.map(|i| create_test_solution(&var_x, &format!("value{}", i)))
.collect();
let right_solutions: Vec<_> = (0..1000)
.map(|i| create_test_solution(&var_x, &format!("value{}", i)))
.collect();
let join_variables = vec![var_x];
let result =
selector.execute_optimal_join(left_solutions, right_solutions, &join_variables);
assert!(result.is_ok());
let (_solutions, stats) = result.unwrap();
assert!(
matches!(
stats.algorithm_used,
OptimalJoinAlgorithm::HashJoin | OptimalJoinAlgorithm::SortMergeJoin
),
"Large inputs should use hash join or sort-merge join"
);
}
#[test]
fn test_selectivity_calculation() {
let cost_model = CostModel::new(CostModelConfig::default());
let mut selector = JoinAlgorithmSelector::new(cost_model, 1024 * 1024);
let var_x = Variable::new("x").unwrap();
let left_solutions: Vec<_> = (0..10)
.map(|i| create_test_solution(&var_x, &format!("value{}", i)))
.collect();
let right_solutions: Vec<_> = (5..15)
.map(|i| create_test_solution(&var_x, &format!("value{}", i)))
.collect();
let join_variables = vec![var_x];
let result =
selector.execute_optimal_join(left_solutions, right_solutions, &join_variables);
assert!(result.is_ok());
let (solutions, stats) = result.unwrap();
assert_eq!(solutions.len(), 5);
assert!(
(stats.join_selectivity - 0.05).abs() < 0.01,
"Selectivity calculation incorrect"
);
}
#[test]
fn test_join_stats_reporting() {
let cost_model = CostModel::new(CostModelConfig::default());
let mut selector = JoinAlgorithmSelector::new(cost_model, 1024 * 1024);
let var_x = Variable::new("x").unwrap();
let left_solutions = vec![create_test_solution(&var_x, "value1")];
let right_solutions = vec![create_test_solution(&var_x, "value1")];
let join_variables = vec![var_x];
let result =
selector.execute_optimal_join(left_solutions, right_solutions, &join_variables);
assert!(result.is_ok());
let (_solutions, stats) = result.unwrap();
assert_eq!(stats.input_cardinalities.0, 1);
assert_eq!(stats.input_cardinalities.1, 1);
assert_eq!(stats.output_cardinality, 1);
assert!(stats.memory_used > 0);
assert!(stats.join_selectivity > 0.0);
let summary = stats.performance_summary();
assert!(
summary.contains("Algorithm"),
"Summary should contain algorithm info: {}",
summary
);
}
#[test]
fn test_cartesian_product() {
let cost_model = CostModel::new(CostModelConfig::default());
let mut selector = JoinAlgorithmSelector::new(cost_model, 1024 * 1024);
let var_x = Variable::new("x").unwrap();
let var_y = Variable::new("y").unwrap();
let left_solutions = vec![
create_test_solution(&var_x, "a"),
create_test_solution(&var_x, "b"),
];
let right_solutions = vec![
create_test_solution(&var_y, "1"),
create_test_solution(&var_y, "2"),
];
let join_variables = vec![];
let result =
selector.execute_optimal_join(left_solutions, right_solutions, &join_variables);
assert!(result.is_ok());
let (solutions, stats) = result.unwrap();
assert_eq!(solutions.len(), 4);
assert!(
(stats.join_selectivity - 1.0).abs() < 0.01,
"Cartesian product selectivity should be ~1.0"
);
}
}