use proc_macro::TokenStream;
use syn::{TraitItemFn, parse_macro_input};
mod alias_system;
mod code_generator;
mod constants;
mod dml;
mod error;
mod fetch;
mod method_variants;
mod repo_system;
mod scope_system;
mod type_analyzer;
mod type_system;
#[cfg(test)]
mod test_framework;
use code_generator::CodeGenerator;
use dml::DmlParser;
#[proc_macro_attribute]
pub fn dml(args: TokenStream, input: TokenStream) -> TokenStream {
let method = parse_macro_input!(input as TraitItemFn);
let parsed_method = match DmlParser::parse_dml_method_with_args(method, args, false) {
Ok(method) => method,
Err(error) => return error.to_compile_error().into(),
};
match CodeGenerator::generate_dml_methods(&parsed_method) {
Ok(tokens) => tokens.into(),
Err(error) => error.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn repo(args: TokenStream, input: TokenStream) -> TokenStream {
let input_trait = parse_macro_input!(input as syn::ItemTrait);
match repo_system::RepoProcessor::process_trait_with_args(input_trait, args) {
Ok(tokens) => tokens.into(),
Err(error) => error.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn generate_versions(args: TokenStream, input: TokenStream) -> TokenStream {
let input_method = parse_macro_input!(input as TraitItemFn);
let args_tokens = proc_macro2::TokenStream::from(args);
match method_variants::expand_method_variants(input_method, args_tokens) {
Ok(tokens) => TokenStream::from(tokens),
Err(error) => error.to_compile_error().into(),
}
}
#[cfg(test)]
mod tests {
use syn::parse_quote;
fn create_test_dml_method(
method_name: &str,
sql: &str,
parameters: Vec<crate::dml::DmlParameter>,
return_type: syn::Type,
) -> crate::dml::DmlMethod {
use syn::{FnArg, Pat, PatIdent, PatType, Signature, TraitItemFn};
let mut inputs = syn::punctuated::Punctuated::new();
inputs.push(FnArg::Receiver(syn::Receiver {
attrs: vec![],
reference: Some((syn::Token), None)),
mutability: None,
self_token: syn::Token),
colon_token: None,
ty: Box::new(parse_quote! { Self }),
}));
for param in ¶meters {
let pat = PatIdent {
attrs: vec![],
by_ref: None,
mutability: None,
ident: syn::Ident::new(¶m.name, proc_macro2::Span::call_site()),
subpat: None,
};
inputs.push(FnArg::Typed(PatType {
attrs: vec![],
pat: Box::new(Pat::Ident(pat)),
colon_token: syn::Token),
ty: Box::new(param.type_.clone()),
}));
}
let is_stream_type = matches!(&return_type, syn::Type::ImplTrait(impl_trait)
if impl_trait.bounds.iter().any(|bound| {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
trait_bound.path.segments.last()
.map_or(false, |seg| seg.ident == "Stream")
} else {
false
}
})
);
let sig = Signature {
constness: None,
asyncness: if is_stream_type {
None
} else {
Some(syn::Token))
},
unsafety: None,
abi: None,
fn_token: syn::Token),
ident: syn::Ident::new(method_name, proc_macro2::Span::call_site()),
generics: syn::Generics::default(),
paren_token: syn::token::Paren::default(),
inputs,
variadic: None,
output: syn::ReturnType::Type(
syn::Token),
Box::new(return_type),
),
};
let trait_method = TraitItemFn {
attrs: vec![],
sig,
default: None,
semi_token: Some(syn::Token)),
};
crate::dml::DmlMethod {
method: trait_method,
sql_content: sql.to_string(),
parameters,
statement: sqlx_data_parser::parse_sql(sql).unwrap(),
kind: sqlx_data_parser::SqlStatementType::Select,
is_json_query: false,
is_multi_insert: false,
is_unchecked: false,
has_explicit_instrument: false,
trait_instrument: false,
return_info_cache: std::sync::OnceLock::new(),
}
}
#[test]
fn test_dml_macro_basic() {
use crate::code_generator::CodeGenerator;
use crate::dml::DmlParameter;
use syn::parse_quote;
let method = create_test_dml_method(
"find_by_id",
"SELECT * FROM users WHERE id = $1",
vec![DmlParameter {
name: "id".to_string(),
type_: parse_quote! { i64 },
is_pool: false,
is_dynamic_param: false,
is_generic: false,
}],
parse_quote! { Result<User> },
);
let result = CodeGenerator::generate_dml_methods(&method);
assert!(result.is_ok());
let generated_code = result.unwrap().to_string();
assert!(generated_code.contains("find_by_id_query"));
assert!(generated_code.contains("find_by_id"));
assert!(generated_code.contains("sqlx::query_as!"));
}
#[test]
fn test_dml_macro_with_flatten() {
use crate::code_generator::CodeGenerator;
use crate::dml::DmlParameter;
use syn::parse_quote;
let method = create_test_dml_method(
"get_birth_year",
"SELECT birth_year FROM users WHERE id = $1",
vec![DmlParameter {
name: "id".to_string(),
type_: parse_quote! { i64 },
is_pool: false,
is_dynamic_param: false,
is_generic: false,
}],
parse_quote! { Result<Option<i64>> },
);
let result = CodeGenerator::generate_dml_methods(&method);
assert!(result.is_ok());
let generated_code = result.unwrap().to_string();
assert!(generated_code.contains("get_birth_year_query"));
assert!(generated_code.contains("get_birth_year"));
assert!(generated_code.contains("sqlx::query_scalar!"));
}
#[test]
#[cfg(feature = "sqlite")]
fn test_tuple_f32_casting() {
use crate::code_generator::CodeGenerator;
use syn::parse_quote;
let method = create_test_dml_method(
"group_avg",
"SELECT birth_year, AVG(age) as avg_age FROM users GROUP BY birth_year",
vec![],
parse_quote! { Result<Vec<(Option<u16>, f32)>> },
);
let result = CodeGenerator::generate_dml_methods(&method);
assert!(result.is_ok());
let generated_code = result.unwrap().to_string();
eprintln!("Generated Code for tuple casting:\n{}", generated_code);
assert!(generated_code.contains("as f32"));
assert!(generated_code.contains("as u16"));
assert!(generated_code.contains("group_avg_query"));
assert!(generated_code.contains("QueryTuple"));
}
#[test]
#[cfg(feature = "sqlite")]
fn test_tuple_f64_casting() {
use crate::code_generator::CodeGenerator;
use crate::dml::DmlParameter;
use syn::parse_quote;
let method = create_test_dml_method(
"group_having_avg",
"SELECT birth_year, AVG(age) as avg_age FROM users WHERE birth_year IS NOT NULL GROUP BY birth_year HAVING AVG(age) > $1",
vec![DmlParameter {
name: "min_avg".to_string(),
type_: parse_quote! { f32 },
is_pool: false,
is_dynamic_param: false,
is_generic: false,
}],
parse_quote! { Result<Vec<(Option<u16>, f64)>> },
);
let result = CodeGenerator::generate_dml_methods(&method);
assert!(result.is_ok());
let generated_code = result.unwrap().to_string();
eprintln!("Generated Code for f64 casting:\\n{}", generated_code);
assert!(generated_code.contains("as u16"));
assert!(generated_code.contains("group_having_avg_query"));
assert!(generated_code.contains("QueryTuple"));
}
#[test]
#[cfg(feature = "sqlite")]
fn test_tuple_i64_usize_casting() {
use crate::code_generator::CodeGenerator;
use syn::parse_quote;
let method = create_test_dml_method(
"count_by_year",
"SELECT birth_year, COUNT(*) as count FROM users GROUP BY birth_year",
vec![],
parse_quote! { Result<Vec<(Option<i64>, usize)>> },
);
let result = CodeGenerator::generate_dml_methods(&method);
assert!(result.is_ok());
let generated_code = result.unwrap().to_string();
eprintln!("Generated Code for i64/usize casting:\\n{}", generated_code);
assert!(generated_code.contains("as usize"));
assert!(generated_code.contains("count_by_year_query"));
}
#[test]
fn test_documentation_generation() {
use crate::code_generator::CodeGenerator;
use syn::parse_quote;
let method = create_test_dml_method(
"find_by_id",
"SELECT * FROM users WHERE id = $1",
vec![],
parse_quote! { Result<User> },
);
let result = CodeGenerator::generate_dml_methods(&method);
assert!(result.is_ok());
let generated_code = result.unwrap().to_string();
eprintln!("Generated Code:\n{}", generated_code);
assert!(generated_code.contains("# [doc = "));
assert!(generated_code.contains("Generated by #[dml] macro:"));
assert!(generated_code.contains("```rust"));
assert!(generated_code.contains("find_by_id_query"));
assert!(generated_code.contains("sqlx :: query_as !"));
}
}