1mod pg_composite;
4mod pg_domain;
5mod pg_enum;
6
7pub(crate) fn consume_unknown_meta_value(meta: &syn::meta::ParseNestedMeta) -> syn::Result<()> {
11 if meta.input.peek(syn::Token![=]) {
12 let _ = meta.value()?.parse::<syn::Expr>()?;
13 }
14 Ok(())
15}
16
17use proc_macro::TokenStream;
18use quote::quote;
19use syn::{parse_macro_input, Data, DeriveInput, Fields, LitStr};
20
21#[proc_macro_derive(FromRow, attributes(from_row))]
32pub fn derive_from_row(input: TokenStream) -> TokenStream {
33 let input = parse_macro_input!(input as DeriveInput);
34 match derive_from_row_inner(input) {
35 Ok(tokens) => tokens.into(),
36 Err(err) => err.to_compile_error().into(),
37 }
38}
39
40fn derive_from_row_inner(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
41 let name = &input.ident;
42 let generics = &input.generics;
43 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
44
45 let fields = match &input.data {
46 Data::Struct(data) => match &data.fields {
47 Fields::Named(fields) => &fields.named,
48 _ => {
49 return Err(syn::Error::new_spanned(
50 &input,
51 "FromRow only supports structs with named fields",
52 ));
53 }
54 },
55 _ => {
56 return Err(syn::Error::new_spanned(
57 &input,
58 "FromRow only supports structs",
59 ));
60 }
61 };
62
63 let field_extractions = fields
64 .iter()
65 .map(|field| {
66 let field_name = field.ident.as_ref().unwrap();
67 let field_type = &field.ty;
68 let attrs = FromRowFieldAttrs::parse(field)?;
69
70 let col_name = attrs.rename.unwrap_or_else(|| field_name.to_string());
71
72 if attrs.skip {
73 return Ok(quote! { #field_name: Default::default() });
74 }
75
76 if attrs.flatten {
77 return Ok(quote! {
78 #field_name: <#field_type as resolute::FromRow>::from_row(row)?
79 });
80 }
81
82 if let Some(ref source_type) = attrs.try_from {
83 if is_option_type(field_type) {
84 return Ok(quote! {
85 #field_name: {
86 let __opt: Option<#source_type> = row.get_opt_by_name(#col_name)?;
87 match __opt {
88 Some(__src) => Some(
89 <_ as std::convert::TryFrom<#source_type>>::try_from(__src)
90 .map_err(|e| resolute::TypedError::Decode {
91 column: 0,
92 message: format!("try_from({}): {}", #col_name, e),
93 })?
94 ),
95 None => None,
96 }
97 }
98 });
99 } else {
100 return Ok(quote! {
101 #field_name: {
102 let __src: #source_type = row.get_by_name(#col_name)?;
103 <#field_type as std::convert::TryFrom<#source_type>>::try_from(__src)
104 .map_err(|e| resolute::TypedError::Decode {
105 column: 0,
106 message: format!("try_from({}): {}", #col_name, e),
107 })?
108 }
109 });
110 }
111 }
112
113 if attrs.json {
114 if is_option_type(field_type) {
115 return Ok(quote! {
116 #field_name: {
117 let __opt: Option<serde_json::Value> = row.get_opt_by_name(#col_name)?;
118 match __opt {
119 Some(__v) => Some(
120 serde_json::from_value(__v).map_err(|e| resolute::TypedError::Decode {
121 column: 0,
122 message: format!("json({}): {}", #col_name, e),
123 })?
124 ),
125 None => None,
126 }
127 }
128 });
129 } else {
130 return Ok(quote! {
131 #field_name: {
132 let __v: serde_json::Value = row.get_by_name(#col_name)?;
133 serde_json::from_value(__v).map_err(|e| resolute::TypedError::Decode {
134 column: 0,
135 message: format!("json({}): {}", #col_name, e),
136 })?
137 }
138 });
139 }
140 }
141
142 if attrs.default {
143 if is_option_type(field_type) {
144 return Ok(quote! {
145 #field_name: if row.has_column(#col_name) {
146 row.get_opt_by_name(#col_name)?
147 } else {
148 None
149 }
150 });
151 } else {
152 return Ok(quote! {
153 #field_name: if row.has_column(#col_name) {
154 match row.get_by_name(#col_name) {
155 Ok(v) => v,
156 Err(resolute::TypedError::UnexpectedNull(_)) => Default::default(),
157 Err(e) => return Err(e),
158 }
159 } else {
160 Default::default()
161 }
162 });
163 }
164 }
165
166 if is_option_type(field_type) {
168 Ok(quote! { #field_name: row.get_opt_by_name(#col_name)? })
169 } else {
170 Ok(quote! { #field_name: row.get_by_name(#col_name)? })
171 }
172 })
173 .collect::<syn::Result<Vec<_>>>()?;
174
175 Ok(quote! {
176 impl #impl_generics resolute::FromRow for #name #ty_generics #where_clause {
177 fn from_row(row: &resolute::Row) -> Result<Self, resolute::TypedError> {
178 Ok(Self {
179 #(#field_extractions,)*
180 })
181 }
182 }
183 })
184}
185
186struct FromRowFieldAttrs {
192 rename: Option<String>,
193 skip: bool,
194 default: bool,
195 json: bool,
196 try_from: Option<syn::Type>,
197 flatten: bool,
198}
199
200impl FromRowFieldAttrs {
201 fn parse(field: &syn::Field) -> syn::Result<Self> {
202 let mut attrs = Self {
203 rename: None,
204 skip: false,
205 default: false,
206 json: false,
207 try_from: None,
208 flatten: false,
209 };
210
211 for attr in &field.attrs {
212 if !attr.path().is_ident("from_row") {
213 continue;
214 }
215 attr.parse_nested_meta(|meta| {
216 if meta.path.is_ident("rename") {
217 let value = meta.value()?;
218 let s: LitStr = value.parse()?;
219 attrs.rename = Some(s.value());
220 } else if meta.path.is_ident("skip") {
221 attrs.skip = true;
222 } else if meta.path.is_ident("default") {
223 attrs.default = true;
224 } else if meta.path.is_ident("json") {
225 attrs.json = true;
226 } else if meta.path.is_ident("try_from") {
227 let value = meta.value()?;
228 let s: LitStr = value.parse()?;
229 let ty: syn::Type = syn::parse_str(&s.value()).map_err(|e| {
230 syn::Error::new(
231 s.span(),
232 format!("from_row(try_from = \"...\") must be a valid Rust type: {e}"),
233 )
234 })?;
235 attrs.try_from = Some(ty);
236 } else if meta.path.is_ident("flatten") {
237 attrs.flatten = true;
238 } else {
239 return Err(meta.error("unknown from_row attribute"));
240 }
241 Ok(())
242 })?;
243 }
244
245 if attrs.skip
247 && (attrs.rename.is_some()
248 || attrs.default
249 || attrs.json
250 || attrs.try_from.is_some()
251 || attrs.flatten)
252 {
253 return Err(syn::Error::new_spanned(
254 field,
255 "from_row(skip) cannot be combined with other attributes",
256 ));
257 }
258 if attrs.flatten && (attrs.rename.is_some() || attrs.json || attrs.try_from.is_some()) {
259 return Err(syn::Error::new_spanned(
260 field,
261 "from_row(flatten) cannot be combined with rename, json, or try_from",
262 ));
263 }
264 if attrs.json && attrs.try_from.is_some() {
265 return Err(syn::Error::new_spanned(
266 field,
267 "from_row(json) cannot be combined with try_from",
268 ));
269 }
270
271 Ok(attrs)
272 }
273}
274
275fn is_option_type(ty: &syn::Type) -> bool {
277 if let syn::Type::Path(type_path) = ty {
278 if let Some(seg) = type_path.path.segments.last() {
279 return seg.ident == "Option";
280 }
281 }
282 false
283}
284
285#[proc_macro_derive(PgEnum, attributes(pg_type))]
288pub fn derive_pg_enum(input: TokenStream) -> TokenStream {
289 let input = parse_macro_input!(input as DeriveInput);
290 pg_enum::derive(input)
291}
292
293#[proc_macro_derive(PgComposite, attributes(pg_type))]
296pub fn derive_pg_composite(input: TokenStream) -> TokenStream {
297 let input = parse_macro_input!(input as DeriveInput);
298 pg_composite::derive(input)
299}
300
301#[proc_macro_derive(PgDomain, attributes(pg_type))]
304pub fn derive_pg_domain(input: TokenStream) -> TokenStream {
305 let input = parse_macro_input!(input as DeriveInput);
306 pg_domain::derive(input)
307}
308
309#[proc_macro_attribute]
327pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
328 let input_fn = parse_macro_input!(item as syn::ItemFn);
329
330 let mut migrations: Option<String> = None;
331 let attr_parser = syn::meta::parser(|meta| {
332 if meta.path.is_ident("migrations") {
333 let value = meta.value()?;
334 let s: LitStr = value.parse()?;
335 migrations = Some(s.value());
336 Ok(())
337 } else {
338 Err(meta.error("unknown resolute::test attribute"))
339 }
340 });
341 parse_macro_input!(attr with attr_parser);
342
343 let fn_name = &input_fn.sig.ident;
344 let fn_block = &input_fn.block;
345 let fn_vis = &input_fn.vis;
346 let fn_attrs = &input_fn.attrs;
347
348 let create_db = if let Some(mig_path) = &migrations {
349 quote! {
350 let __test_db = resolute::test_db::TestDb::create_with_migrations(
351 &__addr, &__user, &__pass, #mig_path,
352 ).await.expect("failed to create test database");
353 }
354 } else {
355 quote! {
356 let __test_db = resolute::test_db::TestDb::create(
357 &__addr, &__user, &__pass,
358 ).await.expect("failed to create test database");
359 }
360 };
361
362 let expanded = quote! {
363 #(#fn_attrs)*
364 #[tokio::test]
365 #fn_vis async fn #fn_name() {
366 let __addr = resolute::test_db::test_addr().to_string();
370 let __user = resolute::test_db::test_user().to_string();
371 let __pass = resolute::test_db::test_password().to_string();
372
373 #create_db
374
375 let client = __test_db.client().await.expect("failed to connect to test database");
376
377 let __result = async { #fn_block }.await;
379
380 drop(client);
382 let _ = __test_db.drop_db().await;
383 }
384 };
385
386 TokenStream::from(expanded)
387}