1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote};
3use std::iter;
4use syn::{
5 punctuated::Punctuated, token, AngleBracketedGenericArguments, Data, DataStruct, DeriveInput,
6 Error, Fields, GenericArgument, GenericParam, Generics, Ident, Lifetime, PathArguments,
7 PathSegment,
8};
9use syn::{LifetimeParam, TraitBound, TraitBoundModifier, TypeParamBound};
10
11use crate::accepts;
12use crate::composites::Field;
13use crate::composites::{append_generic_bound, new_derive_path};
14use crate::enums::Variant;
15use crate::overrides::Overrides;
16
17pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
18 let overrides = Overrides::extract(&input.attrs, true)?;
19
20 if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent {
21 return Err(Error::new_spanned(
22 &input,
23 "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]",
24 ));
25 }
26
27 let name = overrides
28 .name
29 .clone()
30 .unwrap_or_else(|| input.ident.to_string());
31
32 let (accepts_body, to_sql_body) = if overrides.transparent {
33 match input.data {
34 Data::Struct(DataStruct {
35 fields: Fields::Unnamed(ref fields),
36 ..
37 }) if fields.unnamed.len() == 1 => {
38 let field = fields.unnamed.first().unwrap();
39 (
40 accepts::transparent_body(field),
41 transparent_body(&input.ident, field),
42 )
43 }
44 _ => {
45 return Err(Error::new_spanned(
46 input,
47 "#[postgres(transparent)] may only be applied to single field tuple structs",
48 ))
49 }
50 }
51 } else if overrides.allow_mismatch {
52 match input.data {
53 Data::Enum(ref data) => {
54 let variants = data
55 .variants
56 .iter()
57 .map(|variant| Variant::parse(variant, overrides.rename_all))
58 .collect::<Result<Vec<_>, _>>()?;
59 (
60 accepts::enum_body(&name, &variants, overrides.allow_mismatch),
61 enum_body(&input.ident, &variants),
62 )
63 }
64 _ => {
65 return Err(Error::new_spanned(
66 input,
67 "#[postgres(allow_mismatch)] may only be applied to enums",
68 ));
69 }
70 }
71 } else {
72 match input.data {
73 Data::Enum(ref data) => {
74 let variants = data
75 .variants
76 .iter()
77 .map(|variant| Variant::parse(variant, overrides.rename_all))
78 .collect::<Result<Vec<_>, _>>()?;
79 (
80 accepts::enum_body(&name, &variants, overrides.allow_mismatch),
81 enum_body(&input.ident, &variants),
82 )
83 }
84 Data::Struct(DataStruct {
85 fields: Fields::Unnamed(ref fields),
86 ..
87 }) if fields.unnamed.len() == 1 => {
88 let field = fields.unnamed.first().unwrap();
89 (
90 domain_accepts_body(&name, field),
91 domain_body(&input.ident, field),
92 )
93 }
94 Data::Struct(DataStruct {
95 fields: Fields::Named(ref fields),
96 ..
97 }) => {
98 let fields = fields
99 .named
100 .iter()
101 .map(|field| Field::parse(field, overrides.rename_all))
102 .collect::<Result<Vec<_>, _>>()?;
103 (
104 accepts::composite_body(&name, "FromSql", &fields),
105 composite_body(&input.ident, &fields),
106 )
107 }
108 _ => {
109 return Err(Error::new_spanned(
110 input,
111 "#[derive(FromSql)] may only be applied to structs, single field tuple structs, and enums",
112 ))
113 }
114 }
115 };
116
117 let ident = &input.ident;
118 let (generics, lifetime) = build_generics(&input.generics);
119 let (impl_generics, _, _) = generics.split_for_impl();
120 let (_, ty_generics, where_clause) = input.generics.split_for_impl();
121 let out = quote! {
122 impl #impl_generics postgres_types::FromSql<#lifetime> for #ident #ty_generics #where_clause {
123 fn from_sql(_type: &postgres_types::Type, buf: &#lifetime [u8])
124 -> std::result::Result<#ident #ty_generics,
125 std::boxed::Box<dyn std::error::Error +
126 std::marker::Sync +
127 std::marker::Send>> {
128 #to_sql_body
129 }
130
131 fn accepts(type_: &postgres_types::Type) -> bool {
132 #accepts_body
133 }
134 }
135 };
136
137 Ok(out)
138}
139
140fn transparent_body(ident: &Ident, field: &syn::Field) -> TokenStream {
141 let ty = &field.ty;
142 quote! {
143 <#ty as postgres_types::FromSql>::from_sql(_type, buf).map(#ident)
144 }
145}
146
147fn enum_body(ident: &Ident, variants: &[Variant]) -> TokenStream {
148 let variant_names = variants.iter().map(|v| &v.name);
149 let idents = iter::repeat(ident);
150 let variant_idents = variants.iter().map(|v| &v.ident);
151
152 quote! {
153 match std::str::from_utf8(buf)? {
154 #(
155 #variant_names => std::result::Result::Ok(#idents::#variant_idents),
156 )*
157 s => {
158 std::result::Result::Err(
159 std::convert::Into::into(format!("invalid variant `{}`", s)))
160 }
161 }
162 }
163}
164
165fn domain_accepts_body(name: &str, field: &syn::Field) -> TokenStream {
167 let ty = &field.ty;
168 let normal_body = accepts::domain_body(name, field);
169
170 quote! {
171 if <#ty as postgres_types::FromSql>::accepts(type_) {
172 return true;
173 }
174
175 #normal_body
176 }
177}
178
179fn domain_body(ident: &Ident, field: &syn::Field) -> TokenStream {
180 let ty = &field.ty;
181 quote! {
182 <#ty as postgres_types::FromSql>::from_sql(_type, buf).map(#ident)
183 }
184}
185
186fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
187 let temp_vars = &fields
188 .iter()
189 .map(|f| format_ident!("__{}", f.ident))
190 .collect::<Vec<_>>();
191 let field_names = &fields.iter().map(|f| &f.name).collect::<Vec<_>>();
192 let field_idents = &fields.iter().map(|f| &f.ident).collect::<Vec<_>>();
193
194 quote! {
195 let fields = match *_type.kind() {
196 postgres_types::Kind::Composite(ref fields) => fields,
197 _ => unreachable!(),
198 };
199
200 let mut buf = buf;
201 let num_fields = postgres_types::private::read_be_i32(&mut buf)?;
202 if num_fields as usize != fields.len() {
203 return std::result::Result::Err(
204 std::convert::Into::into(format!("invalid field count: {} vs {}", num_fields, fields.len())));
205 }
206
207 #(
208 let mut #temp_vars = std::option::Option::None;
209 )*
210
211 for field in fields {
212 let oid = postgres_types::private::read_be_i32(&mut buf)? as u32;
213 if oid != field.type_().oid() {
214 return std::result::Result::Err(std::convert::Into::into("unexpected OID"));
215 }
216
217 match field.name() {
218 #(
219 #field_names => {
220 #temp_vars = std::option::Option::Some(
221 postgres_types::private::read_value(field.type_(), &mut buf)?);
222 }
223 )*
224 _ => unreachable!(),
225 }
226 }
227
228 std::result::Result::Ok(#ident {
229 #(
230 #field_idents: #temp_vars.unwrap(),
231 )*
232 })
233 }
234}
235
236fn build_generics(source: &Generics) -> (Generics, Lifetime) {
237 let lifetime = Lifetime::new("'a", Span::call_site());
239
240 let mut out = append_generic_bound(source.to_owned(), &new_fromsql_bound(&lifetime));
241 out.params.insert(
242 0,
243 GenericParam::Lifetime(LifetimeParam::new(lifetime.to_owned())),
244 );
245
246 (out, lifetime)
247}
248
249fn new_fromsql_bound(lifetime: &Lifetime) -> TypeParamBound {
250 let mut path_segment: PathSegment = Ident::new("FromSql", Span::call_site()).into();
251 let mut seg_args = Punctuated::new();
252 seg_args.push(GenericArgument::Lifetime(lifetime.to_owned()));
253 path_segment.arguments = PathArguments::AngleBracketed(AngleBracketedGenericArguments {
254 colon2_token: None,
255 lt_token: token::Lt::default(),
256 args: seg_args,
257 gt_token: token::Gt::default(),
258 });
259
260 TypeParamBound::Trait(TraitBound {
261 lifetimes: None,
262 modifier: TraitBoundModifier::None,
263 paren_token: None,
264 path: new_derive_path(path_segment),
265 })
266}