Skip to main content

tantivy_derive_impl/
lib.rs

1use darling::{FromDeriveInput, FromField, ast, util};
2use proc_macro2::TokenStream;
3use quote::{ToTokens, format_ident, quote};
4use syn::{DeriveInput, Ident, Type, Visibility, parse_macro_input};
5
6#[derive(Debug, FromField)]
7#[darling(attributes(tantivy))]
8struct Field {
9    vis: Visibility,
10    ident: Option<Ident>,
11    ty: Type,
12    #[darling(default)]
13    coerce: bool,
14    #[darling(default)]
15    fast: bool,
16    #[darling(default)]
17    fieldnorms: bool,
18    #[darling(default)]
19    indexed: bool,
20    #[darling(default)]
21    stored: bool,
22    #[darling(default)]
23    store_target: Option<String>,
24    #[darling(default)]
25    string: bool,
26    #[darling(default)]
27    text: bool,
28    #[darling(default)]
29    fast_tokenizer: Option<String>,
30    #[darling(default)]
31    tokenizer: Option<String>,
32    #[darling(default)]
33    index_option: Option<String>,
34    #[darling(default)]
35    precision: Option<String>,
36}
37
38impl Field {
39    fn parse(
40        &self,
41    ) -> (
42        TokenStream,
43        TokenStream,
44        TokenStream,
45        TokenStream,
46        TokenStream,
47    ) {
48        let Field {
49            ident,
50            ty,
51            coerce,
52            fast,
53            fieldnorms,
54            indexed,
55            stored,
56            string,
57            text,
58            fast_tokenizer,
59            tokenizer,
60            index_option,
61            precision,
62            ..
63        } = self;
64
65        let name = ident.as_ref().expect("must be a named struct").to_string();
66        let name = name.trim_start_matches('_');
67
68        let count_token = quote! {
69            count += <#ty>::count_fields();
70        };
71
72        let from_token = if *stored {
73            quote! {
74                let #ident = <#ty>::extract_from_document(&document, field_id)?;
75                field_id += <#ty>::count_fields();
76            }
77        } else {
78            quote! {
79                field_id += <#ty>::count_fields();
80            }
81        };
82
83        let field_token = if *stored {
84            quote! { #ident, }
85        } else {
86            TokenStream::new()
87        };
88
89        let coerce = if *coerce {
90            quote! { options.set_coerce(true); }
91        } else {
92            TokenStream::new()
93        };
94
95        let fast = if *fast {
96            quote! { options.set_fast(true); }
97        } else {
98            TokenStream::new()
99        };
100
101        let fieldnorms = if *fieldnorms {
102            quote! { options.set_fieldnorms(true); }
103        } else {
104            TokenStream::new()
105        };
106
107        let indexed = if *indexed {
108            quote! { options.set_indexed(true); }
109        } else {
110            TokenStream::new()
111        };
112
113        let stored = if *stored {
114            quote! { options.set_stored(true); }
115        } else {
116            TokenStream::new()
117        };
118
119        let string = if *string {
120            quote! { options.set_string(true); }
121        } else {
122            TokenStream::new()
123        };
124
125        let text = if *text {
126            quote! { options.set_text(true); }
127        } else {
128            TokenStream::new()
129        };
130
131        let index_option = match index_option.as_ref().map(|s| s.as_str()) {
132            Some("basic") => quote! { options.set_index_option(IndexRecordOption::Basic); },
133            Some("frequency") => quote! { options.set_index_option(IndexRecordOption::WithFreqs); },
134            Some("frequency-and-position") => {
135                quote! { options.set_index_option(IndexRecordOption::WithFreqsAndPositions); }
136            }
137            _ => TokenStream::new(),
138        };
139
140        let fast_tokenizer = if let Some(tokenizer) = fast_tokenizer {
141            quote! { options.set_fast_tokenizer(#tokenizer); }
142        } else {
143            TokenStream::new()
144        };
145
146        let tokenizer = if let Some(tokenizer) = tokenizer {
147            quote! { options.set_tokenizer(#tokenizer); }
148        } else {
149            TokenStream::new()
150        };
151
152        let precision = match precision.as_ref().map(|s| s.as_str()) {
153            Some("seconds") => quote! { options.set_precision(DateTimePrecision::Seconds); },
154            Some("milliseconds") => {
155                quote! { options.set_precision(DateTimePrecision::Milliseconds); }
156            }
157            Some("microseconds") => {
158                quote! { options.set_precision(DateTimePrecision::Microseconds); }
159            }
160            Some("nanoseconds") => {
161                quote! { options.set_precision(DateTimePrecision::Nanoseconds); }
162            }
163            _ => TokenStream::new(),
164        };
165
166        let schema_token = quote! {
167            let mut options: tantivy_derive::FieldOptions = Default::default();
168            #coerce
169            #fast
170            #fieldnorms
171            #indexed
172            #stored
173            #string
174            #text
175            #fast_tokenizer
176            #tokenizer
177            #index_option
178            #precision
179            <#ty>::add_field(builder, #name, options);
180        };
181
182        let into_token = quote! {
183            <#ty>::insert_into_document(document, field_id, &value.#ident);
184            field_id += <#ty>::count_fields();
185        };
186
187        (
188            schema_token,
189            count_token,
190            from_token,
191            field_token,
192            into_token,
193        )
194    }
195
196    fn parse_stored(&self) -> TokenStream {
197        let Field {
198            vis,
199            ident,
200            ty,
201            stored,
202            store_target,
203            ..
204        } = self;
205
206        if *stored {
207            if let Some(target) = store_target {
208                quote! {
209                    #vis #ident: #target,
210                }
211            } else {
212                quote! {
213                    #vis #ident: #ty,
214                }
215            }
216        } else {
217            TokenStream::new()
218        }
219    }
220}
221
222#[derive(Debug, FromDeriveInput)]
223#[darling(attributes(tantivy), supports(struct_named))]
224struct Document {
225    ident: Ident,
226    generics: syn::Generics,
227    data: ast::Data<util::Ignored, Field>,
228}
229
230impl ToTokens for Document {
231    fn to_tokens(&self, tokens: &mut TokenStream) {
232        let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
233        let name = &self.ident;
234        let stored_name = format_ident!("Stored{name}");
235
236        let fields = self
237            .data
238            .as_ref()
239            .take_struct()
240            .expect("must be struct")
241            .fields;
242
243        let mut schema_tokens = Vec::with_capacity(fields.len());
244        let mut count_tokens = Vec::with_capacity(fields.len());
245        let mut from_tokens = Vec::with_capacity(fields.len());
246        let mut field_tokens = Vec::with_capacity(fields.len());
247        let mut into_tokens = Vec::with_capacity(fields.len());
248
249        for field in fields {
250            let (schema_token, count_token, from_token, field_token, into_token) = field.parse();
251
252            schema_tokens.push(schema_token);
253            count_tokens.push(count_token);
254            from_tokens.push(from_token);
255            field_tokens.push(field_token);
256            into_tokens.push(into_token);
257        }
258
259        tokens.extend(quote! {
260            impl #impl_generics tantivy_derive::Field for #name #ty_generics #where_clause {
261                type Target = #stored_name;
262
263                fn add_field(builder: &mut tantivy::schema::SchemaBuilder, name: &str, options: tantivy_derive::FieldOptions) {
264                    use tantivy::schema::*;
265                    use tantivy_derive::Field as _;
266
267                    #(
268                        #schema_tokens
269                    )*
270                }
271
272                fn count_fields() -> u32 {
273                    let mut count = 0;
274
275                    #(
276                        #count_tokens
277                    )*
278
279                    count
280                }
281
282                fn insert_into_document(
283                    document: &mut tantivy::schema::TantivyDocument,
284                    mut field_id: u32,
285                    value: &Self,
286                ) {
287                    #(
288                        #into_tokens
289                    )*
290                }
291            }
292
293            impl #impl_generics tantivy_derive::Extractable for #name #ty_generics #where_clause {
294                fn extract_from_document(
295                    document: &tantivy::schema::TantivyDocument,
296                    mut field_id: u32,
297                ) -> Option<Self::Target> {
298                    use tantivy_derive::{Extractable as _, Field as _};
299
300                    #(
301                        #from_tokens
302                    )*
303
304                    Some(Self::Target {
305                        #(
306                            #field_tokens
307                        )*
308                    })
309                }
310            }
311
312            impl #impl_generics std::convert::From<#name> for tantivy::schema::TantivyDocument #ty_generics #where_clause {
313                fn from(value: #name) -> tantivy::schema::TantivyDocument {
314                    use tantivy_derive::Field as _;
315
316                    let mut document = tantivy::schema::TantivyDocument::new();
317                    #name::insert_into_document(&mut document, 0, &value);
318                    document
319                }
320            }
321
322            impl #impl_generics std::convert::From<tantivy::schema::TantivyDocument> for #stored_name #ty_generics #where_clause {
323                fn from(document: tantivy::schema::TantivyDocument) -> Self {
324                    use tantivy_derive::{Extractable as _, Field as _};
325
326                    #name::extract_from_document(&document, 0).expect("missing field")
327                }
328            }
329
330            impl #impl_generics tantivy_derive::Schema for #name #ty_generics #where_clause {
331                fn schema() -> tantivy::schema::Schema {
332                    use tantivy::schema::*;
333                    use tantivy_derive::Field as _;
334
335                    let mut builder = Schema::builder();
336                    Self::add_field(&mut builder, "", Default::default());
337                    builder.build()
338                }
339            }
340        });
341    }
342}
343
344#[derive(Debug, FromDeriveInput)]
345#[darling(
346    attributes(tantivy),
347    supports(struct_named),
348    forward_attrs(allow, cfg, derive)
349)]
350struct StoredDocument {
351    ident: Ident,
352    vis: syn::Visibility,
353    data: ast::Data<util::Ignored, Field>,
354    attrs: Vec<syn::Attribute>,
355}
356
357impl ToTokens for StoredDocument {
358    fn to_tokens(&self, tokens: &mut TokenStream) {
359        let name = &self.ident;
360        let vis = &self.vis;
361
362        let fields = self
363            .data
364            .as_ref()
365            .take_struct()
366            .expect("must be struct")
367            .fields;
368
369        let mut field_tokens = Vec::with_capacity(fields.len());
370
371        for field in fields {
372            let token = field.parse_stored();
373
374            field_tokens.push(token);
375        }
376
377        let attrs: Vec<TokenStream> = self.attrs.iter().map(|attr| quote! { #attr }).collect();
378
379        tokens.extend(quote! {
380            #(
381                #attrs
382            )*
383            #vis struct #name {
384                #(
385                    #field_tokens
386                )*
387            }
388        });
389    }
390}
391
392#[proc_macro_derive(Document, attributes(tantivy))]
393pub fn derive_document(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
394    let input = parse_macro_input!(input as DeriveInput);
395    let receiver = Document::from_derive_input(&input).expect("cannot parse");
396    quote!(#receiver).into()
397}
398
399#[proc_macro_attribute]
400pub fn tantivy_document(
401    args: proc_macro::TokenStream,
402    input: proc_macro::TokenStream,
403) -> proc_macro::TokenStream {
404    use syn::{Meta, punctuated::Punctuated};
405    let args = parse_macro_input!(args with Punctuated::<Meta, syn::Token![,]>::parse_terminated);
406    let mut input = parse_macro_input!(input as DeriveInput);
407    let mut struct_name = format_ident!("Stored{}", input.ident);
408
409    for arg in &args {
410        let Ok(arg) = arg.require_name_value() else {
411            continue;
412        };
413        let Some(ident) = arg.path.get_ident() else {
414            continue;
415        };
416        let name = ident.to_string();
417
418        if name != "name" {
419            continue;
420        }
421
422        let syn::Expr::Lit(ref value) = arg.value else {
423            continue;
424        };
425        let syn::Lit::Str(ref value) = value.lit else {
426            continue;
427        };
428
429        struct_name = format_ident!("{}", value.value());
430    }
431
432    let original = quote! { #input };
433
434    input.ident = struct_name;
435    let receiver = StoredDocument::from_derive_input(&input).expect("cannot parse");
436
437    quote! {
438        #[derive(tantivy_derive::Document)]
439        #original
440        #receiver
441    }
442    .into()
443}
444
445#[cfg(test)]
446mod tests {
447    #[test]
448    fn it_works() {
449        use crate::StoredDocument;
450        use darling::FromDeriveInput as _;
451
452        let input = r#"#[derive(Debug)]
453        pub struct Document {
454            #[tantivy(stored, text)]
455            pub title: String,
456            #[tantivy(text)]
457            pub body: String,
458        }"#;
459        let parsed = syn::parse_str(input).unwrap();
460        let receiver = StoredDocument::from_derive_input(&parsed).unwrap();
461        let tokens = quote::quote!(#receiver);
462
463        panic!("{}", tokens.to_string());
464    }
465}