prost_arrow_derive/
lib.rs1#![allow(unused_imports)]
3extern crate proc_macro;
4use std::any::Any;
5
6use proc_macro::TokenStream;
7use proc_macro2::{Span, TokenStream as TokenStream2};
8use quote::{quote, quote_spanned, ToTokens};
9use syn::{
10 parse::{Parse, ParseStream, Parser},
11 punctuated::Punctuated,
12 spanned::Spanned,
13 Result, *,
14};
15
16#[proc_macro_derive(ToArrow)]
17pub fn rule_system_derive(input: TokenStream) -> TokenStream {
18 let ast = parse_macro_input!(input as _);
19 TokenStream::from(match impl_my_trait(ast) {
20 Ok(it) => it,
21 Err(err) => err.to_compile_error(),
22 })
23}
24
25fn impl_my_trait(ast: DeriveInput) -> Result<TokenStream2> {
26 Ok({
27 let name = ast.ident;
28 let fields = match ast.data {
29 Data::Enum(DataEnum {
30 enum_token: token::Enum { span },
31 ..
32 })
33 | Data::Union(DataUnion {
34 union_token: token::Union { span },
35 ..
36 }) => {
37 return Err(Error::new(span, "Expected a `struct`"));
38 }
39
40 Data::Struct(DataStruct {
41 fields: Fields::Named(it),
42 ..
43 }) => it,
44
45 Data::Struct(_) => {
46 return Err(Error::new(
47 Span::call_site(),
48 "Expected a `struct` with named fields",
49 ));
50 }
51 };
52
53 let prost_fields: Vec<ProstField> = fields.named.into_iter().map(ProstField::new).collect();
54
55 let data_expanded_members = prost_fields.iter().map(|field| {
56 let field_name_str = LitStr::new(&field.name.to_string(), field.span);
57 let datatype = &field.arrow_datatype();
58 let nullable = &field.nullable;
59 quote_spanned! { field.span=>
60 ::arrow_schema::Field::new(
61 #field_name_str,
62 #datatype,
63 #nullable,
64 )
65 }
66 });
67
68 let builder_struct_members = prost_fields.iter().map(|field| {
69 let field_name = &field.name;
70 let inner_type = &field.inner_type;
71 let into_arrow_type = quote!(<#inner_type as ::prost_arrow::ToArrow>);
72 let builder_type = if field.array {
73 quote!(::prost_arrow::ArrowListBuilder::<#inner_type>)
74 } else {
75 quote!(#into_arrow_type::Builder)
76 };
77 quote_spanned! {
78 field.span=> #field_name: #builder_type
79 }
80 });
81
82 let builder_struct_initializers = prost_fields.iter().map(|field| {
83 let field_name = &field.name;
84 let inner_type = &field.inner_type;
85 let into_arrow_type = quote!(<#inner_type as ::prost_arrow::ToArrow>);
86 let builder_type = if field.array {
87 quote!(::prost_arrow::ArrowListBuilder::<#inner_type>)
88 } else {
89 quote!(#into_arrow_type::Builder)
90 };
91 quote_spanned! {
92 field.span=> #field_name: #builder_type::new_with_capacity(capacity)
93 }
94 });
95
96 let builder_append_exprs = prost_fields.iter().map(|field| {
97 let field_name = &field.name;
98
99 if field.nullable {
100 quote_spanned! {
101 field.span=> self.#field_name.append_option(value.#field_name)
102 }
103 } else {
104 quote_spanned! {
105 field.span=> self.#field_name.append_value(value.#field_name)
106 }
107 }
108 });
109
110 let builder_append_none_exprs = prost_fields.iter().map(|field| {
111 let field_name = &field.name;
112
113 quote_spanned! {
114 field.span=> self.#field_name.append_option(None)
115 }
116 });
117
118 let fields_vec = quote! {
119 ::arrow_schema::Fields::from(vec![
120 #(#data_expanded_members ,)*
121 ])
122 };
123
124 let finish_accessors = prost_fields.iter().map(|field| {
125 let field_name = &field.name;
126
127 quote_spanned! {
128 field.span => self.#field_name.finish()
129 }
130 });
131
132 let finish_cloned_accessors = prost_fields.iter().map(|field| {
133 let field_name = &field.name;
134
135 quote_spanned! {
136 field.span => self.#field_name.finish_cloned()
137 }
138 });
139
140 let builder_name = Ident::new(format!("{}Builder", name.to_string()).as_str(), name.span());
141
142 quote! {
143 pub struct #builder_name {
144 null_buffer_builder: ::arrow_buffer::NullBufferBuilder,
145 #(#builder_struct_members ,)*
146 }
147
148 impl ::prost_arrow::ToArrow for #name {
149 type Item = #name;
150 type Builder = #builder_name;
151
152 fn to_datatype()
153 -> ::arrow_schema::DataType
154 {
155 ::arrow_schema::DataType::Struct(#fields_vec)
156 }
157 }
158
159 impl ::prost_arrow::ArrowBuilder<#name> for #builder_name {
160 fn new_with_capacity(capacity: usize) -> Self {
161 Self{
162 null_buffer_builder: ::arrow_buffer::NullBufferBuilder::new(capacity),
163 #(#builder_struct_initializers ,)*
164 }
165 }
166
167 fn append_value(&mut self, value: #name) {
168 #(#builder_append_exprs ;)*
169 self.null_buffer_builder.append(true);
170 }
171
172 fn append_option(&mut self, value: Option<#name>) {
173 match value {
174 Some(v) => {
175 self.append_value(v);
176 },
177 None => {
178 #(#builder_append_none_exprs ;)*
179 self.null_buffer_builder.append(false);
180 },
181 }
182 }
183 }
184
185 impl ::arrow_array::builder::ArrayBuilder for #builder_name {
186 fn len(&self) -> usize {
187 self.null_buffer_builder.len()
188 }
189
190 fn finish(&mut self) -> ::arrow_array::ArrayRef {
191 let fields = #fields_vec;
192 let arrays = vec![
193 #(#finish_accessors ,)*
194 ];
195 let nulls = self.null_buffer_builder.finish();
196 ::std::sync::Arc::new(::arrow_array::StructArray::new(fields, arrays, nulls))
197 }
198
199 fn finish_cloned(&self) -> ::arrow_array::ArrayRef {
200 let fields = #fields_vec;
201 let arrays = vec![
202 #(#finish_cloned_accessors ,)*
203 ];
204 let nulls = self.null_buffer_builder.finish_cloned();
205 ::std::sync::Arc::new(::arrow_array::StructArray::new(fields, arrays, nulls))
206 }
207
208 fn as_any(&self) -> &dyn ::std::any::Any {
209 self
210 }
211
212 fn as_any_mut(&mut self) -> &mut dyn ::std::any::Any {
213 self
214 }
215
216 fn into_box_any(self: Box<Self>) -> Box<dyn ::std::any::Any> {
217 self
218 }
219 }
220 }
221 })
222}
223
224struct ProstField {
225 span: Span,
226 name: Ident,
227 inner_type: TokenStream2,
228 nullable: bool,
229 array: bool,
230}
231
232impl ProstField {
233 fn new(field: Field) -> Self {
234 let (inner_type, nullable, array) = match &field.ty {
235 Type::Path(path) => {
236 let last = path.path.segments.last().expect("has last");
237
238 let inner = match &last.arguments {
241 PathArguments::AngleBracketed(args) => args
242 .args
243 .first()
244 .expect("has one type argument")
245 .into_token_stream(),
246 _ => path.into_token_stream(),
247 };
248
249 let last_ident = last.ident.to_string();
250 let is_vec = last_ident.as_str() == "Vec";
251 let is_binary = is_vec && inner.to_string() == "u8";
252 let nullable = last_ident.as_str() == "Option";
253
254 let (inner, array) = if is_binary {
255 (last.into_token_stream(), false)
256 } else {
257 (inner, is_vec)
258 };
259
260 (inner, nullable, array)
261 }
262
263 other => (other.into_token_stream(), false, false),
264 };
265
266 Self {
267 span: field.span(),
268 name: field.ident.expect("field is named"),
269 inner_type,
270 nullable,
271 array,
272 }
273 }
274
275 fn arrow_datatype(&self) -> TokenStream2 {
276 let inner = &self.inner_type;
277
278 if self.array {
279 quote_spanned! { self.span => ::arrow_schema::DataType::List(
280 ::std::sync::Arc::new(::arrow_schema::Field::new_list_field(
281 <#inner as ::prost_arrow::ToArrow>::to_datatype(),
282 true,
283 )))
284 }
285 } else {
286 quote_spanned!(self.span => <#inner as ::prost_arrow::ToArrow>::to_datatype())
287 }
288 }
289}