dbx_core/engine/
parallel_engine.rs1use crate::error::{DbxError, DbxResult};
7use rayon::ThreadPoolBuilder;
8use std::sync::Arc;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
12pub enum ParallelizationPolicy {
13 #[default]
15 Auto,
16 Fixed(usize),
18 Adaptive,
20}
21
22pub struct ParallelExecutionEngine {
24 thread_pool: Arc<rayon::ThreadPool>,
25 policy: ParallelizationPolicy,
26}
27
28impl ParallelExecutionEngine {
29 pub fn new(policy: ParallelizationPolicy) -> DbxResult<Self> {
31 let num_threads = Self::determine_thread_count(policy);
32
33 let thread_pool = ThreadPoolBuilder::new()
34 .num_threads(num_threads)
35 .thread_name(|i| format!("dbx-parallel-{}", i))
36 .build()
37 .map_err(|e| {
38 DbxError::NotImplemented(format!("Failed to create thread pool: {}", e))
39 })?;
40
41 Ok(Self {
42 thread_pool: Arc::new(thread_pool),
43 policy,
44 })
45 }
46
47 pub fn new_auto() -> DbxResult<Self> {
49 Self::new(ParallelizationPolicy::Auto)
50 }
51
52 pub fn new_fixed(num_threads: usize) -> DbxResult<Self> {
54 if num_threads == 0 {
55 return Err(DbxError::InvalidArguments(
56 "Thread count must be greater than 0".to_string(),
57 ));
58 }
59 Self::new(ParallelizationPolicy::Fixed(num_threads))
60 }
61
62 pub fn policy(&self) -> ParallelizationPolicy {
64 self.policy
65 }
66
67 pub fn thread_count(&self) -> usize {
69 self.thread_pool.current_num_threads()
70 }
71
72 pub fn thread_pool(&self) -> &rayon::ThreadPool {
74 &self.thread_pool
75 }
76
77 pub fn execute<F, R>(&self, f: F) -> R
79 where
80 F: FnOnce() -> R + Send,
81 R: Send,
82 {
83 self.thread_pool.install(f)
84 }
85
86 fn determine_thread_count(policy: ParallelizationPolicy) -> usize {
88 match policy {
89 ParallelizationPolicy::Auto => {
90 let num_cpus = num_cpus::get();
92 num_cpus.min(16)
93 }
94 ParallelizationPolicy::Fixed(n) => n,
95 ParallelizationPolicy::Adaptive => {
96 let num_cpus = num_cpus::get();
99 (num_cpus / 2).max(1)
100 }
101 }
102 }
103
104 pub fn auto_tune(&self, workload_size: usize) -> usize {
108 self.auto_tune_weighted(workload_size, 1.0)
109 }
110
111 pub fn auto_tune_weighted(&self, workload_size: usize, avg_complexity: f64) -> usize {
115 let thread_count = self.thread_count();
116
117 match self.policy {
118 ParallelizationPolicy::Auto | ParallelizationPolicy::Adaptive => {
119 let base_threshold: f64 = 1000.0;
123 let adjusted_threshold =
124 (base_threshold / avg_complexity.max(0.1)).max(1.0) as usize;
125
126 if workload_size < adjusted_threshold {
127 1
128 } else {
129 let optimal = (workload_size / adjusted_threshold).min(thread_count);
130 optimal.max(1)
131 }
132 }
133 ParallelizationPolicy::Fixed(_) => thread_count,
134 }
135 }
136
137 pub fn estimate_query_complexity(sql: &str) -> f64 {
141 let sql_upper = sql.to_uppercase();
142 let mut score = 1.0;
143
144 let join_count = sql_upper.matches("JOIN").count();
146 score += join_count as f64 * 2.0;
147
148 let subquery_depth = sql_upper.matches("SELECT").count().saturating_sub(1);
150 score += subquery_depth as f64 * 3.0;
151
152 if sql_upper.contains("WITH ") {
154 score += 4.0;
155 }
156
157 let union_count = sql_upper.matches("UNION").count();
159 score += union_count as f64 * 2.5;
160
161 for func in ["COUNT(", "SUM(", "AVG(", "MAX(", "MIN("] {
163 score += sql_upper.matches(func).count() as f64 * 0.5;
164 }
165
166 if sql_upper.contains("OVER(") || sql_upper.contains("OVER (") {
168 score += 3.0;
169 }
170
171 if sql_upper.contains("ORDER BY") {
173 score += 0.5;
174 }
175 if sql_upper.contains("GROUP BY") {
176 score += 1.0;
177 }
178 if sql_upper.contains("HAVING") {
179 score += 1.0;
180 }
181
182 score += (sql.len() as f64 / 200.0).min(5.0);
184
185 score
186 }
187
188 pub fn should_parallelize(&self, workload_size: usize) -> bool {
190 self.auto_tune(workload_size) > 1
191 }
192}
193
194impl Default for ParallelExecutionEngine {
195 fn default() -> Self {
196 Self::new_auto().expect("Failed to create default parallel execution engine")
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_new_auto() {
206 let engine = ParallelExecutionEngine::new_auto().unwrap();
207 assert_eq!(engine.policy(), ParallelizationPolicy::Auto);
208 assert!(engine.thread_count() > 0);
209 }
210
211 #[test]
212 fn test_new_fixed() {
213 let engine = ParallelExecutionEngine::new_fixed(4).unwrap();
214 assert_eq!(engine.policy(), ParallelizationPolicy::Fixed(4));
215 assert_eq!(engine.thread_count(), 4);
216 }
217
218 #[test]
219 fn test_new_fixed_zero_threads() {
220 let result = ParallelExecutionEngine::new_fixed(0);
221 assert!(result.is_err());
222 }
223
224 #[test]
225 fn test_execute() {
226 let engine = ParallelExecutionEngine::new_auto().unwrap();
227 let result = engine.execute(|| 42);
228 assert_eq!(result, 42);
229 }
230
231 #[test]
232 fn test_auto_tune_small_workload() {
233 let engine = ParallelExecutionEngine::new_auto().unwrap();
234 let parallelism = engine.auto_tune(500);
235 assert_eq!(parallelism, 1); }
237
238 #[test]
239 fn test_auto_tune_large_workload() {
240 let engine = ParallelExecutionEngine::new_auto().unwrap();
241 let parallelism = engine.auto_tune(100_000);
242 assert!(parallelism > 1); }
244
245 #[test]
246 fn test_should_parallelize() {
247 let engine = ParallelExecutionEngine::new_auto().unwrap();
248 assert!(!engine.should_parallelize(500)); assert!(engine.should_parallelize(100_000)); }
251
252 #[test]
253 fn test_fixed_policy_always_uses_all_threads() {
254 let engine = ParallelExecutionEngine::new_fixed(8).unwrap();
255 let parallelism = engine.auto_tune(100);
256 assert_eq!(parallelism, 8); }
258}