Skip to main content

qail_core/parser/grammar/
base.rs

1use crate::ast::values::IntervalUnit;
2use crate::ast::*;
3use nom::{
4    IResult, Parser,
5    branch::alt,
6    bytes::complete::{tag, tag_no_case, take_while1},
7    character::complete::{char, digit1, multispace0, multispace1},
8    combinator::{map, opt, recognize, value},
9    sequence::{delimited, preceded},
10};
11
12/// Parse checking identifier (table name, column name, or qualified name like table.column)
13pub fn parse_identifier(input: &str) -> IResult<&str, &str> {
14    take_while1(|c: char| c.is_alphanumeric() || c == '_' || c == '.').parse(input)
15}
16
17/// Parse interval shorthand: 24h, 7d, 1w, 30m, 6mo, 1y
18pub fn parse_interval(input: &str) -> IResult<&str, Value> {
19    let (input, num_str) = digit1(input)?;
20    let amount: i64 = num_str.parse().unwrap_or(0);
21
22    let (input, unit) = alt((
23        value(IntervalUnit::Second, tag_no_case("s")),
24        value(IntervalUnit::Minute, tag_no_case("m")),
25        value(IntervalUnit::Hour, tag_no_case("h")),
26        value(IntervalUnit::Day, tag_no_case("d")),
27        value(IntervalUnit::Week, tag_no_case("w")),
28        value(IntervalUnit::Month, tag_no_case("mo")),
29        value(IntervalUnit::Year, tag_no_case("y")),
30    ))
31    .parse(input)?;
32
33    Ok((input, Value::Interval { amount, unit }))
34}
35
36/// Parse value: string, number, bool, null, $param, :named_param, interval, JSON
37pub fn parse_value(input: &str) -> IResult<&str, Value> {
38    alt((
39        // Parameter: $1, $2
40        map(preceded(char('$'), digit1), |d: &str| {
41            Value::Param(d.parse().unwrap_or(0))
42        }),
43        // Named parameter: :name, :id, :user_id
44        map(
45            preceded(
46                char(':'),
47                take_while1(|c: char| c.is_alphanumeric() || c == '_'),
48            ),
49            |name: &str| Value::NamedParam(name.to_string()),
50        ),
51        // Boolean
52        value(Value::Bool(true), tag_no_case("true")),
53        value(Value::Bool(false), tag_no_case("false")),
54        // Null
55        value(Value::Null, tag_no_case("null")),
56        // Triple-quoted multi-line string (must come before single/double quotes)
57        parse_triple_quoted_string,
58        // JSON object literal: { ... } or array: [ ... ]
59        parse_json_literal,
60        // String (double quoted) - allow empty strings
61        map(
62            delimited(
63                char('"'),
64                nom::bytes::complete::take_while(|c| c != '"'),
65                char('"'),
66            ),
67            |s: &str| Value::String(s.to_string()),
68        ),
69        // String (single quoted) - allow empty strings
70        map(
71            delimited(
72                char('\''),
73                nom::bytes::complete::take_while(|c| c != '\''),
74                char('\''),
75            ),
76            |s: &str| Value::String(s.to_string()),
77        ),
78        // Float (must check before int)
79        map(
80            recognize((opt(char('-')), digit1, char('.'), digit1)),
81            |s: &str| Value::Float(s.parse().unwrap_or(0.0)),
82        ),
83        // Interval shorthand before plain integers: 24h, 7d, 1w
84        parse_interval,
85        // Integer (last, after interval)
86        map(recognize((opt(char('-')), digit1)), |s: &str| {
87            Value::Int(s.parse().unwrap_or(0))
88        }),
89    ))
90    .parse(input)
91}
92
93/// Parse triple-quoted multi-line string: '''content''' or """content"""
94fn parse_triple_quoted_string(input: &str) -> IResult<&str, Value> {
95    alt((
96        // Triple single quotes
97        map(
98            delimited(
99                tag("'''"),
100                nom::bytes::complete::take_until("'''"),
101                tag("'''"),
102            ),
103            |s: &str| Value::String(s.to_string()),
104        ),
105        // Triple double quotes
106        map(
107            delimited(
108                tag("\"\"\""),
109                nom::bytes::complete::take_until("\"\"\""),
110                tag("\"\"\""),
111            ),
112            |s: &str| Value::String(s.to_string()),
113        ),
114    ))
115    .parse(input)
116}
117
118/// Parse JSON object literal: { key: value, ... } or array: [...]
119/// This captures the entire JSON structure as a string for Value::Json
120fn parse_json_literal(input: &str) -> IResult<&str, Value> {
121    // Determine if it's an object or array
122    let trimmed = input.trim_start();
123    if trimmed.is_empty() {
124        return Err(nom::Err::Error(nom::error::Error::new(
125            input,
126            nom::error::ErrorKind::Tag,
127        )));
128    }
129
130    let (open_char, close_char) = match trimmed.chars().next() {
131        Some('{') => ('{', '}'),
132        Some('[') => ('[', ']'),
133        _ => {
134            return Err(nom::Err::Error(nom::error::Error::new(
135                input,
136                nom::error::ErrorKind::Tag,
137            )));
138        }
139    };
140
141    // Count brackets to find matching close
142    let mut depth = 0;
143    let mut in_string = false;
144    let mut escape_next = false;
145    let mut end_pos = 0;
146
147    for (i, c) in trimmed.char_indices() {
148        if escape_next {
149            escape_next = false;
150            continue;
151        }
152
153        if c == '\\' && in_string {
154            escape_next = true;
155            continue;
156        }
157
158        if c == '"' {
159            in_string = !in_string;
160            continue;
161        }
162
163        if !in_string {
164            if c == open_char {
165                depth += 1;
166            } else if c == close_char {
167                depth -= 1;
168                if depth == 0 {
169                    end_pos = i + 1;
170                    break;
171                }
172            }
173        }
174    }
175
176    if depth != 0 || end_pos == 0 {
177        return Err(nom::Err::Error(nom::error::Error::new(
178            input,
179            nom::error::ErrorKind::Eof,
180        )));
181    }
182
183    let json_str = &trimmed[..end_pos];
184    let _remaining = &trimmed[end_pos..];
185
186    // Calculate how much of original input we consumed (account for leading whitespace)
187    let consumed = input.len() - trimmed.len() + end_pos;
188    let remaining_original = &input[consumed..];
189
190    Ok((remaining_original, Value::Json(json_str.to_string())))
191}
192
193/// Parse comparison operator
194pub fn parse_operator(input: &str) -> IResult<&str, Operator> {
195    alt((
196        // Multi-char keyword operators first
197        alt((
198            value(Operator::NotBetween, tag_no_case("not between")),
199            value(Operator::Between, tag_no_case("between")),
200            value(Operator::IsNotNull, tag_no_case("is not null")),
201            value(Operator::IsNull, tag_no_case("is null")),
202            value(Operator::NotIn, tag_no_case("not in")),
203            value(Operator::NotILike, tag_no_case("not ilike")),
204            value(Operator::NotLike, tag_no_case("not like")),
205            value(Operator::SimilarTo, tag_no_case("similar to")),
206            value(Operator::JsonExists, tag_no_case("json_exists")),
207            value(Operator::JsonQuery, tag_no_case("json_query")),
208            value(Operator::JsonValue, tag_no_case("json_value")),
209            value(Operator::Regex, tag_no_case("regex")),
210            value(Operator::ILike, tag_no_case("ilike")),
211            value(Operator::Like, tag_no_case("like")),
212            value(Operator::In, tag_no_case("in")),
213        )),
214        // Multi-char symbol operators (before shorter prefixes)
215        alt((
216            value(Operator::RegexI, tag("~*")),
217            value(Operator::JsonPathText, tag("#>>")),
218            value(Operator::JsonPath, tag("#>")),
219            value(Operator::TextSearch, tag("@@")),
220            value(Operator::KeyExistsAny, tag("?|")),
221            value(Operator::KeyExistsAll, tag("?&")),
222            value(Operator::Contains, tag("@>")),
223            value(Operator::ContainedBy, tag("<@")),
224            value(Operator::Overlaps, tag("&&")),
225            value(Operator::Gte, tag(">=")),
226            value(Operator::Lte, tag("<=")),
227            value(Operator::Ne, tag("!=")),
228            value(Operator::Ne, tag("<>")),
229        )),
230        // Single char operators
231        alt((
232            value(Operator::Eq, tag("=")),
233            value(Operator::Gt, tag(">")),
234            value(Operator::Lt, tag("<")),
235            value(Operator::KeyExists, tag("?")),
236            value(Operator::Fuzzy, tag("~")),
237        )),
238    ))
239    .parse(input)
240}
241
242/// Parse action keyword: get, export, set, del, add, make, cnt
243pub fn parse_action(input: &str) -> IResult<&str, (Action, bool)> {
244    alt((
245        // get distinct
246        map(
247            (tag_no_case("get"), multispace1, tag_no_case("distinct")),
248            |_| (Action::Get, true),
249        ),
250        // get
251        value((Action::Get, false), tag_no_case("get")),
252        // export
253        value((Action::Export, false), tag_no_case("export")),
254        // cnt / count (must come before general keywords)
255        alt((
256            value((Action::Cnt, false), tag_no_case("count")),
257            value((Action::Cnt, false), tag_no_case("cnt")),
258        )),
259        // set
260        value((Action::Set, false), tag_no_case("set")),
261        // del / delete
262        alt((
263            value((Action::Del, false), tag_no_case("delete")),
264            value((Action::Del, false), tag_no_case("del")),
265        )),
266        // add / insert
267        alt((
268            value((Action::Add, false), tag_no_case("insert")),
269            value((Action::Add, false), tag_no_case("add")),
270        )),
271        // make / create
272        alt((
273            value((Action::Make, false), tag_no_case("create")),
274            value((Action::Make, false), tag_no_case("make")),
275        )),
276    ))
277    .parse(input)
278}
279
280/// Parse transaction commands: begin, commit, rollback
281pub fn parse_txn_command(input: &str) -> IResult<&str, Qail> {
282    let (input, action) = alt((
283        value(Action::TxnStart, tag_no_case("begin")),
284        value(Action::TxnCommit, tag_no_case("commit")),
285        value(Action::TxnRollback, tag_no_case("rollback")),
286    ))
287    .parse(input)?;
288
289    Ok((
290        input,
291        Qail {
292            action,
293            table: String::new(),
294            columns: vec![],
295            joins: vec![],
296            cages: vec![],
297            distinct: false,
298            distinct_on: vec![],
299            index_def: None,
300            table_constraints: vec![],
301            set_ops: vec![],
302            having: vec![],
303            group_by_mode: GroupByMode::default(),
304            ctes: vec![],
305            returning: None,
306            on_conflict: None,
307            source_query: None,
308            channel: None,
309            payload: None,
310            savepoint_name: None,
311            from_tables: vec![],
312            using_tables: vec![],
313            lock_mode: None,
314            skip_locked: false,
315            fetch: None,
316            default_values: false,
317            overriding: None,
318            sample: None,
319            only_table: false,
320            vector: None,
321            score_threshold: None,
322            vector_name: None,
323            with_vector: false,
324            vector_size: None,
325            distance: None,
326            on_disk: None,
327            function_def: None,
328            trigger_def: None,
329        },
330    ))
331}
332
333/// Parse procedural/session commands that don't match the regular `action table ...` flow.
334///
335/// Supported forms:
336/// - `call procedure_name(args...)`
337/// - `do $$ ... $$ [language <lang>]`
338/// - `session set <key> = <value>`
339/// - `session show <key>`
340/// - `session reset <key>`
341pub fn parse_procedural_command(input: &str) -> IResult<&str, Qail> {
342    alt((parse_call_command, parse_do_command, parse_session_command)).parse(input)
343}
344
345fn parse_call_command(input: &str) -> IResult<&str, Qail> {
346    let (input, _) = tag_no_case("call").parse(input)?;
347    let (input, _) = multispace1(input)?;
348
349    let procedure = input.trim().trim_end_matches(';').trim();
350    if procedure.is_empty() {
351        return Err(nom::Err::Error(nom::error::Error::new(
352            input,
353            nom::error::ErrorKind::Eof,
354        )));
355    }
356
357    Ok((
358        "",
359        Qail {
360            action: Action::Call,
361            table: procedure.to_string(),
362            ..Default::default()
363        },
364    ))
365}
366
367fn parse_do_command(input: &str) -> IResult<&str, Qail> {
368    let (input, _) = tag_no_case("do").parse(input)?;
369    let (input, _) = multispace1(input)?;
370
371    let rest = input.trim().trim_end_matches(';').trim();
372    if rest.is_empty() {
373        return Err(nom::Err::Error(nom::error::Error::new(
374            input,
375            nom::error::ErrorKind::Eof,
376        )));
377    }
378
379    // Preferred syntax: do $$...$$ [language <lang>]
380    let (body, language) = if let Some(after_open) = rest.strip_prefix("$$") {
381        if let Some(close_idx) = after_open.find("$$") {
382            let body = after_open[..close_idx].to_string();
383            let trailing = after_open[close_idx + 2..].trim();
384            let lang = if trailing.to_ascii_lowercase().starts_with("language ") {
385                trailing[9..].trim().to_string()
386            } else {
387                "plpgsql".to_string()
388            };
389            (body, lang)
390        } else {
391            (rest.to_string(), "plpgsql".to_string())
392        }
393    } else {
394        (rest.to_string(), "plpgsql".to_string())
395    };
396
397    Ok((
398        "",
399        Qail {
400            action: Action::Do,
401            table: language,
402            payload: Some(body),
403            ..Default::default()
404        },
405    ))
406}
407
408fn parse_session_command(input: &str) -> IResult<&str, Qail> {
409    let (input, _) = tag_no_case("session").parse(input)?;
410    let (input, _) = multispace1(input)?;
411
412    // session set <key> = <value>
413    if let Ok((input, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("set").parse(input) {
414        let (input, _) = multispace1(input)?;
415        let (input, key) = parse_identifier(input)?;
416        let (input, _) = multispace0(input)?;
417        let (input, _) = opt(char('=')).parse(input)?;
418        let value = input.trim().trim_end_matches(';').trim();
419        if value.is_empty() {
420            return Err(nom::Err::Error(nom::error::Error::new(
421                input,
422                nom::error::ErrorKind::Eof,
423            )));
424        }
425        let value = strip_matching_quotes(value);
426        return Ok((
427            "",
428            Qail {
429                action: Action::SessionSet,
430                table: key.to_string(),
431                payload: Some(value.to_string()),
432                ..Default::default()
433            },
434        ));
435    }
436
437    // session show <key>
438    if let Ok((input, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("show").parse(input) {
439        let (input, _) = multispace1(input)?;
440        let (input, key) = parse_identifier(input)?;
441        let trailing = input.trim().trim_end_matches(';').trim();
442        if !trailing.is_empty() {
443            return Err(nom::Err::Error(nom::error::Error::new(
444                input,
445                nom::error::ErrorKind::Tag,
446            )));
447        }
448        return Ok((
449            "",
450            Qail {
451                action: Action::SessionShow,
452                table: key.to_string(),
453                ..Default::default()
454            },
455        ));
456    }
457
458    // session reset <key>
459    let (input, _) = tag_no_case("reset").parse(input)?;
460    let (input, _) = multispace1(input)?;
461    let (input, key) = parse_identifier(input)?;
462    let trailing = input.trim().trim_end_matches(';').trim();
463    if !trailing.is_empty() {
464        return Err(nom::Err::Error(nom::error::Error::new(
465            input,
466            nom::error::ErrorKind::Tag,
467        )));
468    }
469    Ok((
470        "",
471        Qail {
472            action: Action::SessionReset,
473            table: key.to_string(),
474            ..Default::default()
475        },
476    ))
477}
478
479fn strip_matching_quotes(s: &str) -> &str {
480    let bytes = s.as_bytes();
481    if bytes.len() >= 2 {
482        let first = bytes[0];
483        let last = bytes[bytes.len() - 1];
484        if (first == b'\'' && last == b'\'') || (first == b'"' && last == b'"') {
485            return &s[1..s.len() - 1];
486        }
487    }
488    s
489}