1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Field, Fields};
4
5#[proc_macro_derive(DnfEvaluable, attributes(dnf))]
43pub fn derive_dnf_evaluable(input: TokenStream) -> TokenStream {
44 let input = parse_macro_input!(input as DeriveInput);
45
46 let name = &input.ident;
47
48 let fields = match &input.data {
50 Data::Struct(data) => match &data.fields {
51 Fields::Named(fields) => &fields.named,
52 _ => {
53 return syn::Error::new_spanned(
54 &input,
55 "DnfEvaluable can only be derived for structs with named fields",
56 )
57 .to_compile_error()
58 .into();
59 }
60 },
61 _ => {
62 return syn::Error::new_spanned(&input, "DnfEvaluable can only be derived for structs")
63 .to_compile_error()
64 .into();
65 }
66 };
67
68 let match_arms = fields.iter().filter_map(generate_field_match_arm);
70
71 let nested_match_arms = fields.iter().filter_map(generate_nested_field_match_arm);
73
74 let field_infos = fields.iter().filter_map(generate_field_info);
76
77 let field_value_arms = fields.iter().filter_map(generate_field_value_arm);
79
80 let expanded = quote! {
81 impl dnf::DnfEvaluable for #name {
82 fn evaluate_field(
83 &self,
84 field_name: &str,
85 operator: &dnf::Op,
86 value: &dnf::Value
87 ) -> bool {
88 match field_name {
90 #(#match_arms)*
91 _ => {
92 if let Some(dot_pos) = field_name.find('.') {
94 let (outer, inner) = field_name.split_at(dot_pos);
95 let inner = &inner[1..]; match outer {
97 #(#nested_match_arms)*
98 _ => false,
99 }
100 } else {
101 false }
103 }
104 }
105 }
106
107 fn get_field_value(&self, field_name: &str) -> Option<dnf::Value> {
108 match field_name {
109 #(#field_value_arms)*
110 _ => None,
111 }
112 }
113
114 fn fields() -> impl Iterator<Item = dnf::FieldInfo> {
115 [
116 #(#field_infos),*
117 ].into_iter()
118 }
119 }
120 };
121
122 TokenStream::from(expanded)
123}
124
125fn generate_field_match_arm(field: &Field) -> Option<proc_macro2::TokenStream> {
127 let field_name = field.ident.as_ref()?;
128 let field_type = &field.ty;
129
130 if has_skip_attribute(field) {
132 return None;
133 }
134
135 let type_str = quote!(#field_type).to_string().replace(" ", "");
137
138 let has_iter = get_iter_attribute(field).is_some();
142 if has_nested_attribute(field) || (!has_iter && is_nested_type(&type_str)) {
143 return None;
144 }
145
146 let query_name = get_rename_attribute(field).unwrap_or_else(|| field_name.to_string());
148
149 let value_conversion = generate_value_conversion(field, field_name, field_type);
151
152 Some(quote! {
153 #query_name => #value_conversion,
154 })
155}
156
157fn generate_field_value_arm(field: &Field) -> Option<proc_macro2::TokenStream> {
161 let field_name = field.ident.as_ref()?;
162 let field_type = &field.ty;
163
164 if has_skip_attribute(field) {
166 return None;
167 }
168
169 let type_str = quote!(#field_type).to_string().replace(" ", "");
171
172 let has_iter = get_iter_attribute(field).is_some();
174 if has_nested_attribute(field) || (!has_iter && is_nested_type(&type_str)) {
175 return None;
176 }
177
178 if !is_value_convertible(&type_str) {
181 return None;
182 }
183
184 let query_name = get_rename_attribute(field).unwrap_or_else(|| field_name.to_string());
186
187 let value_conversion = if type_str.starts_with("Option<") {
189 quote! {
191 match &self.#field_name {
192 Some(v) => Some(dnf::Value::from(v)),
193 None => Some(dnf::Value::None),
194 }
195 }
196 } else {
197 quote! {
199 Some(dnf::Value::from(&self.#field_name))
200 }
201 };
202
203 Some(quote! {
204 #query_name => #value_conversion,
205 })
206}
207
208fn is_value_convertible(type_str: &str) -> bool {
211 let primitives = [
213 "i8", "i16", "i32", "i64", "isize", "u8", "u16", "u32", "u64", "usize", "f32", "f64",
214 "bool", "String",
215 ];
216
217 if primitives.contains(&type_str) {
218 return true;
219 }
220
221 if type_str.starts_with("&") && type_str.contains("str") {
223 return true;
224 }
225
226 if type_str.starts_with("Cow<") && type_str.contains("str") {
228 return true;
229 }
230
231 if type_str.starts_with("Vec<") {
233 if let Some(inner) = type_str
234 .strip_prefix("Vec<")
235 .and_then(|s| s.strip_suffix(">"))
236 {
237 return primitives.contains(&inner);
238 }
239 }
240
241 if type_str.starts_with("HashSet<") {
243 if let Some(inner) = type_str
244 .strip_prefix("HashSet<")
245 .and_then(|s| s.strip_suffix(">"))
246 {
247 return primitives.contains(&inner) && inner != "f32" && inner != "f64";
249 }
250 }
251
252 if type_str.starts_with("Option<") {
254 if let Some(inner) = type_str
255 .strip_prefix("Option<")
256 .and_then(|s| s.strip_suffix(">"))
257 {
258 return is_value_convertible(inner);
259 }
260 }
261
262 false
263}
264
265fn is_nested_type(type_str: &str) -> bool {
269 if type_str.starts_with("Vec<") {
271 if let Some(inner) = type_str
272 .strip_prefix("Vec<")
273 .and_then(|s| s.strip_suffix(">"))
274 {
275 return !is_primitive_or_builtin(inner);
276 }
277 }
278
279 if type_str.starts_with("Option<Vec<") {
281 if let Some(inner) = type_str
282 .strip_prefix("Option<Vec<")
283 .and_then(|s| s.strip_suffix(">>"))
284 {
285 return !is_primitive_or_builtin(inner);
286 }
287 }
288
289 if is_map_type(type_str) {
291 if let Some((_, value_type)) = extract_map_types(type_str) {
292 return !is_primitive_or_builtin(&value_type);
293 }
294 }
295
296 if type_str.starts_with("Option<HashMap<") || type_str.starts_with("Option<BTreeMap<") {
298 if let Some(inner) = type_str
299 .strip_prefix("Option<")
300 .and_then(|s| s.strip_suffix(">"))
301 {
302 if let Some((_, value_type)) = extract_map_types(inner) {
303 return !is_primitive_or_builtin(&value_type);
304 }
305 }
306 }
307
308 false
312}
313
314fn is_map_type(type_str: &str) -> bool {
316 type_str.starts_with("HashMap<") || type_str.starts_with("BTreeMap<")
317}
318
319fn extract_map_types(type_str: &str) -> Option<(String, String)> {
322 let inner = type_str
323 .strip_prefix("HashMap<")
324 .or_else(|| type_str.strip_prefix("BTreeMap<"))?;
325 let inner = inner.strip_suffix(">")?;
326
327 let mut depth = 0;
329 let mut comma_pos = None;
330 for (i, c) in inner.char_indices() {
331 match c {
332 '<' => depth += 1,
333 '>' => depth -= 1,
334 ',' if depth == 0 => {
335 comma_pos = Some(i);
336 break;
337 }
338 _ => {}
339 }
340 }
341
342 let pos = comma_pos?;
343 let key = inner[..pos].trim().to_string();
344 let value = inner[pos + 1..].trim().to_string();
345 Some((key, value))
346}
347
348fn is_string_key(key_type: &str) -> bool {
350 let t = key_type.trim();
351 matches!(t, "String" | "str" | "&str")
353 || (t.starts_with("&'") && (t.ends_with("str") || t.ends_with(" str")))
355}
356
357fn has_skip_attribute(field: &Field) -> bool {
359 for attr in &field.attrs {
360 if attr.path().is_ident("dnf") {
361 let mut has_skip = false;
362 let _ = attr.parse_nested_meta(|meta| {
363 if meta.path.is_ident("skip") {
364 has_skip = true;
365 }
366 Ok(())
367 });
368 if has_skip {
369 return true;
370 }
371 }
372 }
373 false
374}
375
376fn has_nested_attribute(field: &Field) -> bool {
378 for attr in &field.attrs {
379 if attr.path().is_ident("dnf") {
380 let mut has_nested = false;
381 let _ = attr.parse_nested_meta(|meta| {
382 if meta.path.is_ident("nested") {
383 has_nested = true;
384 }
385 Ok(())
386 });
387 if has_nested {
388 return true;
389 }
390 }
391 }
392 false
393}
394
395fn generate_nested_field_match_arm(field: &Field) -> Option<proc_macro2::TokenStream> {
403 let field_name = field.ident.as_ref()?;
404 let field_type = &field.ty;
405
406 if has_skip_attribute(field) {
408 return None;
409 }
410
411 let type_str = quote!(#field_type).to_string().replace(" ", "");
413
414 let has_iter = get_iter_attribute(field).is_some();
416 if has_iter {
417 return None;
418 }
419
420 if !has_nested_attribute(field) && !is_nested_type(&type_str) {
422 return None;
423 }
424
425 let query_name = get_rename_attribute(field).unwrap_or_else(|| field_name.to_string());
427
428 let delegation_code = if type_str.starts_with("Vec<") {
429 quote! {
431 self.#field_name.iter().any(|item| item.evaluate_field(inner, operator, value))
432 }
433 } else if type_str.starts_with("Option<Vec<") {
434 quote! {
436 match &self.#field_name {
437 Some(vec) => vec.iter().any(|item| item.evaluate_field(inner, operator, value)),
438 None => false,
439 }
440 }
441 } else if type_str.starts_with("HashMap<") || type_str.starts_with("BTreeMap<") {
442 quote! {
445 if let Some(rest) = inner.strip_prefix("@values.") {
446 self.#field_name.values().any(|item| item.evaluate_field(rest, operator, value))
448 } else if inner == "@keys" {
449 operator.any(self.#field_name.keys(), value)
451 } else if inner.starts_with("[\"") {
452 if let Some(end_bracket) = inner.find("\"]") {
454 let key = &inner[2..end_bracket];
455 let rest = inner.get(end_bracket + 2..).unwrap_or("").trim_start_matches('.');
456 if rest.is_empty() {
457 false
459 } else {
460 match self.#field_name.get(key) {
461 Some(item) => item.evaluate_field(rest, operator, value),
462 None => false,
463 }
464 }
465 } else {
466 false
467 }
468 } else {
469 false
471 }
472 }
473 } else if type_str.starts_with("Option<HashMap<") || type_str.starts_with("Option<BTreeMap<") {
474 quote! {
476 match &self.#field_name {
477 Some(map) => {
478 if let Some(rest) = inner.strip_prefix("@values.") {
479 map.values().any(|item| item.evaluate_field(rest, operator, value))
480 } else if inner == "@keys" {
481 operator.any(map.keys(), value)
482 } else if inner.starts_with("[\"") {
483 if let Some(end_bracket) = inner.find("\"]") {
484 let key = &inner[2..end_bracket];
485 let rest = inner.get(end_bracket + 2..).unwrap_or("").trim_start_matches('.');
486 if rest.is_empty() {
487 false
488 } else {
489 match map.get(key) {
490 Some(item) => item.evaluate_field(rest, operator, value),
491 None => false,
492 }
493 }
494 } else {
495 false
496 }
497 } else {
498 false
500 }
501 },
502 None => false,
503 }
504 }
505 } else if type_str.starts_with("Option<") {
506 quote! {
508 match &self.#field_name {
509 Some(inner_val) => inner_val.evaluate_field(inner, operator, value),
510 None => false,
511 }
512 }
513 } else {
514 quote! {
516 self.#field_name.evaluate_field(inner, operator, value)
517 }
518 };
519
520 Some(quote! {
521 #query_name => #delegation_code,
522 })
523}
524
525fn get_rename_attribute(field: &Field) -> Option<String> {
527 for attr in &field.attrs {
528 if attr.path().is_ident("dnf") {
529 let mut rename_value = None;
530 let _ = attr.parse_nested_meta(|meta| {
531 if meta.path.is_ident("rename") {
532 if let Ok(value) = meta.value() {
533 if let Ok(lit_str) = value.parse::<syn::LitStr>() {
534 rename_value = Some(lit_str.value());
535 }
536 }
537 }
538 Ok(())
539 });
540 if let Some(name) = rename_value {
541 return Some(name);
542 }
543 }
544 }
545 None
546}
547
548fn get_iter_attribute(field: &Field) -> Option<Option<String>> {
554 for attr in &field.attrs {
555 if attr.path().is_ident("dnf") {
556 let mut has_iter = false;
557 let mut iter_method = None;
558 let _ = attr.parse_nested_meta(|meta| {
559 if meta.path.is_ident("iter") {
560 has_iter = true;
561 if let Ok(value) = meta.value() {
563 if let Ok(lit_str) = value.parse::<syn::LitStr>() {
564 iter_method = Some(lit_str.value());
565 }
566 }
567 }
568 Ok(())
569 });
570 if has_iter {
571 return Some(iter_method);
572 }
573 }
574 }
575 None
576}
577
578fn generate_value_conversion(
585 field: &Field,
586 field_name: &syn::Ident,
587 _field_type: &syn::Type,
588) -> proc_macro2::TokenStream {
589 if let Some(iter_method) = get_iter_attribute(field) {
591 let method = iter_method.unwrap_or_else(|| "iter".to_string());
592 let method_ident = syn::Ident::new(&method, field_name.span());
593 return quote! {
594 operator.any(self.#field_name.#method_ident(), value)
595 };
596 }
597
598 quote! {
600 dnf::DnfField::evaluate(&self.#field_name, operator, value)
601 }
602}
603
604fn is_primitive_or_builtin(type_str: &str) -> bool {
612 let primitives = [
614 "i8", "i16", "i32", "i64", "isize", "u8", "u16", "u32", "u64", "usize", "f32", "f64",
615 "bool", "String",
616 ];
617
618 if primitives.contains(&type_str) {
619 return true;
620 }
621
622 if type_str.starts_with("&") && type_str.contains("str") {
624 return true;
625 }
626
627 if type_str.starts_with("Cow<") && type_str.contains("str") {
629 return true;
630 }
631
632 if type_str.starts_with("Vec<") {
634 if let Some(inner) = type_str.strip_prefix("Vec<") {
635 if let Some(inner) = inner.strip_suffix(">") {
636 return is_primitive_or_builtin(inner);
638 }
639 }
640 }
641
642 if type_str.starts_with("HashSet<") {
645 if let Some(inner) = type_str.strip_prefix("HashSet<") {
646 if let Some(inner) = inner.strip_suffix(">") {
647 if inner == "f32" || inner == "f64" {
649 return false;
650 }
651 return is_primitive_or_builtin(inner);
653 }
654 }
655 }
656
657 if is_map_type(type_str) {
659 if let Some((key_type, value_type)) = extract_map_types(type_str) {
660 return is_string_key(&key_type) && is_primitive_or_builtin(&value_type);
661 }
662 }
663
664 false
665}
666
667fn generate_field_info(field: &Field) -> Option<proc_macro2::TokenStream> {
669 let field_name = field.ident.as_ref()?;
670 let field_type = &field.ty;
671
672 if has_skip_attribute(field) {
674 return None;
675 }
676
677 let query_name = get_rename_attribute(field).unwrap_or_else(|| field_name.to_string());
679
680 let type_str = quote!(#field_type).to_string();
682 let type_str_normalized = type_str.replace(" ", "");
683
684 let field_kind = if get_iter_attribute(field).is_some() {
686 quote! { dnf::FieldKind::Iter }
688 } else if is_map_type(&type_str_normalized) {
689 quote! { dnf::FieldKind::Map }
690 } else if type_str_normalized.starts_with("Vec<")
691 || type_str_normalized.starts_with("HashSet<")
692 || type_str_normalized.starts_with("BTreeSet<")
693 {
694 quote! { dnf::FieldKind::Iter }
695 } else {
696 quote! { dnf::FieldKind::Scalar }
697 };
698
699 Some(quote! {
700 dnf::FieldInfo::with_kind(#query_name, #type_str, #field_kind)
701 })
702}
703
704#[cfg(test)]
705mod tests {
706 use super::*;
707
708 #[test]
709 fn test_primitives_use_dnf_field() {
710 let primitives = vec!["String", "u32", "i64", "f64", "bool"];
712
713 for type_str in primitives {
714 let input_str = format!("struct User {{ field: {} }}", type_str);
715 let input: proc_macro2::TokenStream = input_str.parse().unwrap();
716
717 let parsed: DeriveInput = syn::parse2(input).unwrap();
718 let fields = match &parsed.data {
719 Data::Struct(data) => match &data.fields {
720 Fields::Named(fields) => &fields.named,
721 _ => continue,
722 },
723 _ => continue,
724 };
725
726 if let Some(field) = fields.first() {
727 let conversion =
728 generate_value_conversion(field, field.ident.as_ref().unwrap(), &field.ty);
729 let conversion_str = conversion.to_string();
730
731 assert!(
733 conversion_str.contains("DnfField :: evaluate"),
734 "Type {} should use DnfField::evaluate(), got: {}",
735 type_str,
736 conversion_str
737 );
738 }
739 }
740 }
741
742 #[test]
743 fn test_collections_use_dnf_field() {
744 let collections = vec!["Vec<String>", "HashSet<i32>"];
746
747 for type_str in collections {
748 let input_str = format!("struct User {{ field: {} }}", type_str);
749 let input: proc_macro2::TokenStream = input_str.parse().unwrap();
750
751 let parsed: DeriveInput = syn::parse2(input).unwrap();
752 let fields = match &parsed.data {
753 Data::Struct(data) => match &data.fields {
754 Fields::Named(fields) => &fields.named,
755 _ => continue,
756 },
757 _ => continue,
758 };
759
760 if let Some(field) = fields.first() {
761 let conversion =
762 generate_value_conversion(field, field.ident.as_ref().unwrap(), &field.ty);
763 let conversion_str = conversion.to_string();
764
765 assert!(
767 conversion_str.contains("DnfField :: evaluate"),
768 "Collection {} should use DnfField::evaluate(), got: {}",
769 type_str,
770 conversion_str
771 );
772 }
773 }
774 }
775
776 #[test]
777 fn test_custom_types_use_dnf_field() {
778 let custom_types = vec!["Score", "CustomEnum", "MyStruct"];
780
781 for type_str in custom_types {
782 let input_str = format!("struct User {{ field: {} }}", type_str);
783 let input: proc_macro2::TokenStream = input_str.parse().unwrap();
784
785 let parsed: DeriveInput = syn::parse2(input).unwrap();
786 let fields = match &parsed.data {
787 Data::Struct(data) => match &data.fields {
788 Fields::Named(fields) => &fields.named,
789 _ => continue,
790 },
791 _ => continue,
792 };
793
794 if let Some(field) = fields.first() {
795 let conversion =
796 generate_value_conversion(field, field.ident.as_ref().unwrap(), &field.ty);
797 let conversion_str = conversion.to_string();
798
799 assert!(
801 conversion_str.contains("DnfField :: evaluate"),
802 "Custom type {} should use DnfField::evaluate(), got: {}",
803 type_str,
804 conversion_str
805 );
806 }
807 }
808 }
809
810 #[test]
811 fn test_iter_attribute_generates_any() {
812 let input_str = "struct User { #[dnf(iter)] field: LinkedList<String> }";
814 let input: proc_macro2::TokenStream = input_str.parse().unwrap();
815
816 let parsed: DeriveInput = syn::parse2(input).unwrap();
817 let fields = match &parsed.data {
818 Data::Struct(data) => match &data.fields {
819 Fields::Named(fields) => &fields.named,
820 _ => panic!("Expected named fields"),
821 },
822 _ => panic!("Expected struct"),
823 };
824
825 let field = fields.first().unwrap();
826 let conversion = generate_value_conversion(field, field.ident.as_ref().unwrap(), &field.ty);
827 let conversion_str = conversion.to_string();
828
829 assert!(
831 conversion_str.contains("any") && conversion_str.contains(". iter ()"),
832 "Expected any with .iter(), got: {}",
833 conversion_str
834 );
835 }
836
837 #[test]
838 fn test_iter_attribute_with_custom_method() {
839 let input_str = "struct User { #[dnf(iter = \"items\")] field: CustomList<i32> }";
841 let input: proc_macro2::TokenStream = input_str.parse().unwrap();
842
843 let parsed: DeriveInput = syn::parse2(input).unwrap();
844 let fields = match &parsed.data {
845 Data::Struct(data) => match &data.fields {
846 Fields::Named(fields) => &fields.named,
847 _ => panic!("Expected named fields"),
848 },
849 _ => panic!("Expected struct"),
850 };
851
852 let field = fields.first().unwrap();
853 let conversion = generate_value_conversion(field, field.ident.as_ref().unwrap(), &field.ty);
854 let conversion_str = conversion.to_string();
855
856 assert!(
858 conversion_str.contains("any") && conversion_str.contains(". items ()"),
859 "Expected any with .items(), got: {}",
860 conversion_str
861 );
862 }
863}