ella_derive/
lib.rs

1use darling::{ast::Style, FromDeriveInput, FromField};
2use proc_macro2::{Span, TokenStream, TokenTree};
3use proc_macro_crate::FoundCrate;
4use quote::{format_ident, quote, ToTokens};
5use syn::{parse_macro_input, parse_quote, DeriveInput, GenericParam, Generics};
6
7#[proc_macro_derive(RowFormat, attributes(row))]
8pub fn derive_row_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
9    let mut input = parse_macro_input!(input as DeriveInput);
10    if let Err(err) = fix_field_attrs(&mut input) {
11        return err.into_compile_error().into();
12    }
13
14    let parsed = match RowFormat::from_derive_input(&input) {
15        Ok(parsed) => parsed,
16        Err(err) => return err.write_errors().into(),
17    };
18    let builder = match RowFormatBuilder::new(parsed) {
19        Ok(builder) => builder,
20        Err(err) => return err.into_compile_error().into(),
21    };
22    builder.implement().into()
23}
24
25#[derive(Debug, Clone, FromDeriveInput)]
26#[darling(attributes(row), supports(struct_any), forward_attrs)]
27struct RowFormat {
28    ident: syn::Ident,
29    vis: syn::Visibility,
30    generics: Generics,
31    data: darling::ast::Data<(), Column>,
32    #[darling(default)]
33    builder: Option<syn::Ident>,
34    #[darling(default)]
35    view: Option<syn::Ident>,
36}
37
38#[derive(Debug, Clone, FromField)]
39#[darling(attributes(row), forward_attrs)]
40struct Column {
41    ident: Option<syn::Ident>,
42    ty: syn::Type,
43    #[darling(default)]
44    name: Option<String>,
45    #[darling(default, rename = "r#type")]
46    as_type: Option<syn::Type>,
47}
48
49struct ColumnBuilder {
50    ident: TokenStream,
51    ty: syn::Type,
52    name: String,
53}
54
55impl ColumnBuilder {
56    fn new(col: Column, num: TokenStream) -> Result<Self, syn::Error> {
57        let ident = col
58            .ident
59            .as_ref()
60            .map_or_else(|| num.clone(), |ident| ident.to_token_stream());
61        let name = match (col.name, &col.ident) {
62            (Some(name), _) => name,
63            (None, Some(ident)) => ident.to_string(),
64            _ => {
65                return Err(syn::Error::new_spanned(
66                    col.ty.clone(),
67                    "missing field name".to_string(),
68                ))
69            }
70        };
71        let ty = col.as_type.unwrap_or(col.ty);
72        Ok(Self { ident, ty, name })
73    }
74}
75
76struct RowFormatBuilder {
77    ident: syn::Ident,
78    vis: syn::Visibility,
79    generics: Generics,
80    style: Style,
81    crt: TokenStream,
82    fields: Vec<ColumnBuilder>,
83    view_name: syn::Ident,
84    builder_name: syn::Ident,
85}
86
87impl RowFormatBuilder {
88    fn new(input: RowFormat) -> Result<Self, syn::Error> {
89        let crt = ella_crate();
90        let generics = Self::with_bounds(input.generics, &crt);
91        let fields = input.data.take_struct().unwrap();
92        let (style, fields) = fields.split();
93        let fields = fields
94            .iter()
95            .enumerate()
96            .map(|(i, f)| ColumnBuilder::new(f.clone(), syn::Index::from(i).to_token_stream()))
97            .collect::<Result<Vec<_>, _>>()?;
98
99        let view_name = input
100            .view
101            .unwrap_or_else(|| format_ident!("_{}View", input.ident));
102        let builder_name = input
103            .builder
104            .unwrap_or_else(|| format_ident!("_{}Builder", input.ident));
105
106        Ok(Self {
107            ident: input.ident,
108            vis: input.vis,
109            generics,
110            style,
111            fields,
112            crt,
113            view_name,
114            builder_name,
115        })
116    }
117
118    fn implement(self) -> TokenStream {
119        let crt = &self.crt;
120        let row = quote! { #crt::common::row };
121        let ident = &self.ident;
122        let generics = &self.generics;
123        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
124
125        let view_name = &self.view_name;
126        let builder_name = &self.builder_name;
127
128        let field_types = self.field_types();
129
130        let impl_builder = self.impl_builder();
131        let impl_view = self.impl_view();
132
133        let num_cols = quote! {
134            #(<#field_types as #row::RowFormat>::COLUMNS +)* 0
135        };
136
137        quote! {
138            #[automatically_derived]
139            impl #impl_generics #crt::common::row::RowFormat for #ident #ty_generics #where_clause {
140                const COLUMNS: usize = #num_cols;
141                type Builder = #builder_name #ty_generics;
142                type View = #view_name #ty_generics;
143
144                fn builder(fields: &[::std::sync::Arc<#crt::derive::Field>]) -> #crt::Result<Self::Builder> {
145                    #builder_name::<#ty_generics>::new(fields)
146                }
147
148                fn view(rows: usize, fields: &[::std::sync::Arc<#crt::derive::Field>], arrays: &[#crt::derive::ArrayRef]) -> #crt::Result<Self::View> {
149                    #view_name::<#ty_generics>::new(rows, fields, arrays)
150                }
151            }
152
153            #impl_builder
154            #impl_view
155        }
156    }
157
158    fn impl_builder(&self) -> TokenStream {
159        let crt = &self.crt;
160        let row = quote! { #crt::common::row };
161        let vis = &self.vis;
162        let ident = &self.ident;
163        let generics = &self.generics;
164        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
165
166        let builder_name = &self.builder_name;
167        let field_types = self.field_types();
168        let field_idents = self.field_idents();
169        let builder_fields = self
170            .field_names()
171            .into_iter()
172            .map(|name| syn::Ident::new(&name, Span::call_site()))
173            .collect::<Vec<_>>();
174
175        let len = syn::Ident::new("_ella_len", Span::call_site());
176        let doc = format!("[`{}::RowBatchBuilder`] for [`{}`]", row, ident);
177
178        quote! {
179            #[doc = #doc]
180            #[derive(Debug, Clone)]
181            #vis struct #builder_name #generics {
182                #len: usize,
183                #(#builder_fields: <#field_types as #row::RowFormat>::Builder, )*
184            }
185
186            #[automatically_derived]
187            impl #impl_generics #builder_name #ty_generics #where_clause {
188                fn new(mut fields: &[::std::sync::Arc<#crt::derive::Field>]) -> #crt::Result<#builder_name #ty_generics> {
189                    if fields.len() != <#ident #ty_generics as #row::RowFormat>::COLUMNS {
190                        return Err(#crt::Error::ColumnCount(<#ident #ty_generics as #row::RowFormat>::COLUMNS, fields.len()));
191                    }
192
193                    #(
194                        let cols = <#field_types as #row::RowFormat>::COLUMNS;
195                        let #builder_fields = <#field_types as #row::RowFormat>::builder(&fields[..cols])?;
196                        fields = &fields[cols..];
197                    )*
198
199                    Ok(#builder_name {
200                        #len: 0,
201                        #(#builder_fields, )*
202                    })
203                }
204            }
205
206            #[automatically_derived]
207            impl #impl_generics #row::RowBatchBuilder<#ident #ty_generics> for #builder_name #ty_generics #where_clause {
208                #[inline]
209                fn len(&self) -> usize {
210                    self.#len
211                }
212
213                fn push(&mut self, row: #ident #ty_generics) {
214                    #(
215                        <<#field_types as #row::RowFormat>::Builder as #row::RowBatchBuilder<#field_types>>::push(&mut self.#builder_fields, row.#field_idents.into());
216                    )*
217                    self.#len += 1;
218                }
219
220                fn build_columns(&mut self) -> #crt::Result<::std::vec::Vec<#crt::derive::ArrayRef>> {
221                    // TODO: internal state should be consistent on error
222                    let mut cols = ::std::vec::Vec::with_capacity(<#ident #ty_generics as #row::RowFormat>::COLUMNS);
223                    #(
224                        cols.extend(<<#field_types as #row::RowFormat>::Builder as #row::RowBatchBuilder<#field_types>>::build_columns(&mut self.#builder_fields)?);
225                    )*
226                    self.#len = 0;
227                    Ok(cols)
228                }
229            }
230        }
231    }
232
233    fn impl_view(&self) -> TokenStream {
234        let crt = &self.crt;
235        let row = quote! { #crt::common::row };
236        let vis = &self.vis;
237        let ident = &self.ident;
238        let generics = &self.generics;
239        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
240
241        let view_name = &self.view_name;
242        let field_types = self.field_types();
243        let field_idents = self.field_idents();
244        let view_fields = self
245            .field_names()
246            .into_iter()
247            .map(|name| syn::Ident::new(&name, Span::call_site()))
248            .collect::<Vec<_>>();
249
250        let len = syn::Ident::new("_ella_len", Span::call_site());
251        let doc = format!("[`{}::RowFormatView`] for [`{}`]", row, ident);
252
253        let impl_accessors = if self.style.is_tuple() {
254            quote! {
255                fn row(&self, i: usize) -> #ident #ty_generics {
256                    #ident(
257                        #(<<#field_types as #row::RowFormat>::View as #row::RowFormatView<#field_types>>::row(&self.#view_fields, i).into(), )*
258                    )
259                }
260
261                unsafe fn row_unchecked(&self, i: usize) -> #ident #ty_generics {
262                    #ident(
263                        #(<<#field_types as #row::RowFormat>::View as #row::RowFormatView<#field_types>>::row_unchecked(&self.#view_fields, i).into(), )*
264                    )
265                }
266            }
267        } else {
268            quote! {
269                fn row(&self, i: usize) -> #ident #ty_generics {
270                    #ident {
271                        #(#field_idents: <<#field_types as #row::RowFormat>::View as #row::RowFormatView<#field_types>>::row(&self.#view_fields, i).into(), )*
272                    }
273                }
274
275                unsafe fn row_unchecked(&self, i: usize) -> #ident #ty_generics {
276                    #ident {
277                        #(#field_idents: <<#field_types as #row::RowFormat>::View as #row::RowFormatView<#field_types>>::row_unchecked(&self.#view_fields, i).into(), )*
278                    }
279                }
280            }
281        };
282
283        quote! {
284            #[doc = #doc]
285            #[derive(Debug, Clone)]
286            #vis struct #view_name #generics {
287                #len: usize,
288                #(#view_fields: <#field_types as #row::RowFormat>::View, )*
289            }
290
291            #[automatically_derived]
292            impl #impl_generics #view_name #ty_generics #where_clause {
293                fn new(rows: usize, mut fields: &[::std::sync::Arc<#crt::derive::Field>], mut arrays: &[#crt::derive::ArrayRef]) -> #crt::Result<#view_name #ty_generics> {
294                    if arrays.len() != <#ident #ty_generics as #row::RowFormat>::COLUMNS {
295                        return Err(#crt::Error::ColumnCount(<#ident as #row::RowFormat>::COLUMNS, fields.len()));
296                    }
297
298                    #(
299                        let cols = <#field_types as #row::RowFormat>::COLUMNS;
300                        let #view_fields = <#field_types as #row::RowFormat>::view(rows, &fields[..cols], &arrays[..cols])?;
301                        debug_assert_eq!(<<#field_types as #row::RowFormat>::View as #row::RowFormatView<#field_types>>::len(&#view_fields), rows);
302                        fields = &fields[cols..];
303                        arrays = &arrays[cols..];
304                    )*
305
306                    Ok(#view_name {
307                        #len: rows,
308                        #(#view_fields, )*
309                    })
310                }
311            }
312
313            #[automatically_derived]
314            impl #impl_generics #row::RowFormatView<#ident #ty_generics> for #view_name #ty_generics #where_clause {
315                #[inline]
316                fn len(&self) -> usize {
317                    self.#len
318                }
319
320                #impl_accessors
321            }
322
323            #[automatically_derived]
324            impl #impl_generics ::core::iter::IntoIterator for #view_name #ty_generics #where_clause {
325                type Item = #ident #ty_generics;
326                type IntoIter = #row::RowViewIter<#ident #ty_generics, #view_name #ty_generics>;
327
328                fn into_iter(self) -> Self::IntoIter {
329                    #row::RowViewIter::new(self)
330                }
331            }
332        }
333    }
334
335    fn field_idents(&self) -> Vec<TokenStream> {
336        self.fields.iter().map(|c| c.ident.clone()).collect()
337    }
338
339    fn field_types(&self) -> Vec<syn::Type> {
340        self.fields.iter().map(|c| c.ty.clone()).collect()
341    }
342
343    fn field_names(&self) -> Vec<String> {
344        self.fields.iter().map(|c| c.name.clone()).collect()
345    }
346
347    fn with_bounds(mut generics: Generics, crt: &TokenStream) -> Generics {
348        for param in &mut generics.params {
349            if let GenericParam::Type(ref mut param) = *param {
350                param
351                    .bounds
352                    .push(parse_quote!(#crt::common::row::RowFormat));
353                param.bounds.push(parse_quote!(::core::fmt::Debug));
354                param.bounds.push(parse_quote!(::core::clone::Clone));
355            }
356        }
357        generics
358    }
359}
360
361fn ella_crate() -> TokenStream {
362    let crt = proc_macro_crate::crate_name("ella").expect("ella crate not found in manifest");
363    match crt {
364        FoundCrate::Itself => quote! { ::ella },
365        FoundCrate::Name(name) => {
366            let ident = format_ident!("{name}");
367            quote! { ::#ident }
368        }
369    }
370}
371
372/// Search field attributes for `type` key and replace with raw `r#type`.
373///
374/// See [TedDriggs/darling#238](https://github.com/TedDriggs/darling/issues/238)
375fn fix_field_attrs(input: &mut DeriveInput) -> Result<(), syn::Error> {
376    match &mut input.data {
377        syn::Data::Struct(data) => {
378            for f in &mut data.fields {
379                for attr in &mut f.attrs {
380                    if attr.path().is_ident("row") {
381                        if let syn::Meta::List(list) = &mut attr.meta {
382                            list.tokens = std::mem::take(&mut list.tokens)
383                                .into_iter()
384                                .map(|token| match token {
385                                    TokenTree::Ident(ident) if ident == "type" => TokenTree::Ident(
386                                        proc_macro2::Ident::new_raw("type", ident.span()),
387                                    ),
388                                    _ => token,
389                                })
390                                .collect();
391                        }
392                    }
393                }
394            }
395        }
396        syn::Data::Enum(data) => {
397            return Err(syn::Error::new(
398                data.enum_token.span,
399                "RowFormat macro does not support enums".to_string(),
400            ))
401        }
402        syn::Data::Union(data) => {
403            return Err(syn::Error::new(
404                data.union_token.span,
405                "RowFormat macro does not support unions".to_string(),
406            ))
407        }
408    }
409    Ok(())
410}