prost_arrow_derive/
lib.rs

1// This is my typical proc-macro prelude
2#![allow(unused_imports)]
3extern crate proc_macro;
4use std::any::Any;
5
6use proc_macro::TokenStream;
7use proc_macro2::{Span, TokenStream as TokenStream2};
8use quote::{quote, quote_spanned, ToTokens};
9use syn::{
10    parse::{Parse, ParseStream, Parser},
11    punctuated::Punctuated,
12    spanned::Spanned,
13    Result, *,
14};
15
16#[proc_macro_derive(ToArrow)]
17pub fn rule_system_derive(input: TokenStream) -> TokenStream {
18    let ast = parse_macro_input!(input as _);
19    TokenStream::from(match impl_my_trait(ast) {
20        Ok(it) => it,
21        Err(err) => err.to_compile_error(),
22    })
23}
24
25fn impl_my_trait(ast: DeriveInput) -> Result<TokenStream2> {
26    Ok({
27        let name = ast.ident;
28        let fields = match ast.data {
29            Data::Enum(DataEnum {
30                enum_token: token::Enum { span },
31                ..
32            })
33            | Data::Union(DataUnion {
34                union_token: token::Union { span },
35                ..
36            }) => {
37                return Err(Error::new(span, "Expected a `struct`"));
38            }
39
40            Data::Struct(DataStruct {
41                fields: Fields::Named(it),
42                ..
43            }) => it,
44
45            Data::Struct(_) => {
46                return Err(Error::new(
47                    Span::call_site(),
48                    "Expected a `struct` with named fields",
49                ));
50            }
51        };
52
53        let prost_fields: Vec<ProstField> = fields.named.into_iter().map(ProstField::new).collect();
54
55        let data_expanded_members = prost_fields.iter().map(|field| {
56            let field_name_str = LitStr::new(&field.name.to_string(), field.span);
57            let datatype = &field.arrow_datatype();
58            let nullable = &field.nullable;
59            quote_spanned! { field.span=>
60                             ::arrow_schema::Field::new(
61                                 #field_name_str,
62                                 #datatype,
63                                 #nullable,
64                             )
65            }
66        });
67
68        let builder_struct_members = prost_fields.iter().map(|field| {
69            let field_name = &field.name;
70            let inner_type = &field.inner_type;
71            let into_arrow_type = quote!(<#inner_type as ::prost_arrow::ToArrow>);
72            let builder_type = if field.array {
73                quote!(::prost_arrow::ArrowListBuilder::<#inner_type>)
74            } else {
75                quote!(#into_arrow_type::Builder)
76            };
77            quote_spanned! {
78                field.span=> #field_name: #builder_type
79            }
80        });
81
82        let builder_struct_initializers = prost_fields.iter().map(|field| {
83            let field_name = &field.name;
84            let inner_type = &field.inner_type;
85            let into_arrow_type = quote!(<#inner_type as ::prost_arrow::ToArrow>);
86            let builder_type = if field.array {
87                quote!(::prost_arrow::ArrowListBuilder::<#inner_type>)
88            } else {
89                quote!(#into_arrow_type::Builder)
90            };
91            quote_spanned! {
92                field.span=> #field_name: #builder_type::new_with_capacity(capacity)
93            }
94        });
95
96        let builder_append_exprs = prost_fields.iter().map(|field| {
97            let field_name = &field.name;
98
99            if field.nullable {
100                quote_spanned! {
101                    field.span=> self.#field_name.append_option(value.#field_name)
102                }
103            } else {
104                quote_spanned! {
105                    field.span=> self.#field_name.append_value(value.#field_name)
106                }
107            }
108        });
109
110        let builder_append_none_exprs = prost_fields.iter().map(|field| {
111            let field_name = &field.name;
112
113            quote_spanned! {
114                field.span=> self.#field_name.append_option(None)
115            }
116        });
117
118        let fields_vec = quote! {
119            ::arrow_schema::Fields::from(vec![
120                #(#data_expanded_members ,)*
121            ])
122        };
123
124        let finish_accessors = prost_fields.iter().map(|field| {
125            let field_name = &field.name;
126
127            quote_spanned! {
128                field.span => self.#field_name.finish()
129            }
130        });
131
132        let finish_cloned_accessors = prost_fields.iter().map(|field| {
133            let field_name = &field.name;
134
135            quote_spanned! {
136                field.span => self.#field_name.finish_cloned()
137            }
138        });
139
140        let builder_name = Ident::new(format!("{}Builder", name.to_string()).as_str(), name.span());
141
142        quote! {
143            pub struct #builder_name {
144                null_buffer_builder: ::arrow_buffer::NullBufferBuilder,
145                #(#builder_struct_members ,)*
146            }
147
148            impl ::prost_arrow::ToArrow for #name {
149                type Item = #name;
150                type Builder = #builder_name;
151
152                fn to_datatype()
153                  -> ::arrow_schema::DataType
154                {
155                    ::arrow_schema::DataType::Struct(#fields_vec)
156                }
157            }
158
159            impl ::prost_arrow::ArrowBuilder<#name> for #builder_name {
160                fn new_with_capacity(capacity: usize) -> Self {
161                    Self{
162                        null_buffer_builder: ::arrow_buffer::NullBufferBuilder::new(capacity),
163                        #(#builder_struct_initializers ,)*
164                    }
165                }
166
167                fn append_value(&mut self, value: #name) {
168                    #(#builder_append_exprs ;)*
169                    self.null_buffer_builder.append(true);
170                }
171
172                fn append_option(&mut self, value: Option<#name>) {
173                    match value {
174                        Some(v) => {
175                            self.append_value(v);
176                        },
177                        None => {
178                            #(#builder_append_none_exprs ;)*
179                            self.null_buffer_builder.append(false);
180                        },
181                    }
182                }
183            }
184
185            impl ::arrow_array::builder::ArrayBuilder for #builder_name {
186                fn len(&self) -> usize {
187                    self.null_buffer_builder.len()
188                }
189
190                fn finish(&mut self) -> ::arrow_array::ArrayRef {
191                    let fields = #fields_vec;
192                    let arrays = vec![
193                        #(#finish_accessors ,)*
194                    ];
195                    let nulls = self.null_buffer_builder.finish();
196                    ::std::sync::Arc::new(::arrow_array::StructArray::new(fields, arrays, nulls))
197                }
198
199                fn finish_cloned(&self) -> ::arrow_array::ArrayRef {
200                    let fields = #fields_vec;
201                    let arrays = vec![
202                        #(#finish_cloned_accessors ,)*
203                    ];
204                    let nulls = self.null_buffer_builder.finish_cloned();
205                    ::std::sync::Arc::new(::arrow_array::StructArray::new(fields, arrays, nulls))
206                }
207
208                fn as_any(&self) -> &dyn ::std::any::Any {
209                    self
210                }
211
212                fn as_any_mut(&mut self) -> &mut dyn ::std::any::Any {
213                    self
214                }
215
216                fn into_box_any(self: Box<Self>) -> Box<dyn ::std::any::Any> {
217                    self
218                }
219            }
220        }
221    })
222}
223
224struct ProstField {
225    span: Span,
226    name: Ident,
227    inner_type: TokenStream2,
228    nullable: bool,
229    array: bool,
230}
231
232impl ProstField {
233    fn new(field: Field) -> Self {
234        let (inner_type, nullable, array) = match &field.ty {
235            Type::Path(path) => {
236                let last = path.path.segments.last().expect("has last");
237
238                // if Vec<u8> then inner should be Vec<u8> and array is false
239
240                let inner = match &last.arguments {
241                    PathArguments::AngleBracketed(args) => args
242                        .args
243                        .first()
244                        .expect("has one type argument")
245                        .into_token_stream(),
246                    _ => path.into_token_stream(),
247                };
248
249                let last_ident = last.ident.to_string();
250                let is_vec = last_ident.as_str() == "Vec";
251                let is_binary = is_vec && inner.to_string() == "u8";
252                let nullable = last_ident.as_str() == "Option";
253
254                let (inner, array) = if is_binary {
255                    (last.into_token_stream(), false)
256                } else {
257                    (inner, is_vec)
258                };
259
260                (inner, nullable, array)
261            }
262
263            other => (other.into_token_stream(), false, false),
264        };
265
266        Self {
267            span: field.span(),
268            name: field.ident.expect("field is named"),
269            inner_type,
270            nullable,
271            array,
272        }
273    }
274
275    fn arrow_datatype(&self) -> TokenStream2 {
276        let inner = &self.inner_type;
277
278        if self.array {
279            quote_spanned! { self.span => ::arrow_schema::DataType::List(
280                ::std::sync::Arc::new(::arrow_schema::Field::new_list_field(
281                    <#inner as ::prost_arrow::ToArrow>::to_datatype(),
282                    true,
283                )))
284            }
285        } else {
286            quote_spanned!(self.span => <#inner as ::prost_arrow::ToArrow>::to_datatype())
287        }
288    }
289}