1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use std::iter;
4use syn::{
5 Data, DataStruct, DeriveInput, Error, Fields, Ident, TraitBound, TraitBoundModifier,
6 TypeParamBound,
7};
8
9use crate::accepts;
10use crate::composites::Field;
11use crate::composites::{append_generic_bound, new_derive_path};
12use crate::enums::Variant;
13use crate::overrides::Overrides;
14
15pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
16 let overrides = Overrides::extract(&input.attrs, true)?;
17
18 if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent {
19 return Err(Error::new_spanned(
20 &input,
21 "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]",
22 ));
23 }
24
25 let name = overrides
26 .name
27 .clone()
28 .unwrap_or_else(|| input.ident.to_string());
29
30 let (accepts_body, to_sql_body) = if overrides.transparent {
31 match input.data {
32 Data::Struct(DataStruct {
33 fields: Fields::Unnamed(ref fields),
34 ..
35 }) if fields.unnamed.len() == 1 => {
36 let field = fields.unnamed.first().unwrap();
37
38 (accepts::transparent_body(field), transparent_body())
39 }
40 _ => {
41 return Err(Error::new_spanned(
42 input,
43 "#[postgres(transparent)] may only be applied to single field tuple structs",
44 ));
45 }
46 }
47 } else if overrides.allow_mismatch {
48 match input.data {
49 Data::Enum(ref data) => {
50 let variants = data
51 .variants
52 .iter()
53 .map(|variant| Variant::parse(variant, overrides.rename_all))
54 .collect::<Result<Vec<_>, _>>()?;
55 (
56 accepts::enum_body(&name, &variants, overrides.allow_mismatch),
57 enum_body(&input.ident, &variants),
58 )
59 }
60 _ => {
61 return Err(Error::new_spanned(
62 input,
63 "#[postgres(allow_mismatch)] may only be applied to enums",
64 ));
65 }
66 }
67 } else {
68 match input.data {
69 Data::Enum(ref data) => {
70 let variants = data
71 .variants
72 .iter()
73 .map(|variant| Variant::parse(variant, overrides.rename_all))
74 .collect::<Result<Vec<_>, _>>()?;
75 (
76 accepts::enum_body(&name, &variants, overrides.allow_mismatch),
77 enum_body(&input.ident, &variants),
78 )
79 }
80 Data::Struct(DataStruct {
81 fields: Fields::Unnamed(ref fields),
82 ..
83 }) if fields.unnamed.len() == 1 => {
84 let field = fields.unnamed.first().unwrap();
85
86 (accepts::domain_body(&name, field), domain_body())
87 }
88 Data::Struct(DataStruct {
89 fields: Fields::Named(ref fields),
90 ..
91 }) => {
92 let fields = fields
93 .named
94 .iter()
95 .map(|field| Field::parse(field, overrides.rename_all))
96 .collect::<Result<Vec<_>, _>>()?;
97 (
98 accepts::composite_body(&name, "ToSql", &fields),
99 composite_body(&fields),
100 )
101 }
102 _ => {
103 return Err(Error::new_spanned(
104 input,
105 "#[derive(ToSql)] may only be applied to structs, single field tuple structs, and enums",
106 ));
107 }
108 }
109 };
110
111 let ident = &input.ident;
112 let generics = append_generic_bound(input.generics.to_owned(), &new_tosql_bound());
113 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
114 let out = quote! {
115 impl#impl_generics postgres_types::ToSql for #ident#ty_generics #where_clause {
116 fn to_sql(&self,
117 _type: &postgres_types::Type,
118 buf: &mut postgres_types::private::BytesMut)
119 -> std::result::Result<postgres_types::IsNull,
120 std::boxed::Box<std::error::Error +
121 std::marker::Sync +
122 std::marker::Send>> {
123 #to_sql_body
124 }
125
126 fn accepts(type_: &postgres_types::Type) -> bool {
127 #accepts_body
128 }
129
130 postgres_types::to_sql_checked!();
131 }
132 };
133
134 Ok(out)
135}
136
137fn transparent_body() -> TokenStream {
138 quote! {
139 postgres_types::ToSql::to_sql(&self.0, _type, buf)
140 }
141}
142
143fn enum_body(ident: &Ident, variants: &[Variant]) -> TokenStream {
144 let idents = iter::repeat(ident);
145 let variant_idents = variants.iter().map(|v| &v.ident);
146 let variant_names = variants.iter().map(|v| &v.name);
147
148 quote! {
149 let s = match *self {
150 #(
151 #idents::#variant_idents => #variant_names,
152 )*
153 };
154
155 buf.extend_from_slice(s.as_bytes());
156 std::result::Result::Ok(postgres_types::IsNull::No)
157 }
158}
159
160fn domain_body() -> TokenStream {
161 quote! {
162 let type_ = match *_type.kind() {
163 postgres_types::Kind::Domain(ref type_) => type_,
164 _ => unreachable!(),
165 };
166
167 postgres_types::ToSql::to_sql(&self.0, type_, buf)
168 }
169}
170
171fn composite_body(fields: &[Field]) -> TokenStream {
172 let field_names = fields.iter().map(|f| &f.name);
173 let field_idents = fields.iter().map(|f| &f.ident);
174
175 quote! {
176 let fields = match *_type.kind() {
177 postgres_types::Kind::Composite(ref fields) => fields,
178 _ => unreachable!(),
179 };
180
181 buf.extend_from_slice(&(fields.len() as i32).to_be_bytes());
182
183 for field in fields {
184 buf.extend_from_slice(&field.type_().oid().to_be_bytes());
185
186 let base = buf.len();
187 buf.extend_from_slice(&[0; 4]);
188 let r = match field.name() {
189 #(
190 #field_names => postgres_types::ToSql::to_sql(&self.#field_idents, field.type_(), buf),
191 )*
192 _ => unreachable!(),
193 };
194
195 let count = match r? {
196 postgres_types::IsNull::Yes => -1,
197 postgres_types::IsNull::No => {
198 let len = buf.len() - base - 4;
199 if len > i32::max_value() as usize {
200 return std::result::Result::Err(
201 std::convert::Into::into("value too large to transmit"));
202 }
203 len as i32
204 }
205 };
206
207 buf[base..base + 4].copy_from_slice(&count.to_be_bytes());
208 }
209
210 std::result::Result::Ok(postgres_types::IsNull::No)
211 }
212}
213
214fn new_tosql_bound() -> TypeParamBound {
215 TypeParamBound::Trait(TraitBound {
216 lifetimes: None,
217 modifier: TraitBoundModifier::None,
218 paren_token: None,
219 path: new_derive_path(Ident::new("ToSql", Span::call_site()).into()),
220 })
221}