use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::{format_ident, quote};
use syn::*;
#[proc_macro_derive(ArRowDeserialize)]
pub fn ar_row_deserialize(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
let tokens = match ast.data {
Data::Struct(DataStruct {
fields: Fields::Named(FieldsNamed { named, .. }),
..
}) => impl_struct(
&ast.ident,
named
.iter()
.map(|field| {
field
.ident
.as_ref()
.expect("#ident must not have anonymous fields")
})
.collect(),
named.iter().map(|field| &field.ty).collect(),
),
Data::Struct(DataStruct { .. }) => panic!("#ident must have named fields"),
_ => panic!("#ident must be a structure"),
};
tokens
}
fn impl_struct(ident: &Ident, field_names: Vec<&Ident>, field_types: Vec<&Type>) -> TokenStream {
let num_fields = field_names.len();
let unescaped_field_names: Vec<_> = field_names
.iter()
.map(|field_name| format_ident!("{}", field_name))
.collect();
let check_datatype_impl = quote!(
impl ::ar_row::deserialize::CheckableDataType for #ident {
fn check_datatype(datatype: &::ar_row::arrow::datatypes::DataType) -> ::std::result::Result<(), ::std::string::String> {
use ::ar_row::arrow::datatypes::DataType;
match datatype {
DataType::Struct(fields) => {
let mut fields = fields.iter().enumerate();
let mut errors = ::std::vec::Vec::new();
#(
match fields.next() {
::std::option::Option::Some((i, field)) => {
if field.name() != stringify!(#unescaped_field_names) {
errors.push(format!(
"Field #{} must be called {}, not {}",
i, stringify!(#unescaped_field_names), field.name()))
}
else if let ::std::result::Result::Err(s) = <#field_types as ::ar_row::deserialize::CheckableDataType>::check_datatype(field.data_type()) {
errors.push(format!(
"Field {} cannot be decoded: {}",
stringify!(#unescaped_field_names), s));
}
},
::std::option::Option::None => errors.push(format!(
"Field {} is missing",
stringify!(#unescaped_field_names)))
}
)*
if errors.is_empty() {
::std::result::Result::Ok(())
}
else {
::std::result::Result::Err(format!(
"{} cannot be decoded:\n\t{}",
stringify!(#ident),
errors.join("\n").replace("\n", "\n\t")))
}
}
_ => ::std::result::Result::Err(format!(
"{} must be decoded from DataType::Struct, not {:?}",
stringify!(#ident),
datatype))
}
}
}
);
let orc_struct_impl = quote!(
impl ::ar_row::deserialize::ArRowStruct for #ident {
fn columns_with_prefix(prefix: &str) -> ::std::vec::Vec<::std::string::String> {
let mut columns = ::std::vec::Vec::with_capacity(#num_fields);
let instance: #ident = ::std::default::Default::default();
#({
#[inline(always)]
fn add_columns<FieldType: ::ar_row::deserialize::ArRowStruct>(columns: &mut ::std::vec::Vec<::std::string::String>, prefix: &str, _: FieldType) {
let mut field_name_prefix = prefix.to_string();
if prefix.len() != 0 {
field_name_prefix.push_str(".");
}
field_name_prefix.push_str(stringify!(#unescaped_field_names));
columns.extend(FieldType::columns_with_prefix(&field_name_prefix));
}
add_columns(&mut columns, prefix, instance.#field_names);
})*
columns
}
}
);
let prelude = quote!(
use ::std::sync::Arc;
use ::std::convert::TryInto;
use ::std::collections::HashMap;
use ::ar_row::arrow::array::Array;
use ::ar_row::deserialize::DeserializationError;
use ::ar_row::deserialize::ArRowDeserialize;
use ::ar_row::deserialize::DeserializationTarget;
let src = src.as_struct_opt().ok_or_else(|| {
DeserializationError::MismatchedColumnDataType(format!(
"Could not cast {:?} array to struct array",
src.data_type(),
))
})?;
let columns = src.columns();
assert_eq!(
columns.len(),
#num_fields,
"{} has {} fields, but got {} columns.",
stringify!(#ident), #num_fields, columns.len());
let mut columns = columns.into_iter();
if src.len() > dst.len() {
println!("{} src = {} dst = {}", stringify!(#ident), src.len(), dst.len());
return ::std::result::Result::Err(::ar_row::deserialize::DeserializationError::MismatchedLength { src: src.len(), dst: dst.len() });
}
);
let read_from_array_impl = quote!(
impl ::ar_row::deserialize::ArRowDeserialize for #ident {
fn read_from_array<'a, 'b, T> (
src: impl ::ar_row::arrow::array::Array + ::ar_row::arrow::array::AsArray, mut dst: &'b mut T
) -> ::std::result::Result<usize, ::ar_row::deserialize::DeserializationError>
where
&'b mut T: ::ar_row::deserialize::DeserializationTarget<'a, Item=#ident> + 'b {
#prelude
match src.nulls() {
::std::option::Option::None => {
for struct_ in dst.iter_mut() {
*struct_ = ::std::default::Default::default()
}
},
::std::option::Option::Some(nulls) => {
for (struct_, b) in dst.iter_mut().zip(nulls) {
if b {
*struct_ = ::std::default::Default::default()
}
}
}
}
#(
let column: &Arc<_> = columns.next().expect(
&format!("Failed to get '{}' column", stringify!(#field_names)));
ArRowDeserialize::read_from_array::<ar_row::deserialize::MultiMap<&mut T, _>>(
column.clone(),
&mut dst.map(|struct_| &mut struct_.#field_names),
)?;
)*
::std::result::Result::Ok(src.len())
}
}
);
let read_options_from_array_impl = quote!(
impl ::ar_row::deserialize::ArRowDeserializeOption for #ident {
fn read_options_from_array<'a, 'b, T> (
src: impl ::ar_row::arrow::array::Array + ::ar_row::arrow::array::AsArray, mut dst: &'b mut T
) -> ::std::result::Result<usize, ::ar_row::deserialize::DeserializationError>
where
&'b mut T: ::ar_row::deserialize::DeserializationTarget<'a, Item=::std::option::Option<#ident>> + 'b {
#prelude
match src.nulls() {
::std::option::Option::None => {
for struct_ in dst.iter_mut() {
*struct_ = ::std::option::Option::Some(::std::default::Default::default())
}
},
::std::option::Option::Some(nulls) => {
for (struct_, b) in dst.iter_mut().zip(nulls) {
if !b {
*struct_ = ::std::option::Option::Some(::std::default::Default::default())
}
}
}
}
#(
let column: &Arc<_> = columns.next().expect(
&format!("Failed to get '{}' column", stringify!(#field_names)));
ArRowDeserialize::read_from_array::<::ar_row::deserialize::MultiMap<&mut T, _>>(
column.clone(),
&mut dst.map(|struct_| &mut unsafe { struct_.as_mut().unwrap_unchecked() }.#field_names),
)?;
)*
::std::result::Result::Ok(src.len())
}
}
);
quote!(
#check_datatype_impl
#orc_struct_impl
#read_from_array_impl
#read_options_from_array_impl
)
.into()
}