fire_postgres_derive/
lib.rs

1use ::quote::{quote, ToTokens};
2
3use syn::parse_macro_input;
4use syn::{DataEnum, Field, Fields, FieldsNamed, FieldsUnnamed, Type};
5use syn::{DeriveInput, Error};
6
7use proc_macro::TokenStream as V1TokenStream;
8
9use proc_macro2::{Ident, Span, TokenStream};
10
11use proc_macro_crate::{crate_name, FoundCrate};
12
13type Result<T> = std::result::Result<T, Error>;
14
15// inspired from https://github.com/serde-rs/serde/blob/master/serde_derive
16
17#[proc_macro_derive(TableTempl, attributes(len, index, unique))]
18pub fn derive_table_type(input: V1TokenStream) -> V1TokenStream {
19	let input = parse_macro_input!(input as DeriveInput);
20
21	// crate name
22	let name =
23		crate_name("fire-postgres").expect("fire-postgres not in dependencies");
24	let name = match name {
25		FoundCrate::Itself => quote!(crate),
26		FoundCrate::Name(n) => {
27			let ident = Ident::new(&n, Span::call_site());
28			quote!(#ident)
29		}
30	};
31
32	expand(&input, &name).unwrap_or_else(to_compile_error)
33}
34
35fn to_compile_error(error: syn::Error) -> V1TokenStream {
36	let compile_error = syn::Error::to_compile_error(&error);
37	quote!(#compile_error).into()
38}
39
40macro_rules! err {
41	($input:ident, $msg:expr) => {
42		Error::new_spanned($input.into_token_stream(), $msg)
43	};
44}
45
46fn expand(
47	input: &DeriveInput,
48	name: &TokenStream,
49) -> Result<proc_macro::TokenStream> {
50	Ok(match &input.data {
51		syn::Data::Enum(data) => parse_enum(input, data, name)?.into(),
52		syn::Data::Struct(data) => match &data.fields {
53			Fields::Named(fields) => {
54				let ident = &input.ident;
55
56				let (len, info_block, data_block, from_block) =
57					parse_named_fields(fields, name)?;
58
59				let table = quote!(#name::table);
60				quote!(
61					impl #table::TableTemplate for #ident {
62						fn table_info() -> #table::Info {
63							{ #info_block }
64						}
65						fn to_data(&self) -> Vec<#table::column::ColumnData<'_>> {
66							use #table::column::ColumnType;
67							{ #data_block }
68						}
69						fn from_data(
70							data: Vec<#table::column::ColumnData>
71						) -> std::result::Result<Self, #table::column::FromDataError> {
72							use #table::column::ColumnType;
73							if data.len() != #len {
74								return Err(#table::column::FromDataError::Custom(
75									"TableTemplate from_data: data isn't long enough"
76								))
77							}
78							let mut data = data.into_iter();
79							{ #from_block }
80						}
81					}
82				)
83				.into()
84			}
85			Fields::Unnamed(fields) => {
86				parse_unnamed_fields(input, fields, name)?.into()
87			}
88			f => return Err(err!(f, "not supported")),
89		},
90		_ => return Err(err!(input, "is not supported")),
91	})
92}
93
94fn parse_named_fields(
95	fields: &FieldsNamed,
96	name: &TokenStream,
97) -> Result<(usize, TokenStream, TokenStream, TokenStream)> {
98	let len = fields.named.len();
99	let mut info_stream = quote!(
100		let mut info = #name::table::Info::with_capacity(#len);
101	);
102	let mut data_stream = quote!(
103		let mut data = Vec::with_capacity(#len);
104	);
105	let mut from_stream = quote!();
106
107	for field in fields.named.iter() {
108		let (col, data, from) = parse_named_field(field, name)?;
109		info_stream.extend(quote!(info.push(#col);));
110		data_stream.extend(quote!(data.push(#data);));
111		from_stream.extend(from);
112	}
113
114	info_stream.extend(quote!(info));
115	data_stream.extend(quote!(data));
116
117	Ok((
118		len,
119		info_stream,
120		data_stream,
121		quote!(Ok(Self {#from_stream})),
122	))
123}
124
125fn parse_named_field(
126	field: &Field,
127	#[allow(unused_variables)] crate_name: &TokenStream,
128) -> Result<(TokenStream, TokenStream, TokenStream)> {
129	let ident = &field.ident;
130	let name = field.ident.as_ref().unwrap().to_string(); // TODO
131
132	let mut len = quote!(None);
133	let table = quote!(#crate_name::table);
134	let index_kind = quote!(#table::column::IndexKind);
135	let mut index = quote!(#index_kind::None);
136
137	// this is name: Type
138	// should build Attributes
139	// println!("parse named ident: {:?} attr: {:?}", name, field.attrs);
140	for attr in &field.attrs {
141		match &attr.path() {
142			p if p.is_ident("len") => {
143				let res: syn::LitInt = attr.parse_args()?;
144				len = quote!(Some(#res));
145			}
146			p if p.is_ident("index") => {
147				let res: syn::Ident = attr.parse_args()?;
148				let index_str = res.to_string();
149				index = match index_str.as_str() {
150					"primary" => quote!(#index_kind::Primary),
151					"unique" => quote!(#index_kind::Unique),
152					"index" => quote!(#index_kind::Index),
153					_ => return Err(err!(res, "not supported index type")),
154				};
155			}
156			p if p.is_ident("unique") => {
157				let res: syn::Ident = attr.parse_args()?;
158				let index_str = res.to_string();
159				index = quote!(#index_kind::NamedUnique(#index_str));
160			}
161			_ => {}
162		}
163	}
164
165	let ty = match &field.ty {
166		Type::Path(t) => t,
167		t => return Err(err!(t, "only type path is supported")),
168	};
169
170	let col = quote!(#table::column::Column::new::<#ty>(#name, #len, #index));
171	let data = quote!(self.#ident.to_data());
172	let from = quote!(#ident: #ty::from_data(data.next().unwrap())?,);
173
174	Ok((col, data, from))
175}
176
177fn parse_unnamed_fields(
178	input: &DeriveInput,
179	fields: &FieldsUnnamed,
180	name: &TokenStream,
181) -> Result<TokenStream> {
182	if fields.unnamed.len() != 1 {
183		return Err(err!(fields, "only single unamed fied supported"));
184	}
185	let field = fields.unnamed.iter().next().unwrap();
186
187	let ident = &input.ident;
188	let ty = &field.ty;
189
190	let table = quote!(#name::table);
191	Ok(quote!(
192		impl #table::column::ColumnType for #ident {
193			fn column_kind() -> #table::column::ColumnKind {
194				<#ty as #table::column::ColumnType>::column_kind()
195			}
196			fn to_data(&self) -> #table::column::ColumnData<'_> {
197				self.0.to_data()
198			}
199			fn from_data(
200				data: #table::column::ColumnData
201			) -> std::result::Result<Self, #table::column::FromDataError> {
202				Ok(Self(<#ty as #table::column::ColumnType>::from_data(data)?))
203			}
204		}
205	))
206}
207
208fn parse_enum(
209	input: &DeriveInput,
210	data: &DataEnum,
211	name: &TokenStream,
212) -> Result<TokenStream> {
213	let mut into_stream = quote!();
214	let mut from_stream = quote!();
215
216	for variant in data.variants.iter() {
217		if variant.fields != Fields::Unit {
218			return Err(err!(variant, "only unit are allowed"));
219		}
220		let ident = &variant.ident;
221		let ident_str = ident.to_string();
222		into_stream.extend(quote!(Self::#ident => #ident_str,));
223		from_stream.extend(quote!(#ident_str => Ok(Self::#ident),));
224	}
225
226	let ident = &input.ident;
227
228	let table = quote!(#name::table);
229	Ok(quote!(
230		impl #ident {
231			pub fn as_str(&self) -> &'static str {
232				match self {
233					#into_stream
234				}
235			}
236			pub fn from_str(
237				s: &str
238			) -> std::result::Result<Self, #table::column::FromDataError> {
239				match s {
240					#from_stream
241					_ => Err(#table::column::FromDataError::Custom(
242						"text doesnt match any enum variant"
243					))
244				}
245			}
246		}
247
248		impl #table::column::ColumnType for #ident {
249			fn column_kind() -> #table::column::ColumnKind {
250				#table::column::ColumnKind::Text
251			}
252			fn to_data(&self) -> #table::column::ColumnData {
253				#table::column::ColumnData::Text(self.as_str().into())
254			}
255			fn from_data(
256				data: #table::column::ColumnData
257			) -> std::result::Result<Self, #table::column::FromDataError> {
258				match data {
259					#table::column::ColumnData::Text(t) => {
260						Self::from_str(t.as_str())
261					},
262					_ => Err(#table::column::FromDataError::ExpectedType(
263						"text for enum"
264					))
265				}
266			}
267		}
268	))
269}