1#![deny(clippy::unwrap_used)]
2
3use proc_macro::TokenStream;
4
5use darling::util::Flag;
6use darling::{ast, FromDeriveInput, FromField, FromMeta};
7use quote::quote;
8use syn::{parse_macro_input, DeriveInput};
9
10#[derive(FromMeta, Copy, Clone)]
11enum TwoD {
12 Spherical,
13 Cartesian,
14}
15
16#[derive(FromMeta, Default)]
17struct FieldIdentityMetaData {
18 sub: Option<String>,
20 index: Option<i32>,
22 link: Option<String>,
25 order: Option<u8>,
27 unique: Flag,
28 text_weight: Option<u8>,
30 two_d: Option<TwoD>,
31 icase_locale: Option<String>,
33 icase_strength: Option<u8>,
34 name: Option<String>,
35 two_d_bits: Option<u32>,
36 two_d_max: Option<f64>,
37 two_d_min: Option<f64>,
38 lang_field: Flag,
39 pfe: Option<String>,
40}
41
42const DEFAULT_2D_BITS: u32 = 26;
43const DEFAULT_2D_MIN: f64 = -180.0;
44const DEFAULT_2D_MAX: f64 = 180.0;
45
46#[derive(FromField)]
47#[darling(attributes(db))]
48struct FieldIdentityData {
49 ident: Option<syn::Ident>,
50 id_field: Flag,
51 native_id_field: Flag,
52 indexing: Option<FieldIdentityMetaData>,
53}
54
55#[derive(Default, Copy, Clone)]
56enum TwoDPacked {
57 Spherical {
58 bits: u32,
59 max: f64,
60 min: f64,
61 },
62 #[default]
63 Cartesian,
64}
65
66struct CombinedFieldIdentityData {
67 ident: Option<syn::Ident>,
68 sub: Option<String>,
70 index: Option<i32>,
72 link: Option<String>,
75 order: Option<u8>,
77 unique: Flag,
78 text_weight: Option<u8>,
80 two_d: Option<TwoDPacked>,
81 icase_locale: Option<String>,
82 icase_strength: Option<u8>,
83 name: Option<String>,
84 lang_field: Flag,
85 pfe: Option<String>,
87}
88
89impl From<FieldIdentityData> for CombinedFieldIdentityData {
90 fn from(value: FieldIdentityData) -> Self {
91 let meta = value.indexing.expect("Indexing metadata");
92 CombinedFieldIdentityData {
93 ident: value.ident,
94 sub: meta.sub,
95 index: meta.index,
96 link: meta.link,
97 order: meta.order,
98 unique: meta.unique,
99 text_weight: meta.text_weight,
100 two_d: match meta.two_d {
101 None => None,
102 Some(two_d) => Some(match two_d {
103 TwoD::Spherical => TwoDPacked::Spherical {
104 bits: meta.two_d_bits.unwrap_or(DEFAULT_2D_BITS),
105 max: meta.two_d_max.unwrap_or(DEFAULT_2D_MAX),
106 min: meta.two_d_min.unwrap_or(DEFAULT_2D_MIN),
107 },
108 TwoD::Cartesian => TwoDPacked::Cartesian,
109 }),
110 },
111 icase_locale: meta.icase_locale,
112 icase_strength: meta.icase_strength,
113 name: meta.name,
114 lang_field: meta.lang_field,
115 pfe: meta.pfe,
116 }
117 }
118}
119
120enum IndexType {
121 Numeric(i32),
122 Text(u32),
123 TwoD(TwoDPacked),
124}
125
126struct IndexPair {
127 ident: String,
128 index: IndexType,
129 order_index: u8,
130 is_lang_field: bool,
131}
132
133#[derive(Default)]
134struct CaseInsensitivity {
135 locale: String,
136 strength: u8,
137}
138
139struct CollatedFieldIdentityData {
140 pairs: Vec<IndexPair>,
141 unique: bool,
142 name: Option<String>,
143 link: Option<String>,
144 case_insensitivity: Option<CaseInsensitivity>,
145 two_d: Option<TwoDPacked>,
146 pfe: Option<String>,
148}
149
150#[derive(FromDeriveInput)]
151#[darling(attributes(db), supports(struct_named), forward_attrs(allow, doc, cfg))]
152struct CollectionIdentityData {
153 ident: syn::Ident,
154 name: String,
155 expiration_secs: Option<u64>,
157 data: ast::Data<(), FieldIdentityData>,
158}
159
160#[proc_macro_derive(CollectionIdentity, attributes(db))]
161pub fn collection_identity(input: TokenStream) -> TokenStream {
162 let input = parse_macro_input!(input as DeriveInput);
163 let collection = match CollectionIdentityData::from_derive_input(&input) {
164 Ok(parsed) => parsed,
165 Err(e) => return e.write_errors().into(),
166 };
167
168 let collection_name = collection.name;
169 let struct_name = collection.ident;
170
171 let mut fields = collection
173 .data
174 .take_struct()
175 .expect("Must be struct")
176 .fields;
177
178 let mut id_field = None;
179 let mut native_id = false;
180
181 for field in &mut fields {
182 if field.id_field.is_present() || field.native_id_field.is_present() {
183 if id_field.is_some()
184 || (field.id_field.is_present() && field.native_id_field.is_present())
185 {
186 panic!("Multiple ID fields not allowed!");
187 }
188
189 id_field = Some(
190 field
191 .ident
192 .as_ref()
193 .expect("ID field identifier")
194 .to_string(),
195 );
196 native_id = field.native_id_field.is_present();
197
198 if !native_id {
199 field.indexing.get_or_insert_default().unique = Flag::present();
200 }
201 }
202 }
203
204 let id_field = id_field.expect("ID field must be present!");
205 let id_field_tok: syn::Ident = syn::parse_str(&id_field).expect("Valid parse of ID field");
206 let (id_field, id_field_value) = if native_id {
207 (
208 format!("_{id_field}"),
209 quote!(self.#id_field_tok.as_ref().unwrap()),
210 )
211 } else {
212 (id_field, quote!(&self.#id_field_tok))
213 };
214
215 let sync_impl = if cfg!(feature = "sync") {
216 quote! {
217 fn save_sync(&self, db: &::goldleaf::mongodb::sync::Database) -> Result<(), ::mongodb::error::Error> {
218 let coll = <::goldleaf::mongodb::sync::Database as ::goldleaf::SyncAutoCollection>::auto_collection::<Self>(db);
219 let res = coll.replace_one(::goldleaf::mongodb::bson::doc! {
220 #id_field: #id_field_value
221 }, self).run()?;
222
223 debug_assert_eq!(res.matched_count, 1, "unable to find structure with identifying field `{}`", #id_field);
224
225 Ok(())
226 }
227 }
228 } else {
229 quote! {}
230 };
231
232 let identity = quote! {
234 #[::goldleaf::async_trait]
235 impl ::goldleaf::CollectionIdentity for #struct_name {
236 const COLLECTION: &'static str = #collection_name;
237
238 async fn save(&self, db: &::goldleaf::mongodb::Database) -> Result<(), ::mongodb::error::Error> {
239 let coll = <::goldleaf::mongodb::Database as ::goldleaf::AutoCollection>::auto_collection::<Self>(db);
240 let res = coll.replace_one(::goldleaf::mongodb::bson::doc! {
241 #id_field: #id_field_value
242 }, self).await?;
243
244 debug_assert_eq!(res.matched_count, 1, "unable to find structure with identifying field `{}`", #id_field);
245
246 Ok(())
247 }
248
249 #sync_impl
250 }
251 };
252
253 let indexing_fields = fields
254 .into_iter()
255 .filter(|f| f.indexing.is_some())
256 .collect::<Vec<_>>();
257 if indexing_fields.is_empty() {
258 return identity.into();
259 }
260 let indexing_fields = indexing_fields
261 .into_iter()
262 .map(CombinedFieldIdentityData::from)
263 .collect::<Vec<_>>();
264
265 let mut identities: Vec<CollatedFieldIdentityData> = vec![];
267 for field in indexing_fields {
268 if let Some(link_id) = &field.link {
270 if let Some(id) = identities
271 .iter_mut()
272 .find(|id| id.link.as_ref().is_some_and(|l| l == link_id))
273 {
274 id.pairs.push(generate_index_pair(&field));
275
276 id.pairs.sort_unstable_by_key(|data| data.order_index);
277
278 id.unique = id.unique || field.unique.is_present();
279 if let Some(name) = field.name {
280 id.name = Some(name);
281 }
282 if let (Some(locale), Some(strength)) = (field.icase_locale, field.icase_strength) {
283 id.case_insensitivity = Some(CaseInsensitivity { locale, strength })
284 }
285
286 if let Some(two_d) = field.two_d {
287 id.two_d = Some(two_d);
288 }
289 }
290 } else {
291 identities.push(CollatedFieldIdentityData {
293 pairs: vec![generate_index_pair(&field)],
294 unique: field.unique.is_present(),
295 name: field.name,
296 link: field.link,
297 case_insensitivity: if let (Some(locale), Some(strength)) =
298 (field.icase_locale, field.icase_strength)
299 {
300 Some(CaseInsensitivity { locale, strength })
301 } else {
302 None
303 },
304 two_d: field.two_d,
305 pfe: field.pfe,
306 })
307 }
308 }
309
310 let docs = identities
312 .iter()
313 .map(|i| {
314 let pairs = i.pairs.iter().map(|p| {
315 let ident = p.ident.clone();
316 match &p.index {
317 IndexType::Numeric(val) => quote! {
318 #ident: #val
319 },
320 IndexType::Text { .. } => quote! {
321 #ident: "text"
322 },
323 IndexType::TwoD(two_d) => match two_d {
324 TwoDPacked::Spherical { .. } => quote! {
325 #ident: "2dsphere"
326 },
327 TwoDPacked::Cartesian => quote! {
328 #ident: "2d"
329 },
330 },
331 }
332 });
333
334 quote! {
335 ::goldleaf::mongodb::bson::doc!{#(#pairs),*}
336 }
337 })
338 .collect::<Vec<_>>();
339
340 let opts = identities.iter().map(|i| {
342 let index_name = i.name.clone().unwrap_or("".to_string());
343 let unique = i.unique;
344
345 let use_two_d = i.two_d.is_some_and(|t| match t {
348 TwoDPacked::Spherical { .. } => true,
349 TwoDPacked::Cartesian => false,
350 });
351 let two_d = i.two_d.unwrap_or_default();
352 let (bits, max, min) = match two_d {
353 TwoDPacked::Spherical { bits, max, min } => (bits, max, min),
354 TwoDPacked::Cartesian => (0, 0f64, 0f64),
355 };
356
357 let pairs = i.pairs.iter().filter_map(|p| match p.index {
359 IndexType::Text(weight) => Some((p, weight)),
360 _ => None,
361 }).map(|(text_pair, weight)| {
362 let ident = text_pair.ident.clone();
363 quote! { #ident: #weight }
364 }).collect::<Vec<_>>();
365
366 let has_weights = !pairs.is_empty();
367
368 let weights = quote! {
369 ::goldleaf::mongodb::bson::doc!{#(#pairs),*}
370 };
371
372 let use_collation = i.case_insensitivity.is_some();
374 let collation = match &i.case_insensitivity {
375 None => quote! {
376 ::goldleaf::mongodb::options::Collation::builder().locale("en".to_string()).build()
377 },
378 Some(case_insensitivity) => {
379 let locale = &case_insensitivity.locale;
380 let strength = case_insensitivity.strength;
381 let strength = quote! {
382 match #strength {
383 1 => ::goldleaf::mongodb::options::CollationStrength::Primary,
384 2 => ::goldleaf::mongodb::options::CollationStrength::Secondary,
385 3 => ::goldleaf::mongodb::options::CollationStrength::Tertiary,
386 4 => ::goldleaf::mongodb::options::CollationStrength::Quaternary,
387 5 => ::goldleaf::mongodb::options::CollationStrength::Identical,
388 _ => panic!("Collation strength out of bounds!")
389 }
390 };
391 quote! {
392 ::goldleaf::mongodb::options::Collation::builder().locale(#locale.to_string()).strength(Some(#strength)).build()
393 }
394 },
395 };
396
397 let language = i.pairs.iter().find_map(|p| if p.is_lang_field { Some(p.ident.clone()) } else { None }).unwrap_or("".to_string());
399
400 let expiration_secs = collection.expiration_secs.unwrap_or(0);
402
403 let has_pfe = i.pfe.is_some();
405 let pfe: proc_macro2::TokenStream = i.pfe.clone().unwrap_or_default().parse().expect("PFE to be parseable");
406 let pfe = quote! {
407 ::goldleaf::mongodb::bson::doc!{#pfe}
408 };
409
410 quote! {
411 ::goldleaf::mongodb::options::IndexOptions::builder()
412 .name(if #index_name.is_empty() {None} else {Some(#index_name.to_string())})
413 .unique(Some(#unique))
414 .expire_after(if #expiration_secs > 0 {Some(::std::time::Duration::from_secs(#expiration_secs))} else {None})
415 .weights(if #has_weights {Some(#weights)} else {None})
416 .bits(if #use_two_d {Some(#bits)} else {None})
417 .max(if #use_two_d {Some(#max)} else {None})
418 .min(if #use_two_d {Some(#min)} else {None})
419 .collation(if #use_collation {Some(#collation)} else {None})
420 .language_override(if #language.is_empty() {None} else {Some(#language.to_string())})
421 .partial_filter_expression(if #has_pfe {Some(#pfe)} else {None})
422 .build()
423 }
424 }).collect::<Vec<_>>();
425
426 let calls = docs.iter().zip(opts.iter()).map(|(doc, opt)| quote! {coll.create_index(::goldleaf::mongodb::IndexModel::builder().keys(#doc).options(Some(#opt)).build()).await?;}).collect::<Vec<_>>();
428
429 let indices = quote! {
431 impl #struct_name {
432 pub async fn create_indices(db: &::goldleaf::mongodb::Database) -> Result<(), ::mongodb::error::Error> {
433 let coll = <::goldleaf::mongodb::Database as ::goldleaf::AutoCollection>::auto_collection::<Self>(db);
434
435 #(#calls)*
436 Ok(())
437 }
438 }
439 };
440
441 let out = quote! {
443 #identity
444
445 #indices
446 };
447
448 out.into()
449}
450
451fn generate_index_pair(field: &CombinedFieldIdentityData) -> IndexPair {
452 IndexPair {
453 ident: match field.sub.as_ref() {
454 None => field.ident.as_ref().expect("Field identifier").to_string(),
455 Some(sub) => format!(
456 "{}.{}",
457 field.ident.as_ref().expect("Field identifier"),
458 sub
459 ),
460 },
461 index: if let Some(text_weight) = field.text_weight {
462 IndexType::Text(text_weight.into())
463 } else if let Some(two_d) = field.two_d.as_ref() {
464 IndexType::TwoD(*two_d)
465 } else {
466 IndexType::Numeric(field.index.unwrap_or(1))
467 },
468 order_index: field.order.unwrap_or(0),
469 is_lang_field: field.lang_field.is_present(),
470 }
471}