Skip to main content

fnsql_macro/
lib.rs

1//! The `fnsql` crate provides simple type-safe optional wrappers around SQL
2//! queries. Instead of calling type-less `.query()` and `.execute()`, you call to
3//! auto-generated unique wrappers that are strongly typed, `.query_<name>()` and
4//! `.execute_<name>()`. However, you manually specify the input and output types,
5//! but only once, with the query, and in separation with the code that uses the
6//! query.
7//!
8//! It's a very simple implementation that doesn't force any schema or ORM down
9//! your throat, so if you are already using the `rusqlite` or `postgres` crates,
10//! you can gradually replace your type-less queries with the type-ful wrappers,
11//! or migrate from an opinionated ORM.
12//!
13//! The way to generate these wrappers is to specify input and output types for
14//! each one of the queries. For example, consider the following definitions
15//! specified with `fnsql`, based on the `rusqlite` example:
16//!
17//! ```rust
18//! fnsql::fnsql! {
19//!     #[rusqlite, test]
20//!     create_table_pet() {
21//!         "CREATE TABLE pet (
22//!               id      INTEGER PRIMARY KEY,
23//!               name    TEXT NOT NULL,
24//!               data    BLOB
25//!         )"
26//!     }
27//!
28//!     #[rusqlite, test(with=[create_table_pet])]
29//!     insert_new_pet(name: String, data: Option<Vec<u8>>) {
30//!         "INSERT INTO pet (name, data) VALUES (:name, :data)"
31//!     }
32//!
33//!     #[rusqlite, test(with=[create_table_pet])]
34//!     get_pet_id_data(name: Option<String>) -> [(i32, Option<Vec<u8>>, String)] {
35//!         "SELECT id, data, name FROM pet WHERE pet.name = :name"
36//!     }
37//! }
38//! ```
39//!
40//! The definitions can be used as such (commented out is how the previous
41//! type-less interfaces were used):
42//!
43//! ```rust ignore
44//! let mut conn = rusqlite::Connection::open_in_memory()?;
45//!
46//! conn.execute_create_table_pet()?;
47//! // conn.execute(
48//! //    "CREATE TABLE pet (
49//! //               id              INTEGER PRIMARY KEY,
50//! //               name            TEXT NOT NULL,
51//! //               data            BLOB
52//! //               )",
53//! //     [],
54//! // )?;
55//!
56//! conn.execute_insert_new_pet(&me.name, &me.data)?;
57//! // conn.execute(
58//! //     "INSERT INTO pet (name, data) VALUES (?1, ?2)",
59//! //     params![me.name, me.data],
60//! // )?;
61//!
62//! let mut stmt = conn.prepare_get_pet_id_data()?;
63//! // let mut stmt = conn.prepare("SELECT id, data, name FROM pet WHERE pet.name = :name")?;
64//!
65//! let pet_iter = stmt.query_map(&Some("Max".to_string()), |id, data, name| {
66//!     Ok::<_, rusqlite::Error>(Pet {
67//!         id,
68//!         data,
69//!         name,
70//!     })
71//! })?;
72//! // let pet_iter = stmt.query_map([(":name", "Max".to_string())], |row| {
73//! //     Ok(Pet {
74//! //         id: row.get(0)?,
75//! //         name: row.get(1)?,
76//! //         data: row.get(2)?,
77//! //     })
78//! // })?;
79//! ```
80//!
81//! ## Technical discussion
82//!
83//! The idea with this crate is to allow direct SQL usage but never use inline
84//! queries or have type inference at the call-site. Instead, we declare each query
85//! on top-level, giving each a name and designated accessor methods that derive
86//! from the name.
87//!
88//! - The types of named variables are give in a Rust-like syntax.
89//! - The type of the returned row is also provided.
90//! - `fnsql` does not make an assurances to make sure the types match the query,
91//!   you will discover it with `cargo test` and no additional code.
92//! - `fnsql` writes the tests for each of the queries.  - `Arbitrary` is used to
93//!   generate parameter values.
94//! - If testing one query depend on another, you can specify that with `test(with=[..])`.
95//!
96//! ```text
97//! running 3 tests
98//! test auto_create_table_pet ... ok
99//! test auto_insert_new_pet ... ok
100//! test auto_get_pet_id_data ... ok
101//! ```
102//!
103//! The following is for allowing generated query tests to compile:
104//!
105//! ```toml
106//! [dev-dependencies]
107//! arbitrary = { version = "1", features = ["derive"] }
108//! ```
109//!
110//! ## Limitations
111//!
112//!  * Though it <i>does</i> provide auto-generated tests for validating queries in `cargo test`,
113//!    it does not do any compile-time validation based on the SQL query string.
114//!  * It only supports `rusqlite` and `postgres` for now.
115
116extern crate proc_macro;
117
118use std::collections::HashMap;
119
120use proc_macro::TokenStream;
121use proc_macro2::{Span, TokenStream as Tokens};
122use quote::{quote, ToTokens};
123use regex::{Captures, Regex};
124use syn::{
125    braced, bracketed, parenthesized,
126    parse::{Parse, ParseStream},
127    parse_macro_input,
128    punctuated::Punctuated,
129    token, Ident, LitStr, Token,
130};
131
132struct Queries {
133    list: Vec<Query>,
134}
135
136impl Parse for Queries {
137    fn parse(input: ParseStream) -> syn::Result<Self> {
138        let mut list = vec![];
139        while !input.is_empty() {
140            list.push(input.parse()?)
141        }
142
143        Ok(Queries { list })
144    }
145}
146
147enum Kind {
148    Rusqlite,
149    PostgreSQL,
150}
151
152struct Query {
153    name: Ident,
154    params: Vec<Param>,
155    outputs: Vec<Output>,
156    query: syn::LitStr,
157    kind: Kind,
158    test: Option<Vec<String>>,
159    named: bool,
160    conststr: Option<String>,
161}
162
163impl Parse for Query {
164    fn parse(input: ParseStream) -> syn::Result<Self> {
165        let mut kind = None;
166        let mut test = None;
167        let mut named = false;
168        let mut conststr = None;
169
170        if input.peek(Token![#]) {
171            let _: Token![#] = input.parse()?;
172            let content;
173            let _ = bracketed!(content in input);
174            let list: Punctuated<Attr, Token![,]> = content.parse_terminated(Parse::parse)?;
175
176            for attr in list {
177                match attr {
178                    Attr::Kind(attr_kind) => {
179                        kind = Some(attr_kind);
180                    }
181                    Attr::Test(test_attrs) => {
182                        if test.is_none() {
183                            test = Some(vec![]);
184                        }
185                        for test_attr in test_attrs {
186                            match test_attr {
187                                TestAttr::With(v) => {
188                                    test.as_mut().unwrap().extend(v);
189                                }
190                            }
191                        }
192                    }
193                    Attr::Named => {
194                        named = true;
195                    }
196                    Attr::ConstStr(v) => {
197                        conststr = Some(v);
198                    }
199                }
200            }
201        };
202
203        let name = input.parse()?;
204        let kind = match kind {
205            None => panic!("unknown SQL type. Supported: rusqlite"),
206            Some(kind) => kind,
207        };
208        let content;
209        let _ = parenthesized!(content in input);
210        let list: Punctuated<_, Token![,]> = content.parse_terminated(Parse::parse)?;
211        let params = list.into_iter().collect();
212
213        let outputs = if input.peek(Token![->]) {
214            let _: Token![->] = input.parse()?;
215
216            let content;
217            let _ = bracketed!(content in input);
218            {
219                let sub_content;
220                let _ = parenthesized!(sub_content in content);
221                let list: Punctuated<_, Token![,]> = sub_content.parse_terminated(Parse::parse)?;
222                list.into_iter().collect()
223            }
224        } else {
225            vec![]
226        };
227
228        let content;
229        let _ = braced!(content in input);
230        let query = content.parse::<syn::LitStr>()?;
231
232        Ok(Query {
233            name,
234            params,
235            outputs,
236            query,
237            kind,
238            test,
239            named,
240            conststr,
241        })
242    }
243}
244
245impl Query {
246    fn prepend_name(&self, prefix: &'static str) -> Ident {
247        Ident::new(&format!("{}{}", prefix, &self.name), self.name.span())
248    }
249
250    fn params_declr(&self) -> Tokens {
251        let list: Vec<_> = self.params.iter().map(|x| x.expand_declr()).collect();
252        quote! { #(, #list)* }
253    }
254
255    fn outputs_declr(&self) -> Tokens {
256        let list: Vec<_> = self.outputs.iter().map(|x| x.expand_declr()).collect();
257        quote! { #(#list),* }
258    }
259
260    fn outputs_row_get_numbered(&self) -> Tokens {
261        let list: Vec<_> = self
262            .outputs
263            .iter()
264            .enumerate()
265            .map(|(i, _)| {
266                let i = syn::LitInt::new(&format!("{}", i), self.name.span());
267                quote! {row.get(#i)?}
268            })
269            .collect();
270
271        quote! { #(#list),* }
272    }
273
274    fn outputs_row_try_get_numbered(&self) -> Tokens {
275        let list: Vec<_> = self
276            .outputs
277            .iter()
278            .enumerate()
279            .map(|(i, _)| {
280                let i = syn::LitInt::new(&format!("{}", i), self.name.span());
281                quote! {row.try_get(#i)?}
282            })
283            .collect();
284
285        quote! { #(#list),* }
286    }
287
288    fn outputs_mapped_row_closure(&self) -> Tokens {
289        let list = self.outputs_row_get_numbered();
290        quote! { Ok(map(#list)) }
291    }
292
293    fn params_arbitrary(&self) -> (Tokens, Tokens) {
294        let mut gen_lets = vec![];
295        let mut params = vec![];
296
297        let _ = self
298            .params
299            .iter()
300            .enumerate()
301            .map(|(idx, param)| {
302                let ttype = &param.ttype;
303                let owned_ttype = if ttype.to_token_stream().to_string() == "str" {
304                    quote! {String}
305                } else if ttype.to_token_stream().to_string() == "[u8]" {
306                    quote! {Vec<u8>}
307                } else {
308                    quote! {#ttype}
309                };
310                let ident = Ident::new(&format!("i_{}", idx), self.name.span());
311
312                gen_lets.push(quote! {
313                    let #ident: #owned_ttype = arbitrary::Arbitrary::arbitrary(uns).unwrap();
314                });
315                params.push(quote! {&#ident});
316            })
317            .collect::<Vec<()>>();
318
319        (quote! { #(#gen_lets);* }, quote! { #(#params),* })
320    }
321
322    fn params_query(&self) -> Tokens {
323        let list: Vec<_> = self.params.iter().map(|x| x.expand_query(self)).collect();
324        if list.len() == 0 {
325            quote! { [] }
326        } else {
327            quote! { &[#(#list),*] }
328        }
329    }
330
331    fn params_query_ref(&self) -> Tokens {
332        let list: Vec<_> = self.params.iter().map(|x| x.expand_query(self)).collect();
333        if list.len() == 0 {
334            quote! { &[] }
335        } else {
336            quote! { &[#(#list),*] }
337        }
338    }
339
340    fn params_relay(&self) -> Tokens {
341        let list: Vec<_> = self
342            .params
343            .iter()
344            .map(|x| {
345                let name = &x.name;
346                quote! { #name }
347            })
348            .collect();
349        if list.len() == 0 {
350            quote! {}
351        } else {
352            quote! { #(#list),*, }
353        }
354    }
355
356    fn expand(&self) -> Tokens {
357        match self.kind {
358            Kind::Rusqlite => self.sqlite_expand(),
359            Kind::PostgreSQL => self.postgres_expand(),
360        }
361    }
362
363    fn postgres_expand(&self) -> Tokens {
364        #[allow(non_snake_case)]
365        let Client = self.prepend_name("Client_");
366        #[allow(non_snake_case)]
367        let Statement = self.prepend_name("Statement_");
368        let execute_name = self.prepend_name("execute_");
369        let execute_prepared_name = self.prepend_name("execute_prepared_");
370        let prepare_name = self.prepend_name("prepare_");
371        let prepare_cached_name = self.prepend_name("prepare_cached_");
372        let convert_row = self.prepend_name("convert_row_");
373        let query_name = self.prepend_name("query_");
374        let query_prepared_name = self.prepend_name("query_prepared_");
375        let query_one_name = self.prepend_name("query_one_");
376        let query_one_prepared_name = self.prepend_name("query_one_prepared_");
377        let query_opt_name = self.prepend_name("query_opt_");
378        let query_opt_prepared_name = self.prepend_name("query_opt_prepared_");
379        let params_declr = self.params_declr();
380        let params_query_ref = self.params_query_ref();
381        let outputs_declr = self.outputs_declr();
382        let row_try_get_numbered = self.outputs_row_try_get_numbered();
383
384        let query;
385        if self.named {
386            lazy_static::lazy_static! {
387                static ref RE: Regex = Regex::new(":([A-Za-z_][_A-Za-z0-9]*)($|[^_A-Za-z0-9])").unwrap();
388            }
389
390            let params: HashMap<_, _> = self
391                .params
392                .iter()
393                .enumerate()
394                .map(|(idx, param)| (format!("{}", param.name), idx))
395                .collect();
396
397            query = String::from(RE.replace_all(&self.query.value(), |captures: &Captures| {
398                let c1 = captures.get(1).unwrap().as_str();
399                let c2 = captures.get(2).unwrap().as_str();
400                match params.get(c1) {
401                    Some(idx) => format!("${}{}", idx + 1, c2),
402                    None => format!("{}{}", c1, c2),
403                }
404            }));
405        } else {
406            query = self.query.value();
407        };
408        let query = LitStr::new(query.as_str(), self.query.span());
409
410        let const_str = self.conststr.as_ref().map(|name| {
411            let ident = Ident::new(name, Span::call_site());
412            quote! { pub const #ident: &str = #query; }
413        });
414
415        #[cfg(feature = "prepare-cache")]
416        let (prepare_cached_decl, prepare_cached_impl) = {
417            let prepare_cached_decl = quote! {
418                fn #prepare_cached_name(&mut self, cache: &mut fnsql::postgres::Cache) -> Result<#Statement, postgres::Error>;
419            };
420
421            let prepare_cached_impl = quote! {
422                fn #prepare_cached_name(&mut self, cache: &mut fnsql::postgres::Cache) -> Result<#Statement, postgres::Error> {
423                    Ok(#Statement(cache.prepare(#query, self)?))
424                }
425            };
426
427            (prepare_cached_decl, prepare_cached_impl)
428        };
429
430        #[cfg(not(feature = "prepare-cache"))]
431        let (prepare_cached_decl, prepare_cached_impl) = { (quote! {}, quote! {}) };
432
433        let defs = quote! {
434            #[allow(non_camel_case_types)]
435            pub struct #Statement(pub postgres::Statement);
436
437            #[allow(non_camel_case_types)]
438            pub trait #Client {
439                fn #prepare_name(&mut self) -> Result<#Statement, postgres::Error>;
440                #prepare_cached_decl
441                fn #execute_name(&mut self #params_declr) -> Result<u64, postgres::Error>;
442                fn #execute_prepared_name(&mut self, stmt: &#Statement #params_declr)
443                    -> Result<u64, postgres::Error>;
444                fn #query_name(&mut self #params_declr) -> Result<Vec<(#outputs_declr)>, postgres::Error>;
445                fn #query_prepared_name(&mut self, stmt: &#Statement #params_declr) -> Result<Vec<(#outputs_declr)>, postgres::Error>;
446                fn #query_one_name(&mut self #params_declr) -> Result<(#outputs_declr), postgres::Error>;
447                fn #query_one_prepared_name(&mut self, stmt: &#Statement #params_declr) -> Result<(#outputs_declr), postgres::Error>;
448                fn #query_opt_name(&mut self #params_declr) -> Result<Option<(#outputs_declr)>, postgres::Error>;
449                fn #query_opt_prepared_name(&mut self, stmt: &#Statement #params_declr) -> Result<Option<(#outputs_declr)>, postgres::Error>;
450            }
451
452            pub fn #convert_row(row: postgres::Row) -> Result<(#outputs_declr), postgres::Error> {
453                Ok((#row_try_get_numbered))
454            }
455        };
456
457        let timpl = quote! {
458            fn #prepare_name(&mut self)  -> Result<#Statement, postgres::Error> {
459                self.prepare(#query).map(#Statement)
460            }
461
462            #prepare_cached_impl
463
464            fn #execute_name(&mut self #params_declr) -> Result<u64, postgres::Error> {
465                self.execute(#query, #params_query_ref)
466            }
467
468            fn #execute_prepared_name(&mut self, stmt: &#Statement #params_declr)
469                -> Result<u64, postgres::Error>
470            {
471                self.execute(&stmt.0, #params_query_ref)
472            }
473
474            fn #query_name(&mut self #params_declr) -> Result<Vec<(#outputs_declr)>, postgres::Error> {
475                let result: Result<Vec<_>, postgres::Error> =
476                    self.query(#query, #params_query_ref)?.into_iter().map(#convert_row).collect();
477                result
478            }
479
480            fn #query_one_name(&mut self #params_declr) -> Result<(#outputs_declr), postgres::Error> {
481                Ok(#convert_row(self.query_one(#query, #params_query_ref)?)?)
482            }
483
484            fn #query_prepared_name(&mut self, stmt: &#Statement #params_declr) -> Result<Vec<(#outputs_declr)>, postgres::Error> {
485                let result: Result<Vec<_>, postgres::Error> =
486                    self.query(&stmt.0, #params_query_ref)?.into_iter().map(#convert_row).collect();
487                result
488            }
489
490            fn #query_one_prepared_name(&mut self, stmt: &#Statement #params_declr) -> Result<(#outputs_declr), postgres::Error> {
491                Ok(#convert_row(self.query_one(&stmt.0, #params_query_ref)?)?)
492            }
493
494            fn #query_opt_name(&mut self #params_declr) -> Result<Option<(#outputs_declr)>, postgres::Error> {
495                match self.query_opt(#query, #params_query_ref)? {
496                    None => Ok(None),
497                    Some(x) => Ok(Some(#convert_row(x)?)),
498                }
499            }
500
501            fn #query_opt_prepared_name(&mut self, stmt: &#Statement #params_declr) -> Result<Option<(#outputs_declr)>, postgres::Error> {
502                match self.query_opt(&stmt.0, #params_query_ref)? {
503                    None => Ok(None),
504                    Some(x) => Ok(Some(#convert_row(x)?)),
505                }
506            }
507        };
508
509        let test_code = self.test_code();
510
511        quote! {
512            #const_str
513            #defs
514
515            impl #Client for postgres::Client {
516                #timpl
517            }
518
519            impl<'a> #Client for postgres::Transaction<'a> {
520                #timpl
521            }
522
523            #test_code
524        }
525    }
526
527    fn sqlite_expand(&self) -> Tokens {
528        let conn_trait_name = self.prepend_name("Connection_");
529        #[allow(non_snake_case)]
530        let StatementType = self.prepend_name("Statement_");
531        #[allow(non_snake_case)]
532        let CachedStatementType = self.prepend_name("CachedStatement_");
533        #[allow(non_snake_case)]
534        let MappedRows = self.prepend_name("MappedRows_");
535        #[allow(non_snake_case)]
536        let Rows = self.prepend_name("Rows_");
537        let prepare_name = self.prepend_name("prepare_");
538        let prepare_cached_name = self.prepend_name("prepare_cached_");
539        let execute_name = self.prepend_name("execute_");
540        let query_row_name = self.prepend_name("query_row_");
541        let params_declr = self.params_declr();
542        let outputs_declr = self.outputs_declr();
543        let row_closure = self.outputs_row_get_numbered();
544        let mapped_row_closure = self.outputs_mapped_row_closure();
545        let params_query = self.params_query();
546        let params_relay = self.params_relay();
547        let query = &self.query;
548
549        let test_code = self.test_code();
550
551        let const_str = self.conststr.as_ref().map(|name| {
552            let ident = Ident::new(name, Span::call_site());
553            quote! { pub const #ident: &str = #query; }
554        });
555
556        quote! {
557            #const_str
558            #[allow(non_camel_case_types)]
559            pub trait #conn_trait_name {
560                fn #prepare_name(&self) -> rusqlite::Result<#StatementType<'_>>;
561                fn #prepare_cached_name(&self) -> rusqlite::Result<#CachedStatementType<'_>>;
562                fn #execute_name(&self #params_declr) -> rusqlite::Result<usize>;
563                fn #query_row_name<F, T>(&mut self #params_declr, f: F) -> rusqlite::Result<T>
564                where
565                    F: FnMut(#outputs_declr) -> T;
566            }
567
568            impl #conn_trait_name for rusqlite::Connection {
569                fn #prepare_name(&self) -> rusqlite::Result<#StatementType<'_>> {
570                    self.prepare(#query).map(#StatementType)
571                }
572
573                fn #prepare_cached_name(&self) -> rusqlite::Result<#CachedStatementType<'_>> {
574                    self.prepare_cached(#query).map(#CachedStatementType)
575                }
576
577                fn #execute_name(&self #params_declr) -> rusqlite::Result<usize> {
578                    self.execute(#query, #params_query)
579                }
580
581                fn #query_row_name<F, T>(&mut self #params_declr, f: F) -> rusqlite::Result<T>
582                where
583                    F: FnMut(#outputs_declr) -> T,
584                {
585                    let mut stmt = self.#prepare_name()?;
586                    stmt.query_row(#params_relay f)
587                }
588            }
589
590            #[allow(non_camel_case_types)]
591            pub struct #MappedRows<'stmt, F> {
592                rows: rusqlite::Rows<'stmt>,
593                map: F,
594            }
595
596            impl<'stmt, T, F> #MappedRows<'stmt, F>
597            where
598                F: FnMut(#outputs_declr) -> T
599            {
600                pub(crate) fn new(rows: rusqlite::Rows<'stmt>, f: F) -> Self {
601                    Self { rows, map: f }
602                }
603            }
604
605            impl<'stmt, T, F> Iterator for #MappedRows<'stmt, F>
606            where
607                F: FnMut(#outputs_declr) -> T
608            {
609                type Item = rusqlite::Result<T>;
610
611                fn next(&mut self) -> Option<rusqlite::Result<T>> {
612                    let map = &mut self.map;
613                    self.rows
614                        .next()
615                        .transpose()
616                        .map(|row_result| {
617                            row_result.and_then(|row| {
618                                #mapped_row_closure
619                            })
620                        })
621                }
622            }
623
624            #[allow(non_camel_case_types)]
625            pub struct #Rows<'stmt> {
626                rows: rusqlite::Rows<'stmt>,
627            }
628
629            impl<'stmt> #Rows<'stmt> {
630                pub(crate) fn new(rows: rusqlite::Rows<'stmt>) -> Self {
631                    Self { rows }
632                }
633            }
634
635            impl<'stmt> Iterator for #Rows<'stmt> {
636                type Item = rusqlite::Result<(#outputs_declr)>;
637
638                fn next(&mut self) -> Option<Self::Item> {
639                    self.rows
640                        .next()
641                        .transpose()
642                        .map(|row_result| {
643                            row_result.and_then(|row| {
644                                Ok((#row_closure))
645                            })
646                        })
647                }
648            }
649
650            #[allow(non_camel_case_types)]
651            pub struct #StatementType<'a>(pub rusqlite::Statement<'a>);
652
653            impl<'a> #StatementType<'a> {
654                fn query_map<F, T>(&mut self #params_declr, f: F) -> rusqlite::Result<#MappedRows<'_, F>>
655                where
656                    F: FnMut(#outputs_declr) -> T,
657                {
658                    let rows = self.0.query(#params_query)?;
659                    Ok(#MappedRows::new(rows, f))
660                }
661
662                fn query_row<F, T>(&mut self #params_declr, f: F) -> rusqlite::Result<T>
663                where
664                    F: FnMut(#outputs_declr) -> T,
665                {
666                    let rows = self.query_map(#params_relay f)?;
667                    for item in rows {
668                        return Ok(item?);
669                    }
670                    Err(rusqlite::Error::QueryReturnedNoRows)
671                }
672
673                fn query(&mut self #params_declr) -> rusqlite::Result<#Rows<'_>> {
674                    let rows = self.0.query(#params_query)?;
675                    Ok(#Rows::new(rows))
676                }
677
678                fn execute(&mut self #params_declr) -> rusqlite::Result<()> {
679                    self.0.execute(#params_query)?;
680                    Ok(())
681                }
682            }
683
684            #[allow(non_camel_case_types)]
685            pub struct #CachedStatementType<'a>(pub rusqlite::CachedStatement<'a>);
686
687            impl<'a> #CachedStatementType<'a> {
688                fn query_map<F, T>(&mut self #params_declr, f: F) -> rusqlite::Result<#MappedRows<'_, F>>
689                where
690                    F: FnMut(#outputs_declr) -> T,
691                {
692                    let rows = self.0.query(#params_query)?;
693                    Ok(#MappedRows::new(rows, f))
694                }
695
696                fn query_row<F, T>(&mut self #params_declr, f: F) -> rusqlite::Result<T>
697                where
698                    F: FnMut(#outputs_declr) -> T,
699                {
700                    let rows = self.query_map(#params_relay f)?;
701                    for item in rows {
702                        return Ok(item?);
703                    }
704                    Err(rusqlite::Error::QueryReturnedNoRows)
705                }
706
707                fn query(&mut self #params_declr) -> rusqlite::Result<#Rows<'_>> {
708                    let rows = self.0.query(#params_query)?;
709                    Ok(#Rows::new(rows))
710                }
711
712                fn execute(&mut self #params_declr) -> rusqlite::Result<()> {
713                    self.0.execute(#params_query)?;
714                    Ok(())
715                }
716            }
717
718            #test_code
719        }
720    }
721
722    fn test_code(&self) -> Tokens {
723        let test_name = self.prepend_name("auto_");
724        let testsetup_name = self.prepend_name("testsetup_");
725        let (params_arbit_prep, params_arbit) = self.params_arbitrary();
726        let execute_name = self.prepend_name("execute_");
727        let name = syn::LitStr::new(&self.name.to_string(), self.name.span());
728
729        let client_type = match self.kind {
730            Kind::Rusqlite => quote! {rusqlite::Connection},
731            Kind::PostgreSQL => quote! {postgres::Client},
732        };
733        let client_ref_type = match self.kind {
734            Kind::Rusqlite => quote! {&},
735            Kind::PostgreSQL => quote! {&mut},
736        };
737        let ignore_error = match self.kind {
738            Kind::Rusqlite => quote! {Err(rusqlite::Error::ExecuteReturnedResults) => {}},
739            Kind::PostgreSQL => quote! {},
740        };
741        let error_type = match self.kind {
742            Kind::Rusqlite => quote! {rusqlite::Error},
743            Kind::PostgreSQL => quote! {postgres::Error},
744        };
745        let open_client = match self.kind {
746            Kind::Rusqlite => quote! {
747                let conn = #client_type::open_in_memory()?;
748            },
749            Kind::PostgreSQL => quote! {let mut conn = {
750                let mut conn = fnsql::postgres::testing_client().expect("unable to connect testing client");
751                conn.execute("SET search_path TO pg_temp", &[]).unwrap();
752                conn
753            }; },
754        };
755
756        let test = if let Some(depends) = &self.test {
757            let depends = depends.iter().map(|name| {
758                let parent_testsetup_name =
759                    Ident::new(&format!("testsetup_{}", name), self.name.span());
760                quote! {
761                    #parent_testsetup_name(uns, deps, conn)?;
762                }
763            });
764            quote! {
765                #[cfg(test)]
766                fn #testsetup_name(
767                    uns: &mut arbitrary::Unstructured,
768                    deps: &mut std::collections::HashSet<&'static str>,
769                    conn: #client_ref_type #client_type) -> Result<(), #error_type>
770                {
771                    if !deps.insert(#name) {
772                        return Ok(());
773                    }
774
775                    #(#depends);*
776
777                    #params_arbit_prep;
778                    let r = conn.#execute_name(#params_arbit);
779                    match r {
780                        Ok(_) => {}
781                        #ignore_error
782                        Err(err) => {
783                            eprintln!("{:?}", err);
784                            Err(err)?;
785                        },
786                    }
787                    Ok(())
788                }
789
790                #[test]
791                fn #test_name() -> Result<(), #error_type> {
792                    #open_client;
793                    let mut deps = std::collections::HashSet::new();
794                    let raw_data: &[u8] = &[1, 2, 3];
795                    let mut unstructured = arbitrary::Unstructured::new(raw_data);
796
797                    #testsetup_name(&mut unstructured, &mut deps, #client_ref_type conn)?;
798                    Ok(())
799                }
800            }
801        } else {
802            quote! {}
803        };
804        test
805    }
806}
807
808struct Output {
809    ttype: syn::Type,
810}
811
812impl Parse for Output {
813    fn parse(input: ParseStream) -> syn::Result<Self> {
814        let ttype = input.parse()?;
815
816        Ok(Self { ttype })
817    }
818}
819
820impl Output {
821    fn expand_declr(&self) -> Tokens {
822        let ttype = &self.ttype;
823
824        quote! { #ttype }
825    }
826}
827
828struct Param {
829    name: Ident,
830    ttype: syn::Type,
831}
832
833impl Parse for Param {
834    fn parse(input: ParseStream) -> syn::Result<Self> {
835        let name = input.parse()?;
836        let _: Token![:] = input.parse()?;
837        let ttype = input.parse()?;
838
839        Ok(Self { name, ttype })
840    }
841}
842
843impl Param {
844    fn expand_declr(&self) -> Tokens {
845        let name = &self.name;
846        let ttype = &self.ttype;
847
848        quote! { #name: &#ttype }
849    }
850
851    fn expand_query(&self, query: &Query) -> Tokens {
852        let name = &self.name;
853        let specifier = syn::LitStr::new(&format!(":{}", name), name.span());
854
855        match query.kind {
856            Kind::Rusqlite => quote! { (#specifier, &#name as &dyn rusqlite::ToSql) },
857            Kind::PostgreSQL => quote! { &#name as &(dyn postgres::types::ToSql + Sync) },
858        }
859    }
860}
861
862enum Attr {
863    Kind(Kind),
864    Test(Vec<TestAttr>),
865    Named,
866    ConstStr(String),
867}
868
869impl Parse for Attr {
870    fn parse(input: ParseStream) -> syn::Result<Self> {
871        let ident: Ident = input.parse()?;
872        if ident == "rusqlite" {
873            return Ok(Attr::Kind(Kind::Rusqlite));
874        }
875        if ident == "postgres" {
876            return Ok(Attr::Kind(Kind::PostgreSQL));
877        }
878        if ident == "named" {
879            return Ok(Attr::Named);
880        }
881        if ident == "test" {
882            let mut v = vec![];
883
884            if input.peek(token::Paren) {
885                let content;
886                let _ = parenthesized!(content in input);
887                let list: Punctuated<TestAttr, Token![,]> =
888                    content.parse_terminated(Parse::parse)?;
889                v = list.into_iter().collect();
890            };
891
892            return Ok(Attr::Test(v));
893        }
894        if ident == "conststr" {
895            let _: Token![=] = input.parse()?;
896            let name: Ident = input.parse()?;
897            return Ok(Attr::ConstStr(name.to_string()));
898        }
899        panic!("unknown attribute {}", ident);
900    }
901}
902
903enum TestAttr {
904    With(Vec<String>),
905}
906
907impl Parse for TestAttr {
908    fn parse(input: ParseStream) -> syn::Result<Self> {
909        let ident: Ident = input.parse()?;
910        if ident == "with" {
911            let mut v = vec![];
912
913            let _: Token![=] = input.parse()?;
914            let content;
915            let _ = bracketed!(content in input);
916            let list: Punctuated<Ident, Token![,]> = content.parse_terminated(Parse::parse)?;
917            for item in list {
918                v.push(item.to_string());
919            }
920
921            return Ok(TestAttr::With(v));
922        }
923
924        panic!("unknown test attribute {}", ident);
925    }
926}
927
928/// The general structure of the input to the `fnsql` macro is the following:
929///
930/// ```ignore
931/// fnsql! {
932///     #[<sql-engine-type>, [OPTIONAL: test(with=[other-function-a, other-function-b...])], [OPTIONAL: conststr=<const-name>]]
933///     <function-name-a>(param1: type, param2: type...)
934///          [OPTIONAL: -> [(col a type, col b type, ...)]]
935///     {
936///         "SQL QUERY STRING"
937///     }
938///
939///     ...
940/// }
941/// ```
942///
943/// **For examples see the root doc of the `fnsql` crate.**
944///
945/// - Return type is optional, and only meaningful for SQL operations that return row data.
946/// - sql-engine-type: supported backends: `rusqlite` and `postgres`.
947/// - Testing is optional - you have to specific the `test` attribute for it.
948/// - With `test(with=[...])`, you specify the quries that need execution for this
949///   query to work.
950/// - The `named` attribute allows using named arguments, e.g. ':name' with `postgres` in additon to the default position-based arguments of '$1' '$2', etc.
951/// - The `conststr=<name>` attribute generates a `pub const <name>: &str = "SQL";` at the top level.
952
953#[proc_macro]
954pub fn fnsql(input: TokenStream) -> TokenStream {
955    let queries: Queries = parse_macro_input!(input);
956    let queries: Vec<_> = queries.list.iter().map(|x| x.expand()).collect();
957
958    quote! { #(#queries)* }.into()
959}