1use proc_macro::{token_stream, TokenStream};
2use proc_macro2::{Ident, TokenStream as TokenStream2};
3use quote::quote;
4use syn::{self, parse::Parse, parse::Parser};
5
6const ZST_MSG: &str = "`odbc_type` must be implemented on a zero-sized struct or an enum";
8
9#[proc_macro_derive(Ident, attributes(identifier))]
10pub fn into_identifier(input: TokenStream) -> TokenStream {
11 let ast: syn::DeriveInput = syn::parse(input).unwrap();
12
13 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
14 let type_name = &ast.ident;
15
16 let mut identifier = None;
17 let mut identifier_type = None;
18 for attr in ast.attrs.into_iter() {
19 if attr.path.is_ident("identifier") {
20 if let syn::Meta::List(attr_list) = attr.parse_meta().expect("Missing arguments") {
21 let mut attr_list = attr_list.nested.into_iter();
22
23 if let syn::NestedMeta::Meta(meta) = &attr_list.next().expect("Missing arguments") {
24 identifier_type = meta.path().get_ident().map(|x| x.to_owned());
25 } else {
26 panic!("1st argument is not a valid ODBC type");
27 }
28 if let syn::NestedMeta::Lit(lit) = attr_list.next().expect("Missing 2nd argument") {
29 identifier = Some(lit);
30 } else {
31 panic!("2nd argument is not a valid literal");
32 }
33 }
34 }
35 }
36
37 let gen = quote! {
38 impl #impl_generics crate::Ident for #type_name #ty_generics #where_clause {
39 type Type = crate::#identifier_type;
40 const IDENTIFIER: Self::Type = #identifier;
41 }
42 };
43
44 gen.into()
45}
46
47fn parse_inner_type(mut args: token_stream::IntoIter) -> Ident {
48 let inner_type: Ident = syn::parse(args.next().unwrap().into()).unwrap();
49
50 if args.next().is_some() {
51 panic!("Only one ODBC type can be declared");
53 }
54
55 match inner_type.to_string().as_str() {
56 "SQLINTEGER" | "SQLUINTEGER" | "SQLSMALLINT" | "SQLUSMALLINT" | "SQLLEN" | "SQLULEN" => {}
60 unsupported => panic!("{}: unsupported ODBC type", unsupported),
61 }
62
63 inner_type
64}
65
66fn odbc_derive(ast: &mut syn::DeriveInput, inner_type: &Ident) -> TokenStream2 {
67 ast.attrs.extend(
68 syn::Attribute::parse_outer
69 .parse2(quote! { #[derive(Debug, Clone, Copy)] })
70 .unwrap(),
71 );
72
73 let type_name = &ast.ident;
74 let mut ret = match ast.data {
75 syn::Data::Struct(ref mut struct_data) => {
76 ast.attrs.extend(
77 syn::Attribute::parse_outer
78 .parse2(quote! { #[repr(transparent)] })
79 .unwrap(),
80 );
81
82 if struct_data.fields.is_empty() {
83 struct_data.fields = syn::Fields::Unnamed(
84 syn::FieldsUnnamed::parse
85 .parse2(quote! { (crate::#inner_type) })
86 .expect(&format!("{}: unknown ODBC type", inner_type)),
87 );
88 } else {
89 panic!("{}", ZST_MSG);
90 }
91
92 quote! {
93 unsafe impl crate::convert::AsMutSQLPOINTER for #type_name {
94 fn as_mut_SQLPOINTER(&mut self) -> crate::SQLPOINTER {
95 (self as *mut Self).cast()
96 }
97 }
98 unsafe impl crate::convert::AsMutSQLPOINTER for std::mem::MaybeUninit<#type_name> {
99 fn as_mut_SQLPOINTER(&mut self) -> crate::SQLPOINTER {
100 self.as_mut_ptr().cast()
101 }
102 }
103
104 impl #type_name {
105 #[inline]
106 pub(crate) const fn identifier(&self) -> crate::#inner_type {
107 self.0
108 }
109 }
110 }
111 }
112 syn::Data::Enum(ref data) => {
113 let variants = data.variants.iter().map(|v| &v.ident);
114
115 quote! {
116 impl std::convert::TryFrom<crate::#inner_type> for #type_name {
117 type Error = crate::#inner_type;
118
119 fn try_from(source: crate::#inner_type) -> Result<Self, Self::Error> {
120 match source {
121 #(x if x == #type_name::#variants as crate::#inner_type => Ok(#type_name::#variants)),*,
122 unknown => Err(unknown),
123 }
124 }
125 }
126
127 impl #type_name {
128 pub(crate) const fn identifier(&self) -> crate::#inner_type {
129 *self as crate::#inner_type
130 }
131 }
132 }
133 }
134 _ => panic!("{}", ZST_MSG),
135 };
136
137 ret.extend(quote! {
138 impl crate::Ident for #type_name where crate::#inner_type: crate::Ident {
139 type Type = <crate::#inner_type as crate::Ident>::Type;
140 const IDENTIFIER: Self::Type = <crate::#inner_type as crate::Ident>::IDENTIFIER;
141 }
142
143 impl crate::Scalar for #type_name where crate::#inner_type: crate::Scalar {}
144
145 unsafe impl crate::convert::IntoSQLPOINTER for #type_name {
146 fn into_SQLPOINTER(self) -> crate::SQLPOINTER {
147 Self::identifier(&self) as _
148 }
149 }
150
151 impl crate::attr::AttrZeroAssert for #type_name {
152 #[inline]
153 fn assert_zeroed(&self) {
154 assert_eq!(0, Self::identifier(&self));
156 }
157 }
158
159 #ast
160 });
161
162 ret
163}
164
165#[proc_macro_attribute]
166pub fn odbc_bitmask(args: TokenStream, input: TokenStream) -> TokenStream {
167 let mut ast: syn::DeriveInput = syn::parse(input).unwrap();
168
169 let inner_type = parse_inner_type(args.into_iter());
170 let mut odbc_bitmask = odbc_derive(&mut ast, &inner_type);
171
172 let type_name = &ast.ident;
173 odbc_bitmask.extend(quote! {
174 impl std::ops::BitAnd<#type_name> for #type_name {
175 type Output = crate::#inner_type;
176
177 fn bitand(self, other: #type_name) -> Self::Output {
178 Self::identifier(&self) &Self::identifier(&other)
179 }
180 }
181 impl std::ops::BitAnd<crate::#inner_type> for #type_name {
182 type Output = crate::#inner_type;
183
184 fn bitand(self, other: crate::#inner_type) -> Self::Output {
185 Self::identifier(&self) & other
186 }
187 }
188 impl std::ops::BitAnd<#type_name> for crate::#inner_type {
189 type Output = crate::#inner_type;
190
191 fn bitand(self, other: #type_name) -> Self::Output {
192 other & self
193 }
194 }
195 });
196
197 odbc_bitmask.into()
198}
199
200#[proc_macro_attribute]
201pub fn odbc_type(args: TokenStream, input: TokenStream) -> TokenStream {
202 let mut ast: syn::DeriveInput = syn::parse(input).unwrap();
203
204 ast.attrs.extend(
205 syn::Attribute::parse_outer
206 .parse2(quote! { #[derive(PartialEq, Eq)] })
207 .unwrap(),
208 );
209
210 let inner_type = parse_inner_type(args.into_iter());
211 let mut odbc_type = odbc_derive(&mut ast, &inner_type);
212
213 let type_name = &ast.ident;
214 odbc_type.extend(quote! {
215 impl PartialEq<crate::#inner_type> for #type_name {
216 fn eq(&self, other: &crate::#inner_type) -> bool {
217 self.identifier() == *other
218 }
219 }
220
221 impl PartialEq<#type_name> for crate::#inner_type {
222 fn eq(&self, other: &#type_name) -> bool {
223 other == self
224 }
225 }
226 });
227
228 odbc_type.into()
229}