goldleaf_derive/
lib.rs

1#![deny(clippy::unwrap_used)]
2
3use proc_macro::TokenStream;
4
5use darling::util::Flag;
6use darling::{ast, FromDeriveInput, FromField, FromMeta};
7use quote::quote;
8use syn::{parse_macro_input, DeriveInput};
9
10#[derive(FromMeta, Copy, Clone)]
11enum TwoD {
12    Spherical,
13    Cartesian,
14}
15
16#[derive(FromMeta, Default)]
17struct FieldIdentityMetaData {
18    /// Subfield for multikey
19    sub: Option<String>,
20    /// The index number
21    index: Option<i32>,
22    /// Links to another index, creating a compound
23    /// Must be format <name>
24    link: Option<String>,
25    /// The link order
26    order: Option<u8>,
27    unique: Flag,
28    /// Automatically sets to be text index
29    text_weight: Option<u8>,
30    two_d: Option<TwoD>,
31    /// Flags a field as containing the locale information
32    icase_locale: Option<String>,
33    icase_strength: Option<u8>,
34    name: Option<String>,
35    two_d_bits: Option<u32>,
36    two_d_max: Option<f64>,
37    two_d_min: Option<f64>,
38    lang_field: Flag,
39    pfe: Option<String>,
40}
41
42const DEFAULT_2D_BITS: u32 = 26;
43const DEFAULT_2D_MIN: f64 = -180.0;
44const DEFAULT_2D_MAX: f64 = 180.0;
45
46#[derive(FromField)]
47#[darling(attributes(db))]
48struct FieldIdentityData {
49    ident: Option<syn::Ident>,
50    id_field: Flag,
51    native_id_field: Flag,
52    indexing: Option<FieldIdentityMetaData>,
53}
54
55#[derive(Default, Copy, Clone)]
56enum TwoDPacked {
57    Spherical {
58        bits: u32,
59        max: f64,
60        min: f64,
61    },
62    #[default]
63    Cartesian,
64}
65
66struct CombinedFieldIdentityData {
67    ident: Option<syn::Ident>,
68    /// Subfield for multikey
69    sub: Option<String>,
70    /// The index number
71    index: Option<i32>,
72    /// Links to another index, creating a compound
73    /// Must be format <name>
74    link: Option<String>,
75    /// The link order
76    order: Option<u8>,
77    unique: Flag,
78    /// Automatically sets to be text index
79    text_weight: Option<u8>,
80    two_d: Option<TwoDPacked>,
81    icase_locale: Option<String>,
82    icase_strength: Option<u8>,
83    name: Option<String>,
84    lang_field: Flag,
85    /// Partial filter expression
86    pfe: Option<String>,
87}
88
89impl From<FieldIdentityData> for CombinedFieldIdentityData {
90    fn from(value: FieldIdentityData) -> Self {
91        let meta = value.indexing.expect("Indexing metadata");
92        CombinedFieldIdentityData {
93            ident: value.ident,
94            sub: meta.sub,
95            index: meta.index,
96            link: meta.link,
97            order: meta.order,
98            unique: meta.unique,
99            text_weight: meta.text_weight,
100            two_d: match meta.two_d {
101                None => None,
102                Some(two_d) => Some(match two_d {
103                    TwoD::Spherical => TwoDPacked::Spherical {
104                        bits: meta.two_d_bits.unwrap_or(DEFAULT_2D_BITS),
105                        max: meta.two_d_max.unwrap_or(DEFAULT_2D_MAX),
106                        min: meta.two_d_min.unwrap_or(DEFAULT_2D_MIN),
107                    },
108                    TwoD::Cartesian => TwoDPacked::Cartesian,
109                }),
110            },
111            icase_locale: meta.icase_locale,
112            icase_strength: meta.icase_strength,
113            name: meta.name,
114            lang_field: meta.lang_field,
115            pfe: meta.pfe,
116        }
117    }
118}
119
120enum IndexType {
121    Numeric(i32),
122    Text(u32),
123    TwoD(TwoDPacked),
124}
125
126struct IndexPair {
127    ident: String,
128    index: IndexType,
129    order_index: u8,
130    is_lang_field: bool,
131}
132
133#[derive(Default)]
134struct CaseInsensitivity {
135    locale: String,
136    strength: u8,
137}
138
139struct CollatedFieldIdentityData {
140    pairs: Vec<IndexPair>,
141    unique: bool,
142    name: Option<String>,
143    link: Option<String>,
144    case_insensitivity: Option<CaseInsensitivity>,
145    two_d: Option<TwoDPacked>,
146    /// Partial filter expression
147    pfe: Option<String>,
148}
149
150#[derive(FromDeriveInput)]
151#[darling(attributes(db), supports(struct_named), forward_attrs(allow, doc, cfg))]
152struct CollectionIdentityData {
153    ident: syn::Ident,
154    name: String,
155    /// When the document will expire, in seconds
156    expiration_secs: Option<u64>,
157    data: ast::Data<(), FieldIdentityData>,
158}
159
160#[proc_macro_derive(CollectionIdentity, attributes(db))]
161pub fn collection_identity(input: TokenStream) -> TokenStream {
162    let input = parse_macro_input!(input as DeriveInput);
163    let collection = match CollectionIdentityData::from_derive_input(&input) {
164        Ok(parsed) => parsed,
165        Err(e) => return e.write_errors().into(),
166    };
167
168    let collection_name = collection.name;
169    let struct_name = collection.ident;
170
171    // Generate indexing if necessary
172    let mut fields = collection
173        .data
174        .take_struct()
175        .expect("Must be struct")
176        .fields;
177
178    let mut id_field = None;
179    let mut native_id = false;
180
181    for field in &mut fields {
182        if field.id_field.is_present() || field.native_id_field.is_present() {
183            if id_field.is_some()
184                || (field.id_field.is_present() && field.native_id_field.is_present())
185            {
186                panic!("Multiple ID fields not allowed!");
187            }
188
189            id_field = Some(
190                field
191                    .ident
192                    .as_ref()
193                    .expect("ID field identifier")
194                    .to_string(),
195            );
196            native_id = field.native_id_field.is_present();
197
198            if !native_id {
199                field.indexing.get_or_insert_default().unique = Flag::present();
200            }
201        }
202    }
203
204    let id_field = id_field.expect("ID field must be present!");
205    let id_field_tok: syn::Ident = syn::parse_str(&id_field).expect("Valid parse of ID field");
206    let (id_field, id_field_value) = if native_id {
207        (
208            format!("_{id_field}"),
209            quote!(self.#id_field_tok.as_ref().unwrap()),
210        )
211    } else {
212        (id_field, quote!(&self.#id_field_tok))
213    };
214
215    let sync_impl = if cfg!(feature = "sync") {
216        quote! {
217            fn save_sync(&self, db: &::goldleaf::mongodb::sync::Database) -> Result<(), ::mongodb::error::Error> {
218                let coll = <::goldleaf::mongodb::sync::Database as ::goldleaf::SyncAutoCollection>::auto_collection::<Self>(db);
219                let res = coll.replace_one(::goldleaf::mongodb::bson::doc! {
220                    #id_field: #id_field_value
221                }, self).run()?;
222
223                debug_assert_eq!(res.matched_count, 1, "unable to find structure with identifying field `{}`", #id_field);
224
225                Ok(())
226            }
227        }
228    } else {
229        quote! {}
230    };
231
232    // Generate collection identity
233    let identity = quote! {
234        #[::goldleaf::async_trait]
235        impl ::goldleaf::CollectionIdentity for #struct_name {
236            const COLLECTION: &'static str = #collection_name;
237
238            async fn save(&self, db: &::goldleaf::mongodb::Database) -> Result<(), ::mongodb::error::Error> {
239                let coll = <::goldleaf::mongodb::Database as ::goldleaf::AutoCollection>::auto_collection::<Self>(db);
240                let res = coll.replace_one(::goldleaf::mongodb::bson::doc! {
241                    #id_field: #id_field_value
242                }, self).await?;
243
244                debug_assert_eq!(res.matched_count, 1, "unable to find structure with identifying field `{}`", #id_field);
245
246                Ok(())
247            }
248
249            #sync_impl
250        }
251    };
252
253    let indexing_fields = fields
254        .into_iter()
255        .filter(|f| f.indexing.is_some())
256        .collect::<Vec<_>>();
257    if indexing_fields.is_empty() {
258        return identity.into();
259    }
260    let indexing_fields = indexing_fields
261        .into_iter()
262        .map(CombinedFieldIdentityData::from)
263        .collect::<Vec<_>>();
264
265    // Collate indices into a more readable struct
266    let mut identities: Vec<CollatedFieldIdentityData> = vec![];
267    for field in indexing_fields {
268        // If this field is linked, try to find other fields with the same link ID and combine them
269        if let Some(link_id) = &field.link {
270            if let Some(id) = identities
271                .iter_mut()
272                .find(|id| id.link.as_ref().is_some_and(|l| l == link_id))
273            {
274                id.pairs.push(generate_index_pair(&field));
275
276                id.pairs.sort_unstable_by_key(|data| data.order_index);
277
278                id.unique = id.unique || field.unique.is_present();
279                if let Some(name) = field.name {
280                    id.name = Some(name);
281                }
282                if let (Some(locale), Some(strength)) = (field.icase_locale, field.icase_strength) {
283                    id.case_insensitivity = Some(CaseInsensitivity { locale, strength })
284                }
285
286                if let Some(two_d) = field.two_d {
287                    id.two_d = Some(two_d);
288                }
289            }
290        } else {
291            // This field is independent, just generate identity data separately
292            identities.push(CollatedFieldIdentityData {
293                pairs: vec![generate_index_pair(&field)],
294                unique: field.unique.is_present(),
295                name: field.name,
296                link: field.link,
297                case_insensitivity: if let (Some(locale), Some(strength)) =
298                    (field.icase_locale, field.icase_strength)
299                {
300                    Some(CaseInsensitivity { locale, strength })
301                } else {
302                    None
303                },
304                two_d: field.two_d,
305                pfe: field.pfe,
306            })
307        }
308    }
309
310    // Generate doc strings
311    let docs = identities
312        .iter()
313        .map(|i| {
314            let pairs = i.pairs.iter().map(|p| {
315                let ident = p.ident.clone();
316                match &p.index {
317                    IndexType::Numeric(val) => quote! {
318                        #ident: #val
319                    },
320                    IndexType::Text { .. } => quote! {
321                        #ident: "text"
322                    },
323                    IndexType::TwoD(two_d) => match two_d {
324                        TwoDPacked::Spherical { .. } => quote! {
325                            #ident: "2dsphere"
326                        },
327                        TwoDPacked::Cartesian => quote! {
328                            #ident: "2d"
329                        },
330                    },
331                }
332            });
333
334            quote! {
335                ::goldleaf::mongodb::bson::doc!{#(#pairs),*}
336            }
337        })
338        .collect::<Vec<_>>();
339
340    // Generate builder strings
341    let opts = identities.iter().map(|i| {
342        let index_name = i.name.clone().unwrap_or("".to_string());
343        let unique = i.unique;
344
345        // Figure out if any index pair has this info
346        // MULTIDIMENSIONAL SEARCHING
347        let use_two_d = i.two_d.is_some_and(|t| match t {
348            TwoDPacked::Spherical { .. } => true,
349            TwoDPacked::Cartesian => false,
350        });
351        let two_d = i.two_d.unwrap_or_default();
352        let (bits, max, min) = match two_d {
353            TwoDPacked::Spherical { bits, max, min } => (bits, max, min),
354            TwoDPacked::Cartesian => (0, 0f64, 0f64),
355        };
356
357        // TEXT WEIGHTS
358        let pairs = i.pairs.iter().filter_map(|p| match p.index {
359            IndexType::Text(weight) => Some((p, weight)),
360            _ => None,
361        }).map(|(text_pair, weight)| {
362            let ident = text_pair.ident.clone();
363            quote! { #ident: #weight }
364        }).collect::<Vec<_>>();
365
366        let has_weights = !pairs.is_empty();
367
368        let weights = quote! {
369            ::goldleaf::mongodb::bson::doc!{#(#pairs),*}
370        };
371
372        // COLLATION
373        let use_collation = i.case_insensitivity.is_some();
374        let collation = match &i.case_insensitivity {
375            None => quote! {
376                ::goldleaf::mongodb::options::Collation::builder().locale("en".to_string()).build()
377            },
378            Some(case_insensitivity) => {
379                let locale = &case_insensitivity.locale;
380                let strength = case_insensitivity.strength;
381                let strength = quote! {
382                    match #strength {
383                        1 => ::goldleaf::mongodb::options::CollationStrength::Primary,
384                        2 => ::goldleaf::mongodb::options::CollationStrength::Secondary,
385                        3 => ::goldleaf::mongodb::options::CollationStrength::Tertiary,
386                        4 => ::goldleaf::mongodb::options::CollationStrength::Quaternary,
387                        5 => ::goldleaf::mongodb::options::CollationStrength::Identical,
388                        _ => panic!("Collation strength out of bounds!")
389                    }
390                };
391                quote! {
392                    ::goldleaf::mongodb::options::Collation::builder().locale(#locale.to_string()).strength(Some(#strength)).build()
393                }
394            },
395        };
396
397        // LANGUAGE OVERRIDES
398        let language = i.pairs.iter().find_map(|p| if p.is_lang_field { Some(p.ident.clone()) } else { None }).unwrap_or("".to_string());
399
400        // EXPIRATION
401        let expiration_secs = collection.expiration_secs.unwrap_or(0);
402
403        // PARTIAL FILTER EXPRESSIONS
404        let has_pfe = i.pfe.is_some();
405        let pfe: proc_macro2::TokenStream = i.pfe.clone().unwrap_or_default().parse().expect("PFE to be parseable");
406        let pfe = quote! {
407            ::goldleaf::mongodb::bson::doc!{#pfe}
408        };
409
410        quote! {
411            ::goldleaf::mongodb::options::IndexOptions::builder()
412                .name(if #index_name.is_empty() {None} else {Some(#index_name.to_string())})
413                .unique(Some(#unique))
414                .expire_after(if #expiration_secs > 0 {Some(::std::time::Duration::from_secs(#expiration_secs))} else {None})
415                .weights(if #has_weights {Some(#weights)} else {None})
416                .bits(if #use_two_d {Some(#bits)} else {None})
417                .max(if #use_two_d {Some(#max)} else {None})
418                .min(if #use_two_d {Some(#min)} else {None})
419                .collation(if #use_collation {Some(#collation)} else {None})
420                .language_override(if #language.is_empty() {None} else {Some(#language.to_string())})
421                .partial_filter_expression(if #has_pfe {Some(#pfe)} else {None})
422                .build()
423        }
424    }).collect::<Vec<_>>();
425
426    // Concatenate strings into function call
427    let calls = docs.iter().zip(opts.iter()).map(|(doc, opt)| quote! {coll.create_index(::goldleaf::mongodb::IndexModel::builder().keys(#doc).options(Some(#opt)).build()).await?;}).collect::<Vec<_>>();
428
429    // Generate quotes
430    let indices = quote! {
431        impl #struct_name {
432            pub async fn create_indices(db: &::goldleaf::mongodb::Database) -> Result<(), ::mongodb::error::Error> {
433                let coll = <::goldleaf::mongodb::Database as ::goldleaf::AutoCollection>::auto_collection::<Self>(db);
434
435                #(#calls)*
436                Ok(())
437            }
438        }
439    };
440
441    // Append tokens
442    let out = quote! {
443        #identity
444
445        #indices
446    };
447
448    out.into()
449}
450
451fn generate_index_pair(field: &CombinedFieldIdentityData) -> IndexPair {
452    IndexPair {
453        ident: match field.sub.as_ref() {
454            None => field.ident.as_ref().expect("Field identifier").to_string(),
455            Some(sub) => format!(
456                "{}.{}",
457                field.ident.as_ref().expect("Field identifier"),
458                sub
459            ),
460        },
461        index: if let Some(text_weight) = field.text_weight {
462            IndexType::Text(text_weight.into())
463        } else if let Some(two_d) = field.two_d.as_ref() {
464            IndexType::TwoD(*two_d)
465        } else {
466            IndexType::Numeric(field.index.unwrap_or(1))
467        },
468        order_index: field.order.unwrap_or(0),
469        is_lang_field: field.lang_field.is_present(),
470    }
471}