postgres_from_row_derive/
lib.rs1use darling::{ast::Data, Error, FromDeriveInput, FromField, ToTokens};
2use proc_macro::TokenStream;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::quote;
5use syn::{parse_macro_input, DeriveInput, Result};
6
7#[proc_macro_derive(FromRow, attributes(from_row))]
9pub fn derive_from_row(input: TokenStream) -> TokenStream {
10 let derive_input = parse_macro_input!(input as DeriveInput);
11 match try_derive_from_row(&derive_input) {
12 Ok(result) => result,
13 Err(err) => err.write_errors().into(),
14 }
15}
16
17fn try_derive_from_row(input: &DeriveInput) -> std::result::Result<TokenStream, Error> {
19 let from_row_derive = DeriveFromRow::from_derive_input(input)?;
20 Ok(from_row_derive.generate()?)
21}
22
23#[derive(Debug, FromDeriveInput)]
25#[darling(
26 attributes(from_row),
27 forward_attrs(allow, doc, cfg),
28 supports(struct_named)
29)]
30struct DeriveFromRow {
31 ident: syn::Ident,
32 generics: syn::Generics,
33 data: Data<(), FromRowField>,
34}
35
36impl DeriveFromRow {
37 fn validate(&self) -> Result<()> {
39 for field in self.fields() {
40 field.validate()?;
41 }
42
43 Ok(())
44 }
45
46 fn predicates(&self) -> Result<Vec<TokenStream2>> {
48 let mut predicates = Vec::new();
49
50 for field in self.fields() {
51 field.add_predicates(&mut predicates)?;
52 }
53
54 Ok(predicates)
55 }
56
57 fn fields(&self) -> &[FromRowField] {
59 match &self.data {
60 Data::Struct(fields) => &fields.fields,
61 _ => panic!("invalid shape"),
62 }
63 }
64
65 fn generate(self) -> Result<TokenStream> {
67 self.validate()?;
68
69 let ident = &self.ident;
70
71 let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
72 let original_predicates = where_clause.clone().map(|w| &w.predicates).into_iter();
73 let predicates = self.predicates()?;
74
75 let from_row_fields = self
76 .fields()
77 .iter()
78 .map(|f| f.generate_from_row())
79 .collect::<syn::Result<Vec<_>>>()?;
80
81 let try_from_row_fields = self
82 .fields()
83 .iter()
84 .map(|f| f.generate_try_from_row())
85 .collect::<syn::Result<Vec<_>>>()?;
86
87 Ok(quote! {
88 impl #impl_generics postgres_from_row::FromRow for #ident #ty_generics where #(#original_predicates),* #(#predicates),* {
89
90 fn from_row(row: &postgres_from_row::tokio_postgres::Row) -> Self {
91 Self {
92 #(#from_row_fields),*
93 }
94 }
95
96 fn try_from_row(row: &postgres_from_row::tokio_postgres::Row) -> std::result::Result<Self, postgres_from_row::tokio_postgres::Error> {
97 Ok(Self {
98 #(#try_from_row_fields),*
99 })
100 }
101 }
102 }
103 .into())
104 }
105}
106
107#[derive(Debug, FromField)]
109#[darling(attributes(from_row), forward_attrs(allow, doc, cfg))]
110struct FromRowField {
111 ident: Option<syn::Ident>,
113 ty: syn::Type,
115 #[darling(default)]
118 flatten: bool,
119 try_from: Option<String>,
122 from: Option<String>,
125 rename: Option<String>,
128}
129
130impl FromRowField {
131 fn validate(&self) -> Result<()> {
133 if self.from.is_some() && self.try_from.is_some() {
134 return Err(Error::custom(
135 r#"can't combine `#[from_row(from = "..")]` with `#[from_row(try_from = "..")]`"#,
136 )
137 .into());
138 }
139
140 if self.rename.is_some() && self.flatten {
141 return Err(Error::custom(
142 r#"can't combine `#[from_row(flatten)]` with `#[from_row(rename = "..")]`"#,
143 )
144 .into());
145 }
146
147 Ok(())
148 }
149
150 fn target_ty(&self) -> Result<TokenStream2> {
153 if let Some(from) = &self.from {
154 Ok(from.parse()?)
155 } else if let Some(try_from) = &self.try_from {
156 Ok(try_from.parse()?)
157 } else {
158 Ok(self.ty.to_token_stream())
159 }
160 }
161
162 fn column_name(&self) -> String {
165 self.rename
166 .as_ref()
167 .map(Clone::clone)
168 .unwrap_or_else(|| self.ident.as_ref().unwrap().to_string())
169 }
170
171 fn add_predicates(&self, predicates: &mut Vec<TokenStream2>) -> Result<()> {
179 let target_ty = &self.target_ty()?;
180 let ty = &self.ty;
181
182 predicates.push(if self.flatten {
183 quote! (#target_ty: postgres_from_row::FromRow)
184 } else {
185 quote! (#target_ty: for<'a> postgres_from_row::tokio_postgres::types::FromSql<'a>)
186 });
187
188 if self.from.is_some() {
189 predicates.push(quote!(#ty: std::convert::From<#target_ty>))
190 } else if self.try_from.is_some() {
191 let try_from = quote!(std::convert::TryFrom<#target_ty>);
192
193 predicates.push(quote!(#ty: #try_from));
194 predicates.push(quote!(postgres_from_row::tokio_postgres::Error: std::convert::From<<#ty as #try_from>::Error>));
195 predicates.push(quote!(<#ty as #try_from>::Error: std::fmt::Debug));
196 }
197
198 Ok(())
199 }
200
201 fn generate_from_row(&self) -> Result<TokenStream2> {
203 let ident = self.ident.as_ref().unwrap();
204 let column_name = self.column_name();
205 let field_ty = &self.ty;
206 let target_ty = self.target_ty()?;
207
208 let mut base = if self.flatten {
209 quote!(<#target_ty as postgres_from_row::FromRow>::from_row(row))
210 } else {
211 quote!(postgres_from_row::tokio_postgres::Row::get::<&str, #target_ty>(row, #column_name))
212 };
213
214 if self.from.is_some() {
215 base = quote!(<#field_ty as std::convert::From<#target_ty>>::from(#base));
216 } else if self.try_from.is_some() {
217 base = quote!(<#field_ty as std::convert::TryFrom<#target_ty>>::try_from(#base).expect("could not convert column"));
218 };
219
220 Ok(quote!(#ident: #base))
221 }
222
223 fn generate_try_from_row(&self) -> Result<TokenStream2> {
225 let ident = self.ident.as_ref().unwrap();
226 let column_name = self.column_name();
227 let field_ty = &self.ty;
228 let target_ty = self.target_ty()?;
229
230 let mut base = if self.flatten {
231 quote!(<#target_ty as postgres_from_row::FromRow>::try_from_row(row)?)
232 } else {
233 quote!(postgres_from_row::tokio_postgres::Row::try_get::<&str, #target_ty>(row, #column_name)?)
234 };
235
236 if self.from.is_some() {
237 base = quote!(<#field_ty as std::convert::From<#target_ty>>::from(#base));
238 } else if self.try_from.is_some() {
239 base = quote!(<#field_ty as std::convert::TryFrom<#target_ty>>::try_from(#base)?);
240 };
241
242 Ok(quote!(#ident: #base))
243 }
244}