1use crate::ast::values::IntervalUnit;
2use crate::ast::*;
3use nom::{
4 IResult, Parser,
5 branch::alt,
6 bytes::complete::{tag, tag_no_case, take_while1},
7 character::complete::{char, digit1, multispace0, multispace1},
8 combinator::{map, map_res, opt, recognize, value},
9 sequence::{delimited, preceded},
10};
11
12pub fn parse_bare_identifier(input: &str) -> IResult<&str, &str> {
14 let (remaining, ident) =
15 take_while1(|c: char| c.is_ascii_alphanumeric() || c == '_').parse(input)?;
16 if is_valid_ident_part(ident) {
17 Ok((remaining, ident))
18 } else {
19 Err(nom::Err::Error(nom::error::Error::new(
20 input,
21 nom::error::ErrorKind::TakeWhile1,
22 )))
23 }
24}
25
26pub fn parse_identifier(input: &str) -> IResult<&str, &str> {
28 let (remaining, ident) =
29 take_while1(|c: char| c.is_ascii_alphanumeric() || c == '_' || c == '.').parse(input)?;
30 if ident.split('.').all(is_valid_ident_part) {
31 Ok((remaining, ident))
32 } else {
33 Err(nom::Err::Error(nom::error::Error::new(
34 input,
35 nom::error::ErrorKind::TakeWhile1,
36 )))
37 }
38}
39
40fn is_valid_ident_part(part: &str) -> bool {
41 let mut chars = part.chars();
42 matches!(chars.next(), Some(ch) if ch.is_ascii_alphabetic() || ch == '_')
43 && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
44}
45
46pub fn parse_interval(input: &str) -> IResult<&str, Value> {
48 let (input, amount) = map_res(digit1, str::parse::<i64>).parse(input)?;
49
50 let (input, unit) = alt((
51 value(IntervalUnit::Month, tag_no_case("mo")),
52 value(IntervalUnit::Second, tag_no_case("s")),
53 value(IntervalUnit::Minute, tag_no_case("m")),
54 value(IntervalUnit::Hour, tag_no_case("h")),
55 value(IntervalUnit::Day, tag_no_case("d")),
56 value(IntervalUnit::Week, tag_no_case("w")),
57 value(IntervalUnit::Year, tag_no_case("y")),
58 ))
59 .parse(input)?;
60
61 Ok((input, Value::Interval { amount, unit }))
62}
63
64pub fn parse_value(input: &str) -> IResult<&str, Value> {
66 alt((
67 map_res(preceded(char('$'), digit1), |d: &str| {
69 d.parse::<usize>().map(Value::Param)
70 }),
71 map(preceded(char(':'), parse_bare_identifier), |name: &str| {
73 Value::NamedParam(name.to_string())
74 }),
75 value(Value::Bool(true), tag_no_case("true")),
77 value(Value::Bool(false), tag_no_case("false")),
78 value(Value::Null, tag_no_case("null")),
80 parse_triple_quoted_string,
82 parse_json_literal,
84 parse_double_quoted_string,
86 parse_single_quoted_string,
88 map_res(
90 recognize((opt(char('-')), digit1, char('.'), digit1)),
91 |s: &str| {
92 let value = s.parse::<f64>().map_err(|err| err.to_string())?;
93 value
94 .is_finite()
95 .then_some(Value::Float(value))
96 .ok_or_else(|| "float literal must be finite".to_string())
97 },
98 ),
99 parse_interval,
101 map_res(recognize((opt(char('-')), digit1)), |s: &str| {
103 s.parse::<i64>().map(Value::Int)
104 }),
105 ))
106 .parse(input)
107}
108
109fn parse_single_quoted_string(input: &str) -> IResult<&str, Value> {
110 parse_quoted_string(input, '\'')
111}
112
113fn parse_double_quoted_string(input: &str) -> IResult<&str, Value> {
114 parse_quoted_string(input, '"')
115}
116
117fn parse_quoted_string(input: &str, quote: char) -> IResult<&str, Value> {
118 if !input.starts_with(quote) {
119 return Err(nom::Err::Error(nom::error::Error::new(
120 input,
121 nom::error::ErrorKind::Char,
122 )));
123 }
124
125 let mut value = String::new();
126 let mut index = quote.len_utf8();
127
128 while index < input.len() {
129 let Some(ch) = input.get(index..).and_then(|s| s.chars().next()) else {
130 return Err(nom::Err::Error(nom::error::Error::new(
131 input,
132 nom::error::ErrorKind::Char,
133 )));
134 };
135 index += ch.len_utf8();
136
137 if ch == quote {
138 if input[index..].starts_with(quote) {
139 value.push(quote);
140 index += quote.len_utf8();
141 } else {
142 return Ok((&input[index..], Value::String(value)));
143 }
144 } else {
145 value.push(ch);
146 }
147 }
148
149 Err(nom::Err::Error(nom::error::Error::new(
150 input,
151 nom::error::ErrorKind::Eof,
152 )))
153}
154
155fn parse_triple_quoted_string(input: &str) -> IResult<&str, Value> {
157 alt((
158 map(
160 delimited(
161 tag("'''"),
162 nom::bytes::complete::take_until("'''"),
163 tag("'''"),
164 ),
165 |s: &str| Value::String(s.to_string()),
166 ),
167 map(
169 delimited(
170 tag("\"\"\""),
171 nom::bytes::complete::take_until("\"\"\""),
172 tag("\"\"\""),
173 ),
174 |s: &str| Value::String(s.to_string()),
175 ),
176 ))
177 .parse(input)
178}
179
180fn parse_json_literal(input: &str) -> IResult<&str, Value> {
183 let trimmed = input.trim_start();
185 if trimmed.is_empty() {
186 return Err(nom::Err::Error(nom::error::Error::new(
187 input,
188 nom::error::ErrorKind::Tag,
189 )));
190 }
191
192 let (open_char, close_char) = match trimmed.chars().next() {
193 Some('{') => ('{', '}'),
194 Some('[') => ('[', ']'),
195 _ => {
196 return Err(nom::Err::Error(nom::error::Error::new(
197 input,
198 nom::error::ErrorKind::Tag,
199 )));
200 }
201 };
202
203 let mut depth = 0;
205 let mut in_string = false;
206 let mut escape_next = false;
207 let mut end_pos = 0;
208
209 for (i, c) in trimmed.char_indices() {
210 if escape_next {
211 escape_next = false;
212 continue;
213 }
214
215 if c == '\\' && in_string {
216 escape_next = true;
217 continue;
218 }
219
220 if c == '"' {
221 in_string = !in_string;
222 continue;
223 }
224
225 if !in_string {
226 if c == open_char {
227 depth += 1;
228 } else if c == close_char {
229 depth -= 1;
230 if depth == 0 {
231 end_pos = i + 1;
232 break;
233 }
234 }
235 }
236 }
237
238 if depth != 0 || end_pos == 0 {
239 return Err(nom::Err::Error(nom::error::Error::new(
240 input,
241 nom::error::ErrorKind::Eof,
242 )));
243 }
244
245 let json_str = &trimmed[..end_pos];
246 let _remaining = &trimmed[end_pos..];
247
248 if serde_json::from_str::<serde_json::Value>(json_str).is_err() {
249 return Err(nom::Err::Error(nom::error::Error::new(
250 input,
251 nom::error::ErrorKind::Verify,
252 )));
253 }
254
255 let consumed = input.len() - trimmed.len() + end_pos;
257 let remaining_original = &input[consumed..];
258
259 Ok((remaining_original, Value::Json(json_str.to_string())))
260}
261
262pub fn parse_operator(input: &str) -> IResult<&str, Operator> {
264 alt((
265 alt((
267 value(Operator::NotBetween, tag_no_case("not between")),
268 value(Operator::Between, tag_no_case("between")),
269 value(Operator::IsNotNull, tag_no_case("is not null")),
270 value(Operator::IsNull, tag_no_case("is null")),
271 value(Operator::NotIn, tag_no_case("not in")),
272 value(Operator::NotILike, tag_no_case("not ilike")),
273 value(Operator::NotLike, tag_no_case("not like")),
274 value(Operator::SimilarTo, tag_no_case("similar to")),
275 value(Operator::JsonExists, tag_no_case("json_exists")),
276 value(Operator::JsonQuery, tag_no_case("json_query")),
277 value(Operator::JsonValue, tag_no_case("json_value")),
278 value(Operator::Regex, tag_no_case("regex")),
279 value(Operator::ILike, tag_no_case("ilike")),
280 value(Operator::Like, tag_no_case("like")),
281 value(Operator::In, tag_no_case("in")),
282 )),
283 alt((
285 value(Operator::RegexI, tag("~*")),
286 value(Operator::JsonPathText, tag("#>>")),
287 value(Operator::JsonPath, tag("#>")),
288 value(Operator::TextSearch, tag("@@")),
289 value(Operator::KeyExistsAny, tag("?|")),
290 value(Operator::KeyExistsAll, tag("?&")),
291 value(Operator::Contains, tag("@>")),
292 value(Operator::ContainedBy, tag("<@")),
293 value(Operator::Overlaps, tag("&&")),
294 value(Operator::Gte, tag(">=")),
295 value(Operator::Lte, tag("<=")),
296 value(Operator::Ne, tag("!=")),
297 value(Operator::Ne, tag("<>")),
298 )),
299 alt((
301 value(Operator::Eq, tag("=")),
302 value(Operator::Gt, tag(">")),
303 value(Operator::Lt, tag("<")),
304 value(Operator::KeyExists, tag("?")),
305 value(Operator::Fuzzy, tag("~")),
306 )),
307 ))
308 .parse(input)
309}
310
311pub fn parse_action(input: &str) -> IResult<&str, (Action, bool)> {
313 alt((
314 map(
316 (tag_no_case("get"), multispace1, tag_no_case("distinct")),
317 |_| (Action::Get, true),
318 ),
319 value((Action::Get, false), tag_no_case("get")),
321 value((Action::Export, false), tag_no_case("export")),
323 alt((
325 value((Action::Cnt, false), tag_no_case("count")),
326 value((Action::Cnt, false), tag_no_case("cnt")),
327 )),
328 value((Action::Set, false), tag_no_case("set")),
330 value((Action::Merge, false), tag_no_case("merge")),
332 alt((
334 value((Action::Del, false), tag_no_case("delete")),
335 value((Action::Del, false), tag_no_case("del")),
336 )),
337 alt((
339 value((Action::Add, false), tag_no_case("insert")),
340 value((Action::Add, false), tag_no_case("add")),
341 )),
342 alt((
344 value((Action::Make, false), tag_no_case("create")),
345 value((Action::Make, false), tag_no_case("make")),
346 )),
347 ))
348 .parse(input)
349}
350
351pub fn parse_txn_command(input: &str) -> IResult<&str, Qail> {
353 let (input, action) = alt((
354 value(Action::TxnStart, tag_no_case("begin")),
355 value(Action::TxnCommit, tag_no_case("commit")),
356 value(Action::TxnRollback, tag_no_case("rollback")),
357 ))
358 .parse(input)?;
359
360 Ok((
361 input,
362 Qail {
363 action,
364 table: String::new(),
365 columns: vec![],
366 joins: vec![],
367 cages: vec![],
368 distinct: false,
369 distinct_on: vec![],
370 index_def: None,
371 table_constraints: vec![],
372 set_ops: vec![],
373 having: vec![],
374 group_by_mode: GroupByMode::default(),
375 ctes: vec![],
376 returning: None,
377 on_conflict: None,
378 merge: None,
379 source_query: None,
380 channel: None,
381 payload: None,
382 savepoint_name: None,
383 from_tables: vec![],
384 using_tables: vec![],
385 lock_mode: None,
386 skip_locked: false,
387 fetch: None,
388 default_values: false,
389 overriding: None,
390 sample: None,
391 only_table: false,
392 vector: None,
393 score_threshold: None,
394 vector_name: None,
395 with_vector: false,
396 vector_size: None,
397 distance: None,
398 on_disk: None,
399 function_def: None,
400 trigger_def: None,
401 policy_def: None,
402 },
403 ))
404}
405
406pub fn parse_procedural_command(input: &str) -> IResult<&str, Qail> {
415 alt((parse_call_command, parse_do_command, parse_session_command)).parse(input)
416}
417
418fn parse_call_command(input: &str) -> IResult<&str, Qail> {
419 let (input, _) = tag_no_case("call").parse(input)?;
420 let (input, _) = multispace1(input)?;
421
422 let procedure = input.trim().trim_end_matches(';').trim();
423 if procedure.is_empty() {
424 return Err(nom::Err::Error(nom::error::Error::new(
425 input,
426 nom::error::ErrorKind::Eof,
427 )));
428 }
429 if !is_safe_call_target(procedure) {
430 return Err(nom::Err::Error(nom::error::Error::new(
431 input,
432 nom::error::ErrorKind::Tag,
433 )));
434 }
435
436 Ok((
437 "",
438 Qail {
439 action: Action::Call,
440 table: procedure.to_string(),
441 ..Default::default()
442 },
443 ))
444}
445
446fn is_safe_call_target(procedure: &str) -> bool {
447 let procedure = procedure.trim();
448 if procedure.is_empty()
449 || procedure.contains('\0')
450 || procedure.contains(';')
451 || procedure.contains("--")
452 || procedure.contains("/*")
453 || procedure.contains("*/")
454 {
455 return false;
456 }
457
458 match procedure.split_once('(') {
459 Some((name, args)) if args.ends_with(')') && !args[..args.len() - 1].contains('(') => {
460 is_valid_qualified_ident(name.trim())
461 }
462 None => is_valid_qualified_ident(procedure),
463 _ => false,
464 }
465}
466
467fn is_valid_qualified_ident(name: &str) -> bool {
468 !name.is_empty()
469 && name.split('.').all(|part| {
470 let mut chars = part.chars();
471 matches!(chars.next(), Some(ch) if ch.is_ascii_alphabetic() || ch == '_')
472 && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
473 })
474}
475
476fn parse_do_command(input: &str) -> IResult<&str, Qail> {
477 let (input, _) = tag_no_case("do").parse(input)?;
478 let (input, _) = multispace1(input)?;
479
480 let rest = input.trim().trim_end_matches(';').trim();
481 if rest.is_empty() {
482 return Err(nom::Err::Error(nom::error::Error::new(
483 input,
484 nom::error::ErrorKind::Eof,
485 )));
486 }
487
488 let (body, language) = if let Some(after_open) = rest.strip_prefix("$$") {
490 if let Some(close_idx) = after_open.find("$$") {
491 let body = after_open[..close_idx].to_string();
492 let trailing = after_open[close_idx + 2..].trim();
493 let lang = if trailing.to_ascii_lowercase().starts_with("language ") {
494 trailing[9..].trim().to_string()
495 } else {
496 "plpgsql".to_string()
497 };
498 (body, lang)
499 } else {
500 (rest.to_string(), "plpgsql".to_string())
501 }
502 } else {
503 (rest.to_string(), "plpgsql".to_string())
504 };
505
506 Ok((
507 "",
508 Qail {
509 action: Action::Do,
510 table: language,
511 payload: Some(body),
512 ..Default::default()
513 },
514 ))
515}
516
517fn parse_session_command(input: &str) -> IResult<&str, Qail> {
518 let (input, _) = tag_no_case("session").parse(input)?;
519 let (input, _) = multispace1(input)?;
520
521 if let Ok((input, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("set").parse(input) {
523 let (input, _) = multispace1(input)?;
524 let (input, key) = parse_session_setting_key(input)?;
525 let (input, _) = multispace0(input)?;
526 let (input, _) = opt(char('=')).parse(input)?;
527 let value = input.trim().trim_end_matches(';').trim();
528 if value.is_empty() {
529 return Err(nom::Err::Error(nom::error::Error::new(
530 input,
531 nom::error::ErrorKind::Eof,
532 )));
533 }
534 let value = strip_matching_quotes(value);
535 return Ok((
536 "",
537 Qail {
538 action: Action::SessionSet,
539 table: key.to_string(),
540 payload: Some(value.to_string()),
541 ..Default::default()
542 },
543 ));
544 }
545
546 if let Ok((input, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("show").parse(input) {
548 let (input, _) = multispace1(input)?;
549 let (input, key) = parse_session_setting_key(input)?;
550 let trailing = input.trim().trim_end_matches(';').trim();
551 if !trailing.is_empty() {
552 return Err(nom::Err::Error(nom::error::Error::new(
553 input,
554 nom::error::ErrorKind::Tag,
555 )));
556 }
557 return Ok((
558 "",
559 Qail {
560 action: Action::SessionShow,
561 table: key.to_string(),
562 ..Default::default()
563 },
564 ));
565 }
566
567 let (input, _) = tag_no_case("reset").parse(input)?;
569 let (input, _) = multispace1(input)?;
570 let (input, key) = parse_session_setting_key(input)?;
571 let trailing = input.trim().trim_end_matches(';').trim();
572 if !trailing.is_empty() {
573 return Err(nom::Err::Error(nom::error::Error::new(
574 input,
575 nom::error::ErrorKind::Tag,
576 )));
577 }
578 Ok((
579 "",
580 Qail {
581 action: Action::SessionReset,
582 table: key.to_string(),
583 ..Default::default()
584 },
585 ))
586}
587
588fn parse_session_setting_key(input: &str) -> IResult<&str, &str> {
589 let end = input
590 .char_indices()
591 .find_map(|(idx, ch)| (ch.is_whitespace() || ch == '=' || ch == ';').then_some(idx))
592 .unwrap_or(input.len());
593 let key = &input[..end];
594 if is_valid_session_setting_key(key) {
595 Ok((&input[end..], key))
596 } else {
597 Err(nom::Err::Error(nom::error::Error::new(
598 input,
599 nom::error::ErrorKind::Tag,
600 )))
601 }
602}
603
604fn is_valid_session_setting_key(key: &str) -> bool {
605 !key.is_empty()
606 && key.split('.').all(|part| {
607 let mut chars = part.chars();
608 matches!(chars.next(), Some(ch) if ch.is_ascii_alphabetic() || ch == '_')
609 && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
610 })
611}
612
613fn strip_matching_quotes(s: &str) -> &str {
614 let bytes = s.as_bytes();
615 if bytes.len() >= 2 {
616 let first = bytes[0];
617 let last = bytes[bytes.len() - 1];
618 if (first == b'\'' && last == b'\'') || (first == b'"' && last == b'"') {
619 return &s[1..s.len() - 1];
620 }
621 }
622 s
623}