1use crate::ast::values::IntervalUnit;
2use crate::ast::*;
3use nom::{
4 IResult, Parser,
5 branch::alt,
6 bytes::complete::{tag, tag_no_case, take_while1},
7 character::complete::{char, digit1, multispace0, multispace1},
8 combinator::{map, opt, recognize, value},
9 sequence::{delimited, preceded},
10};
11
12pub fn parse_identifier(input: &str) -> IResult<&str, &str> {
14 take_while1(|c: char| c.is_alphanumeric() || c == '_' || c == '.').parse(input)
15}
16
17pub fn parse_interval(input: &str) -> IResult<&str, Value> {
19 let (input, num_str) = digit1(input)?;
20 let amount: i64 = num_str.parse().unwrap_or(0);
21
22 let (input, unit) = alt((
23 value(IntervalUnit::Second, tag_no_case("s")),
24 value(IntervalUnit::Minute, tag_no_case("m")),
25 value(IntervalUnit::Hour, tag_no_case("h")),
26 value(IntervalUnit::Day, tag_no_case("d")),
27 value(IntervalUnit::Week, tag_no_case("w")),
28 value(IntervalUnit::Month, tag_no_case("mo")),
29 value(IntervalUnit::Year, tag_no_case("y")),
30 ))
31 .parse(input)?;
32
33 Ok((input, Value::Interval { amount, unit }))
34}
35
36pub fn parse_value(input: &str) -> IResult<&str, Value> {
38 alt((
39 map(preceded(char('$'), digit1), |d: &str| {
41 Value::Param(d.parse().unwrap_or(0))
42 }),
43 map(
45 preceded(
46 char(':'),
47 take_while1(|c: char| c.is_alphanumeric() || c == '_'),
48 ),
49 |name: &str| Value::NamedParam(name.to_string()),
50 ),
51 value(Value::Bool(true), tag_no_case("true")),
53 value(Value::Bool(false), tag_no_case("false")),
54 value(Value::Null, tag_no_case("null")),
56 parse_triple_quoted_string,
58 parse_json_literal,
60 map(
62 delimited(
63 char('"'),
64 nom::bytes::complete::take_while(|c| c != '"'),
65 char('"'),
66 ),
67 |s: &str| Value::String(s.to_string()),
68 ),
69 map(
71 delimited(
72 char('\''),
73 nom::bytes::complete::take_while(|c| c != '\''),
74 char('\''),
75 ),
76 |s: &str| Value::String(s.to_string()),
77 ),
78 map(
80 recognize((opt(char('-')), digit1, char('.'), digit1)),
81 |s: &str| Value::Float(s.parse().unwrap_or(0.0)),
82 ),
83 parse_interval,
85 map(recognize((opt(char('-')), digit1)), |s: &str| {
87 Value::Int(s.parse().unwrap_or(0))
88 }),
89 ))
90 .parse(input)
91}
92
93fn parse_triple_quoted_string(input: &str) -> IResult<&str, Value> {
95 alt((
96 map(
98 delimited(
99 tag("'''"),
100 nom::bytes::complete::take_until("'''"),
101 tag("'''"),
102 ),
103 |s: &str| Value::String(s.to_string()),
104 ),
105 map(
107 delimited(
108 tag("\"\"\""),
109 nom::bytes::complete::take_until("\"\"\""),
110 tag("\"\"\""),
111 ),
112 |s: &str| Value::String(s.to_string()),
113 ),
114 ))
115 .parse(input)
116}
117
118fn parse_json_literal(input: &str) -> IResult<&str, Value> {
121 let trimmed = input.trim_start();
123 if trimmed.is_empty() {
124 return Err(nom::Err::Error(nom::error::Error::new(
125 input,
126 nom::error::ErrorKind::Tag,
127 )));
128 }
129
130 let (open_char, close_char) = match trimmed.chars().next() {
131 Some('{') => ('{', '}'),
132 Some('[') => ('[', ']'),
133 _ => {
134 return Err(nom::Err::Error(nom::error::Error::new(
135 input,
136 nom::error::ErrorKind::Tag,
137 )));
138 }
139 };
140
141 let mut depth = 0;
143 let mut in_string = false;
144 let mut escape_next = false;
145 let mut end_pos = 0;
146
147 for (i, c) in trimmed.char_indices() {
148 if escape_next {
149 escape_next = false;
150 continue;
151 }
152
153 if c == '\\' && in_string {
154 escape_next = true;
155 continue;
156 }
157
158 if c == '"' {
159 in_string = !in_string;
160 continue;
161 }
162
163 if !in_string {
164 if c == open_char {
165 depth += 1;
166 } else if c == close_char {
167 depth -= 1;
168 if depth == 0 {
169 end_pos = i + 1;
170 break;
171 }
172 }
173 }
174 }
175
176 if depth != 0 || end_pos == 0 {
177 return Err(nom::Err::Error(nom::error::Error::new(
178 input,
179 nom::error::ErrorKind::Eof,
180 )));
181 }
182
183 let json_str = &trimmed[..end_pos];
184 let _remaining = &trimmed[end_pos..];
185
186 let consumed = input.len() - trimmed.len() + end_pos;
188 let remaining_original = &input[consumed..];
189
190 Ok((remaining_original, Value::Json(json_str.to_string())))
191}
192
193pub fn parse_operator(input: &str) -> IResult<&str, Operator> {
195 alt((
196 alt((
198 value(Operator::NotBetween, tag_no_case("not between")),
199 value(Operator::Between, tag_no_case("between")),
200 value(Operator::IsNotNull, tag_no_case("is not null")),
201 value(Operator::IsNull, tag_no_case("is null")),
202 value(Operator::NotIn, tag_no_case("not in")),
203 value(Operator::NotILike, tag_no_case("not ilike")),
204 value(Operator::NotLike, tag_no_case("not like")),
205 value(Operator::SimilarTo, tag_no_case("similar to")),
206 value(Operator::JsonExists, tag_no_case("json_exists")),
207 value(Operator::JsonQuery, tag_no_case("json_query")),
208 value(Operator::JsonValue, tag_no_case("json_value")),
209 value(Operator::Regex, tag_no_case("regex")),
210 value(Operator::ILike, tag_no_case("ilike")),
211 value(Operator::Like, tag_no_case("like")),
212 value(Operator::In, tag_no_case("in")),
213 )),
214 alt((
216 value(Operator::RegexI, tag("~*")),
217 value(Operator::JsonPathText, tag("#>>")),
218 value(Operator::JsonPath, tag("#>")),
219 value(Operator::TextSearch, tag("@@")),
220 value(Operator::KeyExistsAny, tag("?|")),
221 value(Operator::KeyExistsAll, tag("?&")),
222 value(Operator::Contains, tag("@>")),
223 value(Operator::ContainedBy, tag("<@")),
224 value(Operator::Overlaps, tag("&&")),
225 value(Operator::Gte, tag(">=")),
226 value(Operator::Lte, tag("<=")),
227 value(Operator::Ne, tag("!=")),
228 value(Operator::Ne, tag("<>")),
229 )),
230 alt((
232 value(Operator::Eq, tag("=")),
233 value(Operator::Gt, tag(">")),
234 value(Operator::Lt, tag("<")),
235 value(Operator::KeyExists, tag("?")),
236 value(Operator::Fuzzy, tag("~")),
237 )),
238 ))
239 .parse(input)
240}
241
242pub fn parse_action(input: &str) -> IResult<&str, (Action, bool)> {
244 alt((
245 map(
247 (tag_no_case("get"), multispace1, tag_no_case("distinct")),
248 |_| (Action::Get, true),
249 ),
250 value((Action::Get, false), tag_no_case("get")),
252 value((Action::Export, false), tag_no_case("export")),
254 alt((
256 value((Action::Cnt, false), tag_no_case("count")),
257 value((Action::Cnt, false), tag_no_case("cnt")),
258 )),
259 value((Action::Set, false), tag_no_case("set")),
261 alt((
263 value((Action::Del, false), tag_no_case("delete")),
264 value((Action::Del, false), tag_no_case("del")),
265 )),
266 alt((
268 value((Action::Add, false), tag_no_case("insert")),
269 value((Action::Add, false), tag_no_case("add")),
270 )),
271 alt((
273 value((Action::Make, false), tag_no_case("create")),
274 value((Action::Make, false), tag_no_case("make")),
275 )),
276 ))
277 .parse(input)
278}
279
280pub fn parse_txn_command(input: &str) -> IResult<&str, Qail> {
282 let (input, action) = alt((
283 value(Action::TxnStart, tag_no_case("begin")),
284 value(Action::TxnCommit, tag_no_case("commit")),
285 value(Action::TxnRollback, tag_no_case("rollback")),
286 ))
287 .parse(input)?;
288
289 Ok((
290 input,
291 Qail {
292 action,
293 table: String::new(),
294 columns: vec![],
295 joins: vec![],
296 cages: vec![],
297 distinct: false,
298 distinct_on: vec![],
299 index_def: None,
300 table_constraints: vec![],
301 set_ops: vec![],
302 having: vec![],
303 group_by_mode: GroupByMode::default(),
304 ctes: vec![],
305 returning: None,
306 on_conflict: None,
307 source_query: None,
308 channel: None,
309 payload: None,
310 savepoint_name: None,
311 from_tables: vec![],
312 using_tables: vec![],
313 lock_mode: None,
314 skip_locked: false,
315 fetch: None,
316 default_values: false,
317 overriding: None,
318 sample: None,
319 only_table: false,
320 vector: None,
321 score_threshold: None,
322 vector_name: None,
323 with_vector: false,
324 vector_size: None,
325 distance: None,
326 on_disk: None,
327 function_def: None,
328 trigger_def: None,
329 policy_def: None,
330 },
331 ))
332}
333
334pub fn parse_procedural_command(input: &str) -> IResult<&str, Qail> {
343 alt((parse_call_command, parse_do_command, parse_session_command)).parse(input)
344}
345
346fn parse_call_command(input: &str) -> IResult<&str, Qail> {
347 let (input, _) = tag_no_case("call").parse(input)?;
348 let (input, _) = multispace1(input)?;
349
350 let procedure = input.trim().trim_end_matches(';').trim();
351 if procedure.is_empty() {
352 return Err(nom::Err::Error(nom::error::Error::new(
353 input,
354 nom::error::ErrorKind::Eof,
355 )));
356 }
357
358 Ok((
359 "",
360 Qail {
361 action: Action::Call,
362 table: procedure.to_string(),
363 ..Default::default()
364 },
365 ))
366}
367
368fn parse_do_command(input: &str) -> IResult<&str, Qail> {
369 let (input, _) = tag_no_case("do").parse(input)?;
370 let (input, _) = multispace1(input)?;
371
372 let rest = input.trim().trim_end_matches(';').trim();
373 if rest.is_empty() {
374 return Err(nom::Err::Error(nom::error::Error::new(
375 input,
376 nom::error::ErrorKind::Eof,
377 )));
378 }
379
380 let (body, language) = if let Some(after_open) = rest.strip_prefix("$$") {
382 if let Some(close_idx) = after_open.find("$$") {
383 let body = after_open[..close_idx].to_string();
384 let trailing = after_open[close_idx + 2..].trim();
385 let lang = if trailing.to_ascii_lowercase().starts_with("language ") {
386 trailing[9..].trim().to_string()
387 } else {
388 "plpgsql".to_string()
389 };
390 (body, lang)
391 } else {
392 (rest.to_string(), "plpgsql".to_string())
393 }
394 } else {
395 (rest.to_string(), "plpgsql".to_string())
396 };
397
398 Ok((
399 "",
400 Qail {
401 action: Action::Do,
402 table: language,
403 payload: Some(body),
404 ..Default::default()
405 },
406 ))
407}
408
409fn parse_session_command(input: &str) -> IResult<&str, Qail> {
410 let (input, _) = tag_no_case("session").parse(input)?;
411 let (input, _) = multispace1(input)?;
412
413 if let Ok((input, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("set").parse(input) {
415 let (input, _) = multispace1(input)?;
416 let (input, key) = parse_identifier(input)?;
417 let (input, _) = multispace0(input)?;
418 let (input, _) = opt(char('=')).parse(input)?;
419 let value = input.trim().trim_end_matches(';').trim();
420 if value.is_empty() {
421 return Err(nom::Err::Error(nom::error::Error::new(
422 input,
423 nom::error::ErrorKind::Eof,
424 )));
425 }
426 let value = strip_matching_quotes(value);
427 return Ok((
428 "",
429 Qail {
430 action: Action::SessionSet,
431 table: key.to_string(),
432 payload: Some(value.to_string()),
433 ..Default::default()
434 },
435 ));
436 }
437
438 if let Ok((input, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("show").parse(input) {
440 let (input, _) = multispace1(input)?;
441 let (input, key) = parse_identifier(input)?;
442 let trailing = input.trim().trim_end_matches(';').trim();
443 if !trailing.is_empty() {
444 return Err(nom::Err::Error(nom::error::Error::new(
445 input,
446 nom::error::ErrorKind::Tag,
447 )));
448 }
449 return Ok((
450 "",
451 Qail {
452 action: Action::SessionShow,
453 table: key.to_string(),
454 ..Default::default()
455 },
456 ));
457 }
458
459 let (input, _) = tag_no_case("reset").parse(input)?;
461 let (input, _) = multispace1(input)?;
462 let (input, key) = parse_identifier(input)?;
463 let trailing = input.trim().trim_end_matches(';').trim();
464 if !trailing.is_empty() {
465 return Err(nom::Err::Error(nom::error::Error::new(
466 input,
467 nom::error::ErrorKind::Tag,
468 )));
469 }
470 Ok((
471 "",
472 Qail {
473 action: Action::SessionReset,
474 table: key.to_string(),
475 ..Default::default()
476 },
477 ))
478}
479
480fn strip_matching_quotes(s: &str) -> &str {
481 let bytes = s.as_bytes();
482 if bytes.len() >= 2 {
483 let first = bytes[0];
484 let last = bytes[bytes.len() - 1];
485 if (first == b'\'' && last == b'\'') || (first == b'"' && last == b'"') {
486 return &s[1..s.len() - 1];
487 }
488 }
489 s
490}