use crate::error::{DbxError, DbxResult};
use rayon::ThreadPoolBuilder;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ParallelizationPolicy {
#[default]
Auto,
Fixed(usize),
Adaptive,
}
#[derive(Debug, Clone)]
pub struct ParallelismConfig {
pub cpu_cap: f64,
pub min_rows_for_parallel: usize,
}
impl Default for ParallelismConfig {
fn default() -> Self {
Self {
cpu_cap: 1.0,
min_rows_for_parallel: 1_000,
}
}
}
impl ParallelismConfig {
pub fn conservative() -> Self {
Self {
cpu_cap: 0.5,
min_rows_for_parallel: 5_000,
}
}
pub fn aggressive() -> Self {
Self {
cpu_cap: 1.0,
min_rows_for_parallel: 500,
}
}
}
pub struct ParallelExecutionEngine {
thread_pool: Arc<rayon::ThreadPool>,
policy: ParallelizationPolicy,
pub db_config: DbConfig,
}
impl ParallelExecutionEngine {
pub fn new(policy: ParallelizationPolicy) -> DbxResult<Self> {
Self::new_with_config(policy, DbConfig::default())
}
pub fn new_with_config(policy: ParallelizationPolicy, config: DbConfig) -> DbxResult<Self> {
let cpu_cap = config.parallelism.cpu_cap.clamp(0.01, 1.0);
let base_threads = Self::determine_thread_count(policy);
let capped_threads = ((base_threads as f64 * cpu_cap).ceil() as usize).max(1);
let thread_pool = ThreadPoolBuilder::new()
.num_threads(capped_threads)
.thread_name(|i| format!("dbx-parallel-{}", i))
.build()
.map_err(|e| {
DbxError::NotImplemented(format!("Failed to create thread pool: {}", e))
})?;
Ok(Self {
thread_pool: Arc::new(thread_pool),
policy,
db_config: config,
})
}
pub fn new_auto() -> DbxResult<Self> {
Self::new(ParallelizationPolicy::Auto)
}
pub fn new_fixed(num_threads: usize) -> DbxResult<Self> {
if num_threads == 0 {
return Err(DbxError::InvalidArguments(
"Thread count must be greater than 0".to_string(),
));
}
Self::new(ParallelizationPolicy::Fixed(num_threads))
}
pub fn policy(&self) -> ParallelizationPolicy {
self.policy
}
pub fn thread_count(&self) -> usize {
self.thread_pool.current_num_threads()
}
pub fn thread_pool(&self) -> &rayon::ThreadPool {
&self.thread_pool
}
pub fn config(&self) -> &ParallelismConfig {
&self.db_config.parallelism
}
pub fn db_config(&self) -> &DbConfig {
&self.db_config
}
pub fn should_parallelize_rows(&self, row_count: usize) -> bool {
row_count >= self.db_config.parallelism.min_rows_for_parallel && self.thread_count() > 1
}
pub fn execute<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R + Send,
R: Send,
{
self.thread_pool.install(f)
}
fn determine_thread_count(policy: ParallelizationPolicy) -> usize {
match policy {
ParallelizationPolicy::Auto => {
let num_cpus = num_cpus::get();
num_cpus.min(16)
}
ParallelizationPolicy::Fixed(n) => n,
ParallelizationPolicy::Adaptive => {
let num_cpus = num_cpus::get();
(num_cpus / 2).max(1)
}
}
}
pub fn auto_tune(&self, workload_size: usize) -> usize {
self.auto_tune_weighted(workload_size, 1.0)
}
pub fn auto_tune_weighted(&self, workload_size: usize, avg_complexity: f64) -> usize {
let thread_count = self.thread_count();
match self.policy {
ParallelizationPolicy::Auto | ParallelizationPolicy::Adaptive => {
let base_threshold: f64 = 1000.0;
let adjusted_threshold =
(base_threshold / avg_complexity.max(0.1)).max(1.0) as usize;
if workload_size < adjusted_threshold {
1
} else {
let optimal = (workload_size / adjusted_threshold).min(thread_count);
optimal.max(1)
}
}
ParallelizationPolicy::Fixed(_) => thread_count,
}
}
pub fn estimate_query_complexity(sql: &str) -> f64 {
let sql_upper = sql.to_uppercase();
let mut score = 1.0;
let join_count = sql_upper.matches("JOIN").count();
score += join_count as f64 * 2.0;
let subquery_depth = sql_upper.matches("SELECT").count().saturating_sub(1);
score += subquery_depth as f64 * 3.0;
if sql_upper.contains("WITH ") {
score += 4.0;
}
let union_count = sql_upper.matches("UNION").count();
score += union_count as f64 * 2.5;
for func in ["COUNT(", "SUM(", "AVG(", "MAX(", "MIN("] {
score += sql_upper.matches(func).count() as f64 * 0.5;
}
if sql_upper.contains("OVER(") || sql_upper.contains("OVER (") {
score += 3.0;
}
if sql_upper.contains("ORDER BY") {
score += 0.5;
}
if sql_upper.contains("GROUP BY") {
score += 1.0;
}
if sql_upper.contains("HAVING") {
score += 1.0;
}
score += (sql.len() as f64 / 200.0).min(5.0);
score
}
pub fn should_parallelize(&self, workload_size: usize) -> bool {
self.auto_tune(workload_size) > 1
}
}
impl Default for ParallelExecutionEngine {
fn default() -> Self {
Self::new_auto().expect("Failed to create default parallel execution engine")
}
}
use crate::replication::transport::ReplicationConfig;
use crate::storage::realtime_sync::RealtimeSyncConfig;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DirtyBufferMode {
#[default]
BTreeMap,
DashMap,
}
#[derive(Debug, Clone, Default)]
pub struct DbConfig {
pub parallelism: ParallelismConfig,
pub sync: RealtimeSyncConfig,
pub replication: ReplicationConfig,
pub dirty_buffer_mode: DirtyBufferMode,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_auto() {
let engine = ParallelExecutionEngine::new_auto().unwrap();
assert_eq!(engine.policy(), ParallelizationPolicy::Auto);
assert!(engine.thread_count() > 0);
}
#[test]
fn test_new_fixed() {
let engine = ParallelExecutionEngine::new_fixed(4).unwrap();
assert_eq!(engine.policy(), ParallelizationPolicy::Fixed(4));
assert_eq!(engine.thread_count(), 4);
}
#[test]
fn test_new_fixed_zero_threads() {
let result = ParallelExecutionEngine::new_fixed(0);
assert!(result.is_err());
}
#[test]
fn test_execute() {
let engine = ParallelExecutionEngine::new_auto().unwrap();
let result = engine.execute(|| 42);
assert_eq!(result, 42);
}
#[test]
fn test_auto_tune_small_workload() {
let engine = ParallelExecutionEngine::new_auto().unwrap();
let parallelism = engine.auto_tune(500);
assert_eq!(parallelism, 1); }
#[test]
fn test_auto_tune_large_workload() {
let engine = ParallelExecutionEngine::new_auto().unwrap();
let parallelism = engine.auto_tune(100_000);
assert!(parallelism > 1); }
#[test]
fn test_should_parallelize() {
let engine = ParallelExecutionEngine::new_auto().unwrap();
assert!(!engine.should_parallelize(500)); assert!(engine.should_parallelize(100_000)); }
#[test]
fn test_fixed_policy_always_uses_all_threads() {
let engine = ParallelExecutionEngine::new_fixed(8).unwrap();
let parallelism = engine.auto_tune(100);
assert_eq!(parallelism, 8); }
}