iridis_message_derive/
lib.rs

1//! This module contains the macro `ArrowMessage`.
2//! It's used to generate the necessary boilerplate code for creating an Arrow message.
3
4extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{
9    DeriveInput, Field, Fields, Ident, Token, Variant, parse_macro_input, punctuated::Punctuated,
10    token::Comma,
11};
12
13/// Apply this macro to a struct or enum to implement the `ArrowMessage` trait.
14///
15/// All fields of the struct must implement the `ArrowMessage` trait. This is the only
16/// constraint for the struct.
17///
18/// Rust primitives like `u8`, `u16`, `u32`, `u64`, `i8`, `i16`, `i32`, `i64`, `f32`, `f64` and
19/// Optionals of these types already implement the `ArrowMessage` trait, as well as all `arrow::array` types.
20#[proc_macro_derive(ArrowMessage)]
21pub fn from_into_arrow_derive(input: TokenStream) -> TokenStream {
22    let input = parse_macro_input!(input as DeriveInput);
23
24    let name = input.ident;
25
26    match input.data {
27        syn::Data::Struct(s) => match s.fields {
28            Fields::Named(ref fields) => struct_derive(name, fields.named.clone()),
29            _ => panic!("Only structs with named fields are supported"),
30        },
31        syn::Data::Enum(e) => enum_derive(name, e.variants.clone()),
32        _ => panic!("Only structs and enums are supported"),
33    }
34}
35
36fn struct_derive(name: Ident, fields: Punctuated<Field, Comma>) -> TokenStream {
37    let field_attributes = fields
38        .iter()
39        .map(|field| (&field.ident, &field.ty))
40        .collect::<Vec<_>>();
41
42    let fields_fill = field_attributes.iter().map(|&(field, ty)| {
43        quote! {
44            <#ty>::field(stringify!(#field)),
45        }
46    });
47
48    let union_data_fill = field_attributes.iter().map(|&(field, _)| {
49        quote! {
50            #field: extract_union_data(stringify!(#field), &map, &children)?,
51        }
52    });
53
54    let arrow_data_fill = field_attributes.iter().map(|&(field, _)| {
55        quote! {
56            self.#field.try_into_arrow()?,
57        }
58    });
59
60    let expanded = quote! {
61        impl ArrowMessage for #name {
62            fn field(name: impl Into<String>) -> iridis_message::prelude::thirdparty::arrow_schema::Field {
63                make_union_fields(
64                    name,
65                    vec![
66                        #(#fields_fill)*
67                    ],
68                )
69            }
70
71            fn try_from_arrow(data: iridis_message::prelude::thirdparty::arrow_data::ArrayData) -> iridis_message::prelude::thirdparty::eyre::Result<Self>
72            where
73                Self: Sized,
74            {
75                let (map, children) = unpack_union(data);
76
77                Ok(Self {
78                    #(#union_data_fill)*
79                })
80            }
81
82            fn try_into_arrow(self) -> iridis_message::prelude::thirdparty::eyre::Result<iridis_message::prelude::thirdparty::arrow_array::ArrayRef> {
83                let union_fields = get_union_fields::<Self>()?;
84
85                make_union_array(
86                    union_fields,
87                    vec![
88                        #(#arrow_data_fill)*
89                    ],
90                )
91            }
92        }
93
94        impl TryFrom<iridis_message::prelude::thirdparty::arrow_data::ArrayData> for #name {
95            type Error = iridis_message::prelude::thirdparty::eyre::Report;
96
97            fn try_from(data: iridis_message::prelude::thirdparty::arrow_data::ArrayData) -> iridis_message::prelude::thirdparty::eyre::Result<Self> {
98                #name::try_from_arrow(data)
99            }
100        }
101
102        impl TryFrom<#name> for iridis_message::prelude::thirdparty::arrow_data::ArrayData {
103            type Error = iridis_message::prelude::thirdparty::eyre::Report;
104
105            fn try_from(item: #name) -> iridis_message::prelude::thirdparty::eyre::Result<Self> {
106                use iridis_message::prelude::thirdparty::arrow_array::Array;
107
108                item.try_into_arrow().map(|array| array.into_data())
109            }
110        }
111    };
112
113    TokenStream::from(expanded)
114}
115
116fn enum_derive(name: Ident, variants: Punctuated<Variant, Token![,]>) -> TokenStream {
117    let variants: Vec<_> = variants
118        .iter()
119        .map(|variant| {
120            let variant_name = &variant.ident;
121            let variant_str = variant_name.to_string().to_lowercase(); // Exemple : `Foo` -> "foo"
122            (variant_name, variant_str)
123        })
124        .collect();
125
126    let into_string_arms = variants.iter().map(|(variant_name, variant_str)| {
127        quote! {
128            #name::#variant_name => #variant_str.to_string(),
129        }
130    });
131
132    let try_from_string_arms = variants.iter().map(|(variant_name, variant_str)| {
133        quote! {
134            #variant_str => Ok(#name::#variant_name),
135        }
136    });
137
138    let expanded = quote! {
139        impl #name {
140            pub fn into_string(self) -> String {
141                match self {
142                    #(#into_string_arms)*
143                }
144            }
145
146            pub fn try_from_string(s: String) -> iridis_message::prelude::thirdparty::eyre::Result<Self> {
147                match s.as_str() {
148                    #(#try_from_string_arms)*
149                    _ => Err(iridis_message::prelude::thirdparty::eyre::eyre!("Invalid value for {}: {}", stringify!(#name), s)),
150                }
151            }
152        }
153
154        impl ArrowMessage for #name {
155            fn field(name: impl Into<String>) -> iridis_message::prelude::thirdparty::arrow_schema::Field {
156                String::field(name)
157            }
158
159            fn try_from_arrow(data: iridis_message::prelude::thirdparty::arrow_data::ArrayData) -> iridis_message::prelude::thirdparty::eyre::Result<Self>
160            where
161                Self: Sized,
162            {
163                Encoding::try_from_string(String::try_from_arrow(data)?)
164            }
165
166            fn try_into_arrow(self) -> iridis_message::prelude::thirdparty::eyre::Result<iridis_message::prelude::thirdparty::arrow_array::ArrayRef> {
167                String::try_into_arrow(self.into_string())
168            }
169        }
170
171
172        impl TryFrom<iridis_message::prelude::thirdparty::arrow_data::ArrayData> for #name {
173            type Error = iridis_message::prelude::thirdparty::eyre::Report;
174
175            fn try_from(data: iridis_message::prelude::thirdparty::arrow_data::ArrayData) -> iridis_message::prelude::thirdparty::eyre::Result<Self> {
176                #name::try_from_arrow(data)
177            }
178        }
179
180        impl TryFrom<#name> for iridis_message::prelude::thirdparty::arrow_data::ArrayData {
181            type Error = iridis_message::prelude::thirdparty::eyre::Report;
182
183            fn try_from(item: #name) -> iridis_message::prelude::thirdparty::eyre::Result<Self> {
184                item.try_into_arrow().map(|array| array.into_data())
185            }
186        }
187    };
188
189    TokenStream::from(expanded)
190}