1use nom::{
16 IResult, Parser,
17 branch::alt,
18 bytes::complete::{tag, tag_no_case, take_while1},
19 character::complete::{char, multispace0, multispace1, not_line_ending},
20 combinator::map,
21 multi::{many0, separated_list0},
22};
23use std::collections::HashSet;
24
25#[derive(Debug, Clone, Default)]
27pub struct QueryFile {
28 pub queries: Vec<QueryDef>,
30}
31
32#[derive(Debug, Clone)]
34pub struct QueryDef {
35 pub name: String,
37 pub params: Vec<QueryParam>,
39 pub return_type: Option<ReturnType>,
41 pub body: String,
43 pub is_execute: bool,
45}
46
47#[derive(Debug, Clone)]
49pub struct QueryParam {
50 pub name: String,
52 pub typ: String,
54}
55
56#[derive(Debug, Clone)]
58pub enum ReturnType {
59 Single(String),
61 Vec(String),
63 Option(String),
65}
66
67impl QueryFile {
68 pub fn parse(input: &str) -> Result<Self, String> {
70 match parse_query_file(input) {
71 Ok(("", qf)) => Ok(qf),
72 Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
73 Err(e) => Err(format!("Parse error: {:?}", e)),
74 }
75 }
76
77 pub fn find_query(&self, name: &str) -> Option<&QueryDef> {
79 self.queries
80 .iter()
81 .find(|q| q.name.eq_ignore_ascii_case(name))
82 }
83}
84
85fn identifier(input: &str) -> IResult<&str, &str> {
91 let (remaining, ident) =
92 take_while1(|c: char| c.is_ascii_alphanumeric() || c == '_').parse(input)?;
93 if ident
94 .chars()
95 .next()
96 .is_some_and(|c| c.is_ascii_alphabetic() || c == '_')
97 {
98 Ok((remaining, ident))
99 } else {
100 Err(nom::Err::Error(nom::error::Error::new(
101 input,
102 nom::error::ErrorKind::TakeWhile1,
103 )))
104 }
105}
106
107fn rust_type_expr(input: &str) -> IResult<&str, &str> {
108 let mut angle_depth = 0usize;
109 let mut end = None;
110
111 for (idx, ch) in input.char_indices() {
112 match ch {
113 '<' => {
114 angle_depth += 1;
115 }
116 '>' => {
117 let Some(next) = angle_depth.checked_sub(1) else {
118 return Err(nom::Err::Error(nom::error::Error::new(
119 input,
120 nom::error::ErrorKind::TakeWhile1,
121 )));
122 };
123 angle_depth = next;
124 }
125 ',' | ')' if angle_depth == 0 => {
126 end = Some(idx);
127 break;
128 }
129 ':' if angle_depth == 0
130 && !input[..idx].ends_with(':')
131 && !input[idx + ch.len_utf8()..].starts_with(':') =>
132 {
133 end = Some(idx);
134 break;
135 }
136 c if c.is_whitespace() && angle_depth == 0 => {
137 end = Some(idx);
138 break;
139 }
140 c if c.is_ascii_alphanumeric() || matches!(c, '_' | ':' | '[' | ']' | '.') => {}
141 _ => {
142 return Err(nom::Err::Error(nom::error::Error::new(
143 input,
144 nom::error::ErrorKind::TakeWhile1,
145 )));
146 }
147 }
148 }
149
150 let end = end.unwrap_or(input.len());
151 if end == 0 || angle_depth != 0 {
152 return Err(nom::Err::Error(nom::error::Error::new(
153 input,
154 nom::error::ErrorKind::TakeWhile1,
155 )));
156 }
157 let typ = &input[..end];
158 if !validate_rust_type_expr(typ) {
159 return Err(nom::Err::Error(nom::error::Error::new(
160 input,
161 nom::error::ErrorKind::TakeWhile1,
162 )));
163 }
164 Ok((&input[end..], typ))
165}
166
167fn validate_rust_type_expr(typ: &str) -> bool {
168 validate_rust_type_generics(typ)
169 && validate_rust_type_paths(typ)
170 && validate_rust_type_group_adjacency(typ)
171}
172
173fn validate_rust_type_generics(typ: &str) -> bool {
174 let mut stack: Vec<Option<char>> = Vec::new();
175
176 for ch in typ.chars() {
177 match ch {
178 '<' => stack.push(None),
179 '>' => {
180 let Some(last_arg) = stack.pop() else {
181 return false;
182 };
183 if !matches!(last_arg, Some('t')) {
184 return false;
185 }
186 if let Some(parent) = stack.last_mut() {
187 *parent = Some('t');
188 }
189 }
190 ',' if !stack.is_empty() => {
191 let Some(current) = stack.last_mut() else {
192 return false;
193 };
194 if !matches!(current, Some('t')) {
195 return false;
196 }
197 *current = Some(',');
198 }
199 c if c.is_whitespace() => {}
200 _ if !stack.is_empty() => {
201 if let Some(current) = stack.last_mut() {
202 *current = Some('t');
203 }
204 }
205 _ => {}
206 }
207 }
208
209 stack.is_empty()
210}
211
212fn validate_rust_type_paths(typ: &str) -> bool {
213 let mut token = String::new();
214
215 for ch in typ.chars() {
216 if ch.is_ascii_alphanumeric() || matches!(ch, '_' | ':' | '.') {
217 token.push(ch);
218 continue;
219 }
220
221 if !token.is_empty() {
222 if !validate_rust_type_path_token(&token) {
223 return false;
224 }
225 token.clear();
226 }
227
228 if !(ch.is_whitespace() || matches!(ch, '<' | '>' | ',' | '[' | ']')) {
229 return false;
230 }
231 }
232
233 token.is_empty() || validate_rust_type_path_token(&token)
234}
235
236fn validate_rust_type_path_token(token: &str) -> bool {
237 if token.is_empty()
238 || token.contains(":::")
239 || token.contains("..")
240 || token.contains(".:")
241 || token.contains(":.")
242 || token.starts_with(':')
243 || token.starts_with('.')
244 || token.ends_with(':')
245 || token.ends_with('.')
246 {
247 return false;
248 }
249
250 token.replace("::", ".").split('.').all(|part| {
251 let mut chars = part.chars();
252 matches!(chars.next(), Some(ch) if ch.is_ascii_alphabetic() || ch == '_')
253 && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
254 })
255}
256
257fn validate_rust_type_group_adjacency(typ: &str) -> bool {
258 let mut prev_sig: Option<char> = None;
259 let mut depth = 0usize;
260
261 for ch in typ.chars().filter(|ch| !ch.is_whitespace()) {
262 match ch {
263 '<' => {
264 if matches!(prev_sig, Some('<' | '>' | ',')) {
265 return false;
266 }
267 depth += 1;
268 }
269 '>' => {
270 if matches!(prev_sig, None | Some('<' | ',')) {
271 return false;
272 }
273 let Some(next) = depth.checked_sub(1) else {
274 return false;
275 };
276 depth = next;
277 }
278 ',' if depth == 0 || matches!(prev_sig, None | Some('<' | ',')) => return false,
279 ',' => {}
280 _ => {}
281 }
282 prev_sig = Some(ch);
283 }
284
285 !matches!(prev_sig, Some('<' | ','))
286}
287
288fn ws_and_comments(input: &str) -> IResult<&str, ()> {
290 let (input, _) = many0(alt((
291 map(multispace1, |_| ()),
292 map((tag("--"), not_line_ending), |_| ()),
293 )))
294 .parse(input)?;
295 Ok((input, ()))
296}
297
298fn parse_param(input: &str) -> IResult<&str, QueryParam> {
300 let (input, _) = multispace0(input)?;
301 let (input, name) = identifier(input)?;
302 let (input, _) = multispace0(input)?;
303 let (input, _) = char(':').parse(input)?;
304 let (input, _) = multispace0(input)?;
305 let (input, typ) = rust_type_expr(input)?;
306
307 Ok((
308 input,
309 QueryParam {
310 name: name.to_string(),
311 typ: typ.to_string(),
312 },
313 ))
314}
315
316fn parse_params(input: &str) -> IResult<&str, Vec<QueryParam>> {
318 let (input, _) = char('(').parse(input)?;
319 let (input, params) = separated_list0(char(','), parse_param).parse(input)?;
320 let (input, _) = multispace0(input)?;
321 let (input, _) = char(')').parse(input)?;
322
323 Ok((input, params))
324}
325
326fn parse_return_type(input: &str) -> IResult<&str, ReturnType> {
328 let (input, _) = multispace0(input)?;
329 let (input, _) = tag("->").parse(input)?;
330 let (input, _) = multispace0(input)?;
331
332 let (input, typ) = rust_type_expr(input)?;
333 if let Some(inner) = strip_outer_generic(typ, "Vec") {
334 return Ok((input, ReturnType::Vec(inner.to_string())));
335 }
336 if let Some(inner) = strip_outer_generic(typ, "Option") {
337 return Ok((input, ReturnType::Option(inner.to_string())));
338 }
339 Ok((input, ReturnType::Single(typ.to_string())))
340}
341
342fn strip_outer_generic<'a>(typ: &'a str, outer: &str) -> Option<&'a str> {
343 let inner = typ
344 .strip_prefix(outer)?
345 .strip_prefix('<')?
346 .strip_suffix('>')?;
347 (!inner.is_empty()).then_some(inner)
348}
349
350fn parse_body(input: &str) -> IResult<&str, &str> {
352 let (input, _) = multispace0(input)?;
353 let (input, _) = char(':').parse(input)?;
354 let (input, _) = multispace0(input)?;
355
356 let mut end = input.len();
358
359 for (i, _) in input.char_indices() {
360 if i == 0 || input.as_bytes().get(i.saturating_sub(1)) == Some(&b'\n') {
361 let line_rest = &input[i..];
363 let trimmed = line_rest.trim_start();
364 if trimmed.starts_with("query ") || trimmed.starts_with("execute ") {
365 let ws_len = line_rest.len() - trimmed.len();
367 end = i + ws_len;
368 break;
369 }
370 }
371 }
372
373 let body = input[..end].trim();
374 Ok((&input[end..], body))
375}
376
377fn parse_query_def(input: &str) -> IResult<&str, QueryDef> {
379 let (input, _) = ws_and_comments(input)?;
380
381 let (input, is_execute) = alt((
382 map(tag_no_case("query"), |_| false),
383 map(tag_no_case("execute"), |_| true),
384 ))
385 .parse(input)?;
386
387 let (input, _) = multispace1(input)?;
388 let (input, name) = identifier(input)?;
389 let (input, params) = parse_params(input)?;
390 if !query_params_are_unique(¶ms) {
391 return Err(nom::Err::Error(nom::error::Error::new(
392 input,
393 nom::error::ErrorKind::Verify,
394 )));
395 }
396
397 let (input, return_type) = if is_execute {
399 (input, None)
400 } else {
401 let (input, rt) = parse_return_type(input)?;
402 (input, Some(rt))
403 };
404
405 let (input, body) = parse_body(input)?;
406 if body.is_empty() || super::parse(body).is_err() {
407 return Err(nom::Err::Error(nom::error::Error::new(
408 input,
409 nom::error::ErrorKind::Verify,
410 )));
411 }
412
413 Ok((
414 input,
415 QueryDef {
416 name: name.to_string(),
417 params,
418 return_type,
419 body: body.to_string(),
420 is_execute,
421 },
422 ))
423}
424
425fn query_params_are_unique(params: &[QueryParam]) -> bool {
426 let mut names = HashSet::new();
427 params
428 .iter()
429 .all(|param| names.insert(param.name.to_ascii_lowercase()))
430}
431
432fn parse_query_file(input: &str) -> IResult<&str, QueryFile> {
434 let (input, _) = ws_and_comments(input)?;
435 let (input, queries) = many0(parse_query_def).parse(input)?;
436 let (input, _) = ws_and_comments(input)?;
437 if !query_names_are_unique(&queries) {
438 return Err(nom::Err::Error(nom::error::Error::new(
439 input,
440 nom::error::ErrorKind::Verify,
441 )));
442 }
443
444 Ok((input, QueryFile { queries }))
445}
446
447fn query_names_are_unique(queries: &[QueryDef]) -> bool {
448 let mut names = HashSet::new();
449 queries
450 .iter()
451 .all(|query| names.insert(query.name.to_ascii_lowercase()))
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 #[test]
459 fn test_parse_simple_query() {
460 let input = r#"
461 query find_user(id: Uuid) -> User:
462 get users where id = :id
463 "#;
464
465 let qf = QueryFile::parse(input).expect("parse failed");
466 assert_eq!(qf.queries.len(), 1);
467
468 let q = &qf.queries[0];
469 assert_eq!(q.name, "find_user");
470 assert!(!q.is_execute);
471 assert_eq!(q.params.len(), 1);
472 assert_eq!(q.params[0].name, "id");
473 assert_eq!(q.params[0].typ, "Uuid");
474 assert!(matches!(q.return_type, Some(ReturnType::Single(ref t)) if t == "User"));
475 assert!(q.body.contains("get users"));
476 }
477
478 #[test]
479 fn test_parse_vec_return() {
480 let input = r#"
481 query list_orders(user_id: Uuid) -> Vec<Order>:
482 get orders where user_id = :user_id order by created_at desc
483 "#;
484
485 let qf = QueryFile::parse(input).expect("parse failed");
486 let q = &qf.queries[0];
487 assert!(matches!(q.return_type, Some(ReturnType::Vec(ref t)) if t == "Order"));
488 }
489
490 #[test]
491 fn test_parse_option_return() {
492 let input = r#"
493 query find_optional(email: String) -> Option<User>:
494 get users where email = :email limit 1
495 "#;
496
497 let qf = QueryFile::parse(input).expect("parse failed");
498 let q = &qf.queries[0];
499 assert!(matches!(q.return_type, Some(ReturnType::Option(ref t)) if t == "User"));
500 }
501
502 #[test]
503 fn test_parse_generic_param_and_nested_return_types() {
504 let input = r#"
505 query find_many(ids: std::vec::Vec<Uuid>, tags: Option<Vec<String>>) -> Option<Vec<User>>:
506 get users where id in :ids
507 "#;
508
509 let qf = QueryFile::parse(input).expect("parse failed");
510 let q = &qf.queries[0];
511 assert_eq!(q.params[0].typ, "std::vec::Vec<Uuid>");
512 assert_eq!(q.params[1].typ, "Option<Vec<String>>");
513 assert!(matches!(q.return_type, Some(ReturnType::Option(ref t)) if t == "Vec<User>"));
514 }
515
516 #[test]
517 fn test_parse_rejects_unbalanced_generic_param_type() {
518 let input = r#"
519 query broken(ids: Vec<Uuid) -> Vec<User>:
520 get users where id in :ids
521 "#;
522
523 let err = QueryFile::parse(input).expect_err("unbalanced generic must fail");
524 assert!(err.contains("Parse error") || err.contains("Unexpected content"));
525 }
526
527 #[test]
528 fn test_parse_rejects_invalid_query_and_param_identifiers() {
529 let invalid_query_name = r#"
530 query 1find_user(id: Uuid) -> User:
531 get users where id = :id
532 "#;
533 QueryFile::parse(invalid_query_name)
534 .expect_err("query names must be valid Rust identifiers");
535
536 let invalid_param_name = r#"
537 query find_user(1id: Uuid) -> User:
538 get users where id = :id
539 "#;
540 QueryFile::parse(invalid_param_name)
541 .expect_err("parameter names must be valid Rust identifiers");
542 }
543
544 #[test]
545 fn test_parse_rejects_empty_generic_type_arguments() {
546 for input in [
547 r#"
548 query broken(ids: Vec<>) -> Vec<User>:
549 get users where id in :ids
550 "#,
551 r#"
552 query broken(id: Uuid) -> Option<>:
553 get users where id = :id
554 "#,
555 r#"
556 query broken(ids: Vec<,Uuid>) -> Vec<User>:
557 get users where id in :ids
558 "#,
559 r#"
560 query broken(ids: Vec<Uuid,>) -> Vec<User>:
561 get users where id in :ids
562 "#,
563 ] {
564 QueryFile::parse(input).expect_err("empty generic type arguments must fail");
565 }
566 }
567
568 #[test]
569 fn test_parse_rejects_malformed_type_paths() {
570 for input in [
571 r#"
572 query broken(id: .Uuid) -> User:
573 get users where id = :id
574 "#,
575 r#"
576 query broken(id: std::::Uuid) -> User:
577 get users where id = :id
578 "#,
579 r#"
580 query broken(id: Uuid) -> Vec<1User>:
581 get users where id = :id
582 "#,
583 r#"
584 query broken(id: Uuid) -> Option<User><Other>:
585 get users where id = :id
586 "#,
587 ] {
588 QueryFile::parse(input).expect_err("malformed Rust type path must fail");
589 }
590 }
591
592 #[test]
593 fn test_parse_rejects_duplicate_names_params_and_invalid_bodies() {
594 let duplicate_query = r#"
595 query find_user(id: Uuid) -> User:
596 get users where id = :id
597
598 query find_user(email: String) -> User:
599 get users where email = :email
600 "#;
601 QueryFile::parse(duplicate_query).expect_err("duplicate query names must fail");
602
603 let duplicate_param = r#"
604 query find_user(id: Uuid, id: String) -> User:
605 get users where id = :id
606 "#;
607 QueryFile::parse(duplicate_param).expect_err("duplicate params must fail");
608
609 let empty_body = r#"
610 query find_user(id: Uuid) -> User:
611
612 "#;
613 QueryFile::parse(empty_body).expect_err("empty body must fail");
614
615 let invalid_body = r#"
616 query find_user(id: Uuid) -> User:
617 SELECT * FROM users WHERE id = :id
618 "#;
619 QueryFile::parse(invalid_body).expect_err("raw SQL body must fail QAIL parsing");
620
621 let malformed_qail_body = r#"
622 query find_user(id: Uuid) -> User:
623 get .users where id = :id
624 "#;
625 QueryFile::parse(malformed_qail_body).expect_err("malformed QAIL body must fail");
626 }
627
628 #[test]
629 fn test_parse_execute() {
630 let input = r#"
631 execute create_user(email: String, name: String):
632 add users fields email, name values :email, :name
633 "#;
634
635 let qf = QueryFile::parse(input).expect("parse failed");
636 let q = &qf.queries[0];
637 assert!(q.is_execute);
638 assert!(q.return_type.is_none());
639 assert_eq!(q.params.len(), 2);
640 }
641
642 #[test]
643 fn test_parse_multiple_queries() {
644 let input = r#"
645 -- User queries
646 query find_user(id: Uuid) -> User:
647 get users where id = :id
648
649 query list_users() -> Vec<User>:
650 get users order by created_at desc
651
652 execute delete_user(id: Uuid):
653 del users where id = :id
654 "#;
655
656 let qf = QueryFile::parse(input).expect("parse failed");
657 assert_eq!(qf.queries.len(), 3);
658
659 assert_eq!(qf.queries[0].name, "find_user");
660 assert_eq!(qf.queries[1].name, "list_users");
661 assert_eq!(qf.queries[2].name, "delete_user");
662 }
663}