use rayon::prelude::*;
use rustc_hash::FxHashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use crate::core::{Result, Row, Value};
use crate::functions::FunctionRegistry;
use crate::parser::ast::Expression;
use super::expression::{ExpressionEval, RowFilter};
use super::utils::{hash_composite_key, hash_row, rows_equal, verify_composite_key_equality};
pub const DEFAULT_PARALLEL_FILTER_THRESHOLD: usize = 10_000;
pub const DEFAULT_PARALLEL_SORT_THRESHOLD: usize = 50_000;
pub const DEFAULT_PARALLEL_JOIN_THRESHOLD: usize = 5_000;
pub const DEFAULT_PARALLEL_CHUNK_SIZE: usize = 2048;
#[derive(Clone, Debug)]
pub struct ParallelConfig {
pub enabled: bool,
pub min_rows_for_parallel_filter: usize,
pub min_rows_for_parallel_sort: usize,
pub min_rows_for_parallel_join: usize,
pub chunk_size: usize,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
enabled: true,
min_rows_for_parallel_filter: DEFAULT_PARALLEL_FILTER_THRESHOLD,
min_rows_for_parallel_sort: DEFAULT_PARALLEL_SORT_THRESHOLD,
min_rows_for_parallel_join: DEFAULT_PARALLEL_JOIN_THRESHOLD,
chunk_size: DEFAULT_PARALLEL_CHUNK_SIZE,
}
}
}
impl ParallelConfig {
pub fn new(
enabled: bool,
min_rows_for_parallel_filter: usize,
min_rows_for_parallel_sort: usize,
min_rows_for_parallel_join: usize,
chunk_size: usize,
) -> Self {
Self {
enabled,
min_rows_for_parallel_filter,
min_rows_for_parallel_sort,
min_rows_for_parallel_join,
chunk_size,
}
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
#[inline]
pub fn should_parallel_filter(&self, row_count: usize) -> bool {
self.enabled && row_count >= self.min_rows_for_parallel_filter
}
#[inline]
pub fn should_parallel_sort(&self, row_count: usize) -> bool {
self.enabled && row_count >= self.min_rows_for_parallel_sort
}
#[inline]
pub fn should_parallel_join(&self, build_rows: usize) -> bool {
self.enabled && build_rows >= self.min_rows_for_parallel_join
}
}
pub fn parallel_filter(
rows: Vec<Row>,
filter_expr: &Expression,
columns: &[String],
function_registry: &FunctionRegistry,
config: &ParallelConfig,
) -> Result<Vec<Row>> {
let row_count = rows.len();
if !config.should_parallel_filter(row_count) {
return sequential_filter(rows, filter_expr, columns, function_registry);
}
let num_threads = rayon::current_num_threads();
let target_chunks = num_threads * 4; let chunk_size = (row_count / target_chunks).max(config.chunk_size).max(512);
let columns_vec: Vec<String> = columns.to_vec();
let filter = RowFilter::new(filter_expr, &columns_vec)?;
let filtered_chunks: Vec<Vec<Row>> = rows
.into_par_iter()
.chunks(chunk_size)
.map(|chunk| {
let mut filtered = Vec::with_capacity(chunk.len() / 2);
for row in chunk {
if filter.matches(&row) {
filtered.push(row);
}
}
filtered
})
.collect();
let total_size: usize = filtered_chunks.iter().map(|c| c.len()).sum();
let mut result = Vec::with_capacity(total_size);
for chunk in filtered_chunks {
result.extend(chunk);
}
Ok(result)
}
fn sequential_filter(
rows: Vec<Row>,
filter_expr: &Expression,
columns: &[String],
_function_registry: &FunctionRegistry,
) -> Result<Vec<Row>> {
let columns_vec: Vec<String> = columns.to_vec();
let mut eval = ExpressionEval::compile(filter_expr, &columns_vec)?;
Ok(rows.into_iter().filter(|row| eval.eval_bool(row)).collect())
}
pub fn parallel_filter_owned(
rows: Vec<Row>,
predicate: impl Fn(&Row) -> bool + Sync + Send,
config: &ParallelConfig,
) -> Vec<Row> {
let row_count = rows.len();
if !config.should_parallel_filter(row_count) {
return rows.into_iter().filter(|r| predicate(r)).collect();
}
rows.into_par_iter().filter(|r| predicate(r)).collect()
}
pub fn parallel_sort<F>(rows: &mut [Row], compare: F, config: &ParallelConfig)
where
F: Fn(&Row, &Row) -> std::cmp::Ordering + Sync + Send,
{
if config.should_parallel_sort(rows.len()) {
rows.par_sort_by(compare);
} else {
rows.sort_by(compare);
}
}
pub fn parallel_sort_unstable<F>(rows: &mut [Row], compare: F, config: &ParallelConfig)
where
F: Fn(&Row, &Row) -> std::cmp::Ordering + Sync + Send,
{
if config.should_parallel_sort(rows.len()) {
rows.par_sort_unstable_by(compare);
} else {
rows.sort_unstable_by(compare);
}
}
pub fn parallel_distinct(rows: Vec<Row>, config: &ParallelConfig) -> Vec<Row> {
let row_count = rows.len();
if !config.should_parallel_filter(row_count) {
return sequential_distinct(rows);
}
let num_threads = rayon::current_num_threads();
let chunk_size = config.chunk_size.max(row_count / num_threads).max(1000);
let deduped_chunks: Vec<Vec<(u64, Row)>> = rows
.into_par_iter()
.chunks(chunk_size)
.map(|chunk| {
let mut hash_to_indices: FxHashMap<u64, Vec<usize>> = FxHashMap::default();
let mut unique_with_hashes: Vec<(u64, Row)> = Vec::with_capacity(chunk.len());
for row in chunk {
let hash = hash_row(&row);
let indices = hash_to_indices.entry(hash).or_default();
let is_duplicate = indices
.iter()
.any(|&idx| rows_equal(&unique_with_hashes[idx].1, &row));
if !is_duplicate {
indices.push(unique_with_hashes.len());
unique_with_hashes.push((hash, row));
}
}
unique_with_hashes
})
.collect();
let total_size: usize = deduped_chunks.iter().map(|chunk| chunk.len()).sum();
let estimated_size = (total_size * 3) / 4;
let mut result = Vec::with_capacity(estimated_size);
let mut hash_to_indices: FxHashMap<u64, Vec<usize>> = FxHashMap::default();
for chunk in deduped_chunks {
for (hash, row) in chunk {
let indices = hash_to_indices.entry(hash).or_default();
let is_duplicate = indices.iter().any(|&idx| rows_equal(&result[idx], &row));
if !is_duplicate {
indices.push(result.len());
result.push(row); }
}
}
result
}
fn sequential_distinct(rows: Vec<Row>) -> Vec<Row> {
let mut hash_to_indices: FxHashMap<u64, Vec<usize>> = FxHashMap::default();
let mut result = Vec::with_capacity(rows.len());
for row in rows {
let hash = hash_row(&row);
let indices = hash_to_indices.entry(hash).or_default();
let is_duplicate = indices.iter().any(|&idx| rows_equal(&result[idx], &row));
if !is_duplicate {
indices.push(result.len());
result.push(row); }
}
result
}
pub fn parallel_project<F>(rows: Vec<Row>, project_fn: F, config: &ParallelConfig) -> Vec<Row>
where
F: Fn(&Row) -> Row + Sync + Send,
{
let row_count = rows.len();
if !config.should_parallel_filter(row_count) {
return rows.iter().map(&project_fn).collect();
}
rows.into_par_iter().map(|r| project_fn(&r)).collect()
}
#[derive(Clone, Debug, Default)]
pub struct ParallelStats {
pub rows_processed: usize,
pub rows_passed: usize,
pub chunks_used: usize,
pub parallel_used: bool,
}
enum HashTableStorage {
Sequential(FxHashMap<u64, Vec<usize>>),
Parallel(dashmap::DashMap<u64, Vec<usize>>),
}
struct BuildMatchedTracker {
matched: Vec<AtomicBool>,
}
impl BuildMatchedTracker {
fn new(size: usize) -> Self {
BuildMatchedTracker {
matched: (0..size).map(|_| AtomicBool::new(false)).collect(),
}
}
#[inline]
fn mark_matched(&self, idx: usize) {
self.matched[idx].store(true, Ordering::Release);
}
#[inline]
fn was_matched(&self, idx: usize) -> bool {
self.matched[idx].load(Ordering::Acquire)
}
}
impl HashTableStorage {
#[inline]
fn get(&self, key: &u64) -> Option<Vec<usize>> {
match self {
HashTableStorage::Sequential(map) => map.get(key).cloned(),
HashTableStorage::Parallel(map) => map.get(key).map(|v| v.clone()),
}
}
}
pub struct ParallelHashTable {
storage: HashTableStorage,
pub row_count: usize,
}
impl ParallelHashTable {
#[inline]
pub fn get(&self, key: &u64) -> Option<Vec<usize>> {
self.storage.get(key)
}
}
pub fn parallel_hash_build(
build_rows: &[Row],
key_indices: &[usize],
config: &ParallelConfig,
) -> ParallelHashTable {
use dashmap::DashMap;
let row_count = build_rows.len();
if !config.should_parallel_join(row_count) {
let mut table: FxHashMap<u64, Vec<usize>> =
FxHashMap::with_capacity_and_hasher(row_count, Default::default());
for (idx, row) in build_rows.iter().enumerate() {
let hash = hash_composite_key(row, key_indices);
table.entry(hash).or_default().push(idx);
}
return ParallelHashTable {
storage: HashTableStorage::Sequential(table),
row_count,
};
}
let table: DashMap<u64, Vec<usize>> = DashMap::with_capacity(row_count);
let num_threads = rayon::current_num_threads();
let chunk_size = config.chunk_size.max(row_count / num_threads).max(1000);
build_rows
.par_chunks(chunk_size)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let base_idx = chunk_idx * chunk_size;
for (local_idx, row) in chunk.iter().enumerate() {
debug_assert!(
base_idx.checked_add(local_idx).is_some(),
"Index overflow in parallel hash build: base_idx={} + local_idx={}",
base_idx,
local_idx
);
let global_idx = base_idx + local_idx;
let hash = hash_composite_key(row, key_indices);
table.entry(hash).or_default().push(global_idx);
}
});
ParallelHashTable {
storage: HashTableStorage::Parallel(table),
row_count,
}
}
pub fn parallel_hash_probe<F>(
probe_rows: &[Row],
probe_key_indices: &[usize],
hash_table: &ParallelHashTable,
build_rows: &[Row],
verify_match: F,
config: &ParallelConfig,
) -> Vec<(usize, usize)>
where
F: Fn(&Row, &Row) -> bool + Sync + Send,
{
let probe_count = probe_rows.len();
if !config.should_parallel_join(probe_count) {
let mut matches = Vec::new();
for (probe_idx, probe_row) in probe_rows.iter().enumerate() {
let hash = hash_composite_key(probe_row, probe_key_indices);
if let Some(build_indices) = hash_table.get(&hash) {
for build_idx in build_indices {
if verify_match(probe_row, &build_rows[build_idx]) {
matches.push((probe_idx, build_idx));
}
}
}
}
return matches;
}
let num_threads = rayon::current_num_threads();
let chunk_size = config.chunk_size.max(probe_count / num_threads).max(1000);
probe_rows
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(chunk_idx, chunk)| {
let base_idx = chunk_idx * chunk_size;
let mut local_matches = Vec::new();
for (local_idx, probe_row) in chunk.iter().enumerate() {
let probe_idx = base_idx + local_idx;
let hash = hash_composite_key(probe_row, probe_key_indices);
if let Some(build_indices) = hash_table.get(&hash) {
for build_idx in build_indices {
if verify_match(probe_row, &build_rows[build_idx]) {
local_matches.push((probe_idx, build_idx));
}
}
}
}
local_matches
})
.collect()
}
#[inline]
pub fn hash_row_by_keys(row: &Row, key_indices: &[usize]) -> u64 {
hash_composite_key(row, key_indices)
}
#[inline]
pub fn verify_key_match(
probe_row: &Row,
build_row: &Row,
probe_key_indices: &[usize],
build_key_indices: &[usize],
) -> bool {
verify_composite_key_equality(probe_row, build_row, probe_key_indices, build_key_indices)
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum JoinType {
Inner,
Left,
Right,
Full,
}
impl JoinType {
#[allow(clippy::should_implement_trait)]
pub fn from_str(s: &str) -> Self {
let upper = s.to_uppercase();
if upper.contains("FULL") {
JoinType::Full
} else if upper.contains("LEFT") {
JoinType::Left
} else if upper.contains("RIGHT") {
JoinType::Right
} else {
JoinType::Inner
}
}
fn needs_unmatched_probe(&self, swapped: bool) -> bool {
match self {
JoinType::Inner => false,
JoinType::Left => !swapped, JoinType::Right => swapped, JoinType::Full => true, }
}
fn needs_unmatched_build(&self, swapped: bool) -> bool {
match self {
JoinType::Inner => false,
JoinType::Left => swapped, JoinType::Right => !swapped, JoinType::Full => true, }
}
}
pub struct ParallelJoinResult {
pub rows: Vec<Row>,
pub parallel_used: bool,
pub probe_rows_processed: usize,
pub build_rows_count: usize,
pub matches_found: usize,
}
#[allow(clippy::too_many_arguments)]
pub fn parallel_hash_join(
probe_rows: &[Row],
build_rows: &[Row],
probe_key_indices: &[usize],
build_key_indices: &[usize],
join_type: JoinType,
probe_col_count: usize,
build_col_count: usize,
swapped: bool,
config: &ParallelConfig,
) -> ParallelJoinResult {
let probe_count = probe_rows.len();
let build_count = build_rows.len();
let use_parallel =
config.should_parallel_join(build_count) || config.should_parallel_join(probe_count);
let hash_table = parallel_hash_build(build_rows, build_key_indices, config);
let build_matched: Option<BuildMatchedTracker> = if join_type.needs_unmatched_build(swapped) {
Some(BuildMatchedTracker::new(build_count))
} else {
None
};
let (matched_rows, unmatched_probe_rows) = if use_parallel && join_type == JoinType::Inner {
let matches: Vec<Row> = probe_rows
.par_chunks(config.chunk_size.max(1000))
.flat_map(|chunk| {
let mut local_results = Vec::new();
for probe_row in chunk {
let hash = hash_row_by_keys(probe_row, probe_key_indices);
if let Some(build_indices) = hash_table.get(&hash) {
for build_idx in build_indices {
let build_row = &build_rows[build_idx];
if verify_key_match(
probe_row,
build_row,
probe_key_indices,
build_key_indices,
) {
let combined = combine_join_rows(
probe_row,
build_row,
probe_col_count,
build_col_count,
swapped,
);
local_results.push(Row::from_values(combined));
}
}
}
}
local_results
})
.collect();
(matches, Vec::new())
} else if use_parallel {
let needs_unmatched_probe = join_type.needs_unmatched_probe(swapped);
let chunk_results: Vec<(Vec<Row>, Vec<Row>)> = probe_rows
.par_chunks(config.chunk_size.max(1000))
.map(|chunk| {
let mut matched_results = Vec::new();
let mut unmatched_results = Vec::new();
for probe_row in chunk.iter() {
let mut matched = false;
let hash = hash_row_by_keys(probe_row, probe_key_indices);
if let Some(build_indices) = hash_table.get(&hash) {
for build_idx in build_indices {
let build_row = &build_rows[build_idx];
if verify_key_match(
probe_row,
build_row,
probe_key_indices,
build_key_indices,
) {
matched = true;
if let Some(ref tracker) = build_matched {
tracker.mark_matched(build_idx);
}
let combined = combine_join_rows(
probe_row,
build_row,
probe_col_count,
build_col_count,
swapped,
);
matched_results.push(Row::from_values(combined));
}
}
}
if !matched && needs_unmatched_probe {
let values = combine_with_nulls(
probe_row,
probe_col_count,
build_col_count,
swapped,
);
unmatched_results.push(Row::from_values(values));
}
}
(matched_results, unmatched_results)
})
.collect();
let total_matched: usize = chunk_results.iter().map(|(m, _)| m.len()).sum();
let total_unmatched: usize = chunk_results.iter().map(|(_, u)| u.len()).sum();
let mut matched_rows = Vec::with_capacity(total_matched);
let mut unmatched_rows = Vec::with_capacity(total_unmatched);
for (matched, unmatched) in chunk_results {
matched_rows.extend(matched);
unmatched_rows.extend(unmatched);
}
std::sync::atomic::fence(Ordering::Acquire);
(matched_rows, unmatched_rows)
} else {
let mut matched_rows = Vec::new();
let needs_unmatched_probe = join_type.needs_unmatched_probe(swapped);
for probe_row in probe_rows.iter() {
let hash = hash_row_by_keys(probe_row, probe_key_indices);
let mut matched = false;
if let Some(build_indices) = hash_table.get(&hash) {
for build_idx in build_indices {
let build_row = &build_rows[build_idx];
if verify_key_match(probe_row, build_row, probe_key_indices, build_key_indices)
{
matched = true;
if let Some(ref tracker) = build_matched {
tracker.mark_matched(build_idx);
}
let combined = combine_join_rows(
probe_row,
build_row,
probe_col_count,
build_col_count,
swapped,
);
matched_rows.push(Row::from_values(combined));
}
}
}
if !matched && needs_unmatched_probe {
let values =
combine_with_nulls(probe_row, probe_col_count, build_col_count, swapped);
matched_rows.push(Row::from_values(values));
}
}
(matched_rows, Vec::new())
};
let mut result_rows = matched_rows;
result_rows.extend(unmatched_probe_rows);
if let Some(ref tracker) = build_matched {
for (build_idx, build_row) in build_rows.iter().enumerate() {
if !tracker.was_matched(build_idx) {
let values =
combine_build_with_nulls(build_row, build_col_count, probe_col_count, swapped);
result_rows.push(Row::from_values(values));
}
}
}
let matches_found = result_rows.len();
ParallelJoinResult {
rows: result_rows,
parallel_used: use_parallel,
probe_rows_processed: probe_count,
build_rows_count: build_count,
matches_found,
}
}
#[inline]
fn combine_join_rows(
probe_row: &Row,
build_row: &Row,
probe_col_count: usize,
build_col_count: usize,
swapped: bool,
) -> Vec<Value> {
let mut combined = Vec::with_capacity(probe_col_count + build_col_count);
if swapped {
for i in 0..build_col_count {
combined.push(
build_row
.get(i)
.cloned()
.unwrap_or_else(Value::null_unknown),
);
}
for i in 0..probe_col_count {
combined.push(
probe_row
.get(i)
.cloned()
.unwrap_or_else(Value::null_unknown),
);
}
} else {
for i in 0..probe_col_count {
combined.push(
probe_row
.get(i)
.cloned()
.unwrap_or_else(Value::null_unknown),
);
}
for i in 0..build_col_count {
combined.push(
build_row
.get(i)
.cloned()
.unwrap_or_else(Value::null_unknown),
);
}
}
combined
}
#[inline]
fn combine_with_nulls(
probe_row: &Row,
probe_col_count: usize,
build_col_count: usize,
swapped: bool,
) -> Vec<Value> {
let mut combined = Vec::with_capacity(probe_col_count + build_col_count);
if swapped {
for _ in 0..build_col_count {
combined.push(Value::null_unknown());
}
for i in 0..probe_col_count {
combined.push(
probe_row
.get(i)
.cloned()
.unwrap_or_else(Value::null_unknown),
);
}
} else {
for i in 0..probe_col_count {
combined.push(
probe_row
.get(i)
.cloned()
.unwrap_or_else(Value::null_unknown),
);
}
for _ in 0..build_col_count {
combined.push(Value::null_unknown());
}
}
combined
}
#[inline]
fn combine_build_with_nulls(
build_row: &Row,
build_col_count: usize,
probe_col_count: usize,
swapped: bool,
) -> Vec<Value> {
let mut combined = Vec::with_capacity(probe_col_count + build_col_count);
if swapped {
for i in 0..build_col_count {
combined.push(
build_row
.get(i)
.cloned()
.unwrap_or_else(Value::null_unknown),
);
}
for _ in 0..probe_col_count {
combined.push(Value::null_unknown());
}
} else {
for _ in 0..probe_col_count {
combined.push(Value::null_unknown());
}
for i in 0..build_col_count {
combined.push(
build_row
.get(i)
.cloned()
.unwrap_or_else(Value::null_unknown),
);
}
}
combined
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SortDirection {
Ascending,
Descending,
}
#[derive(Clone, Debug)]
pub struct SortSpec {
pub column_index: usize,
pub direction: SortDirection,
pub nulls_first: bool,
}
pub fn parallel_order_by(rows: &mut [Row], sort_specs: &[SortSpec], config: &ParallelConfig) {
let compare = |a: &Row, b: &Row| -> std::cmp::Ordering {
for spec in sort_specs {
let a_val = a.get(spec.column_index);
let b_val = b.get(spec.column_index);
let a_is_null = a_val.map(|v| v.is_null()).unwrap_or(true);
let b_is_null = b_val.map(|v| v.is_null()).unwrap_or(true);
let ordering = match (a_is_null, b_is_null) {
(true, true) => std::cmp::Ordering::Equal,
(true, false) => {
if spec.nulls_first {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Greater
}
}
(false, true) => {
if spec.nulls_first {
std::cmp::Ordering::Greater
} else {
std::cmp::Ordering::Less
}
}
(false, false) => {
let a_v = a_val.unwrap();
let b_v = b_val.unwrap();
a_v.partial_cmp(b_v).unwrap_or(std::cmp::Ordering::Equal)
}
};
let ordering = if spec.direction == SortDirection::Descending {
ordering.reverse()
} else {
ordering
};
if ordering != std::cmp::Ordering::Equal {
return ordering;
}
}
std::cmp::Ordering::Equal
};
if config.should_parallel_sort(rows.len()) {
rows.par_sort_by(compare);
} else {
rows.sort_by(compare);
}
}
pub fn parallel_order_by_fn<F>(rows: &mut [Row], compare: F, config: &ParallelConfig)
where
F: Fn(&Row, &Row) -> std::cmp::Ordering + Sync + Send,
{
if config.should_parallel_sort(rows.len()) {
rows.par_sort_by(compare);
} else {
rows.sort_by(compare);
}
}
pub fn parallel_order_by_unstable<F>(rows: &mut [Row], compare: F, config: &ParallelConfig)
where
F: Fn(&Row, &Row) -> std::cmp::Ordering + Sync + Send,
{
if config.should_parallel_sort(rows.len()) {
rows.par_sort_unstable_by(compare);
} else {
rows.sort_unstable_by(compare);
}
}
pub fn parallel_filter_with_stats(
rows: Vec<Row>,
filter_expr: &Expression,
columns: &[String],
function_registry: &FunctionRegistry,
config: &ParallelConfig,
) -> Result<(Vec<Row>, ParallelStats)> {
let row_count = rows.len();
let parallel_used = config.should_parallel_filter(row_count);
let num_threads = rayon::current_num_threads();
let chunk_size = if parallel_used {
config
.chunk_size
.max(row_count / (num_threads * 4))
.max(512)
} else {
row_count };
let chunks_used = row_count.div_ceil(chunk_size);
let result = if parallel_used {
parallel_filter(rows, filter_expr, columns, function_registry, config)?
} else {
sequential_filter(rows, filter_expr, columns, function_registry)?
};
let stats = ParallelStats {
rows_processed: row_count,
rows_passed: result.len(),
chunks_used,
parallel_used,
};
Ok((result, stats))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::Value;
fn make_test_rows(count: usize) -> Vec<Row> {
(0..count)
.map(|i| {
Row::from_values(vec![
Value::Integer(i as i64),
Value::Integer(i as i64 % 10),
])
})
.collect()
}
#[test]
fn test_parallel_distinct() {
let mut rows = Vec::new();
for i in 0..1000 {
rows.push(Row::from_values(vec![Value::Integer(i % 100)]));
}
let config = ParallelConfig {
min_rows_for_parallel_filter: 100, ..Default::default()
};
let result = parallel_distinct(rows, &config);
assert_eq!(result.len(), 100);
}
#[test]
fn test_parallel_sort() {
let mut rows: Vec<Row> = (0..1000)
.rev()
.map(|i| Row::from_values(vec![Value::Integer(i)]))
.collect();
let config = ParallelConfig {
min_rows_for_parallel_sort: 100,
..Default::default()
};
parallel_sort(
&mut rows,
|a, b| {
let a_val = a.get(0).and_then(|v| v.as_int64()).unwrap_or(0);
let b_val = b.get(0).and_then(|v| v.as_int64()).unwrap_or(0);
a_val.cmp(&b_val)
},
&config,
);
for (i, row) in rows.iter().enumerate() {
assert_eq!(row.get(0), Some(&Value::Integer(i as i64)));
}
}
#[test]
fn test_parallel_config_thresholds() {
let config = ParallelConfig::default();
assert!(!config.should_parallel_filter(1000)); assert!(config.should_parallel_filter(20_000));
assert!(!config.should_parallel_sort(10_000)); assert!(config.should_parallel_sort(100_000));
let disabled = ParallelConfig::disabled();
assert!(!disabled.should_parallel_filter(1_000_000)); }
#[test]
fn test_parallel_filter_owned() {
let rows = make_test_rows(50_000);
let config = ParallelConfig::default();
let result = parallel_filter_owned(
rows,
|row| {
if let Some(Value::Integer(v)) = row.get(1) {
*v < 5
} else {
false
}
},
&config,
);
assert_eq!(result.len(), 25_000);
}
#[test]
fn test_sequential_fallback_small_dataset() {
let rows = make_test_rows(100); let config = ParallelConfig::default();
let result = parallel_filter_owned(
rows,
|row| {
if let Some(Value::Integer(v)) = row.get(1) {
*v < 5
} else {
false
}
},
&config,
);
assert_eq!(result.len(), 50);
}
#[test]
fn test_parallel_hash_build() {
let build_rows: Vec<Row> = (0..10_000)
.map(|i| {
Row::from_values(vec![
Value::Integer(i),
Value::Text(format!("build_{}", i).into()),
])
})
.collect();
let config = ParallelConfig {
min_rows_for_parallel_join: 1000,
..Default::default()
};
let hash_table = parallel_hash_build(&build_rows, &[0], &config);
assert_eq!(hash_table.row_count, 10_000);
let test_hash = hash_row_by_keys(&build_rows[500], &[0]);
assert!(hash_table.get(&test_hash).is_some());
}
#[test]
fn test_parallel_hash_probe() {
let build_rows: Vec<Row> = (0..5_000)
.map(|i| {
Row::from_values(vec![
Value::Integer(i),
Value::Text(format!("build_{}", i).into()),
])
})
.collect();
let probe_rows: Vec<Row> = (0..10_000)
.map(|i| {
Row::from_values(vec![
Value::Integer(i * 2 % 5_000), Value::Text(format!("probe_{}", i).into()),
])
})
.collect();
let config = ParallelConfig {
min_rows_for_parallel_join: 1000,
..Default::default()
};
let hash_table = parallel_hash_build(&build_rows, &[0], &config);
let matches = parallel_hash_probe(
&probe_rows,
&[0],
&hash_table,
&build_rows,
|probe, build| {
probe.get(0) == build.get(0)
},
&config,
);
assert_eq!(matches.len(), 10_000);
}
#[test]
fn test_verify_key_match() {
let row1 = Row::from_values(vec![Value::Integer(1), Value::Text("a".to_string().into())]);
let row2 = Row::from_values(vec![Value::Integer(1), Value::Text("b".to_string().into())]);
let row3 = Row::from_values(vec![Value::Integer(2), Value::Text("a".to_string().into())]);
assert!(verify_key_match(&row1, &row2, &[0], &[0]));
assert!(!verify_key_match(&row1, &row3, &[0], &[0]));
assert!(verify_key_match(&row1, &row3, &[1], &[1]));
}
#[test]
fn test_parallel_order_by() {
let mut rows: Vec<Row> = (0..1000)
.map(|i| {
Row::from_values(vec![
Value::Integer((i * 7 + 13) % 1000), Value::Text(format!("row_{}", i).into()),
])
})
.collect();
let config = ParallelConfig {
min_rows_for_parallel_sort: 100,
..Default::default()
};
let sort_specs = vec![SortSpec {
column_index: 0,
direction: SortDirection::Ascending,
nulls_first: false,
}];
parallel_order_by(&mut rows, &sort_specs, &config);
for i in 1..rows.len() {
let prev = rows[i - 1].get(0).and_then(|v| v.as_int64()).unwrap();
let curr = rows[i].get(0).and_then(|v| v.as_int64()).unwrap();
assert!(prev <= curr, "Row {} should be <= row {}", i - 1, i);
}
}
#[test]
fn test_parallel_order_by_descending() {
let mut rows: Vec<Row> = (0..500)
.map(|i| Row::from_values(vec![Value::Integer(i)]))
.collect();
let config = ParallelConfig {
min_rows_for_parallel_sort: 100,
..Default::default()
};
let sort_specs = vec![SortSpec {
column_index: 0,
direction: SortDirection::Descending,
nulls_first: false,
}];
parallel_order_by(&mut rows, &sort_specs, &config);
for i in 1..rows.len() {
let prev = rows[i - 1].get(0).and_then(|v| v.as_int64()).unwrap();
let curr = rows[i].get(0).and_then(|v| v.as_int64()).unwrap();
assert!(prev >= curr, "Row {} should be >= row {}", i - 1, i);
}
}
#[test]
fn test_parallel_order_by_with_nulls() {
let mut rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(3)]),
Row::from_values(vec![Value::null_unknown()]),
Row::from_values(vec![Value::Integer(1)]),
Row::from_values(vec![Value::null_unknown()]),
Row::from_values(vec![Value::Integer(2)]),
];
let config = ParallelConfig {
min_rows_for_parallel_sort: 1, ..Default::default()
};
let sort_specs = vec![SortSpec {
column_index: 0,
direction: SortDirection::Ascending,
nulls_first: true,
}];
parallel_order_by(&mut rows, &sort_specs, &config);
assert!(rows[0].get(0).map(|v| v.is_null()).unwrap_or(false));
assert!(rows[1].get(0).map(|v| v.is_null()).unwrap_or(false));
assert_eq!(rows[2].get(0), Some(&Value::Integer(1)));
assert_eq!(rows[3].get(0), Some(&Value::Integer(2)));
assert_eq!(rows[4].get(0), Some(&Value::Integer(3)));
}
#[test]
fn test_distinct_hash_collision_handling() {
let rows = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("a".into())]),
Row::from_values(vec![Value::Integer(1), Value::Text("b".into())]), Row::from_values(vec![Value::Integer(1), Value::Text("a".into())]), Row::from_values(vec![Value::Integer(2), Value::Text("a".into())]), Row::from_values(vec![Value::Integer(2), Value::Text("a".into())]), ];
let config = ParallelConfig {
min_rows_for_parallel_filter: 1, ..Default::default()
};
let result = parallel_distinct(rows, &config);
assert_eq!(result.len(), 3, "Should have 3 unique rows");
let has_1_a = result.iter().any(|r| {
r.get(0) == Some(&Value::Integer(1)) && r.get(1) == Some(&Value::Text("a".into()))
});
let has_1_b = result.iter().any(|r| {
r.get(0) == Some(&Value::Integer(1)) && r.get(1) == Some(&Value::Text("b".into()))
});
let has_2_a = result.iter().any(|r| {
r.get(0) == Some(&Value::Integer(2)) && r.get(1) == Some(&Value::Text("a".into()))
});
assert!(has_1_a, "Should contain (1, 'a')");
assert!(has_1_b, "Should contain (1, 'b')");
assert!(has_2_a, "Should contain (2, 'a')");
}
#[test]
fn test_sequential_distinct_hash_collision() {
let rows = vec![
Row::from_values(vec![Value::Integer(100)]),
Row::from_values(vec![Value::Integer(200)]),
Row::from_values(vec![Value::Integer(100)]), Row::from_values(vec![Value::Integer(300)]),
];
let config = ParallelConfig {
min_rows_for_parallel_filter: 10000,
..Default::default()
};
let result = parallel_distinct(rows, &config);
assert_eq!(
result.len(),
3,
"Should have 3 unique values: 100, 200, 300"
);
}
#[test]
fn test_parallel_hash_join_collision_handling() {
let build_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("build_a".into())]),
Row::from_values(vec![Value::Integer(2), Value::Text("build_b".into())]),
Row::from_values(vec![Value::Integer(3), Value::Text("build_c".into())]),
];
let probe_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("probe_x".into())]),
Row::from_values(vec![Value::Integer(2), Value::Text("probe_y".into())]),
Row::from_values(vec![Value::Integer(4), Value::Text("probe_z".into())]), ];
let config = ParallelConfig {
min_rows_for_parallel_join: 1, ..Default::default()
};
let result = parallel_hash_join(
&probe_rows,
&build_rows,
&[0], &[0], JoinType::Inner,
2, 2, false,
&config,
);
assert_eq!(result.rows.len(), 2, "INNER JOIN should have 2 matches");
for row in &result.rows {
assert_eq!(row.len(), 4);
}
}
#[test]
fn test_parallel_left_join_unmatched() {
let build_rows: Vec<Row> = vec![Row::from_values(vec![
Value::Integer(1),
Value::Text("match".into()),
])];
let probe_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("p1".into())]), Row::from_values(vec![Value::Integer(2), Value::Text("p2".into())]), Row::from_values(vec![Value::Integer(3), Value::Text("p3".into())]), ];
let config = ParallelConfig {
min_rows_for_parallel_join: 1,
..Default::default()
};
let result = parallel_hash_join(
&probe_rows,
&build_rows,
&[0],
&[0],
JoinType::Left,
2,
2,
false,
&config,
);
assert_eq!(result.rows.len(), 3, "LEFT JOIN should have 3 rows");
let null_count = result
.rows
.iter()
.filter(|r| {
r.get(2).map(|v| v.is_null()).unwrap_or(false)
&& r.get(3).map(|v| v.is_null()).unwrap_or(false)
})
.count();
assert_eq!(
null_count, 2,
"Should have 2 unmatched rows with NULL build columns"
);
}
#[test]
fn test_parallel_right_join_unmatched() {
let build_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("b1".into())]), Row::from_values(vec![Value::Integer(2), Value::Text("b2".into())]), Row::from_values(vec![Value::Integer(3), Value::Text("b3".into())]), ];
let probe_rows: Vec<Row> = vec![Row::from_values(vec![
Value::Integer(1),
Value::Text("p1".into()),
])];
let config = ParallelConfig {
min_rows_for_parallel_join: 1,
..Default::default()
};
let result = parallel_hash_join(
&probe_rows,
&build_rows,
&[0],
&[0],
JoinType::Right,
2,
2,
false,
&config,
);
assert_eq!(result.rows.len(), 3, "RIGHT JOIN should have 3 rows");
let null_count = result
.rows
.iter()
.filter(|r| {
r.get(0).map(|v| v.is_null()).unwrap_or(false)
&& r.get(1).map(|v| v.is_null()).unwrap_or(false)
})
.count();
assert_eq!(
null_count, 2,
"Should have 2 unmatched rows with NULL probe columns"
);
}
#[test]
fn test_parallel_full_outer_join() {
let build_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("b1".into())]), Row::from_values(vec![Value::Integer(3), Value::Text("b3".into())]), ];
let probe_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("p1".into())]), Row::from_values(vec![Value::Integer(2), Value::Text("p2".into())]), ];
let config = ParallelConfig {
min_rows_for_parallel_join: 1,
..Default::default()
};
let result = parallel_hash_join(
&probe_rows,
&build_rows,
&[0],
&[0],
JoinType::Full,
2,
2,
false,
&config,
);
assert_eq!(result.rows.len(), 3, "FULL OUTER JOIN should have 3 rows");
}
#[test]
fn test_parallel_join_empty_tables() {
let config = ParallelConfig::default();
let result = parallel_hash_join(
&[],
&[Row::from_values(vec![Value::Integer(1)])],
&[0],
&[0],
JoinType::Inner,
1,
1,
false,
&config,
);
assert_eq!(
result.rows.len(),
0,
"Empty probe should give empty result for INNER"
);
let result = parallel_hash_join(
&[Row::from_values(vec![Value::Integer(1)])],
&[],
&[0],
&[0],
JoinType::Inner,
1,
1,
false,
&config,
);
assert_eq!(
result.rows.len(),
0,
"Empty build should give empty result for INNER"
);
let result = parallel_hash_join(
&[
Row::from_values(vec![Value::Integer(1)]),
Row::from_values(vec![Value::Integer(2)]),
],
&[],
&[0],
&[0],
JoinType::Left,
1,
1,
false,
&config,
);
assert_eq!(
result.rows.len(),
2,
"LEFT JOIN with empty build should have all probe rows"
);
}
#[test]
fn test_parallel_join_swapped() {
let build_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("b1".into())]),
Row::from_values(vec![Value::Integer(3), Value::Text("b3".into())]), ];
let probe_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("p1".into())]),
Row::from_values(vec![Value::Integer(2), Value::Text("p2".into())]), ];
let config = ParallelConfig {
min_rows_for_parallel_join: 1,
..Default::default()
};
let result = parallel_hash_join(
&probe_rows,
&build_rows,
&[0],
&[0],
JoinType::Left,
2,
2,
true, &config,
);
assert_eq!(result.rows.len(), 2, "LEFT JOIN swapped should have 2 rows");
let matched_row = result
.rows
.iter()
.find(|r| r.get(0) == Some(&Value::Integer(1)) && r.get(2) == Some(&Value::Integer(1)));
assert!(matched_row.is_some(), "Should have a matched row with id=1");
let unmatched_row = result
.rows
.iter()
.find(|r| r.get(0) == Some(&Value::Integer(3)));
assert!(
unmatched_row.is_some(),
"Should have unmatched build row with id=3"
);
let unmatched = unmatched_row.unwrap();
assert!(
unmatched.get(2).map(|v| v.is_null()).unwrap_or(false),
"Probe col should be NULL"
);
assert!(
unmatched.get(3).map(|v| v.is_null()).unwrap_or(false),
"Probe col should be NULL"
);
}
#[test]
fn test_rows_equal() {
let row1 = Row::from_values(vec![Value::Integer(1), Value::Text("a".into())]);
let row2 = Row::from_values(vec![Value::Integer(1), Value::Text("a".into())]);
let row3 = Row::from_values(vec![Value::Integer(1), Value::Text("b".into())]);
let row4 = Row::from_values(vec![Value::Integer(1)]);
assert!(rows_equal(&row1, &row2), "Identical rows should be equal");
assert!(
!rows_equal(&row1, &row3),
"Different values should not be equal"
);
assert!(
!rows_equal(&row1, &row4),
"Different lengths should not be equal"
);
}
#[test]
fn test_rows_equal_with_nulls() {
let row_with_null1 = Row::from_values(vec![Value::Integer(1), Value::null_unknown()]);
let row_with_null2 = Row::from_values(vec![Value::Integer(1), Value::null_unknown()]);
let row_no_null = Row::from_values(vec![Value::Integer(1), Value::Integer(2)]);
assert!(
rows_equal(&row_with_null1, &row_with_null2),
"Rows with same NULL positions should be equal"
);
assert!(
!rows_equal(&row_with_null1, &row_no_null),
"NULL should not equal non-NULL"
);
}
}