Skip to main content

tusker_query_derive/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![deny(
4    nonstandard_style,
5    rust_2018_idioms,
6    rustdoc::broken_intra_doc_links,
7    rustdoc::private_intra_doc_links
8)]
9#![forbid(non_ascii_idents, unsafe_code)]
10#![warn(
11    deprecated_in_future,
12    missing_copy_implementations,
13    missing_debug_implementations,
14    missing_docs,
15    unreachable_pub,
16    unused_import_braces,
17    unused_labels,
18    unused_lifetimes,
19    unused_qualifications,
20    unused_results
21)]
22#![allow(clippy::uninlined_format_args)]
23
24use std::{env, fs, path::PathBuf};
25
26use darling::FromDeriveInput;
27use proc_macro::TokenStream;
28use proc_macro2::TokenStream as TokenStream2;
29use quote::{quote, ToTokens};
30use sha2::{Digest, Sha512};
31use syn::{Data, DeriveInput};
32use tusker_query_models::{Column, Query as QueryMetadata};
33
34#[derive(FromDeriveInput)]
35#[darling(attributes(query), supports(struct_named))]
36struct QueryTraitOpts {
37    ident: syn::Ident,
38    sql: String,
39    row: syn::Path,
40}
41
42#[proc_macro_derive(Query, attributes(query))]
43/// Derives `tusker_query::Query` for a named struct.
44pub fn derive_query(input: TokenStream) -> TokenStream {
45    let ast: DeriveInput = syn::parse(input).unwrap();
46    let opts = match QueryTraitOpts::from_derive_input(&ast) {
47        Ok(opts) => opts,
48        Err(err) => return err.write_errors().into(),
49    };
50    match expand_query(&ast, &opts) {
51        Ok(tokens) => tokens.into(),
52        Err(err) => err.to_compile_error().into(),
53    }
54}
55
56fn expand_query(ast: &DeriveInput, opts: &QueryTraitOpts) -> syn::Result<TokenStream2> {
57    let generics = ast.generics.clone();
58    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
59    let Data::Struct(s) = &ast.data else {
60        unreachable!();
61    };
62    let name = &opts.ident;
63    let sql_path = &opts.sql;
64    let row = &opts.row;
65    let params = s.fields.iter().map(|field| {
66        let field_name = field.ident.as_ref().unwrap();
67        quote! {
68            &self.#field_name
69        }
70    });
71
72    let (sidecar_validation, sidecar_dependency) =
73        if let Some(sidecar) = load_sidecar_metadata(sql_path, name)? {
74            (
75                build_query_validation(
76                    s.fields.iter().map(|field| &field.ty).collect(),
77                    row,
78                    &sidecar,
79                )?,
80                quote! {
81                    const _: &str = include_str!(concat!(
82                        env!("CARGO_MANIFEST_DIR"),
83                        "/db/queries/",
84                        #sql_path,
85                        ".json"
86                    ));
87                },
88            )
89        } else {
90            (quote! {}, quote! {})
91        };
92
93    Ok(quote! {
94        impl #impl_generics ::tusker_query::Query for #name #ty_generics #where_clause {
95            const SQL: &'static str = include_str!(concat!(
96                env!("CARGO_MANIFEST_DIR"),
97                "/db/queries/",
98                #sql_path,
99                ".sql"
100            ));
101            type Row = #row;
102            fn as_params(&self) -> Box<[&(dyn ::tokio_postgres::types::ToSql + Sync)]> {
103                #sidecar_validation
104                Box::new([
105                    #( #params ),*
106                ])
107            }
108        }
109
110        #sidecar_dependency
111    })
112}
113
114fn load_sidecar_metadata(
115    sql_path: &str,
116    error_target: &impl ToTokens,
117) -> syn::Result<Option<QueryMetadata>> {
118    let manifest_dir = env::var("CARGO_MANIFEST_DIR").map_err(|err| {
119        syn::Error::new_spanned(
120            error_target,
121            format!("Unable to determine CARGO_MANIFEST_DIR: {err}"),
122        )
123    })?;
124    let sql_file = PathBuf::from(&manifest_dir)
125        .join("db/queries")
126        .join(format!("{sql_path}.sql"));
127    let json_file = PathBuf::from(&manifest_dir)
128        .join("db/queries")
129        .join(format!("{sql_path}.json"));
130
131    if !json_file.exists() {
132        return Ok(None);
133    }
134
135    let sql = fs::read(&sql_file).map_err(|err| {
136        syn::Error::new_spanned(
137            error_target,
138            format!(
139                "Unable to read query SQL file {}: {err}",
140                sql_file.display()
141            ),
142        )
143    })?;
144    let json = fs::read(&json_file).map_err(|err| {
145        syn::Error::new_spanned(
146            error_target,
147            format!(
148                "Unable to read query sidecar file {}: {err}",
149                json_file.display()
150            ),
151        )
152    })?;
153    let metadata: QueryMetadata = serde_json::from_slice(&json).map_err(|err| {
154        syn::Error::new_spanned(
155            error_target,
156            format!(
157                "Unable to parse query sidecar file {}: {err}",
158                json_file.display()
159            ),
160        )
161    })?;
162
163    let mut hasher = Sha512::new();
164    hasher.update(&sql);
165    let checksum = hasher.finalize().to_vec();
166    if metadata.checksum != checksum {
167        return Err(syn::Error::new_spanned(
168            error_target,
169            format!(
170                "Query sidecar file {} is out of date. Run `tusker query sync` to refresh it.",
171                json_file.display()
172            ),
173        ));
174    }
175
176    Ok(Some(metadata))
177}
178
179fn build_query_validation(
180    field_types: Vec<&syn::Type>,
181    row: &syn::Path,
182    sidecar: &QueryMetadata,
183) -> syn::Result<TokenStream2> {
184    if sidecar.params.len() != field_types.len() {
185        return Err(syn::Error::new_spanned(
186            row,
187            format!(
188                "Query parameter count mismatch: Rust struct has {} fields but the sidecar expects {} parameters.",
189                field_types.len(),
190                sidecar.params.len()
191            ),
192        ));
193    }
194
195    let param_assertions = field_types
196        .iter()
197        .zip(sidecar.params.iter())
198        .enumerate()
199        .map(|(idx, (field_type, sql_type))| {
200            let marker = sql_type_marker(sql_type).map_err(|message| {
201                syn::Error::new_spanned(
202                    field_type,
203                    format!(
204                        "Unsupported SQL parameter type at position {}: {message}",
205                        idx + 1
206                    ),
207                )
208            })?;
209            Ok(quote! {
210                __assert_param_type::<#field_type, #marker>();
211            })
212        })
213        .collect::<syn::Result<Vec<_>>>()?;
214
215    let row_assertions = sidecar
216        .columns
217        .iter()
218        .enumerate()
219        .map(|(idx, column)| build_row_assertion(row, idx, column))
220        .collect::<syn::Result<Vec<_>>>()?;
221    let row_len = sidecar.columns.len();
222
223    Ok(quote! {
224        {
225            fn __assert_param_type<T, Sql>()
226            where
227                T: ::tusker_query::types::QueryParamTyped<Sql>,
228            {
229            }
230
231            fn __assert_row_count<Row, const N: usize>()
232            where
233                Row: ::tusker_query::__private::RowFieldCount<N>,
234            {
235            }
236
237            fn __assert_row_type<Row, const I: usize, Sql>()
238            where
239                Row: ::tusker_query::__private::RowFieldType<I>,
240                <Row as ::tusker_query::__private::RowFieldType<I>>::Ty:
241                    ::tusker_query::types::QueryRowTyped<Sql>,
242            {
243            }
244
245            fn __assert_nullable_row_type<Row, const I: usize, Sql>()
246            where
247                Row: ::tusker_query::__private::RowFieldType<I>,
248                <Row as ::tusker_query::__private::RowFieldType<I>>::Ty:
249                    ::tusker_query::types::QueryNullableRowTyped<Sql>,
250            {
251            }
252
253            fn __assert_maybe_nullable_row_type<Row, const I: usize, Sql>()
254            where
255                Row: ::tusker_query::__private::RowFieldType<I>,
256                <Row as ::tusker_query::__private::RowFieldType<I>>::Ty:
257                    ::tusker_query::types::QueryMaybeNullableRowTyped<Sql>,
258            {
259            }
260
261            #(#param_assertions)*
262            __assert_row_count::<#row, #row_len>();
263            #(#row_assertions)*
264        }
265    })
266}
267
268fn build_row_assertion(
269    row: &syn::Path,
270    index: usize,
271    column: &Column,
272) -> syn::Result<TokenStream2> {
273    let marker = sql_type_marker(&column.r#type).map_err(|message| {
274        syn::Error::new_spanned(
275            row,
276            format!(
277                "Unsupported SQL result type for column `{}` at position {}: {message}",
278                column.name,
279                index + 1
280            ),
281        )
282    })?;
283
284    Ok(match column.notnull {
285        Some(true) => {
286            quote! { __assert_row_type::<#row, #index, #marker>(); }
287        }
288        Some(false) => {
289            quote! { __assert_maybe_nullable_row_type::<#row, #index, #marker>(); }
290        }
291        None => {
292            quote! { __assert_maybe_nullable_row_type::<#row, #index, #marker>(); }
293        }
294    })
295}
296
297fn sql_type_marker(sql_type: &str) -> Result<TokenStream2, String> {
298    match sql_type {
299        "bool" => Ok(quote!(::tusker_query::types::PgBool)),
300        "char" => Ok(quote!(::tusker_query::types::PgI8)),
301        "int2" => Ok(quote!(::tusker_query::types::PgI16)),
302        "int4" => Ok(quote!(::tusker_query::types::PgI32)),
303        "int8" | "oid" => Ok(quote!(::tusker_query::types::PgI64)),
304        "float4" => Ok(quote!(::tusker_query::types::PgF32)),
305        "float8" => Ok(quote!(::tusker_query::types::PgF64)),
306        "varchar" | "bpchar" | "text" | "citext" | "name" | "unknown" | "ltree" | "lquery"
307        | "ltxtquery" => Ok(quote!(::tusker_query::types::PgString)),
308        "bytea" => Ok(quote!(::tusker_query::types::PgBytea)),
309        "hstore" => Ok(quote!(::tusker_query::types::PgHstore)),
310        "timestamp" => Ok(quote!(::tusker_query::types::PgTimestamp)),
311        "timestamptz" => Ok(quote!(::tusker_query::types::PgTimestampTz)),
312        "inet" => Ok(quote!(::tusker_query::types::PgInet)),
313        "date" => Ok(quote!(::tusker_query::types::PgDate)),
314        "time" => Ok(quote!(::tusker_query::types::PgTime)),
315        "uuid" => Ok(quote!(::tusker_query::types::PgUuid)),
316        "json" | "jsonb" => Ok(quote!(::tusker_query::types::PgJson)),
317        other => Err(format!("`{other}` is not supported yet")),
318    }
319}
320
321#[derive(FromDeriveInput)]
322#[darling(supports(struct_named))]
323struct FromRowTraitOpts {
324    ident: syn::Ident,
325}
326
327#[proc_macro_derive(FromRow)]
328/// Derives `tusker_query::FromRow` for a named struct.
329pub fn derive_from_row(input: TokenStream) -> TokenStream {
330    let ast: DeriveInput = syn::parse(input).unwrap();
331    let opts = match FromRowTraitOpts::from_derive_input(&ast) {
332        Ok(opts) => opts,
333        Err(err) => return err.write_errors().into(),
334    };
335    let generics = ast.generics.clone();
336    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
337    let Data::Struct(s) = ast.data else {
338        unreachable!();
339    };
340    let name = opts.ident;
341    let fields = s.fields.iter().enumerate().map(|(idx, field)| {
342        let field_name = &field.ident;
343        quote! {
344            #field_name: row.get(#idx)
345        }
346    });
347    let field_type_assertions = s.fields.iter().enumerate().map(|(idx, field)| {
348        let field_type = &field.ty;
349        quote! {
350            impl #impl_generics ::tusker_query::__private::RowFieldType<#idx> for #name #ty_generics #where_clause {
351                type Ty = #field_type;
352            }
353        }
354    });
355    let field_count = s.fields.len();
356    quote! {
357        impl #impl_generics ::tusker_query::FromRow for #name #ty_generics #where_clause {
358            fn from_row(row: ::tokio_postgres::Row) -> Self {
359                Self {
360                    #( #fields ),*
361                }
362            }
363        }
364
365        impl #impl_generics ::tusker_query::__private::RowFieldCount<#field_count> for #name #ty_generics #where_clause {}
366
367        #( #field_type_assertions )*
368    }
369    .into()
370}