1use darling::{ast, FromDeriveInput, FromField, FromMeta};
2use quote::{quote, ToTokens};
3
4type Result<T> = std::result::Result<T, darling::Error>;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
7struct Version (u8, u8);
8
9impl FromMeta for Version {
10 fn from_string(value: &str) -> Result<Self> {
11 if let Ok(re) = regex::Regex::new(r"^(\d)\.(\d)$") {
12 if let Some(caps) = re.captures(value) {
13 return Ok(Version(
14 caps.get(1).unwrap().as_str().parse::<u8>().unwrap(),
15 caps.get(2).unwrap().as_str().parse::<u8>().unwrap(),
16 ));
17 }
18 }
19 Err(darling::Error::unsupported_format("X.Y"))
20 }
21}
22
23#[derive(Debug, FromDeriveInput)]
24#[darling(attributes(slippi), supports(struct_any))]
25pub(crate) struct MyInputReceiver {
26 ident: syn::Ident,
27 generics: syn::Generics,
28 data: ast::Data<(), MyFieldReceiver>,
29}
30
31fn if_ver(version: Option<Version>, inner: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
32 match version {
33 Some(version) => {
34 let Version(major, minor) = version;
35 quote!(match version.0 > #major || (version.0 == #major && version.1 >= #minor) {
36 true => Some(#inner),
37 _ => None,
38 },)
39 },
40 _ => quote!(Some(#inner),),
41 }
42}
43
44fn wrapped_type(ty: &syn::Type) -> Option<&syn::Type> {
46 match ty {
47 syn::Type::Path(tpath) => {
48 let segment = &tpath.path.segments[0];
49 match segment.ident.to_string().as_str() {
50 "Option" => match &segment.arguments {
52 syn::PathArguments::AngleBracketed(args) =>
53 match &args.args[0] {
54 syn::GenericArgument::Type(ty) =>
55 Some(ty),
56 _ => None,
57 },
58 _ => None,
59 }
60 _ => None,
61 }
62 },
63 _ => None,
64 }
65}
66
67impl ToTokens for MyInputReceiver {
68 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
69 let MyInputReceiver {
70 ref ident,
71 ref generics,
72 ref data,
73 } = *self;
74
75 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
76 let mut fields = data
77 .as_ref()
78 .take_struct()
79 .expect("Should never be enum")
80 .fields;
81 fields.sort_by_key(|f| f.version);
82
83 let mut arrow_defaults = quote!();
84 let mut arrow_fields = quote!();
85 let mut arrow_builders = quote!();
86 let mut arrow_writers = quote!();
87 let mut arrow_null_writers = quote!();
88 let mut arrow_readers = quote!();
89
90 for (i, f) in fields.into_iter().enumerate() {
91 let ident = &f.ident;
92 let name = ident.as_ref()
93 .map(|n| n.to_string().trim_start_matches("r#").to_string())
94 .unwrap_or(format!("{}", i));
95 let ty = &f.ty;
96 arrow_defaults.extend(
97 quote!(
98 #ident: <#ty as ::peppi_arrow::Arrow>::default(),
99 )
100 );
101 arrow_fields.extend(if_ver(f.version,
102 quote!(::arrow::datatypes::Field::new(
103 #name,
104 <#ty>::data_type(context),
105 <#ty>::is_nullable(),
106 ))
107 ));
108 arrow_builders.extend(if_ver(f.version,
109 quote!(Box::new(<#ty>::builder(len, context))
110 as Box<dyn ::arrow::array::ArrayBuilder>)
111 ));
112 arrow_writers.extend(
113 quote!(
114 let x: Option<usize> = None;
115 if num_fields > #i {
116 self.#ident.write(
117 builder.field_builder::<<#ty as ::peppi_arrow::Arrow>::Builder>(#i).expect(stringify!(Failed to create builder for: #ident)),
118 context,
119 );
120 }
121 )
122 );
123 arrow_null_writers.extend(
124 quote!(
125 if num_fields > #i {
126 <#ty>::write_null(
127 builder.field_builder::<<#ty as ::peppi_arrow::Arrow>::Builder>(#i).expect(stringify!(Failed to create null builder for: #ident)),
128 context,
129 );
130 }
131 )
132 );
133 arrow_readers.extend(
134 if f.version.is_some() {
135 let wrapped = wrapped_type(ty).expect(stringify!(Failed to unwrap type for: #ident));
136 quote!(
137 if struct_array.num_columns() > #i {
138 let mut value = <#wrapped as ::peppi_arrow::Arrow>::default();
139 value.read(struct_array.column(#i).clone(), idx);
140 self.#ident = Some(value);
141 }
142 )
143 } else {
144 quote!(
145 self.#ident.read(struct_array.column(#i).clone(), idx);
146 )
147 }
148 );
149 }
150
151 tokens.extend(quote! {
152 impl #impl_generics ::peppi_arrow::Arrow for #ident #ty_generics #where_clause {
153 type Builder = ::arrow::array::StructBuilder;
154
155 fn default() -> Self {
156 Self {
157 #arrow_defaults
158 }
159 }
160
161 fn fields<C: ::peppi_arrow::Context>(context: C) -> Vec<::arrow::datatypes::Field> {
162 let version = context.slippi_version();
163 vec![#arrow_fields].into_iter().filter_map(|f| f).collect()
164 }
165
166 fn data_type<C: ::peppi_arrow::Context>(context: C) -> ::arrow::datatypes::DataType {
167 ::arrow::datatypes::DataType::Struct(Self::fields(context))
168 }
169
170 fn builder<C: ::peppi_arrow::Context>(len: usize, context: C) -> Self::Builder {
171 let version = context.slippi_version();
172 let fields = Self::fields(context);
173 let builders: Vec<_> = vec![#arrow_builders].into_iter().filter_map(|f| f).collect();
174 ::arrow::array::StructBuilder::new(fields, builders)
175 }
176
177 fn write<C: ::peppi_arrow::Context>(&self, builder: &mut dyn ::arrow::array::ArrayBuilder, context: C) {
178 let builder = builder.as_any_mut().downcast_mut::<Self::Builder>()
179 .expect(stringify!(Failed to downcast builder for: #ident));
180 let num_fields = builder.num_fields();
181 #arrow_writers
182 builder.append(true)
183 .expect(stringify!(Failed to append for: #ident));
184 }
185
186 fn write_null<C: ::peppi_arrow::Context>(builder: &mut dyn ::arrow::array::ArrayBuilder, context: C) {
187 let builder = builder.as_any_mut().downcast_mut::<Self::Builder>()
188 .expect(stringify!(Failed to downcast null builder for: #ident));
189 let num_fields = builder.num_fields();
190 #arrow_null_writers
191 builder.append(false)
192 .expect(stringify!(Failed to append null for: #ident));
193 }
194
195 fn read(&mut self, array: ::arrow::array::ArrayRef, idx: usize) {
196 let struct_array = array.as_any().downcast_ref::<arrow::array::StructArray>()
197 .expect(stringify!(Failed to downcast array for: #ident));
198 #arrow_readers
199 }
200 }
201 });
202 }
203}
204
205#[derive(Debug, FromField)]
206#[darling(attributes(slippi))]
207pub(crate) struct MyFieldReceiver {
208 ident: Option<syn::Ident>,
209 ty: syn::Type,
210 #[darling(default)]
211 version: Option<Version>,
212}
213
214#[proc_macro_derive(Arrow, attributes(slippi))]
215pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
216 let ast = syn::parse(input).expect("Failed to parse item");
217 build_converters(ast).expect("Failed to build converters")
218}
219
220fn build_converters(ast: syn::DeriveInput) -> Result<proc_macro::TokenStream> {
221 let receiver = MyInputReceiver::from_derive_input(&ast).map_err(|e| e.flatten())?;
222 Ok(quote!(#receiver).into())
223}