Skip to main content

dbx_core/sql/
parallel_parser.rs

1//! Parallel SQL Parser — Batch SQL parsing using Rayon
2//!
3//! This module provides parallel SQL parsing capabilities to process multiple
4//! SQL statements concurrently, improving throughput for batch operations.
5
6use crate::error::{DbxError, DbxResult};
7use rayon::prelude::*;
8use sqlparser::ast::Statement;
9use sqlparser::dialect::GenericDialect;
10use sqlparser::parser::Parser;
11use std::sync::Arc;
12
13/// Parallel SQL parser for batch processing
14pub struct ParallelSqlParser {
15    dialect: GenericDialect,
16    thread_pool: Option<Arc<rayon::ThreadPool>>,
17}
18
19impl ParallelSqlParser {
20    /// Create a new parallel SQL parser
21    pub fn new() -> Self {
22        Self {
23            dialect: GenericDialect {},
24            thread_pool: None,
25        }
26    }
27
28    /// Create a new parallel SQL parser with a custom thread pool
29    pub fn with_thread_pool(thread_pool: Arc<rayon::ThreadPool>) -> Self {
30        Self {
31            dialect: GenericDialect {},
32            thread_pool: Some(thread_pool),
33        }
34    }
35
36    /// Parse a single SQL string into AST
37    pub fn parse(&self, sql: &str) -> DbxResult<Vec<Statement>> {
38        Parser::parse_sql(&self.dialect, sql).map_err(|e| DbxError::SqlParse {
39            message: e.to_string(),
40            sql: sql.to_string(),
41        })
42    }
43
44    /// Parse multiple SQL strings in parallel with optimized scheduling
45    ///
46    /// Applies three optimization layers:
47    /// 1. Dynamic thread pool: adjusts parallelism based on workload complexity
48    /// 2. Adaptive batch splitting: distributes work by estimated query complexity
49    /// 3. Lock-free result collection: pre-allocated indexed output
50    ///
51    /// # Arguments
52    ///
53    /// * `sqls` - A slice of SQL strings to parse
54    ///
55    /// # Example
56    ///
57    /// ```rust
58    /// use dbx_core::sql::parallel_parser::ParallelSqlParser;
59    ///
60    /// let parser = ParallelSqlParser::new();
61    /// let sqls = vec![
62    ///     "SELECT * FROM users",
63    ///     "SELECT * FROM orders",
64    ///     "SELECT * FROM products",
65    /// ];
66    /// let results = parser.parse_batch(&sqls).unwrap();
67    /// assert_eq!(results.len(), 3);
68    /// ```
69    pub fn parse_batch(&self, sqls: &[&str]) -> DbxResult<Vec<Vec<Statement>>> {
70        let len = sqls.len();
71        if len == 0 {
72            return Ok(Vec::new());
73        }
74
75        // Fast-path: small batches always sequential (no complexity estimation overhead)
76        if len < 4 {
77            return sqls
78                .iter()
79                .map(|sql| self.parse(sql))
80                .collect::<DbxResult<Vec<_>>>();
81        }
82
83        // For medium+ batches, sample complexity to decide parallelism strategy
84        let avg_complexity = if len <= 20 {
85            sqls.iter()
86                .map(|s| Self::estimate_complexity(s))
87                .sum::<f64>()
88                / len as f64
89        } else {
90            // Sample first 10 for speed
91            sqls.iter()
92                .take(10)
93                .map(|s| Self::estimate_complexity(s))
94                .sum::<f64>()
95                / 10.0
96        };
97
98        // Dynamic threshold
99        let parallel_threshold = if avg_complexity > 5.0 {
100            4
101        } else if avg_complexity > 2.0 {
102            8
103        } else {
104            16 // Simple queries: only parallelize large batches
105        };
106
107        if len < parallel_threshold {
108            return sqls
109                .iter()
110                .map(|sql| self.parse(sql))
111                .collect::<DbxResult<Vec<_>>>();
112        }
113
114        // Parallel execution
115        let results: Vec<Option<DbxResult<Vec<Statement>>>> = if let Some(pool) = &self.thread_pool
116        {
117            pool.install(|| self.parallel_parse_adaptive(sqls, avg_complexity))
118        } else {
119            self.parallel_parse_adaptive(sqls, avg_complexity)
120        };
121
122        results
123            .into_iter()
124            .map(|opt| {
125                opt.unwrap_or_else(|| {
126                    Err(DbxError::SqlParse {
127                        message: "Missing parse result".to_string(),
128                        sql: String::new(),
129                    })
130                })
131            })
132            .collect()
133    }
134
135    /// Adaptive parallel parsing with weighted work distribution
136    fn parallel_parse_adaptive(
137        &self,
138        sqls: &[&str],
139        avg_complexity: f64,
140    ) -> Vec<Option<DbxResult<Vec<Statement>>>> {
141        let len = sqls.len();
142
143        if avg_complexity > 5.0 {
144            // High complexity: use work-stealing with fine-grained tasks
145            // Each query is its own task — Rayon's work-stealing handles load balancing
146            sqls.par_iter().map(|sql| Some(self.parse(sql))).collect()
147        } else {
148            // Low/medium complexity: chunk-based parallelism to reduce scheduling overhead
149            let num_threads = rayon::current_num_threads();
150            let chunk_size = (len / num_threads).max(1);
151
152            // Pre-allocate result slots
153            let mut results: Vec<Option<DbxResult<Vec<Statement>>>> = Vec::with_capacity(len);
154            results.resize_with(len, || None);
155
156            // Use parallel chunks with index tracking
157            let chunk_results: Vec<(usize, Vec<DbxResult<Vec<Statement>>>)> = sqls
158                .par_chunks(chunk_size)
159                .enumerate()
160                .map(|(chunk_idx, chunk)| {
161                    let start_idx = chunk_idx * chunk_size;
162                    let parsed: Vec<DbxResult<Vec<Statement>>> =
163                        chunk.iter().map(|sql| self.parse(sql)).collect();
164                    (start_idx, parsed)
165                })
166                .collect();
167
168            // Merge results into pre-allocated slots (single-threaded, no locks needed)
169            for (start_idx, chunk_results_vec) in chunk_results {
170                for (offset, result) in chunk_results_vec.into_iter().enumerate() {
171                    if start_idx + offset < len {
172                        results[start_idx + offset] = Some(result);
173                    }
174                }
175            }
176
177            results
178        }
179    }
180
181    /// Fast complexity estimation using byte-level scanning (zero allocation)
182    fn estimate_complexity(sql: &str) -> f64 {
183        let bytes = sql.as_bytes();
184        let len = bytes.len();
185        let mut score = 1.0;
186
187        // Byte-level case-insensitive keyword counting
188        score += Self::count_keyword_ci(bytes, b"JOIN") as f64 * 2.0;
189        let select_count = Self::count_keyword_ci(bytes, b"SELECT");
190        score += select_count.saturating_sub(1) as f64 * 3.0;
191        if Self::contains_keyword_ci(bytes, b"WITH ") {
192            score += 4.0;
193        }
194        score += Self::count_keyword_ci(bytes, b"UNION") as f64 * 2.5;
195
196        // Length as proxy
197        score += (len as f64 / 200.0).min(5.0);
198        score
199    }
200
201    /// Count occurrences of keyword (case-insensitive, ASCII only)
202    #[inline]
203    fn count_keyword_ci(haystack: &[u8], needle: &[u8]) -> usize {
204        if needle.len() > haystack.len() {
205            return 0;
206        }
207        let mut count = 0;
208        for i in 0..=(haystack.len() - needle.len()) {
209            if haystack[i..i + needle.len()]
210                .iter()
211                .zip(needle.iter())
212                .all(|(h, n)| h.to_ascii_uppercase() == *n)
213            {
214                count += 1;
215            }
216        }
217        count
218    }
219
220    /// Check if keyword exists (case-insensitive, ASCII only)
221    #[inline]
222    fn contains_keyword_ci(haystack: &[u8], needle: &[u8]) -> bool {
223        Self::count_keyword_ci(haystack, needle) > 0
224    }
225
226    /// Parse multiple SQL strings in parallel, collecting only successful results
227    ///
228    /// Returns a vector of successful parse results and a vector of errors.
229    /// This is useful when you want to continue processing even if some SQL strings fail.
230    ///
231    /// # Arguments
232    ///
233    /// * `sqls` - A slice of SQL strings to parse
234    ///
235    /// # Returns
236    ///
237    /// A tuple of (successful_results, errors)
238    pub fn parse_batch_partial(
239        &self,
240        sqls: &[&str],
241    ) -> (Vec<Vec<Statement>>, Vec<(usize, DbxError)>) {
242        let results = if let Some(pool) = &self.thread_pool {
243            pool.install(|| {
244                sqls.par_iter()
245                    .enumerate()
246                    .map(|(idx, sql)| (idx, self.parse(sql)))
247                    .collect::<Vec<_>>()
248            })
249        } else {
250            sqls.par_iter()
251                .enumerate()
252                .map(|(idx, sql)| (idx, self.parse(sql)))
253                .collect::<Vec<_>>()
254        };
255
256        let mut successes = Vec::new();
257        let mut errors = Vec::new();
258
259        for (idx, result) in results {
260            match result {
261                Ok(statements) => successes.push(statements),
262                Err(e) => errors.push((idx, e)),
263            }
264        }
265
266        (successes, errors)
267    }
268
269    /// Parse multiple SQL strings and execute a callback for each result
270    ///
271    /// This is useful for streaming processing where you want to handle each
272    /// result as it becomes available.
273    ///
274    /// # Arguments
275    ///
276    /// * `sqls` - A slice of SQL strings to parse
277    /// * `callback` - A function to call for each parse result
278    pub fn parse_batch_with_callback<F>(&self, sqls: &[&str], mut callback: F) -> DbxResult<()>
279    where
280        F: FnMut(usize, DbxResult<Vec<Statement>>) -> DbxResult<()>,
281    {
282        let results = if let Some(pool) = &self.thread_pool {
283            pool.install(|| {
284                sqls.par_iter()
285                    .enumerate()
286                    .map(|(idx, sql)| (idx, self.parse(sql)))
287                    .collect::<Vec<_>>()
288            })
289        } else {
290            sqls.par_iter()
291                .enumerate()
292                .map(|(idx, sql)| (idx, self.parse(sql)))
293                .collect::<Vec<_>>()
294        };
295
296        for (idx, result) in results {
297            callback(idx, result)?;
298        }
299
300        Ok(())
301    }
302}
303
304impl Default for ParallelSqlParser {
305    fn default() -> Self {
306        Self::new()
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_parse_single() {
316        let parser = ParallelSqlParser::new();
317        let result = parser.parse("SELECT * FROM users");
318        assert!(result.is_ok());
319        assert_eq!(result.unwrap().len(), 1);
320    }
321
322    #[test]
323    fn test_parse_batch_small() {
324        let parser = ParallelSqlParser::new();
325        let sqls = vec!["SELECT * FROM users", "SELECT * FROM orders"];
326        let results = parser.parse_batch(&sqls).unwrap();
327        assert_eq!(results.len(), 2);
328        assert_eq!(results[0].len(), 1);
329        assert_eq!(results[1].len(), 1);
330    }
331
332    #[test]
333    fn test_parse_batch_large() {
334        let parser = ParallelSqlParser::new();
335        let sqls = vec![
336            "SELECT * FROM users",
337            "SELECT * FROM orders",
338            "SELECT * FROM products",
339            "SELECT * FROM categories",
340            "SELECT * FROM reviews",
341        ];
342        let results = parser.parse_batch(&sqls).unwrap();
343        assert_eq!(results.len(), 5);
344    }
345
346    #[test]
347    fn test_parse_batch_with_error() {
348        let parser = ParallelSqlParser::new();
349        let sqls = vec!["SELECT * FROM users", "INVALID SQL", "SELECT * FROM orders"];
350        let result = parser.parse_batch(&sqls);
351        assert!(result.is_err());
352    }
353
354    #[test]
355    fn test_parse_batch_partial() {
356        let parser = ParallelSqlParser::new();
357        let sqls = vec!["SELECT * FROM users", "INVALID SQL", "SELECT * FROM orders"];
358        let (successes, errors) = parser.parse_batch_partial(&sqls);
359        assert_eq!(successes.len(), 2);
360        assert_eq!(errors.len(), 1);
361        assert_eq!(errors[0].0, 1); // Error at index 1
362    }
363
364    #[test]
365    fn test_parse_batch_with_callback() {
366        let parser = ParallelSqlParser::new();
367        let sqls = vec!["SELECT * FROM users", "SELECT * FROM orders"];
368        let mut count = 0;
369        parser
370            .parse_batch_with_callback(&sqls, |_idx, result| {
371                assert!(result.is_ok());
372                count += 1;
373                Ok(())
374            })
375            .unwrap();
376        assert_eq!(count, 2);
377    }
378
379    #[test]
380    fn test_with_custom_thread_pool() {
381        let pool = rayon::ThreadPoolBuilder::new()
382            .num_threads(2)
383            .build()
384            .unwrap();
385        let parser = ParallelSqlParser::with_thread_pool(Arc::new(pool));
386        let sqls = vec![
387            "SELECT * FROM users",
388            "SELECT * FROM orders",
389            "SELECT * FROM products",
390        ];
391        let results = parser.parse_batch(&sqls).unwrap();
392        assert_eq!(results.len(), 3);
393    }
394
395    #[test]
396    fn test_parse_multiple_statements() {
397        let parser = ParallelSqlParser::new();
398        let result = parser.parse("SELECT * FROM users; SELECT * FROM orders;");
399        assert!(result.is_ok());
400        let statements = result.unwrap();
401        assert_eq!(statements.len(), 2);
402    }
403}