postgres_from_row_derive/
lib.rs

1use darling::{ast::Data, Error, FromDeriveInput, FromField, ToTokens};
2use proc_macro::TokenStream;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::quote;
5use syn::{parse_macro_input, DeriveInput, Result};
6
7/// Calls the fallible entry point and writes any errors to the tokenstream.
8#[proc_macro_derive(FromRow, attributes(from_row))]
9pub fn derive_from_row(input: TokenStream) -> TokenStream {
10    let derive_input = parse_macro_input!(input as DeriveInput);
11    match try_derive_from_row(&derive_input) {
12        Ok(result) => result,
13        Err(err) => err.write_errors().into(),
14    }
15}
16
17/// Fallible entry point for generating a `FromRow` implementation
18fn try_derive_from_row(input: &DeriveInput) -> std::result::Result<TokenStream, Error> {
19    let from_row_derive = DeriveFromRow::from_derive_input(input)?;
20    Ok(from_row_derive.generate()?)
21}
22
23/// Main struct for deriving `FromRow` for a struct.
24#[derive(Debug, FromDeriveInput)]
25#[darling(
26    attributes(from_row),
27    forward_attrs(allow, doc, cfg),
28    supports(struct_named)
29)]
30struct DeriveFromRow {
31    ident: syn::Ident,
32    generics: syn::Generics,
33    data: Data<(), FromRowField>,
34}
35
36impl DeriveFromRow {
37    /// Validates all fields
38    fn validate(&self) -> Result<()> {
39        for field in self.fields() {
40            field.validate()?;
41        }
42
43        Ok(())
44    }
45
46    /// Generates any additional where clause predicates needed for the fields in this struct.
47    fn predicates(&self) -> Result<Vec<TokenStream2>> {
48        let mut predicates = Vec::new();
49
50        for field in self.fields() {
51            field.add_predicates(&mut predicates)?;
52        }
53
54        Ok(predicates)
55    }
56
57    /// Provides a slice of this struct's fields.
58    fn fields(&self) -> &[FromRowField] {
59        match &self.data {
60            Data::Struct(fields) => &fields.fields,
61            _ => panic!("invalid shape"),
62        }
63    }
64
65    /// Generate the `FromRow` implementation.
66    fn generate(self) -> Result<TokenStream> {
67        self.validate()?;
68
69        let ident = &self.ident;
70
71        let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
72        let original_predicates = where_clause.clone().map(|w| &w.predicates).into_iter();
73        let predicates = self.predicates()?;
74
75        let from_row_fields = self
76            .fields()
77            .iter()
78            .map(|f| f.generate_from_row())
79            .collect::<syn::Result<Vec<_>>>()?;
80
81        let try_from_row_fields = self
82            .fields()
83            .iter()
84            .map(|f| f.generate_try_from_row())
85            .collect::<syn::Result<Vec<_>>>()?;
86
87        Ok(quote! {
88            impl #impl_generics postgres_from_row::FromRow for #ident #ty_generics where #(#original_predicates),* #(#predicates),* {
89
90                fn from_row(row: &postgres_from_row::tokio_postgres::Row) -> Self {
91                    Self {
92                        #(#from_row_fields),*
93                    }
94                }
95
96                fn try_from_row(row: &postgres_from_row::tokio_postgres::Row) -> std::result::Result<Self, postgres_from_row::tokio_postgres::Error> {
97                    Ok(Self {
98                        #(#try_from_row_fields),*
99                    })
100                }
101            }
102        }
103        .into())
104    }
105}
106
107/// A single field inside of a struct that derives `FromRow`
108#[derive(Debug, FromField)]
109#[darling(attributes(from_row), forward_attrs(allow, doc, cfg))]
110struct FromRowField {
111    /// The identifier of this field.
112    ident: Option<syn::Ident>,
113    /// The type specified in this field.
114    ty: syn::Type,
115    /// Wether to flatten this field. Flattening means calling the `FromRow` implementation
116    /// of `self.ty` instead of extracting it directly from the row.
117    #[darling(default)]
118    flatten: bool,
119    /// Optionaly use this type as the target for `FromRow` or `FromSql`, and then
120    /// call `TryFrom::try_from` to convert it the `self.ty`.
121    try_from: Option<String>,
122    /// Optionaly use this type as the target for `FromRow` or `FromSql`, and then
123    /// call `From::from` to convert it the `self.ty`.
124    from: Option<String>,
125    /// Override the name of the actual sql column instead of using `self.ident`.
126    /// Is not compatible with `flatten` since no column is needed there.
127    rename: Option<String>,
128}
129
130impl FromRowField {
131    /// Checks wether this field has a valid combination of attributes
132    fn validate(&self) -> Result<()> {
133        if self.from.is_some() && self.try_from.is_some() {
134            return Err(Error::custom(
135                r#"can't combine `#[from_row(from = "..")]` with `#[from_row(try_from = "..")]`"#,
136            )
137            .into());
138        }
139
140        if self.rename.is_some() && self.flatten {
141            return Err(Error::custom(
142                r#"can't combine `#[from_row(flatten)]` with `#[from_row(rename = "..")]`"#,
143            )
144            .into());
145        }
146
147        Ok(())
148    }
149
150    /// Returns a tokenstream of the type that should be returned from either
151    /// `FromRow` (when using `flatten`) or `FromSql`.
152    fn target_ty(&self) -> Result<TokenStream2> {
153        if let Some(from) = &self.from {
154            Ok(from.parse()?)
155        } else if let Some(try_from) = &self.try_from {
156            Ok(try_from.parse()?)
157        } else {
158            Ok(self.ty.to_token_stream())
159        }
160    }
161
162    /// Returns the name that maps to the actuall sql column
163    /// By default this is the same as the rust field name but can be overwritten by `#[from_row(rename = "..")]`.
164    fn column_name(&self) -> String {
165        self.rename
166            .as_ref()
167            .map(Clone::clone)
168            .unwrap_or_else(|| self.ident.as_ref().unwrap().to_string())
169    }
170
171    /// Pushes the needed where clause predicates for this field.
172    ///
173    /// By default this is `T: for<'a> postgres::types::FromSql<'a>`,
174    /// when using `flatten` it's: `T: postgres_from_row::FromRow`
175    /// and when using either `from` or `try_from` attributes it additionally pushes this bound:
176    /// `T: std::convert::From<R>`, where `T` is the type specified in the struct and `R` is the
177    /// type specified in the `[try]_from` attribute.
178    fn add_predicates(&self, predicates: &mut Vec<TokenStream2>) -> Result<()> {
179        let target_ty = &self.target_ty()?;
180        let ty = &self.ty;
181
182        predicates.push(if self.flatten {
183            quote! (#target_ty: postgres_from_row::FromRow)
184        } else {
185            quote! (#target_ty: for<'a> postgres_from_row::tokio_postgres::types::FromSql<'a>)
186        });
187
188        if self.from.is_some() {
189            predicates.push(quote!(#ty: std::convert::From<#target_ty>))
190        } else if self.try_from.is_some() {
191            let try_from = quote!(std::convert::TryFrom<#target_ty>);
192
193            predicates.push(quote!(#ty: #try_from));
194            predicates.push(quote!(postgres_from_row::tokio_postgres::Error: std::convert::From<<#ty as #try_from>::Error>));
195            predicates.push(quote!(<#ty as #try_from>::Error: std::fmt::Debug));
196        }
197
198        Ok(())
199    }
200
201    /// Generate the line needed to retrievee this field from a row when calling `from_row`.
202    fn generate_from_row(&self) -> Result<TokenStream2> {
203        let ident = self.ident.as_ref().unwrap();
204        let column_name = self.column_name();
205        let field_ty = &self.ty;
206        let target_ty = self.target_ty()?;
207
208        let mut base = if self.flatten {
209            quote!(<#target_ty as postgres_from_row::FromRow>::from_row(row))
210        } else {
211            quote!(postgres_from_row::tokio_postgres::Row::get::<&str, #target_ty>(row, #column_name))
212        };
213
214        if self.from.is_some() {
215            base = quote!(<#field_ty as std::convert::From<#target_ty>>::from(#base));
216        } else if self.try_from.is_some() {
217            base = quote!(<#field_ty as std::convert::TryFrom<#target_ty>>::try_from(#base).expect("could not convert column"));
218        };
219
220        Ok(quote!(#ident: #base))
221    }
222
223    /// Generate the line needed to retrieve this field from a row when calling `try_from_row`.
224    fn generate_try_from_row(&self) -> Result<TokenStream2> {
225        let ident = self.ident.as_ref().unwrap();
226        let column_name = self.column_name();
227        let field_ty = &self.ty;
228        let target_ty = self.target_ty()?;
229
230        let mut base = if self.flatten {
231            quote!(<#target_ty as postgres_from_row::FromRow>::try_from_row(row)?)
232        } else {
233            quote!(postgres_from_row::tokio_postgres::Row::try_get::<&str, #target_ty>(row, #column_name)?)
234        };
235
236        if self.from.is_some() {
237            base = quote!(<#field_ty as std::convert::From<#target_ty>>::from(#base));
238        } else if self.try_from.is_some() {
239            base = quote!(<#field_ty as std::convert::TryFrom<#target_ty>>::try_from(#base)?);
240        };
241
242        Ok(quote!(#ident: #base))
243    }
244}