1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote};
3use std::iter;
4use syn::{
5 punctuated::Punctuated, token, AngleBracketedGenericArguments, Data, DataStruct, DeriveInput,
6 Error, Fields, GenericArgument, GenericParam, Generics, Ident, Lifetime, PathArguments,
7 PathSegment,
8};
9use syn::{LifetimeParam, TraitBound, TraitBoundModifier, TypeParamBound};
10
11use crate::accepts;
12use crate::composites::Field;
13use crate::composites::{append_generic_bound, new_derive_path};
14use crate::enums::Variant;
15use crate::overrides::Overrides;
16
17pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
18 let overrides = Overrides::extract(&input.attrs, true)?;
19
20 if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent {
21 return Err(Error::new_spanned(
22 &input,
23 "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]",
24 ));
25 }
26
27 let name = overrides
28 .name
29 .clone()
30 .unwrap_or_else(|| input.ident.to_string());
31
32 let (accepts_body, to_sql_body) = if overrides.transparent {
33 match input.data {
34 Data::Struct(DataStruct {
35 fields: Fields::Unnamed(ref fields),
36 ..
37 }) if fields.unnamed.len() == 1 => {
38 let field = fields.unnamed.first().unwrap();
39 (
40 accepts::transparent_body(field),
41 transparent_body(&input.ident, field),
42 )
43 }
44 _ => {
45 return Err(Error::new_spanned(
46 input,
47 "#[postgres(transparent)] may only be applied to single field tuple structs",
48 ))
49 }
50 }
51 } else if overrides.allow_mismatch {
52 match input.data {
53 Data::Enum(ref data) => {
54 let variants = data
55 .variants
56 .iter()
57 .map(|variant| Variant::parse(variant, overrides.rename_all))
58 .collect::<Result<Vec<_>, _>>()?;
59 (
60 accepts::enum_body(&name, &variants, overrides.allow_mismatch),
61 enum_body(&input.ident, &variants),
62 )
63 }
64 _ => {
65 return Err(Error::new_spanned(
66 input,
67 "#[postgres(allow_mismatch)] may only be applied to enums",
68 ));
69 }
70 }
71 } else {
72 match input.data {
73 Data::Enum(ref data) => {
74 let variants = data
75 .variants
76 .iter()
77 .map(|variant| Variant::parse(variant, overrides.rename_all))
78 .collect::<Result<Vec<_>, _>>()?;
79 (
80 accepts::enum_body(&name, &variants, overrides.allow_mismatch),
81 enum_body(&input.ident, &variants),
82 )
83 }
84 Data::Struct(DataStruct {
85 fields: Fields::Unnamed(ref fields),
86 ..
87 }) if fields.unnamed.len() == 1 => {
88 let field = fields.unnamed.first().unwrap();
89 (
90 domain_accepts_body(&name, field),
91 domain_body(&input.ident, field),
92 )
93 }
94 Data::Struct(DataStruct {
95 fields: Fields::Named(ref fields),
96 ..
97 }) => {
98 let fields = fields
99 .named
100 .iter()
101 .map(|field| Field::parse(field, overrides.rename_all))
102 .collect::<Result<Vec<_>, _>>()?;
103 (
104 accepts::composite_body(&name, "FromSql", &fields),
105 composite_body(&input.ident, &fields),
106 )
107 }
108 _ => {
109 return Err(Error::new_spanned(
110 input,
111 "#[derive(FromSql)] may only be applied to structs, single field tuple structs, and enums",
112 ))
113 }
114 }
115 };
116
117 let ident = &input.ident;
118 let (generics, lifetime) = build_generics(&input.generics);
119 let (impl_generics, _, _) = generics.split_for_impl();
120 let (_, ty_generics, where_clause) = input.generics.split_for_impl();
121 let out = quote! {
122 impl #impl_generics postgres_types::FromSql<#lifetime> for #ident #ty_generics #where_clause {
123 fn from_sql(_type: &postgres_types::Type, buf: &#lifetime [u8])
124 -> std::result::Result<#ident #ty_generics,
125 std::boxed::Box<dyn std::error::Error +
126 std::marker::Sync +
127 std::marker::Send>> {
128 #to_sql_body
129 }
130
131 fn accepts(type_: &postgres_types::Type) -> bool {
132 #accepts_body
133 }
134 }
135 };
136
137 Ok(out)
138}
139
140fn transparent_body(ident: &Ident, field: &syn::Field) -> TokenStream {
141 let ty = &field.ty;
142 quote! {
143 <#ty as postgres_types::FromSql>::from_sql(_type, buf).map(#ident)
144 }
145}
146
147fn enum_body(ident: &Ident, variants: &[Variant]) -> TokenStream {
148 let variant_names = variants.iter().map(|v| &v.name);
149 let idents = iter::repeat(ident);
150 let variant_idents = variants.iter().map(|v| &v.ident);
151
152 quote! {
153 match std::str::from_utf8(buf)? {
154 #(
155 #variant_names => std::result::Result::Ok(#idents::#variant_idents),
156 )*
157 s => {
158 std::result::Result::Err(
159 std::convert::Into::into(format!("invalid variant `{}`", s)))
160 }
161 }
162 }
163}
164
165fn domain_accepts_body(name: &str, field: &syn::Field) -> TokenStream {
167 let ty = &field.ty;
168 let normal_body = accepts::domain_body(name, field);
169
170 quote! {
171 if <#ty as postgres_types::FromSql>::accepts(type_) {
172 return true;
173 }
174
175 #normal_body
176 }
177}
178
179fn domain_body(ident: &Ident, field: &syn::Field) -> TokenStream {
180 let ty = &field.ty;
181 quote! {
182 <#ty as postgres_types::FromSql>::from_sql(match *_type.kind() {
183 postgres_types::Kind::Domain(ref _type) => _type,
184 _ => _type
185 }, buf).map(#ident)
186 }
187}
188
189fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
190 let temp_vars = &fields
191 .iter()
192 .map(|f| format_ident!("__{}", f.ident))
193 .collect::<Vec<_>>();
194 let field_names = &fields.iter().map(|f| &f.name).collect::<Vec<_>>();
195 let field_idents = &fields.iter().map(|f| &f.ident).collect::<Vec<_>>();
196
197 quote! {
198 let fields = match *_type.kind() {
199 postgres_types::Kind::Composite(ref fields) => fields,
200 _ => unreachable!(),
201 };
202
203 let mut buf = buf;
204 let num_fields = postgres_types::private::read_be_i32(&mut buf)?;
205 if num_fields as usize != fields.len() {
206 return std::result::Result::Err(
207 std::convert::Into::into(format!("invalid field count: {} vs {}", num_fields, fields.len())));
208 }
209
210 #(
211 let mut #temp_vars = std::option::Option::None;
212 )*
213
214 for field in fields {
215 let oid = postgres_types::private::read_be_i32(&mut buf)? as u32;
216 if oid != field.type_().oid() {
217 return std::result::Result::Err(std::convert::Into::into("unexpected OID"));
218 }
219
220 match field.name() {
221 #(
222 #field_names => {
223 #temp_vars = std::option::Option::Some(
224 postgres_types::private::read_value(field.type_(), &mut buf)?);
225 }
226 )*
227 _ => unreachable!(),
228 }
229 }
230
231 std::result::Result::Ok(#ident {
232 #(
233 #field_idents: #temp_vars.unwrap(),
234 )*
235 })
236 }
237}
238
239fn build_generics(source: &Generics) -> (Generics, Lifetime) {
240 let lifetime = Lifetime::new("'a", Span::call_site());
242
243 let mut out = append_generic_bound(source.to_owned(), &new_fromsql_bound(&lifetime));
244 out.params.insert(
245 0,
246 GenericParam::Lifetime(LifetimeParam::new(lifetime.to_owned())),
247 );
248
249 (out, lifetime)
250}
251
252fn new_fromsql_bound(lifetime: &Lifetime) -> TypeParamBound {
253 let mut path_segment: PathSegment = Ident::new("FromSql", Span::call_site()).into();
254 let mut seg_args = Punctuated::new();
255 seg_args.push(GenericArgument::Lifetime(lifetime.to_owned()));
256 path_segment.arguments = PathArguments::AngleBracketed(AngleBracketedGenericArguments {
257 colon2_token: None,
258 lt_token: token::Lt::default(),
259 args: seg_args,
260 gt_token: token::Gt::default(),
261 });
262
263 TypeParamBound::Trait(TraitBound {
264 lifetimes: None,
265 modifier: TraitBoundModifier::None,
266 paren_token: None,
267 path: new_derive_path(path_segment),
268 })
269}