1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![deny(
4 nonstandard_style,
5 rust_2018_idioms,
6 rustdoc::broken_intra_doc_links,
7 rustdoc::private_intra_doc_links
8)]
9#![forbid(non_ascii_idents, unsafe_code)]
10#![warn(
11 deprecated_in_future,
12 missing_copy_implementations,
13 missing_debug_implementations,
14 missing_docs,
15 unreachable_pub,
16 unused_import_braces,
17 unused_labels,
18 unused_lifetimes,
19 unused_qualifications,
20 unused_results
21)]
22#![allow(clippy::uninlined_format_args)]
23
24use std::{env, fs, path::PathBuf};
25
26use darling::FromDeriveInput;
27use proc_macro::TokenStream;
28use proc_macro2::TokenStream as TokenStream2;
29use quote::{quote, ToTokens};
30use sha2::{Digest, Sha512};
31use syn::{Data, DeriveInput};
32use tusker_query_models::{Column, Query as QueryMetadata};
33
34#[derive(FromDeriveInput)]
35#[darling(attributes(query), supports(struct_named))]
36struct QueryTraitOpts {
37 ident: syn::Ident,
38 sql: String,
39 row: syn::Path,
40}
41
42#[proc_macro_derive(Query, attributes(query))]
43pub fn derive_query(input: TokenStream) -> TokenStream {
45 let ast: DeriveInput = syn::parse(input).unwrap();
46 let opts = match QueryTraitOpts::from_derive_input(&ast) {
47 Ok(opts) => opts,
48 Err(err) => return err.write_errors().into(),
49 };
50 match expand_query(&ast, &opts) {
51 Ok(tokens) => tokens.into(),
52 Err(err) => err.to_compile_error().into(),
53 }
54}
55
56fn expand_query(ast: &DeriveInput, opts: &QueryTraitOpts) -> syn::Result<TokenStream2> {
57 let generics = ast.generics.clone();
58 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
59 let Data::Struct(s) = &ast.data else {
60 unreachable!();
61 };
62 let name = &opts.ident;
63 let sql_path = &opts.sql;
64 let row = &opts.row;
65 let params = s.fields.iter().map(|field| {
66 let field_name = field.ident.as_ref().unwrap();
67 quote! {
68 &self.#field_name
69 }
70 });
71
72 let (sidecar_validation, sidecar_dependency) =
73 if let Some(sidecar) = load_sidecar_metadata(sql_path, name)? {
74 (
75 build_query_validation(
76 s.fields.iter().map(|field| &field.ty).collect(),
77 row,
78 &sidecar,
79 )?,
80 quote! {
81 const _: &str = include_str!(concat!(
82 env!("CARGO_MANIFEST_DIR"),
83 "/db/queries/",
84 #sql_path,
85 ".json"
86 ));
87 },
88 )
89 } else {
90 (quote! {}, quote! {})
91 };
92
93 Ok(quote! {
94 impl #impl_generics ::tusker_query::Query for #name #ty_generics #where_clause {
95 const SQL: &'static str = include_str!(concat!(
96 env!("CARGO_MANIFEST_DIR"),
97 "/db/queries/",
98 #sql_path,
99 ".sql"
100 ));
101 type Row = #row;
102 fn as_params(&self) -> Box<[&(dyn ::tokio_postgres::types::ToSql + Sync)]> {
103 #sidecar_validation
104 Box::new([
105 #( #params ),*
106 ])
107 }
108 }
109
110 #sidecar_dependency
111 })
112}
113
114fn load_sidecar_metadata(
115 sql_path: &str,
116 error_target: &impl ToTokens,
117) -> syn::Result<Option<QueryMetadata>> {
118 let manifest_dir = env::var("CARGO_MANIFEST_DIR").map_err(|err| {
119 syn::Error::new_spanned(
120 error_target,
121 format!("Unable to determine CARGO_MANIFEST_DIR: {err}"),
122 )
123 })?;
124 let sql_file = PathBuf::from(&manifest_dir)
125 .join("db/queries")
126 .join(format!("{sql_path}.sql"));
127 let json_file = PathBuf::from(&manifest_dir)
128 .join("db/queries")
129 .join(format!("{sql_path}.json"));
130
131 if !json_file.exists() {
132 return Ok(None);
133 }
134
135 let sql = fs::read(&sql_file).map_err(|err| {
136 syn::Error::new_spanned(
137 error_target,
138 format!(
139 "Unable to read query SQL file {}: {err}",
140 sql_file.display()
141 ),
142 )
143 })?;
144 let json = fs::read(&json_file).map_err(|err| {
145 syn::Error::new_spanned(
146 error_target,
147 format!(
148 "Unable to read query sidecar file {}: {err}",
149 json_file.display()
150 ),
151 )
152 })?;
153 let metadata: QueryMetadata = serde_json::from_slice(&json).map_err(|err| {
154 syn::Error::new_spanned(
155 error_target,
156 format!(
157 "Unable to parse query sidecar file {}: {err}",
158 json_file.display()
159 ),
160 )
161 })?;
162
163 let mut hasher = Sha512::new();
164 hasher.update(&sql);
165 let checksum = hasher.finalize().to_vec();
166 if metadata.checksum != checksum {
167 return Err(syn::Error::new_spanned(
168 error_target,
169 format!(
170 "Query sidecar file {} is out of date. Run `tusker query sync` to refresh it.",
171 json_file.display()
172 ),
173 ));
174 }
175
176 Ok(Some(metadata))
177}
178
179fn build_query_validation(
180 field_types: Vec<&syn::Type>,
181 row: &syn::Path,
182 sidecar: &QueryMetadata,
183) -> syn::Result<TokenStream2> {
184 if sidecar.params.len() != field_types.len() {
185 return Err(syn::Error::new_spanned(
186 row,
187 format!(
188 "Query parameter count mismatch: Rust struct has {} fields but the sidecar expects {} parameters.",
189 field_types.len(),
190 sidecar.params.len()
191 ),
192 ));
193 }
194
195 let param_assertions = field_types
196 .iter()
197 .zip(sidecar.params.iter())
198 .enumerate()
199 .map(|(idx, (field_type, sql_type))| {
200 let marker = sql_type_marker(sql_type).map_err(|message| {
201 syn::Error::new_spanned(
202 field_type,
203 format!(
204 "Unsupported SQL parameter type at position {}: {message}",
205 idx + 1
206 ),
207 )
208 })?;
209 Ok(quote! {
210 __assert_param_type::<#field_type, #marker>();
211 })
212 })
213 .collect::<syn::Result<Vec<_>>>()?;
214
215 let row_assertions = sidecar
216 .columns
217 .iter()
218 .enumerate()
219 .map(|(idx, column)| build_row_assertion(row, idx, column))
220 .collect::<syn::Result<Vec<_>>>()?;
221 let row_len = sidecar.columns.len();
222
223 Ok(quote! {
224 {
225 fn __assert_param_type<T, Sql>()
226 where
227 T: ::tusker_query::types::QueryParamTyped<Sql>,
228 {
229 }
230
231 fn __assert_row_count<Row, const N: usize>()
232 where
233 Row: ::tusker_query::__private::RowFieldCount<N>,
234 {
235 }
236
237 fn __assert_row_type<Row, const I: usize, Sql>()
238 where
239 Row: ::tusker_query::__private::RowFieldType<I>,
240 <Row as ::tusker_query::__private::RowFieldType<I>>::Ty:
241 ::tusker_query::types::QueryRowTyped<Sql>,
242 {
243 }
244
245 fn __assert_nullable_row_type<Row, const I: usize, Sql>()
246 where
247 Row: ::tusker_query::__private::RowFieldType<I>,
248 <Row as ::tusker_query::__private::RowFieldType<I>>::Ty:
249 ::tusker_query::types::QueryNullableRowTyped<Sql>,
250 {
251 }
252
253 fn __assert_maybe_nullable_row_type<Row, const I: usize, Sql>()
254 where
255 Row: ::tusker_query::__private::RowFieldType<I>,
256 <Row as ::tusker_query::__private::RowFieldType<I>>::Ty:
257 ::tusker_query::types::QueryMaybeNullableRowTyped<Sql>,
258 {
259 }
260
261 #(#param_assertions)*
262 __assert_row_count::<#row, #row_len>();
263 #(#row_assertions)*
264 }
265 })
266}
267
268fn build_row_assertion(
269 row: &syn::Path,
270 index: usize,
271 column: &Column,
272) -> syn::Result<TokenStream2> {
273 let marker = sql_type_marker(&column.r#type).map_err(|message| {
274 syn::Error::new_spanned(
275 row,
276 format!(
277 "Unsupported SQL result type for column `{}` at position {}: {message}",
278 column.name,
279 index + 1
280 ),
281 )
282 })?;
283
284 Ok(match column.notnull {
285 Some(true) => {
286 quote! { __assert_row_type::<#row, #index, #marker>(); }
287 }
288 Some(false) => {
289 quote! { __assert_maybe_nullable_row_type::<#row, #index, #marker>(); }
290 }
291 None => {
292 quote! { __assert_maybe_nullable_row_type::<#row, #index, #marker>(); }
293 }
294 })
295}
296
297fn sql_type_marker(sql_type: &str) -> Result<TokenStream2, String> {
298 match sql_type {
299 "bool" => Ok(quote!(::tusker_query::types::PgBool)),
300 "char" => Ok(quote!(::tusker_query::types::PgI8)),
301 "int2" => Ok(quote!(::tusker_query::types::PgI16)),
302 "int4" => Ok(quote!(::tusker_query::types::PgI32)),
303 "int8" | "oid" => Ok(quote!(::tusker_query::types::PgI64)),
304 "float4" => Ok(quote!(::tusker_query::types::PgF32)),
305 "float8" => Ok(quote!(::tusker_query::types::PgF64)),
306 "varchar" | "bpchar" | "text" | "citext" | "name" | "unknown" | "ltree" | "lquery"
307 | "ltxtquery" => Ok(quote!(::tusker_query::types::PgString)),
308 "bytea" => Ok(quote!(::tusker_query::types::PgBytea)),
309 "hstore" => Ok(quote!(::tusker_query::types::PgHstore)),
310 "timestamp" => Ok(quote!(::tusker_query::types::PgTimestamp)),
311 "timestamptz" => Ok(quote!(::tusker_query::types::PgTimestampTz)),
312 "inet" => Ok(quote!(::tusker_query::types::PgInet)),
313 "date" => Ok(quote!(::tusker_query::types::PgDate)),
314 "time" => Ok(quote!(::tusker_query::types::PgTime)),
315 "uuid" => Ok(quote!(::tusker_query::types::PgUuid)),
316 "json" | "jsonb" => Ok(quote!(::tusker_query::types::PgJson)),
317 other => Err(format!("`{other}` is not supported yet")),
318 }
319}
320
321#[derive(FromDeriveInput)]
322#[darling(supports(struct_named))]
323struct FromRowTraitOpts {
324 ident: syn::Ident,
325}
326
327#[proc_macro_derive(FromRow)]
328pub fn derive_from_row(input: TokenStream) -> TokenStream {
330 let ast: DeriveInput = syn::parse(input).unwrap();
331 let opts = match FromRowTraitOpts::from_derive_input(&ast) {
332 Ok(opts) => opts,
333 Err(err) => return err.write_errors().into(),
334 };
335 let generics = ast.generics.clone();
336 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
337 let Data::Struct(s) = ast.data else {
338 unreachable!();
339 };
340 let name = opts.ident;
341 let fields = s.fields.iter().enumerate().map(|(idx, field)| {
342 let field_name = &field.ident;
343 quote! {
344 #field_name: row.get(#idx)
345 }
346 });
347 let field_type_assertions = s.fields.iter().enumerate().map(|(idx, field)| {
348 let field_type = &field.ty;
349 quote! {
350 impl #impl_generics ::tusker_query::__private::RowFieldType<#idx> for #name #ty_generics #where_clause {
351 type Ty = #field_type;
352 }
353 }
354 });
355 let field_count = s.fields.len();
356 quote! {
357 impl #impl_generics ::tusker_query::FromRow for #name #ty_generics #where_clause {
358 fn from_row(row: ::tokio_postgres::Row) -> Self {
359 Self {
360 #( #fields ),*
361 }
362 }
363 }
364
365 impl #impl_generics ::tusker_query::__private::RowFieldCount<#field_count> for #name #ty_generics #where_clause {}
366
367 #( #field_type_assertions )*
368 }
369 .into()
370}