peppi_derive/
lib.rs

1use darling::{ast, FromDeriveInput, FromField, FromMeta};
2use quote::{quote, ToTokens};
3
4type Result<T> = std::result::Result<T, darling::Error>;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
7struct Version (u8, u8);
8
9impl FromMeta for Version {
10	fn from_string(value: &str) -> Result<Self> {
11		if let Ok(re) = regex::Regex::new(r"^(\d)\.(\d)$") {
12			if let Some(caps) = re.captures(value) {
13				return Ok(Version(
14					caps.get(1).unwrap().as_str().parse::<u8>().unwrap(),
15					caps.get(2).unwrap().as_str().parse::<u8>().unwrap(),
16				));
17			}
18		}
19		Err(darling::Error::unsupported_format("X.Y"))
20	}
21}
22
23#[derive(Debug, FromDeriveInput)]
24#[darling(attributes(slippi), supports(struct_any))]
25pub(crate) struct MyInputReceiver {
26	ident: syn::Ident,
27	generics: syn::Generics,
28	data: ast::Data<(), MyFieldReceiver>,
29}
30
31fn if_ver(version: Option<Version>, inner: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
32	match version {
33		Some(version) => {
34			let Version(major, minor) = version;
35			quote!(match version.0 > #major || (version.0 == #major && version.1 >= #minor) {
36				true => Some(#inner),
37				_ => None,
38			},)
39		},
40		_ => quote!(Some(#inner),),
41	}
42}
43
44/// Takes an `Option<...>` type and returns the inner type.
45fn wrapped_type(ty: &syn::Type) -> Option<&syn::Type> {
46	match ty {
47		syn::Type::Path(tpath) => {
48			let segment = &tpath.path.segments[0];
49			match segment.ident.to_string().as_str() {
50				"Option" => // FIXME: will miss `std::option::Option`, etc
51					match &segment.arguments {
52						syn::PathArguments::AngleBracketed(args) =>
53							match &args.args[0] {
54								syn::GenericArgument::Type(ty) =>
55									Some(ty),
56								_ => None,
57							},
58						_ => None,
59					}
60				_ => None,
61			}
62		},
63		_ => None,
64	}
65}
66
67impl ToTokens for MyInputReceiver {
68	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
69		let MyInputReceiver {
70			ref ident,
71			ref generics,
72			ref data,
73		} = *self;
74
75		let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
76		let mut fields = data
77			.as_ref()
78			.take_struct()
79			.expect("Should never be enum")
80			.fields;
81		fields.sort_by_key(|f| f.version);
82
83		let mut arrow_defaults = quote!();
84		let mut arrow_fields = quote!();
85		let mut arrow_builders = quote!();
86		let mut arrow_writers = quote!();
87		let mut arrow_null_writers = quote!();
88		let mut arrow_readers = quote!();
89
90		for (i, f) in fields.into_iter().enumerate() {
91			let ident = &f.ident;
92			let name = ident.as_ref()
93				.map(|n| n.to_string().trim_start_matches("r#").to_string())
94				.unwrap_or(format!("{}", i));
95			let ty = &f.ty;
96			arrow_defaults.extend(
97				quote!(
98					#ident: <#ty as ::peppi_arrow::Arrow>::default(),
99				)
100			);
101			arrow_fields.extend(if_ver(f.version,
102				quote!(::arrow::datatypes::Field::new(
103					#name,
104					<#ty>::data_type(context),
105					<#ty>::is_nullable(),
106				))
107			));
108			arrow_builders.extend(if_ver(f.version,
109				quote!(Box::new(<#ty>::builder(len, context))
110					as Box<dyn ::arrow::array::ArrayBuilder>)
111			));
112			arrow_writers.extend(
113				quote!(
114					let x: Option<usize> = None;
115					if num_fields > #i {
116						self.#ident.write(
117							builder.field_builder::<<#ty as ::peppi_arrow::Arrow>::Builder>(#i).expect(stringify!(Failed to create builder for: #ident)),
118							context,
119						);
120					}
121				)
122			);
123			arrow_null_writers.extend(
124				quote!(
125					if num_fields > #i {
126						<#ty>::write_null(
127							builder.field_builder::<<#ty as ::peppi_arrow::Arrow>::Builder>(#i).expect(stringify!(Failed to create null builder for: #ident)),
128							context,
129						);
130					}
131				)
132			);
133			arrow_readers.extend(
134				if f.version.is_some() {
135					let wrapped = wrapped_type(ty).expect(stringify!(Failed to unwrap type for: #ident));
136					quote!(
137						if struct_array.num_columns() > #i {
138							let mut value = <#wrapped as ::peppi_arrow::Arrow>::default();
139							value.read(struct_array.column(#i).clone(), idx);
140							self.#ident = Some(value);
141						}
142					)
143				} else {
144					quote!(
145						self.#ident.read(struct_array.column(#i).clone(), idx);
146					)
147				}
148			);
149		}
150
151		tokens.extend(quote! {
152			impl #impl_generics ::peppi_arrow::Arrow for #ident #ty_generics #where_clause {
153				type Builder = ::arrow::array::StructBuilder;
154
155				fn default() -> Self {
156					Self {
157						#arrow_defaults
158					}
159				}
160
161				fn fields<C: ::peppi_arrow::Context>(context: C) -> Vec<::arrow::datatypes::Field> {
162					let version = context.slippi_version();
163					vec![#arrow_fields].into_iter().filter_map(|f| f).collect()
164				}
165
166				fn data_type<C: ::peppi_arrow::Context>(context: C) -> ::arrow::datatypes::DataType {
167					::arrow::datatypes::DataType::Struct(Self::fields(context))
168				}
169
170				fn builder<C: ::peppi_arrow::Context>(len: usize, context: C) -> Self::Builder {
171					let version = context.slippi_version();
172					let fields = Self::fields(context);
173					let builders: Vec<_> = vec![#arrow_builders].into_iter().filter_map(|f| f).collect();
174					::arrow::array::StructBuilder::new(fields, builders)
175				}
176
177				fn write<C: ::peppi_arrow::Context>(&self, builder: &mut dyn ::arrow::array::ArrayBuilder, context: C) {
178					let builder = builder.as_any_mut().downcast_mut::<Self::Builder>()
179						.expect(stringify!(Failed to downcast builder for: #ident));
180					let num_fields = builder.num_fields();
181					#arrow_writers
182					builder.append(true)
183						.expect(stringify!(Failed to append for: #ident));
184				}
185
186				fn write_null<C: ::peppi_arrow::Context>(builder: &mut dyn ::arrow::array::ArrayBuilder, context: C) {
187					let builder = builder.as_any_mut().downcast_mut::<Self::Builder>()
188						.expect(stringify!(Failed to downcast null builder for: #ident));
189					let num_fields = builder.num_fields();
190					#arrow_null_writers
191					builder.append(false)
192						.expect(stringify!(Failed to append null for: #ident));
193				}
194
195				fn read(&mut self, array: ::arrow::array::ArrayRef, idx: usize) {
196					let struct_array = array.as_any().downcast_ref::<arrow::array::StructArray>()
197						.expect(stringify!(Failed to downcast array for: #ident));
198					#arrow_readers
199				}
200			}
201		});
202	}
203}
204
205#[derive(Debug, FromField)]
206#[darling(attributes(slippi))]
207pub(crate) struct MyFieldReceiver {
208	ident: Option<syn::Ident>,
209	ty: syn::Type,
210	#[darling(default)]
211	version: Option<Version>,
212}
213
214#[proc_macro_derive(Arrow, attributes(slippi))]
215pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
216	let ast = syn::parse(input).expect("Failed to parse item");
217	build_converters(ast).expect("Failed to build converters")
218}
219
220fn build_converters(ast: syn::DeriveInput) -> Result<proc_macro::TokenStream> {
221	let receiver = MyInputReceiver::from_derive_input(&ast).map_err(|e| e.flatten())?;
222	Ok(quote!(#receiver).into())
223}