flowscope_core/completion/
parse_strategies.rs

1//! Parse strategies for hybrid SQL completion.
2//!
3//! This module provides multiple strategies for parsing incomplete SQL,
4//! trying them in order of cost until one succeeds.
5//!
6//! Note: This module assumes ASCII SQL keywords and identifiers. Non-ASCII
7//! characters in identifiers are handled correctly, but keyword matching
8//! is ASCII-only for consistent cross-dialect behavior.
9
10// We intentionally create Vec<Range<usize>> with single elements for synthetic ranges
11#![allow(clippy::single_range_in_vec_init)]
12
13use std::ops::Range;
14
15use sqlparser::ast::Statement;
16use sqlparser::parser::Parser;
17
18use crate::types::{Dialect, ParseStrategy};
19
20/// Maximum number of truncation attempts to prevent DoS with pathological SQL.
21const MAX_TRUNCATION_ATTEMPTS: usize = 50;
22
23/// Type alias for SQL fix functions that return fixed SQL and synthetic ranges.
24type SqlFixFn = fn(&str, usize) -> Option<(String, Vec<Range<usize>>)>;
25
26/// Find all positions of a keyword in SQL using ASCII case-insensitive matching.
27///
28/// This function operates directly on the original string bytes, avoiding the
29/// index mismatch issues that can occur with `to_uppercase()` on non-ASCII strings.
30///
31/// Returns byte indices into the original string where the keyword starts.
32fn find_keyword_positions(sql: &str, keyword: &str) -> Vec<usize> {
33    let sql_bytes = sql.as_bytes();
34    let kw_bytes = keyword.as_bytes();
35    let kw_len = kw_bytes.len();
36
37    if kw_len == 0 || sql_bytes.len() < kw_len {
38        return Vec::new();
39    }
40
41    let mut positions = Vec::new();
42    for i in 0..=sql_bytes.len() - kw_len {
43        let matches = sql_bytes[i..i + kw_len]
44            .iter()
45            .zip(kw_bytes)
46            .all(|(s, k)| s.eq_ignore_ascii_case(k));
47
48        if matches {
49            positions.push(i);
50        }
51    }
52    positions
53}
54
55/// Find the last position of a keyword in SQL using ASCII case-insensitive matching.
56///
57/// Returns the byte index of the last occurrence, or None if not found.
58fn rfind_keyword(sql: &str, keyword: &str) -> Option<usize> {
59    let sql_bytes = sql.as_bytes();
60    let kw_bytes = keyword.as_bytes();
61    let kw_len = kw_bytes.len();
62
63    if kw_len == 0 || sql_bytes.len() < kw_len {
64        return None;
65    }
66
67    for i in (0..=sql_bytes.len() - kw_len).rev() {
68        let matches = sql_bytes[i..i + kw_len]
69            .iter()
70            .zip(kw_bytes)
71            .all(|(s, k)| s.eq_ignore_ascii_case(k));
72
73        if matches {
74            return Some(i);
75        }
76    }
77    None
78}
79
80/// Check if SQL ends with a keyword (case-insensitive, allowing trailing whitespace).
81fn ends_with_keyword(sql: &str, keyword: &str) -> bool {
82    let trimmed = sql.trim_end();
83    let kw_bytes = keyword.as_bytes();
84    let kw_len = kw_bytes.len();
85
86    if trimmed.len() < kw_len {
87        return false;
88    }
89
90    let start = trimmed.len() - kw_len;
91    trimmed.as_bytes()[start..]
92        .iter()
93        .zip(kw_bytes)
94        .all(|(s, k)| s.eq_ignore_ascii_case(k))
95}
96
97/// Result of a successful parse attempt
98#[derive(Debug, Clone)]
99pub(crate) struct ParseResult {
100    /// Parsed SQL statements
101    pub statements: Vec<Statement>,
102    /// Strategy that succeeded (reserved for future diagnostic/optimization use)
103    #[allow(dead_code)]
104    pub strategy: ParseStrategy,
105    /// Byte ranges of synthetic (added) content to ignore during extraction
106    /// (reserved for future use when filtering AST nodes that were synthesized)
107    #[allow(dead_code)]
108    pub synthetic_ranges: Vec<Range<usize>>,
109}
110
111/// Try to parse SQL for completion context extraction.
112///
113/// Attempts multiple strategies in order of cost until one succeeds:
114/// 1. Full parse (complete SQL)
115/// 2. Truncated parse (cut at cursor position)
116/// 3. Complete statements only (semicolon-terminated before cursor)
117/// 4. With minimal fixes (patch incomplete SQL)
118pub(crate) fn try_parse_for_completion(
119    sql: &str,
120    cursor_offset: usize,
121    dialect: Dialect,
122) -> Option<ParseResult> {
123    // Strategy 1: Try full parse
124    if let Some(stmts) = try_full_parse(sql, dialect) {
125        return Some(ParseResult {
126            statements: stmts,
127            strategy: ParseStrategy::FullParse,
128            synthetic_ranges: vec![],
129        });
130    }
131
132    // Strategy 2: Try truncated parse
133    if let Some(stmts) = try_truncated_parse(sql, cursor_offset, dialect) {
134        return Some(ParseResult {
135            statements: stmts,
136            strategy: ParseStrategy::Truncated,
137            synthetic_ranges: vec![],
138        });
139    }
140
141    // Strategy 3: Try complete statements only
142    if let Some(stmts) = try_complete_statements(sql, cursor_offset, dialect) {
143        return Some(ParseResult {
144            statements: stmts,
145            strategy: ParseStrategy::CompleteStatementsOnly,
146            synthetic_ranges: vec![],
147        });
148    }
149
150    // Strategy 4: Try with minimal fixes
151    if let Some((stmts, synthetic)) = try_with_fixes(sql, cursor_offset, dialect) {
152        return Some(ParseResult {
153            statements: stmts,
154            strategy: ParseStrategy::WithFixes,
155            synthetic_ranges: synthetic,
156        });
157    }
158
159    None
160}
161
162/// Strategy 1: Parse complete SQL as-is
163pub fn try_full_parse(sql: &str, dialect: Dialect) -> Option<Vec<Statement>> {
164    if sql.trim().is_empty() {
165        return None;
166    }
167
168    let dialect_impl = dialect.to_sqlparser_dialect();
169    Parser::parse_sql(&*dialect_impl, sql)
170        .ok()
171        .filter(|stmts| !stmts.is_empty())
172}
173
174/// Strategy 2: Truncate SQL at a safe point before cursor
175pub fn try_truncated_parse(
176    sql: &str,
177    cursor_offset: usize,
178    dialect: Dialect,
179) -> Option<Vec<Statement>> {
180    if cursor_offset == 0 || cursor_offset > sql.len() {
181        return None;
182    }
183
184    let dialect_impl = dialect.to_sqlparser_dialect();
185    let before_cursor = &sql[..cursor_offset.min(sql.len())];
186
187    // Try progressively shorter truncations until we find one that parses
188    // Limit attempts to prevent DoS with pathological SQL
189    let candidates = find_truncation_candidates(before_cursor);
190    for truncation in candidates.into_iter().take(MAX_TRUNCATION_ATTEMPTS) {
191        if truncation == 0 {
192            continue;
193        }
194
195        let truncated = &sql[..truncation];
196        if truncated.trim().is_empty() {
197            continue;
198        }
199
200        if let Ok(stmts) = Parser::parse_sql(&*dialect_impl, truncated) {
201            if !stmts.is_empty() {
202                return Some(stmts);
203            }
204        }
205    }
206
207    None
208}
209
210/// Strategy 3: Parse only complete statements before cursor
211pub fn try_complete_statements(
212    sql: &str,
213    cursor_offset: usize,
214    dialect: Dialect,
215) -> Option<Vec<Statement>> {
216    // Find the last semicolon before cursor
217    let before_cursor = &sql[..cursor_offset.min(sql.len())];
218    let last_semicolon = before_cursor.rfind(';')?;
219
220    let complete_portion = &sql[..=last_semicolon];
221    if complete_portion.trim().is_empty() {
222        return None;
223    }
224
225    let dialect_impl = dialect.to_sqlparser_dialect();
226    Parser::parse_sql(&*dialect_impl, complete_portion)
227        .ok()
228        .filter(|stmts| !stmts.is_empty())
229}
230
231/// Strategy 4: Apply minimal fixes to make SQL parseable
232pub fn try_with_fixes(
233    sql: &str,
234    cursor_offset: usize,
235    dialect: Dialect,
236) -> Option<(Vec<Statement>, Vec<Range<usize>>)> {
237    let dialect_impl = dialect.to_sqlparser_dialect();
238
239    // Try fixes in order of likelihood
240    let fixes: Vec<SqlFixFn> = vec![
241        fix_trailing_comma,
242        fix_unclosed_parens,
243        fix_incomplete_select,
244        fix_incomplete_from,
245        fix_unclosed_string,
246    ];
247
248    for fix in fixes {
249        if let Some((fixed_sql, synthetic)) = fix(sql, cursor_offset) {
250            if let Ok(stmts) = Parser::parse_sql(&*dialect_impl, &fixed_sql) {
251                if !stmts.is_empty() {
252                    return Some((stmts, synthetic));
253                }
254            }
255        }
256    }
257
258    None
259}
260
261/// Generate candidate truncation points from longest to shortest.
262/// These are positions where SQL might be syntactically complete.
263fn find_truncation_candidates(sql: &str) -> Vec<usize> {
264    let mut candidates = Vec::new();
265    let bytes = sql.as_bytes();
266
267    // SQL keywords that often mark clause boundaries where truncation might work
268    let keywords = [
269        "WHERE",
270        "GROUP",
271        "HAVING",
272        "ORDER",
273        "LIMIT",
274        "OFFSET",
275        "UNION",
276        "EXCEPT",
277        "INTERSECT",
278    ];
279
280    // Find positions right before keywords (truncating before the keyword)
281    // Use ASCII case-insensitive matching to avoid index mismatch with non-ASCII
282    for kw in &keywords {
283        for abs_pos in find_keyword_positions(sql, kw) {
284            // Make sure it's a word boundary (preceded by whitespace)
285            if abs_pos > 0 && bytes[abs_pos - 1].is_ascii_whitespace() {
286                candidates.push(abs_pos);
287            }
288        }
289    }
290
291    // Also try truncating at word boundaries going backwards
292    // Only consider positions that are valid UTF-8 character boundaries
293    let mut pos = sql.len();
294    while pos > 0 {
295        let byte = bytes[pos - 1];
296
297        // Only process ASCII bytes to avoid UTF-8 boundary issues
298        // Non-ASCII bytes (high bit set) are part of multi-byte sequences
299        if byte.is_ascii() {
300            let ch = byte as char;
301
302            // After alphanumeric/identifier chars could be a valid truncation point
303            if ch.is_ascii_alphanumeric() || ch == '_' || ch == ')' || ch == '"' || ch == '\'' {
304                candidates.push(pos);
305            }
306        }
307
308        pos -= 1;
309    }
310
311    // Sort by position descending (try longer truncations first)
312    candidates.sort_by(|a, b| b.cmp(a));
313    candidates.dedup();
314    candidates
315}
316
317/// Fix: Remove trailing comma
318fn fix_trailing_comma(sql: &str, _cursor_offset: usize) -> Option<(String, Vec<Range<usize>>)> {
319    // Look for patterns like "SELECT a, FROM" or "SELECT a, b, FROM"
320    let trimmed = sql.trim_end();
321
322    // Simple case: trailing comma before FROM
323    // Use ASCII case-insensitive search to find " FROM"
324    if let Some(from_pos) = rfind_keyword(trimmed, " FROM") {
325        let before_from = trimmed[..from_pos].trim_end();
326        if let Some(without_comma) = before_from.strip_suffix(',') {
327            let fixed = format!("{} {}", without_comma, &trimmed[from_pos..]);
328            return Some((fixed, vec![]));
329        }
330    }
331
332    None
333}
334
335/// Fix: Close unclosed parentheses
336fn fix_unclosed_parens(sql: &str, _cursor_offset: usize) -> Option<(String, Vec<Range<usize>>)> {
337    let open = sql.chars().filter(|&c| c == '(').count();
338    let close = sql.chars().filter(|&c| c == ')').count();
339
340    if open > close {
341        let missing = open - close;
342        let suffix = ")".repeat(missing);
343        let synthetic_start = sql.len();
344        let fixed = format!("{}{}", sql, suffix);
345        return Some((fixed, vec![synthetic_start..synthetic_start + missing]));
346    }
347
348    None
349}
350
351/// Fix: Add placeholder after incomplete SELECT
352fn fix_incomplete_select(sql: &str, _cursor_offset: usize) -> Option<(String, Vec<Range<usize>>)> {
353    // Look for "SELECT FROM" without anything between
354    // Use ASCII case-insensitive matching
355
356    // Find SELECT keyword
357    let positions = find_keyword_positions(sql, "SELECT");
358    if let Some(&select_pos) = positions.first() {
359        let after_select_start = select_pos + 6;
360        if after_select_start <= sql.len() {
361            let after_select = &sql[after_select_start..];
362
363            // Check if FROM follows immediately (with only whitespace)
364            let from_positions = find_keyword_positions(after_select, "FROM");
365            if let Some(&from_rel_pos) = from_positions.first() {
366                let between = after_select[..from_rel_pos].trim();
367                if between.is_empty() {
368                    // Insert "1 " after SELECT
369                    let insert_pos = after_select_start;
370                    let mut fixed = sql.to_string();
371                    fixed.insert_str(insert_pos, " 1");
372                    return Some((fixed, vec![insert_pos..insert_pos + 2]));
373                }
374            }
375        }
376    }
377
378    None
379}
380
381/// Fix: Add dummy table after incomplete FROM
382fn fix_incomplete_from(sql: &str, _cursor_offset: usize) -> Option<(String, Vec<Range<usize>>)> {
383    let trimmed = sql.trim_end();
384
385    // Check if SQL ends with FROM (possibly with whitespace)
386    // Use ASCII case-insensitive matching
387    if ends_with_keyword(trimmed, "FROM") {
388        let suffix = " _dummy_";
389        let synthetic_start = sql.len();
390        let fixed = format!("{}{}", sql, suffix);
391        return Some((fixed, vec![synthetic_start..synthetic_start + suffix.len()]));
392    }
393
394    None
395}
396
397/// Fix: Close unclosed string literal
398fn fix_unclosed_string(sql: &str, _cursor_offset: usize) -> Option<(String, Vec<Range<usize>>)> {
399    // Count quotes
400    let single_quotes = sql.chars().filter(|&c| c == '\'').count();
401    let double_quotes = sql.chars().filter(|&c| c == '"').count();
402
403    if single_quotes % 2 != 0 {
404        let synthetic_start = sql.len();
405        let fixed = format!("{}'", sql);
406        return Some((fixed, vec![synthetic_start..synthetic_start + 1]));
407    }
408
409    if double_quotes % 2 != 0 {
410        let synthetic_start = sql.len();
411        let fixed = format!("{}\"", sql);
412        return Some((fixed, vec![synthetic_start..synthetic_start + 1]));
413    }
414
415    None
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    #[test]
423    fn test_full_parse_valid_sql() {
424        let sql = "SELECT * FROM users WHERE id = 1";
425        let result = try_full_parse(sql, Dialect::Generic);
426        assert!(result.is_some());
427        assert_eq!(result.unwrap().len(), 1);
428    }
429
430    #[test]
431    fn test_full_parse_invalid_sql() {
432        let sql = "SELECT * FROM";
433        let result = try_full_parse(sql, Dialect::Generic);
434        assert!(result.is_none());
435    }
436
437    #[test]
438    fn test_truncated_parse() {
439        let sql = "SELECT * FROM users WHERE ";
440        let result = try_truncated_parse(sql, sql.len(), Dialect::Generic);
441        assert!(result.is_some());
442    }
443
444    #[test]
445    fn test_complete_statements_only() {
446        let sql = "SELECT 1; SELECT * FROM";
447        let result = try_complete_statements(sql, sql.len(), Dialect::Generic);
448        assert!(result.is_some());
449        assert_eq!(result.unwrap().len(), 1);
450    }
451
452    #[test]
453    fn test_fix_trailing_comma() {
454        let sql = "SELECT a, FROM users";
455        let result = try_with_fixes(sql, sql.len(), Dialect::Generic);
456        assert!(result.is_some());
457    }
458
459    #[test]
460    fn test_fix_unclosed_parens() {
461        let sql = "SELECT COUNT(* FROM users";
462        let result = fix_unclosed_parens(sql, sql.len());
463        assert!(result.is_some());
464        let (fixed, synthetic) = result.unwrap();
465        assert!(fixed.ends_with(')'));
466        assert_eq!(synthetic.len(), 1);
467    }
468
469    #[test]
470    fn test_fix_incomplete_select() {
471        let sql = "SELECT FROM users";
472        let result = fix_incomplete_select(sql, sql.len());
473        assert!(result.is_some());
474        let (fixed, synthetic) = result.unwrap();
475        assert!(fixed.contains("1"));
476        assert_eq!(synthetic.len(), 1);
477    }
478
479    #[test]
480    fn test_fix_incomplete_from() {
481        let sql = "SELECT * FROM";
482        let result = fix_incomplete_from(sql, sql.len());
483        assert!(result.is_some());
484        let (fixed, _) = result.unwrap();
485        assert!(fixed.contains("_dummy_"));
486    }
487
488    #[test]
489    fn test_fix_unclosed_string() {
490        let sql = "SELECT 'hello";
491        let result = fix_unclosed_string(sql, sql.len());
492        assert!(result.is_some());
493        let (fixed, _) = result.unwrap();
494        assert!(fixed.ends_with('\''));
495    }
496
497    #[test]
498    fn test_try_parse_for_completion_valid() {
499        let sql = "SELECT * FROM users";
500        let result = try_parse_for_completion(sql, sql.len(), Dialect::Generic);
501        assert!(result.is_some());
502        assert_eq!(result.unwrap().strategy, ParseStrategy::FullParse);
503    }
504
505    #[test]
506    fn test_try_parse_for_completion_truncated() {
507        let sql = "SELECT * FROM users WHERE id = ";
508        let result = try_parse_for_completion(sql, sql.len(), Dialect::Generic);
509        assert!(result.is_some());
510        // Should fall back to truncated
511        assert!(matches!(
512            result.unwrap().strategy,
513            ParseStrategy::Truncated | ParseStrategy::FullParse
514        ));
515    }
516
517    #[test]
518    fn test_try_parse_for_completion_with_fixes() {
519        // "SELECT FROM users" actually parses in sqlparser 0.59 (empty projection is valid)
520        // Use a truly invalid SQL that requires fixes
521        let sql = "SELECT * FROM";
522        let result = try_parse_for_completion(sql, sql.len(), Dialect::Generic);
523        assert!(result.is_some());
524        assert_eq!(result.unwrap().strategy, ParseStrategy::WithFixes);
525    }
526}