1use darling::{ast::Style, FromDeriveInput, FromField};
2use proc_macro2::{Span, TokenStream, TokenTree};
3use proc_macro_crate::FoundCrate;
4use quote::{format_ident, quote, ToTokens};
5use syn::{parse_macro_input, parse_quote, DeriveInput, GenericParam, Generics};
6
7#[proc_macro_derive(RowFormat, attributes(row))]
8pub fn derive_row_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
9 let mut input = parse_macro_input!(input as DeriveInput);
10 if let Err(err) = fix_field_attrs(&mut input) {
11 return err.into_compile_error().into();
12 }
13
14 let parsed = match RowFormat::from_derive_input(&input) {
15 Ok(parsed) => parsed,
16 Err(err) => return err.write_errors().into(),
17 };
18 let builder = match RowFormatBuilder::new(parsed) {
19 Ok(builder) => builder,
20 Err(err) => return err.into_compile_error().into(),
21 };
22 builder.implement().into()
23}
24
25#[derive(Debug, Clone, FromDeriveInput)]
26#[darling(attributes(row), supports(struct_any), forward_attrs)]
27struct RowFormat {
28 ident: syn::Ident,
29 vis: syn::Visibility,
30 generics: Generics,
31 data: darling::ast::Data<(), Column>,
32 #[darling(default)]
33 builder: Option<syn::Ident>,
34 #[darling(default)]
35 view: Option<syn::Ident>,
36}
37
38#[derive(Debug, Clone, FromField)]
39#[darling(attributes(row), forward_attrs)]
40struct Column {
41 ident: Option<syn::Ident>,
42 ty: syn::Type,
43 #[darling(default)]
44 name: Option<String>,
45 #[darling(default, rename = "r#type")]
46 as_type: Option<syn::Type>,
47}
48
49struct ColumnBuilder {
50 ident: TokenStream,
51 ty: syn::Type,
52 name: String,
53}
54
55impl ColumnBuilder {
56 fn new(col: Column, num: TokenStream) -> Result<Self, syn::Error> {
57 let ident = col
58 .ident
59 .as_ref()
60 .map_or_else(|| num.clone(), |ident| ident.to_token_stream());
61 let name = match (col.name, &col.ident) {
62 (Some(name), _) => name,
63 (None, Some(ident)) => ident.to_string(),
64 _ => {
65 return Err(syn::Error::new_spanned(
66 col.ty.clone(),
67 "missing field name".to_string(),
68 ))
69 }
70 };
71 let ty = col.as_type.unwrap_or(col.ty);
72 Ok(Self { ident, ty, name })
73 }
74}
75
76struct RowFormatBuilder {
77 ident: syn::Ident,
78 vis: syn::Visibility,
79 generics: Generics,
80 style: Style,
81 crt: TokenStream,
82 fields: Vec<ColumnBuilder>,
83 view_name: syn::Ident,
84 builder_name: syn::Ident,
85}
86
87impl RowFormatBuilder {
88 fn new(input: RowFormat) -> Result<Self, syn::Error> {
89 let crt = ella_crate();
90 let generics = Self::with_bounds(input.generics, &crt);
91 let fields = input.data.take_struct().unwrap();
92 let (style, fields) = fields.split();
93 let fields = fields
94 .iter()
95 .enumerate()
96 .map(|(i, f)| ColumnBuilder::new(f.clone(), syn::Index::from(i).to_token_stream()))
97 .collect::<Result<Vec<_>, _>>()?;
98
99 let view_name = input
100 .view
101 .unwrap_or_else(|| format_ident!("_{}View", input.ident));
102 let builder_name = input
103 .builder
104 .unwrap_or_else(|| format_ident!("_{}Builder", input.ident));
105
106 Ok(Self {
107 ident: input.ident,
108 vis: input.vis,
109 generics,
110 style,
111 fields,
112 crt,
113 view_name,
114 builder_name,
115 })
116 }
117
118 fn implement(self) -> TokenStream {
119 let crt = &self.crt;
120 let row = quote! { #crt::common::row };
121 let ident = &self.ident;
122 let generics = &self.generics;
123 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
124
125 let view_name = &self.view_name;
126 let builder_name = &self.builder_name;
127
128 let field_types = self.field_types();
129
130 let impl_builder = self.impl_builder();
131 let impl_view = self.impl_view();
132
133 let num_cols = quote! {
134 #(<#field_types as #row::RowFormat>::COLUMNS +)* 0
135 };
136
137 quote! {
138 #[automatically_derived]
139 impl #impl_generics #crt::common::row::RowFormat for #ident #ty_generics #where_clause {
140 const COLUMNS: usize = #num_cols;
141 type Builder = #builder_name #ty_generics;
142 type View = #view_name #ty_generics;
143
144 fn builder(fields: &[::std::sync::Arc<#crt::derive::Field>]) -> #crt::Result<Self::Builder> {
145 #builder_name::<#ty_generics>::new(fields)
146 }
147
148 fn view(rows: usize, fields: &[::std::sync::Arc<#crt::derive::Field>], arrays: &[#crt::derive::ArrayRef]) -> #crt::Result<Self::View> {
149 #view_name::<#ty_generics>::new(rows, fields, arrays)
150 }
151 }
152
153 #impl_builder
154 #impl_view
155 }
156 }
157
158 fn impl_builder(&self) -> TokenStream {
159 let crt = &self.crt;
160 let row = quote! { #crt::common::row };
161 let vis = &self.vis;
162 let ident = &self.ident;
163 let generics = &self.generics;
164 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
165
166 let builder_name = &self.builder_name;
167 let field_types = self.field_types();
168 let field_idents = self.field_idents();
169 let builder_fields = self
170 .field_names()
171 .into_iter()
172 .map(|name| syn::Ident::new(&name, Span::call_site()))
173 .collect::<Vec<_>>();
174
175 let len = syn::Ident::new("_ella_len", Span::call_site());
176 let doc = format!("[`{}::RowBatchBuilder`] for [`{}`]", row, ident);
177
178 quote! {
179 #[doc = #doc]
180 #[derive(Debug, Clone)]
181 #vis struct #builder_name #generics {
182 #len: usize,
183 #(#builder_fields: <#field_types as #row::RowFormat>::Builder, )*
184 }
185
186 #[automatically_derived]
187 impl #impl_generics #builder_name #ty_generics #where_clause {
188 fn new(mut fields: &[::std::sync::Arc<#crt::derive::Field>]) -> #crt::Result<#builder_name #ty_generics> {
189 if fields.len() != <#ident #ty_generics as #row::RowFormat>::COLUMNS {
190 return Err(#crt::Error::ColumnCount(<#ident #ty_generics as #row::RowFormat>::COLUMNS, fields.len()));
191 }
192
193 #(
194 let cols = <#field_types as #row::RowFormat>::COLUMNS;
195 let #builder_fields = <#field_types as #row::RowFormat>::builder(&fields[..cols])?;
196 fields = &fields[cols..];
197 )*
198
199 Ok(#builder_name {
200 #len: 0,
201 #(#builder_fields, )*
202 })
203 }
204 }
205
206 #[automatically_derived]
207 impl #impl_generics #row::RowBatchBuilder<#ident #ty_generics> for #builder_name #ty_generics #where_clause {
208 #[inline]
209 fn len(&self) -> usize {
210 self.#len
211 }
212
213 fn push(&mut self, row: #ident #ty_generics) {
214 #(
215 <<#field_types as #row::RowFormat>::Builder as #row::RowBatchBuilder<#field_types>>::push(&mut self.#builder_fields, row.#field_idents.into());
216 )*
217 self.#len += 1;
218 }
219
220 fn build_columns(&mut self) -> #crt::Result<::std::vec::Vec<#crt::derive::ArrayRef>> {
221 let mut cols = ::std::vec::Vec::with_capacity(<#ident #ty_generics as #row::RowFormat>::COLUMNS);
223 #(
224 cols.extend(<<#field_types as #row::RowFormat>::Builder as #row::RowBatchBuilder<#field_types>>::build_columns(&mut self.#builder_fields)?);
225 )*
226 self.#len = 0;
227 Ok(cols)
228 }
229 }
230 }
231 }
232
233 fn impl_view(&self) -> TokenStream {
234 let crt = &self.crt;
235 let row = quote! { #crt::common::row };
236 let vis = &self.vis;
237 let ident = &self.ident;
238 let generics = &self.generics;
239 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
240
241 let view_name = &self.view_name;
242 let field_types = self.field_types();
243 let field_idents = self.field_idents();
244 let view_fields = self
245 .field_names()
246 .into_iter()
247 .map(|name| syn::Ident::new(&name, Span::call_site()))
248 .collect::<Vec<_>>();
249
250 let len = syn::Ident::new("_ella_len", Span::call_site());
251 let doc = format!("[`{}::RowFormatView`] for [`{}`]", row, ident);
252
253 let impl_accessors = if self.style.is_tuple() {
254 quote! {
255 fn row(&self, i: usize) -> #ident #ty_generics {
256 #ident(
257 #(<<#field_types as #row::RowFormat>::View as #row::RowFormatView<#field_types>>::row(&self.#view_fields, i).into(), )*
258 )
259 }
260
261 unsafe fn row_unchecked(&self, i: usize) -> #ident #ty_generics {
262 #ident(
263 #(<<#field_types as #row::RowFormat>::View as #row::RowFormatView<#field_types>>::row_unchecked(&self.#view_fields, i).into(), )*
264 )
265 }
266 }
267 } else {
268 quote! {
269 fn row(&self, i: usize) -> #ident #ty_generics {
270 #ident {
271 #(#field_idents: <<#field_types as #row::RowFormat>::View as #row::RowFormatView<#field_types>>::row(&self.#view_fields, i).into(), )*
272 }
273 }
274
275 unsafe fn row_unchecked(&self, i: usize) -> #ident #ty_generics {
276 #ident {
277 #(#field_idents: <<#field_types as #row::RowFormat>::View as #row::RowFormatView<#field_types>>::row_unchecked(&self.#view_fields, i).into(), )*
278 }
279 }
280 }
281 };
282
283 quote! {
284 #[doc = #doc]
285 #[derive(Debug, Clone)]
286 #vis struct #view_name #generics {
287 #len: usize,
288 #(#view_fields: <#field_types as #row::RowFormat>::View, )*
289 }
290
291 #[automatically_derived]
292 impl #impl_generics #view_name #ty_generics #where_clause {
293 fn new(rows: usize, mut fields: &[::std::sync::Arc<#crt::derive::Field>], mut arrays: &[#crt::derive::ArrayRef]) -> #crt::Result<#view_name #ty_generics> {
294 if arrays.len() != <#ident #ty_generics as #row::RowFormat>::COLUMNS {
295 return Err(#crt::Error::ColumnCount(<#ident as #row::RowFormat>::COLUMNS, fields.len()));
296 }
297
298 #(
299 let cols = <#field_types as #row::RowFormat>::COLUMNS;
300 let #view_fields = <#field_types as #row::RowFormat>::view(rows, &fields[..cols], &arrays[..cols])?;
301 debug_assert_eq!(<<#field_types as #row::RowFormat>::View as #row::RowFormatView<#field_types>>::len(&#view_fields), rows);
302 fields = &fields[cols..];
303 arrays = &arrays[cols..];
304 )*
305
306 Ok(#view_name {
307 #len: rows,
308 #(#view_fields, )*
309 })
310 }
311 }
312
313 #[automatically_derived]
314 impl #impl_generics #row::RowFormatView<#ident #ty_generics> for #view_name #ty_generics #where_clause {
315 #[inline]
316 fn len(&self) -> usize {
317 self.#len
318 }
319
320 #impl_accessors
321 }
322
323 #[automatically_derived]
324 impl #impl_generics ::core::iter::IntoIterator for #view_name #ty_generics #where_clause {
325 type Item = #ident #ty_generics;
326 type IntoIter = #row::RowViewIter<#ident #ty_generics, #view_name #ty_generics>;
327
328 fn into_iter(self) -> Self::IntoIter {
329 #row::RowViewIter::new(self)
330 }
331 }
332 }
333 }
334
335 fn field_idents(&self) -> Vec<TokenStream> {
336 self.fields.iter().map(|c| c.ident.clone()).collect()
337 }
338
339 fn field_types(&self) -> Vec<syn::Type> {
340 self.fields.iter().map(|c| c.ty.clone()).collect()
341 }
342
343 fn field_names(&self) -> Vec<String> {
344 self.fields.iter().map(|c| c.name.clone()).collect()
345 }
346
347 fn with_bounds(mut generics: Generics, crt: &TokenStream) -> Generics {
348 for param in &mut generics.params {
349 if let GenericParam::Type(ref mut param) = *param {
350 param
351 .bounds
352 .push(parse_quote!(#crt::common::row::RowFormat));
353 param.bounds.push(parse_quote!(::core::fmt::Debug));
354 param.bounds.push(parse_quote!(::core::clone::Clone));
355 }
356 }
357 generics
358 }
359}
360
361fn ella_crate() -> TokenStream {
362 let crt = proc_macro_crate::crate_name("ella").expect("ella crate not found in manifest");
363 match crt {
364 FoundCrate::Itself => quote! { ::ella },
365 FoundCrate::Name(name) => {
366 let ident = format_ident!("{name}");
367 quote! { ::#ident }
368 }
369 }
370}
371
372fn fix_field_attrs(input: &mut DeriveInput) -> Result<(), syn::Error> {
376 match &mut input.data {
377 syn::Data::Struct(data) => {
378 for f in &mut data.fields {
379 for attr in &mut f.attrs {
380 if attr.path().is_ident("row") {
381 if let syn::Meta::List(list) = &mut attr.meta {
382 list.tokens = std::mem::take(&mut list.tokens)
383 .into_iter()
384 .map(|token| match token {
385 TokenTree::Ident(ident) if ident == "type" => TokenTree::Ident(
386 proc_macro2::Ident::new_raw("type", ident.span()),
387 ),
388 _ => token,
389 })
390 .collect();
391 }
392 }
393 }
394 }
395 }
396 syn::Data::Enum(data) => {
397 return Err(syn::Error::new(
398 data.enum_token.span,
399 "RowFormat macro does not support enums".to_string(),
400 ))
401 }
402 syn::Data::Union(data) => {
403 return Err(syn::Error::new(
404 data.union_token.span,
405 "RowFormat macro does not support unions".to_string(),
406 ))
407 }
408 }
409 Ok(())
410}