#[cfg(feature = "parallel")]
use rayon::prelude::*;
use rustc_hash::{FxHashMap, FxHashSet};
use std::sync::atomic::{AtomicBool, Ordering};
use crate::common::CompactVec;
use crate::core::value::NULL_VALUE;
use crate::core::{Result, Row, RowVec, Value};
use crate::functions::FunctionRegistry;
use crate::parser::ast::Expression;
use super::context::ExecutionContext;
use super::expression::ExpressionEval;
#[cfg(feature = "parallel")]
use super::expression::RowFilter;
use super::utils::{hash_composite_key, hash_row, rows_equal, verify_composite_key_equality};
#[inline]
fn num_threads() -> usize {
#[cfg(feature = "parallel")]
{
rayon::current_num_threads()
}
#[cfg(not(feature = "parallel"))]
{
1
}
}
pub use super::operators::hash_join::JoinType;
pub const DEFAULT_PARALLEL_FILTER_THRESHOLD: usize = 10_000;
pub const DEFAULT_PARALLEL_SORT_THRESHOLD: usize = 50_000;
pub const DEFAULT_PARALLEL_JOIN_THRESHOLD: usize = 10_000;
pub const DEFAULT_PARALLEL_CHUNK_SIZE: usize = 2048;
#[derive(Clone, Debug)]
pub struct ParallelConfig {
pub enabled: bool,
pub min_rows_for_parallel_filter: usize,
pub min_rows_for_parallel_sort: usize,
pub min_rows_for_parallel_join: usize,
pub chunk_size: usize,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
enabled: true,
min_rows_for_parallel_filter: DEFAULT_PARALLEL_FILTER_THRESHOLD,
min_rows_for_parallel_sort: DEFAULT_PARALLEL_SORT_THRESHOLD,
min_rows_for_parallel_join: DEFAULT_PARALLEL_JOIN_THRESHOLD,
chunk_size: DEFAULT_PARALLEL_CHUNK_SIZE,
}
}
}
impl ParallelConfig {
pub fn new(
enabled: bool,
min_rows_for_parallel_filter: usize,
min_rows_for_parallel_sort: usize,
min_rows_for_parallel_join: usize,
chunk_size: usize,
) -> Self {
Self {
enabled,
min_rows_for_parallel_filter,
min_rows_for_parallel_sort,
min_rows_for_parallel_join,
chunk_size,
}
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
#[inline]
pub fn should_parallel_filter(&self, row_count: usize) -> bool {
self.enabled && row_count >= self.min_rows_for_parallel_filter
}
#[inline]
pub fn should_parallel_sort(&self, row_count: usize) -> bool {
self.enabled && row_count >= self.min_rows_for_parallel_sort
}
#[inline]
pub fn should_parallel_join(&self, build_rows: usize) -> bool {
self.enabled && build_rows >= self.min_rows_for_parallel_join
}
}
pub fn parallel_filter(
rows: RowVec,
filter_expr: &Expression,
columns: &[String],
_function_registry: &FunctionRegistry,
config: &ParallelConfig,
ctx: &ExecutionContext,
) -> Result<RowVec> {
#[cfg(feature = "parallel")]
{
let row_count = rows.len();
if config.should_parallel_filter(row_count) {
let columns_vec: Vec<String> = columns.to_vec();
let filter = RowFilter::new(filter_expr, &columns_vec)?.with_context(ctx);
let mut row_vec: Vec<(i64, Row)> = rows.into_vec();
let keep: Result<Vec<bool>> = row_vec
.par_iter()
.map(|(_, row)| filter.matches_checked(row))
.collect();
let keep = keep?;
let kept_count = keep.iter().filter(|&&b| b).count();
let mut filtered = Vec::with_capacity(kept_count);
for (i, entry) in row_vec.drain(..).enumerate() {
if keep[i] {
filtered.push(entry);
}
}
return Ok(RowVec::from_vec(filtered));
}
}
let _ = config; sequential_filter(rows, filter_expr, columns, ctx)
}
fn sequential_filter(
rows: RowVec,
filter_expr: &Expression,
columns: &[String],
ctx: &ExecutionContext,
) -> Result<RowVec> {
let columns_vec: Vec<String> = columns.to_vec();
let mut eval = ExpressionEval::compile(filter_expr, &columns_vec)?.with_context(ctx);
let mut result = RowVec::with_capacity(rows.len());
for (id, row) in rows {
if eval.eval_bool_checked(&row)? {
result.push((id, row));
}
}
Ok(result)
}
pub fn parallel_filter_owned(
rows: RowVec,
predicate: impl Fn(&Row) -> bool + Sync + Send,
config: &ParallelConfig,
) -> RowVec {
#[cfg(feature = "parallel")]
{
let row_count = rows.len();
if config.should_parallel_filter(row_count) {
let chunk_size = (row_count / num_threads()).max(1000);
let marks: Vec<bool> = rows
.par_chunks(chunk_size)
.flat_map(|chunk| {
chunk
.iter()
.map(|(_, row)| predicate(row))
.collect::<Vec<_>>()
})
.collect();
let match_count = marks.iter().filter(|&&b| b).count();
let mut result = RowVec::with_capacity(match_count);
for ((id, row), keep) in rows.into_iter().zip(marks) {
if keep {
result.push((id, row));
}
}
return result;
}
}
let _ = config; rows.into_iter().filter(|(_, r)| predicate(r)).collect()
}
pub fn parallel_sort<F>(rows: &mut [(i64, Row)], compare: F, config: &ParallelConfig)
where
F: Fn(&Row, &Row) -> std::cmp::Ordering + Sync + Send,
{
#[cfg(feature = "parallel")]
if config.should_parallel_sort(rows.len()) {
rows.par_sort_unstable_by(|(_, a), (_, b)| compare(a, b));
return;
}
let _ = config; rows.sort_unstable_by(|(_, a), (_, b)| compare(a, b));
}
pub fn parallel_sort_unstable<F>(rows: &mut [(i64, Row)], compare: F, config: &ParallelConfig)
where
F: Fn(&Row, &Row) -> std::cmp::Ordering + Sync + Send,
{
#[cfg(feature = "parallel")]
if config.should_parallel_sort(rows.len()) {
rows.par_sort_unstable_by(|(_, a), (_, b)| compare(a, b));
return;
}
let _ = config; rows.sort_unstable_by(|(_, a), (_, b)| compare(a, b));
}
pub fn parallel_distinct(rows: RowVec, config: &ParallelConfig) -> RowVec {
#[cfg(feature = "parallel")]
{
let row_count = rows.len();
if config.should_parallel_filter(row_count) {
let n_threads = num_threads();
let chunk_size = config.chunk_size.max(row_count / n_threads).max(1000);
let row_vec: Vec<(i64, Row)> = rows.into_vec();
let deduped_chunks: Vec<Vec<(u64, i64, Row)>> = row_vec
.into_par_iter()
.chunks(chunk_size)
.map(|chunk| {
let mut hash_to_indices: FxHashMap<u64, Vec<usize>> = FxHashMap::default();
let mut unique_with_hashes: Vec<(u64, i64, Row)> =
Vec::with_capacity(chunk.len());
for (row_id, row) in chunk {
let hash = hash_row(&row);
let indices = hash_to_indices.entry(hash).or_default();
let is_duplicate = indices
.iter()
.any(|&idx| rows_equal(&unique_with_hashes[idx].2, &row));
if !is_duplicate {
indices.push(unique_with_hashes.len());
unique_with_hashes.push((hash, row_id, row));
}
}
unique_with_hashes
})
.collect();
let total_size: usize = deduped_chunks.iter().map(|chunk| chunk.len()).sum();
let estimated_size = (total_size * 3) / 4;
let mut result = RowVec::with_capacity(estimated_size);
let mut hash_to_indices: FxHashMap<u64, Vec<usize>> = FxHashMap::default();
for chunk in deduped_chunks {
for (hash, row_id, row) in chunk {
let indices = hash_to_indices.entry(hash).or_default();
let is_duplicate = indices.iter().any(|&idx| rows_equal(&result[idx].1, &row));
if !is_duplicate {
indices.push(result.len());
result.push((row_id, row));
}
}
}
return result;
}
}
let _ = config; sequential_distinct(rows)
}
fn sequential_distinct(rows: RowVec) -> RowVec {
let mut hash_to_indices: FxHashMap<u64, Vec<usize>> = FxHashMap::default();
let mut result = RowVec::with_capacity(rows.len());
for (row_id, row) in rows {
let hash = hash_row(&row);
let indices = hash_to_indices.entry(hash).or_default();
let is_duplicate = indices.iter().any(|&idx| rows_equal(&result[idx].1, &row));
if !is_duplicate {
indices.push(result.len());
result.push((row_id, row));
}
}
result
}
pub fn parallel_project<F>(rows: RowVec, project_fn: F, config: &ParallelConfig) -> RowVec
where
F: Fn(&Row) -> Row + Sync + Send,
{
#[cfg(feature = "parallel")]
let row_count = rows.len();
#[cfg(feature = "parallel")]
if config.should_parallel_filter(row_count) {
let chunk_size = (row_count / num_threads()).max(1000);
let projected: Vec<(i64, Row)> = rows
.par_chunks(chunk_size)
.flat_map(|chunk| {
chunk
.iter()
.map(|(id, row)| (*id, project_fn(row)))
.collect::<Vec<_>>()
})
.collect();
return RowVec::from_vec(projected);
}
let _ = config; rows.into_iter()
.map(|(id, r)| (id, project_fn(&r)))
.collect()
}
#[derive(Clone, Debug, Default)]
pub struct ParallelStats {
pub rows_processed: usize,
pub rows_passed: usize,
pub chunks_used: usize,
pub parallel_used: bool,
}
enum HashTableStorage {
Sequential(FxHashMap<u64, Vec<usize>>),
#[cfg(feature = "parallel")]
Parallel(dashmap::DashMap<u64, Vec<usize>>),
}
struct BuildMatchedTracker {
matched: Vec<AtomicBool>,
}
impl BuildMatchedTracker {
fn new(size: usize) -> Self {
BuildMatchedTracker {
matched: (0..size).map(|_| AtomicBool::new(false)).collect(),
}
}
#[inline]
fn mark_matched(&self, idx: usize) {
self.matched[idx].store(true, Ordering::Release);
}
#[inline]
fn was_matched(&self, idx: usize) -> bool {
self.matched[idx].load(Ordering::Acquire)
}
}
impl HashTableStorage {
#[inline]
fn get(&self, key: &u64) -> Option<Vec<usize>> {
match self {
HashTableStorage::Sequential(map) => map.get(key).cloned(),
#[cfg(feature = "parallel")]
HashTableStorage::Parallel(map) => map.get(key).map(|v| v.clone()),
}
}
}
pub struct ParallelHashTable {
storage: HashTableStorage,
pub row_count: usize,
}
impl ParallelHashTable {
#[inline]
pub fn get(&self, key: &u64) -> Option<Vec<usize>> {
self.storage.get(key)
}
}
pub fn parallel_hash_build(
build_rows: &[Row],
key_indices: &[usize],
config: &ParallelConfig,
) -> ParallelHashTable {
let row_count = build_rows.len();
#[cfg(feature = "parallel")]
if config.should_parallel_join(row_count) {
use dashmap::DashMap;
let table: DashMap<u64, Vec<usize>> = DashMap::with_capacity(row_count);
let n_threads = num_threads();
let chunk_size = config.chunk_size.max(row_count / n_threads).max(1000);
build_rows
.par_chunks(chunk_size)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let base_idx = chunk_idx * chunk_size;
for (local_idx, row) in chunk.iter().enumerate() {
debug_assert!(
base_idx.checked_add(local_idx).is_some(),
"Index overflow in parallel hash build: base_idx={} + local_idx={}",
base_idx,
local_idx
);
let global_idx = base_idx + local_idx;
let hash = hash_composite_key(row, key_indices);
table.entry(hash).or_default().push(global_idx);
}
});
return ParallelHashTable {
storage: HashTableStorage::Parallel(table),
row_count,
};
}
let _ = config; let mut table: FxHashMap<u64, Vec<usize>> =
FxHashMap::with_capacity_and_hasher(row_count, Default::default());
for (idx, row) in build_rows.iter().enumerate() {
let hash = hash_composite_key(row, key_indices);
table.entry(hash).or_default().push(idx);
}
ParallelHashTable {
storage: HashTableStorage::Sequential(table),
row_count,
}
}
pub fn parallel_hash_probe<F>(
probe_rows: &[Row],
probe_key_indices: &[usize],
hash_table: &ParallelHashTable,
build_rows: &[Row],
verify_match: F,
config: &ParallelConfig,
) -> Vec<(usize, usize)>
where
F: Fn(&Row, &Row) -> bool + Sync + Send,
{
#[cfg(feature = "parallel")]
let probe_count = probe_rows.len();
#[cfg(feature = "parallel")]
if config.should_parallel_join(probe_count) {
let n_threads = num_threads();
let chunk_size = config.chunk_size.max(probe_count / n_threads).max(1000);
return probe_rows
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(chunk_idx, chunk)| {
let base_idx = chunk_idx * chunk_size;
let mut local_matches = Vec::new();
for (local_idx, probe_row) in chunk.iter().enumerate() {
let probe_idx = base_idx + local_idx;
let hash = hash_composite_key(probe_row, probe_key_indices);
if let Some(build_indices) = hash_table.get(&hash) {
for build_idx in build_indices {
if verify_match(probe_row, &build_rows[build_idx]) {
local_matches.push((probe_idx, build_idx));
}
}
}
}
local_matches
})
.collect();
}
let _ = config; let mut matches = Vec::new();
for (probe_idx, probe_row) in probe_rows.iter().enumerate() {
let hash = hash_composite_key(probe_row, probe_key_indices);
if let Some(build_indices) = hash_table.get(&hash) {
for build_idx in build_indices {
if verify_match(probe_row, &build_rows[build_idx]) {
matches.push((probe_idx, build_idx));
}
}
}
}
matches
}
#[inline]
pub fn hash_row_by_keys(row: &Row, key_indices: &[usize]) -> u64 {
hash_composite_key(row, key_indices)
}
#[inline]
pub fn verify_key_match(
probe_row: &Row,
build_row: &Row,
probe_key_indices: &[usize],
build_key_indices: &[usize],
) -> bool {
verify_composite_key_equality(probe_row, build_row, probe_key_indices, build_key_indices)
}
pub struct ParallelJoinResult {
pub rows: Vec<Row>,
pub parallel_used: bool,
pub probe_rows_processed: usize,
pub build_rows_count: usize,
pub matches_found: usize,
}
#[allow(clippy::too_many_arguments)]
fn sequential_probe(
probe_rows: &[Row],
build_rows: &[Row],
hash_table: &ParallelHashTable,
probe_key_indices: &[usize],
build_key_indices: &[usize],
join_type: &JoinType,
probe_col_count: usize,
build_col_count: usize,
swapped: bool,
build_matched: &Option<BuildMatchedTracker>,
) -> (Vec<Row>, Vec<Row>) {
let mut matched_rows = Vec::new();
let needs_unmatched_probe = join_type.needs_unmatched_probe(swapped);
for probe_row in probe_rows.iter() {
let hash = hash_row_by_keys(probe_row, probe_key_indices);
let mut matched = false;
if let Some(build_indices) = hash_table.get(&hash) {
for build_idx in build_indices {
let build_row = &build_rows[build_idx];
if verify_key_match(probe_row, build_row, probe_key_indices, build_key_indices) {
matched = true;
if let Some(ref tracker) = build_matched {
tracker.mark_matched(build_idx);
}
let combined = combine_join_rows(
probe_row,
build_row,
probe_col_count,
build_col_count,
swapped,
);
matched_rows.push(Row::from_compact_vec(combined));
}
}
}
if !matched && needs_unmatched_probe {
let values = combine_with_nulls(probe_row, probe_col_count, build_col_count, swapped);
matched_rows.push(Row::from_compact_vec(values));
}
}
(matched_rows, Vec::new())
}
#[allow(clippy::too_many_arguments)]
pub fn parallel_hash_join(
probe_rows: &[Row],
build_rows: &[Row],
probe_key_indices: &[usize],
build_key_indices: &[usize],
join_type: JoinType,
probe_col_count: usize,
build_col_count: usize,
swapped: bool,
config: &ParallelConfig,
) -> ParallelJoinResult {
let probe_count = probe_rows.len();
let build_count = build_rows.len();
#[cfg(feature = "parallel")]
let use_parallel =
config.should_parallel_join(build_count) || config.should_parallel_join(probe_count);
#[cfg(not(feature = "parallel"))]
let use_parallel = false;
let hash_table = parallel_hash_build(build_rows, build_key_indices, config);
let build_matched: Option<BuildMatchedTracker> = if join_type.needs_unmatched_build(swapped) {
Some(BuildMatchedTracker::new(build_count))
} else {
None
};
#[allow(unused_variables)]
let (matched_rows, unmatched_probe_rows) = {
#[cfg(feature = "parallel")]
{
if use_parallel && join_type == JoinType::Inner {
let matches: Vec<Row> = probe_rows
.par_chunks(config.chunk_size.max(1000))
.flat_map(|chunk| {
let mut local_results = Vec::new();
for probe_row in chunk {
let hash = hash_row_by_keys(probe_row, probe_key_indices);
if let Some(build_indices) = hash_table.get(&hash) {
for build_idx in build_indices {
let build_row = &build_rows[build_idx];
if verify_key_match(
probe_row,
build_row,
probe_key_indices,
build_key_indices,
) {
let combined = combine_join_rows(
probe_row,
build_row,
probe_col_count,
build_col_count,
swapped,
);
local_results.push(Row::from_compact_vec(combined));
}
}
}
}
local_results
})
.collect();
(matches, Vec::new())
} else if use_parallel {
let needs_unmatched_probe = join_type.needs_unmatched_probe(swapped);
let chunk_results: Vec<(Vec<Row>, Vec<Row>)> = probe_rows
.par_chunks(config.chunk_size.max(1000))
.map(|chunk| {
let mut matched_results = Vec::new();
let mut unmatched_results = Vec::new();
for probe_row in chunk.iter() {
let mut matched = false;
let hash = hash_row_by_keys(probe_row, probe_key_indices);
if let Some(build_indices) = hash_table.get(&hash) {
for build_idx in build_indices {
let build_row = &build_rows[build_idx];
if verify_key_match(
probe_row,
build_row,
probe_key_indices,
build_key_indices,
) {
matched = true;
if let Some(ref tracker) = build_matched {
tracker.mark_matched(build_idx);
}
let combined = combine_join_rows(
probe_row,
build_row,
probe_col_count,
build_col_count,
swapped,
);
matched_results.push(Row::from_compact_vec(combined));
}
}
}
if !matched && needs_unmatched_probe {
let values = combine_with_nulls(
probe_row,
probe_col_count,
build_col_count,
swapped,
);
unmatched_results.push(Row::from_compact_vec(values));
}
}
(matched_results, unmatched_results)
})
.collect();
let total_matched: usize = chunk_results.iter().map(|(m, _)| m.len()).sum();
let total_unmatched: usize = chunk_results.iter().map(|(_, u)| u.len()).sum();
let mut matched_rows = Vec::with_capacity(total_matched);
let mut unmatched_rows = Vec::with_capacity(total_unmatched);
for (matched, unmatched) in chunk_results {
matched_rows.extend(matched);
unmatched_rows.extend(unmatched);
}
std::sync::atomic::fence(Ordering::Acquire);
(matched_rows, unmatched_rows)
} else {
sequential_probe(
probe_rows,
build_rows,
&hash_table,
probe_key_indices,
build_key_indices,
&join_type,
probe_col_count,
build_col_count,
swapped,
&build_matched,
)
}
}
#[cfg(not(feature = "parallel"))]
{
sequential_probe(
probe_rows,
build_rows,
&hash_table,
probe_key_indices,
build_key_indices,
&join_type,
probe_col_count,
build_col_count,
swapped,
&build_matched,
)
}
};
let mut result_rows = matched_rows;
result_rows.extend(unmatched_probe_rows);
if let Some(ref tracker) = build_matched {
for (build_idx, build_row) in build_rows.iter().enumerate() {
if !tracker.was_matched(build_idx) {
let values =
combine_build_with_nulls(build_row, build_col_count, probe_col_count, swapped);
result_rows.push(Row::from_compact_vec(values));
}
}
}
let matches_found = result_rows.len();
ParallelJoinResult {
rows: result_rows,
parallel_used: use_parallel,
probe_rows_processed: probe_count,
build_rows_count: build_count,
matches_found,
}
}
#[inline]
fn combine_join_rows(
probe_row: &Row,
build_row: &Row,
probe_col_count: usize,
build_col_count: usize,
swapped: bool,
) -> CompactVec<Value> {
let mut combined: CompactVec<Value> =
CompactVec::with_capacity(probe_col_count + build_col_count);
if swapped {
for i in 0..build_col_count {
combined.push(build_row.get(i).cloned().unwrap_or(NULL_VALUE));
}
for i in 0..probe_col_count {
combined.push(probe_row.get(i).cloned().unwrap_or(NULL_VALUE));
}
} else {
for i in 0..probe_col_count {
combined.push(probe_row.get(i).cloned().unwrap_or(NULL_VALUE));
}
for i in 0..build_col_count {
combined.push(build_row.get(i).cloned().unwrap_or(NULL_VALUE));
}
}
combined
}
#[inline]
fn combine_with_nulls(
probe_row: &Row,
probe_col_count: usize,
build_col_count: usize,
swapped: bool,
) -> CompactVec<Value> {
let mut combined: CompactVec<Value> =
CompactVec::with_capacity(probe_col_count + build_col_count);
if swapped {
combined.extend(std::iter::repeat_n(NULL_VALUE, build_col_count));
for i in 0..probe_col_count {
combined.push(probe_row.get(i).cloned().unwrap_or(NULL_VALUE));
}
} else {
for i in 0..probe_col_count {
combined.push(probe_row.get(i).cloned().unwrap_or(NULL_VALUE));
}
combined.extend(std::iter::repeat_n(NULL_VALUE, build_col_count));
}
combined
}
#[inline]
fn combine_build_with_nulls(
build_row: &Row,
build_col_count: usize,
probe_col_count: usize,
swapped: bool,
) -> CompactVec<Value> {
let mut combined: CompactVec<Value> =
CompactVec::with_capacity(probe_col_count + build_col_count);
if swapped {
for i in 0..build_col_count {
combined.push(build_row.get(i).cloned().unwrap_or(NULL_VALUE));
}
combined.extend(std::iter::repeat_n(NULL_VALUE, probe_col_count));
} else {
combined.extend(std::iter::repeat_n(NULL_VALUE, probe_col_count));
for i in 0..build_col_count {
combined.push(build_row.get(i).cloned().unwrap_or(NULL_VALUE));
}
}
combined
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SortDirection {
Ascending,
Descending,
}
#[derive(Clone, Debug)]
pub struct SortSpec {
pub column_index: usize,
pub direction: SortDirection,
pub nulls_first: bool,
}
pub fn parallel_order_by(
rows: &mut [(i64, Row)],
sort_specs: &[SortSpec],
config: &ParallelConfig,
) {
let compare = |(_, a): &(i64, Row), (_, b): &(i64, Row)| -> std::cmp::Ordering {
for spec in sort_specs {
let a_val = a.get(spec.column_index);
let b_val = b.get(spec.column_index);
let a_is_null = a_val.map(|v| v.is_null()).unwrap_or(true);
let b_is_null = b_val.map(|v| v.is_null()).unwrap_or(true);
let ordering = match (a_is_null, b_is_null) {
(true, true) => std::cmp::Ordering::Equal,
(true, false) => {
if spec.nulls_first {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Greater
}
}
(false, true) => {
if spec.nulls_first {
std::cmp::Ordering::Greater
} else {
std::cmp::Ordering::Less
}
}
(false, false) => {
let a_v = a_val.unwrap();
let b_v = b_val.unwrap();
a_v.partial_cmp(b_v).unwrap_or(std::cmp::Ordering::Equal)
}
};
let ordering = if spec.direction == SortDirection::Descending {
ordering.reverse()
} else {
ordering
};
if ordering != std::cmp::Ordering::Equal {
return ordering;
}
}
std::cmp::Ordering::Equal
};
#[cfg(feature = "parallel")]
if config.should_parallel_sort(rows.len()) {
rows.par_sort_unstable_by(compare);
return;
}
let _ = config; rows.sort_unstable_by(compare);
}
pub fn parallel_order_by_fn<F>(rows: &mut [(i64, Row)], compare: F, config: &ParallelConfig)
where
F: Fn(&Row, &Row) -> std::cmp::Ordering + Sync + Send,
{
#[cfg(feature = "parallel")]
if config.should_parallel_sort(rows.len()) {
rows.par_sort_unstable_by(|(_, a), (_, b)| compare(a, b));
return;
}
let _ = config; rows.sort_unstable_by(|(_, a), (_, b)| compare(a, b));
}
pub fn parallel_order_by_unstable<F>(rows: &mut [(i64, Row)], compare: F, config: &ParallelConfig)
where
F: Fn(&Row, &Row) -> std::cmp::Ordering + Sync + Send,
{
#[cfg(feature = "parallel")]
if config.should_parallel_sort(rows.len()) {
rows.par_sort_unstable_by(|(_, a), (_, b)| compare(a, b));
return;
}
let _ = config; rows.sort_unstable_by(|(_, a), (_, b)| compare(a, b));
}
pub fn parallel_filter_with_stats(
rows: RowVec,
filter_expr: &Expression,
columns: &[String],
function_registry: &FunctionRegistry,
config: &ParallelConfig,
ctx: &ExecutionContext,
) -> Result<(RowVec, ParallelStats)> {
let row_count = rows.len();
let parallel_used = config.should_parallel_filter(row_count);
let n_threads = num_threads();
let chunk_size = if parallel_used {
config.chunk_size.max(row_count / (n_threads * 4)).max(512)
} else {
row_count };
let chunks_used = row_count.div_ceil(chunk_size);
let result = parallel_filter(rows, filter_expr, columns, function_registry, config, ctx)?;
let stats = ParallelStats {
rows_processed: row_count,
rows_passed: result.len(),
chunks_used,
parallel_used,
};
Ok((result, stats))
}
#[derive(Debug, Clone, Copy)]
pub enum DistanceMetric {
L2,
Cosine,
InnerProduct,
}
pub fn parallel_topn_vector_search(
rows: RowVec,
vector_col_idx: usize,
query_bytes: &[u8],
k: usize,
metric: DistanceMetric,
config: &ParallelConfig,
) -> Vec<(i64, Row, f64)> {
use std::collections::BinaryHeap;
if k == 0 || rows.is_empty() {
return Vec::new();
}
let distance_fn: fn(&[u8], &[u8]) -> f64 = match metric {
DistanceMetric::L2 => crate::functions::scalar::vector::l2_distance_bytes,
DistanceMetric::Cosine => crate::functions::scalar::vector::cosine_distance_bytes,
DistanceMetric::InnerProduct => crate::functions::scalar::vector::ip_distance_bytes,
};
#[inline]
fn get_vector_bytes(row: &Row, col_idx: usize) -> Option<&[u8]> {
match row.get(col_idx)? {
Value::Extension(data)
if data.first() == Some(&(crate::core::DataType::Vector as u8)) =>
{
Some(&data[1..])
}
_ => None,
}
}
struct HeapEntry {
distance: f64,
idx: usize, }
impl PartialEq for HeapEntry {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for HeapEntry {}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.distance.total_cmp(&other.distance)
}
}
let row_vec: Vec<(i64, Row)> = rows.into_vec();
#[cfg(feature = "parallel")]
let use_parallel = config.should_parallel_filter(row_vec.len());
#[cfg(not(feature = "parallel"))]
let use_parallel = false;
let _ = config;
if use_parallel {
#[cfg(feature = "parallel")]
{
let n_threads = num_threads();
let chunk_size = (row_vec.len() / (n_threads * 4)).max(1024);
let chunk_results: Vec<Vec<(i64, Row, f64)>> = row_vec
.into_par_iter()
.chunks(chunk_size)
.map(|chunk| {
let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
let mut entries: Vec<(i64, Row, f64)> = Vec::with_capacity(k + 1);
for (row_id, row) in chunk {
let dist = if let Some(vec_bytes) = get_vector_bytes(&row, vector_col_idx) {
if vec_bytes.len() == query_bytes.len() {
distance_fn(vec_bytes, query_bytes)
} else {
f64::INFINITY }
} else {
f64::INFINITY };
if entries.len() < k {
let idx = entries.len();
entries.push((row_id, row, dist));
heap.push(HeapEntry {
distance: dist,
idx,
});
} else if let Some(worst) = heap.peek() {
if dist < worst.distance {
let evict_idx = worst.idx;
heap.pop();
entries[evict_idx] = (row_id, row, dist);
heap.push(HeapEntry {
distance: dist,
idx: evict_idx,
});
}
}
}
let live_indices: FxHashSet<usize> = heap.into_iter().map(|e| e.idx).collect();
entries
.into_iter()
.enumerate()
.filter_map(|(i, e)| {
if live_indices.contains(&i) {
Some(e)
} else {
None
}
})
.collect()
})
.collect();
let mut merged: Vec<(i64, Row, f64)> = chunk_results.into_iter().flatten().collect();
merged.sort_unstable_by(|a, b| a.2.total_cmp(&b.2));
merged.truncate(k);
return merged;
}
}
let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
let mut entries: Vec<(i64, Row, f64)> = Vec::with_capacity(k + 1);
for (row_id, row) in row_vec {
let dist = if let Some(vec_bytes) = get_vector_bytes(&row, vector_col_idx) {
if vec_bytes.len() == query_bytes.len() {
distance_fn(vec_bytes, query_bytes)
} else {
f64::INFINITY }
} else {
f64::INFINITY };
if entries.len() < k {
let idx = entries.len();
entries.push((row_id, row, dist));
heap.push(HeapEntry {
distance: dist,
idx,
});
} else if let Some(worst) = heap.peek() {
if dist < worst.distance {
let evict_idx = worst.idx;
heap.pop();
entries[evict_idx] = (row_id, row, dist);
heap.push(HeapEntry {
distance: dist,
idx: evict_idx,
});
}
}
}
let live_indices: FxHashSet<usize> = heap.into_iter().map(|e| e.idx).collect();
let mut result: Vec<(i64, Row, f64)> = entries
.into_iter()
.enumerate()
.filter_map(|(i, e)| {
if live_indices.contains(&i) {
Some(e)
} else {
None
}
})
.collect();
result.sort_unstable_by(|a, b| a.2.total_cmp(&b.2));
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::Value;
fn make_test_rows(count: usize) -> RowVec {
(0..count)
.map(|i| {
(
i as i64,
Row::from_values(vec![
Value::Integer(i as i64),
Value::Integer(i as i64 % 10),
]),
)
})
.collect()
}
#[test]
fn test_parallel_distinct() {
let mut rows = RowVec::new();
for i in 0i64..1000 {
rows.push((i, Row::from_values(vec![Value::Integer(i % 100)])));
}
let config = ParallelConfig {
min_rows_for_parallel_filter: 100, ..Default::default()
};
let result = parallel_distinct(rows, &config);
assert_eq!(result.len(), 100);
}
#[test]
fn test_parallel_sort() {
let mut rows: RowVec = (0..1000)
.rev()
.map(|i| (i, Row::from_values(vec![Value::Integer(i)])))
.collect();
let config = ParallelConfig {
min_rows_for_parallel_sort: 100,
..Default::default()
};
parallel_sort(
&mut rows,
|a, b| {
let a_val = a.get(0).and_then(|v| v.as_int64()).unwrap_or(0);
let b_val = b.get(0).and_then(|v| v.as_int64()).unwrap_or(0);
a_val.cmp(&b_val)
},
&config,
);
for (i, (_, row)) in rows.iter().enumerate() {
assert_eq!(row.get(0), Some(&Value::Integer(i as i64)));
}
}
#[test]
fn test_parallel_config_thresholds() {
let config = ParallelConfig::default();
assert!(!config.should_parallel_filter(1000)); assert!(config.should_parallel_filter(20_000));
assert!(!config.should_parallel_sort(10_000)); assert!(config.should_parallel_sort(100_000));
let disabled = ParallelConfig::disabled();
assert!(!disabled.should_parallel_filter(1_000_000)); }
#[test]
fn test_parallel_filter_owned() {
let rows = make_test_rows(50_000);
let config = ParallelConfig::default();
let result = parallel_filter_owned(
rows,
|row| {
if let Some(Value::Integer(v)) = row.get(1) {
*v < 5
} else {
false
}
},
&config,
);
assert_eq!(result.len(), 25_000);
}
#[test]
fn test_sequential_fallback_small_dataset() {
let rows = make_test_rows(100); let config = ParallelConfig::default();
let result = parallel_filter_owned(
rows,
|row| {
if let Some(Value::Integer(v)) = row.get(1) {
*v < 5
} else {
false
}
},
&config,
);
assert_eq!(result.len(), 50);
}
#[test]
fn test_parallel_hash_build() {
let build_rows: Vec<Row> = (0..10_000)
.map(|i| {
Row::from_values(vec![
Value::Integer(i),
Value::Text(format!("build_{}", i).into()),
])
})
.collect();
let config = ParallelConfig {
min_rows_for_parallel_join: 1000,
..Default::default()
};
let hash_table = parallel_hash_build(&build_rows, &[0], &config);
assert_eq!(hash_table.row_count, 10_000);
let test_hash = hash_row_by_keys(&build_rows[500], &[0]);
assert!(hash_table.get(&test_hash).is_some());
}
#[test]
fn test_parallel_hash_probe() {
let build_rows: Vec<Row> = (0..5_000)
.map(|i| {
Row::from_values(vec![
Value::Integer(i),
Value::Text(format!("build_{}", i).into()),
])
})
.collect();
let probe_rows: Vec<Row> = (0..10_000)
.map(|i| {
Row::from_values(vec![
Value::Integer(i * 2 % 5_000), Value::Text(format!("probe_{}", i).into()),
])
})
.collect();
let config = ParallelConfig {
min_rows_for_parallel_join: 1000,
..Default::default()
};
let hash_table = parallel_hash_build(&build_rows, &[0], &config);
let matches = parallel_hash_probe(
&probe_rows,
&[0],
&hash_table,
&build_rows,
|probe, build| {
probe.get(0) == build.get(0)
},
&config,
);
assert_eq!(matches.len(), 10_000);
}
#[test]
fn test_verify_key_match() {
let row1 = Row::from_values(vec![Value::Integer(1), Value::Text("a".to_string().into())]);
let row2 = Row::from_values(vec![Value::Integer(1), Value::Text("b".to_string().into())]);
let row3 = Row::from_values(vec![Value::Integer(2), Value::Text("a".to_string().into())]);
assert!(verify_key_match(&row1, &row2, &[0], &[0]));
assert!(!verify_key_match(&row1, &row3, &[0], &[0]));
assert!(verify_key_match(&row1, &row3, &[1], &[1]));
}
#[test]
fn test_parallel_order_by() {
let mut rows: RowVec = (0..1000)
.map(|i| {
(
i,
Row::from_values(vec![
Value::Integer((i * 7 + 13) % 1000), Value::Text(format!("row_{}", i).into()),
]),
)
})
.collect();
let config = ParallelConfig {
min_rows_for_parallel_sort: 100,
..Default::default()
};
let sort_specs = vec![SortSpec {
column_index: 0,
direction: SortDirection::Ascending,
nulls_first: false,
}];
parallel_order_by(&mut rows, &sort_specs, &config);
for i in 1..rows.len() {
let prev = rows[i - 1].1.get(0).and_then(|v| v.as_int64()).unwrap();
let curr = rows[i].1.get(0).and_then(|v| v.as_int64()).unwrap();
assert!(prev <= curr, "Row {} should be <= row {}", i - 1, i);
}
}
#[test]
fn test_parallel_order_by_descending() {
let mut rows: RowVec = (0..500)
.map(|i| (i, Row::from_values(vec![Value::Integer(i)])))
.collect();
let config = ParallelConfig {
min_rows_for_parallel_sort: 100,
..Default::default()
};
let sort_specs = vec![SortSpec {
column_index: 0,
direction: SortDirection::Descending,
nulls_first: false,
}];
parallel_order_by(&mut rows, &sort_specs, &config);
for i in 1..rows.len() {
let prev = rows[i - 1].1.get(0).and_then(|v| v.as_int64()).unwrap();
let curr = rows[i].1.get(0).and_then(|v| v.as_int64()).unwrap();
assert!(prev >= curr, "Row {} should be >= row {}", i - 1, i);
}
}
#[test]
fn test_parallel_order_by_with_nulls() {
let mut rows = RowVec::new();
rows.push((0, Row::from_values(vec![Value::Integer(3)])));
rows.push((1, Row::from_values(vec![Value::null_unknown()])));
rows.push((2, Row::from_values(vec![Value::Integer(1)])));
rows.push((3, Row::from_values(vec![Value::null_unknown()])));
rows.push((4, Row::from_values(vec![Value::Integer(2)])));
let config = ParallelConfig {
min_rows_for_parallel_sort: 1, ..Default::default()
};
let sort_specs = vec![SortSpec {
column_index: 0,
direction: SortDirection::Ascending,
nulls_first: true,
}];
parallel_order_by(&mut rows, &sort_specs, &config);
assert!(rows[0].1.get(0).map(|v| v.is_null()).unwrap_or(false));
assert!(rows[1].1.get(0).map(|v| v.is_null()).unwrap_or(false));
assert_eq!(rows[2].1.get(0), Some(&Value::Integer(1)));
assert_eq!(rows[3].1.get(0), Some(&Value::Integer(2)));
assert_eq!(rows[4].1.get(0), Some(&Value::Integer(3)));
}
#[test]
fn test_distinct_hash_collision_handling() {
let mut rows = RowVec::new();
rows.push((
0,
Row::from_values(vec![Value::Integer(1), Value::Text("a".into())]),
));
rows.push((
1,
Row::from_values(vec![Value::Integer(1), Value::Text("b".into())]),
)); rows.push((
2,
Row::from_values(vec![Value::Integer(1), Value::Text("a".into())]),
)); rows.push((
3,
Row::from_values(vec![Value::Integer(2), Value::Text("a".into())]),
)); rows.push((
4,
Row::from_values(vec![Value::Integer(2), Value::Text("a".into())]),
));
let config = ParallelConfig {
min_rows_for_parallel_filter: 1, ..Default::default()
};
let result = parallel_distinct(rows, &config);
assert_eq!(result.len(), 3, "Should have 3 unique rows");
let has_1_a = result.iter().any(|(_, r)| {
r.get(0) == Some(&Value::Integer(1)) && r.get(1) == Some(&Value::Text("a".into()))
});
let has_1_b = result.iter().any(|(_, r)| {
r.get(0) == Some(&Value::Integer(1)) && r.get(1) == Some(&Value::Text("b".into()))
});
let has_2_a = result.iter().any(|(_, r)| {
r.get(0) == Some(&Value::Integer(2)) && r.get(1) == Some(&Value::Text("a".into()))
});
assert!(has_1_a, "Should contain (1, 'a')");
assert!(has_1_b, "Should contain (1, 'b')");
assert!(has_2_a, "Should contain (2, 'a')");
}
#[test]
fn test_sequential_distinct_hash_collision() {
let mut rows = RowVec::new();
rows.push((0, Row::from_values(vec![Value::Integer(100)])));
rows.push((1, Row::from_values(vec![Value::Integer(200)])));
rows.push((2, Row::from_values(vec![Value::Integer(100)]))); rows.push((3, Row::from_values(vec![Value::Integer(300)])));
let config = ParallelConfig {
min_rows_for_parallel_filter: 10000,
..Default::default()
};
let result = parallel_distinct(rows, &config);
assert_eq!(
result.len(),
3,
"Should have 3 unique values: 100, 200, 300"
);
}
#[test]
fn test_parallel_hash_join_collision_handling() {
let build_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("build_a".into())]),
Row::from_values(vec![Value::Integer(2), Value::Text("build_b".into())]),
Row::from_values(vec![Value::Integer(3), Value::Text("build_c".into())]),
];
let probe_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("probe_x".into())]),
Row::from_values(vec![Value::Integer(2), Value::Text("probe_y".into())]),
Row::from_values(vec![Value::Integer(4), Value::Text("probe_z".into())]), ];
let config = ParallelConfig {
min_rows_for_parallel_join: 1, ..Default::default()
};
let result = parallel_hash_join(
&probe_rows,
&build_rows,
&[0], &[0], JoinType::Inner,
2, 2, false,
&config,
);
assert_eq!(result.rows.len(), 2, "INNER JOIN should have 2 matches");
for row in &result.rows {
assert_eq!(row.len(), 4);
}
}
#[test]
fn test_parallel_left_join_unmatched() {
let build_rows: Vec<Row> = vec![Row::from_values(vec![
Value::Integer(1),
Value::Text("match".into()),
])];
let probe_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("p1".into())]), Row::from_values(vec![Value::Integer(2), Value::Text("p2".into())]), Row::from_values(vec![Value::Integer(3), Value::Text("p3".into())]), ];
let config = ParallelConfig {
min_rows_for_parallel_join: 1,
..Default::default()
};
let result = parallel_hash_join(
&probe_rows,
&build_rows,
&[0],
&[0],
JoinType::Left,
2,
2,
false,
&config,
);
assert_eq!(result.rows.len(), 3, "LEFT JOIN should have 3 rows");
let null_count = result
.rows
.iter()
.filter(|r| {
r.get(2).map(|v| v.is_null()).unwrap_or(false)
&& r.get(3).map(|v| v.is_null()).unwrap_or(false)
})
.count();
assert_eq!(
null_count, 2,
"Should have 2 unmatched rows with NULL build columns"
);
}
#[test]
fn test_parallel_right_join_unmatched() {
let build_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("b1".into())]), Row::from_values(vec![Value::Integer(2), Value::Text("b2".into())]), Row::from_values(vec![Value::Integer(3), Value::Text("b3".into())]), ];
let probe_rows: Vec<Row> = vec![Row::from_values(vec![
Value::Integer(1),
Value::Text("p1".into()),
])];
let config = ParallelConfig {
min_rows_for_parallel_join: 1,
..Default::default()
};
let result = parallel_hash_join(
&probe_rows,
&build_rows,
&[0],
&[0],
JoinType::Right,
2,
2,
false,
&config,
);
assert_eq!(result.rows.len(), 3, "RIGHT JOIN should have 3 rows");
let null_count = result
.rows
.iter()
.filter(|r| {
r.get(0).map(|v| v.is_null()).unwrap_or(false)
&& r.get(1).map(|v| v.is_null()).unwrap_or(false)
})
.count();
assert_eq!(
null_count, 2,
"Should have 2 unmatched rows with NULL probe columns"
);
}
#[test]
fn test_parallel_full_outer_join() {
let build_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("b1".into())]), Row::from_values(vec![Value::Integer(3), Value::Text("b3".into())]), ];
let probe_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("p1".into())]), Row::from_values(vec![Value::Integer(2), Value::Text("p2".into())]), ];
let config = ParallelConfig {
min_rows_for_parallel_join: 1,
..Default::default()
};
let result = parallel_hash_join(
&probe_rows,
&build_rows,
&[0],
&[0],
JoinType::Full,
2,
2,
false,
&config,
);
assert_eq!(result.rows.len(), 3, "FULL OUTER JOIN should have 3 rows");
}
#[test]
fn test_parallel_join_empty_tables() {
let config = ParallelConfig::default();
let result = parallel_hash_join(
&[],
&[Row::from_values(vec![Value::Integer(1)])],
&[0],
&[0],
JoinType::Inner,
1,
1,
false,
&config,
);
assert_eq!(
result.rows.len(),
0,
"Empty probe should give empty result for INNER"
);
let result = parallel_hash_join(
&[Row::from_values(vec![Value::Integer(1)])],
&[],
&[0],
&[0],
JoinType::Inner,
1,
1,
false,
&config,
);
assert_eq!(
result.rows.len(),
0,
"Empty build should give empty result for INNER"
);
let result = parallel_hash_join(
&[
Row::from_values(vec![Value::Integer(1)]),
Row::from_values(vec![Value::Integer(2)]),
],
&[],
&[0],
&[0],
JoinType::Left,
1,
1,
false,
&config,
);
assert_eq!(
result.rows.len(),
2,
"LEFT JOIN with empty build should have all probe rows"
);
}
#[test]
fn test_parallel_join_swapped() {
let build_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("b1".into())]),
Row::from_values(vec![Value::Integer(3), Value::Text("b3".into())]), ];
let probe_rows: Vec<Row> = vec![
Row::from_values(vec![Value::Integer(1), Value::Text("p1".into())]),
Row::from_values(vec![Value::Integer(2), Value::Text("p2".into())]), ];
let config = ParallelConfig {
min_rows_for_parallel_join: 1,
..Default::default()
};
let result = parallel_hash_join(
&probe_rows,
&build_rows,
&[0],
&[0],
JoinType::Left,
2,
2,
true, &config,
);
assert_eq!(result.rows.len(), 2, "LEFT JOIN swapped should have 2 rows");
let matched_row = result
.rows
.iter()
.find(|r| r.get(0) == Some(&Value::Integer(1)) && r.get(2) == Some(&Value::Integer(1)));
assert!(matched_row.is_some(), "Should have a matched row with id=1");
let unmatched_row = result
.rows
.iter()
.find(|r| r.get(0) == Some(&Value::Integer(3)));
assert!(
unmatched_row.is_some(),
"Should have unmatched build row with id=3"
);
let unmatched = unmatched_row.unwrap();
assert!(
unmatched.get(2).map(|v| v.is_null()).unwrap_or(false),
"Probe col should be NULL"
);
assert!(
unmatched.get(3).map(|v| v.is_null()).unwrap_or(false),
"Probe col should be NULL"
);
}
#[test]
fn test_rows_equal() {
let row1 = Row::from_values(vec![Value::Integer(1), Value::Text("a".into())]);
let row2 = Row::from_values(vec![Value::Integer(1), Value::Text("a".into())]);
let row3 = Row::from_values(vec![Value::Integer(1), Value::Text("b".into())]);
let row4 = Row::from_values(vec![Value::Integer(1)]);
assert!(rows_equal(&row1, &row2), "Identical rows should be equal");
assert!(
!rows_equal(&row1, &row3),
"Different values should not be equal"
);
assert!(
!rows_equal(&row1, &row4),
"Different lengths should not be equal"
);
}
#[test]
fn test_rows_equal_with_nulls() {
let row_with_null1 = Row::from_values(vec![Value::Integer(1), Value::null_unknown()]);
let row_with_null2 = Row::from_values(vec![Value::Integer(1), Value::null_unknown()]);
let row_no_null = Row::from_values(vec![Value::Integer(1), Value::Integer(2)]);
assert!(
rows_equal(&row_with_null1, &row_with_null2),
"Rows with same NULL positions should be equal"
);
assert!(
!rows_equal(&row_with_null1, &row_no_null),
"NULL should not equal non-NULL"
);
}
}