rs_odbc_derive/
lib.rs

1use proc_macro::{token_stream, TokenStream};
2use proc_macro2::{Ident, TokenStream as TokenStream2};
3use quote::quote;
4use syn::{self, parse::Parse, parse::Parser};
5
6// TODO: Better message
7const ZST_MSG: &str = "`odbc_type` must be implemented on a zero-sized struct or an enum";
8
9#[proc_macro_derive(Ident, attributes(identifier))]
10pub fn into_identifier(input: TokenStream) -> TokenStream {
11    let ast: syn::DeriveInput = syn::parse(input).unwrap();
12
13    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
14    let type_name = &ast.ident;
15
16    let mut identifier = None;
17    let mut identifier_type = None;
18    for attr in ast.attrs.into_iter() {
19        if attr.path.is_ident("identifier") {
20            if let syn::Meta::List(attr_list) = attr.parse_meta().expect("Missing arguments") {
21                let mut attr_list = attr_list.nested.into_iter();
22
23                if let syn::NestedMeta::Meta(meta) = &attr_list.next().expect("Missing arguments") {
24                    identifier_type = meta.path().get_ident().map(|x| x.to_owned());
25                } else {
26                    panic!("1st argument is not a valid ODBC type");
27                }
28                if let syn::NestedMeta::Lit(lit) = attr_list.next().expect("Missing 2nd argument") {
29                    identifier = Some(lit);
30                } else {
31                    panic!("2nd argument is not a valid literal");
32                }
33            }
34        }
35    }
36
37    let gen = quote! {
38        impl #impl_generics crate::Ident for #type_name #ty_generics #where_clause {
39            type Type = crate::#identifier_type;
40            const IDENTIFIER: Self::Type = #identifier;
41        }
42    };
43
44    gen.into()
45}
46
47fn parse_inner_type(mut args: token_stream::IntoIter) -> Ident {
48    let inner_type: Ident = syn::parse(args.next().unwrap().into()).unwrap();
49
50    if args.next().is_some() {
51        // TODO: Better message
52        panic!("Only one ODBC type can be declared");
53    }
54
55    match inner_type.to_string().as_str() {
56        // TODO: Expand with other types. Maybe use Copy trait. Then implement Ident only if inner_type implements inner
57        // REQUIREMENT1: supported types must have a valid zero-byte representation because of AttrZeroFill
58        // REQUIREMENT2: supported types must have the same representation as SQLPOINTER because of Ident
59        "SQLINTEGER" | "SQLUINTEGER" | "SQLSMALLINT" | "SQLUSMALLINT" | "SQLLEN" | "SQLULEN" => {}
60        unsupported => panic!("{}: unsupported ODBC type", unsupported),
61    }
62
63    inner_type
64}
65
66fn odbc_derive(ast: &mut syn::DeriveInput, inner_type: &Ident) -> TokenStream2 {
67    ast.attrs.extend(
68        syn::Attribute::parse_outer
69            .parse2(quote! { #[derive(Debug, Clone, Copy)] })
70            .unwrap(),
71    );
72
73    let type_name = &ast.ident;
74    let mut ret = match ast.data {
75        syn::Data::Struct(ref mut struct_data) => {
76            ast.attrs.extend(
77                syn::Attribute::parse_outer
78                    .parse2(quote! { #[repr(transparent)] })
79                    .unwrap(),
80            );
81
82            if struct_data.fields.is_empty() {
83                struct_data.fields = syn::Fields::Unnamed(
84                    syn::FieldsUnnamed::parse
85                        .parse2(quote! { (crate::#inner_type) })
86                        .expect(&format!("{}: unknown ODBC type", inner_type)),
87                );
88            } else {
89                panic!("{}", ZST_MSG);
90            }
91
92            quote! {
93                unsafe impl crate::convert::AsMutSQLPOINTER for #type_name {
94                    fn as_mut_SQLPOINTER(&mut self) -> crate::SQLPOINTER {
95                        (self as *mut Self).cast()
96                    }
97                }
98                unsafe impl crate::convert::AsMutSQLPOINTER for std::mem::MaybeUninit<#type_name> {
99                    fn as_mut_SQLPOINTER(&mut self) -> crate::SQLPOINTER {
100                        self.as_mut_ptr().cast()
101                    }
102                }
103
104                impl #type_name {
105                    #[inline]
106                    pub(crate) const fn identifier(&self) -> crate::#inner_type {
107                        self.0
108                    }
109                }
110            }
111        }
112        syn::Data::Enum(ref data) => {
113            let variants = data.variants.iter().map(|v| &v.ident);
114
115            quote! {
116                impl std::convert::TryFrom<crate::#inner_type> for #type_name {
117                    type Error = crate::#inner_type;
118
119                    fn try_from(source: crate::#inner_type) -> Result<Self, Self::Error> {
120                        match source {
121                            #(x if x == #type_name::#variants as crate::#inner_type => Ok(#type_name::#variants)),*,
122                            unknown => Err(unknown),
123                        }
124                    }
125                }
126
127                impl #type_name {
128                    pub(crate) const fn identifier(&self) -> crate::#inner_type {
129                        *self as crate::#inner_type
130                    }
131                }
132            }
133        }
134        _ => panic!("{}", ZST_MSG),
135    };
136
137    ret.extend(quote! {
138        impl crate::Ident for #type_name where crate::#inner_type: crate::Ident {
139            type Type = <crate::#inner_type as crate::Ident>::Type;
140            const IDENTIFIER: Self::Type = <crate::#inner_type as crate::Ident>::IDENTIFIER;
141        }
142
143        impl crate::Scalar for #type_name where crate::#inner_type: crate::Scalar {}
144
145        unsafe impl crate::convert::IntoSQLPOINTER for #type_name {
146            fn into_SQLPOINTER(self) -> crate::SQLPOINTER {
147                Self::identifier(&self) as _
148            }
149        }
150
151        impl crate::attr::AttrZeroAssert for #type_name {
152            #[inline]
153            fn assert_zeroed(&self) {
154                // TODO: Check implementation on types in lib.rs
155                assert_eq!(0, Self::identifier(&self));
156            }
157        }
158
159        #ast
160    });
161
162    ret
163}
164
165#[proc_macro_attribute]
166pub fn odbc_bitmask(args: TokenStream, input: TokenStream) -> TokenStream {
167    let mut ast: syn::DeriveInput = syn::parse(input).unwrap();
168
169    let inner_type = parse_inner_type(args.into_iter());
170    let mut odbc_bitmask = odbc_derive(&mut ast, &inner_type);
171
172    let type_name = &ast.ident;
173    odbc_bitmask.extend(quote! {
174        impl std::ops::BitAnd<#type_name> for #type_name {
175            type Output = crate::#inner_type;
176
177            fn bitand(self, other: #type_name) -> Self::Output {
178                Self::identifier(&self) &Self::identifier(&other)
179            }
180        }
181        impl std::ops::BitAnd<crate::#inner_type> for #type_name {
182            type Output = crate::#inner_type;
183
184            fn bitand(self, other: crate::#inner_type) -> Self::Output {
185                Self::identifier(&self) & other
186            }
187        }
188        impl std::ops::BitAnd<#type_name> for crate::#inner_type {
189            type Output = crate::#inner_type;
190
191            fn bitand(self, other: #type_name) -> Self::Output {
192                other & self
193            }
194        }
195    });
196
197    odbc_bitmask.into()
198}
199
200#[proc_macro_attribute]
201pub fn odbc_type(args: TokenStream, input: TokenStream) -> TokenStream {
202    let mut ast: syn::DeriveInput = syn::parse(input).unwrap();
203
204    ast.attrs.extend(
205        syn::Attribute::parse_outer
206            .parse2(quote! { #[derive(PartialEq, Eq)] })
207            .unwrap(),
208    );
209
210    let inner_type = parse_inner_type(args.into_iter());
211    let mut odbc_type = odbc_derive(&mut ast, &inner_type);
212
213    let type_name = &ast.ident;
214    odbc_type.extend(quote! {
215        impl PartialEq<crate::#inner_type> for #type_name {
216            fn eq(&self, other: &crate::#inner_type) -> bool {
217                self.identifier() == *other
218            }
219        }
220
221        impl PartialEq<#type_name> for crate::#inner_type {
222            fn eq(&self, other: &#type_name) -> bool {
223                other == self
224            }
225        }
226    });
227
228    odbc_type.into()
229}