use crate::executor::operators::Operator;
use crate::executor::pipeline::{ExecutionContext, RowBatch};
use crate::executor::plan::PhysicalPlan;
use crate::executor::{ExecutionError, Result};
use rayon::prelude::*;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub enabled: bool,
pub num_threads: usize,
pub batch_size: usize,
}
impl ParallelConfig {
pub fn new() -> Self {
Self {
enabled: true,
num_threads: 0, batch_size: 1024,
}
}
pub fn sequential() -> Self {
Self {
enabled: false,
num_threads: 1,
batch_size: 1024,
}
}
pub fn with_threads(num_threads: usize) -> Self {
Self {
enabled: true,
num_threads,
batch_size: 1024,
}
}
}
impl Default for ParallelConfig {
fn default() -> Self {
Self::new()
}
}
pub struct ParallelExecutor {
config: ParallelConfig,
thread_pool: rayon::ThreadPool,
}
impl ParallelExecutor {
pub fn new(config: ParallelConfig) -> Self {
let num_threads = if config.num_threads == 0 {
num_cpus::get()
} else {
config.num_threads
};
let thread_pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
.expect("Failed to create thread pool");
Self {
config,
thread_pool,
}
}
pub fn execute(&self, plan: &PhysicalPlan) -> Result<Vec<RowBatch>> {
if !self.config.enabled {
return self.execute_sequential(plan);
}
if plan.pipeline_breakers.is_empty() {
self.execute_parallel_scan(plan)
} else {
self.execute_parallel_staged(plan)
}
}
fn execute_sequential(&self, plan: &PhysicalPlan) -> Result<Vec<RowBatch>> {
let mut results = Vec::new();
Ok(results)
}
fn execute_parallel_scan(&self, plan: &PhysicalPlan) -> Result<Vec<RowBatch>> {
let results = Arc::new(Mutex::new(Vec::new()));
let num_partitions = self.config.num_threads.max(1);
self.thread_pool.scope(|s| {
for partition_id in 0..num_partitions {
let results = Arc::clone(&results);
s.spawn(move |_| {
let batch = self.execute_partition(plan, partition_id, num_partitions);
if let Ok(Some(b)) = batch {
results.lock().unwrap().push(b);
}
});
}
});
let final_results = Arc::try_unwrap(results)
.map_err(|_| ExecutionError::Internal("Failed to unwrap results".to_string()))?
.into_inner()
.map_err(|_| ExecutionError::Internal("Failed to acquire lock".to_string()))?;
Ok(final_results)
}
fn execute_partition(
&self,
plan: &PhysicalPlan,
partition_id: usize,
num_partitions: usize,
) -> Result<Option<RowBatch>> {
Ok(None)
}
fn execute_parallel_staged(&self, plan: &PhysicalPlan) -> Result<Vec<RowBatch>> {
let mut intermediate_results = Vec::new();
let mut start = 0;
for &breaker in &plan.pipeline_breakers {
let stage_results = self.execute_stage(plan, start, breaker)?;
intermediate_results = stage_results;
start = breaker + 1;
}
let final_results = self.execute_stage(plan, start, plan.operators.len())?;
Ok(final_results)
}
fn execute_stage(
&self,
plan: &PhysicalPlan,
start: usize,
end: usize,
) -> Result<Vec<RowBatch>> {
Ok(Vec::new())
}
pub fn process_batches_parallel<F>(
&self,
batches: Vec<RowBatch>,
processor: F,
) -> Result<Vec<RowBatch>>
where
F: Fn(RowBatch) -> Result<RowBatch> + Send + Sync,
{
let results: Vec<_> = self.thread_pool.install(|| {
batches
.into_par_iter()
.map(|batch| processor(batch))
.collect()
});
results.into_iter().collect()
}
pub fn aggregate_parallel<K, V, F, G>(
&self,
batches: Vec<RowBatch>,
key_fn: F,
agg_fn: G,
) -> Result<Vec<(K, V)>>
where
K: Send + Sync + Eq + std::hash::Hash,
V: Send + Sync,
F: Fn(&RowBatch) -> K + Send + Sync,
G: Fn(Vec<RowBatch>) -> V + Send + Sync,
{
use std::collections::HashMap;
let mut groups: HashMap<K, Vec<RowBatch>> = HashMap::new();
for batch in batches {
let key = key_fn(&batch);
groups.entry(key).or_insert_with(Vec::new).push(batch);
}
let results: Vec<_> = self.thread_pool.install(|| {
groups
.into_par_iter()
.map(|(key, batches)| (key, agg_fn(batches)))
.collect()
});
Ok(results)
}
pub fn num_threads(&self) -> usize {
self.thread_pool.current_num_threads()
}
}
pub struct ScanPartitioner {
total_rows: usize,
num_partitions: usize,
}
impl ScanPartitioner {
pub fn new(total_rows: usize, num_partitions: usize) -> Self {
Self {
total_rows,
num_partitions,
}
}
pub fn partition_range(&self, partition_id: usize) -> (usize, usize) {
let rows_per_partition = (self.total_rows + self.num_partitions - 1) / self.num_partitions;
let start = partition_id * rows_per_partition;
let end = (start + rows_per_partition).min(self.total_rows);
(start, end)
}
pub fn is_valid_partition(&self, partition_id: usize) -> bool {
partition_id < self.num_partitions
}
}
pub enum ParallelJoinStrategy {
Broadcast,
PartitionedHash,
SortMerge,
}
pub struct ParallelJoin {
strategy: ParallelJoinStrategy,
executor: Arc<ParallelExecutor>,
}
impl ParallelJoin {
pub fn new(strategy: ParallelJoinStrategy, executor: Arc<ParallelExecutor>) -> Self {
Self { strategy, executor }
}
pub fn execute(&self, left: Vec<RowBatch>, right: Vec<RowBatch>) -> Result<Vec<RowBatch>> {
match self.strategy {
ParallelJoinStrategy::Broadcast => self.broadcast_join(left, right),
ParallelJoinStrategy::PartitionedHash => self.partitioned_hash_join(left, right),
ParallelJoinStrategy::SortMerge => self.sort_merge_join(left, right),
}
}
fn broadcast_join(&self, left: Vec<RowBatch>, right: Vec<RowBatch>) -> Result<Vec<RowBatch>> {
let (build_side, probe_side) = if left.len() < right.len() {
(left, right)
} else {
(right, left)
};
Ok(Vec::new())
}
fn partitioned_hash_join(
&self,
left: Vec<RowBatch>,
right: Vec<RowBatch>,
) -> Result<Vec<RowBatch>> {
Ok(Vec::new())
}
fn sort_merge_join(&self, left: Vec<RowBatch>, right: Vec<RowBatch>) -> Result<Vec<RowBatch>> {
Ok(Vec::new())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parallel_config() {
let config = ParallelConfig::new();
assert!(config.enabled);
assert_eq!(config.num_threads, 0);
let seq_config = ParallelConfig::sequential();
assert!(!seq_config.enabled);
}
#[test]
fn test_parallel_executor_creation() {
let config = ParallelConfig::with_threads(4);
let executor = ParallelExecutor::new(config);
assert_eq!(executor.num_threads(), 4);
}
#[test]
fn test_scan_partitioner() {
let partitioner = ScanPartitioner::new(100, 4);
let (start, end) = partitioner.partition_range(0);
assert_eq!(start, 0);
assert_eq!(end, 25);
let (start, end) = partitioner.partition_range(3);
assert_eq!(start, 75);
assert_eq!(end, 100);
}
#[test]
fn test_partition_validity() {
let partitioner = ScanPartitioner::new(100, 4);
assert!(partitioner.is_valid_partition(0));
assert!(partitioner.is_valid_partition(3));
assert!(!partitioner.is_valid_partition(4));
}
}