1extern crate proc_macro;
2mod attrs;
3
4use attrs::{ContainerAttributes, FieldAttribute, ParseAttributes};
5use core::panic;
6use proc_macro2::TokenStream;
7use quote::{format_ident, quote, ToTokens};
8use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Error, Fields, FieldsNamed, Type};
9
10#[proc_macro_derive(MiniQuery, attributes(mini_query))]
11pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
12 let input = parse_macro_input!(input as DeriveInput);
13
14 let struct_name = &input.ident;
15
16 let (_impl_generics, type_generics, where_clause) = &input.generics.split_for_impl();
17
18 let container_attributes = ContainerAttributes::parse_attributes("mini_query", &input.attrs).unwrap();
19
20 let Some(table_name) = &container_attributes.table_name else {
21 panic!("Expected table_name attr.");
22 };
23
24 let Data::Struct(DataStruct {
25 fields: Fields::Named(FieldsNamed { named: fields, .. }),
26 ..
27 }) = input.data
28 else {
29 panic!("Derive(MiniQuery) only applicable to named structs");
30 };
31
32 let mut token_stream = TokenStream::new();
33
34 let mut primary_key = None;
35 let mut field_tokens = Vec::new();
36 let mut field_names = Vec::new();
37 let mut from_impl = Vec::new();
38
39 for field in fields {
40 let field_attributes = FieldAttribute::parse_attributes(container_attributes.attribute(), &field.attrs).unwrap();
41
42 let ty = &field.ty;
43 let Some(field) = field.ident else {
44 return Err(Error::new_spanned(field, "field must be a named field")).unwrap();
45 };
46
47 let name = container_attributes.apply_to_field(&field.to_string());
48
49 if field_attributes.primary_key {
52 primary_key = Some((field.clone(), ty.clone()));
53 from_impl.push(TokenStream::from(quote! { #field: row.get(stringify!(#field)) }));
54 continue;
55 }
56
57 if let Some(name) = field_attributes.apply_to_field(&name) {
59 field_names.push(name.clone());
60
61 if let Some(cast) = field_attributes.cast.clone() {
64 from_impl.push(TokenStream::from(quote! { #field: row.get::<'a, &str, #cast>(#name).into() }));
65 field_tokens.push(TokenStream::from(quote! { #field as #cast }));
66 } else {
67 from_impl.push(TokenStream::from(quote! { #field: row.get(#name) }));
68 field_tokens.push(TokenStream::from(quote! { #field }));
69 }
70 }
71
72 if field_attributes.get_by {
75 let field = field.clone();
76 let query = format!("SELECT * FROM {table_name} WHERE {name} = $1");
77 let get_by_fn_name = format_ident!("get_by_{}", field);
78
79 let mut cast = TokenStream::new();
80 if let Some(ty) = field_attributes.cast {
81 cast.extend(TokenStream::from(quote! { as #ty }));
82 }
83
84 let (field_type, field_fetch) = coalesce_types(ty);
87
88 token_stream.extend(TokenStream::from(quote! {
89 impl #struct_name #type_generics #where_clause {
90 pub async fn #get_by_fn_name(client: &impl GenericClient, field: #field_type) -> anyhow::Result<Vec<Self>> {
91 #field_fetch
92 let recs = client.query(#query, &[&(*field #cast)]).await?.iter().map(Self::from).collect();
93 Ok(recs)
94 }
95 }
96 }));
97 }
98
99 if field_attributes.find_by {
102 let field = field.clone();
103 let query = format!("SELECT * FROM {table_name} WHERE {name} = $1");
104 let find_by_fn_name = format_ident!("find_by_{}", field);
105
106 let (field_type, field_fetch) = coalesce_types(&ty);
109
110 token_stream.extend(TokenStream::from(quote! {
111 impl #struct_name #type_generics #where_clause {
112 pub async fn #find_by_fn_name(client: &impl GenericClient, field: #field_type) -> anyhow::Result<Option<Self>> {
113 #field_fetch
114 let rec = client.query_opt(#query, &[&field]).await?.map(Self::from);
115 Ok(rec)
116 }
117 }
118 }));
119 }
120 }
121
122 let len = field_names.len();
123
124 let ts = TokenStream::from(quote! {
125 impl From<tokio_postgres::Row> for #struct_name #type_generics #where_clause {
126 fn from(row: tokio_postgres::Row) -> Self {
127 Self::from(&row)
128 }
129 }
130 impl<'a> From<&'a tokio_postgres::Row> for #struct_name #type_generics #where_clause {
131 fn from(row: &'a tokio_postgres::Row) -> Self {
132 Self {
133 #(#from_impl),*,
134 ..Default::default()
135 }
136 }
137 }
138 });
139 token_stream.extend(ts);
140
141 let ts = {
142 let field_tokens = field_tokens.clone();
143
144 let dollar_signs: String = (1..=len).map(|i| format!("${i}")).collect::<Vec<String>>().join(",");
145 let insert_query = format!("INSERT INTO {table_name} ({}) VALUES ({dollar_signs})", field_names.join(","));
146 let insert_query_returning = format!("{} RETURNING *", &insert_query);
147 let all_query = format!("SELECT * FROM {table_name}");
148
149 TokenStream::from(quote! {
150 impl #struct_name #type_generics #where_clause {
151 #[doc=concat!("Generated array of field names for `", stringify!(#struct_name #type_generics), "`.")]
152 const FIELD_NAMES: [&'static str; #len] = [#(#field_names),*];
153 pub const __TABLE_NAME__: &'static str = #table_name;
154
155 pub async fn all(client: &impl GenericClient) -> anyhow::Result<Vec<Self>> {
156 let recs = client.query(#all_query, &[]).await?.iter().map(Self::from).collect();
157 Ok(recs)
158 }
159
160 pub async fn quick_insert(&self, client: &impl GenericClient) -> anyhow::Result<Self> {
161 let rec = client.query_one(
162 #insert_query_returning,
163 &[#(&(self.#field_tokens)),*]
164 ).await?;
165
166 Ok(Self::from(rec))
167 }
168
169 pub async fn quick_insert_no_return(&self, client: &impl GenericClient) -> anyhow::Result<()> {
170 client
171 .query(
172 #insert_query,
173 &[#(&(self.#field_tokens)),*]
174 ).await?;
175
176 Ok(())
177 }
178 }
179 })
180 };
181 token_stream.extend(ts);
182
183 if let Some((ident, ty)) = primary_key {
184 let query = format!("SELECT * FROM {} WHERE {} = $1", table_name, ident);
185 let update_query = field_names
186 .iter()
187 .enumerate()
188 .map(|(i, name)| format!("{name}=${}", i + 2))
189 .collect::<Vec<String>>()
190 .join(",");
191 let update_query = format!("UPDATE {table_name} SET {update_query} WHERE id=$1 RETURNING *");
192
193 let ts = TokenStream::from(quote! {
194 impl #struct_name #type_generics #where_clause {
195 pub const __PRIMARY_KEY__: &'static str = stringify!(#ident);
196
197 pub async fn get(client: &impl GenericClient, id: &#ty) -> anyhow::Result<Option<Self>> {
198 let rec = client.query_opt(#query, &[&id]).await?.map(Self::from);
199
200 Ok(rec)
201 }
202
203 pub async fn quick_update(&self, client: &impl GenericClient) -> anyhow::Result<Self> {
204 let rec = client.query_one(#update_query, &[&self.#ident, #(&(self.#field_tokens)),*]).await?;
205
206 Ok(Self::from(rec))
207 }
208 }
209 });
210 token_stream.extend(ts);
211 }
212
213 token_stream.into()
214}
215
216fn coalesce_types(ty: &Type) -> (TokenStream, TokenStream) {
217 let rasterized_type = ty.clone().into_token_stream().to_string();
218 match rasterized_type.as_ref() {
219 "String" => (
220 TokenStream::from(quote! { impl AsRef<str> }),
221 TokenStream::from(quote! { let field = field.as_ref(); }),
222 ),
223 _ => (TokenStream::from(quote! { &#ty }), TokenStream::new()),
224 }
225}