1use winnow::combinator::{alt, repeat, trace};
11use winnow::error::ParserError;
12use winnow::stream::{AsBStr, AsChar, Compare, Stream, StreamIsPartial};
13use winnow::token::{any, literal};
14use winnow::Parser;
15
16use crate::types::Element;
17
18use super::bind::bind;
19use super::command::{command_body, command_kind};
20use super::compose::compose;
21
22fn macro_invocation<'i, Input, Error>(input: &mut Input) -> Result<Element, Error>
26where
27 Input: StreamIsPartial + Stream + Compare<&'i str>,
28 <Input as Stream>::Slice: AsBStr,
29 <Input as Stream>::Token: AsChar + Clone,
30 Error: ParserError<Input>,
31{
32 trace("macro_invocation", move |input: &mut Input| {
33 literal(":").parse_next(input)?;
34
35 alt((
36 literal("bind(").flat_map(|_| bind).map(Element::Bind),
37 literal("compose(")
38 .flat_map(|_| compose)
39 .map(Element::Compose),
40 |input: &mut Input| {
41 let kind = command_kind(input)?;
42 let cmd = command_body(input, kind)?;
43 Ok(Element::Command(cmd))
44 },
45 ))
46 .parse_next(input)
47 })
48 .parse_next(input)
49}
50
51fn sql_literal<'i, Input, Error>(input: &mut Input) -> Result<Element, Error>
57where
58 Input: StreamIsPartial + Stream + Compare<&'i str>,
59 <Input as Stream>::Slice: AsBStr,
60 <Input as Stream>::Token: AsChar + Clone,
61 Error: ParserError<Input>,
62{
63 trace("sql_literal", move |input: &mut Input| {
64 let mut sql = String::new();
65 let mut consumed_comment = false;
66
67 loop {
68 let checkpoint = input.checkpoint();
70 if literal::<_, _, Error>(":").parse_next(input).is_ok() {
71 let is_macro = alt((
73 literal::<_, Input, Error>("bind(").void(),
74 literal::<_, Input, Error>("compose(").void(),
75 literal::<_, Input, Error>("count(").void(),
76 literal::<_, Input, Error>("union(").void(),
77 ))
78 .parse_next(input)
79 .is_ok();
80
81 input.reset(&checkpoint);
83
84 if is_macro {
85 break;
86 }
87 } else {
88 input.reset(&checkpoint);
89 }
90
91 match any::<_, Error>.parse_next(input) {
93 Ok(c) => {
94 let ch = c.as_char();
95 if ch == '#' {
96 consumed_comment = true;
98 loop {
99 match any::<_, Error>.parse_next(input) {
100 Ok(c) if c.clone().as_char() == '\n' => break,
101 Ok(_) => continue,
102 Err(_) => break, }
104 }
105 } else {
106 sql.push(ch);
107 }
108 }
109 Err(_) => break, }
111 }
112
113 if sql.is_empty() && !consumed_comment {
114 return Err(ParserError::from_input(input));
115 }
116
117 Ok(Element::Sql(sql))
118 })
119 .parse_next(input)
120}
121
122fn element<'i, Input, Error>(input: &mut Input) -> Result<Element, Error>
124where
125 Input: StreamIsPartial + Stream + Compare<&'i str>,
126 <Input as Stream>::Slice: AsBStr,
127 <Input as Stream>::Token: AsChar + Clone,
128 Error: ParserError<Input>,
129{
130 trace("element", move |input: &mut Input| {
131 alt((macro_invocation, sql_literal)).parse_next(input)
132 })
133 .parse_next(input)
134}
135
136pub fn template<'i, Input, Error>(input: &mut Input) -> Result<Vec<Element>, Error>
140where
141 Input: StreamIsPartial + Stream + Compare<&'i str>,
142 <Input as Stream>::Slice: AsBStr,
143 <Input as Stream>::Token: AsChar + Clone,
144 Error: ParserError<Input>,
145{
146 trace("template", move |input: &mut Input| {
147 let elements: Vec<Element> = repeat(0.., element).parse_next(input)?;
148 Ok(elements)
149 })
150 .parse_next(input)
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::types::{Binding, CommandKind, ComposeRef, ComposeTarget};
157 use std::path::PathBuf;
158 use winnow::error::ContextError;
159
160 type TestInput<'a> = &'a str;
161
162 #[test]
163 fn test_plain_sql() {
164 let mut input: TestInput = "SELECT id, name FROM users";
165 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
166 assert_eq!(result.len(), 1);
167 assert_eq!(result[0], Element::Sql("SELECT id, name FROM users".into()));
168 }
169
170 #[test]
171 fn test_sql_with_bind() {
172 let mut input: TestInput = "SELECT * FROM users WHERE id = :bind(user_id)";
173 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
174 assert_eq!(result.len(), 2);
175 assert_eq!(
176 result[0],
177 Element::Sql("SELECT * FROM users WHERE id = ".into())
178 );
179 assert_eq!(
180 result[1],
181 Element::Bind(Binding {
182 name: "user_id".into(),
183 min_values: None,
184 max_values: None,
185 nullable: false,
186 })
187 );
188 }
189
190 #[test]
191 fn test_sql_with_compose() {
192 let mut input: TestInput = "SELECT COUNT(*) FROM (\n :compose(templates/get_user.tql)\n)";
193 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
194 assert_eq!(result.len(), 3);
195 assert_eq!(result[0], Element::Sql("SELECT COUNT(*) FROM (\n ".into()));
196 assert_eq!(
197 result[1],
198 Element::Compose(ComposeRef {
199 target: ComposeTarget::Path(PathBuf::from("templates/get_user.tql")),
200 slots: vec![],
201 })
202 );
203 assert_eq!(result[2], Element::Sql("\n)".into()));
204 }
205
206 #[test]
207 fn test_multiple_binds() {
208 let mut input: TestInput = "WHERE id = :bind(user_id) AND active = :bind(active)";
209 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
210 assert_eq!(result.len(), 4);
211 assert_eq!(result[0], Element::Sql("WHERE id = ".into()));
212 assert_eq!(
213 result[1],
214 Element::Bind(Binding {
215 name: "user_id".into(),
216 min_values: None,
217 max_values: None,
218 nullable: false,
219 })
220 );
221 assert_eq!(result[2], Element::Sql(" AND active = ".into()));
222 assert_eq!(
223 result[3],
224 Element::Bind(Binding {
225 name: "active".into(),
226 min_values: None,
227 max_values: None,
228 nullable: false,
229 })
230 );
231 }
232
233 #[test]
234 fn test_colon_not_a_macro() {
235 let mut input: TestInput = "SELECT '10:30' FROM t";
236 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
237 assert_eq!(result.len(), 1);
238 assert_eq!(result[0], Element::Sql("SELECT '10:30' FROM t".into()));
239 }
240
241 #[test]
242 fn test_command_in_template() {
243 let mut input: TestInput = ":count(templates/get_user.tql)";
244 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
245 assert_eq!(result.len(), 1);
246 match &result[0] {
247 Element::Command(cmd) => {
248 assert_eq!(cmd.kind, CommandKind::Count);
249 assert_eq!(cmd.sources, vec![PathBuf::from("templates/get_user.tql")]);
250 }
251 other => panic!("expected Command, got {:?}", other),
252 }
253 }
254
255 #[test]
256 fn test_full_template() {
257 let mut input: TestInput =
258 "SELECT id, name, email\nFROM users\nWHERE id = :bind(user_id)\n AND active = :bind(active);";
259 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
260 assert_eq!(result.len(), 5);
261 assert_eq!(
262 result[0],
263 Element::Sql("SELECT id, name, email\nFROM users\nWHERE id = ".into())
264 );
265 assert_eq!(
266 result[1],
267 Element::Bind(Binding {
268 name: "user_id".into(),
269 min_values: None,
270 max_values: None,
271 nullable: false,
272 })
273 );
274 assert_eq!(result[2], Element::Sql("\n AND active = ".into()));
275 assert_eq!(
276 result[3],
277 Element::Bind(Binding {
278 name: "active".into(),
279 min_values: None,
280 max_values: None,
281 nullable: false,
282 })
283 );
284 assert_eq!(result[4], Element::Sql(";".into()));
285 }
286
287 #[test]
288 fn test_semicolon_after_bind() {
289 let mut input: TestInput = "WHERE id = :bind(user_id);";
290 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
291 assert_eq!(result.len(), 3);
292 assert_eq!(result[0], Element::Sql("WHERE id = ".into()));
293 assert_eq!(
294 result[1],
295 Element::Bind(Binding {
296 name: "user_id".into(),
297 min_values: None,
298 max_values: None,
299 nullable: false,
300 })
301 );
302 assert_eq!(result[2], Element::Sql(";".into()));
303 }
304
305 #[test]
306 fn test_empty_input() {
307 let mut input: TestInput = "";
308 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
309 assert!(result.is_empty());
310 }
311
312 #[test]
313 fn test_comment_standalone_line() {
314 let mut input: TestInput = "# comment\nSELECT 1;";
315 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
316 assert_eq!(result.len(), 1);
317 assert_eq!(result[0], Element::Sql("SELECT 1;".into()));
318 }
319
320 #[test]
321 fn test_comment_inline() {
322 let mut input: TestInput = "SELECT 1; # comment\nSELECT 2;";
323 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
324 assert_eq!(result.len(), 1);
325 assert_eq!(result[0], Element::Sql("SELECT 1; SELECT 2;".into()));
326 }
327
328 #[test]
329 fn test_comment_with_macro_text() {
330 let mut input: TestInput = "# :bind(x)\nSELECT 1;";
331 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
332 assert_eq!(result.len(), 1);
333 assert_eq!(result[0], Element::Sql("SELECT 1;".into()));
334 }
335
336 #[test]
337 fn test_comment_before_macro() {
338 let mut input: TestInput = "# get user\nSELECT * FROM users WHERE id = :bind(id);";
339 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
340 assert_eq!(result.len(), 3);
341 assert_eq!(
342 result[0],
343 Element::Sql("SELECT * FROM users WHERE id = ".into())
344 );
345 assert_eq!(
346 result[1],
347 Element::Bind(Binding {
348 name: "id".into(),
349 min_values: None,
350 max_values: None,
351 nullable: false,
352 })
353 );
354 assert_eq!(result[2], Element::Sql(";".into()));
355 }
356
357 #[test]
358 fn test_only_comments() {
359 let mut input: TestInput = "# just a comment";
360 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
361 assert_eq!(result.len(), 1);
363 assert_eq!(result[0], Element::Sql(String::new()));
364 }
365
366 #[test]
367 fn test_multiple_comment_lines() {
368 let mut input: TestInput = "# line 1\n# line 2\nSELECT 1;";
369 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
370 assert_eq!(result.len(), 1);
371 assert_eq!(result[0], Element::Sql("SELECT 1;".into()));
372 }
373
374 #[test]
375 fn test_comment_at_eof_no_newline() {
376 let mut input: TestInput = "SELECT 1;\n# trailing";
377 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
378 assert_eq!(result.len(), 1);
379 assert_eq!(result[0], Element::Sql("SELECT 1;\n".into()));
380 }
381
382 #[test]
383 fn test_at_sign_in_sql_literal() {
384 let mut input: TestInput = "SELECT @variable FROM t";
385 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
386 assert_eq!(result.len(), 1);
387 assert_eq!(
388 result[0],
389 Element::Sql("SELECT @variable FROM t".into())
390 );
391 }
392
393 #[test]
394 fn test_comments_before_macro() {
395 let mut input: TestInput = "# comment line 1\n# comment line 2\n:bind(id)";
396 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
397 assert_eq!(result.len(), 2);
399 assert_eq!(result[0], Element::Sql(String::new()));
400 assert_eq!(
401 result[1],
402 Element::Bind(Binding {
403 name: "id".into(),
404 min_values: None,
405 max_values: None,
406 nullable: false,
407 })
408 );
409 }
410}