Skip to main content

qail_core/parser/grammar/
base.rs

1use crate::ast::values::IntervalUnit;
2use crate::ast::*;
3use nom::{
4    IResult, Parser,
5    branch::alt,
6    bytes::complete::{tag, tag_no_case, take_while1},
7    character::complete::{char, digit1, multispace0, multispace1},
8    combinator::{map, map_res, opt, recognize, value},
9    sequence::{delimited, preceded},
10};
11
12/// Parse a bare identifier (column, alias, or parameter name).
13pub fn parse_bare_identifier(input: &str) -> IResult<&str, &str> {
14    let (remaining, ident) =
15        take_while1(|c: char| c.is_ascii_alphanumeric() || c == '_').parse(input)?;
16    if is_valid_ident_part(ident) {
17        Ok((remaining, ident))
18    } else {
19        Err(nom::Err::Error(nom::error::Error::new(
20            input,
21            nom::error::ErrorKind::TakeWhile1,
22        )))
23    }
24}
25
26/// Parse checking identifier (table name, column name, or qualified name like table.column)
27pub fn parse_identifier(input: &str) -> IResult<&str, &str> {
28    let (remaining, ident) =
29        take_while1(|c: char| c.is_ascii_alphanumeric() || c == '_' || c == '.').parse(input)?;
30    if ident.split('.').all(is_valid_ident_part) {
31        Ok((remaining, ident))
32    } else {
33        Err(nom::Err::Error(nom::error::Error::new(
34            input,
35            nom::error::ErrorKind::TakeWhile1,
36        )))
37    }
38}
39
40fn is_valid_ident_part(part: &str) -> bool {
41    let mut chars = part.chars();
42    matches!(chars.next(), Some(ch) if ch.is_ascii_alphabetic() || ch == '_')
43        && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
44}
45
46/// Parse interval shorthand: 24h, 7d, 1w, 30m, 6mo, 1y
47pub fn parse_interval(input: &str) -> IResult<&str, Value> {
48    let (input, amount) = map_res(digit1, str::parse::<i64>).parse(input)?;
49
50    let (input, unit) = alt((
51        value(IntervalUnit::Month, tag_no_case("mo")),
52        value(IntervalUnit::Second, tag_no_case("s")),
53        value(IntervalUnit::Minute, tag_no_case("m")),
54        value(IntervalUnit::Hour, tag_no_case("h")),
55        value(IntervalUnit::Day, tag_no_case("d")),
56        value(IntervalUnit::Week, tag_no_case("w")),
57        value(IntervalUnit::Year, tag_no_case("y")),
58    ))
59    .parse(input)?;
60
61    Ok((input, Value::Interval { amount, unit }))
62}
63
64/// Parse value: string, number, bool, null, $param, :named_param, interval, JSON
65pub fn parse_value(input: &str) -> IResult<&str, Value> {
66    alt((
67        // Parameter: $1, $2
68        map_res(preceded(char('$'), digit1), |d: &str| {
69            d.parse::<usize>().map(Value::Param)
70        }),
71        // Named parameter: :name, :id, :user_id
72        map(preceded(char(':'), parse_bare_identifier), |name: &str| {
73            Value::NamedParam(name.to_string())
74        }),
75        // Boolean
76        value(Value::Bool(true), tag_no_case("true")),
77        value(Value::Bool(false), tag_no_case("false")),
78        // Null
79        value(Value::Null, tag_no_case("null")),
80        // Triple-quoted multi-line string (must come before single/double quotes)
81        parse_triple_quoted_string,
82        // JSON object literal: { ... } or array: [ ... ]
83        parse_json_literal,
84        // String (double quoted) - allow empty strings
85        parse_double_quoted_string,
86        // String (single quoted) - allow empty strings
87        parse_single_quoted_string,
88        // Float (must check before int)
89        map_res(
90            recognize((opt(char('-')), digit1, char('.'), digit1)),
91            |s: &str| {
92                let value = s.parse::<f64>().map_err(|err| err.to_string())?;
93                value
94                    .is_finite()
95                    .then_some(Value::Float(value))
96                    .ok_or_else(|| "float literal must be finite".to_string())
97            },
98        ),
99        // Interval shorthand before plain integers: 24h, 7d, 1w
100        parse_interval,
101        // Integer (last, after interval)
102        map_res(recognize((opt(char('-')), digit1)), |s: &str| {
103            s.parse::<i64>().map(Value::Int)
104        }),
105    ))
106    .parse(input)
107}
108
109fn parse_single_quoted_string(input: &str) -> IResult<&str, Value> {
110    parse_quoted_string(input, '\'')
111}
112
113fn parse_double_quoted_string(input: &str) -> IResult<&str, Value> {
114    parse_quoted_string(input, '"')
115}
116
117fn parse_quoted_string(input: &str, quote: char) -> IResult<&str, Value> {
118    if !input.starts_with(quote) {
119        return Err(nom::Err::Error(nom::error::Error::new(
120            input,
121            nom::error::ErrorKind::Char,
122        )));
123    }
124
125    let mut value = String::new();
126    let mut index = quote.len_utf8();
127
128    while index < input.len() {
129        let Some(ch) = input.get(index..).and_then(|s| s.chars().next()) else {
130            return Err(nom::Err::Error(nom::error::Error::new(
131                input,
132                nom::error::ErrorKind::Char,
133            )));
134        };
135        index += ch.len_utf8();
136
137        if ch == quote {
138            if input[index..].starts_with(quote) {
139                value.push(quote);
140                index += quote.len_utf8();
141            } else {
142                return Ok((&input[index..], Value::String(value)));
143            }
144        } else {
145            value.push(ch);
146        }
147    }
148
149    Err(nom::Err::Error(nom::error::Error::new(
150        input,
151        nom::error::ErrorKind::Eof,
152    )))
153}
154
155/// Parse triple-quoted multi-line string: '''content''' or """content"""
156fn parse_triple_quoted_string(input: &str) -> IResult<&str, Value> {
157    alt((
158        // Triple single quotes
159        map(
160            delimited(
161                tag("'''"),
162                nom::bytes::complete::take_until("'''"),
163                tag("'''"),
164            ),
165            |s: &str| Value::String(s.to_string()),
166        ),
167        // Triple double quotes
168        map(
169            delimited(
170                tag("\"\"\""),
171                nom::bytes::complete::take_until("\"\"\""),
172                tag("\"\"\""),
173            ),
174            |s: &str| Value::String(s.to_string()),
175        ),
176    ))
177    .parse(input)
178}
179
180/// Parse JSON object literal: { key: value, ... } or array: [...]
181/// This captures the entire JSON structure as a string for Value::Json
182fn parse_json_literal(input: &str) -> IResult<&str, Value> {
183    // Determine if it's an object or array
184    let trimmed = input.trim_start();
185    if trimmed.is_empty() {
186        return Err(nom::Err::Error(nom::error::Error::new(
187            input,
188            nom::error::ErrorKind::Tag,
189        )));
190    }
191
192    let (open_char, close_char) = match trimmed.chars().next() {
193        Some('{') => ('{', '}'),
194        Some('[') => ('[', ']'),
195        _ => {
196            return Err(nom::Err::Error(nom::error::Error::new(
197                input,
198                nom::error::ErrorKind::Tag,
199            )));
200        }
201    };
202
203    // Count brackets to find matching close
204    let mut depth = 0;
205    let mut in_string = false;
206    let mut escape_next = false;
207    let mut end_pos = 0;
208
209    for (i, c) in trimmed.char_indices() {
210        if escape_next {
211            escape_next = false;
212            continue;
213        }
214
215        if c == '\\' && in_string {
216            escape_next = true;
217            continue;
218        }
219
220        if c == '"' {
221            in_string = !in_string;
222            continue;
223        }
224
225        if !in_string {
226            if c == open_char {
227                depth += 1;
228            } else if c == close_char {
229                depth -= 1;
230                if depth == 0 {
231                    end_pos = i + 1;
232                    break;
233                }
234            }
235        }
236    }
237
238    if depth != 0 || end_pos == 0 {
239        return Err(nom::Err::Error(nom::error::Error::new(
240            input,
241            nom::error::ErrorKind::Eof,
242        )));
243    }
244
245    let json_str = &trimmed[..end_pos];
246    let _remaining = &trimmed[end_pos..];
247
248    if serde_json::from_str::<serde_json::Value>(json_str).is_err() {
249        return Err(nom::Err::Error(nom::error::Error::new(
250            input,
251            nom::error::ErrorKind::Verify,
252        )));
253    }
254
255    // Calculate how much of original input we consumed (account for leading whitespace)
256    let consumed = input.len() - trimmed.len() + end_pos;
257    let remaining_original = &input[consumed..];
258
259    Ok((remaining_original, Value::Json(json_str.to_string())))
260}
261
262/// Parse comparison operator
263pub fn parse_operator(input: &str) -> IResult<&str, Operator> {
264    alt((
265        // Multi-char keyword operators first
266        alt((
267            value(Operator::NotBetween, tag_no_case("not between")),
268            value(Operator::Between, tag_no_case("between")),
269            value(Operator::IsNotNull, tag_no_case("is not null")),
270            value(Operator::IsNull, tag_no_case("is null")),
271            value(Operator::NotIn, tag_no_case("not in")),
272            value(Operator::NotILike, tag_no_case("not ilike")),
273            value(Operator::NotLike, tag_no_case("not like")),
274            value(Operator::SimilarTo, tag_no_case("similar to")),
275            value(Operator::JsonExists, tag_no_case("json_exists")),
276            value(Operator::JsonQuery, tag_no_case("json_query")),
277            value(Operator::JsonValue, tag_no_case("json_value")),
278            value(Operator::Regex, tag_no_case("regex")),
279            value(Operator::ILike, tag_no_case("ilike")),
280            value(Operator::Like, tag_no_case("like")),
281            value(Operator::In, tag_no_case("in")),
282        )),
283        // Multi-char symbol operators (before shorter prefixes)
284        alt((
285            value(Operator::RegexI, tag("~*")),
286            value(Operator::JsonPathText, tag("#>>")),
287            value(Operator::JsonPath, tag("#>")),
288            value(Operator::TextSearch, tag("@@")),
289            value(Operator::KeyExistsAny, tag("?|")),
290            value(Operator::KeyExistsAll, tag("?&")),
291            value(Operator::Contains, tag("@>")),
292            value(Operator::ContainedBy, tag("<@")),
293            value(Operator::Overlaps, tag("&&")),
294            value(Operator::Gte, tag(">=")),
295            value(Operator::Lte, tag("<=")),
296            value(Operator::Ne, tag("!=")),
297            value(Operator::Ne, tag("<>")),
298        )),
299        // Single char operators
300        alt((
301            value(Operator::Eq, tag("=")),
302            value(Operator::Gt, tag(">")),
303            value(Operator::Lt, tag("<")),
304            value(Operator::KeyExists, tag("?")),
305            value(Operator::Fuzzy, tag("~")),
306        )),
307    ))
308    .parse(input)
309}
310
311/// Parse action keyword: get, export, set, del, add, make, merge, cnt
312pub fn parse_action(input: &str) -> IResult<&str, (Action, bool)> {
313    alt((
314        // get distinct
315        map(
316            (tag_no_case("get"), multispace1, tag_no_case("distinct")),
317            |_| (Action::Get, true),
318        ),
319        // get
320        value((Action::Get, false), tag_no_case("get")),
321        // export
322        value((Action::Export, false), tag_no_case("export")),
323        // cnt / count (must come before general keywords)
324        alt((
325            value((Action::Cnt, false), tag_no_case("count")),
326            value((Action::Cnt, false), tag_no_case("cnt")),
327        )),
328        // set
329        value((Action::Set, false), tag_no_case("set")),
330        // merge
331        value((Action::Merge, false), tag_no_case("merge")),
332        // del / delete
333        alt((
334            value((Action::Del, false), tag_no_case("delete")),
335            value((Action::Del, false), tag_no_case("del")),
336        )),
337        // add / insert
338        alt((
339            value((Action::Add, false), tag_no_case("insert")),
340            value((Action::Add, false), tag_no_case("add")),
341        )),
342        // make / create
343        alt((
344            value((Action::Make, false), tag_no_case("create")),
345            value((Action::Make, false), tag_no_case("make")),
346        )),
347    ))
348    .parse(input)
349}
350
351/// Parse transaction commands: begin, commit, rollback
352pub fn parse_txn_command(input: &str) -> IResult<&str, Qail> {
353    let (input, action) = alt((
354        value(Action::TxnStart, tag_no_case("begin")),
355        value(Action::TxnCommit, tag_no_case("commit")),
356        value(Action::TxnRollback, tag_no_case("rollback")),
357    ))
358    .parse(input)?;
359
360    Ok((
361        input,
362        Qail {
363            action,
364            table: String::new(),
365            columns: vec![],
366            joins: vec![],
367            cages: vec![],
368            distinct: false,
369            distinct_on: vec![],
370            index_def: None,
371            table_constraints: vec![],
372            set_ops: vec![],
373            having: vec![],
374            group_by_mode: GroupByMode::default(),
375            ctes: vec![],
376            returning: None,
377            on_conflict: None,
378            merge: None,
379            source_query: None,
380            channel: None,
381            payload: None,
382            savepoint_name: None,
383            from_tables: vec![],
384            using_tables: vec![],
385            lock_mode: None,
386            skip_locked: false,
387            fetch: None,
388            default_values: false,
389            overriding: None,
390            sample: None,
391            only_table: false,
392            vector: None,
393            score_threshold: None,
394            vector_name: None,
395            with_vector: false,
396            vector_size: None,
397            distance: None,
398            on_disk: None,
399            function_def: None,
400            trigger_def: None,
401            policy_def: None,
402        },
403    ))
404}
405
406/// Parse procedural/session commands that don't match the regular `action table ...` flow.
407///
408/// Supported forms:
409/// - `call procedure_name(args...)`
410/// - `do $$ ... $$ [language <lang>]`
411/// - `session set <key> = <value>`
412/// - `session show <key>`
413/// - `session reset <key>`
414pub fn parse_procedural_command(input: &str) -> IResult<&str, Qail> {
415    alt((parse_call_command, parse_do_command, parse_session_command)).parse(input)
416}
417
418fn parse_call_command(input: &str) -> IResult<&str, Qail> {
419    let (input, _) = tag_no_case("call").parse(input)?;
420    let (input, _) = multispace1(input)?;
421
422    let procedure = input.trim().trim_end_matches(';').trim();
423    if procedure.is_empty() {
424        return Err(nom::Err::Error(nom::error::Error::new(
425            input,
426            nom::error::ErrorKind::Eof,
427        )));
428    }
429    if !is_safe_call_target(procedure) {
430        return Err(nom::Err::Error(nom::error::Error::new(
431            input,
432            nom::error::ErrorKind::Tag,
433        )));
434    }
435
436    Ok((
437        "",
438        Qail {
439            action: Action::Call,
440            table: procedure.to_string(),
441            ..Default::default()
442        },
443    ))
444}
445
446fn is_safe_call_target(procedure: &str) -> bool {
447    let procedure = procedure.trim();
448    if procedure.is_empty()
449        || procedure.contains('\0')
450        || procedure.contains(';')
451        || procedure.contains("--")
452        || procedure.contains("/*")
453        || procedure.contains("*/")
454    {
455        return false;
456    }
457
458    match procedure.split_once('(') {
459        Some((name, args)) if args.ends_with(')') && !args[..args.len() - 1].contains('(') => {
460            is_valid_qualified_ident(name.trim())
461        }
462        None => is_valid_qualified_ident(procedure),
463        _ => false,
464    }
465}
466
467fn is_valid_qualified_ident(name: &str) -> bool {
468    !name.is_empty()
469        && name.split('.').all(|part| {
470            let mut chars = part.chars();
471            matches!(chars.next(), Some(ch) if ch.is_ascii_alphabetic() || ch == '_')
472                && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
473        })
474}
475
476fn parse_do_command(input: &str) -> IResult<&str, Qail> {
477    let (input, _) = tag_no_case("do").parse(input)?;
478    let (input, _) = multispace1(input)?;
479
480    let rest = input.trim().trim_end_matches(';').trim();
481    if rest.is_empty() {
482        return Err(nom::Err::Error(nom::error::Error::new(
483            input,
484            nom::error::ErrorKind::Eof,
485        )));
486    }
487
488    // Preferred syntax: do $$...$$ [language <lang>]
489    let (body, language) = if let Some(after_open) = rest.strip_prefix("$$") {
490        if let Some(close_idx) = after_open.find("$$") {
491            let body = after_open[..close_idx].to_string();
492            let trailing = after_open[close_idx + 2..].trim();
493            let lang = if trailing.to_ascii_lowercase().starts_with("language ") {
494                trailing[9..].trim().to_string()
495            } else {
496                "plpgsql".to_string()
497            };
498            (body, lang)
499        } else {
500            (rest.to_string(), "plpgsql".to_string())
501        }
502    } else {
503        (rest.to_string(), "plpgsql".to_string())
504    };
505
506    Ok((
507        "",
508        Qail {
509            action: Action::Do,
510            table: language,
511            payload: Some(body),
512            ..Default::default()
513        },
514    ))
515}
516
517fn parse_session_command(input: &str) -> IResult<&str, Qail> {
518    let (input, _) = tag_no_case("session").parse(input)?;
519    let (input, _) = multispace1(input)?;
520
521    // session set <key> = <value>
522    if let Ok((input, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("set").parse(input) {
523        let (input, _) = multispace1(input)?;
524        let (input, key) = parse_session_setting_key(input)?;
525        let (input, _) = multispace0(input)?;
526        let (input, _) = opt(char('=')).parse(input)?;
527        let value = input.trim().trim_end_matches(';').trim();
528        if value.is_empty() {
529            return Err(nom::Err::Error(nom::error::Error::new(
530                input,
531                nom::error::ErrorKind::Eof,
532            )));
533        }
534        let value = strip_matching_quotes(value);
535        return Ok((
536            "",
537            Qail {
538                action: Action::SessionSet,
539                table: key.to_string(),
540                payload: Some(value.to_string()),
541                ..Default::default()
542            },
543        ));
544    }
545
546    // session show <key>
547    if let Ok((input, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("show").parse(input) {
548        let (input, _) = multispace1(input)?;
549        let (input, key) = parse_session_setting_key(input)?;
550        let trailing = input.trim().trim_end_matches(';').trim();
551        if !trailing.is_empty() {
552            return Err(nom::Err::Error(nom::error::Error::new(
553                input,
554                nom::error::ErrorKind::Tag,
555            )));
556        }
557        return Ok((
558            "",
559            Qail {
560                action: Action::SessionShow,
561                table: key.to_string(),
562                ..Default::default()
563            },
564        ));
565    }
566
567    // session reset <key>
568    let (input, _) = tag_no_case("reset").parse(input)?;
569    let (input, _) = multispace1(input)?;
570    let (input, key) = parse_session_setting_key(input)?;
571    let trailing = input.trim().trim_end_matches(';').trim();
572    if !trailing.is_empty() {
573        return Err(nom::Err::Error(nom::error::Error::new(
574            input,
575            nom::error::ErrorKind::Tag,
576        )));
577    }
578    Ok((
579        "",
580        Qail {
581            action: Action::SessionReset,
582            table: key.to_string(),
583            ..Default::default()
584        },
585    ))
586}
587
588fn parse_session_setting_key(input: &str) -> IResult<&str, &str> {
589    let end = input
590        .char_indices()
591        .find_map(|(idx, ch)| (ch.is_whitespace() || ch == '=' || ch == ';').then_some(idx))
592        .unwrap_or(input.len());
593    let key = &input[..end];
594    if is_valid_session_setting_key(key) {
595        Ok((&input[end..], key))
596    } else {
597        Err(nom::Err::Error(nom::error::Error::new(
598            input,
599            nom::error::ErrorKind::Tag,
600        )))
601    }
602}
603
604fn is_valid_session_setting_key(key: &str) -> bool {
605    !key.is_empty()
606        && key.split('.').all(|part| {
607            let mut chars = part.chars();
608            matches!(chars.next(), Some(ch) if ch.is_ascii_alphabetic() || ch == '_')
609                && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
610        })
611}
612
613fn strip_matching_quotes(s: &str) -> &str {
614    let bytes = s.as_bytes();
615    if bytes.len() >= 2 {
616        let first = bytes[0];
617        let last = bytes[bytes.len() - 1];
618        if (first == b'\'' && last == b'\'') || (first == b'"' && last == b'"') {
619            return &s[1..s.len() - 1];
620        }
621    }
622    s
623}