fnsql_macro/
lib.rs

1//! The `fnsql` crate provides simple type-safe optional wrappers around SQL
2//! queries. Instead of calling type-less `.query()` and `.execute()`, you call to
3//! auto-generated unique wrappers that are strongly typed, `.query_<name>()` and
4//! `.execute_<name>()`. However, you manually specify the input and output types,
5//! but only once, with the query, and in separation with the code that uses the
6//! query.
7//!
8//! It's a very simple implementation that doesn't force any schema or ORM down
9//! your throat, so if you are already using the `rusqlite` or `postgres` crates,
10//! you can gradually replace your type-less queries with the type-ful wrappers,
11//! or migrate from an opinionated ORM.
12//!
13//! The way to generate these wrappers is to specify input and output types for
14//! each one of the queries. For example, consider the following definitions
15//! specified with `fnsql`, based on the `rusqlite` example:
16//!
17//! ```rust
18//! fnsql::fnsql! {
19//!     #[rusqlite, test]
20//!     create_table_pet() {
21//!         "CREATE TABLE pet (
22//!               id      INTEGER PRIMARY KEY,
23//!               name    TEXT NOT NULL,
24//!               data    BLOB
25//!         )"
26//!     }
27//!
28//!     #[rusqlite, test(with=[create_table_pet])]
29//!     insert_new_pet(name: String, data: Option<Vec<u8>>) {
30//!         "INSERT INTO pet (name, data) VALUES (:name, :data)"
31//!     }
32//!
33//!     #[rusqlite, test(with=[create_table_pet])]
34//!     get_pet_id_data(name: Option<String>) -> [(i32, Option<Vec<u8>>, String)] {
35//!         "SELECT id, data, name FROM pet WHERE pet.name = :name"
36//!     }
37//! }
38//! ```
39//!
40//! The definitions can be used as such (commented out is how the previous
41//! type-less interfaces were used):
42//!
43//! ```rust ignore
44//! let mut conn = rusqlite::Connection::open_in_memory()?;
45//!
46//! conn.execute_create_table_pet()?;
47//! // conn.execute(
48//! //    "CREATE TABLE pet (
49//! //               id              INTEGER PRIMARY KEY,
50//! //               name            TEXT NOT NULL,
51//! //               data            BLOB
52//! //               )",
53//! //     [],
54//! // )?;
55//!
56//! conn.execute_insert_new_pet(&me.name, &me.data)?;
57//! // conn.execute(
58//! //     "INSERT INTO pet (name, data) VALUES (?1, ?2)",
59//! //     params![me.name, me.data],
60//! // )?;
61//!
62//! let mut stmt = conn.prepare_get_pet_id_data()?;
63//! // let mut stmt = conn.prepare("SELECT id, data, name FROM pet WHERE pet.name = :name")?;
64//!
65//! let pet_iter = stmt.query_map(&Some("Max".to_string()), |id, data, name| {
66//!     Ok::<_, rusqlite::Error>(Pet {
67//!         id,
68//!         data,
69//!         name,
70//!     })
71//! })?;
72//! // let pet_iter = stmt.query_map([(":name", "Max".to_string())], |row| {
73//! //     Ok(Pet {
74//! //         id: row.get(0)?,
75//! //         name: row.get(1)?,
76//! //         data: row.get(2)?,
77//! //     })
78//! // })?;
79//! ```
80//!
81//! ## Technical discussion
82//!
83//! The idea with this crate is to allow direct SQL usage but never use inline
84//! queries or have type inference at the call-site. Instead, we declare each query
85//! on top-level, giving each a name and designated accessor methods that derive
86//! from the name.
87//!
88//! - The types of named variables are give in a Rust-like syntax.
89//! - The type of the returned row is also provided.
90//! - `fnsql` does not make an assurances to make sure the types match the query,
91//!   you will discover it with `cargo test` and no additional code.
92//! - `fnsql` writes the tests for each of the queries.  - `Arbitrary` is used to
93//!   generate parameter values.
94//! - If testing one query depend on another, you can specify that with `test(with=[..])`.
95//!
96//! ```text
97//! running 3 tests
98//! test auto_create_table_pet ... ok
99//! test auto_insert_new_pet ... ok
100//! test auto_get_pet_id_data ... ok
101//! ```
102//!
103//! The following is for allowing generated query tests to compile:
104//!
105//! ```toml
106//! [dev-dependencies]
107//! arbitrary = { version = "1", features = ["derive"] }
108//! ```
109//!
110//! ## Limitations
111//!
112//!  * Though it <i>does</i> provide auto-generated tests for validating queries in `cargo test`,
113//!    it does not do any compile-time validation based on the SQL query string.
114//!  * It only supports `rusqlite` and `postgres` for now.
115
116extern crate proc_macro;
117
118use std::collections::HashMap;
119
120use proc_macro::TokenStream;
121use proc_macro2::TokenStream as Tokens;
122use quote::{quote, ToTokens};
123use regex::{Regex, Captures};
124use syn::{
125    braced, bracketed, parenthesized,
126    parse::{Parse, ParseStream},
127    parse_macro_input,
128    punctuated::Punctuated,
129    token, Ident, Token, LitStr,
130};
131
132struct Queries {
133    list: Vec<Query>,
134}
135
136impl Parse for Queries {
137    fn parse(input: ParseStream) -> syn::Result<Self> {
138        let mut list = vec![];
139        while !input.is_empty() {
140            list.push(input.parse()?)
141        }
142
143        Ok(Queries { list })
144    }
145}
146
147enum Kind {
148    Rusqlite,
149    PostgreSQL,
150}
151
152struct Query {
153    name: Ident,
154    params: Vec<Param>,
155    outputs: Vec<Output>,
156    query: syn::LitStr,
157    kind: Kind,
158    test: Option<Vec<String>>,
159    named: bool,
160}
161
162impl Parse for Query {
163    fn parse(input: ParseStream) -> syn::Result<Self> {
164        let mut kind = None;
165        let mut test = None;
166        let mut named = false;
167
168        if input.peek(Token![#]) {
169            let _: Token![#] = input.parse()?;
170            let content;
171            let _ = bracketed!(content in input);
172            let list: Punctuated<Attr, Token![,]> = content.parse_terminated(Parse::parse)?;
173
174            for attr in list {
175                match attr {
176                    Attr::Kind(attr_kind) => {
177                        kind = Some(attr_kind);
178                    }
179                    Attr::Test(test_attrs) => {
180                        if test.is_none() {
181                            test = Some(vec![]);
182                        }
183                        for test_attr in test_attrs {
184                            match test_attr {
185                                TestAttr::With(v) => {
186                                    test.as_mut().unwrap().extend(v);
187                                }
188                            }
189                        }
190                    }
191                    Attr::Named => {
192                        named = true;
193                    },
194                }
195            }
196        };
197
198        let name = input.parse()?;
199        let kind = match kind {
200            None => panic!("unknown SQL type. Supported: rusqlite"),
201            Some(kind) => kind,
202        };
203        let content;
204        let _ = parenthesized!(content in input);
205        let list: Punctuated<_, Token![,]> = content.parse_terminated(Parse::parse)?;
206        let params = list.into_iter().collect();
207
208        let outputs = if input.peek(Token![->]) {
209            let _: Token![->] = input.parse()?;
210
211            let content;
212            let _ = bracketed!(content in input);
213            {
214                let sub_content;
215                let _ = parenthesized!(sub_content in content);
216                let list: Punctuated<_, Token![,]> = sub_content.parse_terminated(Parse::parse)?;
217                list.into_iter().collect()
218            }
219        } else {
220            vec![]
221        };
222
223        let content;
224        let _ = braced!(content in input);
225        let query = content.parse::<syn::LitStr>()?;
226
227        Ok(Query {
228            name,
229            params,
230            outputs,
231            query,
232            kind,
233            test,
234            named,
235        })
236    }
237}
238
239impl Query {
240    fn prepend_name(&self, prefix: &'static str) -> Ident {
241        Ident::new(&format!("{}{}", prefix, &self.name), self.name.span())
242    }
243
244    fn params_declr(&self) -> Tokens {
245        let list: Vec<_> = self.params.iter().map(|x| x.expand_declr()).collect();
246        quote! { #(, #list)* }
247    }
248
249    fn outputs_declr(&self) -> Tokens {
250        let list: Vec<_> = self.outputs.iter().map(|x| x.expand_declr()).collect();
251        quote! { #(#list),* }
252    }
253
254    fn outputs_row_get_numbered(&self) -> Tokens {
255        let list: Vec<_> = self
256            .outputs
257            .iter()
258            .enumerate()
259            .map(|(i, _)| {
260                let i = syn::LitInt::new(&format!("{}", i), self.name.span());
261                quote! {row.get(#i)?}
262            })
263            .collect();
264
265        quote! { #(#list),* }
266    }
267
268    fn outputs_row_try_get_numbered(&self) -> Tokens {
269        let list: Vec<_> = self
270            .outputs
271            .iter()
272            .enumerate()
273            .map(|(i, _)| {
274                let i = syn::LitInt::new(&format!("{}", i), self.name.span());
275                quote! {row.try_get(#i)?}
276            })
277            .collect();
278
279        quote! { #(#list),* }
280    }
281
282    fn outputs_mapped_row_closure(&self) -> Tokens {
283        let list = self.outputs_row_get_numbered();
284        quote! { Ok(map(#list)) }
285    }
286
287    fn params_arbitrary(&self) -> (Tokens, Tokens) {
288        let mut gen_lets = vec![];
289        let mut params = vec![];
290
291        let _ = self
292            .params
293            .iter()
294            .enumerate()
295            .map(|(idx, param)| {
296                let ttype = &param.ttype;
297                let owned_ttype = if ttype.to_token_stream().to_string() == "str" {
298                    quote! {String}
299                } else if ttype.to_token_stream().to_string() == "[u8]" {
300                    quote! {Vec<u8>}
301                } else {
302                    quote! {#ttype}
303                };
304                let ident = Ident::new(&format!("i_{}", idx), self.name.span());
305
306                gen_lets.push(quote! {
307                    let #ident: #owned_ttype = arbitrary::Arbitrary::arbitrary(uns).unwrap();
308                });
309                params.push(quote! {&#ident});
310            })
311            .collect::<Vec<()>>();
312
313        (quote! { #(#gen_lets);* }, quote! { #(#params),* })
314    }
315
316    fn params_query(&self) -> Tokens {
317        let list: Vec<_> = self.params.iter().map(|x| x.expand_query(self)).collect();
318        if list.len() == 0 {
319            quote! { [] }
320        } else {
321            quote! { &[#(#list),*] }
322        }
323    }
324
325    fn params_query_ref(&self) -> Tokens {
326        let list: Vec<_> = self.params.iter().map(|x| x.expand_query(self)).collect();
327        if list.len() == 0 {
328            quote! { &[] }
329        } else {
330            quote! { &[#(#list),*] }
331        }
332    }
333
334    fn params_relay(&self) -> Tokens {
335        let list: Vec<_> = self
336            .params
337            .iter()
338            .map(|x| {
339                let name = &x.name;
340                quote! { #name }
341            })
342            .collect();
343        if list.len() == 0 {
344            quote! {}
345        } else {
346            quote! { #(#list),*, }
347        }
348    }
349
350    fn expand(&self) -> Tokens {
351        match self.kind {
352            Kind::Rusqlite => self.sqlite_expand(),
353            Kind::PostgreSQL => self.postgres_expand(),
354        }
355    }
356
357    fn postgres_expand(&self) -> Tokens {
358        #[allow(non_snake_case)]
359        let Client = self.prepend_name("Client_");
360        #[allow(non_snake_case)]
361        let Statement = self.prepend_name("Statement_");
362        let execute_name = self.prepend_name("execute_");
363        let execute_prepared_name = self.prepend_name("execute_prepared_");
364        let prepare_name = self.prepend_name("prepare_");
365        let prepare_cached_name = self.prepend_name("prepare_cached_");
366        let convert_row = self.prepend_name("convert_row_");
367        let query_name = self.prepend_name("query_");
368        let query_prepared_name = self.prepend_name("query_prepared_");
369        let query_one_name = self.prepend_name("query_one_");
370        let query_one_prepared_name = self.prepend_name("query_one_prepared_");
371        let query_opt_name = self.prepend_name("query_opt_");
372        let query_opt_prepared_name = self.prepend_name("query_opt_prepared_");
373        let params_declr = self.params_declr();
374        let params_query_ref = self.params_query_ref();
375        let outputs_declr = self.outputs_declr();
376        let row_try_get_numbered = self.outputs_row_try_get_numbered();
377
378        let query;
379        if self.named {
380            lazy_static::lazy_static! {
381                static ref RE: Regex = Regex::new(":([A-Za-z_][_A-Za-z0-9]*)($|[^_A-Za-z0-9])").unwrap();
382            }
383
384            let params: HashMap<_, _> = self
385                .params
386                .iter()
387                .enumerate()
388                .map(|(idx, param)| {
389                    (format!("{}", param.name), idx)
390                }).collect();
391
392            query = String::from(RE.replace_all(&self.query.value(), |captures: &Captures| {
393                let c1 = captures.get(1).unwrap().as_str();
394                let c2 = captures.get(2).unwrap().as_str();
395                match params.get(c1) {
396                    Some(idx) => format!("${}{}", idx + 1, c2),
397                    None => format!("{}{}", c1, c2),
398                }
399            }));
400        } else {
401            query = self.query.value();
402        };
403        let query = LitStr::new(query.as_str(), self.query.span());
404
405        #[cfg(feature = "prepare-cache")]
406        let (prepare_cached_decl, prepare_cached_impl) = {
407            let prepare_cached_decl = quote! {
408                fn #prepare_cached_name(&mut self, cache: &mut fnsql::postgres::Cache) -> Result<#Statement, postgres::Error>;
409            };
410
411            let prepare_cached_impl = quote! {
412                fn #prepare_cached_name(&mut self, cache: &mut fnsql::postgres::Cache) -> Result<#Statement, postgres::Error> {
413                    Ok(#Statement(cache.prepare(#query, self)?))
414                }
415            };
416
417            (prepare_cached_decl, prepare_cached_impl)
418        };
419
420        #[cfg(not(feature = "prepare-cache"))]
421        let (prepare_cached_decl, prepare_cached_impl) = {
422            (quote!{}, quote!{})
423        };
424
425        let defs = quote! {
426            #[allow(non_camel_case_types)]
427            pub struct #Statement(pub postgres::Statement);
428
429            #[allow(non_camel_case_types)]
430            pub trait #Client {
431                fn #prepare_name(&mut self) -> Result<#Statement, postgres::Error>;
432                #prepare_cached_decl
433                fn #execute_name(&mut self #params_declr) -> Result<u64, postgres::Error>;
434                fn #execute_prepared_name(&mut self, stmt: &#Statement #params_declr)
435                    -> Result<u64, postgres::Error>;
436                fn #query_name(&mut self #params_declr) -> Result<Vec<(#outputs_declr)>, postgres::Error>;
437                fn #query_prepared_name(&mut self, stmt: &#Statement #params_declr) -> Result<Vec<(#outputs_declr)>, postgres::Error>;
438                fn #query_one_name(&mut self #params_declr) -> Result<(#outputs_declr), postgres::Error>;
439                fn #query_one_prepared_name(&mut self, stmt: &#Statement #params_declr) -> Result<(#outputs_declr), postgres::Error>;
440                fn #query_opt_name(&mut self #params_declr) -> Result<Option<(#outputs_declr)>, postgres::Error>;
441                fn #query_opt_prepared_name(&mut self, stmt: &#Statement #params_declr) -> Result<Option<(#outputs_declr)>, postgres::Error>;
442            }
443
444            pub fn #convert_row(row: postgres::Row) -> Result<(#outputs_declr), postgres::Error> {
445                Ok((#row_try_get_numbered))
446            }
447        };
448
449        let timpl = quote! {
450            fn #prepare_name(&mut self)  -> Result<#Statement, postgres::Error> {
451                self.prepare(#query).map(#Statement)
452            }
453
454            #prepare_cached_impl
455
456            fn #execute_name(&mut self #params_declr) -> Result<u64, postgres::Error> {
457                self.execute(#query, #params_query_ref)
458            }
459
460            fn #execute_prepared_name(&mut self, stmt: &#Statement #params_declr)
461                -> Result<u64, postgres::Error>
462            {
463                self.execute(&stmt.0, #params_query_ref)
464            }
465
466            fn #query_name(&mut self #params_declr) -> Result<Vec<(#outputs_declr)>, postgres::Error> {
467                let result: Result<Vec<_>, postgres::Error> =
468                    self.query(#query, #params_query_ref)?.into_iter().map(#convert_row).collect();
469                result
470            }
471
472            fn #query_one_name(&mut self #params_declr) -> Result<(#outputs_declr), postgres::Error> {
473                Ok(#convert_row(self.query_one(#query, #params_query_ref)?)?)
474            }
475
476            fn #query_prepared_name(&mut self, stmt: &#Statement #params_declr) -> Result<Vec<(#outputs_declr)>, postgres::Error> {
477                let result: Result<Vec<_>, postgres::Error> =
478                    self.query(&stmt.0, #params_query_ref)?.into_iter().map(#convert_row).collect();
479                result
480            }
481
482            fn #query_one_prepared_name(&mut self, stmt: &#Statement #params_declr) -> Result<(#outputs_declr), postgres::Error> {
483                Ok(#convert_row(self.query_one(&stmt.0, #params_query_ref)?)?)
484            }
485
486            fn #query_opt_name(&mut self #params_declr) -> Result<Option<(#outputs_declr)>, postgres::Error> {
487                match self.query_opt(#query, #params_query_ref)? {
488                    None => Ok(None),
489                    Some(x) => Ok(Some(#convert_row(x)?)),
490                }
491            }
492
493            fn #query_opt_prepared_name(&mut self, stmt: &#Statement #params_declr) -> Result<Option<(#outputs_declr)>, postgres::Error> {
494                match self.query_opt(&stmt.0, #params_query_ref)? {
495                    None => Ok(None),
496                    Some(x) => Ok(Some(#convert_row(x)?)),
497                }
498            }
499        };
500
501        let test_code = self.test_code();
502
503        quote! {
504            #defs
505
506            impl #Client for postgres::Client {
507                #timpl
508            }
509
510            impl<'a> #Client for postgres::Transaction<'a> {
511                #timpl
512            }
513
514            #test_code
515        }
516    }
517
518    fn sqlite_expand(&self) -> Tokens {
519        let conn_trait_name = self.prepend_name("Connection_");
520        #[allow(non_snake_case)]
521        let StatementType = self.prepend_name("Statement_");
522        #[allow(non_snake_case)]
523        let CachedStatementType = self.prepend_name("CachedStatement_");
524        #[allow(non_snake_case)]
525        let MappedRows = self.prepend_name("MappedRows_");
526        #[allow(non_snake_case)]
527        let Rows = self.prepend_name("Rows_");
528        let prepare_name = self.prepend_name("prepare_");
529        let prepare_cached_name = self.prepend_name("prepare_cached_");
530        let execute_name = self.prepend_name("execute_");
531        let query_row_name = self.prepend_name("query_row_");
532        let params_declr = self.params_declr();
533        let outputs_declr = self.outputs_declr();
534        let row_closure = self.outputs_row_get_numbered();
535        let mapped_row_closure = self.outputs_mapped_row_closure();
536        let params_query = self.params_query();
537        let params_relay = self.params_relay();
538        let query = &self.query;
539
540        let test_code = self.test_code();
541
542        quote! {
543            #[allow(non_camel_case_types)]
544            pub trait #conn_trait_name {
545                fn #prepare_name(&self) -> rusqlite::Result<#StatementType<'_>>;
546                fn #prepare_cached_name(&self) -> rusqlite::Result<#CachedStatementType<'_>>;
547                fn #execute_name(&self #params_declr) -> rusqlite::Result<usize>;
548                fn #query_row_name<F, T>(&mut self #params_declr, f: F) -> rusqlite::Result<T>
549                where
550                    F: FnMut(#outputs_declr) -> T;
551            }
552
553            impl #conn_trait_name for rusqlite::Connection {
554                fn #prepare_name(&self) -> rusqlite::Result<#StatementType<'_>> {
555                    self.prepare(#query).map(#StatementType)
556                }
557
558                fn #prepare_cached_name(&self) -> rusqlite::Result<#CachedStatementType<'_>> {
559                    self.prepare_cached(#query).map(#CachedStatementType)
560                }
561
562                fn #execute_name(&self #params_declr) -> rusqlite::Result<usize> {
563                    self.execute(#query, #params_query)
564                }
565
566                fn #query_row_name<F, T>(&mut self #params_declr, f: F) -> rusqlite::Result<T>
567                where
568                    F: FnMut(#outputs_declr) -> T,
569                {
570                    let mut stmt = self.#prepare_name()?;
571                    stmt.query_row(#params_relay f)
572                }
573            }
574
575            #[allow(non_camel_case_types)]
576            pub struct #MappedRows<'stmt, F> {
577                rows: rusqlite::Rows<'stmt>,
578                map: F,
579            }
580
581            impl<'stmt, T, F> #MappedRows<'stmt, F>
582            where
583                F: FnMut(#outputs_declr) -> T
584            {
585                pub(crate) fn new(rows: rusqlite::Rows<'stmt>, f: F) -> Self {
586                    Self { rows, map: f }
587                }
588            }
589
590            impl<'stmt, T, F> Iterator for #MappedRows<'stmt, F>
591            where
592                F: FnMut(#outputs_declr) -> T
593            {
594                type Item = rusqlite::Result<T>;
595
596                fn next(&mut self) -> Option<rusqlite::Result<T>> {
597                    let map = &mut self.map;
598                    self.rows
599                        .next()
600                        .transpose()
601                        .map(|row_result| {
602                            row_result.and_then(|row| {
603                                #mapped_row_closure
604                            })
605                        })
606                }
607            }
608
609            #[allow(non_camel_case_types)]
610            pub struct #Rows<'stmt> {
611                rows: rusqlite::Rows<'stmt>,
612            }
613
614            impl<'stmt> #Rows<'stmt> {
615                pub(crate) fn new(rows: rusqlite::Rows<'stmt>) -> Self {
616                    Self { rows }
617                }
618            }
619
620            impl<'stmt> Iterator for #Rows<'stmt> {
621                type Item = rusqlite::Result<(#outputs_declr)>;
622
623                fn next(&mut self) -> Option<Self::Item> {
624                    self.rows
625                        .next()
626                        .transpose()
627                        .map(|row_result| {
628                            row_result.and_then(|row| {
629                                Ok((#row_closure))
630                            })
631                        })
632                }
633            }
634
635            #[allow(non_camel_case_types)]
636            pub struct #StatementType<'a>(pub rusqlite::Statement<'a>);
637
638            impl<'a> #StatementType<'a> {
639                fn query_map<F, T>(&mut self #params_declr, f: F) -> rusqlite::Result<#MappedRows<'_, F>>
640                where
641                    F: FnMut(#outputs_declr) -> T,
642                {
643                    let rows = self.0.query(#params_query)?;
644                    Ok(#MappedRows::new(rows, f))
645                }
646
647                fn query_row<F, T>(&mut self #params_declr, f: F) -> rusqlite::Result<T>
648                where
649                    F: FnMut(#outputs_declr) -> T,
650                {
651                    let rows = self.query_map(#params_relay f)?;
652                    for item in rows {
653                        return Ok(item?);
654                    }
655                    Err(rusqlite::Error::QueryReturnedNoRows)
656                }
657
658                fn query(&mut self #params_declr) -> rusqlite::Result<#Rows<'_>> {
659                    let rows = self.0.query(#params_query)?;
660                    Ok(#Rows::new(rows))
661                }
662
663                fn execute(&mut self #params_declr) -> rusqlite::Result<()> {
664                    self.0.execute(#params_query)?;
665                    Ok(())
666                }
667            }
668
669            #[allow(non_camel_case_types)]
670            pub struct #CachedStatementType<'a>(pub rusqlite::CachedStatement<'a>);
671
672            impl<'a> #CachedStatementType<'a> {
673                fn query_map<F, T>(&mut self #params_declr, f: F) -> rusqlite::Result<#MappedRows<'_, F>>
674                where
675                    F: FnMut(#outputs_declr) -> T,
676                {
677                    let rows = self.0.query(#params_query)?;
678                    Ok(#MappedRows::new(rows, f))
679                }
680
681                fn query_row<F, T>(&mut self #params_declr, f: F) -> rusqlite::Result<T>
682                where
683                    F: FnMut(#outputs_declr) -> T,
684                {
685                    let rows = self.query_map(#params_relay f)?;
686                    for item in rows {
687                        return Ok(item?);
688                    }
689                    Err(rusqlite::Error::QueryReturnedNoRows)
690                }
691
692                fn query(&mut self #params_declr) -> rusqlite::Result<#Rows<'_>> {
693                    let rows = self.0.query(#params_query)?;
694                    Ok(#Rows::new(rows))
695                }
696
697                fn execute(&mut self #params_declr) -> rusqlite::Result<()> {
698                    self.0.execute(#params_query)?;
699                    Ok(())
700                }
701            }
702
703            #test_code
704        }
705    }
706
707    fn test_code(&self) -> Tokens {
708        let test_name = self.prepend_name("auto_");
709        let testsetup_name = self.prepend_name("testsetup_");
710        let (params_arbit_prep, params_arbit) = self.params_arbitrary();
711        let execute_name = self.prepend_name("execute_");
712        let name = syn::LitStr::new(&self.name.to_string(), self.name.span());
713
714        let client_type = match self.kind {
715            Kind::Rusqlite => quote!{rusqlite::Connection},
716            Kind::PostgreSQL => quote!{postgres::Client},
717        };
718        let client_ref_type = match self.kind {
719            Kind::Rusqlite => quote!{&},
720            Kind::PostgreSQL => quote!{&mut},
721        };
722        let ignore_error = match self.kind {
723            Kind::Rusqlite => quote!{Err(rusqlite::Error::ExecuteReturnedResults) => {}},
724            Kind::PostgreSQL => quote!{},
725        };
726        let error_type = match self.kind {
727            Kind::Rusqlite => quote!{rusqlite::Error},
728            Kind::PostgreSQL => quote!{postgres::Error},
729        };
730        let open_client = match self.kind {
731            Kind::Rusqlite => quote!{
732                let conn = #client_type::open_in_memory()?;
733            },
734            Kind::PostgreSQL => quote!{let mut conn = {
735                let mut conn = fnsql::postgres::testing_client().expect("unable to connect testing client");
736                conn.execute("SET search_path TO pg_temp", &[]).unwrap();
737                conn
738            }; },
739        };
740
741        let test = if let Some(depends) = &self.test {
742            let depends = depends.iter().map(|name| {
743                let parent_testsetup_name =
744                    Ident::new(&format!("testsetup_{}", name), self.name.span());
745                quote! {
746                    #parent_testsetup_name(uns, deps, conn)?;
747                }
748            });
749            quote! {
750                #[cfg(test)]
751                fn #testsetup_name(
752                    uns: &mut arbitrary::Unstructured,
753                    deps: &mut std::collections::HashSet<&'static str>,
754                    conn: #client_ref_type #client_type) -> Result<(), #error_type>
755                {
756                    if !deps.insert(#name) {
757                        return Ok(());
758                    }
759
760                    #(#depends);*
761
762                    #params_arbit_prep;
763                    let r = conn.#execute_name(#params_arbit);
764                    match r {
765                        Ok(_) => {}
766                        #ignore_error
767                        Err(err) => {
768                            eprintln!("{:?}", err);
769                            Err(err)?;
770                        },
771                    }
772                    Ok(())
773                }
774
775                #[test]
776                fn #test_name() -> Result<(), #error_type> {
777                    #open_client;
778                    let mut deps = std::collections::HashSet::new();
779                    let raw_data: &[u8] = &[1, 2, 3];
780                    let mut unstructured = arbitrary::Unstructured::new(raw_data);
781
782                    #testsetup_name(&mut unstructured, &mut deps, #client_ref_type conn)?;
783                    Ok(())
784                }
785            }
786        } else {
787            quote! {}
788        };
789        test
790    }
791}
792
793struct Output {
794    ttype: syn::Type,
795}
796
797impl Parse for Output {
798    fn parse(input: ParseStream) -> syn::Result<Self> {
799        let ttype = input.parse()?;
800
801        Ok(Self { ttype })
802    }
803}
804
805impl Output {
806    fn expand_declr(&self) -> Tokens {
807        let ttype = &self.ttype;
808
809        quote! { #ttype }
810    }
811}
812
813struct Param {
814    name: Ident,
815    ttype: syn::Type,
816}
817
818impl Parse for Param {
819    fn parse(input: ParseStream) -> syn::Result<Self> {
820        let name = input.parse()?;
821        let _: Token![:] = input.parse()?;
822        let ttype = input.parse()?;
823
824        Ok(Self { name, ttype })
825    }
826}
827
828impl Param {
829    fn expand_declr(&self) -> Tokens {
830        let name = &self.name;
831        let ttype = &self.ttype;
832
833        quote! { #name: &#ttype }
834    }
835
836    fn expand_query(&self, query: &Query) -> Tokens {
837        let name = &self.name;
838        let specifier = syn::LitStr::new(&format!(":{}", name), name.span());
839
840        match query.kind {
841            Kind::Rusqlite => quote! { (#specifier, &#name as &dyn rusqlite::ToSql) },
842            Kind::PostgreSQL => quote! { &#name as &(dyn postgres::types::ToSql + Sync) }
843        }
844    }
845}
846
847enum Attr {
848    Kind(Kind),
849    Test(Vec<TestAttr>),
850    Named,
851}
852
853impl Parse for Attr {
854    fn parse(input: ParseStream) -> syn::Result<Self> {
855        let ident: Ident = input.parse()?;
856        if ident == "rusqlite" {
857            return Ok(Attr::Kind(Kind::Rusqlite));
858        }
859        if ident == "postgres" {
860            return Ok(Attr::Kind(Kind::PostgreSQL));
861        }
862        if ident == "named" {
863            return Ok(Attr::Named);
864        }
865        if ident == "test" {
866            let mut v = vec![];
867
868            if input.peek(token::Paren) {
869                let content;
870                let _ = parenthesized!(content in input);
871                let list: Punctuated<TestAttr, Token![,]> =
872                    content.parse_terminated(Parse::parse)?;
873                v = list.into_iter().collect();
874            };
875
876            return Ok(Attr::Test(v));
877        }
878        panic!("unknown attribute {}", ident);
879    }
880}
881
882enum TestAttr {
883    With(Vec<String>),
884}
885
886impl Parse for TestAttr {
887    fn parse(input: ParseStream) -> syn::Result<Self> {
888        let ident: Ident = input.parse()?;
889        if ident == "with" {
890            let mut v = vec![];
891
892            let _: Token![=] = input.parse()?;
893            let content;
894            let _ = bracketed!(content in input);
895            let list: Punctuated<Ident, Token![,]> = content.parse_terminated(Parse::parse)?;
896            for item in list {
897                v.push(item.to_string());
898            }
899
900            return Ok(TestAttr::With(v));
901        }
902
903        panic!("unknown test attribute {}", ident);
904    }
905}
906
907/// The general structure of the input to the `fnsql` macro is the following:
908///
909/// ```ignore
910/// fnsql! {
911///     #[<sql-engine-type>, [OPTIONAL: test(with=[other-function-a, other-function-b...])]]
912///     <function-name-a>(param1: type, param2: type...)
913///          [OPTIONAL: -> [(col a type, col b type, ...)]]
914///     {
915///         "SQL QUERY STRING"
916///     }
917///
918///     ...
919/// }
920/// ```
921///
922/// **For examples see the root doc of the `fnsql` crate.**
923///
924/// - Return type is optional, and only meaningful for SQL operations that return row data.
925/// - sql-engine-type: supported backends: `rusqlite` and `postgres`.
926/// - Testing is optional - you have to specific the `test` attribute for it.
927/// - With `test(with=[...])`, you specify the quries that need execution for this
928///   query to work.
929/// - The `named` attribute allows using named arguments, e.g. ':name' with `postgres` in additon to the default position-based arguments of '$1' '$2', etc.
930
931#[proc_macro]
932pub fn fnsql(input: TokenStream) -> TokenStream {
933    let queries: Queries = parse_macro_input!(input);
934    let queries: Vec<_> = queries.list.iter().map(|x| x.expand()).collect();
935
936    quote! { #(#queries)* }.into()
937}