flowscope_core/completion/
parse_strategies.rs1#![allow(clippy::single_range_in_vec_init)]
12
13use std::ops::Range;
14
15use sqlparser::ast::Statement;
16use sqlparser::parser::Parser;
17
18use crate::types::{Dialect, ParseStrategy};
19
20const MAX_TRUNCATION_ATTEMPTS: usize = 50;
22
23type SqlFixFn = fn(&str, usize) -> Option<(String, Vec<Range<usize>>)>;
25
26fn find_keyword_positions(sql: &str, keyword: &str) -> Vec<usize> {
33 let sql_bytes = sql.as_bytes();
34 let kw_bytes = keyword.as_bytes();
35 let kw_len = kw_bytes.len();
36
37 if kw_len == 0 || sql_bytes.len() < kw_len {
38 return Vec::new();
39 }
40
41 let mut positions = Vec::new();
42 for i in 0..=sql_bytes.len() - kw_len {
43 let matches = sql_bytes[i..i + kw_len]
44 .iter()
45 .zip(kw_bytes)
46 .all(|(s, k)| s.eq_ignore_ascii_case(k));
47
48 if matches {
49 positions.push(i);
50 }
51 }
52 positions
53}
54
55fn rfind_keyword(sql: &str, keyword: &str) -> Option<usize> {
59 let sql_bytes = sql.as_bytes();
60 let kw_bytes = keyword.as_bytes();
61 let kw_len = kw_bytes.len();
62
63 if kw_len == 0 || sql_bytes.len() < kw_len {
64 return None;
65 }
66
67 for i in (0..=sql_bytes.len() - kw_len).rev() {
68 let matches = sql_bytes[i..i + kw_len]
69 .iter()
70 .zip(kw_bytes)
71 .all(|(s, k)| s.eq_ignore_ascii_case(k));
72
73 if matches {
74 return Some(i);
75 }
76 }
77 None
78}
79
80fn ends_with_keyword(sql: &str, keyword: &str) -> bool {
82 let trimmed = sql.trim_end();
83 let kw_bytes = keyword.as_bytes();
84 let kw_len = kw_bytes.len();
85
86 if trimmed.len() < kw_len {
87 return false;
88 }
89
90 let start = trimmed.len() - kw_len;
91 trimmed.as_bytes()[start..]
92 .iter()
93 .zip(kw_bytes)
94 .all(|(s, k)| s.eq_ignore_ascii_case(k))
95}
96
97#[derive(Debug, Clone)]
99pub(crate) struct ParseResult {
100 pub statements: Vec<Statement>,
102 #[allow(dead_code)]
104 pub strategy: ParseStrategy,
105 #[allow(dead_code)]
108 pub synthetic_ranges: Vec<Range<usize>>,
109}
110
111pub(crate) fn try_parse_for_completion(
119 sql: &str,
120 cursor_offset: usize,
121 dialect: Dialect,
122) -> Option<ParseResult> {
123 if let Some(stmts) = try_full_parse(sql, dialect) {
125 return Some(ParseResult {
126 statements: stmts,
127 strategy: ParseStrategy::FullParse,
128 synthetic_ranges: vec![],
129 });
130 }
131
132 if let Some(stmts) = try_truncated_parse(sql, cursor_offset, dialect) {
134 return Some(ParseResult {
135 statements: stmts,
136 strategy: ParseStrategy::Truncated,
137 synthetic_ranges: vec![],
138 });
139 }
140
141 if let Some(stmts) = try_complete_statements(sql, cursor_offset, dialect) {
143 return Some(ParseResult {
144 statements: stmts,
145 strategy: ParseStrategy::CompleteStatementsOnly,
146 synthetic_ranges: vec![],
147 });
148 }
149
150 if let Some((stmts, synthetic)) = try_with_fixes(sql, cursor_offset, dialect) {
152 return Some(ParseResult {
153 statements: stmts,
154 strategy: ParseStrategy::WithFixes,
155 synthetic_ranges: synthetic,
156 });
157 }
158
159 None
160}
161
162pub fn try_full_parse(sql: &str, dialect: Dialect) -> Option<Vec<Statement>> {
164 if sql.trim().is_empty() {
165 return None;
166 }
167
168 let dialect_impl = dialect.to_sqlparser_dialect();
169 Parser::parse_sql(&*dialect_impl, sql)
170 .ok()
171 .filter(|stmts| !stmts.is_empty())
172}
173
174pub fn try_truncated_parse(
176 sql: &str,
177 cursor_offset: usize,
178 dialect: Dialect,
179) -> Option<Vec<Statement>> {
180 if cursor_offset == 0 || cursor_offset > sql.len() {
181 return None;
182 }
183
184 let dialect_impl = dialect.to_sqlparser_dialect();
185 let before_cursor = &sql[..cursor_offset.min(sql.len())];
186
187 let candidates = find_truncation_candidates(before_cursor);
190 for truncation in candidates.into_iter().take(MAX_TRUNCATION_ATTEMPTS) {
191 if truncation == 0 {
192 continue;
193 }
194
195 let truncated = &sql[..truncation];
196 if truncated.trim().is_empty() {
197 continue;
198 }
199
200 if let Ok(stmts) = Parser::parse_sql(&*dialect_impl, truncated) {
201 if !stmts.is_empty() {
202 return Some(stmts);
203 }
204 }
205 }
206
207 None
208}
209
210pub fn try_complete_statements(
212 sql: &str,
213 cursor_offset: usize,
214 dialect: Dialect,
215) -> Option<Vec<Statement>> {
216 let before_cursor = &sql[..cursor_offset.min(sql.len())];
218 let last_semicolon = before_cursor.rfind(';')?;
219
220 let complete_portion = &sql[..=last_semicolon];
221 if complete_portion.trim().is_empty() {
222 return None;
223 }
224
225 let dialect_impl = dialect.to_sqlparser_dialect();
226 Parser::parse_sql(&*dialect_impl, complete_portion)
227 .ok()
228 .filter(|stmts| !stmts.is_empty())
229}
230
231pub fn try_with_fixes(
233 sql: &str,
234 cursor_offset: usize,
235 dialect: Dialect,
236) -> Option<(Vec<Statement>, Vec<Range<usize>>)> {
237 let dialect_impl = dialect.to_sqlparser_dialect();
238
239 let fixes: Vec<SqlFixFn> = vec![
241 fix_trailing_comma,
242 fix_unclosed_parens,
243 fix_incomplete_select,
244 fix_incomplete_from,
245 fix_unclosed_string,
246 ];
247
248 for fix in fixes {
249 if let Some((fixed_sql, synthetic)) = fix(sql, cursor_offset) {
250 if let Ok(stmts) = Parser::parse_sql(&*dialect_impl, &fixed_sql) {
251 if !stmts.is_empty() {
252 return Some((stmts, synthetic));
253 }
254 }
255 }
256 }
257
258 None
259}
260
261fn find_truncation_candidates(sql: &str) -> Vec<usize> {
264 let mut candidates = Vec::new();
265 let bytes = sql.as_bytes();
266
267 let keywords = [
269 "WHERE",
270 "GROUP",
271 "HAVING",
272 "ORDER",
273 "LIMIT",
274 "OFFSET",
275 "UNION",
276 "EXCEPT",
277 "INTERSECT",
278 ];
279
280 for kw in &keywords {
283 for abs_pos in find_keyword_positions(sql, kw) {
284 if abs_pos > 0 && bytes[abs_pos - 1].is_ascii_whitespace() {
286 candidates.push(abs_pos);
287 }
288 }
289 }
290
291 let mut pos = sql.len();
294 while pos > 0 {
295 let byte = bytes[pos - 1];
296
297 if byte.is_ascii() {
300 let ch = byte as char;
301
302 if ch.is_ascii_alphanumeric() || ch == '_' || ch == ')' || ch == '"' || ch == '\'' {
304 candidates.push(pos);
305 }
306 }
307
308 pos -= 1;
309 }
310
311 candidates.sort_by(|a, b| b.cmp(a));
313 candidates.dedup();
314 candidates
315}
316
317fn fix_trailing_comma(sql: &str, _cursor_offset: usize) -> Option<(String, Vec<Range<usize>>)> {
319 let trimmed = sql.trim_end();
321
322 if let Some(from_pos) = rfind_keyword(trimmed, " FROM") {
325 let before_from = trimmed[..from_pos].trim_end();
326 if let Some(without_comma) = before_from.strip_suffix(',') {
327 let fixed = format!("{} {}", without_comma, &trimmed[from_pos..]);
328 return Some((fixed, vec![]));
329 }
330 }
331
332 None
333}
334
335fn fix_unclosed_parens(sql: &str, _cursor_offset: usize) -> Option<(String, Vec<Range<usize>>)> {
337 let open = sql.chars().filter(|&c| c == '(').count();
338 let close = sql.chars().filter(|&c| c == ')').count();
339
340 if open > close {
341 let missing = open - close;
342 let suffix = ")".repeat(missing);
343 let synthetic_start = sql.len();
344 let fixed = format!("{}{}", sql, suffix);
345 return Some((fixed, vec![synthetic_start..synthetic_start + missing]));
346 }
347
348 None
349}
350
351fn fix_incomplete_select(sql: &str, _cursor_offset: usize) -> Option<(String, Vec<Range<usize>>)> {
353 let positions = find_keyword_positions(sql, "SELECT");
358 if let Some(&select_pos) = positions.first() {
359 let after_select_start = select_pos + 6;
360 if after_select_start <= sql.len() {
361 let after_select = &sql[after_select_start..];
362
363 let from_positions = find_keyword_positions(after_select, "FROM");
365 if let Some(&from_rel_pos) = from_positions.first() {
366 let between = after_select[..from_rel_pos].trim();
367 if between.is_empty() {
368 let insert_pos = after_select_start;
370 let mut fixed = sql.to_string();
371 fixed.insert_str(insert_pos, " 1");
372 return Some((fixed, vec![insert_pos..insert_pos + 2]));
373 }
374 }
375 }
376 }
377
378 None
379}
380
381fn fix_incomplete_from(sql: &str, _cursor_offset: usize) -> Option<(String, Vec<Range<usize>>)> {
383 let trimmed = sql.trim_end();
384
385 if ends_with_keyword(trimmed, "FROM") {
388 let suffix = " _dummy_";
389 let synthetic_start = sql.len();
390 let fixed = format!("{}{}", sql, suffix);
391 return Some((fixed, vec![synthetic_start..synthetic_start + suffix.len()]));
392 }
393
394 None
395}
396
397fn fix_unclosed_string(sql: &str, _cursor_offset: usize) -> Option<(String, Vec<Range<usize>>)> {
399 let single_quotes = sql.chars().filter(|&c| c == '\'').count();
401 let double_quotes = sql.chars().filter(|&c| c == '"').count();
402
403 if single_quotes % 2 != 0 {
404 let synthetic_start = sql.len();
405 let fixed = format!("{}'", sql);
406 return Some((fixed, vec![synthetic_start..synthetic_start + 1]));
407 }
408
409 if double_quotes % 2 != 0 {
410 let synthetic_start = sql.len();
411 let fixed = format!("{}\"", sql);
412 return Some((fixed, vec![synthetic_start..synthetic_start + 1]));
413 }
414
415 None
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
423 fn test_full_parse_valid_sql() {
424 let sql = "SELECT * FROM users WHERE id = 1";
425 let result = try_full_parse(sql, Dialect::Generic);
426 assert!(result.is_some());
427 assert_eq!(result.unwrap().len(), 1);
428 }
429
430 #[test]
431 fn test_full_parse_invalid_sql() {
432 let sql = "SELECT * FROM";
433 let result = try_full_parse(sql, Dialect::Generic);
434 assert!(result.is_none());
435 }
436
437 #[test]
438 fn test_truncated_parse() {
439 let sql = "SELECT * FROM users WHERE ";
440 let result = try_truncated_parse(sql, sql.len(), Dialect::Generic);
441 assert!(result.is_some());
442 }
443
444 #[test]
445 fn test_complete_statements_only() {
446 let sql = "SELECT 1; SELECT * FROM";
447 let result = try_complete_statements(sql, sql.len(), Dialect::Generic);
448 assert!(result.is_some());
449 assert_eq!(result.unwrap().len(), 1);
450 }
451
452 #[test]
453 fn test_fix_trailing_comma() {
454 let sql = "SELECT a, FROM users";
455 let result = try_with_fixes(sql, sql.len(), Dialect::Generic);
456 assert!(result.is_some());
457 }
458
459 #[test]
460 fn test_fix_unclosed_parens() {
461 let sql = "SELECT COUNT(* FROM users";
462 let result = fix_unclosed_parens(sql, sql.len());
463 assert!(result.is_some());
464 let (fixed, synthetic) = result.unwrap();
465 assert!(fixed.ends_with(')'));
466 assert_eq!(synthetic.len(), 1);
467 }
468
469 #[test]
470 fn test_fix_incomplete_select() {
471 let sql = "SELECT FROM users";
472 let result = fix_incomplete_select(sql, sql.len());
473 assert!(result.is_some());
474 let (fixed, synthetic) = result.unwrap();
475 assert!(fixed.contains("1"));
476 assert_eq!(synthetic.len(), 1);
477 }
478
479 #[test]
480 fn test_fix_incomplete_from() {
481 let sql = "SELECT * FROM";
482 let result = fix_incomplete_from(sql, sql.len());
483 assert!(result.is_some());
484 let (fixed, _) = result.unwrap();
485 assert!(fixed.contains("_dummy_"));
486 }
487
488 #[test]
489 fn test_fix_unclosed_string() {
490 let sql = "SELECT 'hello";
491 let result = fix_unclosed_string(sql, sql.len());
492 assert!(result.is_some());
493 let (fixed, _) = result.unwrap();
494 assert!(fixed.ends_with('\''));
495 }
496
497 #[test]
498 fn test_try_parse_for_completion_valid() {
499 let sql = "SELECT * FROM users";
500 let result = try_parse_for_completion(sql, sql.len(), Dialect::Generic);
501 assert!(result.is_some());
502 assert_eq!(result.unwrap().strategy, ParseStrategy::FullParse);
503 }
504
505 #[test]
506 fn test_try_parse_for_completion_truncated() {
507 let sql = "SELECT * FROM users WHERE id = ";
508 let result = try_parse_for_completion(sql, sql.len(), Dialect::Generic);
509 assert!(result.is_some());
510 assert!(matches!(
512 result.unwrap().strategy,
513 ParseStrategy::Truncated | ParseStrategy::FullParse
514 ));
515 }
516
517 #[test]
518 fn test_try_parse_for_completion_with_fixes() {
519 let sql = "SELECT * FROM";
522 let result = try_parse_for_completion(sql, sql.len(), Dialect::Generic);
523 assert!(result.is_some());
524 assert_eq!(result.unwrap().strategy, ParseStrategy::WithFixes);
525 }
526}