dbx_core/sql/
parallel_parser.rs1use crate::error::{DbxError, DbxResult};
7use rayon::prelude::*;
8use sqlparser::ast::Statement;
9use sqlparser::dialect::GenericDialect;
10use sqlparser::parser::Parser;
11use std::sync::Arc;
12
13pub struct ParallelSqlParser {
15 dialect: GenericDialect,
16 thread_pool: Option<Arc<rayon::ThreadPool>>,
17}
18
19impl ParallelSqlParser {
20 pub fn new() -> Self {
22 Self {
23 dialect: GenericDialect {},
24 thread_pool: None,
25 }
26 }
27
28 pub fn with_thread_pool(thread_pool: Arc<rayon::ThreadPool>) -> Self {
30 Self {
31 dialect: GenericDialect {},
32 thread_pool: Some(thread_pool),
33 }
34 }
35
36 pub fn parse(&self, sql: &str) -> DbxResult<Vec<Statement>> {
38 Parser::parse_sql(&self.dialect, sql).map_err(|e| DbxError::SqlParse {
39 message: e.to_string(),
40 sql: sql.to_string(),
41 })
42 }
43
44 pub fn parse_batch(&self, sqls: &[&str]) -> DbxResult<Vec<Vec<Statement>>> {
70 let len = sqls.len();
71 if len == 0 {
72 return Ok(Vec::new());
73 }
74
75 if len < 4 {
77 return sqls
78 .iter()
79 .map(|sql| self.parse(sql))
80 .collect::<DbxResult<Vec<_>>>();
81 }
82
83 let avg_complexity = if len <= 20 {
85 sqls.iter()
86 .map(|s| Self::estimate_complexity(s))
87 .sum::<f64>()
88 / len as f64
89 } else {
90 sqls.iter()
92 .take(10)
93 .map(|s| Self::estimate_complexity(s))
94 .sum::<f64>()
95 / 10.0
96 };
97
98 let parallel_threshold = if avg_complexity > 5.0 {
100 4
101 } else if avg_complexity > 2.0 {
102 8
103 } else {
104 16 };
106
107 if len < parallel_threshold {
108 return sqls
109 .iter()
110 .map(|sql| self.parse(sql))
111 .collect::<DbxResult<Vec<_>>>();
112 }
113
114 let results: Vec<Option<DbxResult<Vec<Statement>>>> = if let Some(pool) = &self.thread_pool
116 {
117 pool.install(|| self.parallel_parse_adaptive(sqls, avg_complexity))
118 } else {
119 self.parallel_parse_adaptive(sqls, avg_complexity)
120 };
121
122 results
123 .into_iter()
124 .map(|opt| {
125 opt.unwrap_or_else(|| {
126 Err(DbxError::SqlParse {
127 message: "Missing parse result".to_string(),
128 sql: String::new(),
129 })
130 })
131 })
132 .collect()
133 }
134
135 fn parallel_parse_adaptive(
137 &self,
138 sqls: &[&str],
139 avg_complexity: f64,
140 ) -> Vec<Option<DbxResult<Vec<Statement>>>> {
141 let len = sqls.len();
142
143 if avg_complexity > 5.0 {
144 sqls.par_iter().map(|sql| Some(self.parse(sql))).collect()
147 } else {
148 let num_threads = rayon::current_num_threads();
150 let chunk_size = (len / num_threads).max(1);
151
152 let mut results: Vec<Option<DbxResult<Vec<Statement>>>> = Vec::with_capacity(len);
154 results.resize_with(len, || None);
155
156 let chunk_results: Vec<(usize, Vec<DbxResult<Vec<Statement>>>)> = sqls
158 .par_chunks(chunk_size)
159 .enumerate()
160 .map(|(chunk_idx, chunk)| {
161 let start_idx = chunk_idx * chunk_size;
162 let parsed: Vec<DbxResult<Vec<Statement>>> =
163 chunk.iter().map(|sql| self.parse(sql)).collect();
164 (start_idx, parsed)
165 })
166 .collect();
167
168 for (start_idx, chunk_results_vec) in chunk_results {
170 for (offset, result) in chunk_results_vec.into_iter().enumerate() {
171 if start_idx + offset < len {
172 results[start_idx + offset] = Some(result);
173 }
174 }
175 }
176
177 results
178 }
179 }
180
181 fn estimate_complexity(sql: &str) -> f64 {
183 let bytes = sql.as_bytes();
184 let len = bytes.len();
185 let mut score = 1.0;
186
187 score += Self::count_keyword_ci(bytes, b"JOIN") as f64 * 2.0;
189 let select_count = Self::count_keyword_ci(bytes, b"SELECT");
190 score += select_count.saturating_sub(1) as f64 * 3.0;
191 if Self::contains_keyword_ci(bytes, b"WITH ") {
192 score += 4.0;
193 }
194 score += Self::count_keyword_ci(bytes, b"UNION") as f64 * 2.5;
195
196 score += (len as f64 / 200.0).min(5.0);
198 score
199 }
200
201 #[inline]
203 fn count_keyword_ci(haystack: &[u8], needle: &[u8]) -> usize {
204 if needle.len() > haystack.len() {
205 return 0;
206 }
207 let mut count = 0;
208 for i in 0..=(haystack.len() - needle.len()) {
209 if haystack[i..i + needle.len()]
210 .iter()
211 .zip(needle.iter())
212 .all(|(h, n)| h.to_ascii_uppercase() == *n)
213 {
214 count += 1;
215 }
216 }
217 count
218 }
219
220 #[inline]
222 fn contains_keyword_ci(haystack: &[u8], needle: &[u8]) -> bool {
223 Self::count_keyword_ci(haystack, needle) > 0
224 }
225
226 pub fn parse_batch_partial(
239 &self,
240 sqls: &[&str],
241 ) -> (Vec<Vec<Statement>>, Vec<(usize, DbxError)>) {
242 let results = if let Some(pool) = &self.thread_pool {
243 pool.install(|| {
244 sqls.par_iter()
245 .enumerate()
246 .map(|(idx, sql)| (idx, self.parse(sql)))
247 .collect::<Vec<_>>()
248 })
249 } else {
250 sqls.par_iter()
251 .enumerate()
252 .map(|(idx, sql)| (idx, self.parse(sql)))
253 .collect::<Vec<_>>()
254 };
255
256 let mut successes = Vec::new();
257 let mut errors = Vec::new();
258
259 for (idx, result) in results {
260 match result {
261 Ok(statements) => successes.push(statements),
262 Err(e) => errors.push((idx, e)),
263 }
264 }
265
266 (successes, errors)
267 }
268
269 pub fn parse_batch_with_callback<F>(&self, sqls: &[&str], mut callback: F) -> DbxResult<()>
279 where
280 F: FnMut(usize, DbxResult<Vec<Statement>>) -> DbxResult<()>,
281 {
282 let results = if let Some(pool) = &self.thread_pool {
283 pool.install(|| {
284 sqls.par_iter()
285 .enumerate()
286 .map(|(idx, sql)| (idx, self.parse(sql)))
287 .collect::<Vec<_>>()
288 })
289 } else {
290 sqls.par_iter()
291 .enumerate()
292 .map(|(idx, sql)| (idx, self.parse(sql)))
293 .collect::<Vec<_>>()
294 };
295
296 for (idx, result) in results {
297 callback(idx, result)?;
298 }
299
300 Ok(())
301 }
302}
303
304impl Default for ParallelSqlParser {
305 fn default() -> Self {
306 Self::new()
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_parse_single() {
316 let parser = ParallelSqlParser::new();
317 let result = parser.parse("SELECT * FROM users");
318 assert!(result.is_ok());
319 assert_eq!(result.unwrap().len(), 1);
320 }
321
322 #[test]
323 fn test_parse_batch_small() {
324 let parser = ParallelSqlParser::new();
325 let sqls = vec!["SELECT * FROM users", "SELECT * FROM orders"];
326 let results = parser.parse_batch(&sqls).unwrap();
327 assert_eq!(results.len(), 2);
328 assert_eq!(results[0].len(), 1);
329 assert_eq!(results[1].len(), 1);
330 }
331
332 #[test]
333 fn test_parse_batch_large() {
334 let parser = ParallelSqlParser::new();
335 let sqls = vec![
336 "SELECT * FROM users",
337 "SELECT * FROM orders",
338 "SELECT * FROM products",
339 "SELECT * FROM categories",
340 "SELECT * FROM reviews",
341 ];
342 let results = parser.parse_batch(&sqls).unwrap();
343 assert_eq!(results.len(), 5);
344 }
345
346 #[test]
347 fn test_parse_batch_with_error() {
348 let parser = ParallelSqlParser::new();
349 let sqls = vec!["SELECT * FROM users", "INVALID SQL", "SELECT * FROM orders"];
350 let result = parser.parse_batch(&sqls);
351 assert!(result.is_err());
352 }
353
354 #[test]
355 fn test_parse_batch_partial() {
356 let parser = ParallelSqlParser::new();
357 let sqls = vec!["SELECT * FROM users", "INVALID SQL", "SELECT * FROM orders"];
358 let (successes, errors) = parser.parse_batch_partial(&sqls);
359 assert_eq!(successes.len(), 2);
360 assert_eq!(errors.len(), 1);
361 assert_eq!(errors[0].0, 1); }
363
364 #[test]
365 fn test_parse_batch_with_callback() {
366 let parser = ParallelSqlParser::new();
367 let sqls = vec!["SELECT * FROM users", "SELECT * FROM orders"];
368 let mut count = 0;
369 parser
370 .parse_batch_with_callback(&sqls, |_idx, result| {
371 assert!(result.is_ok());
372 count += 1;
373 Ok(())
374 })
375 .unwrap();
376 assert_eq!(count, 2);
377 }
378
379 #[test]
380 fn test_with_custom_thread_pool() {
381 let pool = rayon::ThreadPoolBuilder::new()
382 .num_threads(2)
383 .build()
384 .unwrap();
385 let parser = ParallelSqlParser::with_thread_pool(Arc::new(pool));
386 let sqls = vec![
387 "SELECT * FROM users",
388 "SELECT * FROM orders",
389 "SELECT * FROM products",
390 ];
391 let results = parser.parse_batch(&sqls).unwrap();
392 assert_eq!(results.len(), 3);
393 }
394
395 #[test]
396 fn test_parse_multiple_statements() {
397 let parser = ParallelSqlParser::new();
398 let result = parser.parse("SELECT * FROM users; SELECT * FROM orders;");
399 assert!(result.is_ok());
400 let statements = result.unwrap();
401 assert_eq!(statements.len(), 2);
402 }
403}