1use crate::ast::values::IntervalUnit;
2use crate::ast::*;
3use nom::{
4 IResult, Parser,
5 branch::alt,
6 bytes::complete::{tag, tag_no_case, take_while1},
7 character::complete::{char, digit1, multispace0, multispace1},
8 combinator::{map, opt, recognize, value},
9 sequence::{delimited, preceded},
10};
11
12pub fn parse_identifier(input: &str) -> IResult<&str, &str> {
14 take_while1(|c: char| c.is_alphanumeric() || c == '_' || c == '.').parse(input)
15}
16
17pub fn parse_interval(input: &str) -> IResult<&str, Value> {
19 let (input, num_str) = digit1(input)?;
20 let amount: i64 = num_str.parse().unwrap_or(0);
21
22 let (input, unit) = alt((
23 value(IntervalUnit::Second, tag_no_case("s")),
24 value(IntervalUnit::Minute, tag_no_case("m")),
25 value(IntervalUnit::Hour, tag_no_case("h")),
26 value(IntervalUnit::Day, tag_no_case("d")),
27 value(IntervalUnit::Week, tag_no_case("w")),
28 value(IntervalUnit::Month, tag_no_case("mo")),
29 value(IntervalUnit::Year, tag_no_case("y")),
30 ))
31 .parse(input)?;
32
33 Ok((input, Value::Interval { amount, unit }))
34}
35
36pub fn parse_value(input: &str) -> IResult<&str, Value> {
38 alt((
39 map(preceded(char('$'), digit1), |d: &str| {
41 Value::Param(d.parse().unwrap_or(0))
42 }),
43 map(
45 preceded(
46 char(':'),
47 take_while1(|c: char| c.is_alphanumeric() || c == '_'),
48 ),
49 |name: &str| Value::NamedParam(name.to_string()),
50 ),
51 value(Value::Bool(true), tag_no_case("true")),
53 value(Value::Bool(false), tag_no_case("false")),
54 value(Value::Null, tag_no_case("null")),
56 parse_triple_quoted_string,
58 parse_json_literal,
60 map(
62 delimited(
63 char('"'),
64 nom::bytes::complete::take_while(|c| c != '"'),
65 char('"'),
66 ),
67 |s: &str| Value::String(s.to_string()),
68 ),
69 map(
71 delimited(
72 char('\''),
73 nom::bytes::complete::take_while(|c| c != '\''),
74 char('\''),
75 ),
76 |s: &str| Value::String(s.to_string()),
77 ),
78 map(
80 recognize((opt(char('-')), digit1, char('.'), digit1)),
81 |s: &str| Value::Float(s.parse().unwrap_or(0.0)),
82 ),
83 parse_interval,
85 map(recognize((opt(char('-')), digit1)), |s: &str| {
87 Value::Int(s.parse().unwrap_or(0))
88 }),
89 ))
90 .parse(input)
91}
92
93fn parse_triple_quoted_string(input: &str) -> IResult<&str, Value> {
95 alt((
96 map(
98 delimited(
99 tag("'''"),
100 nom::bytes::complete::take_until("'''"),
101 tag("'''"),
102 ),
103 |s: &str| Value::String(s.to_string()),
104 ),
105 map(
107 delimited(
108 tag("\"\"\""),
109 nom::bytes::complete::take_until("\"\"\""),
110 tag("\"\"\""),
111 ),
112 |s: &str| Value::String(s.to_string()),
113 ),
114 ))
115 .parse(input)
116}
117
118fn parse_json_literal(input: &str) -> IResult<&str, Value> {
121 let trimmed = input.trim_start();
123 if trimmed.is_empty() {
124 return Err(nom::Err::Error(nom::error::Error::new(
125 input,
126 nom::error::ErrorKind::Tag,
127 )));
128 }
129
130 let (open_char, close_char) = match trimmed.chars().next() {
131 Some('{') => ('{', '}'),
132 Some('[') => ('[', ']'),
133 _ => {
134 return Err(nom::Err::Error(nom::error::Error::new(
135 input,
136 nom::error::ErrorKind::Tag,
137 )));
138 }
139 };
140
141 let mut depth = 0;
143 let mut in_string = false;
144 let mut escape_next = false;
145 let mut end_pos = 0;
146
147 for (i, c) in trimmed.char_indices() {
148 if escape_next {
149 escape_next = false;
150 continue;
151 }
152
153 if c == '\\' && in_string {
154 escape_next = true;
155 continue;
156 }
157
158 if c == '"' {
159 in_string = !in_string;
160 continue;
161 }
162
163 if !in_string {
164 if c == open_char {
165 depth += 1;
166 } else if c == close_char {
167 depth -= 1;
168 if depth == 0 {
169 end_pos = i + 1;
170 break;
171 }
172 }
173 }
174 }
175
176 if depth != 0 || end_pos == 0 {
177 return Err(nom::Err::Error(nom::error::Error::new(
178 input,
179 nom::error::ErrorKind::Eof,
180 )));
181 }
182
183 let json_str = &trimmed[..end_pos];
184 let _remaining = &trimmed[end_pos..];
185
186 let consumed = input.len() - trimmed.len() + end_pos;
188 let remaining_original = &input[consumed..];
189
190 Ok((remaining_original, Value::Json(json_str.to_string())))
191}
192
193pub fn parse_operator(input: &str) -> IResult<&str, Operator> {
195 alt((
196 alt((
198 value(Operator::NotBetween, tag_no_case("not between")),
199 value(Operator::Between, tag_no_case("between")),
200 value(Operator::IsNotNull, tag_no_case("is not null")),
201 value(Operator::IsNull, tag_no_case("is null")),
202 value(Operator::NotIn, tag_no_case("not in")),
203 value(Operator::NotILike, tag_no_case("not ilike")),
204 value(Operator::NotLike, tag_no_case("not like")),
205 value(Operator::SimilarTo, tag_no_case("similar to")),
206 value(Operator::JsonExists, tag_no_case("json_exists")),
207 value(Operator::JsonQuery, tag_no_case("json_query")),
208 value(Operator::JsonValue, tag_no_case("json_value")),
209 value(Operator::Regex, tag_no_case("regex")),
210 value(Operator::ILike, tag_no_case("ilike")),
211 value(Operator::Like, tag_no_case("like")),
212 value(Operator::In, tag_no_case("in")),
213 )),
214 alt((
216 value(Operator::RegexI, tag("~*")),
217 value(Operator::JsonPathText, tag("#>>")),
218 value(Operator::JsonPath, tag("#>")),
219 value(Operator::TextSearch, tag("@@")),
220 value(Operator::KeyExistsAny, tag("?|")),
221 value(Operator::KeyExistsAll, tag("?&")),
222 value(Operator::Contains, tag("@>")),
223 value(Operator::ContainedBy, tag("<@")),
224 value(Operator::Overlaps, tag("&&")),
225 value(Operator::Gte, tag(">=")),
226 value(Operator::Lte, tag("<=")),
227 value(Operator::Ne, tag("!=")),
228 value(Operator::Ne, tag("<>")),
229 )),
230 alt((
232 value(Operator::Eq, tag("=")),
233 value(Operator::Gt, tag(">")),
234 value(Operator::Lt, tag("<")),
235 value(Operator::KeyExists, tag("?")),
236 value(Operator::Fuzzy, tag("~")),
237 )),
238 ))
239 .parse(input)
240}
241
242pub fn parse_action(input: &str) -> IResult<&str, (Action, bool)> {
244 alt((
245 map(
247 (tag_no_case("get"), multispace1, tag_no_case("distinct")),
248 |_| (Action::Get, true),
249 ),
250 value((Action::Get, false), tag_no_case("get")),
252 value((Action::Export, false), tag_no_case("export")),
254 alt((
256 value((Action::Cnt, false), tag_no_case("count")),
257 value((Action::Cnt, false), tag_no_case("cnt")),
258 )),
259 value((Action::Set, false), tag_no_case("set")),
261 alt((
263 value((Action::Del, false), tag_no_case("delete")),
264 value((Action::Del, false), tag_no_case("del")),
265 )),
266 alt((
268 value((Action::Add, false), tag_no_case("insert")),
269 value((Action::Add, false), tag_no_case("add")),
270 )),
271 alt((
273 value((Action::Make, false), tag_no_case("create")),
274 value((Action::Make, false), tag_no_case("make")),
275 )),
276 ))
277 .parse(input)
278}
279
280pub fn parse_txn_command(input: &str) -> IResult<&str, Qail> {
282 let (input, action) = alt((
283 value(Action::TxnStart, tag_no_case("begin")),
284 value(Action::TxnCommit, tag_no_case("commit")),
285 value(Action::TxnRollback, tag_no_case("rollback")),
286 ))
287 .parse(input)?;
288
289 Ok((
290 input,
291 Qail {
292 action,
293 table: String::new(),
294 columns: vec![],
295 joins: vec![],
296 cages: vec![],
297 distinct: false,
298 distinct_on: vec![],
299 index_def: None,
300 table_constraints: vec![],
301 set_ops: vec![],
302 having: vec![],
303 group_by_mode: GroupByMode::default(),
304 ctes: vec![],
305 returning: None,
306 on_conflict: None,
307 source_query: None,
308 channel: None,
309 payload: None,
310 savepoint_name: None,
311 from_tables: vec![],
312 using_tables: vec![],
313 lock_mode: None,
314 skip_locked: false,
315 fetch: None,
316 default_values: false,
317 overriding: None,
318 sample: None,
319 only_table: false,
320 vector: None,
321 score_threshold: None,
322 vector_name: None,
323 with_vector: false,
324 vector_size: None,
325 distance: None,
326 on_disk: None,
327 function_def: None,
328 trigger_def: None,
329 },
330 ))
331}
332
333pub fn parse_procedural_command(input: &str) -> IResult<&str, Qail> {
342 alt((parse_call_command, parse_do_command, parse_session_command)).parse(input)
343}
344
345fn parse_call_command(input: &str) -> IResult<&str, Qail> {
346 let (input, _) = tag_no_case("call").parse(input)?;
347 let (input, _) = multispace1(input)?;
348
349 let procedure = input.trim().trim_end_matches(';').trim();
350 if procedure.is_empty() {
351 return Err(nom::Err::Error(nom::error::Error::new(
352 input,
353 nom::error::ErrorKind::Eof,
354 )));
355 }
356
357 Ok((
358 "",
359 Qail {
360 action: Action::Call,
361 table: procedure.to_string(),
362 ..Default::default()
363 },
364 ))
365}
366
367fn parse_do_command(input: &str) -> IResult<&str, Qail> {
368 let (input, _) = tag_no_case("do").parse(input)?;
369 let (input, _) = multispace1(input)?;
370
371 let rest = input.trim().trim_end_matches(';').trim();
372 if rest.is_empty() {
373 return Err(nom::Err::Error(nom::error::Error::new(
374 input,
375 nom::error::ErrorKind::Eof,
376 )));
377 }
378
379 let (body, language) = if let Some(after_open) = rest.strip_prefix("$$") {
381 if let Some(close_idx) = after_open.find("$$") {
382 let body = after_open[..close_idx].to_string();
383 let trailing = after_open[close_idx + 2..].trim();
384 let lang = if trailing.to_ascii_lowercase().starts_with("language ") {
385 trailing[9..].trim().to_string()
386 } else {
387 "plpgsql".to_string()
388 };
389 (body, lang)
390 } else {
391 (rest.to_string(), "plpgsql".to_string())
392 }
393 } else {
394 (rest.to_string(), "plpgsql".to_string())
395 };
396
397 Ok((
398 "",
399 Qail {
400 action: Action::Do,
401 table: language,
402 payload: Some(body),
403 ..Default::default()
404 },
405 ))
406}
407
408fn parse_session_command(input: &str) -> IResult<&str, Qail> {
409 let (input, _) = tag_no_case("session").parse(input)?;
410 let (input, _) = multispace1(input)?;
411
412 if let Ok((input, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("set").parse(input) {
414 let (input, _) = multispace1(input)?;
415 let (input, key) = parse_identifier(input)?;
416 let (input, _) = multispace0(input)?;
417 let (input, _) = opt(char('=')).parse(input)?;
418 let value = input.trim().trim_end_matches(';').trim();
419 if value.is_empty() {
420 return Err(nom::Err::Error(nom::error::Error::new(
421 input,
422 nom::error::ErrorKind::Eof,
423 )));
424 }
425 let value = strip_matching_quotes(value);
426 return Ok((
427 "",
428 Qail {
429 action: Action::SessionSet,
430 table: key.to_string(),
431 payload: Some(value.to_string()),
432 ..Default::default()
433 },
434 ));
435 }
436
437 if let Ok((input, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("show").parse(input) {
439 let (input, _) = multispace1(input)?;
440 let (input, key) = parse_identifier(input)?;
441 let trailing = input.trim().trim_end_matches(';').trim();
442 if !trailing.is_empty() {
443 return Err(nom::Err::Error(nom::error::Error::new(
444 input,
445 nom::error::ErrorKind::Tag,
446 )));
447 }
448 return Ok((
449 "",
450 Qail {
451 action: Action::SessionShow,
452 table: key.to_string(),
453 ..Default::default()
454 },
455 ));
456 }
457
458 let (input, _) = tag_no_case("reset").parse(input)?;
460 let (input, _) = multispace1(input)?;
461 let (input, key) = parse_identifier(input)?;
462 let trailing = input.trim().trim_end_matches(';').trim();
463 if !trailing.is_empty() {
464 return Err(nom::Err::Error(nom::error::Error::new(
465 input,
466 nom::error::ErrorKind::Tag,
467 )));
468 }
469 Ok((
470 "",
471 Qail {
472 action: Action::SessionReset,
473 table: key.to_string(),
474 ..Default::default()
475 },
476 ))
477}
478
479fn strip_matching_quotes(s: &str) -> &str {
480 let bytes = s.as_bytes();
481 if bytes.len() >= 2 {
482 let first = bytes[0];
483 let last = bytes[bytes.len() - 1];
484 if (first == b'\'' && last == b'\'') || (first == b'"' && last == b'"') {
485 return &s[1..s.len() - 1];
486 }
487 }
488 s
489}