postgres_derive/
tosql.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use std::iter;
4use syn::{
5    Data, DataStruct, DeriveInput, Error, Fields, Ident, TraitBound, TraitBoundModifier,
6    TypeParamBound,
7};
8
9use crate::accepts;
10use crate::composites::Field;
11use crate::composites::{append_generic_bound, new_derive_path};
12use crate::enums::Variant;
13use crate::overrides::Overrides;
14
15pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
16    let overrides = Overrides::extract(&input.attrs, true)?;
17
18    if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent {
19        return Err(Error::new_spanned(
20            &input,
21            "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]",
22        ));
23    }
24
25    let name = overrides
26        .name
27        .clone()
28        .unwrap_or_else(|| input.ident.to_string());
29
30    let (accepts_body, to_sql_body) = if overrides.transparent {
31        match input.data {
32            Data::Struct(DataStruct {
33                fields: Fields::Unnamed(ref fields),
34                ..
35            }) if fields.unnamed.len() == 1 => {
36                let field = fields.unnamed.first().unwrap();
37
38                (accepts::transparent_body(field), transparent_body())
39            }
40            _ => {
41                return Err(Error::new_spanned(
42                    input,
43                    "#[postgres(transparent)] may only be applied to single field tuple structs",
44                ));
45            }
46        }
47    } else if overrides.allow_mismatch {
48        match input.data {
49            Data::Enum(ref data) => {
50                let variants = data
51                    .variants
52                    .iter()
53                    .map(|variant| Variant::parse(variant, overrides.rename_all))
54                    .collect::<Result<Vec<_>, _>>()?;
55                (
56                    accepts::enum_body(&name, &variants, overrides.allow_mismatch),
57                    enum_body(&input.ident, &variants),
58                )
59            }
60            _ => {
61                return Err(Error::new_spanned(
62                    input,
63                    "#[postgres(allow_mismatch)] may only be applied to enums",
64                ));
65            }
66        }
67    } else {
68        match input.data {
69            Data::Enum(ref data) => {
70                let variants = data
71                    .variants
72                    .iter()
73                    .map(|variant| Variant::parse(variant, overrides.rename_all))
74                    .collect::<Result<Vec<_>, _>>()?;
75                (
76                    accepts::enum_body(&name, &variants, overrides.allow_mismatch),
77                    enum_body(&input.ident, &variants),
78                )
79            }
80            Data::Struct(DataStruct {
81                fields: Fields::Unnamed(ref fields),
82                ..
83            }) if fields.unnamed.len() == 1 => {
84                let field = fields.unnamed.first().unwrap();
85
86                (accepts::domain_body(&name, field), domain_body())
87            }
88            Data::Struct(DataStruct {
89                fields: Fields::Named(ref fields),
90                ..
91            }) => {
92                let fields = fields
93                    .named
94                    .iter()
95                    .map(|field| Field::parse(field, overrides.rename_all))
96                    .collect::<Result<Vec<_>, _>>()?;
97                (
98                    accepts::composite_body(&name, "ToSql", &fields),
99                    composite_body(&fields),
100                )
101            }
102            _ => {
103                return Err(Error::new_spanned(
104                    input,
105                    "#[derive(ToSql)] may only be applied to structs, single field tuple structs, and enums",
106                ));
107            }
108        }
109    };
110
111    let ident = &input.ident;
112    let generics = append_generic_bound(input.generics.to_owned(), &new_tosql_bound());
113    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
114    let out = quote! {
115        impl#impl_generics postgres_types::ToSql for #ident#ty_generics #where_clause {
116            fn to_sql(&self,
117                      _type: &postgres_types::Type,
118                      buf: &mut postgres_types::private::BytesMut)
119                      -> std::result::Result<postgres_types::IsNull,
120                                             std::boxed::Box<std::error::Error +
121                                                             std::marker::Sync +
122                                                             std::marker::Send>> {
123                #to_sql_body
124            }
125
126            fn accepts(type_: &postgres_types::Type) -> bool {
127                #accepts_body
128            }
129
130            postgres_types::to_sql_checked!();
131        }
132    };
133
134    Ok(out)
135}
136
137fn transparent_body() -> TokenStream {
138    quote! {
139        postgres_types::ToSql::to_sql(&self.0, _type, buf)
140    }
141}
142
143fn enum_body(ident: &Ident, variants: &[Variant]) -> TokenStream {
144    let idents = iter::repeat(ident);
145    let variant_idents = variants.iter().map(|v| &v.ident);
146    let variant_names = variants.iter().map(|v| &v.name);
147
148    quote! {
149        let s = match *self {
150            #(
151                #idents::#variant_idents => #variant_names,
152            )*
153        };
154
155        buf.extend_from_slice(s.as_bytes());
156        std::result::Result::Ok(postgres_types::IsNull::No)
157    }
158}
159
160fn domain_body() -> TokenStream {
161    quote! {
162        let type_ = match *_type.kind() {
163            postgres_types::Kind::Domain(ref type_) => type_,
164            _ => unreachable!(),
165        };
166
167        postgres_types::ToSql::to_sql(&self.0, type_, buf)
168    }
169}
170
171fn composite_body(fields: &[Field]) -> TokenStream {
172    let field_names = fields.iter().map(|f| &f.name);
173    let field_idents = fields.iter().map(|f| &f.ident);
174
175    quote! {
176        let fields = match *_type.kind() {
177            postgres_types::Kind::Composite(ref fields) => fields,
178            _ => unreachable!(),
179        };
180
181        buf.extend_from_slice(&(fields.len() as i32).to_be_bytes());
182
183        for field in fields {
184            buf.extend_from_slice(&field.type_().oid().to_be_bytes());
185
186            let base = buf.len();
187            buf.extend_from_slice(&[0; 4]);
188            let r = match field.name() {
189                #(
190                    #field_names => postgres_types::ToSql::to_sql(&self.#field_idents, field.type_(), buf),
191                )*
192                _ => unreachable!(),
193            };
194
195            let count = match r? {
196                postgres_types::IsNull::Yes => -1,
197                postgres_types::IsNull::No => {
198                    let len = buf.len() - base - 4;
199                    if len > i32::max_value() as usize {
200                        return std::result::Result::Err(
201                            std::convert::Into::into("value too large to transmit"));
202                    }
203                    len as i32
204                }
205            };
206
207            buf[base..base + 4].copy_from_slice(&count.to_be_bytes());
208        }
209
210        std::result::Result::Ok(postgres_types::IsNull::No)
211    }
212}
213
214fn new_tosql_bound() -> TypeParamBound {
215    TypeParamBound::Trait(TraitBound {
216        lifetimes: None,
217        modifier: TraitBoundModifier::None,
218        paren_token: None,
219        path: new_derive_path(Ident::new("ToSql", Span::call_site()).into()),
220    })
221}