mini_query_derive/
lib.rs

1extern crate proc_macro;
2mod attrs;
3
4use attrs::{ContainerAttributes, FieldAttribute, ParseAttributes};
5use core::panic;
6use proc_macro2::TokenStream;
7use quote::{format_ident, quote, ToTokens};
8use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Error, Fields, FieldsNamed, Type};
9
10#[proc_macro_derive(MiniQuery, attributes(mini_query))]
11pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
12  let input = parse_macro_input!(input as DeriveInput);
13
14  let struct_name = &input.ident;
15
16  let (_impl_generics, type_generics, where_clause) = &input.generics.split_for_impl();
17
18  let container_attributes = ContainerAttributes::parse_attributes("mini_query", &input.attrs).unwrap();
19
20  let Some(table_name) = &container_attributes.table_name else {
21    panic!("Expected table_name attr.");
22  };
23
24  let Data::Struct(DataStruct {
25    fields: Fields::Named(FieldsNamed { named: fields, .. }),
26    ..
27  }) = input.data
28  else {
29    panic!("Derive(MiniQuery) only applicable to named structs");
30  };
31
32  let mut token_stream = TokenStream::new();
33
34  let mut primary_key = None;
35  let mut field_tokens = Vec::new();
36  let mut field_names = Vec::new();
37  let mut from_impl = Vec::new();
38
39  for field in fields {
40    let field_attributes = FieldAttribute::parse_attributes(container_attributes.attribute(), &field.attrs).unwrap();
41
42    let ty = &field.ty;
43    let Some(field) = field.ident else {
44      return Err(Error::new_spanned(field, "field must be a named field")).unwrap();
45    };
46
47    let name = container_attributes.apply_to_field(&field.to_string());
48
49    // is this field is set as the primary key?
50    // denoted with #[mini_query(primary_key)]
51    if field_attributes.primary_key {
52      primary_key = Some((field.clone(), ty.clone()));
53      from_impl.push(TokenStream::from(quote! { #field: row.get(stringify!(#field)) }));
54      continue;
55    }
56
57    // this block will be skipped on this field if #[mini_query(skip)] is set
58    if let Some(name) = field_attributes.apply_to_field(&name) {
59      field_names.push(name.clone());
60
61      // is the field in question being casted when sent to / from database?
62      // denoted with #[mini_query(cast = i16)]
63      if let Some(cast) = field_attributes.cast.clone() {
64        from_impl.push(TokenStream::from(quote! { #field: row.get::<'a, &str, #cast>(#name).into() }));
65        field_tokens.push(TokenStream::from(quote! { #field as #cast }));
66      } else {
67        from_impl.push(TokenStream::from(quote! { #field: row.get(#name) }));
68        field_tokens.push(TokenStream::from(quote! { #field }));
69      }
70    }
71
72    // build out the get_by_x functions
73    // denoted with #[mini_query(get_by)]
74    if field_attributes.get_by {
75      let field = field.clone();
76      let query = format!("SELECT * FROM {table_name} WHERE {name} = $1");
77      let get_by_fn_name = format_ident!("get_by_{}", field);
78
79      let mut cast = TokenStream::new();
80      if let Some(ty) = field_attributes.cast {
81        cast.extend(TokenStream::from(quote! { as #ty }));
82      }
83
84      // Make things simpler for string field types
85      // Allow the passing of anythign that implements AsRef<str>
86      let (field_type, field_fetch) = coalesce_types(ty);
87
88      token_stream.extend(TokenStream::from(quote! {
89        impl #struct_name #type_generics #where_clause {
90          pub async fn #get_by_fn_name(client: &impl GenericClient, field: #field_type) -> anyhow::Result<Vec<Self>> {
91            #field_fetch
92            let recs = client.query(#query, &[&(*field #cast)]).await?.iter().map(Self::from).collect();
93            Ok(recs)
94          }
95        }
96      }));
97    }
98
99    // Build out the find_by_x functions
100    // Denoted with #[mini_query(find_by)]
101    if field_attributes.find_by {
102      let field = field.clone();
103      let query = format!("SELECT * FROM {table_name} WHERE {name} = $1");
104      let find_by_fn_name = format_ident!("find_by_{}", field);
105
106      // Make things simpler for string field types
107      // Allow the passing of anythign that implements AsRef<str>
108      let (field_type, field_fetch) = coalesce_types(&ty);
109
110      token_stream.extend(TokenStream::from(quote! {
111        impl #struct_name #type_generics #where_clause {
112          pub async fn #find_by_fn_name(client: &impl GenericClient, field: #field_type) -> anyhow::Result<Option<Self>> {
113            #field_fetch
114            let rec = client.query_opt(#query, &[&field]).await?.map(Self::from);
115            Ok(rec)
116          }
117        }
118      }));
119    }
120  }
121
122  let len = field_names.len();
123
124  let ts = TokenStream::from(quote! {
125      impl From<tokio_postgres::Row> for #struct_name #type_generics #where_clause {
126        fn from(row: tokio_postgres::Row) -> Self {
127          Self::from(&row)
128        }
129      }
130      impl<'a> From<&'a tokio_postgres::Row> for #struct_name #type_generics #where_clause {
131        fn from(row: &'a tokio_postgres::Row) -> Self {
132          Self {
133            #(#from_impl),*,
134            ..Default::default()
135          }
136        }
137      }
138  });
139  token_stream.extend(ts);
140
141  let ts = {
142    let field_tokens = field_tokens.clone();
143
144    let dollar_signs: String = (1..=len).map(|i| format!("${i}")).collect::<Vec<String>>().join(",");
145    let insert_query = format!("INSERT INTO {table_name} ({}) VALUES ({dollar_signs})", field_names.join(","));
146    let insert_query_returning = format!("{} RETURNING *", &insert_query);
147    let all_query = format!("SELECT * FROM {table_name}");
148
149    TokenStream::from(quote! {
150      impl #struct_name #type_generics #where_clause {
151        #[doc=concat!("Generated array of field names for `", stringify!(#struct_name #type_generics), "`.")]
152        const FIELD_NAMES: [&'static str; #len] = [#(#field_names),*];
153        pub const __TABLE_NAME__: &'static str = #table_name;
154
155        pub async fn all(client: &impl GenericClient) -> anyhow::Result<Vec<Self>> {
156          let recs = client.query(#all_query, &[]).await?.iter().map(Self::from).collect();
157          Ok(recs)
158        }
159
160        pub async fn quick_insert(&self, client: &impl GenericClient) -> anyhow::Result<Self> {
161          let rec = client.query_one(
162            #insert_query_returning,
163            &[#(&(self.#field_tokens)),*]
164          ).await?;
165
166          Ok(Self::from(rec))
167        }
168
169        pub async fn quick_insert_no_return(&self, client: &impl GenericClient) -> anyhow::Result<()> {
170          client
171            .query(
172              #insert_query,
173              &[#(&(self.#field_tokens)),*]
174            ).await?;
175
176          Ok(())
177        }
178      }
179    })
180  };
181  token_stream.extend(ts);
182
183  if let Some((ident, ty)) = primary_key {
184    let query = format!("SELECT * FROM {} WHERE {} = $1", table_name, ident);
185    let update_query = field_names
186      .iter()
187      .enumerate()
188      .map(|(i, name)| format!("{name}=${}", i + 2))
189      .collect::<Vec<String>>()
190      .join(",");
191    let update_query = format!("UPDATE {table_name} SET {update_query} WHERE id=$1 RETURNING *");
192
193    let ts = TokenStream::from(quote! {
194      impl #struct_name #type_generics #where_clause {
195        pub const __PRIMARY_KEY__: &'static str = stringify!(#ident);
196
197        pub async fn get(client: &impl GenericClient, id: &#ty) -> anyhow::Result<Option<Self>> {
198          let rec = client.query_opt(#query, &[&id]).await?.map(Self::from);
199
200          Ok(rec)
201        }
202
203        pub async fn quick_update(&self, client: &impl GenericClient) -> anyhow::Result<Self> {
204          let rec = client.query_one(#update_query, &[&self.#ident, #(&(self.#field_tokens)),*]).await?;
205
206          Ok(Self::from(rec))
207        }
208      }
209    });
210    token_stream.extend(ts);
211  }
212
213  token_stream.into()
214}
215
216fn coalesce_types(ty: &Type) -> (TokenStream, TokenStream) {
217  let rasterized_type = ty.clone().into_token_stream().to_string();
218  match rasterized_type.as_ref() {
219    "String" => (
220      TokenStream::from(quote! { impl AsRef<str> }),
221      TokenStream::from(quote! { let field = field.as_ref(); }),
222    ),
223    _ => (TokenStream::from(quote! { &#ty }), TokenStream::new()),
224  }
225}