Skip to main content

sqlx_data_macros/
lib.rs

1use proc_macro::TokenStream;
2use syn::{TraitItemFn, parse_macro_input};
3
4mod alias_system;
5mod code_generator;
6mod constants;
7mod dml;
8mod error;
9mod fetch;
10mod method_variants;
11mod repo_system;
12mod scope_system;
13mod type_analyzer;
14mod type_system;
15
16#[cfg(test)]
17mod test_framework;
18
19use code_generator::CodeGenerator;
20use dml::DmlParser;
21
22/// Attribute macro for DML (Data Manipulation Language) statements with compile-time validation
23///
24/// This macro generates type-safe, high-performance database operations with automatic parameter binding,
25/// pagination injection, and query optimization. All SQL is validated at compile time using SQLx's
26/// compile-time verification.
27///
28/// # Features
29/// - **Compile-time SQL validation** - Catches SQL errors during compilation
30/// - **Type-safe parameter binding** - Automatic conversion and validation of parameters
31/// - **Pagination support** - Automatic injection of LIMIT/OFFSET clauses
32/// - **Query optimization** - Smart query planning and execution
33/// - **Error handling** - Comprehensive error types with context
34/// - **Generated documentation** - Auto-documented methods with query details
35///
36/// # Usage
37/// ```ignore
38/// use sqlx_data_macros::{dml, repo, Pool, Result};
39/// use sqlx::FromRow;
40///
41/// #[derive(FromRow)]
42/// struct User {
43///     id: i64,
44///     name: String,
45///     email: String,
46/// }
47///
48/// #[repo]
49/// trait UserRepo {
50///     // Simple query returning a single record
51///     #[dml("SELECT * FROM users WHERE id = $1")]
52///     async fn find_by_id(&self, id: i64) -> Result<Option<User>>;
53///
54///     // Query returning multiple records
55///     #[dml("SELECT * FROM users WHERE active = $1")]
56///     async fn find_active_users(&self, active: bool) -> Result<Vec<User>>;
57///
58///     // Insert/Update/Delete operations
59///     #[dml("INSERT INTO users (name, email) VALUES ($1, $2) RETURNING id")]
60///     async fn create_user(&self, name: String, email: String) -> Result<i64>;
61///
62///     // Complex queries with joins
63///     #[dml("SELECT u.* FROM users u JOIN roles r ON u.role_id = r.id WHERE r.name = $1")]
64///     async fn find_users_by_role(&self, role_name: String) -> Result<Vec<User>>;
65///
66///     // Pagination-enabled queries (automatic LIMIT/OFFSET injection)
67///     #[dml("SELECT * FROM users ORDER BY created_at")]
68///     async fn find_all_paged(&self) -> Result<Vec<User>>;
69/// }
70/// ```
71///
72/// # Required Imports
73/// ```ignore
74/// use sqlx_data_macros::{dml, repo};
75/// use sqlx_data_integration::{Pool, Result, Database}; // Core database types
76/// use sqlx::FromRow; // For result mapping
77/// ```
78#[proc_macro_attribute]
79pub fn dml(args: TokenStream, input: TokenStream) -> TokenStream {
80    let method = parse_macro_input!(input as TraitItemFn);
81
82    // Parse the DML arguments directly
83    let parsed_method = match DmlParser::parse_dml_method_with_args(method, args, false) {
84        Ok(method) => method,
85        Err(error) => return error.to_compile_error().into(),
86    };
87
88    // Generate code
89    match CodeGenerator::generate_dml_methods(&parsed_method) {
90        Ok(tokens) => tokens.into(),
91        Err(error) => error.to_compile_error().into(),
92    }
93}
94
95/// Attribute macro for repositories - processes aliases, scopes and adds get_pool method
96#[proc_macro_attribute]
97pub fn repo(args: TokenStream, input: TokenStream) -> TokenStream {
98    let input_trait = parse_macro_input!(input as syn::ItemTrait);
99
100    match repo_system::RepoProcessor::process_trait_with_args(input_trait, args) {
101        Ok(tokens) => tokens.into(),
102        Err(error) => error.to_compile_error().into(),
103    }
104}
105
106/// Attribute macro for generating method variants with different executor types
107///
108/// This macro generates additional method variants that accept different executor types.
109/// Apply this macro to individual methods that need variants.
110///
111/// # Supported Variant Types
112/// - `pool` - Adds `pool: &sqlx_data::Pool` parameter and `_with_pool` suffix
113/// - `tx` - Adds `transaction: &mut sqlx_data::Transaction<'_>` parameter and `_with_tx` suffix
114/// - `conn` - Adds `connection: &mut sqlx_data::Connection` parameter and `_with_conn` suffix
115/// - `exec` - Adds `executor: impl sqlx_data::Executor<'_>` parameter and `_with_executor` suffix
116///
117/// # Usage
118/// ```ignore
119/// use sqlx_data_macros::{generate_versions, dml, repo};
120///
121/// #[repo]
122/// trait UserRepo {
123///     // Original method with variants
124///     #[generate_versions(pool, tx)]
125///     #[dml("DELETE FROM users WHERE id = $1")]
126///     async fn delete_user(&self, id: i64) -> Result<QueryResult>;
127/// }
128/// ```
129///
130/// # Generated Output
131/// The macro generates additional methods alongside the original:
132/// ```ignore
133/// // Original (preserved)
134/// #[dml("DELETE FROM users WHERE id = $1")]
135/// async fn delete_user(&self, id: i64) -> Result<QueryResult>;
136///
137/// // Generated variants
138/// #[dml("DELETE FROM users WHERE id = $1")]
139/// async fn delete_user_with_pool(&self, pool: &sqlx_data::Pool, id: i64) -> Result<QueryResult>;
140///
141/// #[dml("DELETE FROM users WHERE id = $1")]
142/// async fn delete_user_with_tx(&self, transaction: &mut sqlx_data::Transaction<'_>, id: i64) -> Result<QueryResult>;
143/// ```
144#[proc_macro_attribute]
145pub fn generate_versions(args: TokenStream, input: TokenStream) -> TokenStream {
146    let input_method = parse_macro_input!(input as TraitItemFn);
147    let args_tokens = proc_macro2::TokenStream::from(args);
148
149    match method_variants::expand_method_variants(input_method, args_tokens) {
150        Ok(tokens) => TokenStream::from(tokens),
151        Err(error) => error.to_compile_error().into(),
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use syn::parse_quote;
158
159    // Helper function to create a DmlMethod for testing
160    fn create_test_dml_method(
161        method_name: &str,
162        sql: &str,
163        parameters: Vec<crate::dml::DmlParameter>,
164        return_type: syn::Type,
165    ) -> crate::dml::DmlMethod {
166        use syn::{FnArg, Pat, PatIdent, PatType, Signature, TraitItemFn};
167
168        // Create function signature
169        let mut inputs = syn::punctuated::Punctuated::new();
170
171        // Add self parameter
172        inputs.push(FnArg::Receiver(syn::Receiver {
173            attrs: vec![],
174            reference: Some((syn::Token![&](proc_macro2::Span::call_site()), None)),
175            mutability: None,
176            self_token: syn::Token![self](proc_macro2::Span::call_site()),
177            colon_token: None,
178            ty: Box::new(parse_quote! { Self }),
179        }));
180
181        // Add other parameters
182        for param in &parameters {
183            let pat = PatIdent {
184                attrs: vec![],
185                by_ref: None,
186                mutability: None,
187                ident: syn::Ident::new(&param.name, proc_macro2::Span::call_site()),
188                subpat: None,
189            };
190
191            inputs.push(FnArg::Typed(PatType {
192                attrs: vec![],
193                pat: Box::new(Pat::Ident(pat)),
194                colon_token: syn::Token![:](proc_macro2::Span::call_site()),
195                ty: Box::new(param.type_.clone()),
196            }));
197        }
198
199        // Don't add async for Stream return types
200        let is_stream_type = matches!(&return_type, syn::Type::ImplTrait(impl_trait)
201            if impl_trait.bounds.iter().any(|bound| {
202                if let syn::TypeParamBound::Trait(trait_bound) = bound {
203                    trait_bound.path.segments.last()
204                        .map_or(false, |seg| seg.ident == "Stream")
205                } else {
206                    false
207                }
208            })
209        );
210
211        let sig = Signature {
212            constness: None,
213            asyncness: if is_stream_type {
214                None
215            } else {
216                Some(syn::Token![async](proc_macro2::Span::call_site()))
217            },
218            unsafety: None,
219            abi: None,
220            fn_token: syn::Token![fn](proc_macro2::Span::call_site()),
221            ident: syn::Ident::new(method_name, proc_macro2::Span::call_site()),
222            generics: syn::Generics::default(),
223            paren_token: syn::token::Paren::default(),
224            inputs,
225            variadic: None,
226            output: syn::ReturnType::Type(
227                syn::Token![->](proc_macro2::Span::call_site()),
228                Box::new(return_type),
229            ),
230        };
231
232        let trait_method = TraitItemFn {
233            attrs: vec![],
234            sig,
235            default: None,
236            semi_token: Some(syn::Token![;](proc_macro2::Span::call_site())),
237        };
238
239        crate::dml::DmlMethod {
240            method: trait_method,
241            sql_content: sql.to_string(),
242            parameters,
243            statement: sqlx_data_parser::parse_sql(sql).unwrap(),
244            kind: sqlx_data_parser::SqlStatementType::Select,
245            is_json_query: false,
246            is_multi_insert: false,
247            is_unchecked: false,
248            has_explicit_instrument: false,
249            trait_instrument: false,
250            return_info_cache: std::sync::OnceLock::new(),
251        }
252    }
253
254    #[test]
255    fn test_dml_macro_basic() {
256        use crate::code_generator::CodeGenerator;
257        use crate::dml::DmlParameter;
258        use syn::parse_quote;
259
260        let method = create_test_dml_method(
261            "find_by_id",
262            "SELECT * FROM users WHERE id = $1",
263            vec![DmlParameter {
264                name: "id".to_string(),
265                type_: parse_quote! { i64 },
266                is_pool: false,
267                is_dynamic_param: false,
268                is_generic: false,
269            }],
270            parse_quote! { Result<User> },
271        );
272
273        let result = CodeGenerator::generate_dml_methods(&method);
274        assert!(result.is_ok());
275
276        let generated_code = result.unwrap().to_string();
277        assert!(generated_code.contains("find_by_id_query"));
278        assert!(generated_code.contains("find_by_id"));
279        assert!(generated_code.contains("sqlx::query_as!"));
280    }
281
282    #[test]
283    fn test_dml_macro_with_flatten() {
284        use crate::code_generator::CodeGenerator;
285        use crate::dml::DmlParameter;
286        use syn::parse_quote;
287
288        let method = create_test_dml_method(
289            "get_birth_year",
290            "SELECT birth_year FROM users WHERE id = $1",
291            vec![DmlParameter {
292                name: "id".to_string(),
293                type_: parse_quote! { i64 },
294                is_pool: false,
295                is_dynamic_param: false,
296                is_generic: false,
297            }],
298            parse_quote! { Result<Option<i64>> },
299        );
300
301        let result = CodeGenerator::generate_dml_methods(&method);
302        assert!(result.is_ok());
303
304        let generated_code = result.unwrap().to_string();
305        assert!(generated_code.contains("get_birth_year_query"));
306        assert!(generated_code.contains("get_birth_year"));
307        assert!(generated_code.contains("sqlx::query_scalar!"));
308    }
309
310    #[test]
311    #[cfg(feature = "sqlite")]
312    fn test_tuple_f32_casting() {
313        use crate::code_generator::CodeGenerator;
314        use syn::parse_quote;
315
316        let method = create_test_dml_method(
317            "group_avg",
318            "SELECT birth_year, AVG(age) as avg_age FROM users GROUP BY birth_year",
319            vec![],
320            parse_quote! { Result<Vec<(Option<u16>, f32)>> },
321        );
322
323        let result = CodeGenerator::generate_dml_methods(&method);
324        assert!(result.is_ok());
325
326        let generated_code = result.unwrap().to_string();
327        eprintln!("Generated Code for tuple casting:\n{}", generated_code);
328
329        // Verify that f64 -> f32 casting is present (SQLite returns f64 for AVG)
330        assert!(generated_code.contains("as f32"));
331        // Verify that i64 -> u16 casting is present for Option<u16> (SQLite uses i64 for integers)
332        assert!(generated_code.contains("as u16"));
333        assert!(generated_code.contains("group_avg_query"));
334        assert!(generated_code.contains("QueryTuple"));
335    }
336
337    #[test]
338    #[cfg(feature = "sqlite")]
339    fn test_tuple_f64_casting() {
340        use crate::code_generator::CodeGenerator;
341        use crate::dml::DmlParameter;
342        use syn::parse_quote;
343
344        let method = create_test_dml_method(
345            "group_having_avg",
346            "SELECT birth_year, AVG(age) as avg_age FROM users WHERE birth_year IS NOT NULL GROUP BY birth_year HAVING AVG(age) > $1",
347            vec![DmlParameter {
348                name: "min_avg".to_string(),
349                type_: parse_quote! { f32 },
350                is_pool: false,
351                is_dynamic_param: false,
352                is_generic: false,
353            }],
354            parse_quote! { Result<Vec<(Option<u16>, f64)>> },
355        );
356
357        let result = CodeGenerator::generate_dml_methods(&method);
358        assert!(result.is_ok());
359
360        let generated_code = result.unwrap().to_string();
361        eprintln!("Generated Code for f64 casting:\\n{}", generated_code);
362
363        // Verify that i64 -> u16 casting is present for Option<u16> (SQLite uses i64 for integers)
364        assert!(generated_code.contains("as u16"));
365        assert!(generated_code.contains("group_having_avg_query"));
366        assert!(generated_code.contains("QueryTuple"));
367    }
368
369    #[test]
370    #[cfg(feature = "sqlite")]
371    fn test_tuple_i64_usize_casting() {
372        use crate::code_generator::CodeGenerator;
373        use syn::parse_quote;
374
375        let method = create_test_dml_method(
376            "count_by_year",
377            "SELECT birth_year, COUNT(*) as count FROM users GROUP BY birth_year",
378            vec![],
379            parse_quote! { Result<Vec<(Option<i64>, usize)>> },
380        );
381
382        let result = CodeGenerator::generate_dml_methods(&method);
383        assert!(result.is_ok());
384
385        let generated_code = result.unwrap().to_string();
386        eprintln!("Generated Code for i64/usize casting:\\n{}", generated_code);
387
388        // i64 shouldn't need casting (i64 -> i64)
389        // usize should need casting (i64 -> usize) in SQLite
390        assert!(generated_code.contains("as usize"));
391        assert!(generated_code.contains("count_by_year_query"));
392    }
393
394    #[test]
395    fn test_documentation_generation() {
396        use crate::code_generator::CodeGenerator;
397        use syn::parse_quote;
398
399        let method = create_test_dml_method(
400            "find_by_id",
401            "SELECT * FROM users WHERE id = $1",
402            vec![],
403            parse_quote! { Result<User> },
404        );
405
406        let result = CodeGenerator::generate_dml_methods(&method);
407        assert!(result.is_ok());
408
409        let generated_code = result.unwrap().to_string();
410        eprintln!("Generated Code:\n{}", generated_code);
411
412        // Verify that documentation comment is generated on the public method
413        assert!(generated_code.contains("# [doc = "));
414        assert!(generated_code.contains("Generated by #[dml] macro:"));
415        assert!(generated_code.contains("```rust"));
416        assert!(generated_code.contains("find_by_id_query"));
417        assert!(generated_code.contains("sqlx :: query_as !"));
418    }
419}