1use crate::ast::values::IntervalUnit;
2use crate::ast::*;
3use nom::{
4 IResult, Parser,
5 branch::alt,
6 bytes::complete::{tag, tag_no_case, take_while1},
7 character::complete::{char, digit1, multispace0, multispace1},
8 combinator::{map, map_res, opt, recognize, value},
9 sequence::{delimited, preceded},
10};
11
12pub fn parse_identifier(input: &str) -> IResult<&str, &str> {
14 take_while1(|c: char| c.is_alphanumeric() || c == '_' || c == '.').parse(input)
15}
16
17pub fn parse_interval(input: &str) -> IResult<&str, Value> {
19 let (input, amount) = map_res(digit1, str::parse::<i64>).parse(input)?;
20
21 let (input, unit) = alt((
22 value(IntervalUnit::Second, tag_no_case("s")),
23 value(IntervalUnit::Minute, tag_no_case("m")),
24 value(IntervalUnit::Hour, tag_no_case("h")),
25 value(IntervalUnit::Day, tag_no_case("d")),
26 value(IntervalUnit::Week, tag_no_case("w")),
27 value(IntervalUnit::Month, tag_no_case("mo")),
28 value(IntervalUnit::Year, tag_no_case("y")),
29 ))
30 .parse(input)?;
31
32 Ok((input, Value::Interval { amount, unit }))
33}
34
35pub fn parse_value(input: &str) -> IResult<&str, Value> {
37 alt((
38 map_res(preceded(char('$'), digit1), |d: &str| {
40 d.parse::<usize>().map(Value::Param)
41 }),
42 map(
44 preceded(
45 char(':'),
46 take_while1(|c: char| c.is_alphanumeric() || c == '_'),
47 ),
48 |name: &str| Value::NamedParam(name.to_string()),
49 ),
50 value(Value::Bool(true), tag_no_case("true")),
52 value(Value::Bool(false), tag_no_case("false")),
53 value(Value::Null, tag_no_case("null")),
55 parse_triple_quoted_string,
57 parse_json_literal,
59 parse_double_quoted_string,
61 parse_single_quoted_string,
63 map_res(
65 recognize((opt(char('-')), digit1, char('.'), digit1)),
66 |s: &str| {
67 let value = s.parse::<f64>().map_err(|err| err.to_string())?;
68 value
69 .is_finite()
70 .then_some(Value::Float(value))
71 .ok_or_else(|| "float literal must be finite".to_string())
72 },
73 ),
74 parse_interval,
76 map_res(recognize((opt(char('-')), digit1)), |s: &str| {
78 s.parse::<i64>().map(Value::Int)
79 }),
80 ))
81 .parse(input)
82}
83
84fn parse_single_quoted_string(input: &str) -> IResult<&str, Value> {
85 parse_quoted_string(input, '\'')
86}
87
88fn parse_double_quoted_string(input: &str) -> IResult<&str, Value> {
89 parse_quoted_string(input, '"')
90}
91
92fn parse_quoted_string(input: &str, quote: char) -> IResult<&str, Value> {
93 if !input.starts_with(quote) {
94 return Err(nom::Err::Error(nom::error::Error::new(
95 input,
96 nom::error::ErrorKind::Char,
97 )));
98 }
99
100 let mut value = String::new();
101 let mut index = quote.len_utf8();
102
103 while index < input.len() {
104 let Some(ch) = input.get(index..).and_then(|s| s.chars().next()) else {
105 return Err(nom::Err::Error(nom::error::Error::new(
106 input,
107 nom::error::ErrorKind::Char,
108 )));
109 };
110 index += ch.len_utf8();
111
112 if ch == quote {
113 if input[index..].starts_with(quote) {
114 value.push(quote);
115 index += quote.len_utf8();
116 } else {
117 return Ok((&input[index..], Value::String(value)));
118 }
119 } else {
120 value.push(ch);
121 }
122 }
123
124 Err(nom::Err::Error(nom::error::Error::new(
125 input,
126 nom::error::ErrorKind::Eof,
127 )))
128}
129
130fn parse_triple_quoted_string(input: &str) -> IResult<&str, Value> {
132 alt((
133 map(
135 delimited(
136 tag("'''"),
137 nom::bytes::complete::take_until("'''"),
138 tag("'''"),
139 ),
140 |s: &str| Value::String(s.to_string()),
141 ),
142 map(
144 delimited(
145 tag("\"\"\""),
146 nom::bytes::complete::take_until("\"\"\""),
147 tag("\"\"\""),
148 ),
149 |s: &str| Value::String(s.to_string()),
150 ),
151 ))
152 .parse(input)
153}
154
155fn parse_json_literal(input: &str) -> IResult<&str, Value> {
158 let trimmed = input.trim_start();
160 if trimmed.is_empty() {
161 return Err(nom::Err::Error(nom::error::Error::new(
162 input,
163 nom::error::ErrorKind::Tag,
164 )));
165 }
166
167 let (open_char, close_char) = match trimmed.chars().next() {
168 Some('{') => ('{', '}'),
169 Some('[') => ('[', ']'),
170 _ => {
171 return Err(nom::Err::Error(nom::error::Error::new(
172 input,
173 nom::error::ErrorKind::Tag,
174 )));
175 }
176 };
177
178 let mut depth = 0;
180 let mut in_string = false;
181 let mut escape_next = false;
182 let mut end_pos = 0;
183
184 for (i, c) in trimmed.char_indices() {
185 if escape_next {
186 escape_next = false;
187 continue;
188 }
189
190 if c == '\\' && in_string {
191 escape_next = true;
192 continue;
193 }
194
195 if c == '"' {
196 in_string = !in_string;
197 continue;
198 }
199
200 if !in_string {
201 if c == open_char {
202 depth += 1;
203 } else if c == close_char {
204 depth -= 1;
205 if depth == 0 {
206 end_pos = i + 1;
207 break;
208 }
209 }
210 }
211 }
212
213 if depth != 0 || end_pos == 0 {
214 return Err(nom::Err::Error(nom::error::Error::new(
215 input,
216 nom::error::ErrorKind::Eof,
217 )));
218 }
219
220 let json_str = &trimmed[..end_pos];
221 let _remaining = &trimmed[end_pos..];
222
223 let consumed = input.len() - trimmed.len() + end_pos;
225 let remaining_original = &input[consumed..];
226
227 Ok((remaining_original, Value::Json(json_str.to_string())))
228}
229
230pub fn parse_operator(input: &str) -> IResult<&str, Operator> {
232 alt((
233 alt((
235 value(Operator::NotBetween, tag_no_case("not between")),
236 value(Operator::Between, tag_no_case("between")),
237 value(Operator::IsNotNull, tag_no_case("is not null")),
238 value(Operator::IsNull, tag_no_case("is null")),
239 value(Operator::NotIn, tag_no_case("not in")),
240 value(Operator::NotILike, tag_no_case("not ilike")),
241 value(Operator::NotLike, tag_no_case("not like")),
242 value(Operator::SimilarTo, tag_no_case("similar to")),
243 value(Operator::JsonExists, tag_no_case("json_exists")),
244 value(Operator::JsonQuery, tag_no_case("json_query")),
245 value(Operator::JsonValue, tag_no_case("json_value")),
246 value(Operator::Regex, tag_no_case("regex")),
247 value(Operator::ILike, tag_no_case("ilike")),
248 value(Operator::Like, tag_no_case("like")),
249 value(Operator::In, tag_no_case("in")),
250 )),
251 alt((
253 value(Operator::RegexI, tag("~*")),
254 value(Operator::JsonPathText, tag("#>>")),
255 value(Operator::JsonPath, tag("#>")),
256 value(Operator::TextSearch, tag("@@")),
257 value(Operator::KeyExistsAny, tag("?|")),
258 value(Operator::KeyExistsAll, tag("?&")),
259 value(Operator::Contains, tag("@>")),
260 value(Operator::ContainedBy, tag("<@")),
261 value(Operator::Overlaps, tag("&&")),
262 value(Operator::Gte, tag(">=")),
263 value(Operator::Lte, tag("<=")),
264 value(Operator::Ne, tag("!=")),
265 value(Operator::Ne, tag("<>")),
266 )),
267 alt((
269 value(Operator::Eq, tag("=")),
270 value(Operator::Gt, tag(">")),
271 value(Operator::Lt, tag("<")),
272 value(Operator::KeyExists, tag("?")),
273 value(Operator::Fuzzy, tag("~")),
274 )),
275 ))
276 .parse(input)
277}
278
279pub fn parse_action(input: &str) -> IResult<&str, (Action, bool)> {
281 alt((
282 map(
284 (tag_no_case("get"), multispace1, tag_no_case("distinct")),
285 |_| (Action::Get, true),
286 ),
287 value((Action::Get, false), tag_no_case("get")),
289 value((Action::Export, false), tag_no_case("export")),
291 alt((
293 value((Action::Cnt, false), tag_no_case("count")),
294 value((Action::Cnt, false), tag_no_case("cnt")),
295 )),
296 value((Action::Set, false), tag_no_case("set")),
298 value((Action::Merge, false), tag_no_case("merge")),
300 alt((
302 value((Action::Del, false), tag_no_case("delete")),
303 value((Action::Del, false), tag_no_case("del")),
304 )),
305 alt((
307 value((Action::Add, false), tag_no_case("insert")),
308 value((Action::Add, false), tag_no_case("add")),
309 )),
310 alt((
312 value((Action::Make, false), tag_no_case("create")),
313 value((Action::Make, false), tag_no_case("make")),
314 )),
315 ))
316 .parse(input)
317}
318
319pub fn parse_txn_command(input: &str) -> IResult<&str, Qail> {
321 let (input, action) = alt((
322 value(Action::TxnStart, tag_no_case("begin")),
323 value(Action::TxnCommit, tag_no_case("commit")),
324 value(Action::TxnRollback, tag_no_case("rollback")),
325 ))
326 .parse(input)?;
327
328 Ok((
329 input,
330 Qail {
331 action,
332 table: String::new(),
333 columns: vec![],
334 joins: vec![],
335 cages: vec![],
336 distinct: false,
337 distinct_on: vec![],
338 index_def: None,
339 table_constraints: vec![],
340 set_ops: vec![],
341 having: vec![],
342 group_by_mode: GroupByMode::default(),
343 ctes: vec![],
344 returning: None,
345 on_conflict: None,
346 merge: None,
347 source_query: None,
348 channel: None,
349 payload: None,
350 savepoint_name: None,
351 from_tables: vec![],
352 using_tables: vec![],
353 lock_mode: None,
354 skip_locked: false,
355 fetch: None,
356 default_values: false,
357 overriding: None,
358 sample: None,
359 only_table: false,
360 vector: None,
361 score_threshold: None,
362 vector_name: None,
363 with_vector: false,
364 vector_size: None,
365 distance: None,
366 on_disk: None,
367 function_def: None,
368 trigger_def: None,
369 policy_def: None,
370 },
371 ))
372}
373
374pub fn parse_procedural_command(input: &str) -> IResult<&str, Qail> {
383 alt((parse_call_command, parse_do_command, parse_session_command)).parse(input)
384}
385
386fn parse_call_command(input: &str) -> IResult<&str, Qail> {
387 let (input, _) = tag_no_case("call").parse(input)?;
388 let (input, _) = multispace1(input)?;
389
390 let procedure = input.trim().trim_end_matches(';').trim();
391 if procedure.is_empty() {
392 return Err(nom::Err::Error(nom::error::Error::new(
393 input,
394 nom::error::ErrorKind::Eof,
395 )));
396 }
397 if !is_safe_call_target(procedure) {
398 return Err(nom::Err::Error(nom::error::Error::new(
399 input,
400 nom::error::ErrorKind::Tag,
401 )));
402 }
403
404 Ok((
405 "",
406 Qail {
407 action: Action::Call,
408 table: procedure.to_string(),
409 ..Default::default()
410 },
411 ))
412}
413
414fn is_safe_call_target(procedure: &str) -> bool {
415 let procedure = procedure.trim();
416 if procedure.is_empty()
417 || procedure.contains('\0')
418 || procedure.contains(';')
419 || procedure.contains("--")
420 || procedure.contains("/*")
421 || procedure.contains("*/")
422 {
423 return false;
424 }
425
426 match procedure.split_once('(') {
427 Some((name, args)) if args.ends_with(')') && !args[..args.len() - 1].contains('(') => {
428 is_valid_qualified_ident(name.trim())
429 }
430 None => is_valid_qualified_ident(procedure),
431 _ => false,
432 }
433}
434
435fn is_valid_qualified_ident(name: &str) -> bool {
436 !name.is_empty()
437 && name.split('.').all(|part| {
438 let mut chars = part.chars();
439 matches!(chars.next(), Some(ch) if ch.is_ascii_alphabetic() || ch == '_')
440 && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
441 })
442}
443
444fn parse_do_command(input: &str) -> IResult<&str, Qail> {
445 let (input, _) = tag_no_case("do").parse(input)?;
446 let (input, _) = multispace1(input)?;
447
448 let rest = input.trim().trim_end_matches(';').trim();
449 if rest.is_empty() {
450 return Err(nom::Err::Error(nom::error::Error::new(
451 input,
452 nom::error::ErrorKind::Eof,
453 )));
454 }
455
456 let (body, language) = if let Some(after_open) = rest.strip_prefix("$$") {
458 if let Some(close_idx) = after_open.find("$$") {
459 let body = after_open[..close_idx].to_string();
460 let trailing = after_open[close_idx + 2..].trim();
461 let lang = if trailing.to_ascii_lowercase().starts_with("language ") {
462 trailing[9..].trim().to_string()
463 } else {
464 "plpgsql".to_string()
465 };
466 (body, lang)
467 } else {
468 (rest.to_string(), "plpgsql".to_string())
469 }
470 } else {
471 (rest.to_string(), "plpgsql".to_string())
472 };
473
474 Ok((
475 "",
476 Qail {
477 action: Action::Do,
478 table: language,
479 payload: Some(body),
480 ..Default::default()
481 },
482 ))
483}
484
485fn parse_session_command(input: &str) -> IResult<&str, Qail> {
486 let (input, _) = tag_no_case("session").parse(input)?;
487 let (input, _) = multispace1(input)?;
488
489 if let Ok((input, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("set").parse(input) {
491 let (input, _) = multispace1(input)?;
492 let (input, key) = parse_session_setting_key(input)?;
493 let (input, _) = multispace0(input)?;
494 let (input, _) = opt(char('=')).parse(input)?;
495 let value = input.trim().trim_end_matches(';').trim();
496 if value.is_empty() {
497 return Err(nom::Err::Error(nom::error::Error::new(
498 input,
499 nom::error::ErrorKind::Eof,
500 )));
501 }
502 let value = strip_matching_quotes(value);
503 return Ok((
504 "",
505 Qail {
506 action: Action::SessionSet,
507 table: key.to_string(),
508 payload: Some(value.to_string()),
509 ..Default::default()
510 },
511 ));
512 }
513
514 if let Ok((input, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("show").parse(input) {
516 let (input, _) = multispace1(input)?;
517 let (input, key) = parse_session_setting_key(input)?;
518 let trailing = input.trim().trim_end_matches(';').trim();
519 if !trailing.is_empty() {
520 return Err(nom::Err::Error(nom::error::Error::new(
521 input,
522 nom::error::ErrorKind::Tag,
523 )));
524 }
525 return Ok((
526 "",
527 Qail {
528 action: Action::SessionShow,
529 table: key.to_string(),
530 ..Default::default()
531 },
532 ));
533 }
534
535 let (input, _) = tag_no_case("reset").parse(input)?;
537 let (input, _) = multispace1(input)?;
538 let (input, key) = parse_session_setting_key(input)?;
539 let trailing = input.trim().trim_end_matches(';').trim();
540 if !trailing.is_empty() {
541 return Err(nom::Err::Error(nom::error::Error::new(
542 input,
543 nom::error::ErrorKind::Tag,
544 )));
545 }
546 Ok((
547 "",
548 Qail {
549 action: Action::SessionReset,
550 table: key.to_string(),
551 ..Default::default()
552 },
553 ))
554}
555
556fn parse_session_setting_key(input: &str) -> IResult<&str, &str> {
557 let end = input
558 .char_indices()
559 .find_map(|(idx, ch)| (ch.is_whitespace() || ch == '=' || ch == ';').then_some(idx))
560 .unwrap_or(input.len());
561 let key = &input[..end];
562 if is_valid_session_setting_key(key) {
563 Ok((&input[end..], key))
564 } else {
565 Err(nom::Err::Error(nom::error::Error::new(
566 input,
567 nom::error::ErrorKind::Tag,
568 )))
569 }
570}
571
572fn is_valid_session_setting_key(key: &str) -> bool {
573 !key.is_empty()
574 && key.split('.').all(|part| {
575 let mut chars = part.chars();
576 matches!(chars.next(), Some(ch) if ch.is_ascii_alphabetic() || ch == '_')
577 && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
578 })
579}
580
581fn strip_matching_quotes(s: &str) -> &str {
582 let bytes = s.as_bytes();
583 if bytes.len() >= 2 {
584 let first = bytes[0];
585 let last = bytes[bytes.len() - 1];
586 if (first == b'\'' && last == b'\'') || (first == b'"' && last == b'"') {
587 return &s[1..s.len() - 1];
588 }
589 }
590 s
591}