1#![allow(clippy::single_range_in_vec_init)]
12
13use std::ops::Range;
14
15use sqlparser::ast::Statement;
16use sqlparser::parser::Parser;
17use sqlparser::tokenizer::{Token, Tokenizer};
18
19use crate::analyzer::helpers::line_col_to_offset;
20use crate::types::{Dialect, ParseStrategy};
21
22const MAX_TRUNCATION_ATTEMPTS: usize = 50;
24
25const MAX_PAREN_FIXES: usize = 20;
28
29const FROM_KEYWORD_LENGTH: usize = 4;
31
32type SqlFixFn = fn(&str, usize, Dialect) -> Option<(String, Vec<Range<usize>>)>;
37
38#[derive(Debug, Clone)]
40struct WordPosition {
41 start: usize,
43 value_upper: String,
45}
46
47fn tokenize_word_positions(sql: &str, dialect: Dialect) -> Option<Vec<WordPosition>> {
52 let dialect_impl = dialect.to_sqlparser_dialect();
53 let mut tokenizer = Tokenizer::new(&*dialect_impl, sql);
54 let tokens = tokenizer.tokenize_with_location().ok()?;
55
56 let mut positions = Vec::new();
57 for token_with_span in tokens {
58 if let Token::Word(word) = &token_with_span.token {
59 let start = line_col_to_offset(
61 sql,
62 token_with_span.span.start.line as usize,
63 token_with_span.span.start.column as usize,
64 )?;
65 positions.push(WordPosition {
66 start,
67 value_upper: word.value.to_uppercase(),
68 });
69 }
70 }
71 Some(positions)
72}
73
74#[cfg(test)]
86fn find_keyword_positions(sql: &str, keyword: &str) -> Vec<usize> {
87 find_keyword_positions_with_dialect(sql, keyword, Dialect::Generic)
88}
89
90fn find_keyword_positions_with_dialect(sql: &str, keyword: &str, dialect: Dialect) -> Vec<usize> {
92 let keyword_upper = keyword.to_uppercase();
93
94 if let Some(word_positions) = tokenize_word_positions(sql, dialect) {
96 return word_positions
97 .into_iter()
98 .filter(|wp| wp.value_upper == keyword_upper)
99 .map(|wp| wp.start)
100 .collect();
101 }
102
103 find_keyword_positions_fallback(sql, keyword)
105}
106
107fn find_keyword_positions_fallback(sql: &str, keyword: &str) -> Vec<usize> {
112 let sql_bytes = sql.as_bytes();
113 let kw_bytes = keyword.as_bytes();
114 let kw_len = kw_bytes.len();
115
116 if kw_len == 0 || sql_bytes.len() < kw_len {
117 return Vec::new();
118 }
119
120 let mut positions = Vec::new();
121 for i in 0..=sql_bytes.len() - kw_len {
122 let matches = sql_bytes[i..i + kw_len]
123 .iter()
124 .zip(kw_bytes)
125 .all(|(s, k)| s.eq_ignore_ascii_case(k));
126
127 if matches {
128 positions.push(i);
129 }
130 }
131 positions
132}
133
134#[cfg(test)]
144fn rfind_keyword(sql: &str, keyword: &str) -> Option<usize> {
145 rfind_keyword_with_dialect(sql, keyword, Dialect::Generic)
146}
147
148fn rfind_keyword_with_dialect(sql: &str, keyword: &str, dialect: Dialect) -> Option<usize> {
153 let keyword_upper = keyword.to_uppercase();
154
155 if let Some(word_positions) = tokenize_word_positions(sql, dialect) {
157 return word_positions
159 .into_iter()
160 .rfind(|wp| wp.value_upper == keyword_upper)
161 .map(|wp| wp.start);
162 }
163
164 rfind_keyword_fallback(sql, keyword)
166}
167
168fn rfind_keyword_fallback(sql: &str, keyword: &str) -> Option<usize> {
170 let sql_bytes = sql.as_bytes();
171 let kw_bytes = keyword.as_bytes();
172 let kw_len = kw_bytes.len();
173
174 if kw_len == 0 || sql_bytes.len() < kw_len {
175 return None;
176 }
177
178 for i in (0..=sql_bytes.len() - kw_len).rev() {
179 let matches = sql_bytes[i..i + kw_len]
180 .iter()
181 .zip(kw_bytes)
182 .all(|(s, k)| s.eq_ignore_ascii_case(k));
183
184 if matches {
185 return Some(i);
186 }
187 }
188 None
189}
190
191fn ends_with_keyword(sql: &str, keyword: &str) -> bool {
193 let trimmed = sql.trim_end();
194 let kw_bytes = keyword.as_bytes();
195 let kw_len = kw_bytes.len();
196
197 if trimmed.len() < kw_len {
198 return false;
199 }
200
201 let start = trimmed.len() - kw_len;
202 trimmed.as_bytes()[start..]
203 .iter()
204 .zip(kw_bytes)
205 .all(|(s, k)| s.eq_ignore_ascii_case(k))
206}
207
208#[derive(Debug, Clone)]
210pub(crate) struct ParseResult {
211 pub statements: Vec<Statement>,
213 #[allow(dead_code)]
215 pub strategy: ParseStrategy,
216 #[allow(dead_code)]
219 pub synthetic_ranges: Vec<Range<usize>>,
220}
221
222pub(crate) fn try_parse_for_completion(
230 sql: &str,
231 cursor_offset: usize,
232 dialect: Dialect,
233) -> Option<ParseResult> {
234 if let Some(stmts) = try_full_parse(sql, dialect) {
236 return Some(ParseResult {
237 statements: stmts,
238 strategy: ParseStrategy::FullParse,
239 synthetic_ranges: vec![],
240 });
241 }
242
243 if let Some(stmts) = try_truncated_parse(sql, cursor_offset, dialect) {
245 return Some(ParseResult {
246 statements: stmts,
247 strategy: ParseStrategy::Truncated,
248 synthetic_ranges: vec![],
249 });
250 }
251
252 if let Some(stmts) = try_complete_statements(sql, cursor_offset, dialect) {
254 return Some(ParseResult {
255 statements: stmts,
256 strategy: ParseStrategy::CompleteStatementsOnly,
257 synthetic_ranges: vec![],
258 });
259 }
260
261 if let Some((stmts, synthetic)) = try_with_fixes(sql, cursor_offset, dialect) {
263 return Some(ParseResult {
264 statements: stmts,
265 strategy: ParseStrategy::WithFixes,
266 synthetic_ranges: synthetic,
267 });
268 }
269
270 None
271}
272
273pub fn try_full_parse(sql: &str, dialect: Dialect) -> Option<Vec<Statement>> {
275 if sql.trim().is_empty() {
276 return None;
277 }
278
279 let dialect_impl = dialect.to_sqlparser_dialect();
280 Parser::parse_sql(&*dialect_impl, sql)
281 .ok()
282 .filter(|stmts| !stmts.is_empty())
283}
284
285pub fn try_truncated_parse(
287 sql: &str,
288 cursor_offset: usize,
289 dialect: Dialect,
290) -> Option<Vec<Statement>> {
291 if cursor_offset == 0 || cursor_offset > sql.len() {
292 return None;
293 }
294
295 let dialect_impl = dialect.to_sqlparser_dialect();
296 let before_cursor = &sql[..cursor_offset.min(sql.len())];
297
298 let candidates = find_truncation_candidates(before_cursor, dialect);
301 for truncation in candidates.into_iter().take(MAX_TRUNCATION_ATTEMPTS) {
302 if truncation == 0 {
303 continue;
304 }
305
306 let truncated = &sql[..truncation];
307 if truncated.trim().is_empty() {
308 continue;
309 }
310
311 if let Ok(stmts) = Parser::parse_sql(&*dialect_impl, truncated) {
312 if !stmts.is_empty() {
313 return Some(stmts);
314 }
315 }
316 }
317
318 None
319}
320
321pub fn try_complete_statements(
323 sql: &str,
324 cursor_offset: usize,
325 dialect: Dialect,
326) -> Option<Vec<Statement>> {
327 let before_cursor = &sql[..cursor_offset.min(sql.len())];
329 let last_semicolon = before_cursor.rfind(';')?;
330
331 let complete_portion = &sql[..=last_semicolon];
332 if complete_portion.trim().is_empty() {
333 return None;
334 }
335
336 let dialect_impl = dialect.to_sqlparser_dialect();
337 Parser::parse_sql(&*dialect_impl, complete_portion)
338 .ok()
339 .filter(|stmts| !stmts.is_empty())
340}
341
342pub fn try_with_fixes(
344 sql: &str,
345 cursor_offset: usize,
346 dialect: Dialect,
347) -> Option<(Vec<Statement>, Vec<Range<usize>>)> {
348 let dialect_impl = dialect.to_sqlparser_dialect();
349
350 let fixes: Vec<SqlFixFn> = vec![
352 fix_trailing_comma,
353 fix_unclosed_parens,
354 fix_incomplete_select,
355 fix_incomplete_from,
356 fix_unclosed_string,
357 ];
358
359 for fix in fixes {
360 if let Some((fixed_sql, synthetic)) = fix(sql, cursor_offset, dialect) {
361 if let Ok(stmts) = Parser::parse_sql(&*dialect_impl, &fixed_sql) {
362 if !stmts.is_empty() {
363 return Some((stmts, synthetic));
364 }
365 }
366 }
367 }
368
369 None
370}
371
372fn find_truncation_candidates(sql: &str, dialect: Dialect) -> Vec<usize> {
375 let mut candidates = Vec::new();
376 let bytes = sql.as_bytes();
377
378 let keywords = [
380 "WHERE",
381 "GROUP",
382 "HAVING",
383 "ORDER",
384 "LIMIT",
385 "OFFSET",
386 "UNION",
387 "EXCEPT",
388 "INTERSECT",
389 ];
390
391 for kw in &keywords {
394 for abs_pos in find_keyword_positions_with_dialect(sql, kw, dialect) {
395 if abs_pos > 0 && bytes[abs_pos - 1].is_ascii_whitespace() {
397 candidates.push(abs_pos);
398 }
399 }
400 }
401
402 let mut pos = sql.len();
405 while pos > 0 {
406 let byte = bytes[pos - 1];
407
408 if byte.is_ascii() {
411 let ch = byte as char;
412
413 if ch.is_ascii_alphanumeric() || ch == '_' || ch == ')' || ch == '"' || ch == '\'' {
415 candidates.push(pos);
416 }
417 }
418
419 pos -= 1;
420 }
421
422 candidates.sort_by(|a, b| b.cmp(a));
424 candidates.dedup();
425 candidates
426}
427
428fn fix_trailing_comma(
430 sql: &str,
431 _cursor_offset: usize,
432 dialect: Dialect,
433) -> Option<(String, Vec<Range<usize>>)> {
434 let trimmed = sql.trim_end();
436
437 if let Some(from_pos) = rfind_keyword_with_dialect(trimmed, "FROM", dialect) {
440 if from_pos > 0 && trimmed.as_bytes()[from_pos - 1].is_ascii_whitespace() {
442 let before_from = trimmed[..from_pos].trim_end();
443 if let Some(without_comma) = before_from.strip_suffix(',') {
444 let after_from = &trimmed[from_pos + FROM_KEYWORD_LENGTH..];
446 let after_from_trimmed = after_from.trim_start();
447 let fixed = if after_from_trimmed.is_empty() {
448 format!("{} FROM", without_comma)
449 } else {
450 format!("{} FROM {}", without_comma, after_from_trimmed)
451 };
452 return Some((fixed, vec![]));
453 }
454 }
455 }
456
457 None
458}
459
460fn fix_unclosed_parens(
462 sql: &str,
463 _cursor_offset: usize,
464 _dialect: Dialect,
465) -> Option<(String, Vec<Range<usize>>)> {
466 let open = sql.chars().filter(|&c| c == '(').count();
467 let close = sql.chars().filter(|&c| c == ')').count();
468
469 if open > close {
470 let missing = open - close;
471 if missing > MAX_PAREN_FIXES {
473 return None;
474 }
475 let suffix = ")".repeat(missing);
476 let synthetic_start = sql.len();
477 let fixed = format!("{}{}", sql, suffix);
478 return Some((fixed, vec![synthetic_start..synthetic_start + missing]));
479 }
480
481 None
482}
483
484fn fix_incomplete_select(
486 sql: &str,
487 _cursor_offset: usize,
488 dialect: Dialect,
489) -> Option<(String, Vec<Range<usize>>)> {
490 let positions = find_keyword_positions_with_dialect(sql, "SELECT", dialect);
495 if let Some(&select_pos) = positions.first() {
496 let after_select_start = select_pos + 6;
497 if after_select_start <= sql.len() {
498 let after_select = &sql[after_select_start..];
499
500 let from_positions = find_keyword_positions_with_dialect(after_select, "FROM", dialect);
502 if let Some(&from_rel_pos) = from_positions.first() {
503 let between = after_select[..from_rel_pos].trim();
504 if between.is_empty() {
505 let insert_pos = after_select_start;
507 let mut fixed = sql.to_string();
508 fixed.insert_str(insert_pos, " 1");
509 return Some((fixed, vec![insert_pos..insert_pos + 2]));
510 }
511 }
512 }
513 }
514
515 None
516}
517
518fn fix_incomplete_from(
520 sql: &str,
521 _cursor_offset: usize,
522 _dialect: Dialect,
523) -> Option<(String, Vec<Range<usize>>)> {
524 let trimmed = sql.trim_end();
525
526 if ends_with_keyword(trimmed, "FROM") {
531 let suffix = " _dummy_";
532 let synthetic_start = sql.len();
533 let fixed = format!("{}{}", sql, suffix);
534 return Some((fixed, vec![synthetic_start..synthetic_start + suffix.len()]));
535 }
536
537 None
538}
539
540fn fix_unclosed_string(
542 sql: &str,
543 _cursor_offset: usize,
544 _dialect: Dialect,
545) -> Option<(String, Vec<Range<usize>>)> {
546 let single_quotes = sql.chars().filter(|&c| c == '\'').count();
548 let double_quotes = sql.chars().filter(|&c| c == '"').count();
549
550 if single_quotes % 2 != 0 {
551 let synthetic_start = sql.len();
552 let fixed = format!("{}'", sql);
553 return Some((fixed, vec![synthetic_start..synthetic_start + 1]));
554 }
555
556 if double_quotes % 2 != 0 {
557 let synthetic_start = sql.len();
558 let fixed = format!("{}\"", sql);
559 return Some((fixed, vec![synthetic_start..synthetic_start + 1]));
560 }
561
562 None
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568
569 #[test]
570 fn test_find_keyword_skips_string_literals() {
571 let sql = "SELECT 'WHERE is fun' FROM users";
573 let positions = find_keyword_positions(sql, "WHERE");
574 assert!(
575 positions.is_empty(),
576 "Should not find WHERE inside string literal, found at: {:?}",
577 positions
578 );
579
580 let sql2 = "SELECT 'text' FROM users WHERE id = 1";
582 let positions2 = find_keyword_positions(sql2, "WHERE");
583 assert_eq!(positions2.len(), 1);
584 assert_eq!(&sql2[positions2[0]..positions2[0] + 5], "WHERE");
585 }
586
587 #[test]
588 fn test_find_keyword_skips_comments() {
589 let sql = "SELECT * FROM users -- WHERE is commented out";
591 let positions = find_keyword_positions(sql, "WHERE");
592 assert!(
593 positions.is_empty(),
594 "Should not find WHERE inside line comment"
595 );
596
597 let sql2 = "SELECT * /* WHERE */ FROM users";
599 let positions2 = find_keyword_positions(sql2, "WHERE");
600 assert!(
601 positions2.is_empty(),
602 "Should not find WHERE inside block comment"
603 );
604 }
605
606 #[test]
607 fn test_find_keyword_case_insensitive() {
608 let sql = "select * from users where id = 1";
609 let positions = find_keyword_positions(sql, "WHERE");
610 assert_eq!(positions.len(), 1);
611 assert_eq!(&sql[positions[0]..positions[0] + 5], "where");
612 }
613
614 #[test]
615 fn test_find_keyword_handles_unicode_prefix() {
616 let sql = "SELECT μ, FROM users";
617 let positions = find_keyword_positions(sql, "FROM");
618 assert_eq!(positions, vec!["SELECT μ, ".len()]);
619 }
620
621 #[test]
622 fn test_rfind_keyword_token_aware() {
623 let sql = "SELECT 'FROM somewhere' FROM users";
625 let pos = rfind_keyword(sql, "FROM");
626 assert!(pos.is_some());
627 let pos = pos.unwrap();
628 assert_eq!(&sql[pos..pos + 4], "FROM");
629 assert!(pos > 20, "Should find actual FROM, not one in string");
631 }
632
633 #[test]
634 fn test_rfind_keyword_handles_unicode_prefix() {
635 let sql = "SELECT μ, FROM users";
636 let pos = rfind_keyword(sql, "FROM").expect("should find FROM");
637 assert_eq!(
638 pos,
639 "SELECT μ, ".len(),
640 "should account for multi-byte chars"
641 );
642 }
643
644 #[test]
645 fn test_full_parse_valid_sql() {
646 let sql = "SELECT * FROM users WHERE id = 1";
647 let result = try_full_parse(sql, Dialect::Generic);
648 assert!(result.is_some());
649 assert_eq!(result.unwrap().len(), 1);
650 }
651
652 #[test]
653 fn test_full_parse_invalid_sql() {
654 let sql = "SELECT * FROM";
655 let result = try_full_parse(sql, Dialect::Generic);
656 assert!(result.is_none());
657 }
658
659 #[test]
660 fn test_truncated_parse() {
661 let sql = "SELECT * FROM users WHERE ";
662 let result = try_truncated_parse(sql, sql.len(), Dialect::Generic);
663 assert!(result.is_some());
664 }
665
666 #[test]
667 fn test_complete_statements_only() {
668 let sql = "SELECT 1; SELECT * FROM";
669 let result = try_complete_statements(sql, sql.len(), Dialect::Generic);
670 assert!(result.is_some());
671 assert_eq!(result.unwrap().len(), 1);
672 }
673
674 #[test]
675 fn test_fix_trailing_comma() {
676 let sql = "SELECT a, FROM users";
677 let result = try_with_fixes(sql, sql.len(), Dialect::Generic);
678 assert!(result.is_some());
679 }
680
681 #[test]
682 fn test_fix_unclosed_parens() {
683 let sql = "SELECT COUNT(* FROM users";
684 let result = fix_unclosed_parens(sql, sql.len(), Dialect::Generic);
685 assert!(result.is_some());
686 let (fixed, synthetic) = result.unwrap();
687 assert!(fixed.ends_with(')'));
688 assert_eq!(synthetic.len(), 1);
689 }
690
691 #[test]
692 fn test_fix_incomplete_select() {
693 let sql = "SELECT FROM users";
694 let result = fix_incomplete_select(sql, sql.len(), Dialect::Generic);
695 assert!(result.is_some());
696 let (fixed, synthetic) = result.unwrap();
697 assert!(fixed.contains("1"));
698 assert_eq!(synthetic.len(), 1);
699 }
700
701 #[test]
702 fn test_fix_incomplete_from() {
703 let sql = "SELECT * FROM";
704 let result = fix_incomplete_from(sql, sql.len(), Dialect::Generic);
705 assert!(result.is_some());
706 let (fixed, _) = result.unwrap();
707 assert!(fixed.contains("_dummy_"));
708 }
709
710 #[test]
711 fn test_fix_unclosed_string() {
712 let sql = "SELECT 'hello";
713 let result = fix_unclosed_string(sql, sql.len(), Dialect::Generic);
714 assert!(result.is_some());
715 let (fixed, _) = result.unwrap();
716 assert!(fixed.ends_with('\''));
717 }
718
719 #[test]
720 fn test_try_parse_for_completion_valid() {
721 let sql = "SELECT * FROM users";
722 let result = try_parse_for_completion(sql, sql.len(), Dialect::Generic);
723 assert!(result.is_some());
724 assert_eq!(result.unwrap().strategy, ParseStrategy::FullParse);
725 }
726
727 #[test]
728 fn test_try_parse_for_completion_truncated() {
729 let sql = "SELECT * FROM users WHERE id = ";
730 let result = try_parse_for_completion(sql, sql.len(), Dialect::Generic);
731 assert!(result.is_some());
732 assert!(matches!(
734 result.unwrap().strategy,
735 ParseStrategy::Truncated | ParseStrategy::FullParse
736 ));
737 }
738
739 #[test]
740 fn test_try_parse_for_completion_with_fixes() {
741 let sql = "SELECT * FROM";
744 let result = try_parse_for_completion(sql, sql.len(), Dialect::Generic);
745 assert!(result.is_some());
746 assert_eq!(result.unwrap().strategy, ParseStrategy::WithFixes);
747 }
748}