postgres_derive/
fromsql.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote};
3use std::iter;
4use syn::{
5    punctuated::Punctuated, token, AngleBracketedGenericArguments, Data, DataStruct, DeriveInput,
6    Error, Fields, GenericArgument, GenericParam, Generics, Ident, Lifetime, PathArguments,
7    PathSegment,
8};
9use syn::{LifetimeParam, TraitBound, TraitBoundModifier, TypeParamBound};
10
11use crate::accepts;
12use crate::composites::Field;
13use crate::composites::{append_generic_bound, new_derive_path};
14use crate::enums::Variant;
15use crate::overrides::Overrides;
16
17pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
18    let overrides = Overrides::extract(&input.attrs, true)?;
19
20    if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent {
21        return Err(Error::new_spanned(
22            &input,
23            "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]",
24        ));
25    }
26
27    let name = overrides
28        .name
29        .clone()
30        .unwrap_or_else(|| input.ident.to_string());
31
32    let (accepts_body, to_sql_body) = if overrides.transparent {
33        match input.data {
34            Data::Struct(DataStruct {
35                fields: Fields::Unnamed(ref fields),
36                ..
37            }) if fields.unnamed.len() == 1 => {
38                let field = fields.unnamed.first().unwrap();
39                (
40                    accepts::transparent_body(field),
41                    transparent_body(&input.ident, field),
42                )
43            }
44            _ => {
45                return Err(Error::new_spanned(
46                    input,
47                    "#[postgres(transparent)] may only be applied to single field tuple structs",
48                ))
49            }
50        }
51    } else if overrides.allow_mismatch {
52        match input.data {
53            Data::Enum(ref data) => {
54                let variants = data
55                    .variants
56                    .iter()
57                    .map(|variant| Variant::parse(variant, overrides.rename_all))
58                    .collect::<Result<Vec<_>, _>>()?;
59                (
60                    accepts::enum_body(&name, &variants, overrides.allow_mismatch),
61                    enum_body(&input.ident, &variants),
62                )
63            }
64            _ => {
65                return Err(Error::new_spanned(
66                    input,
67                    "#[postgres(allow_mismatch)] may only be applied to enums",
68                ));
69            }
70        }
71    } else {
72        match input.data {
73        Data::Enum(ref data) => {
74            let variants = data
75                .variants
76                .iter()
77                .map(|variant| Variant::parse(variant, overrides.rename_all))
78                .collect::<Result<Vec<_>, _>>()?;
79            (
80                accepts::enum_body(&name, &variants, overrides.allow_mismatch),
81                enum_body(&input.ident, &variants),
82            )
83        }
84        Data::Struct(DataStruct {
85            fields: Fields::Unnamed(ref fields),
86            ..
87        }) if fields.unnamed.len() == 1 => {
88            let field = fields.unnamed.first().unwrap();
89            (
90                domain_accepts_body(&name, field),
91                domain_body(&input.ident, field),
92            )
93        }
94        Data::Struct(DataStruct {
95            fields: Fields::Named(ref fields),
96            ..
97        }) => {
98            let fields = fields
99                .named
100                .iter()
101                .map(|field| Field::parse(field, overrides.rename_all))
102                .collect::<Result<Vec<_>, _>>()?;
103            (
104                accepts::composite_body(&name, "FromSql", &fields),
105                composite_body(&input.ident, &fields),
106            )
107        }
108        _ => {
109            return Err(Error::new_spanned(
110                input,
111                "#[derive(FromSql)] may only be applied to structs, single field tuple structs, and enums",
112            ))
113        }
114    }
115    };
116
117    let ident = &input.ident;
118    let (generics, lifetime) = build_generics(&input.generics);
119    let (impl_generics, _, _) = generics.split_for_impl();
120    let (_, ty_generics, where_clause) = input.generics.split_for_impl();
121    let out = quote! {
122        impl #impl_generics postgres_types::FromSql<#lifetime> for #ident #ty_generics #where_clause {
123            fn from_sql(_type: &postgres_types::Type, buf: &#lifetime [u8])
124                        -> std::result::Result<#ident #ty_generics,
125                                               std::boxed::Box<dyn std::error::Error +
126                                                               std::marker::Sync +
127                                                               std::marker::Send>> {
128                #to_sql_body
129            }
130
131            fn accepts(type_: &postgres_types::Type) -> bool {
132                #accepts_body
133            }
134        }
135    };
136
137    Ok(out)
138}
139
140fn transparent_body(ident: &Ident, field: &syn::Field) -> TokenStream {
141    let ty = &field.ty;
142    quote! {
143        <#ty as postgres_types::FromSql>::from_sql(_type, buf).map(#ident)
144    }
145}
146
147fn enum_body(ident: &Ident, variants: &[Variant]) -> TokenStream {
148    let variant_names = variants.iter().map(|v| &v.name);
149    let idents = iter::repeat(ident);
150    let variant_idents = variants.iter().map(|v| &v.ident);
151
152    quote! {
153        match std::str::from_utf8(buf)? {
154            #(
155                #variant_names => std::result::Result::Ok(#idents::#variant_idents),
156            )*
157            s => {
158                std::result::Result::Err(
159                    std::convert::Into::into(format!("invalid variant `{}`", s)))
160            }
161        }
162    }
163}
164
165// Domains are sometimes but not always just represented by the bare type (!?)
166fn domain_accepts_body(name: &str, field: &syn::Field) -> TokenStream {
167    let ty = &field.ty;
168    let normal_body = accepts::domain_body(name, field);
169
170    quote! {
171        if <#ty as postgres_types::FromSql>::accepts(type_) {
172            return true;
173        }
174
175        #normal_body
176    }
177}
178
179fn domain_body(ident: &Ident, field: &syn::Field) -> TokenStream {
180    let ty = &field.ty;
181    quote! {
182        <#ty as postgres_types::FromSql>::from_sql(_type, buf).map(#ident)
183    }
184}
185
186fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
187    let temp_vars = &fields
188        .iter()
189        .map(|f| format_ident!("__{}", f.ident))
190        .collect::<Vec<_>>();
191    let field_names = &fields.iter().map(|f| &f.name).collect::<Vec<_>>();
192    let field_idents = &fields.iter().map(|f| &f.ident).collect::<Vec<_>>();
193
194    quote! {
195        let fields = match *_type.kind() {
196            postgres_types::Kind::Composite(ref fields) => fields,
197            _ => unreachable!(),
198        };
199
200        let mut buf = buf;
201        let num_fields = postgres_types::private::read_be_i32(&mut buf)?;
202        if num_fields as usize != fields.len() {
203            return std::result::Result::Err(
204                std::convert::Into::into(format!("invalid field count: {} vs {}", num_fields, fields.len())));
205        }
206
207        #(
208            let mut #temp_vars = std::option::Option::None;
209        )*
210
211        for field in fields {
212            let oid = postgres_types::private::read_be_i32(&mut buf)? as u32;
213            if oid != field.type_().oid() {
214                return std::result::Result::Err(std::convert::Into::into("unexpected OID"));
215            }
216
217            match field.name() {
218                #(
219                    #field_names => {
220                        #temp_vars = std::option::Option::Some(
221                            postgres_types::private::read_value(field.type_(), &mut buf)?);
222                    }
223                )*
224                _ => unreachable!(),
225            }
226        }
227
228        std::result::Result::Ok(#ident {
229            #(
230                #field_idents: #temp_vars.unwrap(),
231            )*
232        })
233    }
234}
235
236fn build_generics(source: &Generics) -> (Generics, Lifetime) {
237    // don't worry about lifetime name collisions, it doesn't make sense to derive FromSql on a struct with a lifetime
238    let lifetime = Lifetime::new("'a", Span::call_site());
239
240    let mut out = append_generic_bound(source.to_owned(), &new_fromsql_bound(&lifetime));
241    out.params.insert(
242        0,
243        GenericParam::Lifetime(LifetimeParam::new(lifetime.to_owned())),
244    );
245
246    (out, lifetime)
247}
248
249fn new_fromsql_bound(lifetime: &Lifetime) -> TypeParamBound {
250    let mut path_segment: PathSegment = Ident::new("FromSql", Span::call_site()).into();
251    let mut seg_args = Punctuated::new();
252    seg_args.push(GenericArgument::Lifetime(lifetime.to_owned()));
253    path_segment.arguments = PathArguments::AngleBracketed(AngleBracketedGenericArguments {
254        colon2_token: None,
255        lt_token: token::Lt::default(),
256        args: seg_args,
257        gt_token: token::Gt::default(),
258    });
259
260    TypeParamBound::Trait(TraitBound {
261        lifetimes: None,
262        modifier: TraitBoundModifier::None,
263        paren_token: None,
264        path: new_derive_path(path_segment),
265    })
266}