Skip to main content

qail_core/parser/grammar/
base.rs

1use crate::ast::values::IntervalUnit;
2use crate::ast::*;
3use nom::{
4    IResult, Parser,
5    branch::alt,
6    bytes::complete::{tag, tag_no_case, take_while1},
7    character::complete::{char, digit1, multispace0, multispace1},
8    combinator::{map, map_res, opt, recognize, value},
9    sequence::{delimited, preceded},
10};
11
12/// Parse checking identifier (table name, column name, or qualified name like table.column)
13pub fn parse_identifier(input: &str) -> IResult<&str, &str> {
14    take_while1(|c: char| c.is_alphanumeric() || c == '_' || c == '.').parse(input)
15}
16
17/// Parse interval shorthand: 24h, 7d, 1w, 30m, 6mo, 1y
18pub fn parse_interval(input: &str) -> IResult<&str, Value> {
19    let (input, amount) = map_res(digit1, str::parse::<i64>).parse(input)?;
20
21    let (input, unit) = alt((
22        value(IntervalUnit::Second, tag_no_case("s")),
23        value(IntervalUnit::Minute, tag_no_case("m")),
24        value(IntervalUnit::Hour, tag_no_case("h")),
25        value(IntervalUnit::Day, tag_no_case("d")),
26        value(IntervalUnit::Week, tag_no_case("w")),
27        value(IntervalUnit::Month, tag_no_case("mo")),
28        value(IntervalUnit::Year, tag_no_case("y")),
29    ))
30    .parse(input)?;
31
32    Ok((input, Value::Interval { amount, unit }))
33}
34
35/// Parse value: string, number, bool, null, $param, :named_param, interval, JSON
36pub fn parse_value(input: &str) -> IResult<&str, Value> {
37    alt((
38        // Parameter: $1, $2
39        map_res(preceded(char('$'), digit1), |d: &str| {
40            d.parse::<usize>().map(Value::Param)
41        }),
42        // Named parameter: :name, :id, :user_id
43        map(
44            preceded(
45                char(':'),
46                take_while1(|c: char| c.is_alphanumeric() || c == '_'),
47            ),
48            |name: &str| Value::NamedParam(name.to_string()),
49        ),
50        // Boolean
51        value(Value::Bool(true), tag_no_case("true")),
52        value(Value::Bool(false), tag_no_case("false")),
53        // Null
54        value(Value::Null, tag_no_case("null")),
55        // Triple-quoted multi-line string (must come before single/double quotes)
56        parse_triple_quoted_string,
57        // JSON object literal: { ... } or array: [ ... ]
58        parse_json_literal,
59        // String (double quoted) - allow empty strings
60        parse_double_quoted_string,
61        // String (single quoted) - allow empty strings
62        parse_single_quoted_string,
63        // Float (must check before int)
64        map_res(
65            recognize((opt(char('-')), digit1, char('.'), digit1)),
66            |s: &str| {
67                let value = s.parse::<f64>().map_err(|err| err.to_string())?;
68                value
69                    .is_finite()
70                    .then_some(Value::Float(value))
71                    .ok_or_else(|| "float literal must be finite".to_string())
72            },
73        ),
74        // Interval shorthand before plain integers: 24h, 7d, 1w
75        parse_interval,
76        // Integer (last, after interval)
77        map_res(recognize((opt(char('-')), digit1)), |s: &str| {
78            s.parse::<i64>().map(Value::Int)
79        }),
80    ))
81    .parse(input)
82}
83
84fn parse_single_quoted_string(input: &str) -> IResult<&str, Value> {
85    parse_quoted_string(input, '\'')
86}
87
88fn parse_double_quoted_string(input: &str) -> IResult<&str, Value> {
89    parse_quoted_string(input, '"')
90}
91
92fn parse_quoted_string(input: &str, quote: char) -> IResult<&str, Value> {
93    if !input.starts_with(quote) {
94        return Err(nom::Err::Error(nom::error::Error::new(
95            input,
96            nom::error::ErrorKind::Char,
97        )));
98    }
99
100    let mut value = String::new();
101    let mut index = quote.len_utf8();
102
103    while index < input.len() {
104        let Some(ch) = input.get(index..).and_then(|s| s.chars().next()) else {
105            return Err(nom::Err::Error(nom::error::Error::new(
106                input,
107                nom::error::ErrorKind::Char,
108            )));
109        };
110        index += ch.len_utf8();
111
112        if ch == quote {
113            if input[index..].starts_with(quote) {
114                value.push(quote);
115                index += quote.len_utf8();
116            } else {
117                return Ok((&input[index..], Value::String(value)));
118            }
119        } else {
120            value.push(ch);
121        }
122    }
123
124    Err(nom::Err::Error(nom::error::Error::new(
125        input,
126        nom::error::ErrorKind::Eof,
127    )))
128}
129
130/// Parse triple-quoted multi-line string: '''content''' or """content"""
131fn parse_triple_quoted_string(input: &str) -> IResult<&str, Value> {
132    alt((
133        // Triple single quotes
134        map(
135            delimited(
136                tag("'''"),
137                nom::bytes::complete::take_until("'''"),
138                tag("'''"),
139            ),
140            |s: &str| Value::String(s.to_string()),
141        ),
142        // Triple double quotes
143        map(
144            delimited(
145                tag("\"\"\""),
146                nom::bytes::complete::take_until("\"\"\""),
147                tag("\"\"\""),
148            ),
149            |s: &str| Value::String(s.to_string()),
150        ),
151    ))
152    .parse(input)
153}
154
155/// Parse JSON object literal: { key: value, ... } or array: [...]
156/// This captures the entire JSON structure as a string for Value::Json
157fn parse_json_literal(input: &str) -> IResult<&str, Value> {
158    // Determine if it's an object or array
159    let trimmed = input.trim_start();
160    if trimmed.is_empty() {
161        return Err(nom::Err::Error(nom::error::Error::new(
162            input,
163            nom::error::ErrorKind::Tag,
164        )));
165    }
166
167    let (open_char, close_char) = match trimmed.chars().next() {
168        Some('{') => ('{', '}'),
169        Some('[') => ('[', ']'),
170        _ => {
171            return Err(nom::Err::Error(nom::error::Error::new(
172                input,
173                nom::error::ErrorKind::Tag,
174            )));
175        }
176    };
177
178    // Count brackets to find matching close
179    let mut depth = 0;
180    let mut in_string = false;
181    let mut escape_next = false;
182    let mut end_pos = 0;
183
184    for (i, c) in trimmed.char_indices() {
185        if escape_next {
186            escape_next = false;
187            continue;
188        }
189
190        if c == '\\' && in_string {
191            escape_next = true;
192            continue;
193        }
194
195        if c == '"' {
196            in_string = !in_string;
197            continue;
198        }
199
200        if !in_string {
201            if c == open_char {
202                depth += 1;
203            } else if c == close_char {
204                depth -= 1;
205                if depth == 0 {
206                    end_pos = i + 1;
207                    break;
208                }
209            }
210        }
211    }
212
213    if depth != 0 || end_pos == 0 {
214        return Err(nom::Err::Error(nom::error::Error::new(
215            input,
216            nom::error::ErrorKind::Eof,
217        )));
218    }
219
220    let json_str = &trimmed[..end_pos];
221    let _remaining = &trimmed[end_pos..];
222
223    // Calculate how much of original input we consumed (account for leading whitespace)
224    let consumed = input.len() - trimmed.len() + end_pos;
225    let remaining_original = &input[consumed..];
226
227    Ok((remaining_original, Value::Json(json_str.to_string())))
228}
229
230/// Parse comparison operator
231pub fn parse_operator(input: &str) -> IResult<&str, Operator> {
232    alt((
233        // Multi-char keyword operators first
234        alt((
235            value(Operator::NotBetween, tag_no_case("not between")),
236            value(Operator::Between, tag_no_case("between")),
237            value(Operator::IsNotNull, tag_no_case("is not null")),
238            value(Operator::IsNull, tag_no_case("is null")),
239            value(Operator::NotIn, tag_no_case("not in")),
240            value(Operator::NotILike, tag_no_case("not ilike")),
241            value(Operator::NotLike, tag_no_case("not like")),
242            value(Operator::SimilarTo, tag_no_case("similar to")),
243            value(Operator::JsonExists, tag_no_case("json_exists")),
244            value(Operator::JsonQuery, tag_no_case("json_query")),
245            value(Operator::JsonValue, tag_no_case("json_value")),
246            value(Operator::Regex, tag_no_case("regex")),
247            value(Operator::ILike, tag_no_case("ilike")),
248            value(Operator::Like, tag_no_case("like")),
249            value(Operator::In, tag_no_case("in")),
250        )),
251        // Multi-char symbol operators (before shorter prefixes)
252        alt((
253            value(Operator::RegexI, tag("~*")),
254            value(Operator::JsonPathText, tag("#>>")),
255            value(Operator::JsonPath, tag("#>")),
256            value(Operator::TextSearch, tag("@@")),
257            value(Operator::KeyExistsAny, tag("?|")),
258            value(Operator::KeyExistsAll, tag("?&")),
259            value(Operator::Contains, tag("@>")),
260            value(Operator::ContainedBy, tag("<@")),
261            value(Operator::Overlaps, tag("&&")),
262            value(Operator::Gte, tag(">=")),
263            value(Operator::Lte, tag("<=")),
264            value(Operator::Ne, tag("!=")),
265            value(Operator::Ne, tag("<>")),
266        )),
267        // Single char operators
268        alt((
269            value(Operator::Eq, tag("=")),
270            value(Operator::Gt, tag(">")),
271            value(Operator::Lt, tag("<")),
272            value(Operator::KeyExists, tag("?")),
273            value(Operator::Fuzzy, tag("~")),
274        )),
275    ))
276    .parse(input)
277}
278
279/// Parse action keyword: get, export, set, del, add, make, merge, cnt
280pub fn parse_action(input: &str) -> IResult<&str, (Action, bool)> {
281    alt((
282        // get distinct
283        map(
284            (tag_no_case("get"), multispace1, tag_no_case("distinct")),
285            |_| (Action::Get, true),
286        ),
287        // get
288        value((Action::Get, false), tag_no_case("get")),
289        // export
290        value((Action::Export, false), tag_no_case("export")),
291        // cnt / count (must come before general keywords)
292        alt((
293            value((Action::Cnt, false), tag_no_case("count")),
294            value((Action::Cnt, false), tag_no_case("cnt")),
295        )),
296        // set
297        value((Action::Set, false), tag_no_case("set")),
298        // merge
299        value((Action::Merge, false), tag_no_case("merge")),
300        // del / delete
301        alt((
302            value((Action::Del, false), tag_no_case("delete")),
303            value((Action::Del, false), tag_no_case("del")),
304        )),
305        // add / insert
306        alt((
307            value((Action::Add, false), tag_no_case("insert")),
308            value((Action::Add, false), tag_no_case("add")),
309        )),
310        // make / create
311        alt((
312            value((Action::Make, false), tag_no_case("create")),
313            value((Action::Make, false), tag_no_case("make")),
314        )),
315    ))
316    .parse(input)
317}
318
319/// Parse transaction commands: begin, commit, rollback
320pub fn parse_txn_command(input: &str) -> IResult<&str, Qail> {
321    let (input, action) = alt((
322        value(Action::TxnStart, tag_no_case("begin")),
323        value(Action::TxnCommit, tag_no_case("commit")),
324        value(Action::TxnRollback, tag_no_case("rollback")),
325    ))
326    .parse(input)?;
327
328    Ok((
329        input,
330        Qail {
331            action,
332            table: String::new(),
333            columns: vec![],
334            joins: vec![],
335            cages: vec![],
336            distinct: false,
337            distinct_on: vec![],
338            index_def: None,
339            table_constraints: vec![],
340            set_ops: vec![],
341            having: vec![],
342            group_by_mode: GroupByMode::default(),
343            ctes: vec![],
344            returning: None,
345            on_conflict: None,
346            merge: None,
347            source_query: None,
348            channel: None,
349            payload: None,
350            savepoint_name: None,
351            from_tables: vec![],
352            using_tables: vec![],
353            lock_mode: None,
354            skip_locked: false,
355            fetch: None,
356            default_values: false,
357            overriding: None,
358            sample: None,
359            only_table: false,
360            vector: None,
361            score_threshold: None,
362            vector_name: None,
363            with_vector: false,
364            vector_size: None,
365            distance: None,
366            on_disk: None,
367            function_def: None,
368            trigger_def: None,
369            policy_def: None,
370        },
371    ))
372}
373
374/// Parse procedural/session commands that don't match the regular `action table ...` flow.
375///
376/// Supported forms:
377/// - `call procedure_name(args...)`
378/// - `do $$ ... $$ [language <lang>]`
379/// - `session set <key> = <value>`
380/// - `session show <key>`
381/// - `session reset <key>`
382pub fn parse_procedural_command(input: &str) -> IResult<&str, Qail> {
383    alt((parse_call_command, parse_do_command, parse_session_command)).parse(input)
384}
385
386fn parse_call_command(input: &str) -> IResult<&str, Qail> {
387    let (input, _) = tag_no_case("call").parse(input)?;
388    let (input, _) = multispace1(input)?;
389
390    let procedure = input.trim().trim_end_matches(';').trim();
391    if procedure.is_empty() {
392        return Err(nom::Err::Error(nom::error::Error::new(
393            input,
394            nom::error::ErrorKind::Eof,
395        )));
396    }
397    if !is_safe_call_target(procedure) {
398        return Err(nom::Err::Error(nom::error::Error::new(
399            input,
400            nom::error::ErrorKind::Tag,
401        )));
402    }
403
404    Ok((
405        "",
406        Qail {
407            action: Action::Call,
408            table: procedure.to_string(),
409            ..Default::default()
410        },
411    ))
412}
413
414fn is_safe_call_target(procedure: &str) -> bool {
415    let procedure = procedure.trim();
416    if procedure.is_empty()
417        || procedure.contains('\0')
418        || procedure.contains(';')
419        || procedure.contains("--")
420        || procedure.contains("/*")
421        || procedure.contains("*/")
422    {
423        return false;
424    }
425
426    match procedure.split_once('(') {
427        Some((name, args)) if args.ends_with(')') && !args[..args.len() - 1].contains('(') => {
428            is_valid_qualified_ident(name.trim())
429        }
430        None => is_valid_qualified_ident(procedure),
431        _ => false,
432    }
433}
434
435fn is_valid_qualified_ident(name: &str) -> bool {
436    !name.is_empty()
437        && name.split('.').all(|part| {
438            let mut chars = part.chars();
439            matches!(chars.next(), Some(ch) if ch.is_ascii_alphabetic() || ch == '_')
440                && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
441        })
442}
443
444fn parse_do_command(input: &str) -> IResult<&str, Qail> {
445    let (input, _) = tag_no_case("do").parse(input)?;
446    let (input, _) = multispace1(input)?;
447
448    let rest = input.trim().trim_end_matches(';').trim();
449    if rest.is_empty() {
450        return Err(nom::Err::Error(nom::error::Error::new(
451            input,
452            nom::error::ErrorKind::Eof,
453        )));
454    }
455
456    // Preferred syntax: do $$...$$ [language <lang>]
457    let (body, language) = if let Some(after_open) = rest.strip_prefix("$$") {
458        if let Some(close_idx) = after_open.find("$$") {
459            let body = after_open[..close_idx].to_string();
460            let trailing = after_open[close_idx + 2..].trim();
461            let lang = if trailing.to_ascii_lowercase().starts_with("language ") {
462                trailing[9..].trim().to_string()
463            } else {
464                "plpgsql".to_string()
465            };
466            (body, lang)
467        } else {
468            (rest.to_string(), "plpgsql".to_string())
469        }
470    } else {
471        (rest.to_string(), "plpgsql".to_string())
472    };
473
474    Ok((
475        "",
476        Qail {
477            action: Action::Do,
478            table: language,
479            payload: Some(body),
480            ..Default::default()
481        },
482    ))
483}
484
485fn parse_session_command(input: &str) -> IResult<&str, Qail> {
486    let (input, _) = tag_no_case("session").parse(input)?;
487    let (input, _) = multispace1(input)?;
488
489    // session set <key> = <value>
490    if let Ok((input, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("set").parse(input) {
491        let (input, _) = multispace1(input)?;
492        let (input, key) = parse_session_setting_key(input)?;
493        let (input, _) = multispace0(input)?;
494        let (input, _) = opt(char('=')).parse(input)?;
495        let value = input.trim().trim_end_matches(';').trim();
496        if value.is_empty() {
497            return Err(nom::Err::Error(nom::error::Error::new(
498                input,
499                nom::error::ErrorKind::Eof,
500            )));
501        }
502        let value = strip_matching_quotes(value);
503        return Ok((
504            "",
505            Qail {
506                action: Action::SessionSet,
507                table: key.to_string(),
508                payload: Some(value.to_string()),
509                ..Default::default()
510            },
511        ));
512    }
513
514    // session show <key>
515    if let Ok((input, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("show").parse(input) {
516        let (input, _) = multispace1(input)?;
517        let (input, key) = parse_session_setting_key(input)?;
518        let trailing = input.trim().trim_end_matches(';').trim();
519        if !trailing.is_empty() {
520            return Err(nom::Err::Error(nom::error::Error::new(
521                input,
522                nom::error::ErrorKind::Tag,
523            )));
524        }
525        return Ok((
526            "",
527            Qail {
528                action: Action::SessionShow,
529                table: key.to_string(),
530                ..Default::default()
531            },
532        ));
533    }
534
535    // session reset <key>
536    let (input, _) = tag_no_case("reset").parse(input)?;
537    let (input, _) = multispace1(input)?;
538    let (input, key) = parse_session_setting_key(input)?;
539    let trailing = input.trim().trim_end_matches(';').trim();
540    if !trailing.is_empty() {
541        return Err(nom::Err::Error(nom::error::Error::new(
542            input,
543            nom::error::ErrorKind::Tag,
544        )));
545    }
546    Ok((
547        "",
548        Qail {
549            action: Action::SessionReset,
550            table: key.to_string(),
551            ..Default::default()
552        },
553    ))
554}
555
556fn parse_session_setting_key(input: &str) -> IResult<&str, &str> {
557    let end = input
558        .char_indices()
559        .find_map(|(idx, ch)| (ch.is_whitespace() || ch == '=' || ch == ';').then_some(idx))
560        .unwrap_or(input.len());
561    let key = &input[..end];
562    if is_valid_session_setting_key(key) {
563        Ok((&input[end..], key))
564    } else {
565        Err(nom::Err::Error(nom::error::Error::new(
566            input,
567            nom::error::ErrorKind::Tag,
568        )))
569    }
570}
571
572fn is_valid_session_setting_key(key: &str) -> bool {
573    !key.is_empty()
574        && key.split('.').all(|part| {
575            let mut chars = part.chars();
576            matches!(chars.next(), Some(ch) if ch.is_ascii_alphabetic() || ch == '_')
577                && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
578        })
579}
580
581fn strip_matching_quotes(s: &str) -> &str {
582    let bytes = s.as_bytes();
583    if bytes.len() >= 2 {
584        let first = bytes[0];
585        let last = bytes[bytes.len() - 1];
586        if (first == b'\'' && last == b'\'') || (first == b'"' && last == b'"') {
587            return &s[1..s.len() - 1];
588        }
589    }
590    s
591}